diff --git a/internal/domain/token.go b/internal/domain/token.go index f4eb20e..74b7ed9 100644 --- a/internal/domain/token.go +++ b/internal/domain/token.go @@ -13,17 +13,15 @@ import ( ) type Token struct { - Expiry time.Time Scopes []string AccessToken string - TokenType string ClientID string Me string } func NewToken() *Token { t := new(Token) - t.Expiry = time.Time{} + t.Scopes = make([]string, 0) return t } @@ -61,9 +59,7 @@ func TestToken(tb testing.TB) *Token { return &Token{ AccessToken: string(accessToken), ClientID: t.Issuer(), - Expiry: t.Expiration(), Me: t.Subject(), Scopes: scopes, - TokenType: "Bearer", } } diff --git a/internal/token/delivery/http/token_http.go b/internal/token/delivery/http/token_http.go index 803e70e..a211a6a 100644 --- a/internal/token/delivery/http/token_http.go +++ b/internal/token/delivery/http/token_http.go @@ -56,23 +56,27 @@ func (h *RequestHandler) Read(ctx *http.RequestCtx) { rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization) - token, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer ")))) + t, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer ")))) if err != nil { - ctx.Error(err.Error(), http.StatusBadRequest) + if xerrors.Is(err, token.ErrRevoke) { + ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized) + } else { + ctx.Error(err.Error(), http.StatusBadRequest) + } return } - if token == nil { - ctx.SetStatusCode(http.StatusUnauthorized) + if t == nil { + ctx.Error(http.StatusMessage(http.StatusUnauthorized), http.StatusUnauthorized) return } if err := json.NewEncoder(ctx).Encode(&VerificationResponse{ - ClientID: token.ClientID, - Me: token.Me, - Scope: strings.Join(token.Scopes, " "), + ClientID: t.ClientID, + Me: t.Me, + Scope: strings.Join(t.Scopes, " "), }); err != nil { ctx.Error(err.Error(), http.StatusInternalServerError) } diff --git a/internal/token/repository.go b/internal/token/repository.go index cf3952e..296cffd 100644 --- a/internal/token/repository.go +++ b/internal/token/repository.go @@ -2,8 +2,7 @@ package token import ( "context" - - "golang.org/x/xerrors" + "errors" "source.toby3d.me/website/oauth/internal/domain" ) @@ -13,9 +12,7 @@ type Repository interface { Create(ctx context.Context, accessToken *domain.Token) error } -var ErrExist error = domain.Error{ - Code: "invalid_request", - Description: "this token is already exists", - URI: "", - Frame: xerrors.Caller(1), -} +var ( + ErrExist error = errors.New("token already exist") + ErrNotExist error = errors.New("token not exist") +) diff --git a/internal/token/repository/bolt/bolt_token.go b/internal/token/repository/bolt/bolt_token.go index b7a4db7..b3be95a 100644 --- a/internal/token/repository/bolt/bolt_token.go +++ b/internal/token/repository/bolt/bolt_token.go @@ -2,6 +2,7 @@ package bolt import ( "context" + "encoding/json" "time" "github.com/pkg/errors" @@ -12,13 +13,21 @@ import ( "source.toby3d.me/website/oauth/internal/token" ) -type boltTokenRepository struct { - db *bolt.DB -} +type ( + Token struct { + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + DeletedAt time.Time `json:"deletedAt,omitempty"` + Scopes []string `json:"scopes"` + AccessToken string `json:"accessToken"` + ClientID string `json:"clientId"` + Me string `json:"me"` + } -var ErrNotExist error = errors.New("token not exist") - -var DefaultBucket = []byte("tokens") //nolint: gochecknoglobals + boltTokenRepository struct { + db *bolt.DB + } +) func NewBoltTokenRepository(db *bolt.DB) token.Repository { return &boltTokenRepository{ @@ -26,68 +35,94 @@ func NewBoltTokenRepository(db *bolt.DB) token.Repository { } } -func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { - find, err := repo.Get(ctx, accessToken.AccessToken) - if err != nil { - return errors.Wrap(err, "cannot check token in database") +func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) (err error) { + t, err := repo.Get(ctx, accessToken.AccessToken) + if err != nil && !xerrors.Is(err, token.ErrNotExist) { + return errors.Wrap(err, "cannot get token in database") } - if find != nil { + if t != nil { return token.ErrExist } if err = repo.db.Update(func(tx *bolt.Tx) error { - bkt, err := tx.CreateBucketIfNotExists(DefaultBucket) + bkt, err := tx.CreateBucketIfNotExists(Token{}.Bucket()) if err != nil { return errors.Wrap(err, "cannot create bucket") } - err = bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339))) + token := new(Token) + token.Populate(accessToken) + + src, err := json.Marshal(token) if err != nil { + return errors.Wrap(err, "cannot marshal token data") + } + + if err = bkt.Put([]byte(token.AccessToken), src); err != nil { return errors.Wrap(err, "cannot put token into bucket") } return nil }); err != nil { - return errors.Wrap(err, "failed to batch token in database") + return errors.Wrap(err, "failed to put token into database") } return nil } func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { - result := &domain.Token{ - Expiry: time.Time{}, - Scopes: []string{}, - AccessToken: accessToken, - TokenType: "Bearer", - ClientID: "", - Me: "", - } + result := new(domain.Token) if err := repo.db.View(func(tx *bolt.Tx) (err error) { - bkt := tx.Bucket(DefaultBucket) + t := new(Token) + + bkt := tx.Bucket(t.Bucket()) if bkt == nil { - return ErrNotExist + return token.ErrNotExist } - expiry := bkt.Get([]byte(accessToken)) - if expiry == nil { - return ErrNotExist + src := bkt.Get([]byte(accessToken)) + if src == nil { + return token.ErrNotExist } - if result.Expiry, err = time.Parse(time.RFC3339, string(expiry)); err != nil { - return errors.Wrap(err, "cannot parse expiry date") + if err = t.Bind(src, result); err != nil { + return errors.Wrap(err, "cannot parse token") } return nil }); err != nil { - if xerrors.Is(err, ErrNotExist) { - return nil, nil - } - return nil, errors.Wrap(err, "failed to view token in database") } return result, nil } + +func (Token) Bucket() []byte { return []byte("tokens") } + +func (t *Token) Populate(accessToken *domain.Token) { + t.AccessToken = accessToken.AccessToken + t.ClientID = accessToken.ClientID + t.CreatedAt = time.Now().UTC().Round(time.Second) + t.Me = accessToken.Me + t.Scopes = make([]string, len(accessToken.Scopes)) + t.UpdatedAt = t.CreatedAt + + for i := range accessToken.Scopes { + t.Scopes[i] = accessToken.Scopes[i] + } +} + +func (t *Token) Bind(src []byte, accessToken *domain.Token) error { + if err := json.Unmarshal(src, t); err != nil { + return errors.Wrap(err, "cannot unmarshal token") + } + + accessToken.AccessToken = t.AccessToken + accessToken.ClientID = t.ClientID + accessToken.Me = t.Me + accessToken.Scopes = t.Scopes + + return nil +} diff --git a/internal/token/repository/bolt/bolt_token_test.go b/internal/token/repository/bolt/bolt_token_test.go index 0fd2cce..21c5e23 100644 --- a/internal/token/repository/bolt/bolt_token_test.go +++ b/internal/token/repository/bolt/bolt_token_test.go @@ -2,8 +2,8 @@ package bolt_test import ( "context" + "encoding/json" "testing" - "time" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -19,29 +19,26 @@ import ( func TestCreate(t *testing.T) { t.Parallel() - db, cleanup := util.TestBolt(t, repository.DefaultBucket) + db, cleanup := util.TestBolt(t, repository.Token{}.Bucket()) t.Cleanup(cleanup) + require.NoError(t, db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket()) + + return err + })) + repo := repository.NewBoltTokenRepository(db) accessToken := domain.TestToken(t) require.NoError(t, repo.Create(context.TODO(), accessToken)) - result := &domain.Token{ - Expiry: time.Time{}, - Scopes: []string{}, - AccessToken: accessToken.AccessToken, - TokenType: accessToken.TokenType, - ClientID: "", - Me: "", - } + result := domain.NewToken() require.NoError(t, db.View(func(tx *bolt.Tx) (err error) { - src := tx.Bucket(repository.DefaultBucket).Get([]byte(accessToken.AccessToken)) + dto := new(repository.Token) - result.Expiry, err = time.Parse(time.RFC3339, string(src)) - - return + return dto.Bind(tx.Bucket(repository.Token{}.Bucket()).Get([]byte(accessToken.AccessToken)), result) })) assert.Equal(t, accessToken, result) @@ -51,27 +48,33 @@ func TestCreate(t *testing.T) { func TestGet(t *testing.T) { t.Parallel() - db, cleanup := util.TestBolt(t, repository.DefaultBucket) + db, cleanup := util.TestBolt(t, repository.Token{}.Bucket()) t.Cleanup(cleanup) - repo := repository.NewBoltTokenRepository(db) accessToken := domain.TestToken(t) require.NoError(t, db.Update(func(tx *bolt.Tx) error { - bkt, err := tx.CreateBucketIfNotExists(repository.DefaultBucket) + bkt, err := tx.CreateBucketIfNotExists(repository.Token{}.Bucket()) if err != nil { return errors.Wrap(err, "cannot create bucket") } - err = bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339))) + t := new(repository.Token) + t.Populate(accessToken) + + src, err := json.Marshal(t) if err != nil { + return errors.Wrap(err, "cannot marshal token data") + } + + if err = bkt.Put([]byte(t.AccessToken), src); err != nil { return errors.Wrap(err, "cannot put token into bucket") } return nil })) - result, err := repo.Get(context.TODO(), accessToken.AccessToken) + result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken) assert.NoError(t, err) assert.Equal(t, accessToken, result) } diff --git a/internal/token/repository/memory/memory_token.go b/internal/token/repository/memory/memory_token.go index 715baab..5b76efa 100644 --- a/internal/token/repository/memory/memory_token.go +++ b/internal/token/repository/memory/memory_token.go @@ -2,9 +2,11 @@ package memory import ( "context" + "errors" "path" "sync" - "time" + + "golang.org/x/xerrors" "source.toby3d.me/website/oauth/internal/domain" "source.toby3d.me/website/oauth/internal/token" @@ -16,6 +18,8 @@ type memoryTokenRepository struct { const DefaultPathPrefix string = "tokens" +var ErrExist error = errors.New("token already exist") + func NewMemoryTokenRepository(store *sync.Map) token.Repository { return &memoryTokenRepository{ store: store, @@ -23,29 +27,30 @@ func NewMemoryTokenRepository(store *sync.Map) token.Repository { } func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { - key := path.Join(DefaultPathPrefix, accessToken.AccessToken) - - if _, ok := repo.store.Load(key); ok { - return token.ErrExist + t, err := repo.Get(ctx, accessToken.AccessToken) + if err != nil && !xerrors.Is(err, token.ErrNotExist) { + return err } - repo.store.Store(key, accessToken.Expiry) + if t != nil { + return ErrExist + } + + repo.store.Store(path.Join(DefaultPathPrefix, accessToken.AccessToken), accessToken) return nil } func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { - expiry, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken)) + t, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken)) if !ok { - return nil, nil + return nil, token.ErrNotExist } - return &domain.Token{ - Expiry: expiry.(time.Time), - Scopes: []string{}, - AccessToken: accessToken, - TokenType: "Bearer", - ClientID: "", - Me: "", - }, nil + result, ok := t.(*domain.Token) + if !ok { + return nil, token.ErrNotExist + } + + return result, nil } diff --git a/internal/token/repository/memory/memory_token_test.go b/internal/token/repository/memory/memory_token_test.go index e0ebac1..42f8b16 100644 --- a/internal/token/repository/memory/memory_token_test.go +++ b/internal/token/repository/memory/memory_token_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/require" "source.toby3d.me/website/oauth/internal/domain" - "source.toby3d.me/website/oauth/internal/token" repository "source.toby3d.me/website/oauth/internal/token/repository/memory" ) @@ -18,28 +17,27 @@ func TestCreate(t *testing.T) { t.Parallel() store := new(sync.Map) + token := domain.TestToken(t) + repo := repository.NewMemoryTokenRepository(store) + require.NoError(t, repo.Create(context.TODO(), token)) - accessToken := domain.TestToken(t) - require.NoError(t, repo.Create(context.TODO(), accessToken)) - - expiry, ok := store.Load(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken)) + result, ok := store.Load(path.Join(repository.DefaultPathPrefix, token.AccessToken)) assert.True(t, ok) - assert.Equal(t, accessToken.Expiry, expiry) + assert.Equal(t, token, result) - assert.EqualError(t, repo.Create(context.TODO(), accessToken), token.ErrExist.Error()) + assert.ErrorIs(t, repo.Create(context.TODO(), token), repository.ErrExist) } func TestGet(t *testing.T) { t.Parallel() store := new(sync.Map) - repo := repository.NewMemoryTokenRepository(store) + token := domain.TestToken(t) - accessToken := domain.TestToken(t) - store.Store(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken), accessToken.Expiry) + store.Store(path.Join(repository.DefaultPathPrefix, token.AccessToken), token) - result, err := repo.Get(context.TODO(), accessToken.AccessToken) + result, err := repository.NewMemoryTokenRepository(store).Get(context.TODO(), token.AccessToken) assert.NoError(t, err) - assert.Equal(t, accessToken, result) + assert.Equal(t, token, result) } diff --git a/internal/token/usecase.go b/internal/token/usecase.go index 0f9b388..1efefc7 100644 --- a/internal/token/usecase.go +++ b/internal/token/usecase.go @@ -2,6 +2,7 @@ package token import ( "context" + "errors" "source.toby3d.me/website/oauth/internal/domain" ) @@ -10,3 +11,5 @@ type UseCase interface { Verify(ctx context.Context, accessToken string) (*domain.Token, error) Revoke(ctx context.Context, accessToken string) error } + +var ErrRevoke error = errors.New("this token has been revoked") diff --git a/internal/token/usecase/token_usecase.go b/internal/token/usecase/token_ucase.go similarity index 77% rename from internal/token/usecase/token_usecase.go rename to internal/token/usecase/token_ucase.go index 11327f7..2678f5b 100644 --- a/internal/token/usecase/token_usecase.go +++ b/internal/token/usecase/token_ucase.go @@ -7,6 +7,7 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwt" "github.com/pkg/errors" + "golang.org/x/xerrors" "source.toby3d.me/website/oauth/internal/config" "source.toby3d.me/website/oauth/internal/domain" @@ -26,13 +27,13 @@ func NewTokenUseCase(tokens token.Repository, configer config.UseCase) token.Use } func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) { - token, err := useCase.tokens.Get(ctx, accessToken) - if err != nil { - return nil, errors.Wrap(err, "cannot find token in database") + find, err := useCase.tokens.Get(ctx, accessToken) + if err != nil && !xerrors.Is(err, token.ErrNotExist) { + return nil, err } - if token != nil { - return nil, nil + if find != nil { + return nil, token.ErrRevoke } t, err := jwt.ParseString(accessToken, jwt.WithVerify( @@ -47,20 +48,23 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d return nil, errors.Wrap(err, "cannot validate JWT token") } - token = &domain.Token{ - Expiry: t.Expiration(), - Scopes: []string{}, + result := &domain.Token{ AccessToken: accessToken, - TokenType: "Bearer", ClientID: t.Issuer(), Me: t.Subject(), + Scopes: make([]string, 0), } - if scope, ok := t.Get("scope"); ok { - token.Scopes = strings.Fields(scope.(string)) + rawScope, ok := t.Get("scope") + if !ok { + return result, nil } - return token, nil + if scope, ok := rawScope.(string); ok { + result.Scopes = strings.Fields(scope) + } + + return result, nil } func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error { diff --git a/internal/token/usecase/token_usecase_test.go b/internal/token/usecase/token_ucase_test.go similarity index 51% rename from internal/token/usecase/token_usecase_test.go rename to internal/token/usecase/token_ucase_test.go index c43c5ea..3adb48f 100644 --- a/internal/token/usecase/token_usecase_test.go +++ b/internal/token/usecase/token_ucase_test.go @@ -12,38 +12,54 @@ import ( configrepo "source.toby3d.me/website/oauth/internal/config/repository/viper" configucase "source.toby3d.me/website/oauth/internal/config/usecase" "source.toby3d.me/website/oauth/internal/domain" + "source.toby3d.me/website/oauth/internal/token" repository "source.toby3d.me/website/oauth/internal/token/repository/memory" - "source.toby3d.me/website/oauth/internal/token/usecase" + ucase "source.toby3d.me/website/oauth/internal/token/usecase" ) func TestVerify(t *testing.T) { t.Parallel() v := viper.New() - v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256") - v.SetDefault("indieauth.jwtSecret", "hackme") + v.Set("indieauth.jwtSigningAlgorithm", "HS256") + v.Set("indieauth.jwtSecret", "hackme") repo := repository.NewMemoryTokenRepository(new(sync.Map)) - accessToken := domain.TestToken(t) + useCase := ucase.NewTokenUseCase(repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))) - token, err := usecase.NewTokenUseCase( - repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)), - ).Verify(context.TODO(), accessToken.AccessToken) - require.NoError(t, err) - assert.Equal(t, accessToken.AccessToken, token.AccessToken) + t.Run("valid", func(t *testing.T) { + t.Parallel() + + accessToken := domain.TestToken(t) + + result, err := useCase.Verify(context.TODO(), accessToken.AccessToken) + require.NoError(t, err) + assert.Equal(t, accessToken, result) + }) + + t.Run("revoke", func(t *testing.T) { + t.Parallel() + + accessToken := domain.TestToken(t) + require.NoError(t, repo.Create(context.TODO(), accessToken)) + + result, err := useCase.Verify(context.TODO(), accessToken.AccessToken) + require.ErrorIs(t, err, token.ErrRevoke) + assert.Nil(t, result) + }) } func TestRevoke(t *testing.T) { t.Parallel() v := viper.New() - v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256") - v.SetDefault("indieauth.jwtSecret", "hackme") + v.Set("indieauth.jwtSigningAlgorithm", "HS256") + v.Set("indieauth.jwtSecret", "hackme") repo := repository.NewMemoryTokenRepository(new(sync.Map)) accessToken := domain.TestToken(t) - require.NoError(t, usecase.NewTokenUseCase( + require.NoError(t, ucase.NewTokenUseCase( repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)), ).Revoke(context.TODO(), accessToken.AccessToken))