🏷️ Improved exists domains
This commit is contained in:
parent
9a1bbd4c2c
commit
20b517fdb7
|
@ -1,33 +1,47 @@
|
|||
package domain
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Client describes the client requesting data about the user.
|
||||
type Client struct {
|
||||
RedirectURI []string
|
||||
ID string
|
||||
Logo string
|
||||
Name string
|
||||
URL string
|
||||
}
|
||||
|
||||
func NewClient() *Client {
|
||||
c := new(Client)
|
||||
c.RedirectURI = make([]string, 0)
|
||||
|
||||
return c
|
||||
ID *ClientID
|
||||
Logo []*URL
|
||||
RedirectURI []*URL
|
||||
URL []*URL
|
||||
Name []string
|
||||
}
|
||||
|
||||
// TestClient returns a valid Client with the generated test data filled in.
|
||||
func TestClient(tb testing.TB) *Client {
|
||||
tb.Helper()
|
||||
|
||||
url, err := NewURL("https://app.example.com/")
|
||||
require.NoError(tb, err)
|
||||
|
||||
logo, err := NewURL("https://app.example.com/logo.png")
|
||||
require.NoError(tb, err)
|
||||
|
||||
redirects := make([]*URL, 0)
|
||||
|
||||
for _, redirect := range []string{
|
||||
"https://app.example.net/redirect",
|
||||
"https://app.example.com/redirect",
|
||||
} {
|
||||
u, err := NewURL(redirect)
|
||||
require.NoError(tb, err)
|
||||
|
||||
redirects = append(redirects, u)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
ID: "https://app.example.com/",
|
||||
Name: "Example App",
|
||||
Logo: "https://app.example.com/logo.png",
|
||||
URL: "https://app.example.com/",
|
||||
RedirectURI: []string{
|
||||
"https://app.example.net/redirect",
|
||||
"https://app.example.com/redirect",
|
||||
},
|
||||
ID: TestClientID(tb),
|
||||
Name: []string{"Example App"},
|
||||
URL: []*URL{url},
|
||||
Logo: []*URL{logo},
|
||||
RedirectURI: redirects,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package domain
|
|||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -14,8 +15,8 @@ import (
|
|||
|
||||
// ClientID is a URL client identifier.
|
||||
type ClientID struct {
|
||||
cid *http.URI
|
||||
valid bool
|
||||
clientID *http.URI
|
||||
valid bool
|
||||
}
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
|
@ -25,8 +26,8 @@ var (
|
|||
)
|
||||
|
||||
func NewClientID(raw string) (*ClientID, error) {
|
||||
cid := http.AcquireURI()
|
||||
if err := cid.Parse(nil, []byte(raw)); err != nil {
|
||||
clientID := http.AcquireURI()
|
||||
if err := clientID.Parse(nil, []byte(raw)); err != nil {
|
||||
return nil, Error{
|
||||
Code: "invalid_request",
|
||||
Description: err.Error(),
|
||||
|
@ -35,7 +36,7 @@ func NewClientID(raw string) (*ClientID, error) {
|
|||
}
|
||||
}
|
||||
|
||||
scheme := string(cid.Scheme())
|
||||
scheme := string(clientID.Scheme())
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return nil, Error{
|
||||
Code: "invalid_request",
|
||||
|
@ -45,7 +46,7 @@ func NewClientID(raw string) (*ClientID, error) {
|
|||
}
|
||||
}
|
||||
|
||||
path := string(cid.PathOriginal())
|
||||
path := string(clientID.PathOriginal())
|
||||
if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") {
|
||||
return nil, Error{
|
||||
Code: "invalid_request",
|
||||
|
@ -56,7 +57,7 @@ func NewClientID(raw string) (*ClientID, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if cid.Hash() != nil {
|
||||
if clientID.Hash() != nil {
|
||||
return nil, Error{
|
||||
Code: "invalid_request",
|
||||
Description: "client identifier URL MUST NOT contain a fragment component",
|
||||
|
@ -65,7 +66,7 @@ func NewClientID(raw string) (*ClientID, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if cid.Username() != nil || cid.Password() != nil {
|
||||
if clientID.Username() != nil || clientID.Password() != nil {
|
||||
return nil, Error{
|
||||
Code: "invalid_request",
|
||||
Description: "client identifier URL MUST NOT contain a username or password component",
|
||||
|
@ -74,7 +75,7 @@ func NewClientID(raw string) (*ClientID, error) {
|
|||
}
|
||||
}
|
||||
|
||||
domain := string(cid.Host())
|
||||
domain := string(clientID.Host())
|
||||
if domain == "" {
|
||||
return nil, Error{
|
||||
Code: "invalid_request",
|
||||
|
@ -88,7 +89,7 @@ func NewClientID(raw string) (*ClientID, error) {
|
|||
if err != nil {
|
||||
ipPort, err := netaddr.ParseIPPort(domain)
|
||||
if err != nil {
|
||||
return &ClientID{cid: cid}, nil
|
||||
return &ClientID{clientID: clientID}, nil
|
||||
}
|
||||
|
||||
ip = ipPort.IP()
|
||||
|
@ -104,28 +105,43 @@ func NewClientID(raw string) (*ClientID, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return &ClientID{cid: cid}, nil
|
||||
return &ClientID{clientID: clientID}, nil
|
||||
}
|
||||
|
||||
// TestClientID returns a valid random generated ClientID for tests.
|
||||
func TestClientID(tb testing.TB) *ClientID {
|
||||
tb.Helper()
|
||||
|
||||
cid, err := NewClientID("https://app.example.com/")
|
||||
clientID, err := NewClientID("https://app.example.com/")
|
||||
require.NoError(tb, err)
|
||||
|
||||
return cid
|
||||
return clientID
|
||||
}
|
||||
|
||||
// UnmarshalForm implements a custom form.Unmarshaler.
|
||||
func (cid *ClientID) UnmarshalForm(v []byte) error {
|
||||
clientId, err := NewClientID(string(v))
|
||||
clientID, err := NewClientID(string(v))
|
||||
if err != nil {
|
||||
return fmt.Errorf("UnmarshalForm: %w", err)
|
||||
}
|
||||
defer http.ReleaseURI(clientId.cid) //nolint: wsl
|
||||
|
||||
clientId.cid.CopyTo(cid.cid)
|
||||
*cid = *clientID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cid *ClientID) UnmarshalJSON(v []byte) error {
|
||||
src, err := strconv.Unquote(string(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientID, err := NewClientID(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("UnmarshalJSON: %w", err)
|
||||
}
|
||||
|
||||
*cid = *clientID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -134,23 +150,23 @@ func (cid *ClientID) UnmarshalForm(v []byte) error {
|
|||
// This copy MUST be released via fasthttp.ReleaseURI.
|
||||
func (cid *ClientID) URI() *http.URI {
|
||||
u := http.AcquireURI()
|
||||
cid.cid.CopyTo(u)
|
||||
cid.clientID.CopyTo(u)
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func (cid *ClientID) URL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: string(cid.cid.Scheme()),
|
||||
Host: string(cid.cid.Host()),
|
||||
Path: string(cid.cid.Path()),
|
||||
RawPath: string(cid.cid.PathOriginal()),
|
||||
RawQuery: string(cid.cid.QueryString()),
|
||||
Fragment: string(cid.cid.Hash()),
|
||||
Scheme: string(cid.clientID.Scheme()),
|
||||
Host: string(cid.clientID.Host()),
|
||||
Path: string(cid.clientID.Path()),
|
||||
RawPath: string(cid.clientID.PathOriginal()),
|
||||
RawQuery: string(cid.clientID.QueryString()),
|
||||
Fragment: string(cid.clientID.Hash()),
|
||||
}
|
||||
}
|
||||
|
||||
// String returns string representation of client ID.
|
||||
func (cid *ClientID) String() string {
|
||||
return cid.cid.String()
|
||||
return cid.clientID.String()
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -84,6 +85,22 @@ func (ccm *CodeChallengeMethod) UnmarshalForm(v []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ccm *CodeChallengeMethod) UnmarshalJSON(v []byte) error {
|
||||
src, err := strconv.Unquote(string(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
method, err := ParseCodeChallengeMethod(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("code_challenge_method: %w", err)
|
||||
}
|
||||
|
||||
*ccm = method
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns string representation of code challenge method.
|
||||
func (ccm CodeChallengeMethod) String() string {
|
||||
return ccm.slug
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Error describes the data of a typical error.
|
||||
//nolint: tagliatelle
|
||||
type Error struct {
|
||||
Code string `json:"error"`
|
||||
|
@ -23,15 +24,17 @@ func (e Error) Format(s fmt.State, r rune) {
|
|||
}
|
||||
|
||||
func (e Error) FormatError(p xerrors.Printer) error {
|
||||
p.Printf("%s: %s", e.Code, e.Description)
|
||||
p.Print(e.Description)
|
||||
|
||||
if e.URI != "" {
|
||||
p.Printf(": %s", e.URI)
|
||||
p.Print(": ", e.URI, "\n")
|
||||
}
|
||||
|
||||
if p.Detail() {
|
||||
e.Frame.Format(p)
|
||||
if !p.Detail() {
|
||||
return e
|
||||
}
|
||||
|
||||
e.Frame.Format(p)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package domain
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -16,13 +17,19 @@ type GrantType struct {
|
|||
var (
|
||||
GrantTypeUndefined = GrantType{slug: ""}
|
||||
GrantTypeAuthorizationCode = GrantType{slug: "authorization_code"}
|
||||
|
||||
// TicketAuth extension.
|
||||
GrantTypeTicket = GrantType{slug: "ticket"}
|
||||
)
|
||||
|
||||
var ErrGrantTypeUnknown = errors.New("unknown grant type")
|
||||
|
||||
func ParseGrantType(slug string) (GrantType, error) {
|
||||
if strings.ToLower(slug) == GrantTypeAuthorizationCode.slug {
|
||||
switch strings.ToLower(slug) {
|
||||
case GrantTypeAuthorizationCode.slug:
|
||||
return GrantTypeAuthorizationCode, nil
|
||||
case GrantTypeTicket.slug:
|
||||
return GrantTypeTicket, nil
|
||||
}
|
||||
|
||||
return GrantTypeUndefined, fmt.Errorf("%w: %s", ErrGrantTypeUnknown, slug)
|
||||
|
@ -39,6 +46,22 @@ func (gt *GrantType) UnmarshalForm(src []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (gt *GrantType) UnmarshalJSON(v []byte) error {
|
||||
src, err := strconv.Unquote(string(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
responseType, err := ParseGrantType(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant_type: %w", err)
|
||||
}
|
||||
|
||||
*gt = responseType
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gt GrantType) String() string {
|
||||
return gt.slug
|
||||
}
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"source.toby3d.me/website/oauth/internal/random"
|
||||
)
|
||||
|
||||
type Login struct {
|
||||
PKCE
|
||||
CreatedAt time.Time
|
||||
CompletedAt time.Time
|
||||
Scopes []string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
MeEntered string
|
||||
MeResolved string
|
||||
Code string
|
||||
Provider string
|
||||
IsCompleted bool
|
||||
}
|
||||
|
||||
//nolint: gomnd
|
||||
func TestLogin(tb testing.TB) *Login {
|
||||
tb.Helper()
|
||||
|
||||
code, err := random.String(8)
|
||||
require.NoError(tb, err)
|
||||
|
||||
return &Login{
|
||||
CreatedAt: gofakeit.Date(),
|
||||
CompletedAt: time.Time{},
|
||||
PKCE: PKCE{
|
||||
Method: PKCEMethodS256,
|
||||
Challenge: "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo",
|
||||
Verifier: "a6128783714cfda1d388e2e98b6ae8221ac31aca31959e59512c59f5",
|
||||
},
|
||||
Scopes: []string{"profile", "create", "update", "delete"},
|
||||
ClientID: "https://app.example.com/",
|
||||
RedirectURI: "https://app.example.com/redirect",
|
||||
MeEntered: "user.example.net",
|
||||
MeResolved: "https://user.example.net/",
|
||||
Code: code,
|
||||
Provider: "mastodon",
|
||||
IsCompleted: false,
|
||||
}
|
||||
}
|
||||
|
||||
//nolint: gomnd
|
||||
func TestLoginInvalid(tb testing.TB) *Login {
|
||||
tb.Helper()
|
||||
|
||||
challenge, err := random.String(42)
|
||||
require.NoError(tb, err)
|
||||
|
||||
verifier, err := random.String(64)
|
||||
require.NoError(tb, err)
|
||||
|
||||
return &Login{
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
CompletedAt: time.Time{},
|
||||
PKCE: PKCE{
|
||||
Method: "UNDEFINED",
|
||||
Challenge: challenge,
|
||||
Verifier: verifier,
|
||||
},
|
||||
Scopes: []string{},
|
||||
ClientID: "whoisit",
|
||||
RedirectURI: "redirect",
|
||||
MeEntered: "whoami",
|
||||
MeResolved: "",
|
||||
Code: "",
|
||||
Provider: "undefined",
|
||||
IsCompleted: true,
|
||||
}
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -109,14 +110,29 @@ func TestMe(tb testing.TB) *Me {
|
|||
}
|
||||
|
||||
// UnmarshalForm parses the value of the form key into the Me domain.
|
||||
func (m *Me) UnmarshalForm(v []byte) (err error) {
|
||||
func (m *Me) UnmarshalForm(v []byte) error {
|
||||
me, err := NewMe(string(v))
|
||||
if err != nil {
|
||||
return fmt.Errorf("UnmarshalForm: %w", err)
|
||||
}
|
||||
defer http.ReleaseURI(me.me) //nolint: wsl
|
||||
|
||||
me.me.CopyTo(m.me)
|
||||
*m = *me
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Me) UnmarshalJSON(v []byte) error {
|
||||
src, err := strconv.Unquote(string(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
me, err := NewMe(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("UnmarshalForm: %w", err)
|
||||
}
|
||||
|
||||
*m = *me
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -124,6 +140,10 @@ func (m *Me) UnmarshalForm(v []byte) (err error) {
|
|||
// URI returns copy of parsed Me in *fasthttp.URI representation.
|
||||
// This copy MUST be released via fasthttp.ReleaseURI.
|
||||
func (m *Me) URI() *http.URI {
|
||||
if m.me == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
u := http.AcquireURI()
|
||||
m.me.CopyTo(u)
|
||||
|
||||
|
@ -132,6 +152,10 @@ func (m *Me) URI() *http.URI {
|
|||
|
||||
// URL returns copy of parsed Me in *url.URL representation.
|
||||
func (m *Me) URL() *url.URL {
|
||||
if m.me == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &url.URL{
|
||||
Scheme: string(m.me.Scheme()),
|
||||
Host: string(m.me.Host()),
|
||||
|
@ -143,6 +167,10 @@ func (m *Me) URL() *url.URL {
|
|||
}
|
||||
|
||||
// String returns string representation of Me.
|
||||
func (m Me) String() string {
|
||||
func (m *Me) String() string {
|
||||
if m.me == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return m.me.String()
|
||||
}
|
||||
|
|
|
@ -1,21 +1,43 @@
|
|||
package domain
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Profile describes the data about the user.
|
||||
type Profile struct {
|
||||
Name string
|
||||
URL string
|
||||
Photo string
|
||||
Email string
|
||||
Photo []*URL
|
||||
URL []*URL
|
||||
Email []Email
|
||||
Name []string
|
||||
}
|
||||
|
||||
// NewProfile creates a new empty Profile.
|
||||
func NewProfile() *Profile {
|
||||
return &Profile{
|
||||
Email: make([]Email, 0),
|
||||
Name: make([]string, 0),
|
||||
Photo: make([]*URL, 0),
|
||||
URL: make([]*URL, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfile returns a valid Profile with the generated test data filled in.
|
||||
func TestProfile(tb testing.TB) *Profile {
|
||||
tb.Helper()
|
||||
|
||||
photo, err := NewURL("https://user.example.net/photo.jpg")
|
||||
require.NoError(tb, err)
|
||||
|
||||
url, err := NewURL("https://user.example.net/")
|
||||
require.NoError(tb, err)
|
||||
|
||||
return &Profile{
|
||||
Name: "Example User",
|
||||
URL: "https://user.example.net/",
|
||||
Photo: "https://user.example.net/photo.jpg",
|
||||
Email: "user@example.net",
|
||||
Email: []Email{"user@example.net"},
|
||||
Name: []string{"Example User"},
|
||||
Photo: []*URL{photo},
|
||||
URL: []*URL{url},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package domain
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -58,6 +59,22 @@ func (rt *ResponseType) UnmarshalForm(src []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (rt *ResponseType) UnmarshalJSON(v []byte) error {
|
||||
src, err := strconv.Unquote(string(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
responseType, err := ParseResponseType(string(src))
|
||||
if err != nil {
|
||||
return fmt.Errorf("response_type: %w", err)
|
||||
}
|
||||
|
||||
*rt = responseType
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rt ResponseType) String() string {
|
||||
return rt.slug
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package domain
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -12,6 +13,7 @@ type (
|
|||
slug string
|
||||
}
|
||||
|
||||
// Scopes represent set of Scope domains.
|
||||
Scopes []Scope
|
||||
)
|
||||
|
||||
|
@ -74,7 +76,7 @@ var slugsScopes = map[string]Scope{
|
|||
|
||||
// ParseScope parses scope slug into Scope domain.
|
||||
func ParseScope(slug string) (Scope, error) {
|
||||
if scope, ok := slugsScopes[strings.ToLower(slug)]; !ok {
|
||||
if scope, ok := slugsScopes[strings.ToLower(slug)]; ok {
|
||||
return scope, nil
|
||||
}
|
||||
|
||||
|
@ -93,7 +95,55 @@ func (s *Scope) UnmarshalForm(v []byte) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Scope) UnmarshalJSON(v []byte) error {
|
||||
src, err := strconv.Unquote(string(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scope, err := ParseScope(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("scope: %w", err)
|
||||
}
|
||||
|
||||
*s = scope
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scopes) UnmarshalJSON(v []byte) error {
|
||||
src, err := strconv.Unquote(string(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := make([]Scope, 0)
|
||||
|
||||
for _, scope := range strings.Fields(src) {
|
||||
s, err := ParseScope(scope)
|
||||
if err != nil {
|
||||
return fmt.Errorf("scope: %w", err)
|
||||
}
|
||||
|
||||
result = append(result, s)
|
||||
}
|
||||
|
||||
*s = result
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns scope slug as string.
|
||||
func (s Scope) String() string {
|
||||
return s.slug
|
||||
}
|
||||
|
||||
func (s Scopes) String() string {
|
||||
scopes := make([]string, len(s))
|
||||
|
||||
for i := range s {
|
||||
scopes[i] = s[i].String()
|
||||
}
|
||||
|
||||
return strings.Join(scopes, " ")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
package domain_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"source.toby3d.me/website/oauth/internal/domain"
|
||||
)
|
||||
|
||||
func TestScopesUnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := &struct {
|
||||
Scope domain.Scopes `json:"scope"`
|
||||
}{}
|
||||
require.NoError(t, json.Unmarshal([]byte(`{"scope": "read update delete"}`), result))
|
||||
|
||||
for _, scope := range []domain.Scope{domain.ScopeRead, domain.ScopeUpdate, domain.ScopeDelete} {
|
||||
assert.Contains(t, result.Scope, scope)
|
||||
}
|
||||
}
|
|
@ -1,48 +1,104 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
http "github.com/valyala/fasthttp"
|
||||
|
||||
"source.toby3d.me/website/oauth/internal/random"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Scopes []string
|
||||
AccessToken string
|
||||
ClientID string
|
||||
Me string
|
||||
type (
|
||||
// Token describes the data of the token used by the clients.
|
||||
Token struct {
|
||||
AccessToken string
|
||||
ClientID *ClientID
|
||||
Me *Me
|
||||
Scope Scopes
|
||||
}
|
||||
|
||||
NewTokenOptions struct {
|
||||
Algorithm string
|
||||
Expiration time.Duration
|
||||
Issuer *ClientID
|
||||
NonceLength int
|
||||
Scope Scopes
|
||||
Secret interface{}
|
||||
Subject *Me
|
||||
}
|
||||
)
|
||||
|
||||
var DefaultNewTokenOptions = NewTokenOptions{
|
||||
NonceLength: 32,
|
||||
Algorithm: "HS256",
|
||||
}
|
||||
|
||||
func NewToken() *Token {
|
||||
t := new(Token)
|
||||
t.Scopes = make([]string, 0)
|
||||
func NewToken(opts NewTokenOptions) (*Token, error) {
|
||||
if opts.NonceLength == 0 {
|
||||
opts.NonceLength = DefaultNewTokenOptions.NonceLength
|
||||
}
|
||||
|
||||
return t
|
||||
if opts.Algorithm == "" {
|
||||
opts.Algorithm = DefaultNewTokenOptions.Algorithm
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Round(time.Second)
|
||||
|
||||
nonce, err := random.String(opts.NonceLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot generate nonce: %w", err)
|
||||
}
|
||||
|
||||
t := jwt.New()
|
||||
t.Set(jwt.IssuerKey, opts.Issuer.String())
|
||||
t.Set(jwt.SubjectKey, opts.Subject.String())
|
||||
t.Set(jwt.NotBeforeKey, now)
|
||||
t.Set(jwt.IssuedAtKey, now)
|
||||
t.Set("scope", opts.Scope)
|
||||
t.Set("nonce", nonce)
|
||||
|
||||
if opts.Expiration != 0 {
|
||||
t.Set(jwt.ExpirationKey, now.Add(opts.Expiration))
|
||||
}
|
||||
|
||||
accessToken, err := jwt.Sign(t, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot sign a new access token: %w", err)
|
||||
}
|
||||
|
||||
return &Token{
|
||||
AccessToken: string(accessToken),
|
||||
ClientID: opts.Issuer,
|
||||
Me: opts.Subject,
|
||||
Scope: opts.Scope,
|
||||
}, err
|
||||
}
|
||||
|
||||
// TestToken returns a valid Token with the generated test data filled in.
|
||||
func TestToken(tb testing.TB) *Token {
|
||||
tb.Helper()
|
||||
|
||||
require := require.New(tb)
|
||||
nonce, err := random.String(42)
|
||||
require.NoError(tb, err)
|
||||
|
||||
nonce, err := random.String(50)
|
||||
require.NoError(err)
|
||||
|
||||
client := TestClient(tb)
|
||||
profile := TestProfile(tb)
|
||||
now := time.Now().UTC().Round(time.Second)
|
||||
scopes := []string{"create", "update", "delete"}
|
||||
t := jwt.New()
|
||||
cid := TestClientID(tb)
|
||||
me := TestMe(tb)
|
||||
now := time.Now().UTC().Round(time.Second)
|
||||
scope := []Scope{
|
||||
ScopeCreate,
|
||||
ScopeUpdate,
|
||||
ScopeDelete,
|
||||
}
|
||||
|
||||
// required
|
||||
t.Set(jwt.IssuerKey, client.ID) // NOTE(toby3d): client_id
|
||||
t.Set(jwt.SubjectKey, profile.URL) // NOTE(toby3d): me
|
||||
// NOTE(toby3d): required
|
||||
t.Set(jwt.IssuerKey, cid.String())
|
||||
t.Set(jwt.SubjectKey, me.me.String())
|
||||
// TODO(toby3d): t.Set(jwt.AudienceKey, nil)
|
||||
t.Set(jwt.ExpirationKey, now.Add(1*time.Hour))
|
||||
t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour))
|
||||
|
@ -50,16 +106,33 @@ func TestToken(tb testing.TB) *Token {
|
|||
// TODO(toby3d): t.Set(jwt.JwtIDKey, nil)
|
||||
|
||||
// optional
|
||||
t.Set("scope", strings.Join(scopes, " "))
|
||||
t.Set("scope", scope)
|
||||
t.Set("nonce", nonce)
|
||||
|
||||
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
|
||||
require.NoError(err)
|
||||
require.NoError(tb, err)
|
||||
|
||||
return &Token{
|
||||
ClientID: cid,
|
||||
Me: me,
|
||||
Scope: scope,
|
||||
AccessToken: string(accessToken),
|
||||
ClientID: t.Issuer(),
|
||||
Me: t.Subject(),
|
||||
Scopes: scopes,
|
||||
}
|
||||
}
|
||||
|
||||
// SetAuthHeader writes an Access Token to the request header.
|
||||
func (t *Token) SetAuthHeader(r *http.Request) {
|
||||
if t.AccessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
r.Header.Set(http.HeaderAuthorization, t.String())
|
||||
}
|
||||
|
||||
func (t *Token) String() string {
|
||||
if t.AccessToken == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "Bearer " + string(t.AccessToken)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue