commit
2466cb02f0
4 changed files with 69 additions and 10 deletions
|
@ -3,6 +3,7 @@ package http
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
http "github.com/valyala/fasthttp"
|
||||
|
||||
|
@ -33,6 +34,19 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
|
|||
}
|
||||
|
||||
func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) {
|
||||
ips, err := net.LookupIP(cid.URL().Hostname())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot resolve client IP by id: %w", err)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if !ip.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, client.ErrNotExist
|
||||
}
|
||||
|
||||
req := http.AcquireRequest()
|
||||
defer http.ReleaseRequest(req)
|
||||
req.SetRequestURI(cid.String())
|
||||
|
|
|
@ -3,6 +3,7 @@ package memory
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
|
@ -29,6 +30,18 @@ func (repo *memoryClientRepository) Create(ctx context.Context, client *domain.C
|
|||
}
|
||||
|
||||
func (repo *memoryClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
|
||||
// WARN(toby3d): more often than not, we will work from tests with
|
||||
// non-existent clients, almost guaranteed to cause a resolution error.
|
||||
ips, _ := net.LookupIP(id.URL().Hostname())
|
||||
|
||||
for _, ip := range ips {
|
||||
if !ip.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, client.ErrNotExist
|
||||
}
|
||||
|
||||
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, id.String()))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot find client in store: %w", client.ErrNotExist)
|
||||
|
|
|
@ -2,11 +2,13 @@ package usecase_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"path"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"source.toby3d.me/toby3d/auth/internal/client"
|
||||
repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory"
|
||||
"source.toby3d.me/toby3d/auth/internal/client/usecase"
|
||||
"source.toby3d.me/toby3d/auth/internal/domain"
|
||||
|
@ -15,18 +17,48 @@ import (
|
|||
func TestDiscovery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := domain.TestClient(t)
|
||||
|
||||
store := new(sync.Map)
|
||||
store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
|
||||
testClient, localhostClient := domain.TestClient(t), domain.TestClient(t)
|
||||
localhostClient.ID, _ = domain.ParseClientID("http://localhost/")
|
||||
|
||||
result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)).
|
||||
Discovery(context.Background(), client.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
for _, client := range []*domain.Client{testClient, localhostClient} {
|
||||
store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result, client) {
|
||||
t.Errorf("Discovery(%s) = %+v, want %+v", client.ID, result, client)
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
in *domain.Client
|
||||
out *domain.Client
|
||||
expError error
|
||||
}{{
|
||||
name: "default",
|
||||
in: testClient,
|
||||
out: testClient,
|
||||
}, {
|
||||
name: "localhost",
|
||||
in: localhostClient,
|
||||
expError: client.ErrNotExist,
|
||||
}} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)).
|
||||
Discovery(context.Background(), tc.in.ID)
|
||||
if tc.expError != nil && !errors.Is(err, tc.expError) {
|
||||
t.Errorf("Discovery(%s) = %+v, want %+v", tc.in.ID, err, tc.expError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expError == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result, tc.out) {
|
||||
t.Errorf("Discovery(%s) = %+v, want %+v", tc.in.ID, result, tc.out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ func ParseClientID(src string) (*ClientID, error) {
|
|||
func TestClientID(tb testing.TB) *ClientID {
|
||||
tb.Helper()
|
||||
|
||||
clientID, err := ParseClientID("https://indieauth.example.com/")
|
||||
clientID, err := ParseClientID("https://example.com/")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue