♻️ Refactored token package
This commit is contained in:
parent
83dc4286eb
commit
75f6cb168f
|
@ -1,7 +1,6 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
|
@ -9,99 +8,178 @@ import (
|
|||
http "github.com/valyala/fasthttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"source.toby3d.me/toby3d/form"
|
||||
"source.toby3d.me/toby3d/middleware"
|
||||
"source.toby3d.me/website/indieauth/internal/common"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/token"
|
||||
)
|
||||
|
||||
type (
|
||||
RevocationRequest struct {
|
||||
Action string
|
||||
Token string
|
||||
ExchangeRequest struct {
|
||||
ClientID *domain.ClientID `form:"client_id"`
|
||||
RedirectURI *domain.URL `form:"redirect_uri"`
|
||||
GrantType domain.GrantType `form:"grant_type"`
|
||||
Code string `form:"code"`
|
||||
CodeVerifier string `form:"code_verifier"`
|
||||
}
|
||||
|
||||
RevokeRequest struct {
|
||||
Action domain.Action `form:"action"`
|
||||
Token string `form:"token"`
|
||||
}
|
||||
|
||||
TicketRequest struct {
|
||||
Action domain.Action `form:"action"`
|
||||
Ticket string `form:"ticket"`
|
||||
}
|
||||
|
||||
//nolint: tagliatelle
|
||||
ExchangeResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
Me string `json:"me"`
|
||||
}
|
||||
|
||||
//nolint: tagliatelle
|
||||
VerificationResponse struct {
|
||||
Me string `json:"me"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
Me *domain.Me `json:"me"`
|
||||
ClientID *domain.ClientID `json:"client_id"`
|
||||
Scope domain.Scopes `json:"scope"`
|
||||
}
|
||||
|
||||
RevocationResponse struct{}
|
||||
|
||||
RequestHandler struct {
|
||||
tokener token.UseCase
|
||||
tokens token.UseCase
|
||||
// TODO(toby3d): tickets ticket.UseCase
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
Action string = "action"
|
||||
ActionRevoke string = "revoke"
|
||||
)
|
||||
|
||||
func NewRequestHandler(tokener token.UseCase) *RequestHandler {
|
||||
func NewRequestHandler(tokens token.UseCase /*, tickets ticket.UseCase*/) *RequestHandler {
|
||||
return &RequestHandler{
|
||||
tokener: tokener,
|
||||
tokens: tokens,
|
||||
// tickets: tickets,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RequestHandler) Register(r *router.Router) {
|
||||
r.GET("/token", h.Read)
|
||||
r.POST("/token", h.Update)
|
||||
chain := middleware.Chain{
|
||||
middleware.LogFmt(),
|
||||
}
|
||||
|
||||
r.GET("/token", chain.RequestHandler(h.handleValidate))
|
||||
r.POST("/token", chain.RequestHandler(h.handleAction))
|
||||
}
|
||||
|
||||
func (h *RequestHandler) Read(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMEApplicationJSON)
|
||||
ctx.SetStatusCode(http.StatusOK)
|
||||
|
||||
rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization)
|
||||
|
||||
t, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer "))))
|
||||
if err != nil {
|
||||
if xerrors.Is(err, token.ErrRevoke) {
|
||||
ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
} else {
|
||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if t == nil {
|
||||
ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(ctx).Encode(&VerificationResponse{
|
||||
ClientID: t.ClientID,
|
||||
Me: t.Me,
|
||||
Scope: strings.Join(t.Scopes, " "),
|
||||
}); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RequestHandler) Update(ctx *http.RequestCtx) {
|
||||
if strings.EqualFold(string(ctx.FormValue(Action)), ActionRevoke) {
|
||||
h.Revocation(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RequestHandler) Revocation(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMEApplicationJSON)
|
||||
func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||
ctx.SetStatusCode(http.StatusOK)
|
||||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(RevocationRequest)
|
||||
t, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)),
|
||||
"Bearer "))
|
||||
if err != nil || t == nil {
|
||||
ctx.SetStatusCode(http.StatusUnauthorized)
|
||||
encoder.Encode(&domain.Error{
|
||||
Code: "unauthorized_client",
|
||||
Description: err.Error(),
|
||||
Frame: xerrors.Caller(1),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
encoder.Encode(&VerificationResponse{
|
||||
ClientID: t.ClientID,
|
||||
Me: t.Me,
|
||||
Scope: t.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
switch {
|
||||
case ctx.PostArgs().Has("grant_type"):
|
||||
h.handleExchange(ctx)
|
||||
case ctx.PostArgs().Has("action"):
|
||||
action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action")))
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(domain.Error{
|
||||
Code: "invalid_request",
|
||||
Description: err.Error(),
|
||||
Frame: xerrors.Caller(1),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
switch action {
|
||||
case domain.ActionRevoke:
|
||||
h.handleRevoke(ctx)
|
||||
case domain.ActionTicket:
|
||||
h.handleTicket(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(ExchangeRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.tokens.Exchange(ctx, token.ExchangeOptions{
|
||||
ClientID: req.ClientID,
|
||||
RedirectURI: req.RedirectURI,
|
||||
Code: req.Code,
|
||||
CodeVerifier: req.CodeVerifier,
|
||||
})
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(&domain.Error{
|
||||
Description: err.Error(),
|
||||
Frame: xerrors.Caller(1),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
encoder.Encode(&ExchangeResponse{
|
||||
AccessToken: token.AccessToken,
|
||||
TokenType: "Bearer",
|
||||
Scope: token.Scope.String(),
|
||||
Me: token.Me.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||
ctx.SetStatusCode(http.StatusOK)
|
||||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(RevokeRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.tokener.Revoke(ctx, req.Token); err != nil {
|
||||
if err := h.tokens.Revoke(ctx, req.Token); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
|
@ -112,21 +190,65 @@ func (h *RequestHandler) Revocation(ctx *http.RequestCtx) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *RevocationRequest) bind(ctx *http.RequestCtx) error {
|
||||
if r.Action = string(ctx.FormValue(Action)); !strings.EqualFold(r.Action, ActionRevoke) {
|
||||
return domain.Error{
|
||||
Code: "invalid_request",
|
||||
Description: "request MUST contain 'action' key with value 'revoke'",
|
||||
URI: "https://indieauth.spec.indieweb.org/#token-revocation-request",
|
||||
Frame: xerrors.Caller(1),
|
||||
}
|
||||
func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||
ctx.SetStatusCode(http.StatusOK)
|
||||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(TicketRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if r.Token = string(ctx.FormValue("token")); r.Token == "" {
|
||||
/* TODO(toby3d)
|
||||
token, err := h.tickets.Redeem(ctx, req.Ticket)
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusInternalServerError)
|
||||
encoder.Encode(domain.Error{
|
||||
Description: err.Error(),
|
||||
Frame: xerrors.Caller(1),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
*/
|
||||
|
||||
encoder.Encode(ExchangeResponse{})
|
||||
}
|
||||
|
||||
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
|
||||
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||
return domain.Error{
|
||||
Code: "invalid_request",
|
||||
Description: "request MUST contain the 'token' key with the valid access token as its value",
|
||||
URI: "https://indieauth.spec.indieweb.org/#token-revocation-request",
|
||||
Description: err.Error(),
|
||||
Frame: xerrors.Caller(1),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RevokeRequest) bind(ctx *http.RequestCtx) error {
|
||||
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||
return domain.Error{
|
||||
Code: "invalid_request",
|
||||
Description: err.Error(),
|
||||
Frame: xerrors.Caller(1),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TicketRequest) bind(ctx *http.RequestCtx) error {
|
||||
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||
return domain.Error{
|
||||
Code: "invalid_request",
|
||||
Description: err.Error(),
|
||||
Frame: xerrors.Caller(1),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,82 +6,74 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/fasthttp/router"
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
http "github.com/valyala/fasthttp"
|
||||
|
||||
"source.toby3d.me/website/indieauth/internal/common"
|
||||
configrepo "source.toby3d.me/website/indieauth/internal/config/repository/viper"
|
||||
configucase "source.toby3d.me/website/indieauth/internal/config/usecase"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory"
|
||||
"source.toby3d.me/website/indieauth/internal/testing/httptest"
|
||||
delivery "source.toby3d.me/website/indieauth/internal/token/delivery/http"
|
||||
repository "source.toby3d.me/website/indieauth/internal/token/repository/memory"
|
||||
"source.toby3d.me/website/indieauth/internal/token/usecase"
|
||||
"source.toby3d.me/website/indieauth/internal/util"
|
||||
tokenrepo "source.toby3d.me/website/indieauth/internal/token/repository/memory"
|
||||
tokenucase "source.toby3d.me/website/indieauth/internal/token/usecase"
|
||||
)
|
||||
|
||||
func TestVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
v := viper.New()
|
||||
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
|
||||
v.SetDefault("indieauth.jwtSecret", "hackme")
|
||||
store := new(sync.Map)
|
||||
config := domain.TestConfig(t)
|
||||
token := domain.TestToken(t)
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
r := router.New()
|
||||
// TODO(toby3d): provide tickets
|
||||
delivery.NewRequestHandler(tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store),
|
||||
sessionrepo.NewMemorySessionRepository(config, store), config)).Register(r)
|
||||
|
||||
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(
|
||||
repository.NewMemoryTokenRepository(new(sync.Map)),
|
||||
configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
|
||||
)).Read)
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
req := http.AcquireRequest()
|
||||
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/token", nil)
|
||||
defer http.ReleaseRequest(req)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
req.SetRequestURI("https://app.example.com/token")
|
||||
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
|
||||
req.Header.Set(http.HeaderAuthorization, "Bearer "+accessToken.AccessToken)
|
||||
token.SetAuthHeader(req)
|
||||
|
||||
resp := http.AcquireResponse()
|
||||
defer http.ReleaseResponse(resp)
|
||||
|
||||
require.NoError(t, client.Do(req, resp))
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode())
|
||||
|
||||
token := new(delivery.VerificationResponse)
|
||||
require.NoError(t, json.Unmarshal(resp.Body(), token))
|
||||
assert.Equal(t, &delivery.VerificationResponse{
|
||||
Me: accessToken.Me,
|
||||
ClientID: accessToken.ClientID,
|
||||
Scope: strings.Join(accessToken.Scopes, " "),
|
||||
}, token)
|
||||
result := new(delivery.VerificationResponse)
|
||||
require.NoError(t, json.Unmarshal(resp.Body(), result))
|
||||
assert.Equal(t, token.ClientID.String(), result.ClientID.String())
|
||||
assert.Equal(t, token.Me.String(), result.Me.String())
|
||||
assert.Equal(t, token.Scope.String(), result.Scope.String())
|
||||
}
|
||||
|
||||
func TestRevocation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
v := viper.New()
|
||||
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
|
||||
v.SetDefault("indieauth.jwtSecret", "hackme")
|
||||
|
||||
tokens := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
config := domain.TestConfig(t)
|
||||
store := new(sync.Map)
|
||||
tokens := tokenrepo.NewMemoryTokenRepository(store)
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(
|
||||
usecase.NewTokenUseCase(tokens, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))),
|
||||
).Update)
|
||||
r := router.New()
|
||||
delivery.NewRequestHandler(tokenucase.NewTokenUseCase(tokens, sessionrepo.NewMemorySessionRepository(config,
|
||||
store), config)).Register(r)
|
||||
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
req := http.AcquireRequest()
|
||||
req := httptest.NewRequest(http.MethodPost, "https://app.example.com/token", nil)
|
||||
defer http.ReleaseRequest(req)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetRequestURI("https://app.example.com/token")
|
||||
req.Header.SetContentType(common.MIMEApplicationXWWWFormUrlencoded)
|
||||
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
|
||||
req.PostArgs().Set("action", "revoke")
|
||||
req.Header.SetContentType(common.MIMEApplicationForm)
|
||||
req.PostArgs().Set("action", domain.ActionRevoke.String())
|
||||
req.PostArgs().Set("token", accessToken.AccessToken)
|
||||
|
||||
resp := http.AcquireResponse()
|
||||
|
|
|
@ -1,129 +0,0 @@
|
|||
package bolt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/token"
|
||||
)
|
||||
|
||||
type (
|
||||
Token struct {
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
DeletedAt time.Time `json:"deletedAt,omitempty"`
|
||||
Scopes []string `json:"scopes"`
|
||||
AccessToken string `json:"accessToken"`
|
||||
ClientID string `json:"clientId"`
|
||||
Me string `json:"me"`
|
||||
}
|
||||
|
||||
boltTokenRepository struct {
|
||||
db *bolt.DB
|
||||
}
|
||||
)
|
||||
|
||||
func NewBoltTokenRepository(db *bolt.DB) token.Repository {
|
||||
return &boltTokenRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) (err error) {
|
||||
t, err := repo.Get(ctx, accessToken.AccessToken)
|
||||
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
|
||||
return errors.Wrap(err, "cannot get token in database")
|
||||
}
|
||||
|
||||
if t != nil {
|
||||
return token.ErrExist
|
||||
}
|
||||
|
||||
if err = repo.db.Update(func(tx *bolt.Tx) error {
|
||||
//nolint: exhaustivestruct
|
||||
bkt, err := tx.CreateBucketIfNotExists(Token{}.Bucket())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot create bucket")
|
||||
}
|
||||
|
||||
token := new(Token)
|
||||
token.Populate(accessToken)
|
||||
|
||||
src, err := json.Marshal(token)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot marshal token data")
|
||||
}
|
||||
|
||||
if err = bkt.Put([]byte(token.AccessToken), src); err != nil {
|
||||
return errors.Wrap(err, "cannot put token into bucket")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to put token into database")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
result := new(domain.Token)
|
||||
|
||||
if err := repo.db.View(func(tx *bolt.Tx) (err error) {
|
||||
t := new(Token)
|
||||
|
||||
bkt := tx.Bucket(t.Bucket())
|
||||
if bkt == nil {
|
||||
return token.ErrNotExist
|
||||
}
|
||||
|
||||
src := bkt.Get([]byte(accessToken))
|
||||
if src == nil {
|
||||
return token.ErrNotExist
|
||||
}
|
||||
|
||||
if err = t.Bind(src, result); err != nil {
|
||||
return errors.Wrap(err, "cannot parse token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to view token in database")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (Token) Bucket() []byte { return []byte("tokens") }
|
||||
|
||||
func (t *Token) Populate(accessToken *domain.Token) {
|
||||
t.AccessToken = accessToken.AccessToken
|
||||
t.ClientID = accessToken.ClientID
|
||||
t.CreatedAt = time.Now().UTC().Round(time.Second)
|
||||
t.Me = accessToken.Me
|
||||
t.Scopes = make([]string, len(accessToken.Scopes))
|
||||
t.UpdatedAt = t.CreatedAt
|
||||
|
||||
for i := range accessToken.Scopes {
|
||||
t.Scopes[i] = accessToken.Scopes[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Token) Bind(src []byte, accessToken *domain.Token) error {
|
||||
if err := json.Unmarshal(src, t); err != nil {
|
||||
return errors.Wrap(err, "cannot unmarshal token")
|
||||
}
|
||||
|
||||
accessToken.AccessToken = t.AccessToken
|
||||
accessToken.ClientID = t.ClientID
|
||||
accessToken.Me = t.Me
|
||||
accessToken.Scopes = t.Scopes
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
package bolt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/token"
|
||||
repository "source.toby3d.me/website/indieauth/internal/token/repository/bolt"
|
||||
"source.toby3d.me/website/indieauth/internal/util"
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint: exhaustivestruct
|
||||
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
|
||||
//nolint: exhaustivestruct
|
||||
_, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket())
|
||||
|
||||
//nolint: wrapcheck
|
||||
return err
|
||||
}))
|
||||
|
||||
repo := repository.NewBoltTokenRepository(db)
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
||||
|
||||
result := domain.NewToken()
|
||||
|
||||
require.NoError(t, db.View(func(tx *bolt.Tx) (err error) {
|
||||
dto := new(repository.Token)
|
||||
|
||||
//nolint: wrapcheck
|
||||
return dto.Bind(tx.Bucket(dto.Bucket()).Get([]byte(accessToken.AccessToken)), result)
|
||||
}))
|
||||
assert.Equal(t, accessToken, result)
|
||||
|
||||
assert.ErrorIs(t, repo.Create(context.TODO(), accessToken), token.ErrExist)
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint: exhaustivestruct
|
||||
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
|
||||
//nolint: exhaustivestruct
|
||||
bkt, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot create bucket")
|
||||
}
|
||||
|
||||
t := new(repository.Token)
|
||||
t.Populate(accessToken)
|
||||
|
||||
src, err := json.Marshal(t)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot marshal token data")
|
||||
}
|
||||
|
||||
if err = bkt.Put([]byte(t.AccessToken), src); err != nil {
|
||||
return errors.Wrap(err, "cannot put token into bucket")
|
||||
}
|
||||
|
||||
return nil
|
||||
}))
|
||||
|
||||
result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accessToken, result)
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/token"
|
||||
)
|
||||
|
||||
type (
|
||||
Token struct {
|
||||
CreatedAt sql.NullTime `db:"created_at"`
|
||||
AccessToken string `db:"access_token"`
|
||||
ClientID string `db:"client_id"`
|
||||
Me string `db:"me"`
|
||||
Scope string `db:"scope"`
|
||||
}
|
||||
|
||||
sqlite3TokenRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
QueryTable string = `CREATE TABLE IF NOT EXISTS tokens (
|
||||
access_token TEXT UNIQUE PRIMARY KEY NOT NULL,
|
||||
client_id TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
me TEXT NOT NULL,
|
||||
scope TEXT
|
||||
);`
|
||||
|
||||
QueryGet string = `SELECT *
|
||||
FROM tokens
|
||||
WHERE access_token=$1;`
|
||||
|
||||
QueryCreate string = `INSERT INTO tokens (created_at, access_token, client_id, me, scope)
|
||||
VALUES (:created_at, :access_token, :client_id, :me, :scope);`
|
||||
)
|
||||
|
||||
func NewSQLite3TokenRepository(db *sqlx.DB) token.Repository {
|
||||
return &sqlite3TokenRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
|
||||
if _, err := repo.db.NamedExecContext(ctx, QueryTable+QueryCreate, NewToken(accessToken)); err != nil {
|
||||
return fmt.Errorf("cannot create token record in db: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
t := new(Token)
|
||||
if err := repo.db.GetContext(ctx, t, QueryTable+QueryGet, accessToken); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, token.ErrNotExist
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cannot find token in db: %w", err)
|
||||
}
|
||||
|
||||
result := new(domain.Token)
|
||||
t.Populate(result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func NewToken(src *domain.Token) *Token {
|
||||
return &Token{
|
||||
CreatedAt: sql.NullTime{
|
||||
Time: time.Now().UTC(),
|
||||
Valid: true,
|
||||
},
|
||||
AccessToken: src.AccessToken,
|
||||
ClientID: src.ClientID.String(),
|
||||
Me: src.Me.String(),
|
||||
Scope: src.Scope.String(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Token) Populate(dst *domain.Token) {
|
||||
dst.AccessToken = t.AccessToken
|
||||
dst.ClientID, _ = domain.NewClientID(t.ClientID)
|
||||
dst.Me, _ = domain.NewMe(t.Me)
|
||||
dst.Scope = make(domain.Scopes, 0)
|
||||
|
||||
for _, scope := range strings.Fields(t.Scope) {
|
||||
s, err := domain.ParseScope(scope)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Scope = append(dst.Scope, s)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/testing/sqltest"
|
||||
repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3"
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup := sqltest.Open(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
token := domain.TestToken(t)
|
||||
require.NoError(t, repository.NewSQLite3TokenRepository(db).Create(context.Background(), token))
|
||||
|
||||
results := make([]*repository.Token, 0)
|
||||
require.NoError(t, db.Select(&results, "SELECT * FROM tokens;"))
|
||||
require.Len(t, results, 1)
|
||||
|
||||
result := new(domain.Token)
|
||||
results[0].Populate(result)
|
||||
|
||||
assert.Equal(t, token.AccessToken, result.AccessToken)
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup := sqltest.Open(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
token := domain.TestToken(t)
|
||||
_, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewToken(token))
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := repository.NewSQLite3TokenRepository(db).Get(context.Background(), token.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.AccessToken, result.AccessToken)
|
||||
}
|
|
@ -2,108 +2,115 @@ package usecase
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"source.toby3d.me/website/indieauth/internal/config"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/random"
|
||||
"source.toby3d.me/website/indieauth/internal/session"
|
||||
"source.toby3d.me/website/indieauth/internal/token"
|
||||
)
|
||||
|
||||
type (
|
||||
Config struct {
|
||||
Configer config.UseCase
|
||||
Tokens token.Repository
|
||||
}
|
||||
type tokenUseCase struct {
|
||||
sessions session.Repository
|
||||
config *domain.Config
|
||||
tokens token.Repository
|
||||
}
|
||||
|
||||
tokenUseCase struct {
|
||||
configer config.UseCase
|
||||
tokens token.Repository
|
||||
}
|
||||
)
|
||||
//nolint: gochecknoinits
|
||||
func init() {
|
||||
jwt.RegisterCustomField("scope", make(domain.Scopes, 0))
|
||||
}
|
||||
|
||||
func NewTokenUseCase(config Config) token.UseCase {
|
||||
func NewTokenUseCase(tokens token.Repository, sessions session.Repository, config *domain.Config) token.UseCase {
|
||||
return &tokenUseCase{
|
||||
configer: config.Configer,
|
||||
tokens: config.Tokens,
|
||||
sessions: sessions,
|
||||
config: config,
|
||||
tokens: tokens,
|
||||
}
|
||||
}
|
||||
|
||||
// Generate generates a new Token based on the session data.
|
||||
func (useCase *tokenUseCase) Generate(ctx context.Context, opts token.GenerateOptions) (*domain.Token, error) {
|
||||
nonce, err := random.String(opts.NonceLength)
|
||||
func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOptions) (*domain.Token, error) {
|
||||
session, err := useCase.sessions.GetAndDelete(ctx, opts.Code)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot generate code")
|
||||
return nil, fmt.Errorf("cannot get session from store: %w", err)
|
||||
}
|
||||
|
||||
t := jwt.New()
|
||||
now := time.Now().UTC().Round(time.Second)
|
||||
|
||||
t.Set(jwt.IssuerKey, opts.ClientID)
|
||||
t.Set(jwt.SubjectKey, opts.Me)
|
||||
t.Set(jwt.ExpirationKey, now.Add(useCase.configer.GetIndieAuthAccessTokenExpirationTime()))
|
||||
t.Set(jwt.NotBeforeKey, now)
|
||||
t.Set(jwt.IssuedAtKey, now)
|
||||
t.Set("scope", strings.Join(opts.Scopes, " "))
|
||||
t.Set("nonce", nonce)
|
||||
|
||||
token, err := jwt.Sign(t,
|
||||
jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()),
|
||||
[]byte(useCase.configer.GetIndieAuthJWTSecret()))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot sign a new access token")
|
||||
if opts.ClientID.String() != session.ClientID.String() {
|
||||
return nil, domain.Error{
|
||||
Code: "invalid_request",
|
||||
Description: "client's URL MUST match the client_id used in the authentication request",
|
||||
URI: "https://indieauth.net/source/#request",
|
||||
Frame: xerrors.Caller(1),
|
||||
}
|
||||
}
|
||||
|
||||
return &domain.Token{
|
||||
Scopes: opts.Scopes,
|
||||
AccessToken: string(token),
|
||||
ClientID: opts.ClientID,
|
||||
Me: opts.Me,
|
||||
}, nil
|
||||
if opts.RedirectURI.String() != session.RedirectURI.String() {
|
||||
return nil, domain.Error{
|
||||
Code: "invalid_request",
|
||||
Description: "client's redirect URL MUST match the initial authentication request",
|
||||
URI: "https://indieauth.net/source/#request",
|
||||
Frame: xerrors.Caller(1),
|
||||
}
|
||||
}
|
||||
|
||||
if session.CodeChallenge != "" &&
|
||||
!session.CodeChallengeMethod.Validate(session.CodeChallenge, opts.CodeVerifier) {
|
||||
return nil, domain.Error{
|
||||
Code: "invalid_request",
|
||||
Description: "code_verifier is not hashes to the same value as given in " +
|
||||
"the code_challenge in the original authorization request",
|
||||
URI: "https://indieauth.net/source/#request",
|
||||
Frame: xerrors.Caller(1),
|
||||
}
|
||||
}
|
||||
|
||||
t, err := domain.NewToken(domain.NewTokenOptions{
|
||||
Algorithm: useCase.config.JWT.Algorithm,
|
||||
Expiration: useCase.config.JWT.Expiry,
|
||||
Issuer: session.ClientID,
|
||||
NonceLength: useCase.config.JWT.NonceLength,
|
||||
Scope: session.Scope,
|
||||
Secret: []byte(useCase.config.JWT.Secret),
|
||||
Subject: session.Me,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot generate a new access token: %w", err)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
find, err := useCase.tokens.Get(ctx, accessToken)
|
||||
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
|
||||
return nil, errors.Wrap(err, "cannot ckeck token in store")
|
||||
return nil, fmt.Errorf("cannot check token in store: %w", err)
|
||||
}
|
||||
|
||||
if find != nil {
|
||||
return nil, token.ErrRevoke
|
||||
}
|
||||
|
||||
t, err := jwt.ParseString(accessToken, jwt.WithVerify(
|
||||
jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()),
|
||||
[]byte(useCase.configer.GetIndieAuthJWTSecret()),
|
||||
))
|
||||
t, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm),
|
||||
[]byte(useCase.config.JWT.Secret)))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse JWT token")
|
||||
return nil, fmt.Errorf("cannot parse JWT token: %w", err)
|
||||
}
|
||||
|
||||
if err = jwt.Validate(t); err != nil {
|
||||
return nil, errors.Wrap(err, "cannot validate JWT token")
|
||||
return nil, fmt.Errorf("cannot validate JWT token: %w", err)
|
||||
}
|
||||
|
||||
result := &domain.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientID: t.Issuer(),
|
||||
Me: t.Subject(),
|
||||
Scopes: make([]string, 0),
|
||||
}
|
||||
result.ClientID, _ = domain.NewClientID(t.Issuer())
|
||||
result.Me, _ = domain.NewMe(t.Subject())
|
||||
|
||||
rawScope, ok := t.Get("scope")
|
||||
if !ok {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if scope, ok := rawScope.(string); ok {
|
||||
result.Scopes = strings.Fields(scope)
|
||||
if scope, ok := t.Get("scope"); ok {
|
||||
result.Scope, _ = scope.(domain.Scopes)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
@ -112,11 +119,11 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
|
|||
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
|
||||
t, err := useCase.Verify(ctx, accessToken)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot verify token")
|
||||
return fmt.Errorf("cannot verify token: %w", err)
|
||||
}
|
||||
|
||||
if err = useCase.tokens.Create(ctx, t); err != nil {
|
||||
return errors.Wrap(err, "cannot save token in database")
|
||||
return fmt.Errorf("cannot save token in database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -2,75 +2,48 @@ package usecase_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
configrepo "source.toby3d.me/website/indieauth/internal/config/repository/viper"
|
||||
configucase "source.toby3d.me/website/indieauth/internal/config/usecase"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/token"
|
||||
repository "source.toby3d.me/website/indieauth/internal/token/repository/memory"
|
||||
ucase "source.toby3d.me/website/indieauth/internal/token/usecase"
|
||||
usecase "source.toby3d.me/website/indieauth/internal/token/usecase"
|
||||
)
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
func TestExchange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
configer := configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(domain.TestConfig(t)))
|
||||
options := token.GenerateOptions{
|
||||
ClientID: "https://app.example.com/",
|
||||
Me: "https://user.example.net/",
|
||||
Scopes: []string{"create", "update", "delete"},
|
||||
NonceLength: 42,
|
||||
}
|
||||
|
||||
result, err := ucase.NewTokenUseCase(ucase.Config{
|
||||
Configer: configer,
|
||||
Tokens: nil,
|
||||
}).Generate(context.TODO(), options)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, options.ClientID, result.ClientID)
|
||||
assert.Equal(t, options.Me, result.Me)
|
||||
assert.Equal(t, options.Scopes, result.Scopes)
|
||||
|
||||
token, err := jwt.ParseString(result.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, options.Me, token.Subject())
|
||||
assert.Equal(t, options.ClientID, token.Issuer())
|
||||
|
||||
scope, ok := token.Get("scope")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, strings.Join(options.Scopes, " "), scope)
|
||||
}
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
useCase := ucase.NewTokenUseCase(repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)))
|
||||
ucase := usecase.NewTokenUseCase(repo, nil, domain.TestConfig(t))
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
result, err := useCase.Verify(context.TODO(), accessToken.AccessToken)
|
||||
result, err := ucase.Verify(context.TODO(), accessToken.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accessToken, result)
|
||||
assert.Equal(t, accessToken.AccessToken, result.AccessToken)
|
||||
assert.Equal(t, accessToken.Scope, result.Scope)
|
||||
assert.Equal(t, accessToken.ClientID.String(), result.ClientID.String())
|
||||
assert.Equal(t, accessToken.Me.String(), result.Me.String())
|
||||
})
|
||||
|
||||
t.Run("revoke", func(t *testing.T) {
|
||||
t.Run("revoked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
||||
|
||||
result, err := useCase.Verify(context.TODO(), accessToken.AccessToken)
|
||||
result, err := ucase.Verify(context.TODO(), accessToken.AccessToken)
|
||||
require.ErrorIs(t, err, token.ErrRevoke)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
@ -79,16 +52,12 @@ func TestVerify(t *testing.T) {
|
|||
func TestRevoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
v := viper.New()
|
||||
v.Set("indieauth.jwtSigningAlgorithm", "HS256")
|
||||
v.Set("indieauth.jwtSecret", "hackme")
|
||||
|
||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
config := domain.TestConfig(t)
|
||||
accessToken := domain.TestToken(t)
|
||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
|
||||
require.NoError(t, ucase.NewTokenUseCase(
|
||||
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
|
||||
).Revoke(context.TODO(), accessToken.AccessToken))
|
||||
require.NoError(t, usecase.NewTokenUseCase(repo, nil, config).
|
||||
Revoke(context.TODO(), accessToken.AccessToken))
|
||||
|
||||
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
|
||||
assert.NoError(t, err)
|
||||
|
|
Loading…
Reference in New Issue