diff --git a/internal/client/repository/http/http_client.go b/internal/client/repository/http/http_client.go index 8c73a29..e7f9b2d 100644 --- a/internal/client/repository/http/http_client.go +++ b/internal/client/repository/http/http_client.go @@ -3,6 +3,7 @@ package http import ( "context" "fmt" + "net" http "github.com/valyala/fasthttp" @@ -33,6 +34,19 @@ func NewHTTPClientRepository(c *http.Client) client.Repository { } func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) { + ips, err := net.LookupIP(cid.URL().Hostname()) + if err != nil { + return nil, fmt.Errorf("cannot resolve client IP by id: %w", err) + } + + for _, ip := range ips { + if !ip.IsLoopback() { + continue + } + + return nil, client.ErrNotExist + } + req := http.AcquireRequest() defer http.ReleaseRequest(req) req.SetRequestURI(cid.String()) diff --git a/internal/client/repository/memory/memory_client.go b/internal/client/repository/memory/memory_client.go index 8f91db3..e4d433f 100644 --- a/internal/client/repository/memory/memory_client.go +++ b/internal/client/repository/memory/memory_client.go @@ -3,6 +3,7 @@ package memory import ( "context" "fmt" + "net" "path" "sync" @@ -29,6 +30,18 @@ func (repo *memoryClientRepository) Create(ctx context.Context, client *domain.C } func (repo *memoryClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) { + // WARN(toby3d): more often than not, we will work from tests with + // non-existent clients, almost guaranteed to cause a resolution error. + ips, _ := net.LookupIP(id.URL().Hostname()) + + for _, ip := range ips { + if !ip.IsLoopback() { + continue + } + + return nil, client.ErrNotExist + } + src, ok := repo.store.Load(path.Join(DefaultPathPrefix, id.String())) if !ok { return nil, fmt.Errorf("cannot find client in store: %w", client.ErrNotExist) diff --git a/internal/client/usecase/client_ucase_test.go b/internal/client/usecase/client_ucase_test.go index 386fffb..ae263f7 100644 --- a/internal/client/usecase/client_ucase_test.go +++ b/internal/client/usecase/client_ucase_test.go @@ -2,11 +2,13 @@ package usecase_test import ( "context" + "errors" "path" "reflect" "sync" "testing" + "source.toby3d.me/toby3d/auth/internal/client" repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory" "source.toby3d.me/toby3d/auth/internal/client/usecase" "source.toby3d.me/toby3d/auth/internal/domain" @@ -15,18 +17,48 @@ import ( func TestDiscovery(t *testing.T) { t.Parallel() - client := domain.TestClient(t) - store := new(sync.Map) - store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client) + testClient, localhostClient := domain.TestClient(t), domain.TestClient(t) + localhostClient.ID, _ = domain.ParseClientID("http://localhost/") - result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)). - Discovery(context.Background(), client.ID) - if err != nil { - t.Fatal(err) + for _, client := range []*domain.Client{testClient, localhostClient} { + store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client) } - if !reflect.DeepEqual(result, client) { - t.Errorf("Discovery(%s) = %+v, want %+v", client.ID, result, client) + for _, tc := range []struct { + name string + in *domain.Client + out *domain.Client + expError error + }{{ + name: "default", + in: testClient, + out: testClient, + }, { + name: "localhost", + in: localhostClient, + expError: client.ErrNotExist, + }} { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)). + Discovery(context.Background(), tc.in.ID) + if tc.expError != nil && !errors.Is(err, tc.expError) { + t.Errorf("Discovery(%s) = %+v, want %+v", tc.in.ID, err, tc.expError) + + return + } + + if tc.expError == nil && err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(result, tc.out) { + t.Errorf("Discovery(%s) = %+v, want %+v", tc.in.ID, result, tc.out) + } + }) } } diff --git a/internal/domain/client_id.go b/internal/domain/client_id.go index 4b41165..8bc5179 100644 --- a/internal/domain/client_id.go +++ b/internal/domain/client_id.go @@ -107,7 +107,7 @@ func ParseClientID(src string) (*ClientID, error) { func TestClientID(tb testing.TB) *ClientID { tb.Helper() - clientID, err := ParseClientID("https://indieauth.example.com/") + clientID, err := ParseClientID("https://example.com/") if err != nil { tb.Fatal(err) }