♻️ Refactored token package
This commit is contained in:
parent
83dc4286eb
commit
75f6cb168f
|
@ -1,7 +1,6 @@
|
||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/fasthttp/router"
|
"github.com/fasthttp/router"
|
||||||
|
@ -9,99 +8,178 @@ import (
|
||||||
http "github.com/valyala/fasthttp"
|
http "github.com/valyala/fasthttp"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/form"
|
||||||
|
"source.toby3d.me/toby3d/middleware"
|
||||||
"source.toby3d.me/website/indieauth/internal/common"
|
"source.toby3d.me/website/indieauth/internal/common"
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
"source.toby3d.me/website/indieauth/internal/domain"
|
||||||
"source.toby3d.me/website/indieauth/internal/token"
|
"source.toby3d.me/website/indieauth/internal/token"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
RevocationRequest struct {
|
ExchangeRequest struct {
|
||||||
Action string
|
ClientID *domain.ClientID `form:"client_id"`
|
||||||
Token string
|
RedirectURI *domain.URL `form:"redirect_uri"`
|
||||||
|
GrantType domain.GrantType `form:"grant_type"`
|
||||||
|
Code string `form:"code"`
|
||||||
|
CodeVerifier string `form:"code_verifier"`
|
||||||
|
}
|
||||||
|
|
||||||
|
RevokeRequest struct {
|
||||||
|
Action domain.Action `form:"action"`
|
||||||
|
Token string `form:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
TicketRequest struct {
|
||||||
|
Action domain.Action `form:"action"`
|
||||||
|
Ticket string `form:"ticket"`
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint: tagliatelle
|
||||||
|
ExchangeResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
Me string `json:"me"`
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint: tagliatelle
|
//nolint: tagliatelle
|
||||||
VerificationResponse struct {
|
VerificationResponse struct {
|
||||||
Me string `json:"me"`
|
Me *domain.Me `json:"me"`
|
||||||
ClientID string `json:"client_id"`
|
ClientID *domain.ClientID `json:"client_id"`
|
||||||
Scope string `json:"scope"`
|
Scope domain.Scopes `json:"scope"`
|
||||||
}
|
}
|
||||||
|
|
||||||
RevocationResponse struct{}
|
RevocationResponse struct{}
|
||||||
|
|
||||||
RequestHandler struct {
|
RequestHandler struct {
|
||||||
tokener token.UseCase
|
tokens token.UseCase
|
||||||
|
// TODO(toby3d): tickets ticket.UseCase
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
func NewRequestHandler(tokens token.UseCase /*, tickets ticket.UseCase*/) *RequestHandler {
|
||||||
Action string = "action"
|
|
||||||
ActionRevoke string = "revoke"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewRequestHandler(tokener token.UseCase) *RequestHandler {
|
|
||||||
return &RequestHandler{
|
return &RequestHandler{
|
||||||
tokener: tokener,
|
tokens: tokens,
|
||||||
|
// tickets: tickets,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RequestHandler) Register(r *router.Router) {
|
func (h *RequestHandler) Register(r *router.Router) {
|
||||||
r.GET("/token", h.Read)
|
chain := middleware.Chain{
|
||||||
r.POST("/token", h.Update)
|
middleware.LogFmt(),
|
||||||
|
}
|
||||||
|
|
||||||
|
r.GET("/token", chain.RequestHandler(h.handleValidate))
|
||||||
|
r.POST("/token", chain.RequestHandler(h.handleAction))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RequestHandler) Read(ctx *http.RequestCtx) {
|
func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
|
||||||
ctx.SetContentType(common.MIMEApplicationJSON)
|
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||||
ctx.SetStatusCode(http.StatusOK)
|
|
||||||
|
|
||||||
rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization)
|
|
||||||
|
|
||||||
t, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer "))))
|
|
||||||
if err != nil {
|
|
||||||
if xerrors.Is(err, token.ErrRevoke) {
|
|
||||||
ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
||||||
} else {
|
|
||||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if t == nil {
|
|
||||||
ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(ctx).Encode(&VerificationResponse{
|
|
||||||
ClientID: t.ClientID,
|
|
||||||
Me: t.Me,
|
|
||||||
Scope: strings.Join(t.Scopes, " "),
|
|
||||||
}); err != nil {
|
|
||||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *RequestHandler) Update(ctx *http.RequestCtx) {
|
|
||||||
if strings.EqualFold(string(ctx.FormValue(Action)), ActionRevoke) {
|
|
||||||
h.Revocation(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *RequestHandler) Revocation(ctx *http.RequestCtx) {
|
|
||||||
ctx.SetContentType(common.MIMEApplicationJSON)
|
|
||||||
ctx.SetStatusCode(http.StatusOK)
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
|
|
||||||
encoder := json.NewEncoder(ctx)
|
encoder := json.NewEncoder(ctx)
|
||||||
|
|
||||||
req := new(RevocationRequest)
|
t, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)),
|
||||||
|
"Bearer "))
|
||||||
|
if err != nil || t == nil {
|
||||||
|
ctx.SetStatusCode(http.StatusUnauthorized)
|
||||||
|
encoder.Encode(&domain.Error{
|
||||||
|
Code: "unauthorized_client",
|
||||||
|
Description: err.Error(),
|
||||||
|
Frame: xerrors.Caller(1),
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.Encode(&VerificationResponse{
|
||||||
|
ClientID: t.ClientID,
|
||||||
|
Me: t.Me,
|
||||||
|
Scope: t.Scope,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
|
||||||
|
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||||
|
|
||||||
|
encoder := json.NewEncoder(ctx)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case ctx.PostArgs().Has("grant_type"):
|
||||||
|
h.handleExchange(ctx)
|
||||||
|
case ctx.PostArgs().Has("action"):
|
||||||
|
action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action")))
|
||||||
|
if err != nil {
|
||||||
|
ctx.SetStatusCode(http.StatusBadRequest)
|
||||||
|
encoder.Encode(domain.Error{
|
||||||
|
Code: "invalid_request",
|
||||||
|
Description: err.Error(),
|
||||||
|
Frame: xerrors.Caller(1),
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch action {
|
||||||
|
case domain.ActionRevoke:
|
||||||
|
h.handleRevoke(ctx)
|
||||||
|
case domain.ActionTicket:
|
||||||
|
h.handleTicket(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
|
||||||
|
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||||
|
|
||||||
|
encoder := json.NewEncoder(ctx)
|
||||||
|
|
||||||
|
req := new(ExchangeRequest)
|
||||||
|
if err := req.bind(ctx); err != nil {
|
||||||
|
ctx.SetStatusCode(http.StatusBadRequest)
|
||||||
|
encoder.Encode(err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := h.tokens.Exchange(ctx, token.ExchangeOptions{
|
||||||
|
ClientID: req.ClientID,
|
||||||
|
RedirectURI: req.RedirectURI,
|
||||||
|
Code: req.Code,
|
||||||
|
CodeVerifier: req.CodeVerifier,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
ctx.SetStatusCode(http.StatusBadRequest)
|
||||||
|
encoder.Encode(&domain.Error{
|
||||||
|
Description: err.Error(),
|
||||||
|
Frame: xerrors.Caller(1),
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.Encode(&ExchangeResponse{
|
||||||
|
AccessToken: token.AccessToken,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
Scope: token.Scope.String(),
|
||||||
|
Me: token.Me.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
|
||||||
|
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||||
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
|
|
||||||
|
encoder := json.NewEncoder(ctx)
|
||||||
|
|
||||||
|
req := new(RevokeRequest)
|
||||||
if err := req.bind(ctx); err != nil {
|
if err := req.bind(ctx); err != nil {
|
||||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
ctx.Error(err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.tokener.Revoke(ctx, req.Token); err != nil {
|
if err := h.tokens.Revoke(ctx, req.Token); err != nil {
|
||||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
ctx.Error(err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -112,21 +190,65 @@ func (h *RequestHandler) Revocation(ctx *http.RequestCtx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RevocationRequest) bind(ctx *http.RequestCtx) error {
|
func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
|
||||||
if r.Action = string(ctx.FormValue(Action)); !strings.EqualFold(r.Action, ActionRevoke) {
|
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||||
return domain.Error{
|
ctx.SetStatusCode(http.StatusOK)
|
||||||
Code: "invalid_request",
|
|
||||||
Description: "request MUST contain 'action' key with value 'revoke'",
|
encoder := json.NewEncoder(ctx)
|
||||||
URI: "https://indieauth.spec.indieweb.org/#token-revocation-request",
|
|
||||||
Frame: xerrors.Caller(1),
|
req := new(TicketRequest)
|
||||||
}
|
if err := req.bind(ctx); err != nil {
|
||||||
|
ctx.SetStatusCode(http.StatusBadRequest)
|
||||||
|
encoder.Encode(err)
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Token = string(ctx.FormValue("token")); r.Token == "" {
|
/* TODO(toby3d)
|
||||||
|
token, err := h.tickets.Redeem(ctx, req.Ticket)
|
||||||
|
if err != nil {
|
||||||
|
ctx.SetStatusCode(http.StatusInternalServerError)
|
||||||
|
encoder.Encode(domain.Error{
|
||||||
|
Description: err.Error(),
|
||||||
|
Frame: xerrors.Caller(1),
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
encoder.Encode(ExchangeResponse{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
|
||||||
|
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||||
return domain.Error{
|
return domain.Error{
|
||||||
Code: "invalid_request",
|
Code: "invalid_request",
|
||||||
Description: "request MUST contain the 'token' key with the valid access token as its value",
|
Description: err.Error(),
|
||||||
URI: "https://indieauth.spec.indieweb.org/#token-revocation-request",
|
Frame: xerrors.Caller(1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RevokeRequest) bind(ctx *http.RequestCtx) error {
|
||||||
|
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||||
|
return domain.Error{
|
||||||
|
Code: "invalid_request",
|
||||||
|
Description: err.Error(),
|
||||||
|
Frame: xerrors.Caller(1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TicketRequest) bind(ctx *http.RequestCtx) error {
|
||||||
|
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||||
|
return domain.Error{
|
||||||
|
Code: "invalid_request",
|
||||||
|
Description: err.Error(),
|
||||||
Frame: xerrors.Caller(1),
|
Frame: xerrors.Caller(1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,82 +6,74 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
"github.com/fasthttp/router"
|
||||||
"github.com/spf13/viper"
|
json "github.com/goccy/go-json"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
http "github.com/valyala/fasthttp"
|
http "github.com/valyala/fasthttp"
|
||||||
|
|
||||||
"source.toby3d.me/website/indieauth/internal/common"
|
"source.toby3d.me/website/indieauth/internal/common"
|
||||||
configrepo "source.toby3d.me/website/indieauth/internal/config/repository/viper"
|
|
||||||
configucase "source.toby3d.me/website/indieauth/internal/config/usecase"
|
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
"source.toby3d.me/website/indieauth/internal/domain"
|
||||||
|
sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory"
|
||||||
|
"source.toby3d.me/website/indieauth/internal/testing/httptest"
|
||||||
delivery "source.toby3d.me/website/indieauth/internal/token/delivery/http"
|
delivery "source.toby3d.me/website/indieauth/internal/token/delivery/http"
|
||||||
repository "source.toby3d.me/website/indieauth/internal/token/repository/memory"
|
tokenrepo "source.toby3d.me/website/indieauth/internal/token/repository/memory"
|
||||||
"source.toby3d.me/website/indieauth/internal/token/usecase"
|
tokenucase "source.toby3d.me/website/indieauth/internal/token/usecase"
|
||||||
"source.toby3d.me/website/indieauth/internal/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVerification(t *testing.T) {
|
func TestVerification(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
v := viper.New()
|
store := new(sync.Map)
|
||||||
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
|
config := domain.TestConfig(t)
|
||||||
v.SetDefault("indieauth.jwtSecret", "hackme")
|
token := domain.TestToken(t)
|
||||||
|
|
||||||
accessToken := domain.TestToken(t)
|
r := router.New()
|
||||||
|
// TODO(toby3d): provide tickets
|
||||||
|
delivery.NewRequestHandler(tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store),
|
||||||
|
sessionrepo.NewMemorySessionRepository(config, store), config)).Register(r)
|
||||||
|
|
||||||
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(
|
client, _, cleanup := httptest.New(t, r.Handler)
|
||||||
repository.NewMemoryTokenRepository(new(sync.Map)),
|
|
||||||
configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
|
|
||||||
)).Read)
|
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
req := http.AcquireRequest()
|
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/token", nil)
|
||||||
defer http.ReleaseRequest(req)
|
defer http.ReleaseRequest(req)
|
||||||
req.Header.SetMethod(http.MethodGet)
|
|
||||||
req.SetRequestURI("https://app.example.com/token")
|
|
||||||
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
|
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
|
||||||
req.Header.Set(http.HeaderAuthorization, "Bearer "+accessToken.AccessToken)
|
token.SetAuthHeader(req)
|
||||||
|
|
||||||
resp := http.AcquireResponse()
|
resp := http.AcquireResponse()
|
||||||
defer http.ReleaseResponse(resp)
|
defer http.ReleaseResponse(resp)
|
||||||
|
|
||||||
require.NoError(t, client.Do(req, resp))
|
require.NoError(t, client.Do(req, resp))
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, resp.StatusCode())
|
assert.Equal(t, http.StatusOK, resp.StatusCode())
|
||||||
|
|
||||||
token := new(delivery.VerificationResponse)
|
result := new(delivery.VerificationResponse)
|
||||||
require.NoError(t, json.Unmarshal(resp.Body(), token))
|
require.NoError(t, json.Unmarshal(resp.Body(), result))
|
||||||
assert.Equal(t, &delivery.VerificationResponse{
|
assert.Equal(t, token.ClientID.String(), result.ClientID.String())
|
||||||
Me: accessToken.Me,
|
assert.Equal(t, token.Me.String(), result.Me.String())
|
||||||
ClientID: accessToken.ClientID,
|
assert.Equal(t, token.Scope.String(), result.Scope.String())
|
||||||
Scope: strings.Join(accessToken.Scopes, " "),
|
|
||||||
}, token)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRevocation(t *testing.T) {
|
func TestRevocation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
v := viper.New()
|
config := domain.TestConfig(t)
|
||||||
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
|
store := new(sync.Map)
|
||||||
v.SetDefault("indieauth.jwtSecret", "hackme")
|
tokens := tokenrepo.NewMemoryTokenRepository(store)
|
||||||
|
|
||||||
tokens := repository.NewMemoryTokenRepository(new(sync.Map))
|
|
||||||
accessToken := domain.TestToken(t)
|
accessToken := domain.TestToken(t)
|
||||||
|
|
||||||
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(
|
r := router.New()
|
||||||
usecase.NewTokenUseCase(tokens, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))),
|
delivery.NewRequestHandler(tokenucase.NewTokenUseCase(tokens, sessionrepo.NewMemorySessionRepository(config,
|
||||||
).Update)
|
store), config)).Register(r)
|
||||||
|
|
||||||
|
client, _, cleanup := httptest.New(t, r.Handler)
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
req := http.AcquireRequest()
|
req := httptest.NewRequest(http.MethodPost, "https://app.example.com/token", nil)
|
||||||
defer http.ReleaseRequest(req)
|
defer http.ReleaseRequest(req)
|
||||||
req.Header.SetMethod(http.MethodPost)
|
|
||||||
req.SetRequestURI("https://app.example.com/token")
|
|
||||||
req.Header.SetContentType(common.MIMEApplicationXWWWFormUrlencoded)
|
|
||||||
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
|
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
|
||||||
req.PostArgs().Set("action", "revoke")
|
req.Header.SetContentType(common.MIMEApplicationForm)
|
||||||
|
req.PostArgs().Set("action", domain.ActionRevoke.String())
|
||||||
req.PostArgs().Set("token", accessToken.AccessToken)
|
req.PostArgs().Set("token", accessToken.AccessToken)
|
||||||
|
|
||||||
resp := http.AcquireResponse()
|
resp := http.AcquireResponse()
|
||||||
|
|
|
@ -1,129 +0,0 @@
|
||||||
package bolt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
bolt "go.etcd.io/bbolt"
|
|
||||||
"golang.org/x/xerrors"
|
|
||||||
|
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
|
||||||
"source.toby3d.me/website/indieauth/internal/token"
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
Token struct {
|
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
|
||||||
UpdatedAt time.Time `json:"updatedAt"`
|
|
||||||
DeletedAt time.Time `json:"deletedAt,omitempty"`
|
|
||||||
Scopes []string `json:"scopes"`
|
|
||||||
AccessToken string `json:"accessToken"`
|
|
||||||
ClientID string `json:"clientId"`
|
|
||||||
Me string `json:"me"`
|
|
||||||
}
|
|
||||||
|
|
||||||
boltTokenRepository struct {
|
|
||||||
db *bolt.DB
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewBoltTokenRepository(db *bolt.DB) token.Repository {
|
|
||||||
return &boltTokenRepository{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) (err error) {
|
|
||||||
t, err := repo.Get(ctx, accessToken.AccessToken)
|
|
||||||
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
|
|
||||||
return errors.Wrap(err, "cannot get token in database")
|
|
||||||
}
|
|
||||||
|
|
||||||
if t != nil {
|
|
||||||
return token.ErrExist
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = repo.db.Update(func(tx *bolt.Tx) error {
|
|
||||||
//nolint: exhaustivestruct
|
|
||||||
bkt, err := tx.CreateBucketIfNotExists(Token{}.Bucket())
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "cannot create bucket")
|
|
||||||
}
|
|
||||||
|
|
||||||
token := new(Token)
|
|
||||||
token.Populate(accessToken)
|
|
||||||
|
|
||||||
src, err := json.Marshal(token)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "cannot marshal token data")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = bkt.Put([]byte(token.AccessToken), src); err != nil {
|
|
||||||
return errors.Wrap(err, "cannot put token into bucket")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return errors.Wrap(err, "failed to put token into database")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
|
||||||
result := new(domain.Token)
|
|
||||||
|
|
||||||
if err := repo.db.View(func(tx *bolt.Tx) (err error) {
|
|
||||||
t := new(Token)
|
|
||||||
|
|
||||||
bkt := tx.Bucket(t.Bucket())
|
|
||||||
if bkt == nil {
|
|
||||||
return token.ErrNotExist
|
|
||||||
}
|
|
||||||
|
|
||||||
src := bkt.Get([]byte(accessToken))
|
|
||||||
if src == nil {
|
|
||||||
return token.ErrNotExist
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = t.Bind(src, result); err != nil {
|
|
||||||
return errors.Wrap(err, "cannot parse token")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return nil, errors.Wrap(err, "failed to view token in database")
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (Token) Bucket() []byte { return []byte("tokens") }
|
|
||||||
|
|
||||||
func (t *Token) Populate(accessToken *domain.Token) {
|
|
||||||
t.AccessToken = accessToken.AccessToken
|
|
||||||
t.ClientID = accessToken.ClientID
|
|
||||||
t.CreatedAt = time.Now().UTC().Round(time.Second)
|
|
||||||
t.Me = accessToken.Me
|
|
||||||
t.Scopes = make([]string, len(accessToken.Scopes))
|
|
||||||
t.UpdatedAt = t.CreatedAt
|
|
||||||
|
|
||||||
for i := range accessToken.Scopes {
|
|
||||||
t.Scopes[i] = accessToken.Scopes[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Token) Bind(src []byte, accessToken *domain.Token) error {
|
|
||||||
if err := json.Unmarshal(src, t); err != nil {
|
|
||||||
return errors.Wrap(err, "cannot unmarshal token")
|
|
||||||
}
|
|
||||||
|
|
||||||
accessToken.AccessToken = t.AccessToken
|
|
||||||
accessToken.ClientID = t.ClientID
|
|
||||||
accessToken.Me = t.Me
|
|
||||||
accessToken.Scopes = t.Scopes
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,86 +0,0 @@
|
||||||
package bolt_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
bolt "go.etcd.io/bbolt"
|
|
||||||
|
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
|
||||||
"source.toby3d.me/website/indieauth/internal/token"
|
|
||||||
repository "source.toby3d.me/website/indieauth/internal/token/repository/bolt"
|
|
||||||
"source.toby3d.me/website/indieauth/internal/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCreate(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
//nolint: exhaustivestruct
|
|
||||||
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
|
|
||||||
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
|
|
||||||
//nolint: exhaustivestruct
|
|
||||||
_, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket())
|
|
||||||
|
|
||||||
//nolint: wrapcheck
|
|
||||||
return err
|
|
||||||
}))
|
|
||||||
|
|
||||||
repo := repository.NewBoltTokenRepository(db)
|
|
||||||
accessToken := domain.TestToken(t)
|
|
||||||
|
|
||||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
|
||||||
|
|
||||||
result := domain.NewToken()
|
|
||||||
|
|
||||||
require.NoError(t, db.View(func(tx *bolt.Tx) (err error) {
|
|
||||||
dto := new(repository.Token)
|
|
||||||
|
|
||||||
//nolint: wrapcheck
|
|
||||||
return dto.Bind(tx.Bucket(dto.Bucket()).Get([]byte(accessToken.AccessToken)), result)
|
|
||||||
}))
|
|
||||||
assert.Equal(t, accessToken, result)
|
|
||||||
|
|
||||||
assert.ErrorIs(t, repo.Create(context.TODO(), accessToken), token.ErrExist)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGet(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
//nolint: exhaustivestruct
|
|
||||||
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
|
|
||||||
accessToken := domain.TestToken(t)
|
|
||||||
|
|
||||||
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
|
|
||||||
//nolint: exhaustivestruct
|
|
||||||
bkt, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket())
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "cannot create bucket")
|
|
||||||
}
|
|
||||||
|
|
||||||
t := new(repository.Token)
|
|
||||||
t.Populate(accessToken)
|
|
||||||
|
|
||||||
src, err := json.Marshal(t)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "cannot marshal token data")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = bkt.Put([]byte(t.AccessToken), src); err != nil {
|
|
||||||
return errors.Wrap(err, "cannot put token into bucket")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
|
|
||||||
result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, accessToken, result)
|
|
||||||
}
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
|
||||||
|
"source.toby3d.me/website/indieauth/internal/domain"
|
||||||
|
"source.toby3d.me/website/indieauth/internal/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Token struct {
|
||||||
|
CreatedAt sql.NullTime `db:"created_at"`
|
||||||
|
AccessToken string `db:"access_token"`
|
||||||
|
ClientID string `db:"client_id"`
|
||||||
|
Me string `db:"me"`
|
||||||
|
Scope string `db:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3TokenRepository struct {
|
||||||
|
db *sqlx.DB
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
QueryTable string = `CREATE TABLE IF NOT EXISTS tokens (
|
||||||
|
access_token TEXT UNIQUE PRIMARY KEY NOT NULL,
|
||||||
|
client_id TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
me TEXT NOT NULL,
|
||||||
|
scope TEXT
|
||||||
|
);`
|
||||||
|
|
||||||
|
QueryGet string = `SELECT *
|
||||||
|
FROM tokens
|
||||||
|
WHERE access_token=$1;`
|
||||||
|
|
||||||
|
QueryCreate string = `INSERT INTO tokens (created_at, access_token, client_id, me, scope)
|
||||||
|
VALUES (:created_at, :access_token, :client_id, :me, :scope);`
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewSQLite3TokenRepository(db *sqlx.DB) token.Repository {
|
||||||
|
return &sqlite3TokenRepository{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
|
||||||
|
if _, err := repo.db.NamedExecContext(ctx, QueryTable+QueryCreate, NewToken(accessToken)); err != nil {
|
||||||
|
return fmt.Errorf("cannot create token record in db: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||||
|
t := new(Token)
|
||||||
|
if err := repo.db.GetContext(ctx, t, QueryTable+QueryGet, accessToken); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, token.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("cannot find token in db: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := new(domain.Token)
|
||||||
|
t.Populate(result)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewToken(src *domain.Token) *Token {
|
||||||
|
return &Token{
|
||||||
|
CreatedAt: sql.NullTime{
|
||||||
|
Time: time.Now().UTC(),
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
AccessToken: src.AccessToken,
|
||||||
|
ClientID: src.ClientID.String(),
|
||||||
|
Me: src.Me.String(),
|
||||||
|
Scope: src.Scope.String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Token) Populate(dst *domain.Token) {
|
||||||
|
dst.AccessToken = t.AccessToken
|
||||||
|
dst.ClientID, _ = domain.NewClientID(t.ClientID)
|
||||||
|
dst.Me, _ = domain.NewMe(t.Me)
|
||||||
|
dst.Scope = make(domain.Scopes, 0)
|
||||||
|
|
||||||
|
for _, scope := range strings.Fields(t.Scope) {
|
||||||
|
s, err := domain.ParseScope(scope)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Scope = append(dst.Scope, s)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
package sqlite3_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"source.toby3d.me/website/indieauth/internal/domain"
|
||||||
|
"source.toby3d.me/website/indieauth/internal/testing/sqltest"
|
||||||
|
repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, cleanup := sqltest.Open(t)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
token := domain.TestToken(t)
|
||||||
|
require.NoError(t, repository.NewSQLite3TokenRepository(db).Create(context.Background(), token))
|
||||||
|
|
||||||
|
results := make([]*repository.Token, 0)
|
||||||
|
require.NoError(t, db.Select(&results, "SELECT * FROM tokens;"))
|
||||||
|
require.Len(t, results, 1)
|
||||||
|
|
||||||
|
result := new(domain.Token)
|
||||||
|
results[0].Populate(result)
|
||||||
|
|
||||||
|
assert.Equal(t, token.AccessToken, result.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGet(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, cleanup := sqltest.Open(t)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
token := domain.TestToken(t)
|
||||||
|
_, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewToken(token))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := repository.NewSQLite3TokenRepository(db).Get(context.Background(), token.AccessToken)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, token.AccessToken, result.AccessToken)
|
||||||
|
}
|
|
@ -2,108 +2,115 @@ package usecase
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/jwa"
|
"github.com/lestrrat-go/jwx/jwa"
|
||||||
"github.com/lestrrat-go/jwx/jwt"
|
"github.com/lestrrat-go/jwx/jwt"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"source.toby3d.me/website/indieauth/internal/config"
|
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
"source.toby3d.me/website/indieauth/internal/domain"
|
||||||
"source.toby3d.me/website/indieauth/internal/random"
|
"source.toby3d.me/website/indieauth/internal/session"
|
||||||
"source.toby3d.me/website/indieauth/internal/token"
|
"source.toby3d.me/website/indieauth/internal/token"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type tokenUseCase struct {
|
||||||
Config struct {
|
sessions session.Repository
|
||||||
Configer config.UseCase
|
config *domain.Config
|
||||||
Tokens token.Repository
|
tokens token.Repository
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenUseCase struct {
|
//nolint: gochecknoinits
|
||||||
configer config.UseCase
|
func init() {
|
||||||
tokens token.Repository
|
jwt.RegisterCustomField("scope", make(domain.Scopes, 0))
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
func NewTokenUseCase(config Config) token.UseCase {
|
func NewTokenUseCase(tokens token.Repository, sessions session.Repository, config *domain.Config) token.UseCase {
|
||||||
return &tokenUseCase{
|
return &tokenUseCase{
|
||||||
configer: config.Configer,
|
sessions: sessions,
|
||||||
tokens: config.Tokens,
|
config: config,
|
||||||
|
tokens: tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate generates a new Token based on the session data.
|
func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOptions) (*domain.Token, error) {
|
||||||
func (useCase *tokenUseCase) Generate(ctx context.Context, opts token.GenerateOptions) (*domain.Token, error) {
|
session, err := useCase.sessions.GetAndDelete(ctx, opts.Code)
|
||||||
nonce, err := random.String(opts.NonceLength)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot generate code")
|
return nil, fmt.Errorf("cannot get session from store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t := jwt.New()
|
if opts.ClientID.String() != session.ClientID.String() {
|
||||||
now := time.Now().UTC().Round(time.Second)
|
return nil, domain.Error{
|
||||||
|
Code: "invalid_request",
|
||||||
t.Set(jwt.IssuerKey, opts.ClientID)
|
Description: "client's URL MUST match the client_id used in the authentication request",
|
||||||
t.Set(jwt.SubjectKey, opts.Me)
|
URI: "https://indieauth.net/source/#request",
|
||||||
t.Set(jwt.ExpirationKey, now.Add(useCase.configer.GetIndieAuthAccessTokenExpirationTime()))
|
Frame: xerrors.Caller(1),
|
||||||
t.Set(jwt.NotBeforeKey, now)
|
}
|
||||||
t.Set(jwt.IssuedAtKey, now)
|
|
||||||
t.Set("scope", strings.Join(opts.Scopes, " "))
|
|
||||||
t.Set("nonce", nonce)
|
|
||||||
|
|
||||||
token, err := jwt.Sign(t,
|
|
||||||
jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()),
|
|
||||||
[]byte(useCase.configer.GetIndieAuthJWTSecret()))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "cannot sign a new access token")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &domain.Token{
|
if opts.RedirectURI.String() != session.RedirectURI.String() {
|
||||||
Scopes: opts.Scopes,
|
return nil, domain.Error{
|
||||||
AccessToken: string(token),
|
Code: "invalid_request",
|
||||||
ClientID: opts.ClientID,
|
Description: "client's redirect URL MUST match the initial authentication request",
|
||||||
Me: opts.Me,
|
URI: "https://indieauth.net/source/#request",
|
||||||
}, nil
|
Frame: xerrors.Caller(1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.CodeChallenge != "" &&
|
||||||
|
!session.CodeChallengeMethod.Validate(session.CodeChallenge, opts.CodeVerifier) {
|
||||||
|
return nil, domain.Error{
|
||||||
|
Code: "invalid_request",
|
||||||
|
Description: "code_verifier is not hashes to the same value as given in " +
|
||||||
|
"the code_challenge in the original authorization request",
|
||||||
|
URI: "https://indieauth.net/source/#request",
|
||||||
|
Frame: xerrors.Caller(1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := domain.NewToken(domain.NewTokenOptions{
|
||||||
|
Algorithm: useCase.config.JWT.Algorithm,
|
||||||
|
Expiration: useCase.config.JWT.Expiry,
|
||||||
|
Issuer: session.ClientID,
|
||||||
|
NonceLength: useCase.config.JWT.NonceLength,
|
||||||
|
Scope: session.Scope,
|
||||||
|
Secret: []byte(useCase.config.JWT.Secret),
|
||||||
|
Subject: session.Me,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot generate a new access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
|
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||||
find, err := useCase.tokens.Get(ctx, accessToken)
|
find, err := useCase.tokens.Get(ctx, accessToken)
|
||||||
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
|
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
|
||||||
return nil, errors.Wrap(err, "cannot ckeck token in store")
|
return nil, fmt.Errorf("cannot check token in store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if find != nil {
|
if find != nil {
|
||||||
return nil, token.ErrRevoke
|
return nil, token.ErrRevoke
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := jwt.ParseString(accessToken, jwt.WithVerify(
|
t, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm),
|
||||||
jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()),
|
[]byte(useCase.config.JWT.Secret)))
|
||||||
[]byte(useCase.configer.GetIndieAuthJWTSecret()),
|
|
||||||
))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot parse JWT token")
|
return nil, fmt.Errorf("cannot parse JWT token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = jwt.Validate(t); err != nil {
|
if err = jwt.Validate(t); err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot validate JWT token")
|
return nil, fmt.Errorf("cannot validate JWT token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
result := &domain.Token{
|
result := &domain.Token{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
ClientID: t.Issuer(),
|
|
||||||
Me: t.Subject(),
|
|
||||||
Scopes: make([]string, 0),
|
|
||||||
}
|
}
|
||||||
|
result.ClientID, _ = domain.NewClientID(t.Issuer())
|
||||||
|
result.Me, _ = domain.NewMe(t.Subject())
|
||||||
|
|
||||||
rawScope, ok := t.Get("scope")
|
if scope, ok := t.Get("scope"); ok {
|
||||||
if !ok {
|
result.Scope, _ = scope.(domain.Scopes)
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if scope, ok := rawScope.(string); ok {
|
|
||||||
result.Scopes = strings.Fields(scope)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
@ -112,11 +119,11 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
|
||||||
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
|
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
|
||||||
t, err := useCase.Verify(ctx, accessToken)
|
t, err := useCase.Verify(ctx, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "cannot verify token")
|
return fmt.Errorf("cannot verify token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = useCase.tokens.Create(ctx, t); err != nil {
|
if err = useCase.tokens.Create(ctx, t); err != nil {
|
||||||
return errors.Wrap(err, "cannot save token in database")
|
return fmt.Errorf("cannot save token in database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -2,75 +2,48 @@ package usecase_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/jwt"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
configrepo "source.toby3d.me/website/indieauth/internal/config/repository/viper"
|
|
||||||
configucase "source.toby3d.me/website/indieauth/internal/config/usecase"
|
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
"source.toby3d.me/website/indieauth/internal/domain"
|
||||||
"source.toby3d.me/website/indieauth/internal/token"
|
"source.toby3d.me/website/indieauth/internal/token"
|
||||||
repository "source.toby3d.me/website/indieauth/internal/token/repository/memory"
|
repository "source.toby3d.me/website/indieauth/internal/token/repository/memory"
|
||||||
ucase "source.toby3d.me/website/indieauth/internal/token/usecase"
|
usecase "source.toby3d.me/website/indieauth/internal/token/usecase"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGenerate(t *testing.T) {
|
func TestExchange(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
configer := configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(domain.TestConfig(t)))
|
|
||||||
options := token.GenerateOptions{
|
|
||||||
ClientID: "https://app.example.com/",
|
|
||||||
Me: "https://user.example.net/",
|
|
||||||
Scopes: []string{"create", "update", "delete"},
|
|
||||||
NonceLength: 42,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := ucase.NewTokenUseCase(ucase.Config{
|
|
||||||
Configer: configer,
|
|
||||||
Tokens: nil,
|
|
||||||
}).Generate(context.TODO(), options)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, options.ClientID, result.ClientID)
|
|
||||||
assert.Equal(t, options.Me, result.Me)
|
|
||||||
assert.Equal(t, options.Scopes, result.Scopes)
|
|
||||||
|
|
||||||
token, err := jwt.ParseString(result.AccessToken)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, options.Me, token.Subject())
|
|
||||||
assert.Equal(t, options.ClientID, token.Issuer())
|
|
||||||
|
|
||||||
scope, ok := token.Get("scope")
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, strings.Join(options.Scopes, " "), scope)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestVerify(t *testing.T) {
|
func TestVerify(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||||
useCase := ucase.NewTokenUseCase(repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)))
|
ucase := usecase.NewTokenUseCase(repo, nil, domain.TestConfig(t))
|
||||||
|
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
accessToken := domain.TestToken(t)
|
accessToken := domain.TestToken(t)
|
||||||
|
|
||||||
result, err := useCase.Verify(context.TODO(), accessToken.AccessToken)
|
result, err := ucase.Verify(context.TODO(), accessToken.AccessToken)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, accessToken, result)
|
assert.Equal(t, accessToken.AccessToken, result.AccessToken)
|
||||||
|
assert.Equal(t, accessToken.Scope, result.Scope)
|
||||||
|
assert.Equal(t, accessToken.ClientID.String(), result.ClientID.String())
|
||||||
|
assert.Equal(t, accessToken.Me.String(), result.Me.String())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("revoke", func(t *testing.T) {
|
t.Run("revoked", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
accessToken := domain.TestToken(t)
|
accessToken := domain.TestToken(t)
|
||||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
||||||
|
|
||||||
result, err := useCase.Verify(context.TODO(), accessToken.AccessToken)
|
result, err := ucase.Verify(context.TODO(), accessToken.AccessToken)
|
||||||
require.ErrorIs(t, err, token.ErrRevoke)
|
require.ErrorIs(t, err, token.ErrRevoke)
|
||||||
assert.Nil(t, result)
|
assert.Nil(t, result)
|
||||||
})
|
})
|
||||||
|
@ -79,16 +52,12 @@ func TestVerify(t *testing.T) {
|
||||||
func TestRevoke(t *testing.T) {
|
func TestRevoke(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
v := viper.New()
|
config := domain.TestConfig(t)
|
||||||
v.Set("indieauth.jwtSigningAlgorithm", "HS256")
|
|
||||||
v.Set("indieauth.jwtSecret", "hackme")
|
|
||||||
|
|
||||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
|
||||||
accessToken := domain.TestToken(t)
|
accessToken := domain.TestToken(t)
|
||||||
|
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||||
|
|
||||||
require.NoError(t, ucase.NewTokenUseCase(
|
require.NoError(t, usecase.NewTokenUseCase(repo, nil, config).
|
||||||
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
|
Revoke(context.TODO(), accessToken.AccessToken))
|
||||||
).Revoke(context.TODO(), accessToken.AccessToken))
|
|
||||||
|
|
||||||
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
|
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
Loading…
Reference in New Issue