♻️ Refactored token package, store only revoked tokens
This commit is contained in:
parent
dd43f0bf22
commit
4f63db7bfa
|
@ -1,37 +1,64 @@
|
|||
package domain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"source.toby3d.me/website/oauth/internal/random"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Expiry time.Time
|
||||
Scopes []string
|
||||
AccessToken string
|
||||
TokenType string
|
||||
ClientID string
|
||||
Me string
|
||||
Profile *Profile
|
||||
Scopes []string
|
||||
Type string
|
||||
}
|
||||
|
||||
func NewToken() *Token {
|
||||
t := new(Token)
|
||||
t.Scopes = make([]string, 0)
|
||||
t.Expiry = time.Time{}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
//nolint: gomnd
|
||||
func TestToken(tb testing.TB) *Token {
|
||||
tb.Helper()
|
||||
|
||||
client := TestClient(tb)
|
||||
profile := TestProfile(tb)
|
||||
now := time.Now().UTC().Round(time.Second)
|
||||
scopes := []string{"create", "update", "delete"}
|
||||
t := jwt.New()
|
||||
|
||||
// required
|
||||
t.Set(jwt.IssuerKey, client.ID) // NOTE(toby3d): client_id
|
||||
t.Set(jwt.SubjectKey, profile.URL) // NOTE(toby3d): me
|
||||
// TODO(toby3d): t.Set(jwt.AudienceKey, nil)
|
||||
t.Set(jwt.ExpirationKey, now.Add(1*time.Hour))
|
||||
t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour))
|
||||
t.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour))
|
||||
// TODO(toby3d): t.Set(jwt.JwtIDKey, nil)
|
||||
|
||||
// optional
|
||||
t.Set("scope", strings.Join(scopes, " "))
|
||||
t.Set("nonce", random.New().String(32))
|
||||
|
||||
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return &Token{
|
||||
AccessToken: random.New().String(32),
|
||||
ClientID: "https://app.example.com/",
|
||||
Me: "https://user.example.net/",
|
||||
Profile: TestProfile(tb),
|
||||
Scopes: []string{"create", "update", "delete"},
|
||||
Type: "Bearer",
|
||||
AccessToken: string(accessToken),
|
||||
ClientID: t.Issuer(),
|
||||
Expiry: t.Expiration(),
|
||||
Me: t.Subject(),
|
||||
Scopes: scopes,
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,10 +15,6 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
RequestHandler struct {
|
||||
useCase token.UseCase
|
||||
}
|
||||
|
||||
RevocationRequest struct {
|
||||
Action string
|
||||
Token string
|
||||
|
@ -32,6 +28,10 @@ type (
|
|||
}
|
||||
|
||||
RevocationResponse struct{}
|
||||
|
||||
RequestHandler struct {
|
||||
tokener token.UseCase
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -39,9 +39,9 @@ const (
|
|||
ActionRevoke string = "revoke"
|
||||
)
|
||||
|
||||
func NewRequestHandler(useCase token.UseCase) *RequestHandler {
|
||||
func NewRequestHandler(tokener token.UseCase) *RequestHandler {
|
||||
return &RequestHandler{
|
||||
useCase: useCase,
|
||||
tokener: tokener,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -54,16 +54,11 @@ func (h *RequestHandler) Read(ctx *http.RequestCtx) {
|
|||
ctx.SetContentType(common.MIMEApplicationJSON)
|
||||
ctx.SetStatusCode(http.StatusOK)
|
||||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
rawToken := ctx.Request.Header.Peek(http.HeaderAuthorization)
|
||||
|
||||
token, err := h.useCase.Verify(ctx, string(bytes.TrimSpace(bytes.TrimPrefix(rawToken, []byte("Bearer")))))
|
||||
token, err := h.tokener.Verify(ctx, string(bytes.TrimPrefix(rawToken, []byte("Bearer "))))
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
|
||||
if err = encoder.Encode(err); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -74,16 +69,12 @@ func (h *RequestHandler) Read(ctx *http.RequestCtx) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := encoder.Encode(&VerificationResponse{
|
||||
Me: token.Me,
|
||||
if err := json.NewEncoder(ctx).Encode(&VerificationResponse{
|
||||
ClientID: token.ClientID,
|
||||
Me: token.Me,
|
||||
Scope: strings.Join(token.Scopes, " "),
|
||||
}); err != nil {
|
||||
ctx.SetStatusCode(http.StatusInternalServerError)
|
||||
|
||||
if err = encoder.Encode(err); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,30 +93,21 @@ func (h *RequestHandler) Revocation(ctx *http.RequestCtx) {
|
|||
req := new(RevocationRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
|
||||
if err = encoder.Encode(err); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.useCase.Revoke(ctx, req.Token); err != nil {
|
||||
if err := h.tokener.Revoke(ctx, req.Token); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
|
||||
if err = encoder.Encode(err); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := encoder.Encode(&RevocationResponse{}); err != nil {
|
||||
ctx.SetStatusCode(http.StatusInternalServerError)
|
||||
|
||||
if err = encoder.Encode(err); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
encoder.Encode(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,11 +7,14 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
http "github.com/valyala/fasthttp"
|
||||
|
||||
"source.toby3d.me/website/oauth/internal/common"
|
||||
configrepo "source.toby3d.me/website/oauth/internal/config/repository/viper"
|
||||
configucase "source.toby3d.me/website/oauth/internal/config/usecase"
|
||||
"source.toby3d.me/website/oauth/internal/domain"
|
||||
delivery "source.toby3d.me/website/oauth/internal/token/delivery/http"
|
||||
repository "source.toby3d.me/website/oauth/internal/token/repository/memory"
|
||||
|
@ -22,12 +25,16 @@ import (
|
|||
func TestVerification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
v := viper.New()
|
||||
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
|
||||
v.SetDefault("indieauth.jwtSecret", "hackme")
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
||||
|
||||
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(repo)).Read)
|
||||
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(
|
||||
repository.NewMemoryTokenRepository(new(sync.Map)),
|
||||
configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
|
||||
)).Read)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
req := http.AcquireRequest()
|
||||
|
@ -56,12 +63,16 @@ func TestVerification(t *testing.T) {
|
|||
func TestRevocation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
v := viper.New()
|
||||
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
|
||||
v.SetDefault("indieauth.jwtSecret", "hackme")
|
||||
|
||||
tokens := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
require.NoError(t, repo.Create(context.TODO(), domain.TestToken(t)))
|
||||
|
||||
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(usecase.NewTokenUseCase(repo)).Update)
|
||||
client, _, cleanup := util.TestServe(t, delivery.NewRequestHandler(
|
||||
usecase.NewTokenUseCase(tokens, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v))),
|
||||
).Update)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
req := http.AcquireRequest()
|
||||
|
@ -81,7 +92,7 @@ func TestRevocation(t *testing.T) {
|
|||
assert.Equal(t, http.StatusOK, resp.StatusCode())
|
||||
assert.Equal(t, `{}`, strings.TrimSpace(string(resp.Body())))
|
||||
|
||||
token, err := repo.Get(context.TODO(), accessToken.AccessToken)
|
||||
result, err := tokens.Get(context.TODO(), accessToken.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, token)
|
||||
assert.Equal(t, accessToken.AccessToken, result.AccessToken)
|
||||
}
|
||||
|
|
|
@ -11,8 +11,6 @@ import (
|
|||
type Repository interface {
|
||||
Get(ctx context.Context, accessToken string) (*domain.Token, error)
|
||||
Create(ctx context.Context, accessToken *domain.Token) error
|
||||
Update(ctx context.Context, accessToken *domain.Token) error
|
||||
Remove(ctx context.Context, accessToken string) error
|
||||
}
|
||||
|
||||
var ErrExist error = domain.Error{
|
||||
|
|
|
@ -2,9 +2,8 @@ package bolt
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
"golang.org/x/xerrors"
|
||||
|
@ -13,21 +12,13 @@ import (
|
|||
"source.toby3d.me/website/oauth/internal/token"
|
||||
)
|
||||
|
||||
type (
|
||||
Token struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
ClientID string `json:"clientId"`
|
||||
Me string `json:"me"`
|
||||
Scope string `json:"scope"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
type boltTokenRepository struct {
|
||||
db *bolt.DB
|
||||
}
|
||||
|
||||
boltTokenRepository struct {
|
||||
db *bolt.DB
|
||||
}
|
||||
)
|
||||
var ErrNotExist error = errors.New("token not exist")
|
||||
|
||||
var ErrNotExist error = xerrors.New("key not exist")
|
||||
var DefaultBucket []byte = []byte("tokens")
|
||||
|
||||
func NewBoltTokenRepository(db *bolt.DB) token.Repository {
|
||||
return &boltTokenRepository{
|
||||
|
@ -35,97 +26,60 @@ func NewBoltTokenRepository(db *bolt.DB) token.Repository {
|
|||
}
|
||||
}
|
||||
|
||||
func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
result := domain.NewToken()
|
||||
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
|
||||
find, err := repo.Get(ctx, accessToken.AccessToken)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot check token in database")
|
||||
}
|
||||
|
||||
if err := repo.db.View(func(tx *bolt.Tx) error {
|
||||
//nolint: exhaustivestruct
|
||||
if src := tx.Bucket(Token{}.Bucket()).Get([]byte(accessToken)); src != nil {
|
||||
return new(Token).Bind(src, result)
|
||||
if find != nil {
|
||||
return token.ErrExist
|
||||
}
|
||||
|
||||
if err = repo.db.Update(func(tx *bolt.Tx) error {
|
||||
bkt, err := tx.CreateBucketIfNotExists(DefaultBucket)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot create bucket")
|
||||
}
|
||||
|
||||
return ErrNotExist
|
||||
return bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339)))
|
||||
}); err != nil {
|
||||
if !xerrors.Is(err, ErrNotExist) {
|
||||
return nil, errors.Wrap(err, "failed to retrieve token from storage")
|
||||
return errors.Wrap(err, "failed to batch token in database")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo *boltTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
result := &domain.Token{
|
||||
AccessToken: accessToken,
|
||||
TokenType: "Bearer",
|
||||
Expiry: time.Time{},
|
||||
}
|
||||
|
||||
if err := repo.db.View(func(tx *bolt.Tx) (err error) {
|
||||
bkt := tx.Bucket(DefaultBucket)
|
||||
if bkt == nil {
|
||||
return ErrNotExist
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
expiry := bkt.Get([]byte(accessToken))
|
||||
if expiry == nil {
|
||||
return ErrNotExist
|
||||
}
|
||||
|
||||
if result.Expiry, err = time.Parse(time.RFC3339, string(expiry)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
if xerrors.Is(err, ErrNotExist) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, errors.Wrap(err, "failed to view token in database")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (repo *boltTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
|
||||
t, err := repo.Get(ctx, accessToken.AccessToken)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to verify the existence of the token")
|
||||
}
|
||||
|
||||
if t != nil {
|
||||
return token.ErrExist
|
||||
}
|
||||
|
||||
return repo.Update(ctx, accessToken)
|
||||
}
|
||||
|
||||
func (repo *boltTokenRepository) Update(ctx context.Context, accessToken *domain.Token) error {
|
||||
dto := new(Token)
|
||||
dto.Populate(accessToken)
|
||||
|
||||
src, err := json.Marshal(dto)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal token")
|
||||
}
|
||||
|
||||
if err = repo.db.Update(func(tx *bolt.Tx) error {
|
||||
if err := tx.Bucket(dto.Bucket()).Put([]byte(dto.AccessToken), src); err != nil {
|
||||
return errors.Wrap(err, "failed to overwrite the token in the bucket")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to update the token in the repository")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo *boltTokenRepository) Remove(ctx context.Context, accessToken string) error {
|
||||
if err := repo.db.Update(func(tx *bolt.Tx) error {
|
||||
//nolint: exhaustivestruct
|
||||
if err := tx.Bucket(Token{}.Bucket()).Delete([]byte(accessToken)); err != nil {
|
||||
return errors.Wrap(err, "failed to remove token in bucket")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to remove token from storage")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Token) Bucket() []byte { return []byte("tokens") }
|
||||
|
||||
func (t *Token) Populate(src *domain.Token) {
|
||||
t.AccessToken = src.AccessToken
|
||||
t.ClientID = src.ClientID
|
||||
t.Me = src.Me
|
||||
t.Scope = strings.Join(src.Scopes, " ")
|
||||
t.Type = src.Type
|
||||
}
|
||||
|
||||
func (t *Token) Bind(src []byte, dst *domain.Token) error {
|
||||
if err := json.Unmarshal(src, t); err != nil {
|
||||
return errors.Wrap(err, "cannot unmarshal token source")
|
||||
}
|
||||
|
||||
dst.AccessToken = t.AccessToken
|
||||
dst.Scopes = strings.Fields(t.Scope)
|
||||
dst.Type = t.Type
|
||||
dst.ClientID = t.ClientID
|
||||
dst.Me = t.Me
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
//nolint: wrapcheck
|
||||
package bolt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
|
@ -16,120 +15,55 @@ import (
|
|||
"source.toby3d.me/website/oauth/internal/util"
|
||||
)
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
accessToken.Profile = nil
|
||||
|
||||
dto := new(repository.Token)
|
||||
dto.Populate(accessToken)
|
||||
|
||||
src, err := json.Marshal(dto)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
|
||||
//nolint: exhaustivestruct
|
||||
return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src)
|
||||
}))
|
||||
|
||||
result, err := repository.NewBoltTokenRepository(db).Get(context.TODO(), accessToken.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accessToken, result)
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
|
||||
db, cleanup := util.TestBolt(t, repository.DefaultBucket)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
repo := repository.NewBoltTokenRepository(db)
|
||||
accessToken := domain.TestToken(t)
|
||||
accessToken.Profile = nil
|
||||
|
||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
||||
|
||||
result := new(domain.Token)
|
||||
result := &domain.Token{
|
||||
AccessToken: accessToken.AccessToken,
|
||||
TokenType: accessToken.TokenType,
|
||||
Expiry: time.Time{},
|
||||
}
|
||||
|
||||
require.NoError(t, db.View(func(tx *bolt.Tx) error {
|
||||
require.NoError(t, db.View(func(tx *bolt.Tx) (err error) {
|
||||
//nolint: exhaustivestruct
|
||||
return new(repository.Token).Bind(tx.Bucket(repository.Token{}.Bucket()).
|
||||
Get([]byte(accessToken.AccessToken)), result)
|
||||
}))
|
||||
src := tx.Bucket(repository.DefaultBucket).Get([]byte(accessToken.AccessToken))
|
||||
|
||||
result.Expiry, err = time.Parse(time.RFC3339, string(src))
|
||||
|
||||
return
|
||||
}))
|
||||
assert.Equal(t, accessToken, result)
|
||||
assert.EqualError(t, repo.Create(context.TODO(), accessToken), token.ErrExist.Error())
|
||||
|
||||
assert.ErrorIs(t, repo.Create(context.TODO(), accessToken), token.ErrExist)
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
func TestGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
|
||||
db, cleanup := util.TestBolt(t, repository.DefaultBucket)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
repo := repository.NewBoltTokenRepository(db)
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
src, err := json.Marshal(accessToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
//nolint: exhaustivestruct
|
||||
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
|
||||
return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src)
|
||||
bkt, err := tx.CreateBucketIfNotExists(repository.DefaultBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bkt.Put([]byte(accessToken.AccessToken), []byte(accessToken.Expiry.Format(time.RFC3339)))
|
||||
}))
|
||||
|
||||
require.NoError(t, repository.NewBoltTokenRepository(db).Update(context.TODO(), &domain.Token{
|
||||
AccessToken: accessToken.AccessToken,
|
||||
ClientID: "https://client.example.net/",
|
||||
Me: "https://toby3d.ru/",
|
||||
Scopes: []string{"read"},
|
||||
Type: "Bearer",
|
||||
Profile: nil,
|
||||
}))
|
||||
|
||||
result := domain.NewToken()
|
||||
|
||||
//nolint: exhaustivestruct
|
||||
require.NoError(t, db.View(func(tx *bolt.Tx) error {
|
||||
return new(repository.Token).Bind(tx.Bucket(repository.Token{}.Bucket()).
|
||||
Get([]byte(accessToken.AccessToken)), result)
|
||||
}))
|
||||
|
||||
assert.Equal(t, &domain.Token{
|
||||
AccessToken: accessToken.AccessToken,
|
||||
ClientID: "https://client.example.net/",
|
||||
Me: "https://toby3d.ru/",
|
||||
Scopes: []string{"read"},
|
||||
Type: "Bearer",
|
||||
Profile: nil,
|
||||
}, result)
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, cleanup := util.TestBolt(t, repository.Token{}.Bucket())
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
src, err := json.Marshal(accessToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.Update(func(tx *bolt.Tx) error {
|
||||
//nolint: exhaustivestruct
|
||||
return tx.Bucket(repository.Token{}.Bucket()).Put([]byte(accessToken.AccessToken), src)
|
||||
}))
|
||||
|
||||
require.NoError(t, repository.NewBoltTokenRepository(db).Remove(context.TODO(), accessToken.AccessToken))
|
||||
|
||||
require.NoError(t, db.View(func(tx *bolt.Tx) error {
|
||||
//nolint: exhaustivestruct
|
||||
assert.Nil(t, tx.Bucket(repository.Token{}.Bucket()).Get([]byte(accessToken.AccessToken)))
|
||||
|
||||
return nil
|
||||
}))
|
||||
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accessToken, result)
|
||||
}
|
||||
|
|
|
@ -4,58 +4,45 @@ import (
|
|||
"context"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"source.toby3d.me/website/oauth/internal/domain"
|
||||
"source.toby3d.me/website/oauth/internal/token"
|
||||
)
|
||||
|
||||
type memoryTokenRepository struct {
|
||||
tokens *sync.Map
|
||||
store *sync.Map
|
||||
}
|
||||
|
||||
const Key string = "tokens"
|
||||
const DefaultPathPrefix string = "tokens"
|
||||
|
||||
func NewMemoryTokenRepository(tokens *sync.Map) token.Repository {
|
||||
func NewMemoryTokenRepository(store *sync.Map) token.Repository {
|
||||
return &memoryTokenRepository{
|
||||
tokens: tokens,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
src, ok := repo.tokens.Load(path.Join(Key, accessToken))
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, ok := src.(*domain.Token)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
|
||||
t, err := repo.Get(ctx, accessToken.AccessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := path.Join(DefaultPathPrefix, accessToken.AccessToken)
|
||||
|
||||
if t != nil {
|
||||
if _, ok := repo.store.Load(key); ok {
|
||||
return token.ErrExist
|
||||
}
|
||||
|
||||
return repo.Update(ctx, accessToken)
|
||||
}
|
||||
|
||||
func (repo *memoryTokenRepository) Update(ctx context.Context, accessToken *domain.Token) error {
|
||||
repo.tokens.Store(path.Join(Key, accessToken.AccessToken), accessToken)
|
||||
repo.store.Store(key, accessToken.Expiry)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo *memoryTokenRepository) Remove(ctx context.Context, accessToken string) error {
|
||||
repo.tokens.Delete(path.Join(Key, accessToken))
|
||||
func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
expiry, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken))
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return &domain.Token{
|
||||
AccessToken: accessToken,
|
||||
TokenType: "Bearer",
|
||||
Expiry: expiry.(time.Time),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -14,66 +14,32 @@ import (
|
|||
repository "source.toby3d.me/website/oauth/internal/token/repository/memory"
|
||||
)
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := new(sync.Map)
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken)
|
||||
|
||||
result, err := repository.NewMemoryTokenRepository(store).Get(context.TODO(), accessToken.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accessToken, result)
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := new(sync.Map)
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
repo := repository.NewMemoryTokenRepository(store)
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
||||
|
||||
result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken))
|
||||
expiry, ok := store.Load(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken))
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, accessToken, result)
|
||||
assert.Equal(t, accessToken.Expiry, expiry)
|
||||
|
||||
assert.EqualError(t, repo.Create(context.TODO(), accessToken), token.ErrExist.Error())
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
func TestGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := new(sync.Map)
|
||||
repo := repository.NewMemoryTokenRepository(store)
|
||||
|
||||
accessToken := domain.TestToken(t)
|
||||
store.Store(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken), accessToken.Expiry)
|
||||
|
||||
store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken)
|
||||
|
||||
tokenCopy := *accessToken
|
||||
tokenCopy.ClientID = "https://client.example.com/"
|
||||
tokenCopy.Me = "https://toby3d.ru/"
|
||||
|
||||
require.NoError(t, repository.NewMemoryTokenRepository(store).Update(context.TODO(), &tokenCopy))
|
||||
|
||||
result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken))
|
||||
assert.True(t, ok)
|
||||
assert.NotEqual(t, accessToken, result)
|
||||
assert.Equal(t, &tokenCopy, result)
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := new(sync.Map)
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
store.Store(path.Join(repository.Key, accessToken.AccessToken), accessToken)
|
||||
|
||||
require.NoError(t, repository.NewMemoryTokenRepository(store).Remove(context.TODO(), accessToken.AccessToken))
|
||||
|
||||
result, ok := store.Load(path.Join(repository.Key, accessToken.AccessToken))
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, result)
|
||||
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accessToken, result)
|
||||
}
|
||||
|
|
|
@ -2,35 +2,75 @@ package usecase
|
|||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"source.toby3d.me/website/oauth/internal/config"
|
||||
"source.toby3d.me/website/oauth/internal/domain"
|
||||
"source.toby3d.me/website/oauth/internal/token"
|
||||
)
|
||||
|
||||
type tokenUseCase struct {
|
||||
tokens token.Repository
|
||||
tokens token.Repository
|
||||
configer config.UseCase
|
||||
}
|
||||
|
||||
func NewTokenUseCase(tokens token.Repository) token.UseCase {
|
||||
func NewTokenUseCase(tokens token.Repository, configer config.UseCase) token.UseCase {
|
||||
return &tokenUseCase{
|
||||
tokens: tokens,
|
||||
tokens: tokens,
|
||||
configer: configer,
|
||||
}
|
||||
}
|
||||
|
||||
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
t, err := useCase.tokens.Get(ctx, accessToken)
|
||||
token, err := useCase.tokens.Get(ctx, accessToken)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to retrieve token from storage")
|
||||
return nil, errors.Wrap(err, "cannot find token in database")
|
||||
}
|
||||
|
||||
return t, nil
|
||||
if token != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
t, err := jwt.ParseString(accessToken, jwt.WithVerify(
|
||||
jwa.SignatureAlgorithm(useCase.configer.GetIndieAuthJWTSigningAlgorithm()),
|
||||
[]byte(useCase.configer.GetIndieAuthJWTSecret()),
|
||||
))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse JWT token")
|
||||
}
|
||||
|
||||
if err = jwt.Validate(t); err != nil {
|
||||
return nil, errors.Wrap(err, "cannot validate JWT token")
|
||||
}
|
||||
|
||||
token = &domain.Token{
|
||||
Expiry: t.Expiration(),
|
||||
Scopes: []string{},
|
||||
AccessToken: accessToken,
|
||||
TokenType: "Bearer",
|
||||
ClientID: t.Issuer(),
|
||||
Me: t.Subject(),
|
||||
}
|
||||
|
||||
if scope, ok := t.Get("scope"); ok {
|
||||
token.Scopes = strings.Fields(scope.(string))
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
|
||||
if err := useCase.tokens.Remove(ctx, accessToken); err != nil {
|
||||
return errors.Wrap(err, "failed to delete a token in the vault")
|
||||
t, err := useCase.Verify(ctx, accessToken)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot verify token")
|
||||
}
|
||||
|
||||
if err = useCase.tokens.Create(ctx, t); err != nil {
|
||||
return errors.Wrap(err, "cannot save token in database")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -5,9 +5,12 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
configrepo "source.toby3d.me/website/oauth/internal/config/repository/viper"
|
||||
configucase "source.toby3d.me/website/oauth/internal/config/usecase"
|
||||
"source.toby3d.me/website/oauth/internal/domain"
|
||||
repository "source.toby3d.me/website/oauth/internal/token/repository/memory"
|
||||
"source.toby3d.me/website/oauth/internal/token/usecase"
|
||||
|
@ -16,31 +19,35 @@ import (
|
|||
func TestVerify(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
v := viper.New()
|
||||
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
|
||||
v.SetDefault("indieauth.jwtSecret", "hackme")
|
||||
|
||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
accessToken := domain.NewToken()
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
||||
|
||||
token, err := usecase.NewTokenUseCase(repo).Verify(context.TODO(), accessToken.AccessToken)
|
||||
token, err := usecase.NewTokenUseCase(
|
||||
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
|
||||
).Verify(context.TODO(), accessToken.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accessToken, token)
|
||||
assert.Equal(t, accessToken.AccessToken, token.AccessToken)
|
||||
}
|
||||
|
||||
func TestRevoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
v := viper.New()
|
||||
v.SetDefault("indieauth.jwtSigningAlgorithm", "HS256")
|
||||
v.SetDefault("indieauth.jwtSecret", "hackme")
|
||||
|
||||
repo := repository.NewMemoryTokenRepository(new(sync.Map))
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
require.NoError(t, repo.Create(context.TODO(), accessToken))
|
||||
require.NoError(t, usecase.NewTokenUseCase(
|
||||
repo, configucase.NewConfigUseCase(configrepo.NewViperConfigRepository(v)),
|
||||
).Revoke(context.TODO(), accessToken.AccessToken))
|
||||
|
||||
token, err := repo.Get(context.TODO(), accessToken.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, token)
|
||||
|
||||
require.NoError(t, usecase.NewTokenUseCase(repo).Revoke(context.TODO(), token.AccessToken))
|
||||
|
||||
token, err = repo.Get(context.TODO(), token.AccessToken)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, token)
|
||||
result, err := repo.Get(context.TODO(), accessToken.AccessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accessToken.AccessToken, result.AccessToken)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue