🗃️ Added Get method for sessions repository

This commit is contained in:
Maxim Lebedev 2022-01-30 00:56:27 +05:00
parent 93ba01be84
commit 0495d4c72f
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
4 changed files with 51 additions and 6 deletions

View File

@ -2,15 +2,15 @@ package session
import (
"context"
"errors"
"source.toby3d.me/website/indieauth/internal/domain"
)
type Repository interface {
Get(ctx context.Context, code string) (*domain.Session, error)
Create(ctx context.Context, session *domain.Session) error
GetAndDelete(ctx context.Context, code string) (*domain.Session, error)
GC()
}
var ErrNotExist = errors.New("session not exist")
var ErrNotExist error = domain.NewError(domain.ErrorCodeServerError, "session with this code not exist", "")

View File

@ -2,6 +2,7 @@ package memory
import (
"context"
"fmt"
"path"
"sync"
"time"
@ -40,15 +41,29 @@ func (repo *memorySessionRepository) Create(_ context.Context, state *domain.Ses
return nil
}
func (repo *memorySessionRepository) GetAndDelete(_ context.Context, code string) (*domain.Session, error) {
src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, code))
func (repo *memorySessionRepository) Get(_ context.Context, code string) (*domain.Session, error) {
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, code))
if !ok {
return nil, session.ErrNotExist
return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist)
}
result, ok := src.(*Session)
if !ok {
return nil, session.ErrNotExist
return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist)
}
return result.Session, nil
}
func (repo *memorySessionRepository) GetAndDelete(_ context.Context, code string) (*domain.Session, error) {
src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, code))
if !ok {
return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist)
}
result, ok := src.(*Session)
if !ok {
return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist)
}
return result.Session, nil

View File

@ -71,6 +71,18 @@ func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domai
return nil
}
func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*domain.Session, error) {
s := new(Session)
if err := repo.db.GetContext(ctx, s, QueryTable+QueryGet, code); err != nil {
return nil, fmt.Errorf("cannot find session in db: %w", err)
}
result := new(domain.Session)
s.Populate(result)
return result, nil
}
func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) {
s := new(Session)

View File

@ -32,6 +32,22 @@ func TestCreate(t *testing.T) {
assert.Equal(t, session.Code, result.Code)
}
func TestGet(t *testing.T) {
t.Parallel()
db, cleanup := sqltest.Open(t)
t.Cleanup(cleanup)
session := domain.TestSession(t)
_, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewSession(session))
require.NoError(t, err)
result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db).
Get(context.Background(), session.Code)
require.NoError(t, err)
assert.Equal(t, session.Code, result.Code)
}
func TestGetAndDelete(t *testing.T) {
t.Parallel()
@ -46,4 +62,6 @@ func TestGetAndDelete(t *testing.T) {
GetAndDelete(context.Background(), session.Code)
require.NoError(t, err)
assert.Equal(t, session.Code, result.Code)
assert.Error(t, db.Get(result, repository.QueryGet, session.Code), "session MUST be destroyed after successful"+" query")
}