diff --git a/internal/domain/token.go b/internal/domain/token.go index 66195a5..7ceaaad 100644 --- a/internal/domain/token.go +++ b/internal/domain/token.go @@ -1,37 +1,64 @@ package domain import ( + "strings" "testing" + "time" + + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" + "github.com/stretchr/testify/require" "source.toby3d.me/website/oauth/internal/random" ) type Token struct { + Expiry time.Time + Scopes []string AccessToken string + TokenType string ClientID string Me string - Profile *Profile - Scopes []string - Type string } func NewToken() *Token { t := new(Token) - t.Scopes = make([]string, 0) + t.Expiry = time.Time{} return t } -//nolint: gomnd func TestToken(tb testing.TB) *Token { tb.Helper() + client := TestClient(tb) + profile := TestProfile(tb) + now := time.Now().UTC().Round(time.Second) + scopes := []string{"create", "update", "delete"} + t := jwt.New() + + // required + t.Set(jwt.IssuerKey, client.ID) // NOTE(toby3d): client_id + t.Set(jwt.SubjectKey, profile.URL) // NOTE(toby3d): me + // TODO(toby3d): t.Set(jwt.AudienceKey, nil) + t.Set(jwt.ExpirationKey, now.Add(1*time.Hour)) + t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour)) + t.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour)) + // TODO(toby3d): t.Set(jwt.JwtIDKey, nil) + + // optional + t.Set("scope", strings.Join(scopes, " ")) + t.Set("nonce", random.New().String(32)) + + accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme")) + require.NoError(tb, err) + return &Token{ - AccessToken: random.New().String(32), - ClientID: "https://app.example.com/", - Me: "https://user.example.net/", - Profile: TestProfile(tb), - Scopes: []string{"create", "update", "delete"}, - Type: "Bearer", + AccessToken: string(accessToken), + ClientID: t.Issuer(), + Expiry: t.Expiration(), + Me: t.Subject(), + Scopes: scopes, + TokenType: "Bearer", } } diff --git a/internal/token/delivery/http/token_http.go b/internal/token/delivery/http/token_http.go index 875232e..70677cc 100644 --- a/internal/token/delivery/http/token_http.go +++ b/internal/token/delivery/http/token_http.go @@ -15,10 +15,6 @@ import ( ) type ( - RequestHandler struct { - useCase token.UseCase - } - RevocationRequest struct { Action string Token string @@ -32,6 +28,10 @@ type ( } RevocationResponse struct{} + + RequestHandler struct { + tokener token.UseCase + } ) const ( @@ -39,9 +39,9 @@ const ( ActionRevoke string = "revoke" ) -func NewRequestHandler(useCase token.UseCase) *RequestHandler { +func NewRequestHandler(tokener token.UseCase) *RequestHandler { return &RequestHandler{ - useCase: useCase, + tokener: tokener, } } @@ -54,16 +54,11 @@ func (h *RequestHandler) Read(ctx *http.RequestCtx) { ctx.SetContentType(common.MIMEApplicationJSON) ctx.SetStatusCode(http.StatusOK) - encoder := json.NewEncoder(ctx) rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization) - token, err := h.useCase.Verify(ctx, string(bytes.TrimSpace(bytes.TrimPrefix(rawToken, []byte("Bearer"))))) + token, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer ")))) if err != nil { - ctx.SetStatusCode(http.StatusBadRequest) - - if err = encoder.Encode(err); err != nil { - ctx.Error(err.Error(), http.StatusInternalServerError) - } + ctx.Error(err.Error(), http.StatusBadRequest) return } @@ -74,16 +69,12 @@ func (h *RequestHandler) Read(ctx *http.RequestCtx) { return } - if err := encoder.Encode(&VerificationResponse{ - Me: token.Me, + if err := json.NewEncoder(ctx).Encode(&VerificationResponse{ ClientID: token.ClientID, + Me: token.Me, Scope: strings.Join(token.Scopes, " "), }); err != nil { - ctx.SetStatusCode(http.StatusInternalServerError) - - if err = encoder.Encode(err); err != nil { - ctx.Error(err.Error(), http.StatusInternalServerError) - } + ctx.Error(err.Error(), http.StatusInternalServerError) } } @@ -102,30 +93,21 @@ func (h *RequestHandler) Revocation(ctx *http.RequestCtx) { req := new(RevocationRequest) if err := req.bind(ctx); err != nil { ctx.SetStatusCode(http.StatusBadRequest) - - if err = encoder.Encode(err); err != nil { - ctx.Error(err.Error(), http.StatusInternalServerError) - } + encoder.Encode(err) return } - if err := h.useCase.Revoke(ctx, req.Token); err != nil { + if err := h.tokener.Revoke(ctx, req.Token); err != nil { ctx.SetStatusCode(http.StatusBadRequest) - - if err = encoder.Encode(err); err != nil { - ctx.Error(err.Error(), http.StatusInternalServerError) - } + encoder.Encode(err) return } if err := encoder.Encode(&RevocationResponse{}); err != nil { ctx.SetStatusCode(http.StatusInternalServerError) - - if err = encoder.Encode(err); err != nil { - ctx.Error(err.Error(), http.StatusInternalServerError) - } + encoder.Encode(err) } } diff --git a/internal/token/delivery/http/token_http_test.go b/internal/token/delivery/http/token_http_test.go index fa1ca22..d8c53fc 100644 --- a/internal/token/delivery/http/token_http_test.go +++ b/internal/token/delivery/http/token_http_test.go @@ -7,11 +7,14 @@ import ( "testing" "github.com/goccy/go-json" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" http "github.com/valyala/fasthttp" "source.toby3d.me/website/oauth/internal/common" + configrepo "source.toby3d.me/website/oauth/internal/config/repository/viper" + configucase "source.toby3d.me/website/oauth/internal/config/usecase" "source.toby3d.me/website/oauth/internal/domain" delivery "source.toby3d.me/website/oauth/internal/token/delivery/http" repository "source.toby3d.me/website/oauth/internal/token/repository/memory" @@ -22,12 +25,16 @@ import ( func TestVerification(t *testing.T) { t.Parallel() - repo := repository.NewMemoryTokenRepository(new(sync.Map)) + v := viper.New() + v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256") + v.SetDefault("indieauth.jwtSecret", "hackme") + accessToken := domain.TestToken(t) - require.NoError(t, repo.Create(context.TODO(), accessToken)) - - client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(repo)).Read) + client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase( + repository.NewMemoryTokenRepository(new(sync.Map)), + configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)), + )).Read) t.Cleanup(cleanup) req := http.AcquireRequest() @@ -56,12 +63,16 @@ func TestVerification(t *testing.T) { func TestRevocation(t *testing.T) { t.Parallel() - repo := repository.NewMemoryTokenRepository(new(sync.Map)) + v := viper.New() + v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256") + v.SetDefault("indieauth.jwtSecret", "hackme") + + tokens := repository.NewMemoryTokenRepository(new(sync.Map)) accessToken := domain.TestToken(t) - require.NoError(t, repo.Create(context.TODO(), domain.TestToken(t))) - - client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(repo)).Update) + client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler( + usecase.NewTokenUseCase(tokens, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))), + ).Update) t.Cleanup(cleanup) req := http.AcquireRequest() @@ -81,7 +92,7 @@ func TestRevocation(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode()) assert.Equal(t, `{}`, strings.TrimSpace(string(resp.Body()))) - token, err := repo.Get(context.TODO(), accessToken.AccessToken) + result, err := tokens.Get(context.TODO(), accessToken.AccessToken) require.NoError(t, err) - assert.Nil(t, token) + assert.Equal(t, accessToken.AccessToken, result.AccessToken) } diff --git a/internal/token/repository.go b/internal/token/repository.go index bddf5c3..cf3952e 100644 --- a/internal/token/repository.go +++ b/internal/token/repository.go @@ -11,8 +11,6 @@ import ( type Repository interface { Get(ctx context.Context, accessToken string) (*domain.Token, error) Create(ctx context.Context, accessToken *domain.Token) error - Update(ctx context.Context, accessToken *domain.Token) error - Remove(ctx context.Context, accessToken string) error } var ErrExist error = domain.Error{ diff --git a/internal/token/repository/bolt/bolt_token.go b/internal/token/repository/bolt/bolt_token.go index 6900cce..99cccc5 100644 --- a/internal/token/repository/bolt/bolt_token.go +++ b/internal/token/repository/bolt/bolt_token.go @@ -2,9 +2,8 @@ package bolt import ( "context" - "strings" + "time" - json "github.com/goccy/go-json" "github.com/pkg/errors" bolt "go.etcd.io/bbolt" "golang.org/x/xerrors" @@ -13,21 +12,13 @@ import ( "source.toby3d.me/website/oauth/internal/token" ) -type ( - Token struct { - AccessToken string `json:"accessToken"` - ClientID string `json:"clientId"` - Me string `json:"me"` - Scope string `json:"scope"` - Type string `json:"type"` - } +type boltTokenRepository struct { + db *bolt.DB +} - boltTokenRepository struct { - db *bolt.DB - } -) +var ErrNotExist error = errors.New("token not exist") -var ErrNotExist error = xerrors.New("key not exist") +var DefaultBucket []byte = []byte("tokens") func NewBoltTokenRepository(db *bolt.DB) token.Repository { return &boltTokenRepository{ @@ -35,97 +26,60 @@ func NewBoltTokenRepository(db *bolt.DB) token.Repository { } } -func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { - result := domain.NewToken() +func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { + find, err := repo.Get(ctx, accessToken.AccessToken) + if err != nil { + return errors.Wrap(err, "cannot check token in database") + } - if err := repo.db.View(func(tx *bolt.Tx) error { - //nolint: exhaustivestruct - if src := tx.Bucket(Token{}.Bucket()).Get([]byte(accessToken)); src != nil { - return new(Token).Bind(src, result) + if find != nil { + return token.ErrExist + } + + if err = repo.db.Update(func(tx *bolt.Tx) error { + bkt, err := tx.CreateBucketIfNotExists(DefaultBucket) + if err != nil { + return errors.Wrap(err, "cannot create bucket") } - return ErrNotExist + return bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339))) }); err != nil { - if !xerrors.Is(err, ErrNotExist) { - return nil, errors.Wrap(err, "failed to retrieve token from storage") + return errors.Wrap(err, "failed to batch token in database") + } + + return nil +} + +func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { + result := &domain.Token{ + AccessToken: accessToken, + TokenType: "Bearer", + Expiry: time.Time{}, + } + + if err := repo.db.View(func(tx *bolt.Tx) (err error) { + bkt := tx.Bucket(DefaultBucket) + if bkt == nil { + return ErrNotExist } - return nil, nil + expiry := bkt.Get([]byte(accessToken)) + if expiry == nil { + return ErrNotExist + } + + if result.Expiry, err = time.Parse(time.RFC3339, string(expiry)); err != nil { + return err + } + + return nil + }); err != nil { + if xerrors.Is(err, ErrNotExist) { + return nil, nil + } + + return nil, errors.Wrap(err, "failed to view token in database") } return result, nil } - -func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { - t, err := repo.Get(ctx, accessToken.AccessToken) - if err != nil { - return errors.Wrap(err, "failed to verify the existence of the token") - } - - if t != nil { - return token.ErrExist - } - - return repo.Update(ctx, accessToken) -} - -func (repo *boltTokenRepository) Update(ctx context.Context, accessToken *domain.Token) error { - dto := new(Token) - dto.Populate(accessToken) - - src, err := json.Marshal(dto) - if err != nil { - return errors.Wrap(err, "failed to marshal token") - } - - if err = repo.db.Update(func(tx *bolt.Tx) error { - if err := tx.Bucket(dto.Bucket()).Put([]byte(dto.AccessToken), src); err != nil { - return errors.Wrap(err, "failed to overwrite the token in the bucket") - } - - return nil - }); err != nil { - return errors.Wrap(err, "failed to update the token in the repository") - } - - return nil -} - -func (repo *boltTokenRepository) Remove(ctx context.Context, accessToken string) error { - if err := repo.db.Update(func(tx *bolt.Tx) error { - //nolint: exhaustivestruct - if err := tx.Bucket(Token{}.Bucket()).Delete([]byte(accessToken)); err != nil { - return errors.Wrap(err, "failed to remove token in bucket") - } - - return nil - }); err != nil { - return errors.Wrap(err, "failed to remove token from storage") - } - - return nil -} - -func (Token) Bucket() []byte { return []byte("tokens") } - -func (t *Token) Populate(src *domain.Token) { - t.AccessToken = src.AccessToken - t.ClientID = src.ClientID - t.Me = src.Me - t.Scope = strings.Join(src.Scopes, " ") - t.Type = src.Type -} - -func (t *Token) Bind(src []byte, dst *domain.Token) error { - if err := json.Unmarshal(src, t); err != nil { - return errors.Wrap(err, "cannot unmarshal token source") - } - - dst.AccessToken = t.AccessToken - dst.Scopes = strings.Fields(t.Scope) - dst.Type = t.Type - dst.ClientID = t.ClientID - dst.Me = t.Me - - return nil -} diff --git a/internal/token/repository/bolt/bolt_token_test.go b/internal/token/repository/bolt/bolt_token_test.go index bcc265c..9c1a5c6 100644 --- a/internal/token/repository/bolt/bolt_token_test.go +++ b/internal/token/repository/bolt/bolt_token_test.go @@ -1,11 +1,10 @@ -//nolint: wrapcheck package bolt_test import ( "context" "testing" + "time" - json "github.com/goccy/go-json" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" bolt "go.etcd.io/bbolt" @@ -16,120 +15,55 @@ import ( "source.toby3d.me/website/oauth/internal/util" ) -func TestGet(t *testing.T) { - t.Parallel() - - db, cleanup := util.TestBolt(t, repository.Token{}.Bucket()) - t.Cleanup(cleanup) - - accessToken := domain.TestToken(t) - accessToken.Profile = nil - - dto := new(repository.Token) - dto.Populate(accessToken) - - src, err := json.Marshal(dto) - require.NoError(t, err) - - require.NoError(t, db.Update(func(tx *bolt.Tx) error { - //nolint: exhaustivestruct - return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src) - })) - - result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken) - require.NoError(t, err) - assert.Equal(t, accessToken, result) -} - func TestCreate(t *testing.T) { t.Parallel() - db, cleanup := util.TestBolt(t, repository.Token{}.Bucket()) + db, cleanup := util.TestBolt(t, repository.DefaultBucket) t.Cleanup(cleanup) repo := repository.NewBoltTokenRepository(db) accessToken := domain.TestToken(t) - accessToken.Profile = nil require.NoError(t, repo.Create(context.TODO(), accessToken)) - result := new(domain.Token) + result := &domain.Token{ + AccessToken: accessToken.AccessToken, + TokenType: accessToken.TokenType, + Expiry: time.Time{}, + } - require.NoError(t, db.View(func(tx *bolt.Tx) error { + require.NoError(t, db.View(func(tx *bolt.Tx) (err error) { //nolint: exhaustivestruct - return new(repository.Token).Bind(tx.Bucket(repository.Token{}.Bucket()). - Get([]byte(accessToken.AccessToken)), result) - })) + src := tx.Bucket(repository.DefaultBucket).Get([]byte(accessToken.AccessToken)) + result.Expiry, err = time.Parse(time.RFC3339, string(src)) + + return + })) assert.Equal(t, accessToken, result) - assert.EqualError(t, repo.Create(context.TODO(), accessToken), token.ErrExist.Error()) + + assert.ErrorIs(t, repo.Create(context.TODO(), accessToken), token.ErrExist) } -func TestUpdate(t *testing.T) { +func TestGet(t *testing.T) { t.Parallel() - db, cleanup := util.TestBolt(t, repository.Token{}.Bucket()) + db, cleanup := util.TestBolt(t, repository.DefaultBucket) t.Cleanup(cleanup) + repo := repository.NewBoltTokenRepository(db) accessToken := domain.TestToken(t) - src, err := json.Marshal(accessToken) - require.NoError(t, err) - - //nolint: exhaustivestruct require.NoError(t, db.Update(func(tx *bolt.Tx) error { - return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src) + bkt, err := tx.CreateBucketIfNotExists(repository.DefaultBucket) + if err != nil { + return err + } + + return bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339))) })) - require.NoError(t, repository.NewBoltTokenRepository(db).Update(context.TODO(), &domain.Token{ - AccessToken: accessToken.AccessToken, - ClientID: "https://client.example.net/", - Me: "https://toby3d.ru/", - Scopes: []string{"read"}, - Type: "Bearer", - Profile: nil, - })) - - result := domain.NewToken() - - //nolint: exhaustivestruct - require.NoError(t, db.View(func(tx *bolt.Tx) error { - return new(repository.Token).Bind(tx.Bucket(repository.Token{}.Bucket()). - Get([]byte(accessToken.AccessToken)), result) - })) - - assert.Equal(t, &domain.Token{ - AccessToken: accessToken.AccessToken, - ClientID: "https://client.example.net/", - Me: "https://toby3d.ru/", - Scopes: []string{"read"}, - Type: "Bearer", - Profile: nil, - }, result) -} - -func TestDelete(t *testing.T) { - t.Parallel() - - db, cleanup := util.TestBolt(t, repository.Token{}.Bucket()) - t.Cleanup(cleanup) - - accessToken := domain.TestToken(t) - - src, err := json.Marshal(accessToken) - require.NoError(t, err) - - require.NoError(t, db.Update(func(tx *bolt.Tx) error { - //nolint: exhaustivestruct - return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src) - })) - - require.NoError(t, repository.NewBoltTokenRepository(db).Remove(context.TODO(), accessToken.AccessToken)) - - require.NoError(t, db.View(func(tx *bolt.Tx) error { - //nolint: exhaustivestruct - assert.Nil(t, tx.Bucket(repository.Token{}.Bucket()).Get([]byte(accessToken.AccessToken))) - - return nil - })) + result, err := repo.Get(context.TODO(), accessToken.AccessToken) + assert.NoError(t, err) + assert.Equal(t, accessToken, result) } diff --git a/internal/token/repository/memory/memory_token.go b/internal/token/repository/memory/memory_token.go index 2d991fe..cb88121 100644 --- a/internal/token/repository/memory/memory_token.go +++ b/internal/token/repository/memory/memory_token.go @@ -4,58 +4,45 @@ import ( "context" "path" "sync" + "time" "source.toby3d.me/website/oauth/internal/domain" "source.toby3d.me/website/oauth/internal/token" ) type memoryTokenRepository struct { - tokens *sync.Map + store *sync.Map } -const Key string = "tokens" +const DefaultPathPrefix string = "tokens" -func NewMemoryTokenRepository(tokens *sync.Map) token.Repository { +func NewMemoryTokenRepository(store *sync.Map) token.Repository { return &memoryTokenRepository{ - tokens: tokens, + store: store, } } -func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { - src, ok := repo.tokens.Load(path.Join(Key, accessToken)) - if !ok { - return nil, nil - } - - result, ok := src.(*domain.Token) - if !ok { - return nil, nil - } - - return result, nil -} - func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { - t, err := repo.Get(ctx, accessToken.AccessToken) - if err != nil { - return err - } + key := path.Join(DefaultPathPrefix, accessToken.AccessToken) - if t != nil { + if _, ok := repo.store.Load(key); ok { return token.ErrExist } - return repo.Update(ctx, accessToken) -} - -func (repo *memoryTokenRepository) Update(ctx context.Context, accessToken *domain.Token) error { - repo.tokens.Store(path.Join(Key, accessToken.AccessToken), accessToken) + repo.store.Store(key, accessToken.Expiry) return nil } -func (repo *memoryTokenRepository) Remove(ctx context.Context, accessToken string) error { - repo.tokens.Delete(path.Join(Key, accessToken)) +func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { + expiry, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken)) + if !ok { + return nil, nil + } - return nil + return &domain.Token{ + AccessToken: accessToken, + TokenType: "Bearer", + Expiry: expiry.(time.Time), + }, nil } diff --git a/internal/token/repository/memory/memory_token_test.go b/internal/token/repository/memory/memory_token_test.go index 936e126..e0ebac1 100644 --- a/internal/token/repository/memory/memory_token_test.go +++ b/internal/token/repository/memory/memory_token_test.go @@ -14,66 +14,32 @@ import ( repository "source.toby3d.me/website/oauth/internal/token/repository/memory" ) -func TestGet(t *testing.T) { - t.Parallel() - - store := new(sync.Map) - accessToken := domain.TestToken(t) - - store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken) - - result, err := repository.NewMemoryTokenRepository(store).Get(context.TODO(), accessToken.AccessToken) - require.NoError(t, err) - assert.Equal(t, accessToken, result) -} - func TestCreate(t *testing.T) { t.Parallel() store := new(sync.Map) - accessToken := domain.TestToken(t) - repo := repository.NewMemoryTokenRepository(store) + + accessToken := domain.TestToken(t) require.NoError(t, repo.Create(context.TODO(), accessToken)) - result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken)) + expiry, ok := store.Load(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken)) assert.True(t, ok) - assert.Equal(t, accessToken, result) + assert.Equal(t, accessToken.Expiry, expiry) assert.EqualError(t, repo.Create(context.TODO(), accessToken), token.ErrExist.Error()) } -func TestUpdate(t *testing.T) { +func TestGet(t *testing.T) { t.Parallel() store := new(sync.Map) + repo := repository.NewMemoryTokenRepository(store) + accessToken := domain.TestToken(t) + store.Store(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken), accessToken.Expiry) - store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken) - - tokenCopy := *accessToken - tokenCopy.ClientID = "https://client.example.com/" - tokenCopy.Me = "https://toby3d.ru/" - - require.NoError(t, repository.NewMemoryTokenRepository(store).Update(context.TODO(), &tokenCopy)) - - result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken)) - assert.True(t, ok) - assert.NotEqual(t, accessToken, result) - assert.Equal(t, &tokenCopy, result) -} - -func TestDelete(t *testing.T) { - t.Parallel() - - store := new(sync.Map) - accessToken := domain.TestToken(t) - - store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken) - - require.NoError(t, repository.NewMemoryTokenRepository(store).Remove(context.TODO(), accessToken.AccessToken)) - - result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken)) - assert.False(t, ok) - assert.Nil(t, result) + result, err := repo.Get(context.TODO(), accessToken.AccessToken) + assert.NoError(t, err) + assert.Equal(t, accessToken, result) } diff --git a/internal/token/usecase/token_usecase.go b/internal/token/usecase/token_usecase.go index 669bf44..11327f7 100644 --- a/internal/token/usecase/token_usecase.go +++ b/internal/token/usecase/token_usecase.go @@ -2,35 +2,75 @@ package usecase import ( "context" + "strings" + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" "github.com/pkg/errors" + "source.toby3d.me/website/oauth/internal/config" "source.toby3d.me/website/oauth/internal/domain" "source.toby3d.me/website/oauth/internal/token" ) type tokenUseCase struct { - tokens token.Repository + tokens token.Repository + configer config.UseCase } -func NewTokenUseCase(tokens token.Repository) token.UseCase { +func NewTokenUseCase(tokens token.Repository, configer config.UseCase) token.UseCase { return &tokenUseCase{ - tokens: tokens, + tokens: tokens, + configer: configer, } } func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) { - t, err := useCase.tokens.Get(ctx, accessToken) + token, err := useCase.tokens.Get(ctx, accessToken) if err != nil { - return nil, errors.Wrap(err, "failed to retrieve token from storage") + return nil, errors.Wrap(err, "cannot find token in database") } - return t, nil + if token != nil { + return nil, nil + } + + t, err := jwt.ParseString(accessToken, jwt.WithVerify( + jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()), + []byte(useCase.configer.GetIndieAuthJWTSecret()), + )) + if err != nil { + return nil, errors.Wrap(err, "cannot parse JWT token") + } + + if err = jwt.Validate(t); err != nil { + return nil, errors.Wrap(err, "cannot validate JWT token") + } + + token = &domain.Token{ + Expiry: t.Expiration(), + Scopes: []string{}, + AccessToken: accessToken, + TokenType: "Bearer", + ClientID: t.Issuer(), + Me: t.Subject(), + } + + if scope, ok := t.Get("scope"); ok { + token.Scopes = strings.Fields(scope.(string)) + } + + return token, nil } func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error { - if err := useCase.tokens.Remove(ctx, accessToken); err != nil { - return errors.Wrap(err, "failed to delete a token in the vault") + t, err := useCase.Verify(ctx, accessToken) + if err != nil { + return errors.Wrap(err, "cannot verify token") + } + + if err = useCase.tokens.Create(ctx, t); err != nil { + return errors.Wrap(err, "cannot save token in database") } return nil diff --git a/internal/token/usecase/token_usecase_test.go b/internal/token/usecase/token_usecase_test.go index b9477f0..c43c5ea 100644 --- a/internal/token/usecase/token_usecase_test.go +++ b/internal/token/usecase/token_usecase_test.go @@ -5,9 +5,12 @@ import ( "sync" "testing" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + configrepo "source.toby3d.me/website/oauth/internal/config/repository/viper" + configucase "source.toby3d.me/website/oauth/internal/config/usecase" "source.toby3d.me/website/oauth/internal/domain" repository "source.toby3d.me/website/oauth/internal/token/repository/memory" "source.toby3d.me/website/oauth/internal/token/usecase" @@ -16,31 +19,35 @@ import ( func TestVerify(t *testing.T) { t.Parallel() + v := viper.New() + v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256") + v.SetDefault("indieauth.jwtSecret", "hackme") + repo := repository.NewMemoryTokenRepository(new(sync.Map)) - accessToken := domain.NewToken() + accessToken := domain.TestToken(t) - require.NoError(t, repo.Create(context.TODO(), accessToken)) - - token, err := usecase.NewTokenUseCase(repo).Verify(context.TODO(), accessToken.AccessToken) + token, err := usecase.NewTokenUseCase( + repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)), + ).Verify(context.TODO(), accessToken.AccessToken) require.NoError(t, err) - assert.Equal(t, accessToken, token) + assert.Equal(t, accessToken.AccessToken, token.AccessToken) } func TestRevoke(t *testing.T) { t.Parallel() + v := viper.New() + v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256") + v.SetDefault("indieauth.jwtSecret", "hackme") + repo := repository.NewMemoryTokenRepository(new(sync.Map)) accessToken := domain.TestToken(t) - require.NoError(t, repo.Create(context.TODO(), accessToken)) + require.NoError(t, usecase.NewTokenUseCase( + repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)), + ).Revoke(context.TODO(), accessToken.AccessToken)) - token, err := repo.Get(context.TODO(), accessToken.AccessToken) - require.NoError(t, err) - assert.NotNil(t, token) - - require.NoError(t, usecase.NewTokenUseCase(repo).Revoke(context.TODO(), token.AccessToken)) - - token, err = repo.Get(context.TODO(), token.AccessToken) - require.NoError(t, err) - assert.Nil(t, token) + result, err := repo.Get(context.TODO(), accessToken.AccessToken) + assert.NoError(t, err) + assert.Equal(t, accessToken.AccessToken, result.AccessToken) }