From 4f1e63e221d01ad8b52298564b96af34f777c32b Mon Sep 17 00:00:00 2001 From: Maxim Lebedev Date: Fri, 14 Jan 2022 01:51:33 +0500 Subject: [PATCH] :card_file_box: Added SQLite3 repository for sessions --- .../repository/sqlite3/sqlite3_session.go | 144 ++++++++++++++++++ .../sqlite3/sqlite3_session_test.go | 49 ++++++ 2 files changed, 193 insertions(+) create mode 100644 internal/session/repository/sqlite3/sqlite3_session.go create mode 100644 internal/session/repository/sqlite3/sqlite3_session_test.go diff --git a/internal/session/repository/sqlite3/sqlite3_session.go b/internal/session/repository/sqlite3/sqlite3_session.go new file mode 100644 index 0000000..a7e1690 --- /dev/null +++ b/internal/session/repository/sqlite3/sqlite3_session.go @@ -0,0 +1,144 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/jmoiron/sqlx" + + "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/session" +) + +type ( + Session struct { + CreatedAt sql.NullTime `db:"created_at"` + ClientID string `db:"client_id"` + Me string `db:"me"` + RedirectURI string `db:"redirect_uri"` + CodeChallengeMethod string `db:"code_challenge_method"` + Scope string `db:"scope"` + Code string `db:"code"` + CodeChallenge string `db:"code_challenge"` + } + + sqlite3SessionRepository struct { + config *domain.Config + db *sqlx.DB + } +) + +const ( + QueryTable string = `CREATE TABLE IF NOT EXISTS sessions ( + created_at DATETIME NOT NULL, + client_id TEXT NOT NULL, + me TEXT NOT NULL, + redirect_uri TEXT NOT NULL, + code_challenge_method TEXT, + scope TEXT, + code TEXT UNIQUE PRIMARY KEY NOT NULL, + code_challenge TEXT + );` + + QueryGet string = `SELECT * + FROM sessions + WHERE code=$1;` + + QueryCreate string = `INSERT INTO sessions (created_at, client_id, me, redirect_uri, code_challenge_method, + scope, code, code_challenge) + VALUES (:created_at, :client_id, :me, :redirect_uri, :code_challenge_method, :scope, :code, + :code_challenge);` + + QueryDelete string = `DELETE FROM sessions + WHERE code=$1;` +) + +func NewSQLite3SessionRepository(config *domain.Config, db *sqlx.DB) session.Repository { + return &sqlite3SessionRepository{ + db: db, + } +} + +func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domain.Session) error { + if _, err := repo.db.NamedExecContext(ctx, QueryTable+QueryCreate, NewSession(session)); err != nil { + return fmt.Errorf("cannot create session record in db: %w", err) + } + + return nil +} + +func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) { + s := new(Session) + + tx, err := repo.db.Beginx() + if err != nil { + tx.Rollback() + + return nil, fmt.Errorf("failed to begin transaction: %w", err) + } + + if err = tx.GetContext(ctx, s, QueryTable+QueryGet, code); err != nil { + defer tx.Rollback() + + if errors.Is(err, sql.ErrNoRows) { + return nil, session.ErrNotExist + } + + return nil, fmt.Errorf("cannot find session in db: %w", err) + } + + if _, err = tx.ExecContext(ctx, QueryDelete, code); err != nil { + tx.Rollback() + + return nil, fmt.Errorf("cannot remove session from db: %w", err) + } + + if err = tx.Commit(); err != nil { + return nil, fmt.Errorf("failed to commit transaction: %w", err) + } + + result := new(domain.Session) + s.Populate(result) + + return result, nil +} + +func (repo *sqlite3SessionRepository) GC() {} + +func NewSession(src *domain.Session) *Session { + return &Session{ + CreatedAt: sql.NullTime{ + Time: time.Now().UTC(), + Valid: true, + }, + ClientID: src.ClientID.String(), + Code: src.Code, + CodeChallenge: src.CodeChallenge, + CodeChallengeMethod: src.CodeChallengeMethod.String(), + Me: src.Me.String(), + RedirectURI: src.RedirectURI.String(), + Scope: src.Scope.String(), + } +} + +func (t *Session) Populate(dst *domain.Session) { + dst.ClientID, _ = domain.NewClientID(t.ClientID) + dst.Code = t.Code + dst.CodeChallenge = t.CodeChallenge + dst.CodeChallengeMethod, _ = domain.ParseCodeChallengeMethod(t.CodeChallengeMethod) + dst.Me, _ = domain.NewMe(t.Me) + dst.RedirectURI, _ = domain.NewURL(t.RedirectURI) + + for _, scope := range strings.Fields(t.Scope) { + s, err := domain.ParseScope(scope) + if err != nil { + continue + } + + dst.Scope = append(dst.Scope, s) + } +} diff --git a/internal/session/repository/sqlite3/sqlite3_session_test.go b/internal/session/repository/sqlite3/sqlite3_session_test.go new file mode 100644 index 0000000..c8cb9d2 --- /dev/null +++ b/internal/session/repository/sqlite3/sqlite3_session_test.go @@ -0,0 +1,49 @@ +package sqlite3_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "source.toby3d.me/website/indieauth/internal/domain" + repository "source.toby3d.me/website/indieauth/internal/session/repository/sqlite3" + "source.toby3d.me/website/indieauth/internal/testing/sqltest" +) + +func TestCreate(t *testing.T) { + t.Parallel() + + db, cleanup := sqltest.Open(t) + t.Cleanup(cleanup) + + session := domain.TestSession(t) + require.NoError(t, repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). + Create(context.Background(), session)) + + results := make([]*repository.Session, 0) + require.NoError(t, db.Select(&results, "SELECT * FROM sessions")) + require.Len(t, results, 1) + + result := new(domain.Session) + results[0].Populate(result) + + assert.Equal(t, session.Code, result.Code) +} + +func TestGetAndDelete(t *testing.T) { + t.Parallel() + + db, cleanup := sqltest.Open(t) + t.Cleanup(cleanup) + + session := domain.TestSession(t) + _, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewSession(session)) + require.NoError(t, err) + + result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). + GetAndDelete(context.Background(), session.Code) + require.NoError(t, err) + assert.Equal(t, session.Code, result.Code) +}