♻️ Refactored client package

This commit is contained in:
Maxim Lebedev 2021-12-30 01:53:31 +05:00
parent fbb68d995b
commit 32e9618730
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
8 changed files with 170 additions and 98 deletions

View File

@ -2,10 +2,13 @@ package client
import ( import (
"context" "context"
"errors"
"source.toby3d.me/website/oauth/internal/domain" "source.toby3d.me/website/oauth/internal/domain"
) )
type Repository interface { type Repository interface {
Get(ctx context.Context, id string) (*domain.Client, error) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
} }
var ErrNotExist = errors.New("client with the specified ID does not exist")

View File

@ -3,10 +3,10 @@ package http
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"net/url" "net/url"
"strings" "strings"
"github.com/pkg/errors"
"github.com/tomnomnom/linkheader" "github.com/tomnomnom/linkheader"
http "github.com/valyala/fasthttp" http "github.com/valyala/fasthttp"
"willnorris.com/go/microformats" "willnorris.com/go/microformats"
@ -20,16 +20,14 @@ type httpClientRepository struct {
} }
const ( const (
HApp string = "h-app" relRedirectURI string = "redirect_uri"
HXApp string = "h-x-app"
KeyName string = "name" hApp string = "h-app"
KeyLogo string = "logo" hXApp string = "h-x-app"
KeyURL string = "url"
ValueValue string = "value" propertyLogo string = "logo"
propertyName string = "name"
RelRedirectURI string = "redirect_uri" propertyURL string = "url"
) )
func NewHTTPClientRepository(c *http.Client) client.Repository { func NewHTTPClientRepository(c *http.Client) client.Repository {
@ -38,93 +36,134 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
} }
} }
func (repo *httpClientRepository) Get(ctx context.Context, id string) (*domain.Client, error) { func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
u, err := url.Parse(id)
if err != nil {
return nil, errors.Wrap(err, "failed to parse id as url")
}
req := http.AcquireRequest() req := http.AcquireRequest()
defer http.ReleaseRequest(req) defer http.ReleaseRequest(req)
req.SetRequestURI(u.String()) req.SetRequestURI(id.String())
req.Header.SetMethod(http.MethodGet) req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse() resp := http.AcquireResponse()
defer http.ReleaseResponse(resp) defer http.ReleaseResponse(resp)
if err := repo.client.Do(req, resp); err != nil { if err := repo.client.Do(req, resp); err != nil {
return nil, errors.Wrap(err, "failed to make a request to the client") return nil, fmt.Errorf("failed to make a request to the client: %w", err)
} }
client := domain.NewClient() if resp.StatusCode() == http.StatusNotFound {
client.ID = id return nil, client.ErrNotExist
}
for _, l := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) { client := &domain.Client{
if !strings.Contains(l.Rel, "redirect_uri") { ID: id,
Logo: make([]*domain.URL, 0),
Name: extractValues(resp, propertyName),
RedirectURI: extractEndpoints(resp, relRedirectURI),
URL: make([]*domain.URL, 0),
}
for _, v := range extractValues(resp, propertyLogo) {
u, err := domain.NewURL(v)
if err != nil {
continue continue
} }
client.RedirectURI = append(client.RedirectURI, l.URL) client.Logo = append(client.Logo, u)
} }
data := microformats.Parse(bytes.NewReader(resp.Body()), u) for _, v := range extractValues(resp, propertyURL) {
u, err := domain.NewURL(v)
for _, item := range data.Items { if err != nil {
if len(item.Type) == 0 && !strings.EqualFold(item.Type[0], HApp) &&
!strings.EqualFold(item.Type[0], HXApp) {
continue continue
} }
populateProperties(item.Properties, client) client.URL = append(client.URL, u)
} }
populateRels(data.Rels, client)
return client, nil return client, nil
} }
func populateProperties(src map[string][]interface{}, dst *domain.Client) { func extractEndpoints(resp *http.Response, name string) []*domain.URL {
for key, property := range src { results := make([]*domain.URL, 0)
if len(property) == 0 { endpoints, _ := extractEndpointsFromHeader(resp, name)
results = append(results, endpoints...)
endpoints, _ = extractEndpointsFromBody(resp, name)
results = append(results, endpoints...)
return results
}
func extractValues(resp *http.Response, key string) []string {
results := make([]string, 0)
for _, item := range microformats.Parse(bytes.NewReader(resp.Body()), nil).Items {
if len(item.Type) == 0 || (item.Type[0] != hApp && item.Type[0] != hXApp) {
continue continue
} }
switch key { properties, ok := item.Properties[key]
case KeyName: if !ok || len(properties) == 0 {
dst.Name = getString(property) return nil
case KeyLogo: }
for i := range property {
switch val := property[i].(type) { for j := range properties {
case string: switch p := properties[j].(type) {
dst.Logo = val case string:
case map[string]string: results = append(results, p)
dst.Logo = val[ValueValue] case map[string][]interface{}:
for _, val := range p["value"] {
v, ok := val.(string)
if !ok {
continue
}
results = append(results, v)
} }
} }
case KeyURL:
dst.URL = getString(property)
} }
return results
} }
return nil
} }
func populateRels(src map[string][]string, dst *domain.Client) { func extractEndpointsFromHeader(resp *http.Response, name string) ([]*domain.URL, error) {
for key, values := range src { results := make([]*domain.URL, 0)
if !strings.EqualFold(key, RelRedirectURI) {
for _, link := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) {
if !strings.EqualFold(link.Rel, name) {
continue continue
} }
for i := range values { u := http.AcquireURI()
dst.RedirectURI = append(dst.RedirectURI, values[i]) if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil {
return nil, err
} }
}
}
func getString(property []interface{}) string { results = append(results, &domain.URL{URI: u})
for i := range property {
val, _ := property[i].(string)
return val
} }
return "" return results, nil
}
func extractEndpointsFromBody(resp *http.Response, name string) ([]*domain.URL, error) {
host, err := url.Parse(string(resp.Header.Peek(http.HeaderHost)))
if err != nil {
return nil, fmt.Errorf("cannot parse host header: %w", err)
}
endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), host).Rels[name]
if !ok || len(endpoints) == 0 {
return nil, nil
}
results := make([]*domain.URL, 0)
for i := range endpoints {
u := http.AcquireURI()
u.Update(endpoints[i])
results = append(results, &domain.URL{URI: u})
}
return results, nil
} }

View File

@ -2,6 +2,7 @@ package http_test
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -11,7 +12,7 @@ import (
repository "source.toby3d.me/website/oauth/internal/client/repository/http" repository "source.toby3d.me/website/oauth/internal/client/repository/http"
"source.toby3d.me/website/oauth/internal/common" "source.toby3d.me/website/oauth/internal/common"
"source.toby3d.me/website/oauth/internal/domain" "source.toby3d.me/website/oauth/internal/domain"
"source.toby3d.me/website/oauth/internal/util" "source.toby3d.me/website/oauth/internal/testing/httptest"
) )
const testBody string = ` const testBody string = `
@ -20,13 +21,13 @@ const testBody string = `
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Example App</title> <title>%[1]s</title>
<link rel="redirect_uri" href="/redirect"> <link rel="redirect_uri" href="%[4]s">
</head> </head>
<body> <body>
<div class="h-app"> <div class="h-app h-x-app">
<img src="/logo.png" class="u-logo"> <img class="u-logo" src="%[3]s">
<a href="/" class="u-url p-name">Example App</a> <a class="u-url p-name" href="%[2]s">%[1]s</a>
</div> </div>
</body> </body>
</html> </html>
@ -35,17 +36,37 @@ const testBody string = `
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
t.Parallel() t.Parallel()
client, _, cleanup := util.TestServe(t, func(ctx *http.RequestCtx) { client := domain.TestClient(t)
ctx.Response.Header.Set(http.HeaderLink, `<https://app.example.net/redirect>; rel="redirect_uri">`) httpClient, _, cleanup := httptest.New(t, testHandler(t, client))
ctx.SetStatusCode(http.StatusOK)
ctx.SetContentType(common.MIMETextHTML)
ctx.SetBodyString(testBody)
})
t.Cleanup(cleanup) t.Cleanup(cleanup)
c := domain.TestClient(t) result, err := repository.NewHTTPClientRepository(httpClient).Get(context.TODO(), client.ID)
result, err := repository.NewHTTPClientRepository(client).Get(context.TODO(), c.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, c, result)
assert.Equal(t, client.Name, result.Name)
assert.Equal(t, client.ID.String(), result.ID.String())
for i := range client.URL {
assert.Equal(t, client.URL[i].String(), result.URL[i].String())
}
for i := range client.Logo {
assert.Equal(t, client.Logo[i].String(), result.Logo[i].String())
}
for i := range client.RedirectURI {
assert.Equal(t, client.RedirectURI[i].String(), result.RedirectURI[i].String())
}
}
func testHandler(tb testing.TB, client *domain.Client) http.RequestHandler {
tb.Helper()
return func(ctx *http.RequestCtx) {
ctx.Response.Header.Set(http.HeaderLink, `<`+client.RedirectURI[0].String()+`>; rel="redirect_uri"`)
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, fmt.Sprintf(
testBody, client.Name[0], client.URL[0].String(), client.Logo[0].String(),
client.RedirectURI[1].String(),
))
}
} }

View File

@ -10,26 +10,32 @@ import (
) )
type memoryClientRepository struct { type memoryClientRepository struct {
clients *sync.Map store *sync.Map
} }
const Key string = "clients" const DefaultPathPrefix string = "clients"
func NewMemoryClientRepository(clients *sync.Map) client.Repository { func NewMemoryClientRepository(store *sync.Map) client.Repository {
return &memoryClientRepository{ return &memoryClientRepository{
clients: clients, store: store,
} }
} }
func (repo *memoryClientRepository) Get(ctx context.Context, id string) (*domain.Client, error) { func (repo *memoryClientRepository) Create(ctx context.Context, client *domain.Client) error {
src, ok := repo.clients.Load(path.Join(Key, id)) repo.store.Store(path.Join(DefaultPathPrefix, client.ID.String()), client)
return nil
}
func (repo *memoryClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, id.String()))
if !ok { if !ok {
return nil, nil return nil, client.ErrNotExist
} }
c, ok := src.(*domain.Client) c, ok := src.(*domain.Client)
if !ok { if !ok {
return nil, nil return nil, client.ErrNotExist
} }
return c, nil return c, nil

View File

@ -16,10 +16,10 @@ import (
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
t.Parallel() t.Parallel()
store := new(sync.Map)
client := domain.TestClient(t) client := domain.TestClient(t)
store.Store(path.Join(repository.Key, client.ID), client) store := new(sync.Map)
store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
result, err := repository.NewMemoryClientRepository(store).Get(context.TODO(), client.ID) result, err := repository.NewMemoryClientRepository(store).Get(context.TODO(), client.ID)
require.NoError(t, err) require.NoError(t, err)

View File

@ -2,10 +2,14 @@ package client
import ( import (
"context" "context"
"errors"
"source.toby3d.me/website/oauth/internal/domain" "source.toby3d.me/website/oauth/internal/domain"
) )
type UseCase interface { type UseCase interface {
Discovery(ctx context.Context, clientID string) (*domain.Client, error) // Discovery returns client public information bu ClientID URL.
Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
} }
var ErrInvalidMe = errors.New("provided me is invalid")

View File

@ -2,27 +2,26 @@ package usecase
import ( import (
"context" "context"
"fmt"
"github.com/pkg/errors"
"source.toby3d.me/website/oauth/internal/client" "source.toby3d.me/website/oauth/internal/client"
"source.toby3d.me/website/oauth/internal/domain" "source.toby3d.me/website/oauth/internal/domain"
) )
type clientUseCase struct { type clientUseCase struct {
clients client.Repository repo client.Repository
} }
func NewClientUseCase(clients client.Repository) client.UseCase { func NewClientUseCase(repo client.Repository) client.UseCase {
return &clientUseCase{ return &clientUseCase{
clients: clients, repo: repo,
} }
} }
func (useCase *clientUseCase) Discovery(ctx context.Context, clientID string) (*domain.Client, error) { func (useCase *clientUseCase) Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
c, err := useCase.clients.Get(ctx, clientID) c, err := useCase.repo.Get(ctx, id)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get client information") return nil, fmt.Errorf("cannot discovery client by id: %w", err)
} }
return c, nil return c, nil

View File

@ -17,13 +17,13 @@ import (
func TestDiscovery(t *testing.T) { func TestDiscovery(t *testing.T) {
t.Parallel() t.Parallel()
store := new(sync.Map)
client := domain.TestClient(t) client := domain.TestClient(t)
store.Store(path.Join(repository.Key, client.ID), client) store := new(sync.Map)
store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)).Discovery(context.TODO(), result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)).
client.ID) Discovery(context.TODO(), client.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, client, result) assert.Equal(t, client, result)
} }