diff --git a/internal/domain/client_id.go b/internal/domain/client_id.go new file mode 100644 index 0000000..2edf7a9 --- /dev/null +++ b/internal/domain/client_id.go @@ -0,0 +1,156 @@ +package domain + +import ( + "fmt" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/require" + http "github.com/valyala/fasthttp" + "golang.org/x/xerrors" + "inet.af/netaddr" +) + +// ClientID is a URL client identifier. +type ClientID struct { + cid *http.URI + valid bool +} + +//nolint: gochecknoglobals +var ( + localhostIPv4 = netaddr.MustParseIP("127.0.0.1") + localhostIPv6 = netaddr.MustParseIP("::1") +) + +func NewClientID(raw string) (*ClientID, error) { + cid := http.AcquireURI() + if err := cid.Parse(nil, []byte(raw)); err != nil { + return nil, Error{ + Code: "invalid_request", + Description: err.Error(), + URI: "https://indieauth.net/source/#client-identifier", + Frame: xerrors.Caller(1), + } + } + + scheme := string(cid.Scheme()) + if scheme != "http" && scheme != "https" { + return nil, Error{ + Code: "invalid_request", + Description: "client identifier URL MUST have either an https or http scheme", + URI: "https://indieauth.net/source/#client-identifier", + Frame: xerrors.Caller(1), + } + } + + path := string(cid.PathOriginal()) + if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") { + return nil, Error{ + Code: "invalid_request", + Description: "client identifier URL MUST contain a path component and MUST NOT contain " + + "single-dot or double-dot path segments", + URI: "https://indieauth.net/source/#client-identifier", + Frame: xerrors.Caller(1), + } + } + + if cid.Hash() != nil { + return nil, Error{ + Code: "invalid_request", + Description: "client identifier URL MUST NOT contain a fragment component", + URI: "https://indieauth.net/source/#client-identifier", + Frame: xerrors.Caller(1), + } + } + + if cid.Username() != nil || cid.Password() != nil { + return nil, Error{ + Code: "invalid_request", + Description: "client identifier URL MUST NOT contain a username or password component", + URI: "https://indieauth.net/source/#client-identifier", + Frame: xerrors.Caller(1), + } + } + + domain := string(cid.Host()) + if domain == "" { + return nil, Error{ + Code: "invalid_request", + Description: "client host name MUST be domain name or a loopback interface", + URI: "https://indieauth.net/source/#client-identifier", + Frame: xerrors.Caller(1), + } + } + + ip, err := netaddr.ParseIP(domain) + if err != nil { + ipPort, err := netaddr.ParseIPPort(domain) + if err != nil { + return &ClientID{cid: cid}, nil + } + + ip = ipPort.IP() + } + + if !ip.IsLoopback() && ip.Compare(localhostIPv4) != 0 && ip.Compare(localhostIPv6) != 0 { + return nil, Error{ + Code: "invalid_request", + Description: "client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 " + + "127.0.0.1 or IPv6 [::1]", + URI: "https://indieauth.net/source/#client-identifier", + Frame: xerrors.Caller(1), + } + } + + return &ClientID{cid: cid}, nil +} + +// TestClientID returns a valid random generated ClientID for tests. +func TestClientID(tb testing.TB) *ClientID { + tb.Helper() + + cid, err := NewClientID("https://app.example.com/") + require.NoError(tb, err) + + return cid +} + +// UnmarshalForm implements a custom form.Unmarshaler. +func (cid *ClientID) UnmarshalForm(v []byte) error { + clientId, err := NewClientID(string(v)) + if err != nil { + return fmt.Errorf("UnmarshalForm: %w", err) + } + defer http.ReleaseURI(clientId.cid) //nolint: wsl + + clientId.cid.CopyTo(cid.cid) + + return nil +} + +// URI returns copy of parsed *fasthttp.URI. +// This copy MUST be released via fasthttp.ReleaseURI. +func (cid *ClientID) URI() *http.URI { + u := http.AcquireURI() + cid.cid.CopyTo(u) + + return u +} + +func (cid *ClientID) URL() *url.URL { + return &url.URL{ + Scheme: string(cid.cid.Scheme()), + Host: string(cid.cid.Host()), + Path: string(cid.cid.Path()), + RawPath: string(cid.cid.PathOriginal()), + RawQuery: string(cid.cid.QueryString()), + Fragment: string(cid.cid.Hash()), + } +} + +// String returns string representation of client ID. +func (cid *ClientID) String() string { + return cid.cid.String() +} diff --git a/internal/domain/client_id_test.go b/internal/domain/client_id_test.go new file mode 100644 index 0000000..75f5969 --- /dev/null +++ b/internal/domain/client_id_test.go @@ -0,0 +1,79 @@ +package domain_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "source.toby3d.me/website/oauth/internal/domain" +) + +//nolint: funlen +func TestClientID(t *testing.T) { + t.Parallel() + + for _, testCase := range []struct { + name string + input string + isValid bool + }{{ + name: "valid", + input: "https://example.com/", + isValid: true, + }, { + name: "valid with path", + input: "https://example.com/username", + isValid: true, + }, { + name: "valid with query", + input: "https://example.com/users?id=100", + isValid: true, + }, { + name: "valid with port", + input: "https://example.com:8443/", + isValid: true, + }, { + name: "valid loopback", + input: "https://127.0.0.1:8443/", + isValid: true, + }, { + name: "missing scheme", + input: "example.com", + isValid: false, + }, { + name: "invalid scheme", + input: "mailto:user@example.com", + isValid: false, + }, { + name: "contains a double-dot path segment", + input: "https://example.com/foo/../bar", + isValid: false, + }, { + name: "contains a fragment", + input: "https://example.com/#me", + isValid: false, + }, { + name: "contains a username and password", + input: "https://user:pass@example.com/", + isValid: false, + }, { + name: "host is an IP address", + input: "https://172.28.92.51/", + isValid: false, + }} { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + result, err := domain.NewClientID(testCase.input) + if testCase.isValid { + require.NoError(t, err) + assert.Equal(t, testCase.input, result.String()) + } else { + assert.Error(t, err) + } + }) + } +}