diff --git a/internal/domain/client.go b/internal/domain/client.go index 8b7cae4..a28693a 100644 --- a/internal/domain/client.go +++ b/internal/domain/client.go @@ -1,33 +1,47 @@ package domain -import "testing" +import ( + "testing" + "github.com/stretchr/testify/require" +) + +// Client describes the client requesting data about the user. type Client struct { - RedirectURI []string - ID string - Logo string - Name string - URL string -} - -func NewClient() *Client { - c := new(Client) - c.RedirectURI = make([]string, 0) - - return c + ID *ClientID + Logo []*URL + RedirectURI []*URL + URL []*URL + Name []string } +// TestClient returns a valid Client with the generated test data filled in. func TestClient(tb testing.TB) *Client { tb.Helper() + url, err := NewURL("https://app.example.com/") + require.NoError(tb, err) + + logo, err := NewURL("https://app.example.com/logo.png") + require.NoError(tb, err) + + redirects := make([]*URL, 0) + + for _, redirect := range []string{ + "https://app.example.net/redirect", + "https://app.example.com/redirect", + } { + u, err := NewURL(redirect) + require.NoError(tb, err) + + redirects = append(redirects, u) + } + return &Client{ - ID: "https://app.example.com/", - Name: "Example App", - Logo: "https://app.example.com/logo.png", - URL: "https://app.example.com/", - RedirectURI: []string{ - "https://app.example.net/redirect", - "https://app.example.com/redirect", - }, + ID: TestClientID(tb), + Name: []string{"Example App"}, + URL: []*URL{url}, + Logo: []*URL{logo}, + RedirectURI: redirects, } } diff --git a/internal/domain/client_id.go b/internal/domain/client_id.go index 2edf7a9..2d7b2ff 100644 --- a/internal/domain/client_id.go +++ b/internal/domain/client_id.go @@ -3,6 +3,7 @@ package domain import ( "fmt" "net/url" + "strconv" "strings" "testing" @@ -14,8 +15,8 @@ import ( // ClientID is a URL client identifier. type ClientID struct { - cid *http.URI - valid bool + clientID *http.URI + valid bool } //nolint: gochecknoglobals @@ -25,8 +26,8 @@ var ( ) func NewClientID(raw string) (*ClientID, error) { - cid := http.AcquireURI() - if err := cid.Parse(nil, []byte(raw)); err != nil { + clientID := http.AcquireURI() + if err := clientID.Parse(nil, []byte(raw)); err != nil { return nil, Error{ Code: "invalid_request", Description: err.Error(), @@ -35,7 +36,7 @@ func NewClientID(raw string) (*ClientID, error) { } } - scheme := string(cid.Scheme()) + scheme := string(clientID.Scheme()) if scheme != "http" && scheme != "https" { return nil, Error{ Code: "invalid_request", @@ -45,7 +46,7 @@ func NewClientID(raw string) (*ClientID, error) { } } - path := string(cid.PathOriginal()) + path := string(clientID.PathOriginal()) if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") { return nil, Error{ Code: "invalid_request", @@ -56,7 +57,7 @@ func NewClientID(raw string) (*ClientID, error) { } } - if cid.Hash() != nil { + if clientID.Hash() != nil { return nil, Error{ Code: "invalid_request", Description: "client identifier URL MUST NOT contain a fragment component", @@ -65,7 +66,7 @@ func NewClientID(raw string) (*ClientID, error) { } } - if cid.Username() != nil || cid.Password() != nil { + if clientID.Username() != nil || clientID.Password() != nil { return nil, Error{ Code: "invalid_request", Description: "client identifier URL MUST NOT contain a username or password component", @@ -74,7 +75,7 @@ func NewClientID(raw string) (*ClientID, error) { } } - domain := string(cid.Host()) + domain := string(clientID.Host()) if domain == "" { return nil, Error{ Code: "invalid_request", @@ -88,7 +89,7 @@ func NewClientID(raw string) (*ClientID, error) { if err != nil { ipPort, err := netaddr.ParseIPPort(domain) if err != nil { - return &ClientID{cid: cid}, nil + return &ClientID{clientID: clientID}, nil } ip = ipPort.IP() @@ -104,28 +105,43 @@ func NewClientID(raw string) (*ClientID, error) { } } - return &ClientID{cid: cid}, nil + return &ClientID{clientID: clientID}, nil } // TestClientID returns a valid random generated ClientID for tests. func TestClientID(tb testing.TB) *ClientID { tb.Helper() - cid, err := NewClientID("https://app.example.com/") + clientID, err := NewClientID("https://app.example.com/") require.NoError(tb, err) - return cid + return clientID } // UnmarshalForm implements a custom form.Unmarshaler. func (cid *ClientID) UnmarshalForm(v []byte) error { - clientId, err := NewClientID(string(v)) + clientID, err := NewClientID(string(v)) if err != nil { return fmt.Errorf("UnmarshalForm: %w", err) } - defer http.ReleaseURI(clientId.cid) //nolint: wsl - clientId.cid.CopyTo(cid.cid) + *cid = *clientID + + return nil +} + +func (cid *ClientID) UnmarshalJSON(v []byte) error { + src, err := strconv.Unquote(string(v)) + if err != nil { + return err + } + + clientID, err := NewClientID(src) + if err != nil { + return fmt.Errorf("UnmarshalJSON: %w", err) + } + + *cid = *clientID return nil } @@ -134,23 +150,23 @@ func (cid *ClientID) UnmarshalForm(v []byte) error { // This copy MUST be released via fasthttp.ReleaseURI. func (cid *ClientID) URI() *http.URI { u := http.AcquireURI() - cid.cid.CopyTo(u) + cid.clientID.CopyTo(u) return u } func (cid *ClientID) URL() *url.URL { return &url.URL{ - Scheme: string(cid.cid.Scheme()), - Host: string(cid.cid.Host()), - Path: string(cid.cid.Path()), - RawPath: string(cid.cid.PathOriginal()), - RawQuery: string(cid.cid.QueryString()), - Fragment: string(cid.cid.Hash()), + Scheme: string(cid.clientID.Scheme()), + Host: string(cid.clientID.Host()), + Path: string(cid.clientID.Path()), + RawPath: string(cid.clientID.PathOriginal()), + RawQuery: string(cid.clientID.QueryString()), + Fragment: string(cid.clientID.Hash()), } } // String returns string representation of client ID. func (cid *ClientID) String() string { - return cid.cid.String() + return cid.clientID.String() } diff --git a/internal/domain/code_challenge_method.go b/internal/domain/code_challenge_method.go index 1178ec9..066a029 100644 --- a/internal/domain/code_challenge_method.go +++ b/internal/domain/code_challenge_method.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "hash" + "strconv" "strings" ) @@ -84,6 +85,22 @@ func (ccm *CodeChallengeMethod) UnmarshalForm(v []byte) error { return nil } +func (ccm *CodeChallengeMethod) UnmarshalJSON(v []byte) error { + src, err := strconv.Unquote(string(v)) + if err != nil { + return err + } + + method, err := ParseCodeChallengeMethod(src) + if err != nil { + return fmt.Errorf("code_challenge_method: %w", err) + } + + *ccm = method + + return nil +} + // String returns string representation of code challenge method. func (ccm CodeChallengeMethod) String() string { return ccm.slug diff --git a/internal/domain/error.go b/internal/domain/error.go index 6f8c864..77812f0 100644 --- a/internal/domain/error.go +++ b/internal/domain/error.go @@ -6,6 +6,7 @@ import ( "golang.org/x/xerrors" ) +// Error describes the data of a typical error. //nolint: tagliatelle type Error struct { Code string `json:"error"` @@ -23,15 +24,17 @@ func (e Error) Format(s fmt.State, r rune) { } func (e Error) FormatError(p xerrors.Printer) error { - p.Printf("%s: %s", e.Code, e.Description) + p.Print(e.Description) if e.URI != "" { - p.Printf(": %s", e.URI) + p.Print(": ", e.URI, "\n") } - if p.Detail() { - e.Frame.Format(p) + if !p.Detail() { + return e } + e.Frame.Format(p) + return nil } diff --git a/internal/domain/grant_type.go b/internal/domain/grant_type.go index 2dea568..3de39e3 100644 --- a/internal/domain/grant_type.go +++ b/internal/domain/grant_type.go @@ -3,6 +3,7 @@ package domain import ( "errors" "fmt" + "strconv" "strings" ) @@ -16,13 +17,19 @@ type GrantType struct { var ( GrantTypeUndefined = GrantType{slug: ""} GrantTypeAuthorizationCode = GrantType{slug: "authorization_code"} + + // TicketAuth extension. + GrantTypeTicket = GrantType{slug: "ticket"} ) var ErrGrantTypeUnknown = errors.New("unknown grant type") func ParseGrantType(slug string) (GrantType, error) { - if strings.ToLower(slug) == GrantTypeAuthorizationCode.slug { + switch strings.ToLower(slug) { + case GrantTypeAuthorizationCode.slug: return GrantTypeAuthorizationCode, nil + case GrantTypeTicket.slug: + return GrantTypeTicket, nil } return GrantTypeUndefined, fmt.Errorf("%w: %s", ErrGrantTypeUnknown, slug) @@ -39,6 +46,22 @@ func (gt *GrantType) UnmarshalForm(src []byte) error { return nil } +func (gt *GrantType) UnmarshalJSON(v []byte) error { + src, err := strconv.Unquote(string(v)) + if err != nil { + return err + } + + responseType, err := ParseGrantType(src) + if err != nil { + return fmt.Errorf("grant_type: %w", err) + } + + *gt = responseType + + return nil +} + func (gt GrantType) String() string { return gt.slug } diff --git a/internal/domain/login.go b/internal/domain/login.go deleted file mode 100644 index 960bc2c..0000000 --- a/internal/domain/login.go +++ /dev/null @@ -1,80 +0,0 @@ -package domain - -import ( - "testing" - "time" - - "github.com/brianvoe/gofakeit/v6" - "github.com/stretchr/testify/require" - - "source.toby3d.me/website/oauth/internal/random" -) - -type Login struct { - PKCE - CreatedAt time.Time - CompletedAt time.Time - Scopes []string - ClientID string - RedirectURI string - MeEntered string - MeResolved string - Code string - Provider string - IsCompleted bool -} - -//nolint: gomnd -func TestLogin(tb testing.TB) *Login { - tb.Helper() - - code, err := random.String(8) - require.NoError(tb, err) - - return &Login{ - CreatedAt: gofakeit.Date(), - CompletedAt: time.Time{}, - PKCE: PKCE{ - Method: PKCEMethodS256, - Challenge: "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo", - Verifier: "a6128783714cfda1d388e2e98b6ae8221ac31aca31959e59512c59f5", - }, - Scopes: []string{"profile", "create", "update", "delete"}, - ClientID: "https://app.example.com/", - RedirectURI: "https://app.example.com/redirect", - MeEntered: "user.example.net", - MeResolved: "https://user.example.net/", - Code: code, - Provider: "mastodon", - IsCompleted: false, - } -} - -//nolint: gomnd -func TestLoginInvalid(tb testing.TB) *Login { - tb.Helper() - - challenge, err := random.String(42) - require.NoError(tb, err) - - verifier, err := random.String(64) - require.NoError(tb, err) - - return &Login{ - CreatedAt: time.Now().UTC().Add(-1 * time.Hour), - CompletedAt: time.Time{}, - PKCE: PKCE{ - Method: "UNDEFINED", - Challenge: challenge, - Verifier: verifier, - }, - Scopes: []string{}, - ClientID: "whoisit", - RedirectURI: "redirect", - MeEntered: "whoami", - MeResolved: "", - Code: "", - Provider: "undefined", - IsCompleted: true, - } -} diff --git a/internal/domain/me.go b/internal/domain/me.go index 702457c..42a4503 100644 --- a/internal/domain/me.go +++ b/internal/domain/me.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "net/url" + "strconv" "strings" "testing" @@ -109,14 +110,29 @@ func TestMe(tb testing.TB) *Me { } // UnmarshalForm parses the value of the form key into the Me domain. -func (m *Me) UnmarshalForm(v []byte) (err error) { +func (m *Me) UnmarshalForm(v []byte) error { me, err := NewMe(string(v)) if err != nil { return fmt.Errorf("UnmarshalForm: %w", err) } - defer http.ReleaseURI(me.me) //nolint: wsl - me.me.CopyTo(m.me) + *m = *me + + return nil +} + +func (m *Me) UnmarshalJSON(v []byte) error { + src, err := strconv.Unquote(string(v)) + if err != nil { + return err + } + + me, err := NewMe(src) + if err != nil { + return fmt.Errorf("UnmarshalForm: %w", err) + } + + *m = *me return nil } @@ -124,6 +140,10 @@ func (m *Me) UnmarshalForm(v []byte) (err error) { // URI returns copy of parsed Me in *fasthttp.URI representation. // This copy MUST be released via fasthttp.ReleaseURI. func (m *Me) URI() *http.URI { + if m.me == nil { + return nil + } + u := http.AcquireURI() m.me.CopyTo(u) @@ -132,6 +152,10 @@ func (m *Me) URI() *http.URI { // URL returns copy of parsed Me in *url.URL representation. func (m *Me) URL() *url.URL { + if m.me == nil { + return nil + } + return &url.URL{ Scheme: string(m.me.Scheme()), Host: string(m.me.Host()), @@ -143,6 +167,10 @@ func (m *Me) URL() *url.URL { } // String returns string representation of Me. -func (m Me) String() string { +func (m *Me) String() string { + if m.me == nil { + return "" + } + return m.me.String() } diff --git a/internal/domain/profile.go b/internal/domain/profile.go index 00d9877..a2bc396 100644 --- a/internal/domain/profile.go +++ b/internal/domain/profile.go @@ -1,21 +1,43 @@ package domain -import "testing" +import ( + "testing" + "github.com/stretchr/testify/require" +) + +// Profile describes the data about the user. type Profile struct { - Name string - URL string - Photo string - Email string + Photo []*URL + URL []*URL + Email []Email + Name []string } +// NewProfile creates a new empty Profile. +func NewProfile() *Profile { + return &Profile{ + Email: make([]Email, 0), + Name: make([]string, 0), + Photo: make([]*URL, 0), + URL: make([]*URL, 0), + } +} + +// TestProfile returns a valid Profile with the generated test data filled in. func TestProfile(tb testing.TB) *Profile { tb.Helper() + photo, err := NewURL("https://user.example.net/photo.jpg") + require.NoError(tb, err) + + url, err := NewURL("https://user.example.net/") + require.NoError(tb, err) + return &Profile{ - Name: "Example User", - URL: "https://user.example.net/", - Photo: "https://user.example.net/photo.jpg", - Email: "user@example.net", + Email: []Email{"user@example.net"}, + Name: []string{"Example User"}, + Photo: []*URL{photo}, + URL: []*URL{url}, } } diff --git a/internal/domain/response_type.go b/internal/domain/response_type.go index 41a688e..8d5d23f 100644 --- a/internal/domain/response_type.go +++ b/internal/domain/response_type.go @@ -3,6 +3,7 @@ package domain import ( "errors" "fmt" + "strconv" "strings" ) @@ -58,6 +59,22 @@ func (rt *ResponseType) UnmarshalForm(src []byte) error { return nil } +func (rt *ResponseType) UnmarshalJSON(v []byte) error { + src, err := strconv.Unquote(string(v)) + if err != nil { + return err + } + + responseType, err := ParseResponseType(string(src)) + if err != nil { + return fmt.Errorf("response_type: %w", err) + } + + *rt = responseType + + return nil +} + func (rt ResponseType) String() string { return rt.slug } diff --git a/internal/domain/scope.go b/internal/domain/scope.go index 0e26320..e1d40f0 100644 --- a/internal/domain/scope.go +++ b/internal/domain/scope.go @@ -3,6 +3,7 @@ package domain import ( "errors" "fmt" + "strconv" "strings" ) @@ -12,6 +13,7 @@ type ( slug string } + // Scopes represent set of Scope domains. Scopes []Scope ) @@ -74,7 +76,7 @@ var slugsScopes = map[string]Scope{ // ParseScope parses scope slug into Scope domain. func ParseScope(slug string) (Scope, error) { - if scope, ok := slugsScopes[strings.ToLower(slug)]; !ok { + if scope, ok := slugsScopes[strings.ToLower(slug)]; ok { return scope, nil } @@ -93,7 +95,55 @@ func (s *Scope) UnmarshalForm(v []byte) (err error) { return nil } +func (s *Scope) UnmarshalJSON(v []byte) error { + src, err := strconv.Unquote(string(v)) + if err != nil { + return err + } + + scope, err := ParseScope(src) + if err != nil { + return fmt.Errorf("scope: %w", err) + } + + *s = scope + + return nil +} + +func (s *Scopes) UnmarshalJSON(v []byte) error { + src, err := strconv.Unquote(string(v)) + if err != nil { + return err + } + + result := make([]Scope, 0) + + for _, scope := range strings.Fields(src) { + s, err := ParseScope(scope) + if err != nil { + return fmt.Errorf("scope: %w", err) + } + + result = append(result, s) + } + + *s = result + + return nil +} + // String returns scope slug as string. func (s Scope) String() string { return s.slug } + +func (s Scopes) String() string { + scopes := make([]string, len(s)) + + for i := range s { + scopes[i] = s[i].String() + } + + return strings.Join(scopes, " ") +} diff --git a/internal/domain/scope_test.go b/internal/domain/scope_test.go new file mode 100644 index 0000000..2e1a13a --- /dev/null +++ b/internal/domain/scope_test.go @@ -0,0 +1,24 @@ +package domain_test + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "source.toby3d.me/website/oauth/internal/domain" +) + +func TestScopesUnmarshalJSON(t *testing.T) { + t.Parallel() + + result := &struct { + Scope domain.Scopes `json:"scope"` + }{} + require.NoError(t, json.Unmarshal([]byte(`{"scope": "read update delete"}`), result)) + + for _, scope := range []domain.Scope{domain.ScopeRead, domain.ScopeUpdate, domain.ScopeDelete} { + assert.Contains(t, result.Scope, scope) + } +} diff --git a/internal/domain/token.go b/internal/domain/token.go index 74b7ed9..7b1b387 100644 --- a/internal/domain/token.go +++ b/internal/domain/token.go @@ -1,48 +1,104 @@ package domain import ( - "strings" + "fmt" "testing" "time" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwt" "github.com/stretchr/testify/require" + http "github.com/valyala/fasthttp" "source.toby3d.me/website/oauth/internal/random" ) -type Token struct { - Scopes []string - AccessToken string - ClientID string - Me string +type ( + // Token describes the data of the token used by the clients. + Token struct { + AccessToken string + ClientID *ClientID + Me *Me + Scope Scopes + } + + NewTokenOptions struct { + Algorithm string + Expiration time.Duration + Issuer *ClientID + NonceLength int + Scope Scopes + Secret interface{} + Subject *Me + } +) + +var DefaultNewTokenOptions = NewTokenOptions{ + NonceLength: 32, + Algorithm: "HS256", } -func NewToken() *Token { - t := new(Token) - t.Scopes = make([]string, 0) +func NewToken(opts NewTokenOptions) (*Token, error) { + if opts.NonceLength == 0 { + opts.NonceLength = DefaultNewTokenOptions.NonceLength + } - return t + if opts.Algorithm == "" { + opts.Algorithm = DefaultNewTokenOptions.Algorithm + } + + now := time.Now().UTC().Round(time.Second) + + nonce, err := random.String(opts.NonceLength) + if err != nil { + return nil, fmt.Errorf("cannot generate nonce: %w", err) + } + + t := jwt.New() + t.Set(jwt.IssuerKey, opts.Issuer.String()) + t.Set(jwt.SubjectKey, opts.Subject.String()) + t.Set(jwt.NotBeforeKey, now) + t.Set(jwt.IssuedAtKey, now) + t.Set("scope", opts.Scope) + t.Set("nonce", nonce) + + if opts.Expiration != 0 { + t.Set(jwt.ExpirationKey, now.Add(opts.Expiration)) + } + + accessToken, err := jwt.Sign(t, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret) + if err != nil { + return nil, fmt.Errorf("cannot sign a new access token: %w", err) + } + + return &Token{ + AccessToken: string(accessToken), + ClientID: opts.Issuer, + Me: opts.Subject, + Scope: opts.Scope, + }, err } +// TestToken returns a valid Token with the generated test data filled in. func TestToken(tb testing.TB) *Token { tb.Helper() - require := require.New(tb) + nonce, err := random.String(42) + require.NoError(tb, err) - nonce, err := random.String(50) - require.NoError(err) - - client := TestClient(tb) - profile := TestProfile(tb) - now := time.Now().UTC().Round(time.Second) - scopes := []string{"create", "update", "delete"} t := jwt.New() + cid := TestClientID(tb) + me := TestMe(tb) + now := time.Now().UTC().Round(time.Second) + scope := []Scope{ + ScopeCreate, + ScopeUpdate, + ScopeDelete, + } - // required - t.Set(jwt.IssuerKey, client.ID) // NOTE(toby3d): client_id - t.Set(jwt.SubjectKey, profile.URL) // NOTE(toby3d): me + // NOTE(toby3d): required + t.Set(jwt.IssuerKey, cid.String()) + t.Set(jwt.SubjectKey, me.me.String()) // TODO(toby3d): t.Set(jwt.AudienceKey, nil) t.Set(jwt.ExpirationKey, now.Add(1*time.Hour)) t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour)) @@ -50,16 +106,33 @@ func TestToken(tb testing.TB) *Token { // TODO(toby3d): t.Set(jwt.JwtIDKey, nil) // optional - t.Set("scope", strings.Join(scopes, " ")) + t.Set("scope", scope) t.Set("nonce", nonce) accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme")) - require.NoError(err) + require.NoError(tb, err) return &Token{ + ClientID: cid, + Me: me, + Scope: scope, AccessToken: string(accessToken), - ClientID: t.Issuer(), - Me: t.Subject(), - Scopes: scopes, } } + +// SetAuthHeader writes an Access Token to the request header. +func (t *Token) SetAuthHeader(r *http.Request) { + if t.AccessToken == "" { + return + } + + r.Header.Set(http.HeaderAuthorization, t.String()) +} + +func (t *Token) String() string { + if t.AccessToken == "" { + return "" + } + + return "Bearer " + string(t.AccessToken) +}