diff --git a/internal/token/delivery/http/token_http.go b/internal/token/delivery/http/token_http.go index d16b906..806bbb7 100644 --- a/internal/token/delivery/http/token_http.go +++ b/internal/token/delivery/http/token_http.go @@ -1,7 +1,6 @@ package http import ( - "bytes" "strings" "github.com/fasthttp/router" @@ -9,99 +8,178 @@ import ( http "github.com/valyala/fasthttp" "golang.org/x/xerrors" + "source.toby3d.me/toby3d/form" + "source.toby3d.me/toby3d/middleware" "source.toby3d.me/website/indieauth/internal/common" "source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/token" ) type ( - RevocationRequest struct { - Action string - Token string + ExchangeRequest struct { + ClientID *domain.ClientID `form:"client_id"` + RedirectURI *domain.URL `form:"redirect_uri"` + GrantType domain.GrantType `form:"grant_type"` + Code string `form:"code"` + CodeVerifier string `form:"code_verifier"` + } + + RevokeRequest struct { + Action domain.Action `form:"action"` + Token string `form:"token"` + } + + TicketRequest struct { + Action domain.Action `form:"action"` + Ticket string `form:"ticket"` + } + + //nolint: tagliatelle + ExchangeResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Me string `json:"me"` } //nolint: tagliatelle VerificationResponse struct { - Me string `json:"me"` - ClientID string `json:"client_id"` - Scope string `json:"scope"` + Me *domain.Me `json:"me"` + ClientID *domain.ClientID `json:"client_id"` + Scope domain.Scopes `json:"scope"` } RevocationResponse struct{} RequestHandler struct { - tokener token.UseCase + tokens token.UseCase + // TODO(toby3d): tickets ticket.UseCase } ) -const ( - Action string = "action" - ActionRevoke string = "revoke" -) - -func NewRequestHandler(tokener token.UseCase) *RequestHandler { +func NewRequestHandler(tokens token.UseCase /*, tickets ticket.UseCase*/) *RequestHandler { return &RequestHandler{ - tokener: tokener, + tokens: tokens, + // tickets: tickets, } } func (h *RequestHandler) Register(r *router.Router) { - r.GET("/token", h.Read) - r.POST("/token", h.Update) + chain := middleware.Chain{ + middleware.LogFmt(), + } + + r.GET("/token", chain.RequestHandler(h.handleValidate)) + r.POST("/token", chain.RequestHandler(h.handleAction)) } -func (h *RequestHandler) Read(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSON) - ctx.SetStatusCode(http.StatusOK) - - rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization) - - t, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer ")))) - if err != nil { - if xerrors.Is(err, token.ErrRevoke) { - ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized) - } else { - ctx.Error(err.Error(), http.StatusBadRequest) - } - - return - } - - if t == nil { - ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized) - - return - } - - if err := json.NewEncoder(ctx).Encode(&VerificationResponse{ - ClientID: t.ClientID, - Me: t.Me, - Scope: strings.Join(t.Scopes, " "), - }); err != nil { - ctx.Error(err.Error(), http.StatusInternalServerError) - } -} - -func (h *RequestHandler) Update(ctx *http.RequestCtx) { - if strings.EqualFold(string(ctx.FormValue(Action)), ActionRevoke) { - h.Revocation(ctx) - } -} - -func (h *RequestHandler) Revocation(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSON) +func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) { + ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) ctx.SetStatusCode(http.StatusOK) encoder := json.NewEncoder(ctx) - req := new(RevocationRequest) + t, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)), + "Bearer ")) + if err != nil || t == nil { + ctx.SetStatusCode(http.StatusUnauthorized) + encoder.Encode(&domain.Error{ + Code: "unauthorized_client", + Description: err.Error(), + Frame: xerrors.Caller(1), + }) + + return + } + + encoder.Encode(&VerificationResponse{ + ClientID: t.ClientID, + Me: t.Me, + Scope: t.Scope, + }) +} + +func (h *RequestHandler) handleAction(ctx *http.RequestCtx) { + ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(ctx) + + switch { + case ctx.PostArgs().Has("grant_type"): + h.handleExchange(ctx) + case ctx.PostArgs().Has("action"): + action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action"))) + if err != nil { + ctx.SetStatusCode(http.StatusBadRequest) + encoder.Encode(domain.Error{ + Code: "invalid_request", + Description: err.Error(), + Frame: xerrors.Caller(1), + }) + + return + } + + switch action { + case domain.ActionRevoke: + h.handleRevoke(ctx) + case domain.ActionTicket: + h.handleTicket(ctx) + } + } +} + +func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { + ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(ctx) + + req := new(ExchangeRequest) + if err := req.bind(ctx); err != nil { + ctx.SetStatusCode(http.StatusBadRequest) + encoder.Encode(err) + + return + } + + token, err := h.tokens.Exchange(ctx, token.ExchangeOptions{ + ClientID: req.ClientID, + RedirectURI: req.RedirectURI, + Code: req.Code, + CodeVerifier: req.CodeVerifier, + }) + if err != nil { + ctx.SetStatusCode(http.StatusBadRequest) + encoder.Encode(&domain.Error{ + Description: err.Error(), + Frame: xerrors.Caller(1), + }) + + return + } + + encoder.Encode(&ExchangeResponse{ + AccessToken: token.AccessToken, + TokenType: "Bearer", + Scope: token.Scope.String(), + Me: token.Me.String(), + }) +} + +func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) { + ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) + ctx.SetStatusCode(http.StatusOK) + + encoder := json.NewEncoder(ctx) + + req := new(RevokeRequest) if err := req.bind(ctx); err != nil { ctx.Error(err.Error(), http.StatusBadRequest) return } - if err := h.tokener.Revoke(ctx, req.Token); err != nil { + if err := h.tokens.Revoke(ctx, req.Token); err != nil { ctx.Error(err.Error(), http.StatusBadRequest) return @@ -112,21 +190,65 @@ func (h *RequestHandler) Revocation(ctx *http.RequestCtx) { } } -func (r *RevocationRequest) bind(ctx *http.RequestCtx) error { - if r.Action = string(ctx.FormValue(Action)); !strings.EqualFold(r.Action, ActionRevoke) { - return domain.Error{ - Code: "invalid_request", - Description: "request MUST contain 'action' key with value 'revoke'", - URI: "https://indieauth.spec.indieweb.org/#token-revocation-request", - Frame: xerrors.Caller(1), - } +func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) { + ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) + ctx.SetStatusCode(http.StatusOK) + + encoder := json.NewEncoder(ctx) + + req := new(TicketRequest) + if err := req.bind(ctx); err != nil { + ctx.SetStatusCode(http.StatusBadRequest) + encoder.Encode(err) + + return } - if r.Token = string(ctx.FormValue("token")); r.Token == "" { + /* TODO(toby3d) + token, err := h.tickets.Redeem(ctx, req.Ticket) + if err != nil { + ctx.SetStatusCode(http.StatusInternalServerError) + encoder.Encode(domain.Error{ + Description: err.Error(), + Frame: xerrors.Caller(1), + }) + + return + } + */ + + encoder.Encode(ExchangeResponse{}) +} + +func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error { + if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { return domain.Error{ Code: "invalid_request", - Description: "request MUST contain the 'token' key with the valid access token as its value", - URI: "https://indieauth.spec.indieweb.org/#token-revocation-request", + Description: err.Error(), + Frame: xerrors.Caller(1), + } + } + + return nil +} + +func (r *RevokeRequest) bind(ctx *http.RequestCtx) error { + if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { + return domain.Error{ + Code: "invalid_request", + Description: err.Error(), + Frame: xerrors.Caller(1), + } + } + + return nil +} + +func (r *TicketRequest) bind(ctx *http.RequestCtx) error { + if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { + return domain.Error{ + Code: "invalid_request", + Description: err.Error(), Frame: xerrors.Caller(1), } } diff --git a/internal/token/delivery/http/token_http_test.go b/internal/token/delivery/http/token_http_test.go index 07eea0c..1f80d96 100644 --- a/internal/token/delivery/http/token_http_test.go +++ b/internal/token/delivery/http/token_http_test.go @@ -6,82 +6,74 @@ import ( "sync" "testing" - "github.com/goccy/go-json" - "github.com/spf13/viper" + "github.com/fasthttp/router" + json "github.com/goccy/go-json" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" http "github.com/valyala/fasthttp" "source.toby3d.me/website/indieauth/internal/common" - configrepo "source.toby3d.me/website/indieauth/internal/config/repository/viper" - configucase "source.toby3d.me/website/indieauth/internal/config/usecase" "source.toby3d.me/website/indieauth/internal/domain" + sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory" + "source.toby3d.me/website/indieauth/internal/testing/httptest" delivery "source.toby3d.me/website/indieauth/internal/token/delivery/http" - repository "source.toby3d.me/website/indieauth/internal/token/repository/memory" - "source.toby3d.me/website/indieauth/internal/token/usecase" - "source.toby3d.me/website/indieauth/internal/util" + tokenrepo "source.toby3d.me/website/indieauth/internal/token/repository/memory" + tokenucase "source.toby3d.me/website/indieauth/internal/token/usecase" ) func TestVerification(t *testing.T) { t.Parallel() - v := viper.New() - v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256") - v.SetDefault("indieauth.jwtSecret", "hackme") + store := new(sync.Map) + config := domain.TestConfig(t) + token := domain.TestToken(t) - accessToken := domain.TestToken(t) + r := router.New() + // TODO(toby3d): provide tickets + delivery.NewRequestHandler(tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store), + sessionrepo.NewMemorySessionRepository(config, store), config)).Register(r) - client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase( - repository.NewMemoryTokenRepository(new(sync.Map)), - configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)), - )).Read) + client, _, cleanup := httptest.New(t, r.Handler) t.Cleanup(cleanup) - req := http.AcquireRequest() + req := httptest.NewRequest(http.MethodGet, "https://app.example.com/token", nil) defer http.ReleaseRequest(req) - req.Header.SetMethod(http.MethodGet) - req.SetRequestURI("https://app.example.com/token") req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON) - req.Header.Set(http.HeaderAuthorization, "Bearer "+accessToken.AccessToken) + token.SetAuthHeader(req) resp := http.AcquireResponse() defer http.ReleaseResponse(resp) require.NoError(t, client.Do(req, resp)) - assert.Equal(t, http.StatusOK, resp.StatusCode()) - token := new(delivery.VerificationResponse) - require.NoError(t, json.Unmarshal(resp.Body(), token)) - assert.Equal(t, &delivery.VerificationResponse{ - Me: accessToken.Me, - ClientID: accessToken.ClientID, - Scope: strings.Join(accessToken.Scopes, " "), - }, token) + result := new(delivery.VerificationResponse) + require.NoError(t, json.Unmarshal(resp.Body(), result)) + assert.Equal(t, token.ClientID.String(), result.ClientID.String()) + assert.Equal(t, token.Me.String(), result.Me.String()) + assert.Equal(t, token.Scope.String(), result.Scope.String()) } func TestRevocation(t *testing.T) { t.Parallel() - v := viper.New() - v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256") - v.SetDefault("indieauth.jwtSecret", "hackme") - - tokens := repository.NewMemoryTokenRepository(new(sync.Map)) + config := domain.TestConfig(t) + store := new(sync.Map) + tokens := tokenrepo.NewMemoryTokenRepository(store) accessToken := domain.TestToken(t) - client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler( - usecase.NewTokenUseCase(tokens, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))), - ).Update) + r := router.New() + delivery.NewRequestHandler(tokenucase.NewTokenUseCase(tokens, sessionrepo.NewMemorySessionRepository(config, + store), config)).Register(r) + + client, _, cleanup := httptest.New(t, r.Handler) t.Cleanup(cleanup) - req := http.AcquireRequest() + req := httptest.NewRequest(http.MethodPost, "https://app.example.com/token", nil) defer http.ReleaseRequest(req) - req.Header.SetMethod(http.MethodPost) - req.SetRequestURI("https://app.example.com/token") - req.Header.SetContentType(common.MIMEApplicationXWWWFormUrlencoded) req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON) - req.PostArgs().Set("action", "revoke") + req.Header.SetContentType(common.MIMEApplicationForm) + req.PostArgs().Set("action", domain.ActionRevoke.String()) req.PostArgs().Set("token", accessToken.AccessToken) resp := http.AcquireResponse() diff --git a/internal/token/repository/bolt/bolt_token.go b/internal/token/repository/bolt/bolt_token.go deleted file mode 100644 index 4606593..0000000 --- a/internal/token/repository/bolt/bolt_token.go +++ /dev/null @@ -1,129 +0,0 @@ -package bolt - -import ( - "context" - "encoding/json" - "time" - - "github.com/pkg/errors" - bolt "go.etcd.io/bbolt" - "golang.org/x/xerrors" - - "source.toby3d.me/website/indieauth/internal/domain" - "source.toby3d.me/website/indieauth/internal/token" -) - -type ( - Token struct { - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` - DeletedAt time.Time `json:"deletedAt,omitempty"` - Scopes []string `json:"scopes"` - AccessToken string `json:"accessToken"` - ClientID string `json:"clientId"` - Me string `json:"me"` - } - - boltTokenRepository struct { - db *bolt.DB - } -) - -func NewBoltTokenRepository(db *bolt.DB) token.Repository { - return &boltTokenRepository{ - db: db, - } -} - -func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) (err error) { - t, err := repo.Get(ctx, accessToken.AccessToken) - if err != nil && !xerrors.Is(err, token.ErrNotExist) { - return errors.Wrap(err, "cannot get token in database") - } - - if t != nil { - return token.ErrExist - } - - if err = repo.db.Update(func(tx *bolt.Tx) error { - //nolint: exhaustivestruct - bkt, err := tx.CreateBucketIfNotExists(Token{}.Bucket()) - if err != nil { - return errors.Wrap(err, "cannot create bucket") - } - - token := new(Token) - token.Populate(accessToken) - - src, err := json.Marshal(token) - if err != nil { - return errors.Wrap(err, "cannot marshal token data") - } - - if err = bkt.Put([]byte(token.AccessToken), src); err != nil { - return errors.Wrap(err, "cannot put token into bucket") - } - - return nil - }); err != nil { - return errors.Wrap(err, "failed to put token into database") - } - - return nil -} - -func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { - result := new(domain.Token) - - if err := repo.db.View(func(tx *bolt.Tx) (err error) { - t := new(Token) - - bkt := tx.Bucket(t.Bucket()) - if bkt == nil { - return token.ErrNotExist - } - - src := bkt.Get([]byte(accessToken)) - if src == nil { - return token.ErrNotExist - } - - if err = t.Bind(src, result); err != nil { - return errors.Wrap(err, "cannot parse token") - } - - return nil - }); err != nil { - return nil, errors.Wrap(err, "failed to view token in database") - } - - return result, nil -} - -func (Token) Bucket() []byte { return []byte("tokens") } - -func (t *Token) Populate(accessToken *domain.Token) { - t.AccessToken = accessToken.AccessToken - t.ClientID = accessToken.ClientID - t.CreatedAt = time.Now().UTC().Round(time.Second) - t.Me = accessToken.Me - t.Scopes = make([]string, len(accessToken.Scopes)) - t.UpdatedAt = t.CreatedAt - - for i := range accessToken.Scopes { - t.Scopes[i] = accessToken.Scopes[i] - } -} - -func (t *Token) Bind(src []byte, accessToken *domain.Token) error { - if err := json.Unmarshal(src, t); err != nil { - return errors.Wrap(err, "cannot unmarshal token") - } - - accessToken.AccessToken = t.AccessToken - accessToken.ClientID = t.ClientID - accessToken.Me = t.Me - accessToken.Scopes = t.Scopes - - return nil -} diff --git a/internal/token/repository/bolt/bolt_token_test.go b/internal/token/repository/bolt/bolt_token_test.go deleted file mode 100644 index 1ed75bf..0000000 --- a/internal/token/repository/bolt/bolt_token_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package bolt_test - -import ( - "context" - "encoding/json" - "testing" - - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - bolt "go.etcd.io/bbolt" - - "source.toby3d.me/website/indieauth/internal/domain" - "source.toby3d.me/website/indieauth/internal/token" - repository "source.toby3d.me/website/indieauth/internal/token/repository/bolt" - "source.toby3d.me/website/indieauth/internal/util" -) - -func TestCreate(t *testing.T) { - t.Parallel() - - //nolint: exhaustivestruct - db, cleanup := util.TestBolt(t, repository.Token{}.Bucket()) - t.Cleanup(cleanup) - - require.NoError(t, db.Update(func(tx *bolt.Tx) error { - //nolint: exhaustivestruct - _, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket()) - - //nolint: wrapcheck - return err - })) - - repo := repository.NewBoltTokenRepository(db) - accessToken := domain.TestToken(t) - - require.NoError(t, repo.Create(context.TODO(), accessToken)) - - result := domain.NewToken() - - require.NoError(t, db.View(func(tx *bolt.Tx) (err error) { - dto := new(repository.Token) - - //nolint: wrapcheck - return dto.Bind(tx.Bucket(dto.Bucket()).Get([]byte(accessToken.AccessToken)), result) - })) - assert.Equal(t, accessToken, result) - - assert.ErrorIs(t, repo.Create(context.TODO(), accessToken), token.ErrExist) -} - -func TestGet(t *testing.T) { - t.Parallel() - - //nolint: exhaustivestruct - db, cleanup := util.TestBolt(t, repository.Token{}.Bucket()) - t.Cleanup(cleanup) - - accessToken := domain.TestToken(t) - - require.NoError(t, db.Update(func(tx *bolt.Tx) error { - //nolint: exhaustivestruct - bkt, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket()) - if err != nil { - return errors.Wrap(err, "cannot create bucket") - } - - t := new(repository.Token) - t.Populate(accessToken) - - src, err := json.Marshal(t) - if err != nil { - return errors.Wrap(err, "cannot marshal token data") - } - - if err = bkt.Put([]byte(t.AccessToken), src); err != nil { - return errors.Wrap(err, "cannot put token into bucket") - } - - return nil - })) - - result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken) - assert.NoError(t, err) - assert.Equal(t, accessToken, result) -} diff --git a/internal/token/repository/sqlite3/sqlite3_token.go b/internal/token/repository/sqlite3/sqlite3_token.go new file mode 100644 index 0000000..fb0629e --- /dev/null +++ b/internal/token/repository/sqlite3/sqlite3_token.go @@ -0,0 +1,105 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/jmoiron/sqlx" + + "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/token" +) + +type ( + Token struct { + CreatedAt sql.NullTime `db:"created_at"` + AccessToken string `db:"access_token"` + ClientID string `db:"client_id"` + Me string `db:"me"` + Scope string `db:"scope"` + } + + sqlite3TokenRepository struct { + db *sqlx.DB + } +) + +const ( + QueryTable string = `CREATE TABLE IF NOT EXISTS tokens ( + access_token TEXT UNIQUE PRIMARY KEY NOT NULL, + client_id TEXT NOT NULL, + created_at DATETIME NOT NULL, + me TEXT NOT NULL, + scope TEXT + );` + + QueryGet string = `SELECT * + FROM tokens + WHERE access_token=$1;` + + QueryCreate string = `INSERT INTO tokens (created_at, access_token, client_id, me, scope) + VALUES (:created_at, :access_token, :client_id, :me, :scope);` +) + +func NewSQLite3TokenRepository(db *sqlx.DB) token.Repository { + return &sqlite3TokenRepository{ + db: db, + } +} + +func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { + if _, err := repo.db.NamedExecContext(ctx, QueryTable+QueryCreate, NewToken(accessToken)); err != nil { + return fmt.Errorf("cannot create token record in db: %w", err) + } + + return nil +} + +func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { + t := new(Token) + if err := repo.db.GetContext(ctx, t, QueryTable+QueryGet, accessToken); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, token.ErrNotExist + } + + return nil, fmt.Errorf("cannot find token in db: %w", err) + } + + result := new(domain.Token) + t.Populate(result) + + return result, nil +} + +func NewToken(src *domain.Token) *Token { + return &Token{ + CreatedAt: sql.NullTime{ + Time: time.Now().UTC(), + Valid: true, + }, + AccessToken: src.AccessToken, + ClientID: src.ClientID.String(), + Me: src.Me.String(), + Scope: src.Scope.String(), + } +} + +func (t *Token) Populate(dst *domain.Token) { + dst.AccessToken = t.AccessToken + dst.ClientID, _ = domain.NewClientID(t.ClientID) + dst.Me, _ = domain.NewMe(t.Me) + dst.Scope = make(domain.Scopes, 0) + + for _, scope := range strings.Fields(t.Scope) { + s, err := domain.ParseScope(scope) + if err != nil { + continue + } + + dst.Scope = append(dst.Scope, s) + } +} diff --git a/internal/token/repository/sqlite3/sqlite3_token_test.go b/internal/token/repository/sqlite3/sqlite3_token_test.go new file mode 100644 index 0000000..49fc618 --- /dev/null +++ b/internal/token/repository/sqlite3/sqlite3_token_test.go @@ -0,0 +1,47 @@ +package sqlite3_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/testing/sqltest" + repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3" +) + +func TestCreate(t *testing.T) { + t.Parallel() + + db, cleanup := sqltest.Open(t) + t.Cleanup(cleanup) + + token := domain.TestToken(t) + require.NoError(t, repository.NewSQLite3TokenRepository(db).Create(context.Background(), token)) + + results := make([]*repository.Token, 0) + require.NoError(t, db.Select(&results, "SELECT * FROM tokens;")) + require.Len(t, results, 1) + + result := new(domain.Token) + results[0].Populate(result) + + assert.Equal(t, token.AccessToken, result.AccessToken) +} + +func TestGet(t *testing.T) { + t.Parallel() + + db, cleanup := sqltest.Open(t) + t.Cleanup(cleanup) + + token := domain.TestToken(t) + _, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewToken(token)) + require.NoError(t, err) + + result, err := repository.NewSQLite3TokenRepository(db).Get(context.Background(), token.AccessToken) + require.NoError(t, err) + assert.Equal(t, token.AccessToken, result.AccessToken) +} diff --git a/internal/token/usecase/token_ucase.go b/internal/token/usecase/token_ucase.go index 3e4d661..bb8dbce 100644 --- a/internal/token/usecase/token_ucase.go +++ b/internal/token/usecase/token_ucase.go @@ -2,108 +2,115 @@ package usecase import ( "context" - "strings" - "time" + "fmt" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwt" - "github.com/pkg/errors" "golang.org/x/xerrors" - "source.toby3d.me/website/indieauth/internal/config" "source.toby3d.me/website/indieauth/internal/domain" - "source.toby3d.me/website/indieauth/internal/random" + "source.toby3d.me/website/indieauth/internal/session" "source.toby3d.me/website/indieauth/internal/token" ) -type ( - Config struct { - Configer config.UseCase - Tokens token.Repository - } +type tokenUseCase struct { + sessions session.Repository + config *domain.Config + tokens token.Repository +} - tokenUseCase struct { - configer config.UseCase - tokens token.Repository - } -) +//nolint: gochecknoinits +func init() { + jwt.RegisterCustomField("scope", make(domain.Scopes, 0)) +} -func NewTokenUseCase(config Config) token.UseCase { +func NewTokenUseCase(tokens token.Repository, sessions session.Repository, config *domain.Config) token.UseCase { return &tokenUseCase{ - configer: config.Configer, - tokens: config.Tokens, + sessions: sessions, + config: config, + tokens: tokens, } } -// Generate generates a new Token based on the session data. -func (useCase *tokenUseCase) Generate(ctx context.Context, opts token.GenerateOptions) (*domain.Token, error) { - nonce, err := random.String(opts.NonceLength) +func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOptions) (*domain.Token, error) { + session, err := useCase.sessions.GetAndDelete(ctx, opts.Code) if err != nil { - return nil, errors.Wrap(err, "cannot generate code") + return nil, fmt.Errorf("cannot get session from store: %w", err) } - t := jwt.New() - now := time.Now().UTC().Round(time.Second) - - t.Set(jwt.IssuerKey, opts.ClientID) - t.Set(jwt.SubjectKey, opts.Me) - t.Set(jwt.ExpirationKey, now.Add(useCase.configer.GetIndieAuthAccessTokenExpirationTime())) - t.Set(jwt.NotBeforeKey, now) - t.Set(jwt.IssuedAtKey, now) - t.Set("scope", strings.Join(opts.Scopes, " ")) - t.Set("nonce", nonce) - - token, err := jwt.Sign(t, - jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()), - []byte(useCase.configer.GetIndieAuthJWTSecret())) - if err != nil { - return nil, errors.Wrap(err, "cannot sign a new access token") + if opts.ClientID.String() != session.ClientID.String() { + return nil, domain.Error{ + Code: "invalid_request", + Description: "client's URL MUST match the client_id used in the authentication request", + URI: "https://indieauth.net/source/#request", + Frame: xerrors.Caller(1), + } } - return &domain.Token{ - Scopes: opts.Scopes, - AccessToken: string(token), - ClientID: opts.ClientID, - Me: opts.Me, - }, nil + if opts.RedirectURI.String() != session.RedirectURI.String() { + return nil, domain.Error{ + Code: "invalid_request", + Description: "client's redirect URL MUST match the initial authentication request", + URI: "https://indieauth.net/source/#request", + Frame: xerrors.Caller(1), + } + } + + if session.CodeChallenge != "" && + !session.CodeChallengeMethod.Validate(session.CodeChallenge, opts.CodeVerifier) { + return nil, domain.Error{ + Code: "invalid_request", + Description: "code_verifier is not hashes to the same value as given in " + + "the code_challenge in the original authorization request", + URI: "https://indieauth.net/source/#request", + Frame: xerrors.Caller(1), + } + } + + t, err := domain.NewToken(domain.NewTokenOptions{ + Algorithm: useCase.config.JWT.Algorithm, + Expiration: useCase.config.JWT.Expiry, + Issuer: session.ClientID, + NonceLength: useCase.config.JWT.NonceLength, + Scope: session.Scope, + Secret: []byte(useCase.config.JWT.Secret), + Subject: session.Me, + }) + if err != nil { + return nil, fmt.Errorf("cannot generate a new access token: %w", err) + } + + return t, nil } func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) { find, err := useCase.tokens.Get(ctx, accessToken) if err != nil && !xerrors.Is(err, token.ErrNotExist) { - return nil, errors.Wrap(err, "cannot ckeck token in store") + return nil, fmt.Errorf("cannot check token in store: %w", err) } if find != nil { return nil, token.ErrRevoke } - t, err := jwt.ParseString(accessToken, jwt.WithVerify( - jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()), - []byte(useCase.configer.GetIndieAuthJWTSecret()), - )) + t, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm), + []byte(useCase.config.JWT.Secret))) if err != nil { - return nil, errors.Wrap(err, "cannot parse JWT token") + return nil, fmt.Errorf("cannot parse JWT token: %w", err) } if err = jwt.Validate(t); err != nil { - return nil, errors.Wrap(err, "cannot validate JWT token") + return nil, fmt.Errorf("cannot validate JWT token: %w", err) } result := &domain.Token{ AccessToken: accessToken, - ClientID: t.Issuer(), - Me: t.Subject(), - Scopes: make([]string, 0), } + result.ClientID, _ = domain.NewClientID(t.Issuer()) + result.Me, _ = domain.NewMe(t.Subject()) - rawScope, ok := t.Get("scope") - if !ok { - return result, nil - } - - if scope, ok := rawScope.(string); ok { - result.Scopes = strings.Fields(scope) + if scope, ok := t.Get("scope"); ok { + result.Scope, _ = scope.(domain.Scopes) } return result, nil @@ -112,11 +119,11 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error { t, err := useCase.Verify(ctx, accessToken) if err != nil { - return errors.Wrap(err, "cannot verify token") + return fmt.Errorf("cannot verify token: %w", err) } if err = useCase.tokens.Create(ctx, t); err != nil { - return errors.Wrap(err, "cannot save token in database") + return fmt.Errorf("cannot save token in database: %w", err) } return nil diff --git a/internal/token/usecase/token_ucase_test.go b/internal/token/usecase/token_ucase_test.go index edb59ba..300b2f9 100644 --- a/internal/token/usecase/token_ucase_test.go +++ b/internal/token/usecase/token_ucase_test.go @@ -2,75 +2,48 @@ package usecase_test import ( "context" - "strings" "sync" "testing" - "github.com/lestrrat-go/jwx/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - configrepo "source.toby3d.me/website/indieauth/internal/config/repository/viper" - configucase "source.toby3d.me/website/indieauth/internal/config/usecase" "source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/token" repository "source.toby3d.me/website/indieauth/internal/token/repository/memory" - ucase "source.toby3d.me/website/indieauth/internal/token/usecase" + usecase "source.toby3d.me/website/indieauth/internal/token/usecase" ) -func TestGenerate(t *testing.T) { +func TestExchange(t *testing.T) { t.Parallel() - - configer := configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(domain.TestConfig(t))) - options := token.GenerateOptions{ - ClientID: "https://app.example.com/", - Me: "https://user.example.net/", - Scopes: []string{"create", "update", "delete"}, - NonceLength: 42, - } - - result, err := ucase.NewTokenUseCase(ucase.Config{ - Configer: configer, - Tokens: nil, - }).Generate(context.TODO(), options) - require.NoError(t, err) - assert.Equal(t, options.ClientID, result.ClientID) - assert.Equal(t, options.Me, result.Me) - assert.Equal(t, options.Scopes, result.Scopes) - - token, err := jwt.ParseString(result.AccessToken) - require.NoError(t, err) - assert.Equal(t, options.Me, token.Subject()) - assert.Equal(t, options.ClientID, token.Issuer()) - - scope, ok := token.Get("scope") - require.True(t, ok) - assert.Equal(t, strings.Join(options.Scopes, " "), scope) } func TestVerify(t *testing.T) { t.Parallel() repo := repository.NewMemoryTokenRepository(new(sync.Map)) - useCase := ucase.NewTokenUseCase(repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))) + ucase := usecase.NewTokenUseCase(repo, nil, domain.TestConfig(t)) t.Run("valid", func(t *testing.T) { t.Parallel() accessToken := domain.TestToken(t) - result, err := useCase.Verify(context.TODO(), accessToken.AccessToken) + result, err := ucase.Verify(context.TODO(), accessToken.AccessToken) require.NoError(t, err) - assert.Equal(t, accessToken, result) + assert.Equal(t, accessToken.AccessToken, result.AccessToken) + assert.Equal(t, accessToken.Scope, result.Scope) + assert.Equal(t, accessToken.ClientID.String(), result.ClientID.String()) + assert.Equal(t, accessToken.Me.String(), result.Me.String()) }) - t.Run("revoke", func(t *testing.T) { + t.Run("revoked", func(t *testing.T) { t.Parallel() accessToken := domain.TestToken(t) require.NoError(t, repo.Create(context.TODO(), accessToken)) - result, err := useCase.Verify(context.TODO(), accessToken.AccessToken) + result, err := ucase.Verify(context.TODO(), accessToken.AccessToken) require.ErrorIs(t, err, token.ErrRevoke) assert.Nil(t, result) }) @@ -79,16 +52,12 @@ func TestVerify(t *testing.T) { func TestRevoke(t *testing.T) { t.Parallel() - v := viper.New() - v.Set("indieauth.jwtSigningAlgorithm", "HS256") - v.Set("indieauth.jwtSecret", "hackme") - - repo := repository.NewMemoryTokenRepository(new(sync.Map)) + config := domain.TestConfig(t) accessToken := domain.TestToken(t) + repo := repository.NewMemoryTokenRepository(new(sync.Map)) - require.NoError(t, ucase.NewTokenUseCase( - repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)), - ).Revoke(context.TODO(), accessToken.AccessToken)) + require.NoError(t, usecase.NewTokenUseCase(repo, nil, config). + Revoke(context.TODO(), accessToken.AccessToken)) result, err := repo.Get(context.TODO(), accessToken.AccessToken) assert.NoError(t, err)