diff --git a/internal/token/usecase/token_ucase_test.go b/internal/token/usecase/token_ucase_test.go index 512110d..7c802f6 100644 --- a/internal/token/usecase/token_ucase_test.go +++ b/internal/token/usecase/token_ucase_test.go @@ -3,35 +3,86 @@ package usecase_test import ( "context" "errors" + "path" "sync" "testing" "github.com/stretchr/testify/assert" "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/profile" + profilerepo "source.toby3d.me/website/indieauth/internal/profile/repository/memory" + "source.toby3d.me/website/indieauth/internal/session" + sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory" "source.toby3d.me/website/indieauth/internal/token" - repository "source.toby3d.me/website/indieauth/internal/token/repository/memory" + tokenrepo "source.toby3d.me/website/indieauth/internal/token/repository/memory" usecase "source.toby3d.me/website/indieauth/internal/token/usecase" ) -/* TODO(toby3d) +type Dependencies struct { + config *domain.Config + profile *domain.Profile + profiles profile.Repository + session *domain.Session + sessions session.Repository + store *sync.Map + token *domain.Token + tokens token.Repository +} + func TestExchange(t *testing.T) { t.Parallel() + + deps := NewDependencies(t) + deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, deps.session.Me.String()), deps.profile) + + if err := deps.sessions.Create(context.TODO(), deps.session); err != nil { + t.Fatal(err) + } + + opts := token.ExchangeOptions{ + ClientID: deps.session.ClientID, + Code: deps.session.Code, + CodeVerifier: deps.session.CodeChallenge, + RedirectURI: deps.session.RedirectURI, + } + + tkn, userInfo, err := usecase.NewTokenUseCase(usecase.Config{ + Config: deps.config, + Profiles: deps.profiles, + Sessions: deps.sessions, + Tokens: deps.tokens, + }).Exchange(context.TODO(), opts) + if err != nil { + t.Fatal(err) + } + + if tkn == nil { + t.Errorf("Exchange(ctx, %v) = nil, want not nil", opts) + } + + if userInfo == nil { + t.Errorf("Exchange(ctx, %v) = nil, want not nil", opts) + } } -*/ func TestVerify(t *testing.T) { t.Parallel() - repo := repository.NewMemoryTokenRepository(new(sync.Map)) - ucase := usecase.NewTokenUseCase(repo, nil, domain.TestConfig(t)) + deps := NewDependencies(t) + ucase := usecase.NewTokenUseCase(usecase.Config{ + Config: domain.TestConfig(t), + Profiles: deps.profiles, + Sessions: deps.sessions, + Tokens: deps.tokens, + }) t.Run("valid", func(t *testing.T) { t.Parallel() accessToken := domain.TestToken(t) - result, err := ucase.Verify(context.TODO(), accessToken.AccessToken) + result, _, err := ucase.Verify(context.TODO(), accessToken.AccessToken) if err != nil { t.Fatal(err) } @@ -46,11 +97,11 @@ func TestVerify(t *testing.T) { t.Parallel() accessToken := domain.TestToken(t) - if err := repo.Create(context.TODO(), accessToken); err != nil { + if err := deps.tokens.Create(context.TODO(), accessToken); err != nil { t.Fatal(err) } - result, err := ucase.Verify(context.TODO(), accessToken.AccessToken) + result, _, err := ucase.Verify(context.TODO(), accessToken.AccessToken) if !errors.Is(err, token.ErrRevoke) { t.Errorf("Verify(%s) = %v, want %v", accessToken.AccessToken, err, token.ErrRevoke) } @@ -64,21 +115,41 @@ func TestVerify(t *testing.T) { func TestRevoke(t *testing.T) { t.Parallel() - config := domain.TestConfig(t) - accessToken := domain.TestToken(t) - repo := repository.NewMemoryTokenRepository(new(sync.Map)) + deps := NewDependencies(t) - if err := usecase.NewTokenUseCase(repo, nil, config). - Revoke(context.TODO(), accessToken.AccessToken); err != nil { + if err := usecase.NewTokenUseCase(usecase.Config{ + Config: deps.config, + Profiles: deps.profiles, + Sessions: deps.sessions, + Tokens: deps.tokens, + }).Revoke(context.TODO(), deps.token.AccessToken); err != nil { t.Fatal(err) } - result, err := repo.Get(context.TODO(), accessToken.AccessToken) + result, err := deps.tokens.Get(context.TODO(), deps.token.AccessToken) if err != nil { t.Error(err) } - if result.AccessToken != accessToken.AccessToken { - t.Errorf("Get(%s) = %s, want %s", accessToken.AccessToken, result.AccessToken, accessToken.AccessToken) + if result.AccessToken != deps.token.AccessToken { + t.Errorf("Get(%s) = %s, want %s", deps.token.AccessToken, result.AccessToken, deps.token.AccessToken) + } +} + +func NewDependencies(tb testing.TB) Dependencies { + tb.Helper() + + store := new(sync.Map) + config := domain.TestConfig(tb) + + return Dependencies{ + config: config, + profile: domain.TestProfile(tb), + profiles: profilerepo.NewMemoryProfileRepository(store), + session: domain.TestSession(tb), + sessions: sessionrepo.NewMemorySessionRepository(store, config), + store: store, + token: domain.TestToken(tb), + tokens: tokenrepo.NewMemoryTokenRepository(store), } }