diff --git a/internal/session/repository/sqlite3/sqlite3_session.go b/internal/session/repository/sqlite3/sqlite3_session.go index 90b4132..6ee870b 100644 --- a/internal/session/repository/sqlite3/sqlite3_session.go +++ b/internal/session/repository/sqlite3/sqlite3_session.go @@ -58,13 +58,15 @@ const ( ) func NewSQLite3SessionRepository(config *domain.Config, db *sqlx.DB) session.Repository { + db.MustExec(QueryTable) + return &sqlite3SessionRepository{ db: db, } } func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domain.Session) error { - if _, err := repo.db.NamedExecContext(ctx, QueryTable+QueryCreate, NewSession(session)); err != nil { + if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewSession(session)); err != nil { return fmt.Errorf("cannot create session record in db: %w", err) } @@ -73,7 +75,7 @@ func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domai func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*domain.Session, error) { s := new(Session) - if err := repo.db.GetContext(ctx, s, QueryTable+QueryGet, code); err != nil { + if err := repo.db.GetContext(ctx, s, QueryGet, code); err != nil { return nil, fmt.Errorf("cannot find session in db: %w", err) } @@ -93,7 +95,7 @@ func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code str return nil, fmt.Errorf("failed to begin transaction: %w", err) } - if err = tx.GetContext(ctx, s, QueryTable+QueryGet, code); err != nil { + if err = tx.GetContext(ctx, s, QueryGet, code); err != nil { defer tx.Rollback() if errors.Is(err, sql.ErrNoRows) { diff --git a/internal/session/repository/sqlite3/sqlite3_session_test.go b/internal/session/repository/sqlite3/sqlite3_session_test.go index c3aa7b2..1a0f382 100644 --- a/internal/session/repository/sqlite3/sqlite3_session_test.go +++ b/internal/session/repository/sqlite3/sqlite3_session_test.go @@ -2,66 +2,126 @@ package sqlite3_test import ( "context" + "regexp" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/DATA-DOG/go-sqlmock" "source.toby3d.me/website/indieauth/internal/domain" repository "source.toby3d.me/website/indieauth/internal/session/repository/sqlite3" "source.toby3d.me/website/indieauth/internal/testing/sqltest" ) +//nolint: gochecknoglobals +var tableColumns = []string{ + "created_at", "client_id", "me", "redirect_uri", "code_challenge_method", "scope", "code", + "code_challenge", +} + func TestCreate(t *testing.T) { t.Parallel() - db, cleanup := sqltest.Open(t) + session := domain.TestSession(t) + model := repository.NewSession(session) + db, mock, cleanup := sqltest.Open(t) t.Cleanup(cleanup) - session := domain.TestSession(t) - require.NoError(t, repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). - Create(context.Background(), session)) + createTable(t, mock) + mock.ExpectExec(regexp.QuoteMeta(`INSERT INTO sessions`)). + WithArgs( + sqltest.Time{}, + model.ClientID, + model.Me, + model.RedirectURI, + model.CodeChallengeMethod, + model.Scope, + model.Code, + model.CodeChallenge, + ). + WillReturnResult(sqlmock.NewResult(1, 1)) - results := make([]*repository.Session, 0) - require.NoError(t, db.Select(&results, "SELECT * FROM sessions")) - require.Len(t, results, 1) - - result := new(domain.Session) - results[0].Populate(result) - - assert.Equal(t, session.Code, result.Code) + if err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). + Create(context.TODO(), session); err != nil { + t.Error(err) + } } func TestGet(t *testing.T) { t.Parallel() - db, cleanup := sqltest.Open(t) + session := domain.TestSession(t) + model := repository.NewSession(session) + db, mock, cleanup := sqltest.Open(t) t.Cleanup(cleanup) - session := domain.TestSession(t) - _, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewSession(session)) - require.NoError(t, err) + createTable(t, mock) + mock.ExpectQuery(regexp.QuoteMeta(`SELECT * FROM sessions`)). + WithArgs(session.Code). + WillReturnRows(sqlmock.NewRows(tableColumns). + AddRow( + model.CreatedAt.Time, + model.ClientID, + model.Me, + model.RedirectURI, + model.CodeChallengeMethod, + model.Scope, + model.Code, + model.CodeChallenge, + )) result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). - Get(context.Background(), session.Code) - require.NoError(t, err) - assert.Equal(t, session.Code, result.Code) + Get(context.TODO(), session.Code) + if err != nil { + t.Fatal(err) + } + + if result.Code != session.Code { + t.Errorf("Get(%s) = %+v, want %+v", session.Code, result, session) + } } func TestGetAndDelete(t *testing.T) { t.Parallel() - db, cleanup := sqltest.Open(t) + session := domain.TestSession(t) + model := repository.NewSession(session) + db, mock, cleanup := sqltest.Open(t) t.Cleanup(cleanup) - session := domain.TestSession(t) - _, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewSession(session)) - require.NoError(t, err) + createTable(t, mock) + mock.ExpectBegin() + mock.ExpectQuery(regexp.QuoteMeta(`SELECT * FROM sessions`)). + WithArgs(session.Code). + WillReturnRows(sqlmock.NewRows(tableColumns). + AddRow( + model.CreatedAt.Time, + model.ClientID, + model.Me, + model.RedirectURI, + model.CodeChallengeMethod, + model.Scope, + model.Code, + model.CodeChallenge, + )) + mock.ExpectExec(regexp.QuoteMeta(`DELETE FROM sessions`)). + WithArgs(model.Code). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). - GetAndDelete(context.Background(), session.Code) - require.NoError(t, err) - assert.Equal(t, session.Code, result.Code) + GetAndDelete(context.TODO(), session.Code) + if err != nil { + t.Fatal(err) + } - assert.Error(t, db.Get(result, repository.QueryGet, session.Code), "session MUST be destroyed after successful"+" query") + if result.Code != session.Code { + t.Errorf("GetAndDelete(%s) = %+v, want %+v", session.Code, result, session) + } +} + +func createTable(tb testing.TB, mock sqlmock.Sqlmock) { + tb.Helper() + + mock.ExpectExec(regexp.QuoteMeta(repository.QueryTable)). + WillReturnResult(sqlmock.NewResult(1, 1)) } diff --git a/internal/testing/sqltest/sqltest.go b/internal/testing/sqltest/sqltest.go index e5def5d..be86547 100644 --- a/internal/testing/sqltest/sqltest.go +++ b/internal/testing/sqltest/sqltest.go @@ -1,28 +1,40 @@ package sqltest import ( + "database/sql/driver" "testing" + "time" + "github.com/DATA-DOG/go-sqlmock" "github.com/jmoiron/sqlx" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" _ "modernc.org/sqlite" ) +type Time struct{} + +func (Time) Match(v driver.Value) bool { + _, ok := v.(time.Time) + + return ok +} + // Open creates a new InMemory sqlite3 database for testing. -func Open(tb testing.TB) (*sqlx.DB, func()) { +func Open(tb testing.TB) (*sqlx.DB, sqlmock.Sqlmock, func()) { tb.Helper() - db, err := sqlx.Open("sqlite", ":memory:") - require.NoError(tb, err) - - if !assert.NoError(tb, db.Ping()) { - _ = db.Close() //nolint: errcheck - - tb.FailNow() + db, mock, err := sqlmock.New() + if err != nil { + tb.Fatalf("%+v", err) } - return db, func() { + xdb := sqlx.NewDb(db, "sqlite") + if err = xdb.Ping(); err != nil { + _ = db.Close() + + tb.Fatalf("%+v", err) + } + + return xdb, mock, func() { _ = db.Close() //nolint: errcheck } }