diff --git a/internal/domain/client.go b/internal/domain/client.go index 5e72d20..2b95416 100644 --- a/internal/domain/client.go +++ b/internal/domain/client.go @@ -1,6 +1,9 @@ package domain import ( + "bytes" + "net" + "strings" "testing" ) @@ -33,3 +36,66 @@ func TestClient(tb testing.TB) *Client { RedirectURI: redirects, } } + +// ValidateRedirectURI validates RedirectURI from request to ClientID or +// registered set of client RedirectURI. +// +// If the URL scheme, host or port of the redirect_uri in the request do not +// match that of the client_id, then the authorization endpoint SHOULD verify +// that the requested redirect_uri matches one of the redirect URLs published by +// the client, and SHOULD block the request from proceeding if not. +func (c *Client) ValidateRedirectURI(redirectURI *URL) bool { + if redirectURI == nil { + return false + } + + rHost, rPort, err := net.SplitHostPort(string(redirectURI.Host())) + if err != nil { + rHost = string(redirectURI.Host()) + } + + cHost, cPort, err := net.SplitHostPort(string(c.ID.clientID.Host())) + if err != nil { + cHost = string(c.ID.clientID.Host()) + } + + if bytes.EqualFold(redirectURI.Scheme(), c.ID.clientID.Scheme()) && + strings.EqualFold(rHost, cHost) && + strings.EqualFold(rPort, cPort) { + return true + } + + for i := range c.RedirectURI { + if redirectURI.String() != c.RedirectURI[i].String() { + continue + } + + return true + } + + return false +} + +func (c *Client) GetName() string { + if len(c.Name) < 1 { + return "" + } + + return c.Name[0] +} + +func (c *Client) GetURL() *URL { + if len(c.URL) < 1 { + return nil + } + + return c.URL[0] +} + +func (c *Client) GetLogo() *URL { + if len(c.Logo) < 1 { + return nil + } + + return c.Logo[0] +} diff --git a/internal/domain/client_test.go b/internal/domain/client_test.go new file mode 100644 index 0000000..f16f46d --- /dev/null +++ b/internal/domain/client_test.go @@ -0,0 +1,48 @@ +package domain_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + http "github.com/valyala/fasthttp" + + "source.toby3d.me/website/indieauth/internal/domain" +) + +func TestClient_ValidateRedirectURI(t *testing.T) { + t.Parallel() + + client := domain.TestClient(t) + + for _, testCase := range []struct { + name string + input func() *domain.URL + expResult bool + }{{ + name: "client_id prefix", + input: func() *domain.URL { + u := &domain.URL{ + URI: http.AcquireURI(), + } + client.ID.URI().CopyTo(u.URI) + u.SetPath("/callback") + + return u + }, + expResult: true, + }, { + name: "registered redirect_uri", + input: func() *domain.URL { + return client.RedirectURI[len(client.RedirectURI)-1] + }, + expResult: true, + }} { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, testCase.expResult, client.ValidateRedirectURI(testCase.input())) + }) + } +}