♻️ Refactored client package
This commit is contained in:
parent
fbb68d995b
commit
32e9618730
|
@ -2,10 +2,13 @@ package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"source.toby3d.me/website/oauth/internal/domain"
|
"source.toby3d.me/website/oauth/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Repository interface {
|
type Repository interface {
|
||||||
Get(ctx context.Context, id string) (*domain.Client, error)
|
Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrNotExist = errors.New("client with the specified ID does not exist")
|
||||||
|
|
|
@ -3,10 +3,10 @@ package http
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/tomnomnom/linkheader"
|
"github.com/tomnomnom/linkheader"
|
||||||
http "github.com/valyala/fasthttp"
|
http "github.com/valyala/fasthttp"
|
||||||
"willnorris.com/go/microformats"
|
"willnorris.com/go/microformats"
|
||||||
|
@ -20,16 +20,14 @@ type httpClientRepository struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HApp string = "h-app"
|
relRedirectURI string = "redirect_uri"
|
||||||
HXApp string = "h-x-app"
|
|
||||||
|
|
||||||
KeyName string = "name"
|
hApp string = "h-app"
|
||||||
KeyLogo string = "logo"
|
hXApp string = "h-x-app"
|
||||||
KeyURL string = "url"
|
|
||||||
|
|
||||||
ValueValue string = "value"
|
propertyLogo string = "logo"
|
||||||
|
propertyName string = "name"
|
||||||
RelRedirectURI string = "redirect_uri"
|
propertyURL string = "url"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewHTTPClientRepository(c *http.Client) client.Repository {
|
func NewHTTPClientRepository(c *http.Client) client.Repository {
|
||||||
|
@ -38,93 +36,134 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (repo *httpClientRepository) Get(ctx context.Context, id string) (*domain.Client, error) {
|
func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
|
||||||
u, err := url.Parse(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to parse id as url")
|
|
||||||
}
|
|
||||||
|
|
||||||
req := http.AcquireRequest()
|
req := http.AcquireRequest()
|
||||||
defer http.ReleaseRequest(req)
|
defer http.ReleaseRequest(req)
|
||||||
req.SetRequestURI(u.String())
|
req.SetRequestURI(id.String())
|
||||||
req.Header.SetMethod(http.MethodGet)
|
req.Header.SetMethod(http.MethodGet)
|
||||||
|
|
||||||
resp := http.AcquireResponse()
|
resp := http.AcquireResponse()
|
||||||
defer http.ReleaseResponse(resp)
|
defer http.ReleaseResponse(resp)
|
||||||
|
|
||||||
if err := repo.client.Do(req, resp); err != nil {
|
if err := repo.client.Do(req, resp); err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to make a request to the client")
|
return nil, fmt.Errorf("failed to make a request to the client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := domain.NewClient()
|
if resp.StatusCode() == http.StatusNotFound {
|
||||||
client.ID = id
|
return nil, client.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
for _, l := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) {
|
client := &domain.Client{
|
||||||
if !strings.Contains(l.Rel, "redirect_uri") {
|
ID: id,
|
||||||
|
Logo: make([]*domain.URL, 0),
|
||||||
|
Name: extractValues(resp, propertyName),
|
||||||
|
RedirectURI: extractEndpoints(resp, relRedirectURI),
|
||||||
|
URL: make([]*domain.URL, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range extractValues(resp, propertyLogo) {
|
||||||
|
u, err := domain.NewURL(v)
|
||||||
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
client.RedirectURI = append(client.RedirectURI, l.URL)
|
client.Logo = append(client.Logo, u)
|
||||||
}
|
}
|
||||||
|
|
||||||
data := microformats.Parse(bytes.NewReader(resp.Body()), u)
|
for _, v := range extractValues(resp, propertyURL) {
|
||||||
|
u, err := domain.NewURL(v)
|
||||||
for _, item := range data.Items {
|
if err != nil {
|
||||||
if len(item.Type) == 0 && !strings.EqualFold(item.Type[0], HApp) &&
|
|
||||||
!strings.EqualFold(item.Type[0], HXApp) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
populateProperties(item.Properties, client)
|
client.URL = append(client.URL, u)
|
||||||
}
|
}
|
||||||
|
|
||||||
populateRels(data.Rels, client)
|
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func populateProperties(src map[string][]interface{}, dst *domain.Client) {
|
func extractEndpoints(resp *http.Response, name string) []*domain.URL {
|
||||||
for key, property := range src {
|
results := make([]*domain.URL, 0)
|
||||||
if len(property) == 0 {
|
endpoints, _ := extractEndpointsFromHeader(resp, name)
|
||||||
|
results = append(results, endpoints...)
|
||||||
|
endpoints, _ = extractEndpointsFromBody(resp, name)
|
||||||
|
results = append(results, endpoints...)
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractValues(resp *http.Response, key string) []string {
|
||||||
|
results := make([]string, 0)
|
||||||
|
|
||||||
|
for _, item := range microformats.Parse(bytes.NewReader(resp.Body()), nil).Items {
|
||||||
|
if len(item.Type) == 0 || (item.Type[0] != hApp && item.Type[0] != hXApp) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
switch key {
|
properties, ok := item.Properties[key]
|
||||||
case KeyName:
|
if !ok || len(properties) == 0 {
|
||||||
dst.Name = getString(property)
|
return nil
|
||||||
case KeyLogo:
|
}
|
||||||
for i := range property {
|
|
||||||
switch val := property[i].(type) {
|
for j := range properties {
|
||||||
case string:
|
switch p := properties[j].(type) {
|
||||||
dst.Logo = val
|
case string:
|
||||||
case map[string]string:
|
results = append(results, p)
|
||||||
dst.Logo = val[ValueValue]
|
case map[string][]interface{}:
|
||||||
|
for _, val := range p["value"] {
|
||||||
|
v, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case KeyURL:
|
|
||||||
dst.URL = getString(property)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func populateRels(src map[string][]string, dst *domain.Client) {
|
func extractEndpointsFromHeader(resp *http.Response, name string) ([]*domain.URL, error) {
|
||||||
for key, values := range src {
|
results := make([]*domain.URL, 0)
|
||||||
if !strings.EqualFold(key, RelRedirectURI) {
|
|
||||||
|
for _, link := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) {
|
||||||
|
if !strings.EqualFold(link.Rel, name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range values {
|
u := http.AcquireURI()
|
||||||
dst.RedirectURI = append(dst.RedirectURI, values[i])
|
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getString(property []interface{}) string {
|
results = append(results, &domain.URL{URI: u})
|
||||||
for i := range property {
|
|
||||||
val, _ := property[i].(string)
|
|
||||||
|
|
||||||
return val
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractEndpointsFromBody(resp *http.Response, name string) ([]*domain.URL, error) {
|
||||||
|
host, err := url.Parse(string(resp.Header.Peek(http.HeaderHost)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot parse host header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), host).Rels[name]
|
||||||
|
if !ok || len(endpoints) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]*domain.URL, 0)
|
||||||
|
for i := range endpoints {
|
||||||
|
u := http.AcquireURI()
|
||||||
|
u.Update(endpoints[i])
|
||||||
|
|
||||||
|
results = append(results, &domain.URL{URI: u})
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package http_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -11,7 +12,7 @@ import (
|
||||||
repository "source.toby3d.me/website/oauth/internal/client/repository/http"
|
repository "source.toby3d.me/website/oauth/internal/client/repository/http"
|
||||||
"source.toby3d.me/website/oauth/internal/common"
|
"source.toby3d.me/website/oauth/internal/common"
|
||||||
"source.toby3d.me/website/oauth/internal/domain"
|
"source.toby3d.me/website/oauth/internal/domain"
|
||||||
"source.toby3d.me/website/oauth/internal/util"
|
"source.toby3d.me/website/oauth/internal/testing/httptest"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testBody string = `
|
const testBody string = `
|
||||||
|
@ -20,13 +21,13 @@ const testBody string = `
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
<title>Example App</title>
|
<title>%[1]s</title>
|
||||||
<link rel="redirect_uri" href="/redirect">
|
<link rel="redirect_uri" href="%[4]s">
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div class="h-app">
|
<div class="h-app h-x-app">
|
||||||
<img src="/logo.png" class="u-logo">
|
<img class="u-logo" src="%[3]s">
|
||||||
<a href="/" class="u-url p-name">Example App</a>
|
<a class="u-url p-name" href="%[2]s">%[1]s</a>
|
||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
@ -35,17 +36,37 @@ const testBody string = `
|
||||||
func TestGet(t *testing.T) {
|
func TestGet(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
client, _, cleanup := util.TestServe(t, func(ctx *http.RequestCtx) {
|
client := domain.TestClient(t)
|
||||||
ctx.Response.Header.Set(http.HeaderLink, `<https://app.example.net/redirect>; rel="redirect_uri">`)
|
httpClient, _, cleanup := httptest.New(t, testHandler(t, client))
|
||||||
ctx.SetStatusCode(http.StatusOK)
|
|
||||||
ctx.SetContentType(common.MIMETextHTML)
|
|
||||||
ctx.SetBodyString(testBody)
|
|
||||||
})
|
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
c := domain.TestClient(t)
|
result, err := repository.NewHTTPClientRepository(httpClient).Get(context.TODO(), client.ID)
|
||||||
|
|
||||||
result, err := repository.NewHTTPClientRepository(client).Get(context.TODO(), c.ID)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, c, result)
|
|
||||||
|
assert.Equal(t, client.Name, result.Name)
|
||||||
|
assert.Equal(t, client.ID.String(), result.ID.String())
|
||||||
|
|
||||||
|
for i := range client.URL {
|
||||||
|
assert.Equal(t, client.URL[i].String(), result.URL[i].String())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range client.Logo {
|
||||||
|
assert.Equal(t, client.Logo[i].String(), result.Logo[i].String())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range client.RedirectURI {
|
||||||
|
assert.Equal(t, client.RedirectURI[i].String(), result.RedirectURI[i].String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testHandler(tb testing.TB, client *domain.Client) http.RequestHandler {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
return func(ctx *http.RequestCtx) {
|
||||||
|
ctx.Response.Header.Set(http.HeaderLink, `<`+client.RedirectURI[0].String()+`>; rel="redirect_uri"`)
|
||||||
|
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, fmt.Sprintf(
|
||||||
|
testBody, client.Name[0], client.URL[0].String(), client.Logo[0].String(),
|
||||||
|
client.RedirectURI[1].String(),
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,26 +10,32 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type memoryClientRepository struct {
|
type memoryClientRepository struct {
|
||||||
clients *sync.Map
|
store *sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
const Key string = "clients"
|
const DefaultPathPrefix string = "clients"
|
||||||
|
|
||||||
func NewMemoryClientRepository(clients *sync.Map) client.Repository {
|
func NewMemoryClientRepository(store *sync.Map) client.Repository {
|
||||||
return &memoryClientRepository{
|
return &memoryClientRepository{
|
||||||
clients: clients,
|
store: store,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (repo *memoryClientRepository) Get(ctx context.Context, id string) (*domain.Client, error) {
|
func (repo *memoryClientRepository) Create(ctx context.Context, client *domain.Client) error {
|
||||||
src, ok := repo.clients.Load(path.Join(Key, id))
|
repo.store.Store(path.Join(DefaultPathPrefix, client.ID.String()), client)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memoryClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
|
||||||
|
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, id.String()))
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil
|
return nil, client.ErrNotExist
|
||||||
}
|
}
|
||||||
|
|
||||||
c, ok := src.(*domain.Client)
|
c, ok := src.(*domain.Client)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil
|
return nil, client.ErrNotExist
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
|
|
|
@ -16,10 +16,10 @@ import (
|
||||||
func TestGet(t *testing.T) {
|
func TestGet(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
store := new(sync.Map)
|
|
||||||
client := domain.TestClient(t)
|
client := domain.TestClient(t)
|
||||||
|
|
||||||
store.Store(path.Join(repository.Key, client.ID), client)
|
store := new(sync.Map)
|
||||||
|
store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
|
||||||
|
|
||||||
result, err := repository.NewMemoryClientRepository(store).Get(context.TODO(), client.ID)
|
result, err := repository.NewMemoryClientRepository(store).Get(context.TODO(), client.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -2,10 +2,14 @@ package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"source.toby3d.me/website/oauth/internal/domain"
|
"source.toby3d.me/website/oauth/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UseCase interface {
|
type UseCase interface {
|
||||||
Discovery(ctx context.Context, clientID string) (*domain.Client, error)
|
// Discovery returns client public information bu ClientID URL.
|
||||||
|
Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrInvalidMe = errors.New("provided me is invalid")
|
||||||
|
|
|
@ -2,27 +2,26 @@ package usecase
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"github.com/pkg/errors"
|
|
||||||
|
|
||||||
"source.toby3d.me/website/oauth/internal/client"
|
"source.toby3d.me/website/oauth/internal/client"
|
||||||
"source.toby3d.me/website/oauth/internal/domain"
|
"source.toby3d.me/website/oauth/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type clientUseCase struct {
|
type clientUseCase struct {
|
||||||
clients client.Repository
|
repo client.Repository
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientUseCase(clients client.Repository) client.UseCase {
|
func NewClientUseCase(repo client.Repository) client.UseCase {
|
||||||
return &clientUseCase{
|
return &clientUseCase{
|
||||||
clients: clients,
|
repo: repo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (useCase *clientUseCase) Discovery(ctx context.Context, clientID string) (*domain.Client, error) {
|
func (useCase *clientUseCase) Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
|
||||||
c, err := useCase.clients.Get(ctx, clientID)
|
c, err := useCase.repo.Get(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to get client information")
|
return nil, fmt.Errorf("cannot discovery client by id: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
|
|
|
@ -17,13 +17,13 @@ import (
|
||||||
func TestDiscovery(t *testing.T) {
|
func TestDiscovery(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
store := new(sync.Map)
|
|
||||||
client := domain.TestClient(t)
|
client := domain.TestClient(t)
|
||||||
|
|
||||||
store.Store(path.Join(repository.Key, client.ID), client)
|
store := new(sync.Map)
|
||||||
|
store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
|
||||||
|
|
||||||
result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)).Discovery(context.TODO(),
|
result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)).
|
||||||
client.ID)
|
Discovery(context.TODO(), client.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, client, result)
|
assert.Equal(t, client, result)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue