♻️ Improved token verification and revocation support

This commit is contained in:
Maxim Lebedev 2021-10-18 03:51:10 +05:00
parent f9ec91c246
commit 35b4ae9e23
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
10 changed files with 185 additions and 124 deletions

View File

@ -13,17 +13,15 @@ import (
)
type Token struct {
Expiry time.Time
Scopes []string
AccessToken string
TokenType string
ClientID string
Me string
}
func NewToken() *Token {
t := new(Token)
t.Expiry = time.Time{}
t.Scopes = make([]string, 0)
return t
}
@ -61,9 +59,7 @@ func TestToken(tb testing.TB) *Token {
return &Token{
AccessToken: string(accessToken),
ClientID: t.Issuer(),
Expiry: t.Expiration(),
Me: t.Subject(),
Scopes: scopes,
TokenType: "Bearer",
}
}

View File

@ -56,23 +56,27 @@ func (h *RequestHandler) Read(ctx *http.RequestCtx) {
rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization)
token, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer "))))
t, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer "))))
if err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
if xerrors.Is(err, token.ErrRevoke) {
ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized)
} else {
ctx.Error(err.Error(), http.StatusBadRequest)
}
return
}
if token == nil {
ctx.SetStatusCode(http.StatusUnauthorized)
if t == nil {
ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
if err := json.NewEncoder(ctx).Encode(&VerificationResponse{
ClientID: token.ClientID,
Me: token.Me,
Scope: strings.Join(token.Scopes, " "),
ClientID: t.ClientID,
Me: t.Me,
Scope: strings.Join(t.Scopes, " "),
}); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
}

View File

@ -2,8 +2,7 @@ package token
import (
"context"
"golang.org/x/xerrors"
"errors"
"source.toby3d.me/website/oauth/internal/domain"
)
@ -13,9 +12,7 @@ type Repository interface {
Create(ctx context.Context, accessToken *domain.Token) error
}
var ErrExist error = domain.Error{
Code: "invalid_request",
Description: "this token is already exists",
URI: "",
Frame: xerrors.Caller(1),
}
var (
ErrExist error = errors.New("token already exist")
ErrNotExist error = errors.New("token not exist")
)

View File

@ -2,6 +2,7 @@ package bolt
import (
"context"
"encoding/json"
"time"
"github.com/pkg/errors"
@ -12,13 +13,21 @@ import (
"source.toby3d.me/website/oauth/internal/token"
)
type boltTokenRepository struct {
db *bolt.DB
}
type (
Token struct {
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
DeletedAt time.Time `json:"deletedAt,omitempty"`
Scopes []string `json:"scopes"`
AccessToken string `json:"accessToken"`
ClientID string `json:"clientId"`
Me string `json:"me"`
}
var ErrNotExist error = errors.New("token not exist")
var DefaultBucket = []byte("tokens") //nolint: gochecknoglobals
boltTokenRepository struct {
db *bolt.DB
}
)
func NewBoltTokenRepository(db *bolt.DB) token.Repository {
return &boltTokenRepository{
@ -26,68 +35,94 @@ func NewBoltTokenRepository(db *bolt.DB) token.Repository {
}
}
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
find, err := repo.Get(ctx, accessToken.AccessToken)
if err != nil {
return errors.Wrap(err, "cannot check token in database")
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) (err error) {
t, err := repo.Get(ctx, accessToken.AccessToken)
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
return errors.Wrap(err, "cannot get token in database")
}
if find != nil {
if t != nil {
return token.ErrExist
}
if err = repo.db.Update(func(tx *bolt.Tx) error {
bkt, err := tx.CreateBucketIfNotExists(DefaultBucket)
bkt, err := tx.CreateBucketIfNotExists(Token{}.Bucket())
if err != nil {
return errors.Wrap(err, "cannot create bucket")
}
err = bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339)))
token := new(Token)
token.Populate(accessToken)
src, err := json.Marshal(token)
if err != nil {
return errors.Wrap(err, "cannot marshal token data")
}
if err = bkt.Put([]byte(token.AccessToken), src); err != nil {
return errors.Wrap(err, "cannot put token into bucket")
}
return nil
}); err != nil {
return errors.Wrap(err, "failed to batch token in database")
return errors.Wrap(err, "failed to put token into database")
}
return nil
}
func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
result := &domain.Token{
Expiry: time.Time{},
Scopes: []string{},
AccessToken: accessToken,
TokenType: "Bearer",
ClientID: "",
Me: "",
}
result := new(domain.Token)
if err := repo.db.View(func(tx *bolt.Tx) (err error) {
bkt := tx.Bucket(DefaultBucket)
t := new(Token)
bkt := tx.Bucket(t.Bucket())
if bkt == nil {
return ErrNotExist
return token.ErrNotExist
}
expiry := bkt.Get([]byte(accessToken))
if expiry == nil {
return ErrNotExist
src := bkt.Get([]byte(accessToken))
if src == nil {
return token.ErrNotExist
}
if result.Expiry, err = time.Parse(time.RFC3339, string(expiry)); err != nil {
return errors.Wrap(err, "cannot parse expiry date")
if err = t.Bind(src, result); err != nil {
return errors.Wrap(err, "cannot parse token")
}
return nil
}); err != nil {
if xerrors.Is(err, ErrNotExist) {
return nil, nil
}
return nil, errors.Wrap(err, "failed to view token in database")
}
return result, nil
}
func (Token) Bucket() []byte { return []byte("tokens") }
func (t *Token) Populate(accessToken *domain.Token) {
t.AccessToken = accessToken.AccessToken
t.ClientID = accessToken.ClientID
t.CreatedAt = time.Now().UTC().Round(time.Second)
t.Me = accessToken.Me
t.Scopes = make([]string, len(accessToken.Scopes))
t.UpdatedAt = t.CreatedAt
for i := range accessToken.Scopes {
t.Scopes[i] = accessToken.Scopes[i]
}
}
func (t *Token) Bind(src []byte, accessToken *domain.Token) error {
if err := json.Unmarshal(src, t); err != nil {
return errors.Wrap(err, "cannot unmarshal token")
}
accessToken.AccessToken = t.AccessToken
accessToken.ClientID = t.ClientID
accessToken.Me = t.Me
accessToken.Scopes = t.Scopes
return nil
}

View File

@ -2,8 +2,8 @@ package bolt_test
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
@ -19,29 +19,26 @@ import (
func TestCreate(t *testing.T) {
t.Parallel()
db, cleanup := util.TestBolt(t, repository.DefaultBucket)
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
t.Cleanup(cleanup)
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket())
return err
}))
repo := repository.NewBoltTokenRepository(db)
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
result := &domain.Token{
Expiry: time.Time{},
Scopes: []string{},
AccessToken: accessToken.AccessToken,
TokenType: accessToken.TokenType,
ClientID: "",
Me: "",
}
result := domain.NewToken()
require.NoError(t, db.View(func(tx *bolt.Tx) (err error) {
src := tx.Bucket(repository.DefaultBucket).Get([]byte(accessToken.AccessToken))
dto := new(repository.Token)
result.Expiry, err = time.Parse(time.RFC3339, string(src))
return
return dto.Bind(tx.Bucket(repository.Token{}.Bucket()).Get([]byte(accessToken.AccessToken)), result)
}))
assert.Equal(t, accessToken, result)
@ -51,27 +48,33 @@ func TestCreate(t *testing.T) {
func TestGet(t *testing.T) {
t.Parallel()
db, cleanup := util.TestBolt(t, repository.DefaultBucket)
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
t.Cleanup(cleanup)
repo := repository.NewBoltTokenRepository(db)
accessToken := domain.TestToken(t)
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
bkt, err := tx.CreateBucketIfNotExists(repository.DefaultBucket)
bkt, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket())
if err != nil {
return errors.Wrap(err, "cannot create bucket")
}
err = bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339)))
t := new(repository.Token)
t.Populate(accessToken)
src, err := json.Marshal(t)
if err != nil {
return errors.Wrap(err, "cannot marshal token data")
}
if err = bkt.Put([]byte(t.AccessToken), src); err != nil {
return errors.Wrap(err, "cannot put token into bucket")
}
return nil
}))
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken)
assert.NoError(t, err)
assert.Equal(t, accessToken, result)
}

View File

@ -2,9 +2,11 @@ package memory
import (
"context"
"errors"
"path"
"sync"
"time"
"golang.org/x/xerrors"
"source.toby3d.me/website/oauth/internal/domain"
"source.toby3d.me/website/oauth/internal/token"
@ -16,6 +18,8 @@ type memoryTokenRepository struct {
const DefaultPathPrefix string = "tokens"
var ErrExist error = errors.New("token already exist")
func NewMemoryTokenRepository(store *sync.Map) token.Repository {
return &memoryTokenRepository{
store: store,
@ -23,29 +27,30 @@ func NewMemoryTokenRepository(store *sync.Map) token.Repository {
}
func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
key := path.Join(DefaultPathPrefix, accessToken.AccessToken)
if _, ok := repo.store.Load(key); ok {
return token.ErrExist
t, err := repo.Get(ctx, accessToken.AccessToken)
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
return err
}
repo.store.Store(key, accessToken.Expiry)
if t != nil {
return ErrExist
}
repo.store.Store(path.Join(DefaultPathPrefix, accessToken.AccessToken), accessToken)
return nil
}
func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
expiry, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken))
t, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken))
if !ok {
return nil, nil
return nil, token.ErrNotExist
}
return &domain.Token{
Expiry: expiry.(time.Time),
Scopes: []string{},
AccessToken: accessToken,
TokenType: "Bearer",
ClientID: "",
Me: "",
}, nil
result, ok := t.(*domain.Token)
if !ok {
return nil, token.ErrNotExist
}
return result, nil
}

View File

@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/require"
"source.toby3d.me/website/oauth/internal/domain"
"source.toby3d.me/website/oauth/internal/token"
repository "source.toby3d.me/website/oauth/internal/token/repository/memory"
)
@ -18,28 +17,27 @@ func TestCreate(t *testing.T) {
t.Parallel()
store := new(sync.Map)
token := domain.TestToken(t)
repo := repository.NewMemoryTokenRepository(store)
require.NoError(t, repo.Create(context.TODO(), token))
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
expiry, ok := store.Load(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken))
result, ok := store.Load(path.Join(repository.DefaultPathPrefix, token.AccessToken))
assert.True(t, ok)
assert.Equal(t, accessToken.Expiry, expiry)
assert.Equal(t, token, result)
assert.EqualError(t, repo.Create(context.TODO(), accessToken), token.ErrExist.Error())
assert.ErrorIs(t, repo.Create(context.TODO(), token), repository.ErrExist)
}
func TestGet(t *testing.T) {
t.Parallel()
store := new(sync.Map)
repo := repository.NewMemoryTokenRepository(store)
token := domain.TestToken(t)
accessToken := domain.TestToken(t)
store.Store(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken), accessToken.Expiry)
store.Store(path.Join(repository.DefaultPathPrefix, token.AccessToken), token)
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
result, err := repository.NewMemoryTokenRepository(store).Get(context.TODO(), token.AccessToken)
assert.NoError(t, err)
assert.Equal(t, accessToken, result)
assert.Equal(t, token, result)
}

View File

@ -2,6 +2,7 @@ package token
import (
"context"
"errors"
"source.toby3d.me/website/oauth/internal/domain"
)
@ -10,3 +11,5 @@ type UseCase interface {
Verify(ctx context.Context, accessToken string) (*domain.Token, error)
Revoke(ctx context.Context, accessToken string) error
}
var ErrRevoke error = errors.New("this token has been revoked")

View File

@ -7,6 +7,7 @@ import (
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/pkg/errors"
"golang.org/x/xerrors"
"source.toby3d.me/website/oauth/internal/config"
"source.toby3d.me/website/oauth/internal/domain"
@ -26,13 +27,13 @@ func NewTokenUseCase(tokens token.Repository, configer config.UseCase) token.Use
}
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
token, err := useCase.tokens.Get(ctx, accessToken)
if err != nil {
return nil, errors.Wrap(err, "cannot find token in database")
find, err := useCase.tokens.Get(ctx, accessToken)
if err != nil && !xerrors.Is(err, token.ErrNotExist) {
return nil, err
}
if token != nil {
return nil, nil
if find != nil {
return nil, token.ErrRevoke
}
t, err := jwt.ParseString(accessToken, jwt.WithVerify(
@ -47,20 +48,23 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
return nil, errors.Wrap(err, "cannot validate JWT token")
}
token = &domain.Token{
Expiry: t.Expiration(),
Scopes: []string{},
result := &domain.Token{
AccessToken: accessToken,
TokenType: "Bearer",
ClientID: t.Issuer(),
Me: t.Subject(),
Scopes: make([]string, 0),
}
if scope, ok := t.Get("scope"); ok {
token.Scopes = strings.Fields(scope.(string))
rawScope, ok := t.Get("scope")
if !ok {
return result, nil
}
return token, nil
if scope, ok := rawScope.(string); ok {
result.Scopes = strings.Fields(scope)
}
return result, nil
}
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {

View File

@ -12,38 +12,54 @@ import (
configrepo "source.toby3d.me/website/oauth/internal/config/repository/viper"
configucase "source.toby3d.me/website/oauth/internal/config/usecase"
"source.toby3d.me/website/oauth/internal/domain"
"source.toby3d.me/website/oauth/internal/token"
repository "source.toby3d.me/website/oauth/internal/token/repository/memory"
"source.toby3d.me/website/oauth/internal/token/usecase"
ucase "source.toby3d.me/website/oauth/internal/token/usecase"
)
func TestVerify(t *testing.T) {
t.Parallel()
v := viper.New()
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
v.SetDefault("indieauth.jwtSecret", "hackme")
v.Set("indieauth.jwtSigningAlgorithm", "HS256")
v.Set("indieauth.jwtSecret", "hackme")
repo := repository.NewMemoryTokenRepository(new(sync.Map))
accessToken := domain.TestToken(t)
useCase := ucase.NewTokenUseCase(repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)))
token, err := usecase.NewTokenUseCase(
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
).Verify(context.TODO(), accessToken.AccessToken)
require.NoError(t, err)
assert.Equal(t, accessToken.AccessToken, token.AccessToken)
t.Run("valid", func(t *testing.T) {
t.Parallel()
accessToken := domain.TestToken(t)
result, err := useCase.Verify(context.TODO(), accessToken.AccessToken)
require.NoError(t, err)
assert.Equal(t, accessToken, result)
})
t.Run("revoke", func(t *testing.T) {
t.Parallel()
accessToken := domain.TestToken(t)
require.NoError(t, repo.Create(context.TODO(), accessToken))
result, err := useCase.Verify(context.TODO(), accessToken.AccessToken)
require.ErrorIs(t, err, token.ErrRevoke)
assert.Nil(t, result)
})
}
func TestRevoke(t *testing.T) {
t.Parallel()
v := viper.New()
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
v.SetDefault("indieauth.jwtSecret", "hackme")
v.Set("indieauth.jwtSigningAlgorithm", "HS256")
v.Set("indieauth.jwtSecret", "hackme")
repo := repository.NewMemoryTokenRepository(new(sync.Map))
accessToken := domain.TestToken(t)
require.NoError(t, usecase.NewTokenUseCase(
require.NoError(t, ucase.NewTokenUseCase(
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
).Revoke(context.TODO(), accessToken.AccessToken))