🗃️ Added Get method for sessions repository
This commit is contained in:
parent
93ba01be84
commit
0495d4c72f
|
@ -2,15 +2,15 @@ package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
|
|
||||||
"source.toby3d.me/website/indieauth/internal/domain"
|
"source.toby3d.me/website/indieauth/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Repository interface {
|
type Repository interface {
|
||||||
|
Get(ctx context.Context, code string) (*domain.Session, error)
|
||||||
Create(ctx context.Context, session *domain.Session) error
|
Create(ctx context.Context, session *domain.Session) error
|
||||||
GetAndDelete(ctx context.Context, code string) (*domain.Session, error)
|
GetAndDelete(ctx context.Context, code string) (*domain.Session, error)
|
||||||
GC()
|
GC()
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrNotExist = errors.New("session not exist")
|
var ErrNotExist error = domain.NewError(domain.ErrorCodeServerError, "session with this code not exist", "")
|
||||||
|
|
|
@ -2,6 +2,7 @@ package memory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"path"
|
"path"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -40,15 +41,29 @@ func (repo *memorySessionRepository) Create(_ context.Context, state *domain.Ses
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (repo *memorySessionRepository) GetAndDelete(_ context.Context, code string) (*domain.Session, error) {
|
func (repo *memorySessionRepository) Get(_ context.Context, code string) (*domain.Session, error) {
|
||||||
src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, code))
|
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, code))
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, session.ErrNotExist
|
return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist)
|
||||||
}
|
}
|
||||||
|
|
||||||
result, ok := src.(*Session)
|
result, ok := src.(*Session)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, session.ErrNotExist
|
return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.Session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memorySessionRepository) GetAndDelete(_ context.Context, code string) (*domain.Session, error) {
|
||||||
|
src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, code))
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, ok := src.(*Session)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result.Session, nil
|
return result.Session, nil
|
||||||
|
|
|
@ -71,6 +71,18 @@ func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domai
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*domain.Session, error) {
|
||||||
|
s := new(Session)
|
||||||
|
if err := repo.db.GetContext(ctx, s, QueryTable+QueryGet, code); err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot find session in db: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := new(domain.Session)
|
||||||
|
s.Populate(result)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) {
|
func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) {
|
||||||
s := new(Session)
|
s := new(Session)
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,22 @@ func TestCreate(t *testing.T) {
|
||||||
assert.Equal(t, session.Code, result.Code)
|
assert.Equal(t, session.Code, result.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGet(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, cleanup := sqltest.Open(t)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
session := domain.TestSession(t)
|
||||||
|
_, err := db.NamedExec(repository.QueryTable+repository.QueryCreate, repository.NewSession(session))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db).
|
||||||
|
Get(context.Background(), session.Code)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, session.Code, result.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetAndDelete(t *testing.T) {
|
func TestGetAndDelete(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -46,4 +62,6 @@ func TestGetAndDelete(t *testing.T) {
|
||||||
GetAndDelete(context.Background(), session.Code)
|
GetAndDelete(context.Background(), session.Code)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, session.Code, result.Code)
|
assert.Equal(t, session.Code, result.Code)
|
||||||
|
|
||||||
|
assert.Error(t, db.Get(result, repository.QueryGet, session.Code), "session MUST be destroyed after successful"+" query")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue