♻️ Refactored token package

This commit is contained in:
Maxim Lebedev 2022-01-14 01:50:40 +05:00
parent 83dc4286eb
commit 75f6cb168f
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
8 changed files with 460 additions and 433 deletions

View File

@ -1,7 +1,6 @@
package http
import (
"bytes"
"strings"
"github.com/fasthttp/router"
@ -9,99 +8,178 @@ import (
http "github.com/valyala/fasthttp"
"golang.org/x/xerrors"
"source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware"
"source.toby3d.me/website/indieauth/internal/common"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/token"
)
type (
RevocationRequest struct {
Action string
Token string
ExchangeRequest struct {
ClientID *domain.ClientID `form:"client_id"`
RedirectURI *domain.URL `form:"redirect_uri"`
GrantType domain.GrantType `form:"grant_type"`
Code string `form:"code"`
CodeVerifier string `form:"code_verifier"`
}
RevokeRequest struct {
Action domain.Action `form:"action"`
Token string `form:"token"`
}
TicketRequest struct {
Action domain.Action `form:"action"`
Ticket string `form:"ticket"`
}
//nolint: tagliatelle
ExchangeResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
Me string `json:"me"`
}
//nolint: tagliatelle
VerificationResponse struct {
Me string `json:"me"`
ClientID string `json:"client_id"`
Scope string `json:"scope"`
Me *domain.Me `json:"me"`
ClientID *domain.ClientID `json:"client_id"`
Scope domain.Scopes `json:"scope"`
}
RevocationResponse struct{}
RequestHandler struct {
tokener token.UseCase
tokens token.UseCase
// TODO(toby3d): tickets ticket.UseCase
}
)
const (
Action string = "action"
ActionRevoke string = "revoke"
)
func NewRequestHandler(tokener token.UseCase) *RequestHandler {
func NewRequestHandler(tokens token.UseCase /*, tickets ticket.UseCase*/) *RequestHandler {
return &RequestHandler{
tokener: tokener,
tokens: tokens,
// tickets: tickets,
}
}
func (h *RequestHandler) Register(r *router.Router) {
r.GET("/token", h.Read)
r.POST("/token", h.Update)
chain := middleware.Chain{
middleware.LogFmt(),
}
r.GET("/token", chain.RequestHandler(h.handleValidate))
r.POST("/token", chain.RequestHandler(h.handleAction))
}
func (h *RequestHandler) Read(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSON)
ctx.SetStatusCode(http.StatusOK)
rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization)
t, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer "))))
if err != nil {
if xerrors.Is(err, token.ErrRevoke) {
ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized)
} else {
ctx.Error(err.Error(), http.StatusBadRequest)
}
return
}
if t == nil {
ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
if err := json.NewEncoder(ctx).Encode(&VerificationResponse{
ClientID: t.ClientID,
Me: t.Me,
Scope: strings.Join(t.Scopes, " "),
}); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
}
}
func (h *RequestHandler) Update(ctx *http.RequestCtx) {
if strings.EqualFold(string(ctx.FormValue(Action)), ActionRevoke) {
h.Revocation(ctx)
}
}
func (h *RequestHandler) Revocation(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSON)
func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
ctx.SetStatusCode(http.StatusOK)
encoder := json.NewEncoder(ctx)
req := new(RevocationRequest)
t, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)),
"Bearer "))
if err != nil || t == nil {
ctx.SetStatusCode(http.StatusUnauthorized)
encoder.Encode(&domain.Error{
Code: "unauthorized_client",
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return
}
encoder.Encode(&VerificationResponse{
ClientID: t.ClientID,
Me: t.Me,
Scope: t.Scope,
})
}
func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(ctx)
switch {
case ctx.PostArgs().Has("grant_type"):
h.handleExchange(ctx)
case ctx.PostArgs().Has("action"):
action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action")))
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.Error{
Code: "invalid_request",
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return
}
switch action {
case domain.ActionRevoke:
h.handleRevoke(ctx)
case domain.ActionTicket:
h.handleTicket(ctx)
}
}
}
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
return
}
token, err := h.tokens.Exchange(ctx, token.ExchangeOptions{
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
Code: req.Code,
CodeVerifier: req.CodeVerifier,
})
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(&domain.Error{
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return
}
encoder.Encode(&ExchangeResponse{
AccessToken: token.AccessToken,
TokenType: "Bearer",
Scope: token.Scope.String(),
Me: token.Me.String(),
})
}
func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
ctx.SetStatusCode(http.StatusOK)
encoder := json.NewEncoder(ctx)
req := new(RevokeRequest)
if err := req.bind(ctx); err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
if err := h.tokener.Revoke(ctx, req.Token); err != nil {
if err := h.tokens.Revoke(ctx, req.Token); err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
@ -112,21 +190,65 @@ func (h *RequestHandler) Revocation(ctx *http.RequestCtx) {
}
}
func (r *RevocationRequest) bind(ctx *http.RequestCtx) error {
if r.Action = string(ctx.FormValue(Action)); !strings.EqualFold(r.Action, ActionRevoke) {
return domain.Error{
Code: "invalid_request",
Description: "request MUST contain 'action' key with value 'revoke'",
URI: "https://indieauth.spec.indieweb.org/#token-revocation-request",
Frame: xerrors.Caller(1),
}
func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
ctx.SetStatusCode(http.StatusOK)
encoder := json.NewEncoder(ctx)
req := new(TicketRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
return
}
if r.Token = string(ctx.FormValue("token")); r.Token == "" {
/* TODO(toby3d)
token, err := h.tickets.Redeem(ctx, req.Ticket)
if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(domain.Error{
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return
}
*/
encoder.Encode(ExchangeResponse{})
}
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
return domain.Error{
Code: "invalid_request",
Description: "request MUST contain the 'token' key with the valid access token as its value",
URI: "https://indieauth.spec.indieweb.org/#token-revocation-request",
Description: err.Error(),
Frame: xerrors.Caller(1),
}
}
return nil
}
func (r *RevokeRequest) bind(ctx *http.RequestCtx) error {
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
return domain.Error{
Code: "invalid_request",
Description: err.Error(),
Frame: xerrors.Caller(1),
}
}
return nil
}
func (r *TicketRequest) bind(ctx *http.RequestCtx) error {
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
return domain.Error{
Code: "invalid_request",
Description: err.Error(),
Frame: xerrors.Caller(1),
}
}

View File

@ -6,82 +6,74 @@ import (
"sync"
"testing"
"github.com/goccy/go-json"
"github.com/spf13/viper"
"github.com/fasthttp/router"
json "github.com/goccy/go-json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
http "github.com/valyala/fasthttp"
"source.toby3d.me/website/indieauth/internal/common"
configrepo "source.toby3d.me/website/indieauth/internal/config/repository/viper"
configucase "source.toby3d.me/website/indieauth/internal/config/usecase"
"source.toby3d.me/website/indieauth/internal/domain"
sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory"
"source.toby3d.me/website/indieauth/internal/testing/httptest"
delivery "source.toby3d.me/website/indieauth/internal/token/delivery/http"
repository "source.toby3d.me/website/indieauth/internal/token/repository/memory"
"source.toby3d.me/website/indieauth/internal/token/usecase"
"source.toby3d.me/website/indieauth/internal/util"
tokenrepo "source.toby3d.me/website/indieauth/internal/token/repository/memory"
tokenucase "source.toby3d.me/website/indieauth/internal/token/usecase"
)
func TestVerification(t *testing.T) {
t.Parallel()
v := viper.New()
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
v.SetDefault("indieauth.jwtSecret", "hackme")
store := new(sync.Map)
config := domain.TestConfig(t)
token := domain.TestToken(t)
accessToken := domain.TestToken(t)
r := router.New()
// TODO(toby3d): provide tickets
delivery.NewRequestHandler(tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store),
sessionrepo.NewMemorySessionRepository(config, store), config)).Register(r)
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(
repository.NewMemoryTokenRepository(new(sync.Map)),
configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
)).Read)
client, _, cleanup := httptest.New(t, r.Handler)
t.Cleanup(cleanup)
req := http.AcquireRequest()
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/token", nil)
defer http.ReleaseRequest(req)
req.Header.SetMethod(http.MethodGet)
req.SetRequestURI("https://app.example.com/token")
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
req.Header.Set(http.HeaderAuthorization, "Bearer "+accessToken.AccessToken)
token.SetAuthHeader(req)
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
require.NoError(t, client.Do(req, resp))
assert.Equal(t, http.StatusOK, resp.StatusCode())
token := new(delivery.VerificationResponse)
require.NoError(t, json.Unmarshal(resp.Body(), token))
assert.Equal(t, &delivery.VerificationResponse{
Me: accessToken.Me,
ClientID: accessToken.ClientID,
Scope: strings.Join(accessToken.Scopes, " "),
}, token)
result := new(delivery.VerificationResponse)
require.NoError(t, json.Unmarshal(resp.Body(), result))
assert.Equal(t, token.ClientID.String(), result.ClientID.String())
assert.Equal(t, token.Me.String(), result.Me.String())
assert.Equal(t, token.Scope.String(), result.Scope.String())
}
func TestRevocation(t *testing.T) {
t.Parallel()
v := viper.New()
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
v.SetDefault("indieauth.jwtSecret", "hackme")
tokens := repository.NewMemoryTokenRepository(new(sync.Map))
config := domain.TestConfig(t)
store := new(sync.Map)
tokens := tokenrepo.NewMemoryTokenRepository(store)
accessToken := domain.TestToken(t)
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(
usecase.NewTokenUseCase(tokens, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))),
).Update)
r := router.New()
delivery.NewRequestHandler(tokenucase.NewTokenUseCase(tokens, sessionrepo.NewMemorySessionRepository(config,
store), config)).Register(r)
client, _, cleanup := httptest.New(t, r.Handler)
t.Cleanup(cleanup)
req := http.AcquireRequest()
req := httptest.NewRequest(http.MethodPost, "https://app.example.com/token", nil)
defer http.ReleaseRequest(req)
req.Header.SetMethod(http.MethodPost)
req.SetRequestURI("https://app.example.com/token")
req.Header.SetContentType(common.MIMEApplicationXWWWFormUrlencoded)
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
req.PostArgs().Set("action", "revoke")
req.Header.SetContentType(common.MIMEApplicationForm)
req.PostArgs().Set("action", domain.ActionRevoke.String())
req.PostArgs().Set("token", accessToken.AccessToken)
resp := http.AcquireResponse()

View File

@ -1,129 +0,0 @@
package bolt
import (
"context"
"encoding/json"
"time"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
"golang.org/x/xerrors"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/token"
)
type (
Token struct {
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt time.Time `json:"deletedAt,omitempty"`
Scopes []string `json:"scopes"`
AccessToken string `json:"accessToken"`
ClientID string `json:"clientId"`
Me string `json:"me"`
}
boltTokenRepository struct {
db *bolt.DB
}
)
func NewBoltTokenRepository(db *bolt.DB) token.Repository {
return &boltTokenRepository{
db: db,
}
}
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) (err error) {
t, err := repo.Get(ctx, accessToken.AccessToken)
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
return errors.Wrap(err, "cannot get token in database")
}
if t != nil {
return token.ErrExist
}
if err = repo.db.Update(func(tx *bolt.Tx) error {
//nolint: exhaustivestruct
bkt, err := tx.CreateBucketIfNotExists(Token{}.Bucket())
if err != nil {
return errors.Wrap(err, "cannot create bucket")
}
token := new(Token)
token.Populate(accessToken)
src, err := json.Marshal(token)
if err != nil {
return errors.Wrap(err, "cannot marshal token data")
}
if err = bkt.Put([]byte(token.AccessToken), src); err != nil {
return errors.Wrap(err, "cannot put token into bucket")
}
return nil
}); err != nil {
return errors.Wrap(err, "failed to put token into database")
}
return nil
}
func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
result := new(domain.Token)
if err := repo.db.View(func(tx *bolt.Tx) (err error) {
t := new(Token)
bkt := tx.Bucket(t.Bucket())
if bkt == nil {
return token.ErrNotExist
}
src := bkt.Get([]byte(accessToken))
if src == nil {
return token.ErrNotExist
}
if err = t.Bind(src, result); err != nil {
return errors.Wrap(err, "cannot parse token")
}
return nil
}); err != nil {
return nil, errors.Wrap(err, "failed to view token in database")
}
return result, nil
}
func (Token) Bucket() []byte { return []byte("tokens") }
func (t *Token) Populate(accessToken *domain.Token) {
t.AccessToken = accessToken.AccessToken
t.ClientID = accessToken.ClientID
t.CreatedAt = time.Now().UTC().Round(time.Second)
t.Me = accessToken.Me
t.Scopes = make([]string, len(accessToken.Scopes))
t.UpdatedAt = t.CreatedAt
for i := range accessToken.Scopes {
t.Scopes[i] = accessToken.Scopes[i]
}
}
func (t *Token) Bind(src []byte, accessToken *domain.Token) error {
if err := json.Unmarshal(src, t); err != nil {
return errors.Wrap(err, "cannot unmarshal token")
}
accessToken.AccessToken = t.AccessToken
accessToken.ClientID = t.ClientID
accessToken.Me = t.Me
accessToken.Scopes = t.Scopes
return nil
}

View File

@ -1,86 +0,0 @@
package bolt_test
import (
"context"
"encoding/json"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
bolt "go.etcd.io/bbolt"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/token"
repository "source.toby3d.me/website/indieauth/internal/token/repository/bolt"
"source.toby3d.me/website/indieauth/internal/util"
)
func TestCreate(t *testing.T) {
t.Parallel()
//nolint: exhaustivestruct
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
t.Cleanup(cleanup)
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
//nolint: exhaustivestruct
_, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket())
//nolint: wrapcheck
return err
}))
repo := repository.NewBoltTokenRepository(db)
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
result := domain.NewToken()
require.NoError(t, db.View(func(tx *bolt.Tx) (err error) {
dto := new(repository.Token)
//nolint: wrapcheck
return dto.Bind(tx.Bucket(dto.Bucket()).Get([]byte(accessToken.AccessToken)), result)
}))
assert.Equal(t, accessToken, result)
assert.ErrorIs(t, repo.Create(context.TODO(), accessToken), token.ErrExist)
}
func TestGet(t *testing.T) {
t.Parallel()
//nolint: exhaustivestruct
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
t.Cleanup(cleanup)
accessToken := domain.TestToken(t)
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
//nolint: exhaustivestruct
bkt, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket())
if err != nil {
return errors.Wrap(err, "cannot create bucket")
}
t := new(repository.Token)
t.Populate(accessToken)
src, err := json.Marshal(t)
if err != nil {
return errors.Wrap(err, "cannot marshal token data")
}
if err = bkt.Put([]byte(t.AccessToken), src); err != nil {
return errors.Wrap(err, "cannot put token into bucket")
}
return nil
}))
result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken)
assert.NoError(t, err)
assert.Equal(t, accessToken, result)
}

View File

@ -0,0 +1,105 @@
package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"github.com/jmoiron/sqlx"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/token"
)
type (
Token struct {
CreatedAt sql.NullTime `db:"created_at"`
AccessToken string `db:"access_token"`
ClientID string `db:"client_id"`
Me string `db:"me"`
Scope string `db:"scope"`
}
sqlite3TokenRepository struct {
db *sqlx.DB
}
)
const (
QueryTable string = `CREATE TABLE IF NOT EXISTS tokens (
access_token TEXT UNIQUE PRIMARY KEY NOT NULL,
client_id TEXT NOT NULL,
created_at DATETIME NOT NULL,
me TEXT NOT NULL,
scope TEXT
);`
QueryGet string = `SELECT *
FROM tokens
WHERE access_token=$1;`
QueryCreate string = `INSERT INTO tokens (created_at, access_token, client_id, me, scope)
VALUES (:created_at, :access_token, :client_id, :me, :scope);`
)
func NewSQLite3TokenRepository(db *sqlx.DB) token.Repository {
return &sqlite3TokenRepository{
db: db,
}
}
func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
if _, err := repo.db.NamedExecContext(ctx, QueryTable+QueryCreate, NewToken(accessToken)); err != nil {
return fmt.Errorf("cannot create token record in db: %w", err)
}
return nil
}
func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
t := new(Token)
if err := repo.db.GetContext(ctx, t, QueryTable+QueryGet, accessToken); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, token.ErrNotExist
}
return nil, fmt.Errorf("cannot find token in db: %w", err)
}
result := new(domain.Token)
t.Populate(result)
return result, nil
}
func NewToken(src *domain.Token) *Token {
return &Token{
CreatedAt: sql.NullTime{
Time: time.Now().UTC(),
Valid: true,
},
AccessToken: src.AccessToken,
ClientID: src.ClientID.String(),
Me: src.Me.String(),
Scope: src.Scope.String(),
}
}
func (t *Token) Populate(dst *domain.Token) {
dst.AccessToken = t.AccessToken
dst.ClientID, _ = domain.NewClientID(t.ClientID)
dst.Me, _ = domain.NewMe(t.Me)
dst.Scope = make(domain.Scopes, 0)
for _, scope := range strings.Fields(t.Scope) {
s, err := domain.ParseScope(scope)
if err != nil {
continue
}
dst.Scope = append(dst.Scope, s)
}
}

View File

@ -0,0 +1,47 @@
package sqlite3_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/testing/sqltest"
repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3"
)
func TestCreate(t *testing.T) {
t.Parallel()
db, cleanup := sqltest.Open(t)
t.Cleanup(cleanup)
token := domain.TestToken(t)
require.NoError(t, repository.NewSQLite3TokenRepository(db).Create(context.Background(), token))
results := make([]*repository.Token, 0)
require.NoError(t, db.Select(&results, "SELECT * FROM tokens;"))
require.Len(t, results, 1)
result := new(domain.Token)
results[0].Populate(result)
assert.Equal(t, token.AccessToken, result.AccessToken)
}
func TestGet(t *testing.T) {
t.Parallel()
db, cleanup := sqltest.Open(t)
t.Cleanup(cleanup)
token := domain.TestToken(t)
_, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewToken(token))
require.NoError(t, err)
result, err := repository.NewSQLite3TokenRepository(db).Get(context.Background(), token.AccessToken)
require.NoError(t, err)
assert.Equal(t, token.AccessToken, result.AccessToken)
}

View File

@ -2,108 +2,115 @@ package usecase
import (
"context"
"strings"
"time"
"fmt"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/pkg/errors"
"golang.org/x/xerrors"
"source.toby3d.me/website/indieauth/internal/config"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/random"
"source.toby3d.me/website/indieauth/internal/session"
"source.toby3d.me/website/indieauth/internal/token"
)
type (
Config struct {
Configer config.UseCase
Tokens token.Repository
}
type tokenUseCase struct {
sessions session.Repository
config *domain.Config
tokens token.Repository
}
tokenUseCase struct {
configer config.UseCase
tokens token.Repository
}
)
//nolint: gochecknoinits
func init() {
jwt.RegisterCustomField("scope", make(domain.Scopes, 0))
}
func NewTokenUseCase(config Config) token.UseCase {
func NewTokenUseCase(tokens token.Repository, sessions session.Repository, config *domain.Config) token.UseCase {
return &tokenUseCase{
configer: config.Configer,
tokens: config.Tokens,
sessions: sessions,
config: config,
tokens: tokens,
}
}
// Generate generates a new Token based on the session data.
func (useCase *tokenUseCase) Generate(ctx context.Context, opts token.GenerateOptions) (*domain.Token, error) {
nonce, err := random.String(opts.NonceLength)
func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOptions) (*domain.Token, error) {
session, err := useCase.sessions.GetAndDelete(ctx, opts.Code)
if err != nil {
return nil, errors.Wrap(err, "cannot generate code")
return nil, fmt.Errorf("cannot get session from store: %w", err)
}
t := jwt.New()
now := time.Now().UTC().Round(time.Second)
t.Set(jwt.IssuerKey, opts.ClientID)
t.Set(jwt.SubjectKey, opts.Me)
t.Set(jwt.ExpirationKey, now.Add(useCase.configer.GetIndieAuthAccessTokenExpirationTime()))
t.Set(jwt.NotBeforeKey, now)
t.Set(jwt.IssuedAtKey, now)
t.Set("scope", strings.Join(opts.Scopes, " "))
t.Set("nonce", nonce)
token, err := jwt.Sign(t,
jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()),
[]byte(useCase.configer.GetIndieAuthJWTSecret()))
if err != nil {
return nil, errors.Wrap(err, "cannot sign a new access token")
if opts.ClientID.String() != session.ClientID.String() {
return nil, domain.Error{
Code: "invalid_request",
Description: "client's URL MUST match the client_id used in the authentication request",
URI: "https://indieauth.net/source/#request",
Frame: xerrors.Caller(1),
}
}
return &domain.Token{
Scopes: opts.Scopes,
AccessToken: string(token),
ClientID: opts.ClientID,
Me: opts.Me,
}, nil
if opts.RedirectURI.String() != session.RedirectURI.String() {
return nil, domain.Error{
Code: "invalid_request",
Description: "client's redirect URL MUST match the initial authentication request",
URI: "https://indieauth.net/source/#request",
Frame: xerrors.Caller(1),
}
}
if session.CodeChallenge != "" &&
!session.CodeChallengeMethod.Validate(session.CodeChallenge, opts.CodeVerifier) {
return nil, domain.Error{
Code: "invalid_request",
Description: "code_verifier is not hashes to the same value as given in " +
"the code_challenge in the original authorization request",
URI: "https://indieauth.net/source/#request",
Frame: xerrors.Caller(1),
}
}
t, err := domain.NewToken(domain.NewTokenOptions{
Algorithm: useCase.config.JWT.Algorithm,
Expiration: useCase.config.JWT.Expiry,
Issuer: session.ClientID,
NonceLength: useCase.config.JWT.NonceLength,
Scope: session.Scope,
Secret: []byte(useCase.config.JWT.Secret),
Subject: session.Me,
})
if err != nil {
return nil, fmt.Errorf("cannot generate a new access token: %w", err)
}
return t, nil
}
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
find, err := useCase.tokens.Get(ctx, accessToken)
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
return nil, errors.Wrap(err, "cannot ckeck token in store")
return nil, fmt.Errorf("cannot check token in store: %w", err)
}
if find != nil {
return nil, token.ErrRevoke
}
t, err := jwt.ParseString(accessToken, jwt.WithVerify(
jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()),
[]byte(useCase.configer.GetIndieAuthJWTSecret()),
))
t, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm),
[]byte(useCase.config.JWT.Secret)))
if err != nil {
return nil, errors.Wrap(err, "cannot parse JWT token")
return nil, fmt.Errorf("cannot parse JWT token: %w", err)
}
if err = jwt.Validate(t); err != nil {
return nil, errors.Wrap(err, "cannot validate JWT token")
return nil, fmt.Errorf("cannot validate JWT token: %w", err)
}
result := &domain.Token{
AccessToken: accessToken,
ClientID: t.Issuer(),
Me: t.Subject(),
Scopes: make([]string, 0),
}
result.ClientID, _ = domain.NewClientID(t.Issuer())
result.Me, _ = domain.NewMe(t.Subject())
rawScope, ok := t.Get("scope")
if !ok {
return result, nil
}
if scope, ok := rawScope.(string); ok {
result.Scopes = strings.Fields(scope)
if scope, ok := t.Get("scope"); ok {
result.Scope, _ = scope.(domain.Scopes)
}
return result, nil
@ -112,11 +119,11 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
t, err := useCase.Verify(ctx, accessToken)
if err != nil {
return errors.Wrap(err, "cannot verify token")
return fmt.Errorf("cannot verify token: %w", err)
}
if err = useCase.tokens.Create(ctx, t); err != nil {
return errors.Wrap(err, "cannot save token in database")
return fmt.Errorf("cannot save token in database: %w", err)
}
return nil

View File

@ -2,75 +2,48 @@ package usecase_test
import (
"context"
"strings"
"sync"
"testing"
"github.com/lestrrat-go/jwx/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
configrepo "source.toby3d.me/website/indieauth/internal/config/repository/viper"
configucase "source.toby3d.me/website/indieauth/internal/config/usecase"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/token"
repository "source.toby3d.me/website/indieauth/internal/token/repository/memory"
ucase "source.toby3d.me/website/indieauth/internal/token/usecase"
usecase "source.toby3d.me/website/indieauth/internal/token/usecase"
)
func TestGenerate(t *testing.T) {
func TestExchange(t *testing.T) {
t.Parallel()
configer := configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(domain.TestConfig(t)))
options := token.GenerateOptions{
ClientID: "https://app.example.com/",
Me: "https://user.example.net/",
Scopes: []string{"create", "update", "delete"},
NonceLength: 42,
}
result, err := ucase.NewTokenUseCase(ucase.Config{
Configer: configer,
Tokens: nil,
}).Generate(context.TODO(), options)
require.NoError(t, err)
assert.Equal(t, options.ClientID, result.ClientID)
assert.Equal(t, options.Me, result.Me)
assert.Equal(t, options.Scopes, result.Scopes)
token, err := jwt.ParseString(result.AccessToken)
require.NoError(t, err)
assert.Equal(t, options.Me, token.Subject())
assert.Equal(t, options.ClientID, token.Issuer())
scope, ok := token.Get("scope")
require.True(t, ok)
assert.Equal(t, strings.Join(options.Scopes, " "), scope)
}
func TestVerify(t *testing.T) {
t.Parallel()
repo := repository.NewMemoryTokenRepository(new(sync.Map))
useCase := ucase.NewTokenUseCase(repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)))
ucase := usecase.NewTokenUseCase(repo, nil, domain.TestConfig(t))
t.Run("valid", func(t *testing.T) {
t.Parallel()
accessToken := domain.TestToken(t)
result, err := useCase.Verify(context.TODO(), accessToken.AccessToken)
result, err := ucase.Verify(context.TODO(), accessToken.AccessToken)
require.NoError(t, err)
assert.Equal(t, accessToken, result)
assert.Equal(t, accessToken.AccessToken, result.AccessToken)
assert.Equal(t, accessToken.Scope, result.Scope)
assert.Equal(t, accessToken.ClientID.String(), result.ClientID.String())
assert.Equal(t, accessToken.Me.String(), result.Me.String())
})
t.Run("revoke", func(t *testing.T) {
t.Run("revoked", func(t *testing.T) {
t.Parallel()
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
result, err := useCase.Verify(context.TODO(), accessToken.AccessToken)
result, err := ucase.Verify(context.TODO(), accessToken.AccessToken)
require.ErrorIs(t, err, token.ErrRevoke)
assert.Nil(t, result)
})
@ -79,16 +52,12 @@ func TestVerify(t *testing.T) {
func TestRevoke(t *testing.T) {
t.Parallel()
v := viper.New()
v.Set("indieauth.jwtSigningAlgorithm", "HS256")
v.Set("indieauth.jwtSecret", "hackme")
repo := repository.NewMemoryTokenRepository(new(sync.Map))
config := domain.TestConfig(t)
accessToken := domain.TestToken(t)
repo := repository.NewMemoryTokenRepository(new(sync.Map))
require.NoError(t, ucase.NewTokenUseCase(
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
).Revoke(context.TODO(), accessToken.AccessToken))
require.NoError(t, usecase.NewTokenUseCase(repo, nil, config).
Revoke(context.TODO(), accessToken.AccessToken))
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
assert.NoError(t, err)