♻️ Refactored token package, store only revoked tokens

This commit is contained in:
Maxim Lebedev 2021-10-14 02:53:31 +05:00
parent dd43f0bf22
commit 4f63db7bfa
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
10 changed files with 253 additions and 347 deletions

View File

@ -1,37 +1,64 @@
package domain
import (
"strings"
"testing"
"time"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/stretchr/testify/require"
"source.toby3d.me/website/oauth/internal/random"
)
type Token struct {
Expiry time.Time
Scopes []string
AccessToken string
TokenType string
ClientID string
Me string
Profile *Profile
Scopes []string
Type string
}
func NewToken() *Token {
t := new(Token)
t.Scopes = make([]string, 0)
t.Expiry = time.Time{}
return t
}
//nolint: gomnd
func TestToken(tb testing.TB) *Token {
tb.Helper()
client := TestClient(tb)
profile := TestProfile(tb)
now := time.Now().UTC().Round(time.Second)
scopes := []string{"create", "update", "delete"}
t := jwt.New()
// required
t.Set(jwt.IssuerKey, client.ID) // NOTE(toby3d): client_id
t.Set(jwt.SubjectKey, profile.URL) // NOTE(toby3d): me
// TODO(toby3d): t.Set(jwt.AudienceKey, nil)
t.Set(jwt.ExpirationKey, now.Add(1*time.Hour))
t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour))
t.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour))
// TODO(toby3d): t.Set(jwt.JwtIDKey, nil)
// optional
t.Set("scope", strings.Join(scopes, " "))
t.Set("nonce", random.New().String(32))
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
require.NoError(tb, err)
return &Token{
AccessToken: random.New().String(32),
ClientID: "https://app.example.com/",
Me: "https://user.example.net/",
Profile: TestProfile(tb),
Scopes: []string{"create", "update", "delete"},
Type: "Bearer",
AccessToken: string(accessToken),
ClientID: t.Issuer(),
Expiry: t.Expiration(),
Me: t.Subject(),
Scopes: scopes,
TokenType: "Bearer",
}
}

View File

@ -15,10 +15,6 @@ import (
)
type (
RequestHandler struct {
useCase token.UseCase
}
RevocationRequest struct {
Action string
Token string
@ -32,6 +28,10 @@ type (
}
RevocationResponse struct{}
RequestHandler struct {
tokener token.UseCase
}
)
const (
@ -39,9 +39,9 @@ const (
ActionRevoke string = "revoke"
)
func NewRequestHandler(useCase token.UseCase) *RequestHandler {
func NewRequestHandler(tokener token.UseCase) *RequestHandler {
return &RequestHandler{
useCase: useCase,
tokener: tokener,
}
}
@ -54,16 +54,11 @@ func (h *RequestHandler) Read(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSON)
ctx.SetStatusCode(http.StatusOK)
encoder := json.NewEncoder(ctx)
rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization)
token, err := h.useCase.Verify(ctx, string(bytes.TrimSpace(bytes.TrimPrefix(rawToken, []byte("Bearer")))))
token, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer "))))
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
if err = encoder.Encode(err); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
}
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
@ -74,16 +69,12 @@ func (h *RequestHandler) Read(ctx *http.RequestCtx) {
return
}
if err := encoder.Encode(&VerificationResponse{
Me: token.Me,
if err := json.NewEncoder(ctx).Encode(&VerificationResponse{
ClientID: token.ClientID,
Me: token.Me,
Scope: strings.Join(token.Scopes, " "),
}); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
if err = encoder.Encode(err); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
}
ctx.Error(err.Error(), http.StatusInternalServerError)
}
}
@ -102,30 +93,21 @@ func (h *RequestHandler) Revocation(ctx *http.RequestCtx) {
req := new(RevocationRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
if err = encoder.Encode(err); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
}
encoder.Encode(err)
return
}
if err := h.useCase.Revoke(ctx, req.Token); err != nil {
if err := h.tokener.Revoke(ctx, req.Token); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
if err = encoder.Encode(err); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
}
encoder.Encode(err)
return
}
if err := encoder.Encode(&RevocationResponse{}); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
if err = encoder.Encode(err); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
}
encoder.Encode(err)
}
}

View File

@ -7,11 +7,14 @@ import (
"testing"
"github.com/goccy/go-json"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
http "github.com/valyala/fasthttp"
"source.toby3d.me/website/oauth/internal/common"
configrepo "source.toby3d.me/website/oauth/internal/config/repository/viper"
configucase "source.toby3d.me/website/oauth/internal/config/usecase"
"source.toby3d.me/website/oauth/internal/domain"
delivery "source.toby3d.me/website/oauth/internal/token/delivery/http"
repository "source.toby3d.me/website/oauth/internal/token/repository/memory"
@ -22,12 +25,16 @@ import (
func TestVerification(t *testing.T) {
t.Parallel()
repo := repository.NewMemoryTokenRepository(new(sync.Map))
v := viper.New()
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
v.SetDefault("indieauth.jwtSecret", "hackme")
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(repo)).Read)
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(
repository.NewMemoryTokenRepository(new(sync.Map)),
configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
)).Read)
t.Cleanup(cleanup)
req := http.AcquireRequest()
@ -56,12 +63,16 @@ func TestVerification(t *testing.T) {
func TestRevocation(t *testing.T) {
t.Parallel()
repo := repository.NewMemoryTokenRepository(new(sync.Map))
v := viper.New()
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
v.SetDefault("indieauth.jwtSecret", "hackme")
tokens := repository.NewMemoryTokenRepository(new(sync.Map))
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), domain.TestToken(t)))
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(repo)).Update)
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(
usecase.NewTokenUseCase(tokens, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))),
).Update)
t.Cleanup(cleanup)
req := http.AcquireRequest()
@ -81,7 +92,7 @@ func TestRevocation(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode())
assert.Equal(t, `{}`, strings.TrimSpace(string(resp.Body())))
token, err := repo.Get(context.TODO(), accessToken.AccessToken)
result, err := tokens.Get(context.TODO(), accessToken.AccessToken)
require.NoError(t, err)
assert.Nil(t, token)
assert.Equal(t, accessToken.AccessToken, result.AccessToken)
}

View File

@ -11,8 +11,6 @@ import (
type Repository interface {
Get(ctx context.Context, accessToken string) (*domain.Token, error)
Create(ctx context.Context, accessToken *domain.Token) error
Update(ctx context.Context, accessToken *domain.Token) error
Remove(ctx context.Context, accessToken string) error
}
var ErrExist error = domain.Error{

View File

@ -2,9 +2,8 @@ package bolt
import (
"context"
"strings"
"time"
json "github.com/goccy/go-json"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
"golang.org/x/xerrors"
@ -13,21 +12,13 @@ import (
"source.toby3d.me/website/oauth/internal/token"
)
type (
Token struct {
AccessToken string `json:"accessToken"`
ClientID string `json:"clientId"`
Me string `json:"me"`
Scope string `json:"scope"`
Type string `json:"type"`
}
type boltTokenRepository struct {
db *bolt.DB
}
boltTokenRepository struct {
db *bolt.DB
}
)
var ErrNotExist error = errors.New("token not exist")
var ErrNotExist error = xerrors.New("key not exist")
var DefaultBucket []byte = []byte("tokens")
func NewBoltTokenRepository(db *bolt.DB) token.Repository {
return &boltTokenRepository{
@ -35,97 +26,60 @@ func NewBoltTokenRepository(db *bolt.DB) token.Repository {
}
}
func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
result := domain.NewToken()
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
find, err := repo.Get(ctx, accessToken.AccessToken)
if err != nil {
return errors.Wrap(err, "cannot check token in database")
}
if err := repo.db.View(func(tx *bolt.Tx) error {
//nolint: exhaustivestruct
if src := tx.Bucket(Token{}.Bucket()).Get([]byte(accessToken)); src != nil {
return new(Token).Bind(src, result)
if find != nil {
return token.ErrExist
}
if err = repo.db.Update(func(tx *bolt.Tx) error {
bkt, err := tx.CreateBucketIfNotExists(DefaultBucket)
if err != nil {
return errors.Wrap(err, "cannot create bucket")
}
return ErrNotExist
return bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339)))
}); err != nil {
if !xerrors.Is(err, ErrNotExist) {
return nil, errors.Wrap(err, "failed to retrieve token from storage")
return errors.Wrap(err, "failed to batch token in database")
}
return nil
}
func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
result := &domain.Token{
AccessToken: accessToken,
TokenType: "Bearer",
Expiry: time.Time{},
}
if err := repo.db.View(func(tx *bolt.Tx) (err error) {
bkt := tx.Bucket(DefaultBucket)
if bkt == nil {
return ErrNotExist
}
return nil, nil
expiry := bkt.Get([]byte(accessToken))
if expiry == nil {
return ErrNotExist
}
if result.Expiry, err = time.Parse(time.RFC3339, string(expiry)); err != nil {
return err
}
return nil
}); err != nil {
if xerrors.Is(err, ErrNotExist) {
return nil, nil
}
return nil, errors.Wrap(err, "failed to view token in database")
}
return result, nil
}
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
t, err := repo.Get(ctx, accessToken.AccessToken)
if err != nil {
return errors.Wrap(err, "failed to verify the existence of the token")
}
if t != nil {
return token.ErrExist
}
return repo.Update(ctx, accessToken)
}
func (repo *boltTokenRepository) Update(ctx context.Context, accessToken *domain.Token) error {
dto := new(Token)
dto.Populate(accessToken)
src, err := json.Marshal(dto)
if err != nil {
return errors.Wrap(err, "failed to marshal token")
}
if err = repo.db.Update(func(tx *bolt.Tx) error {
if err := tx.Bucket(dto.Bucket()).Put([]byte(dto.AccessToken), src); err != nil {
return errors.Wrap(err, "failed to overwrite the token in the bucket")
}
return nil
}); err != nil {
return errors.Wrap(err, "failed to update the token in the repository")
}
return nil
}
func (repo *boltTokenRepository) Remove(ctx context.Context, accessToken string) error {
if err := repo.db.Update(func(tx *bolt.Tx) error {
//nolint: exhaustivestruct
if err := tx.Bucket(Token{}.Bucket()).Delete([]byte(accessToken)); err != nil {
return errors.Wrap(err, "failed to remove token in bucket")
}
return nil
}); err != nil {
return errors.Wrap(err, "failed to remove token from storage")
}
return nil
}
func (Token) Bucket() []byte { return []byte("tokens") }
func (t *Token) Populate(src *domain.Token) {
t.AccessToken = src.AccessToken
t.ClientID = src.ClientID
t.Me = src.Me
t.Scope = strings.Join(src.Scopes, " ")
t.Type = src.Type
}
func (t *Token) Bind(src []byte, dst *domain.Token) error {
if err := json.Unmarshal(src, t); err != nil {
return errors.Wrap(err, "cannot unmarshal token source")
}
dst.AccessToken = t.AccessToken
dst.Scopes = strings.Fields(t.Scope)
dst.Type = t.Type
dst.ClientID = t.ClientID
dst.Me = t.Me
return nil
}

View File

@ -1,11 +1,10 @@
//nolint: wrapcheck
package bolt_test
import (
"context"
"testing"
"time"
json "github.com/goccy/go-json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
bolt "go.etcd.io/bbolt"
@ -16,120 +15,55 @@ import (
"source.toby3d.me/website/oauth/internal/util"
)
func TestGet(t *testing.T) {
t.Parallel()
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
t.Cleanup(cleanup)
accessToken := domain.TestToken(t)
accessToken.Profile = nil
dto := new(repository.Token)
dto.Populate(accessToken)
src, err := json.Marshal(dto)
require.NoError(t, err)
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
//nolint: exhaustivestruct
return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src)
}))
result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken)
require.NoError(t, err)
assert.Equal(t, accessToken, result)
}
func TestCreate(t *testing.T) {
t.Parallel()
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
db, cleanup := util.TestBolt(t, repository.DefaultBucket)
t.Cleanup(cleanup)
repo := repository.NewBoltTokenRepository(db)
accessToken := domain.TestToken(t)
accessToken.Profile = nil
require.NoError(t, repo.Create(context.TODO(), accessToken))
result := new(domain.Token)
result := &domain.Token{
AccessToken: accessToken.AccessToken,
TokenType: accessToken.TokenType,
Expiry: time.Time{},
}
require.NoError(t, db.View(func(tx *bolt.Tx) error {
require.NoError(t, db.View(func(tx *bolt.Tx) (err error) {
//nolint: exhaustivestruct
return new(repository.Token).Bind(tx.Bucket(repository.Token{}.Bucket()).
Get([]byte(accessToken.AccessToken)), result)
}))
src := tx.Bucket(repository.DefaultBucket).Get([]byte(accessToken.AccessToken))
result.Expiry, err = time.Parse(time.RFC3339, string(src))
return
}))
assert.Equal(t, accessToken, result)
assert.EqualError(t, repo.Create(context.TODO(), accessToken), token.ErrExist.Error())
assert.ErrorIs(t, repo.Create(context.TODO(), accessToken), token.ErrExist)
}
func TestUpdate(t *testing.T) {
func TestGet(t *testing.T) {
t.Parallel()
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
db, cleanup := util.TestBolt(t, repository.DefaultBucket)
t.Cleanup(cleanup)
repo := repository.NewBoltTokenRepository(db)
accessToken := domain.TestToken(t)
src, err := json.Marshal(accessToken)
require.NoError(t, err)
//nolint: exhaustivestruct
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src)
bkt, err := tx.CreateBucketIfNotExists(repository.DefaultBucket)
if err != nil {
return err
}
return bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339)))
}))
require.NoError(t, repository.NewBoltTokenRepository(db).Update(context.TODO(), &domain.Token{
AccessToken: accessToken.AccessToken,
ClientID: "https://client.example.net/",
Me: "https://toby3d.ru/",
Scopes: []string{"read"},
Type: "Bearer",
Profile: nil,
}))
result := domain.NewToken()
//nolint: exhaustivestruct
require.NoError(t, db.View(func(tx *bolt.Tx) error {
return new(repository.Token).Bind(tx.Bucket(repository.Token{}.Bucket()).
Get([]byte(accessToken.AccessToken)), result)
}))
assert.Equal(t, &domain.Token{
AccessToken: accessToken.AccessToken,
ClientID: "https://client.example.net/",
Me: "https://toby3d.ru/",
Scopes: []string{"read"},
Type: "Bearer",
Profile: nil,
}, result)
}
func TestDelete(t *testing.T) {
t.Parallel()
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
t.Cleanup(cleanup)
accessToken := domain.TestToken(t)
src, err := json.Marshal(accessToken)
require.NoError(t, err)
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
//nolint: exhaustivestruct
return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src)
}))
require.NoError(t, repository.NewBoltTokenRepository(db).Remove(context.TODO(), accessToken.AccessToken))
require.NoError(t, db.View(func(tx *bolt.Tx) error {
//nolint: exhaustivestruct
assert.Nil(t, tx.Bucket(repository.Token{}.Bucket()).Get([]byte(accessToken.AccessToken)))
return nil
}))
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
assert.NoError(t, err)
assert.Equal(t, accessToken, result)
}

View File

@ -4,58 +4,45 @@ import (
"context"
"path"
"sync"
"time"
"source.toby3d.me/website/oauth/internal/domain"
"source.toby3d.me/website/oauth/internal/token"
)
type memoryTokenRepository struct {
tokens *sync.Map
store *sync.Map
}
const Key string = "tokens"
const DefaultPathPrefix string = "tokens"
func NewMemoryTokenRepository(tokens *sync.Map) token.Repository {
func NewMemoryTokenRepository(store *sync.Map) token.Repository {
return &memoryTokenRepository{
tokens: tokens,
store: store,
}
}
func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
src, ok := repo.tokens.Load(path.Join(Key, accessToken))
if !ok {
return nil, nil
}
result, ok := src.(*domain.Token)
if !ok {
return nil, nil
}
return result, nil
}
func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
t, err := repo.Get(ctx, accessToken.AccessToken)
if err != nil {
return err
}
key := path.Join(DefaultPathPrefix, accessToken.AccessToken)
if t != nil {
if _, ok := repo.store.Load(key); ok {
return token.ErrExist
}
return repo.Update(ctx, accessToken)
}
func (repo *memoryTokenRepository) Update(ctx context.Context, accessToken *domain.Token) error {
repo.tokens.Store(path.Join(Key, accessToken.AccessToken), accessToken)
repo.store.Store(key, accessToken.Expiry)
return nil
}
func (repo *memoryTokenRepository) Remove(ctx context.Context, accessToken string) error {
repo.tokens.Delete(path.Join(Key, accessToken))
func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
expiry, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken))
if !ok {
return nil, nil
}
return nil
return &domain.Token{
AccessToken: accessToken,
TokenType: "Bearer",
Expiry: expiry.(time.Time),
}, nil
}

View File

@ -14,66 +14,32 @@ import (
repository "source.toby3d.me/website/oauth/internal/token/repository/memory"
)
func TestGet(t *testing.T) {
t.Parallel()
store := new(sync.Map)
accessToken := domain.TestToken(t)
store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken)
result, err := repository.NewMemoryTokenRepository(store).Get(context.TODO(), accessToken.AccessToken)
require.NoError(t, err)
assert.Equal(t, accessToken, result)
}
func TestCreate(t *testing.T) {
t.Parallel()
store := new(sync.Map)
accessToken := domain.TestToken(t)
repo := repository.NewMemoryTokenRepository(store)
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken))
expiry, ok := store.Load(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken))
assert.True(t, ok)
assert.Equal(t, accessToken, result)
assert.Equal(t, accessToken.Expiry, expiry)
assert.EqualError(t, repo.Create(context.TODO(), accessToken), token.ErrExist.Error())
}
func TestUpdate(t *testing.T) {
func TestGet(t *testing.T) {
t.Parallel()
store := new(sync.Map)
repo := repository.NewMemoryTokenRepository(store)
accessToken := domain.TestToken(t)
store.Store(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken), accessToken.Expiry)
store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken)
tokenCopy := *accessToken
tokenCopy.ClientID = "https://client.example.com/"
tokenCopy.Me = "https://toby3d.ru/"
require.NoError(t, repository.NewMemoryTokenRepository(store).Update(context.TODO(), &tokenCopy))
result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken))
assert.True(t, ok)
assert.NotEqual(t, accessToken, result)
assert.Equal(t, &tokenCopy, result)
}
func TestDelete(t *testing.T) {
t.Parallel()
store := new(sync.Map)
accessToken := domain.TestToken(t)
store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken)
require.NoError(t, repository.NewMemoryTokenRepository(store).Remove(context.TODO(), accessToken.AccessToken))
result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken))
assert.False(t, ok)
assert.Nil(t, result)
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
assert.NoError(t, err)
assert.Equal(t, accessToken, result)
}

View File

@ -2,35 +2,75 @@ package usecase
import (
"context"
"strings"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/pkg/errors"
"source.toby3d.me/website/oauth/internal/config"
"source.toby3d.me/website/oauth/internal/domain"
"source.toby3d.me/website/oauth/internal/token"
)
type tokenUseCase struct {
tokens token.Repository
tokens token.Repository
configer config.UseCase
}
func NewTokenUseCase(tokens token.Repository) token.UseCase {
func NewTokenUseCase(tokens token.Repository, configer config.UseCase) token.UseCase {
return &tokenUseCase{
tokens: tokens,
tokens: tokens,
configer: configer,
}
}
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
t, err := useCase.tokens.Get(ctx, accessToken)
token, err := useCase.tokens.Get(ctx, accessToken)
if err != nil {
return nil, errors.Wrap(err, "failed to retrieve token from storage")
return nil, errors.Wrap(err, "cannot find token in database")
}
return t, nil
if token != nil {
return nil, nil
}
t, err := jwt.ParseString(accessToken, jwt.WithVerify(
jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()),
[]byte(useCase.configer.GetIndieAuthJWTSecret()),
))
if err != nil {
return nil, errors.Wrap(err, "cannot parse JWT token")
}
if err = jwt.Validate(t); err != nil {
return nil, errors.Wrap(err, "cannot validate JWT token")
}
token = &domain.Token{
Expiry: t.Expiration(),
Scopes: []string{},
AccessToken: accessToken,
TokenType: "Bearer",
ClientID: t.Issuer(),
Me: t.Subject(),
}
if scope, ok := t.Get("scope"); ok {
token.Scopes = strings.Fields(scope.(string))
}
return token, nil
}
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
if err := useCase.tokens.Remove(ctx, accessToken); err != nil {
return errors.Wrap(err, "failed to delete a token in the vault")
t, err := useCase.Verify(ctx, accessToken)
if err != nil {
return errors.Wrap(err, "cannot verify token")
}
if err = useCase.tokens.Create(ctx, t); err != nil {
return errors.Wrap(err, "cannot save token in database")
}
return nil

View File

@ -5,9 +5,12 @@ import (
"sync"
"testing"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
configrepo "source.toby3d.me/website/oauth/internal/config/repository/viper"
configucase "source.toby3d.me/website/oauth/internal/config/usecase"
"source.toby3d.me/website/oauth/internal/domain"
repository "source.toby3d.me/website/oauth/internal/token/repository/memory"
"source.toby3d.me/website/oauth/internal/token/usecase"
@ -16,31 +19,35 @@ import (
func TestVerify(t *testing.T) {
t.Parallel()
v := viper.New()
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
v.SetDefault("indieauth.jwtSecret", "hackme")
repo := repository.NewMemoryTokenRepository(new(sync.Map))
accessToken := domain.NewToken()
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
token, err := usecase.NewTokenUseCase(repo).Verify(context.TODO(), accessToken.AccessToken)
token, err := usecase.NewTokenUseCase(
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
).Verify(context.TODO(), accessToken.AccessToken)
require.NoError(t, err)
assert.Equal(t, accessToken, token)
assert.Equal(t, accessToken.AccessToken, token.AccessToken)
}
func TestRevoke(t *testing.T) {
t.Parallel()
v := viper.New()
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
v.SetDefault("indieauth.jwtSecret", "hackme")
repo := repository.NewMemoryTokenRepository(new(sync.Map))
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
require.NoError(t, usecase.NewTokenUseCase(
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
).Revoke(context.TODO(), accessToken.AccessToken))
token, err := repo.Get(context.TODO(), accessToken.AccessToken)
require.NoError(t, err)
assert.NotNil(t, token)
require.NoError(t, usecase.NewTokenUseCase(repo).Revoke(context.TODO(), token.AccessToken))
token, err = repo.Get(context.TODO(), token.AccessToken)
require.NoError(t, err)
assert.Nil(t, token)
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
assert.NoError(t, err)
assert.Equal(t, accessToken.AccessToken, result.AccessToken)
}