🏷️ Improved exists domains

This commit is contained in:
Maxim Lebedev 2021-12-30 01:08:30 +05:00
parent 9a1bbd4c2c
commit 20b517fdb7
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
12 changed files with 377 additions and 170 deletions

View File

@ -1,33 +1,47 @@
package domain
import "testing"
import (
"testing"
"github.com/stretchr/testify/require"
)
// Client describes the client requesting data about the user.
type Client struct {
RedirectURI []string
ID string
Logo string
Name string
URL string
}
func NewClient() *Client {
c := new(Client)
c.RedirectURI = make([]string, 0)
return c
ID *ClientID
Logo []*URL
RedirectURI []*URL
URL []*URL
Name []string
}
// TestClient returns a valid Client with the generated test data filled in.
func TestClient(tb testing.TB) *Client {
tb.Helper()
url, err := NewURL("https://app.example.com/")
require.NoError(tb, err)
logo, err := NewURL("https://app.example.com/logo.png")
require.NoError(tb, err)
redirects := make([]*URL, 0)
for _, redirect := range []string{
"https://app.example.net/redirect",
"https://app.example.com/redirect",
} {
u, err := NewURL(redirect)
require.NoError(tb, err)
redirects = append(redirects, u)
}
return &Client{
ID: "https://app.example.com/",
Name: "Example App",
Logo: "https://app.example.com/logo.png",
URL: "https://app.example.com/",
RedirectURI: []string{
"https://app.example.net/redirect",
"https://app.example.com/redirect",
},
ID: TestClientID(tb),
Name: []string{"Example App"},
URL: []*URL{url},
Logo: []*URL{logo},
RedirectURI: redirects,
}
}

View File

@ -3,6 +3,7 @@ package domain
import (
"fmt"
"net/url"
"strconv"
"strings"
"testing"
@ -14,8 +15,8 @@ import (
// ClientID is a URL client identifier.
type ClientID struct {
cid *http.URI
valid bool
clientID *http.URI
valid bool
}
//nolint: gochecknoglobals
@ -25,8 +26,8 @@ var (
)
func NewClientID(raw string) (*ClientID, error) {
cid := http.AcquireURI()
if err := cid.Parse(nil, []byte(raw)); err != nil {
clientID := http.AcquireURI()
if err := clientID.Parse(nil, []byte(raw)); err != nil {
return nil, Error{
Code: "invalid_request",
Description: err.Error(),
@ -35,7 +36,7 @@ func NewClientID(raw string) (*ClientID, error) {
}
}
scheme := string(cid.Scheme())
scheme := string(clientID.Scheme())
if scheme != "http" && scheme != "https" {
return nil, Error{
Code: "invalid_request",
@ -45,7 +46,7 @@ func NewClientID(raw string) (*ClientID, error) {
}
}
path := string(cid.PathOriginal())
path := string(clientID.PathOriginal())
if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") {
return nil, Error{
Code: "invalid_request",
@ -56,7 +57,7 @@ func NewClientID(raw string) (*ClientID, error) {
}
}
if cid.Hash() != nil {
if clientID.Hash() != nil {
return nil, Error{
Code: "invalid_request",
Description: "client identifier URL MUST NOT contain a fragment component",
@ -65,7 +66,7 @@ func NewClientID(raw string) (*ClientID, error) {
}
}
if cid.Username() != nil || cid.Password() != nil {
if clientID.Username() != nil || clientID.Password() != nil {
return nil, Error{
Code: "invalid_request",
Description: "client identifier URL MUST NOT contain a username or password component",
@ -74,7 +75,7 @@ func NewClientID(raw string) (*ClientID, error) {
}
}
domain := string(cid.Host())
domain := string(clientID.Host())
if domain == "" {
return nil, Error{
Code: "invalid_request",
@ -88,7 +89,7 @@ func NewClientID(raw string) (*ClientID, error) {
if err != nil {
ipPort, err := netaddr.ParseIPPort(domain)
if err != nil {
return &ClientID{cid: cid}, nil
return &ClientID{clientID: clientID}, nil
}
ip = ipPort.IP()
@ -104,28 +105,43 @@ func NewClientID(raw string) (*ClientID, error) {
}
}
return &ClientID{cid: cid}, nil
return &ClientID{clientID: clientID}, nil
}
// TestClientID returns a valid random generated ClientID for tests.
func TestClientID(tb testing.TB) *ClientID {
tb.Helper()
cid, err := NewClientID("https://app.example.com/")
clientID, err := NewClientID("https://app.example.com/")
require.NoError(tb, err)
return cid
return clientID
}
// UnmarshalForm implements a custom form.Unmarshaler.
func (cid *ClientID) UnmarshalForm(v []byte) error {
clientId, err := NewClientID(string(v))
clientID, err := NewClientID(string(v))
if err != nil {
return fmt.Errorf("UnmarshalForm: %w", err)
}
defer http.ReleaseURI(clientId.cid) //nolint: wsl
clientId.cid.CopyTo(cid.cid)
*cid = *clientID
return nil
}
func (cid *ClientID) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return err
}
clientID, err := NewClientID(src)
if err != nil {
return fmt.Errorf("UnmarshalJSON: %w", err)
}
*cid = *clientID
return nil
}
@ -134,23 +150,23 @@ func (cid *ClientID) UnmarshalForm(v []byte) error {
// This copy MUST be released via fasthttp.ReleaseURI.
func (cid *ClientID) URI() *http.URI {
u := http.AcquireURI()
cid.cid.CopyTo(u)
cid.clientID.CopyTo(u)
return u
}
func (cid *ClientID) URL() *url.URL {
return &url.URL{
Scheme: string(cid.cid.Scheme()),
Host: string(cid.cid.Host()),
Path: string(cid.cid.Path()),
RawPath: string(cid.cid.PathOriginal()),
RawQuery: string(cid.cid.QueryString()),
Fragment: string(cid.cid.Hash()),
Scheme: string(cid.clientID.Scheme()),
Host: string(cid.clientID.Host()),
Path: string(cid.clientID.Path()),
RawPath: string(cid.clientID.PathOriginal()),
RawQuery: string(cid.clientID.QueryString()),
Fragment: string(cid.clientID.Hash()),
}
}
// String returns string representation of client ID.
func (cid *ClientID) String() string {
return cid.cid.String()
return cid.clientID.String()
}

View File

@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"hash"
"strconv"
"strings"
)
@ -84,6 +85,22 @@ func (ccm *CodeChallengeMethod) UnmarshalForm(v []byte) error {
return nil
}
func (ccm *CodeChallengeMethod) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return err
}
method, err := ParseCodeChallengeMethod(src)
if err != nil {
return fmt.Errorf("code_challenge_method: %w", err)
}
*ccm = method
return nil
}
// String returns string representation of code challenge method.
func (ccm CodeChallengeMethod) String() string {
return ccm.slug

View File

@ -6,6 +6,7 @@ import (
"golang.org/x/xerrors"
)
// Error describes the data of a typical error.
//nolint: tagliatelle
type Error struct {
Code string `json:"error"`
@ -23,15 +24,17 @@ func (e Error) Format(s fmt.State, r rune) {
}
func (e Error) FormatError(p xerrors.Printer) error {
p.Printf("%s: %s", e.Code, e.Description)
p.Print(e.Description)
if e.URI != "" {
p.Printf(": %s", e.URI)
p.Print(": ", e.URI, "\n")
}
if p.Detail() {
e.Frame.Format(p)
if !p.Detail() {
return e
}
e.Frame.Format(p)
return nil
}

View File

@ -3,6 +3,7 @@ package domain
import (
"errors"
"fmt"
"strconv"
"strings"
)
@ -16,13 +17,19 @@ type GrantType struct {
var (
GrantTypeUndefined = GrantType{slug: ""}
GrantTypeAuthorizationCode = GrantType{slug: "authorization_code"}
// TicketAuth extension.
GrantTypeTicket = GrantType{slug: "ticket"}
)
var ErrGrantTypeUnknown = errors.New("unknown grant type")
func ParseGrantType(slug string) (GrantType, error) {
if strings.ToLower(slug) == GrantTypeAuthorizationCode.slug {
switch strings.ToLower(slug) {
case GrantTypeAuthorizationCode.slug:
return GrantTypeAuthorizationCode, nil
case GrantTypeTicket.slug:
return GrantTypeTicket, nil
}
return GrantTypeUndefined, fmt.Errorf("%w: %s", ErrGrantTypeUnknown, slug)
@ -39,6 +46,22 @@ func (gt *GrantType) UnmarshalForm(src []byte) error {
return nil
}
func (gt *GrantType) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return err
}
responseType, err := ParseGrantType(src)
if err != nil {
return fmt.Errorf("grant_type: %w", err)
}
*gt = responseType
return nil
}
func (gt GrantType) String() string {
return gt.slug
}

View File

@ -1,80 +0,0 @@
package domain
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/require"
"source.toby3d.me/website/oauth/internal/random"
)
type Login struct {
PKCE
CreatedAt time.Time
CompletedAt time.Time
Scopes []string
ClientID string
RedirectURI string
MeEntered string
MeResolved string
Code string
Provider string
IsCompleted bool
}
//nolint: gomnd
func TestLogin(tb testing.TB) *Login {
tb.Helper()
code, err := random.String(8)
require.NoError(tb, err)
return &Login{
CreatedAt: gofakeit.Date(),
CompletedAt: time.Time{},
PKCE: PKCE{
Method: PKCEMethodS256,
Challenge: "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo",
Verifier: "a6128783714cfda1d388e2e98b6ae8221ac31aca31959e59512c59f5",
},
Scopes: []string{"profile", "create", "update", "delete"},
ClientID: "https://app.example.com/",
RedirectURI: "https://app.example.com/redirect",
MeEntered: "user.example.net",
MeResolved: "https://user.example.net/",
Code: code,
Provider: "mastodon",
IsCompleted: false,
}
}
//nolint: gomnd
func TestLoginInvalid(tb testing.TB) *Login {
tb.Helper()
challenge, err := random.String(42)
require.NoError(tb, err)
verifier, err := random.String(64)
require.NoError(tb, err)
return &Login{
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
CompletedAt: time.Time{},
PKCE: PKCE{
Method: "UNDEFINED",
Challenge: challenge,
Verifier: verifier,
},
Scopes: []string{},
ClientID: "whoisit",
RedirectURI: "redirect",
MeEntered: "whoami",
MeResolved: "",
Code: "",
Provider: "undefined",
IsCompleted: true,
}
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
"testing"
@ -109,14 +110,29 @@ func TestMe(tb testing.TB) *Me {
}
// UnmarshalForm parses the value of the form key into the Me domain.
func (m *Me) UnmarshalForm(v []byte) (err error) {
func (m *Me) UnmarshalForm(v []byte) error {
me, err := NewMe(string(v))
if err != nil {
return fmt.Errorf("UnmarshalForm: %w", err)
}
defer http.ReleaseURI(me.me) //nolint: wsl
me.me.CopyTo(m.me)
*m = *me
return nil
}
func (m *Me) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return err
}
me, err := NewMe(src)
if err != nil {
return fmt.Errorf("UnmarshalForm: %w", err)
}
*m = *me
return nil
}
@ -124,6 +140,10 @@ func (m *Me) UnmarshalForm(v []byte) (err error) {
// URI returns copy of parsed Me in *fasthttp.URI representation.
// This copy MUST be released via fasthttp.ReleaseURI.
func (m *Me) URI() *http.URI {
if m.me == nil {
return nil
}
u := http.AcquireURI()
m.me.CopyTo(u)
@ -132,6 +152,10 @@ func (m *Me) URI() *http.URI {
// URL returns copy of parsed Me in *url.URL representation.
func (m *Me) URL() *url.URL {
if m.me == nil {
return nil
}
return &url.URL{
Scheme: string(m.me.Scheme()),
Host: string(m.me.Host()),
@ -143,6 +167,10 @@ func (m *Me) URL() *url.URL {
}
// String returns string representation of Me.
func (m Me) String() string {
func (m *Me) String() string {
if m.me == nil {
return ""
}
return m.me.String()
}

View File

@ -1,21 +1,43 @@
package domain
import "testing"
import (
"testing"
"github.com/stretchr/testify/require"
)
// Profile describes the data about the user.
type Profile struct {
Name string
URL string
Photo string
Email string
Photo []*URL
URL []*URL
Email []Email
Name []string
}
// NewProfile creates a new empty Profile.
func NewProfile() *Profile {
return &Profile{
Email: make([]Email, 0),
Name: make([]string, 0),
Photo: make([]*URL, 0),
URL: make([]*URL, 0),
}
}
// TestProfile returns a valid Profile with the generated test data filled in.
func TestProfile(tb testing.TB) *Profile {
tb.Helper()
photo, err := NewURL("https://user.example.net/photo.jpg")
require.NoError(tb, err)
url, err := NewURL("https://user.example.net/")
require.NoError(tb, err)
return &Profile{
Name: "Example User",
URL: "https://user.example.net/",
Photo: "https://user.example.net/photo.jpg",
Email: "user@example.net",
Email: []Email{"user@example.net"},
Name: []string{"Example User"},
Photo: []*URL{photo},
URL: []*URL{url},
}
}

View File

@ -3,6 +3,7 @@ package domain
import (
"errors"
"fmt"
"strconv"
"strings"
)
@ -58,6 +59,22 @@ func (rt *ResponseType) UnmarshalForm(src []byte) error {
return nil
}
func (rt *ResponseType) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return err
}
responseType, err := ParseResponseType(string(src))
if err != nil {
return fmt.Errorf("response_type: %w", err)
}
*rt = responseType
return nil
}
func (rt ResponseType) String() string {
return rt.slug
}

View File

@ -3,6 +3,7 @@ package domain
import (
"errors"
"fmt"
"strconv"
"strings"
)
@ -12,6 +13,7 @@ type (
slug string
}
// Scopes represent set of Scope domains.
Scopes []Scope
)
@ -74,7 +76,7 @@ var slugsScopes = map[string]Scope{
// ParseScope parses scope slug into Scope domain.
func ParseScope(slug string) (Scope, error) {
if scope, ok := slugsScopes[strings.ToLower(slug)]; !ok {
if scope, ok := slugsScopes[strings.ToLower(slug)]; ok {
return scope, nil
}
@ -93,7 +95,55 @@ func (s *Scope) UnmarshalForm(v []byte) (err error) {
return nil
}
func (s *Scope) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return err
}
scope, err := ParseScope(src)
if err != nil {
return fmt.Errorf("scope: %w", err)
}
*s = scope
return nil
}
func (s *Scopes) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return err
}
result := make([]Scope, 0)
for _, scope := range strings.Fields(src) {
s, err := ParseScope(scope)
if err != nil {
return fmt.Errorf("scope: %w", err)
}
result = append(result, s)
}
*s = result
return nil
}
// String returns scope slug as string.
func (s Scope) String() string {
return s.slug
}
func (s Scopes) String() string {
scopes := make([]string, len(s))
for i := range s {
scopes[i] = s[i].String()
}
return strings.Join(scopes, " ")
}

View File

@ -0,0 +1,24 @@
package domain_test
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"source.toby3d.me/website/oauth/internal/domain"
)
func TestScopesUnmarshalJSON(t *testing.T) {
t.Parallel()
result := &struct {
Scope domain.Scopes `json:"scope"`
}{}
require.NoError(t, json.Unmarshal([]byte(`{"scope": "read update delete"}`), result))
for _, scope := range []domain.Scope{domain.ScopeRead, domain.ScopeUpdate, domain.ScopeDelete} {
assert.Contains(t, result.Scope, scope)
}
}

View File

@ -1,48 +1,104 @@
package domain
import (
"strings"
"fmt"
"testing"
"time"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/stretchr/testify/require"
http "github.com/valyala/fasthttp"
"source.toby3d.me/website/oauth/internal/random"
)
type Token struct {
Scopes []string
AccessToken string
ClientID string
Me string
type (
// Token describes the data of the token used by the clients.
Token struct {
AccessToken string
ClientID *ClientID
Me *Me
Scope Scopes
}
NewTokenOptions struct {
Algorithm string
Expiration time.Duration
Issuer *ClientID
NonceLength int
Scope Scopes
Secret interface{}
Subject *Me
}
)
var DefaultNewTokenOptions = NewTokenOptions{
NonceLength: 32,
Algorithm: "HS256",
}
func NewToken() *Token {
t := new(Token)
t.Scopes = make([]string, 0)
func NewToken(opts NewTokenOptions) (*Token, error) {
if opts.NonceLength == 0 {
opts.NonceLength = DefaultNewTokenOptions.NonceLength
}
return t
if opts.Algorithm == "" {
opts.Algorithm = DefaultNewTokenOptions.Algorithm
}
now := time.Now().UTC().Round(time.Second)
nonce, err := random.String(opts.NonceLength)
if err != nil {
return nil, fmt.Errorf("cannot generate nonce: %w", err)
}
t := jwt.New()
t.Set(jwt.IssuerKey, opts.Issuer.String())
t.Set(jwt.SubjectKey, opts.Subject.String())
t.Set(jwt.NotBeforeKey, now)
t.Set(jwt.IssuedAtKey, now)
t.Set("scope", opts.Scope)
t.Set("nonce", nonce)
if opts.Expiration != 0 {
t.Set(jwt.ExpirationKey, now.Add(opts.Expiration))
}
accessToken, err := jwt.Sign(t, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret)
if err != nil {
return nil, fmt.Errorf("cannot sign a new access token: %w", err)
}
return &Token{
AccessToken: string(accessToken),
ClientID: opts.Issuer,
Me: opts.Subject,
Scope: opts.Scope,
}, err
}
// TestToken returns a valid Token with the generated test data filled in.
func TestToken(tb testing.TB) *Token {
tb.Helper()
require := require.New(tb)
nonce, err := random.String(42)
require.NoError(tb, err)
nonce, err := random.String(50)
require.NoError(err)
client := TestClient(tb)
profile := TestProfile(tb)
now := time.Now().UTC().Round(time.Second)
scopes := []string{"create", "update", "delete"}
t := jwt.New()
cid := TestClientID(tb)
me := TestMe(tb)
now := time.Now().UTC().Round(time.Second)
scope := []Scope{
ScopeCreate,
ScopeUpdate,
ScopeDelete,
}
// required
t.Set(jwt.IssuerKey, client.ID) // NOTE(toby3d): client_id
t.Set(jwt.SubjectKey, profile.URL) // NOTE(toby3d): me
// NOTE(toby3d): required
t.Set(jwt.IssuerKey, cid.String())
t.Set(jwt.SubjectKey, me.me.String())
// TODO(toby3d): t.Set(jwt.AudienceKey, nil)
t.Set(jwt.ExpirationKey, now.Add(1*time.Hour))
t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour))
@ -50,16 +106,33 @@ func TestToken(tb testing.TB) *Token {
// TODO(toby3d): t.Set(jwt.JwtIDKey, nil)
// optional
t.Set("scope", strings.Join(scopes, " "))
t.Set("scope", scope)
t.Set("nonce", nonce)
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
require.NoError(err)
require.NoError(tb, err)
return &Token{
ClientID: cid,
Me: me,
Scope: scope,
AccessToken: string(accessToken),
ClientID: t.Issuer(),
Me: t.Subject(),
Scopes: scopes,
}
}
// SetAuthHeader writes an Access Token to the request header.
func (t *Token) SetAuthHeader(r *http.Request) {
if t.AccessToken == "" {
return
}
r.Header.Set(http.HeaderAuthorization, t.String())
}
func (t *Token) String() string {
if t.AccessToken == "" {
return ""
}
return "Bearer " + string(t.AccessToken)
}