diff --git a/internal/token/repository/sqlite3/sqlite3_token.go b/internal/token/repository/sqlite3/sqlite3_token.go index 2e2e456..73565bc 100644 --- a/internal/token/repository/sqlite3/sqlite3_token.go +++ b/internal/token/repository/sqlite3/sqlite3_token.go @@ -46,13 +46,15 @@ const ( ) func NewSQLite3TokenRepository(db *sqlx.DB) token.Repository { + db.MustExec(QueryTable) + return &sqlite3TokenRepository{ db: db, } } func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { - if _, err := repo.db.NamedExecContext(ctx, QueryTable+QueryCreate, NewToken(accessToken)); err != nil { + if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewToken(accessToken)); err != nil { return fmt.Errorf("cannot create token record in db: %w", err) } @@ -61,7 +63,7 @@ func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *dom func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { t := new(Token) - if err := repo.db.GetContext(ctx, t, QueryTable+QueryGet, accessToken); err != nil { + if err := repo.db.GetContext(ctx, t, QueryGet, accessToken); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, token.ErrNotExist } diff --git a/internal/token/repository/sqlite3/sqlite3_token_test.go b/internal/token/repository/sqlite3/sqlite3_token_test.go index 49fc618..09a41a6 100644 --- a/internal/token/repository/sqlite3/sqlite3_token_test.go +++ b/internal/token/repository/sqlite3/sqlite3_token_test.go @@ -2,46 +2,76 @@ package sqlite3_test import ( "context" + "regexp" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/DATA-DOG/go-sqlmock" "source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/testing/sqltest" repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3" ) +//nolint: gochecknoglobals +var tableColumns []string = []string{"created_at", "access_token", "client_id", "me", "scope"} + func TestCreate(t *testing.T) { t.Parallel() - db, cleanup := sqltest.Open(t) + token := domain.TestToken(t) + model := repository.NewToken(token) + db, mock, cleanup := sqltest.Open(t) t.Cleanup(cleanup) - token := domain.TestToken(t) - require.NoError(t, repository.NewSQLite3TokenRepository(db).Create(context.Background(), token)) + createTable(t, mock) + mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO tokens`)). + WithArgs( + sqltest.Time{}, + model.AccessToken, + model.ClientID, + model.Me, + model.Scope, + ). + WillReturnResult(sqlmock.NewResult(1, 1)) - results := make([]*repository.Token, 0) - require.NoError(t, db.Select(&results, "SELECT * FROM tokens;")) - require.Len(t, results, 1) - - result := new(domain.Token) - results[0].Populate(result) - - assert.Equal(t, token.AccessToken, result.AccessToken) + if err := repository.NewSQLite3TokenRepository(db).Create(context.Background(), token); err != nil { + t.Error(err) + } } func TestGet(t *testing.T) { t.Parallel() - db, cleanup := sqltest.Open(t) + token := domain.TestToken(t) + model := repository.NewToken(token) + db, mock, cleanup := sqltest.Open(t) t.Cleanup(cleanup) - token := domain.TestToken(t) - _, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewToken(token)) - require.NoError(t, err) + createTable(t, mock) + mock.ExpectQuery(regexp.QuoteMeta(`SELECT * FROM tokens`)). + WithArgs(model.AccessToken). + WillReturnRows(sqlmock.NewRows(tableColumns). + AddRow( + model.CreatedAt.Time, + model.AccessToken, + model.ClientID, + model.Me, + model.Scope, + )) result, err := repository.NewSQLite3TokenRepository(db).Get(context.Background(), token.AccessToken) - require.NoError(t, err) - assert.Equal(t, token.AccessToken, result.AccessToken) + if err != nil { + t.Fatal(err) + } + + if result.AccessToken != token.AccessToken { + t.Errorf("Get(%s) = %+v, want %+v", token.AccessToken, result, token) + } +} + +func createTable(tb testing.TB, mock sqlmock.Sqlmock) { + tb.Helper() + + mock.ExpectExec(regexp.QuoteMeta(repository.QueryTable)). + WillReturnResult(sqlmock.NewResult(1, 1)) }