♻️ Simplify error usage in ticket package

This commit is contained in:
Maxim Lebedev 2022-01-30 01:30:37 +05:00
parent 0495d4c72f
commit ed55c8cded
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
8 changed files with 116 additions and 90 deletions

View File

@ -1,6 +1,7 @@
package http package http
import ( import (
"errors"
"fmt" "fmt"
"path" "path"
@ -9,7 +10,6 @@ import (
http "github.com/valyala/fasthttp" http "github.com/valyala/fasthttp"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
"golang.org/x/xerrors"
"source.toby3d.me/toby3d/form" "source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware" "source.toby3d.me/toby3d/middleware"
@ -75,7 +75,7 @@ func (h *RequestHandler) Register(r *router.Router) {
// TODO(toby3d): secure this via JWT middleware // TODO(toby3d): secure this via JWT middleware
r.GET("/ticket", chain.RequestHandler(h.handleRender)) r.GET("/ticket", chain.RequestHandler(h.handleRender))
r.POST("/api/ticket", chain.RequestHandler(h.handleSend)) r.POST("/api/ticket", chain.RequestHandler(h.handleSend))
r.POST("/ticket", chain.RequestHandler(h.handleExchange)) r.POST("/ticket", chain.RequestHandler(h.handleRedeem))
} }
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
@ -120,22 +120,14 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
var err error var err error
if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil { if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(&domain.Error{ encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
Code: "unauthorized_client",
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return return
} }
if err = h.tickets.Generate(ctx, ticket); err != nil { if err = h.tickets.Generate(ctx, ticket); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(&domain.Error{ encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
Code: "unauthorized_client",
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return return
} }
@ -143,7 +135,7 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
ctx.SetStatusCode(http.StatusOK) ctx.SetStatusCode(http.StatusOK)
} }
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
ctx.SetStatusCode(http.StatusOK) ctx.SetStatusCode(http.StatusOK)
@ -157,18 +149,14 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
return return
} }
token, err := h.tickets.Exchange(ctx, &domain.Ticket{ token, err := h.tickets.Redeem(ctx, &domain.Ticket{
Ticket: req.Ticket, Ticket: req.Ticket,
Resource: req.Resource, Resource: req.Resource,
Subject: req.Subject, Subject: req.Subject,
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.Error{ encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
Code: "invalid_request",
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return return
} }
@ -184,71 +172,74 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
} }
func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) { func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) {
indieAuthError := new(domain.Error)
if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil { if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil {
return domain.Error{ if errors.As(err, indieAuthError) {
Code: "invalid_request", return indieAuthError
Description: err.Error(),
URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
Frame: xerrors.Caller(1),
} }
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
} }
if req.Resource == nil { if req.Resource == nil {
return domain.Error{ return domain.NewError(
Code: "invalid_request", domain.ErrorCodeInvalidRequest,
Description: "resource value MUST be set", "resource value MUST be set",
URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
Frame: xerrors.Caller(1), )
}
} }
if req.Subject == nil { if req.Subject == nil {
return domain.Error{ return domain.NewError(
Code: "invalid_request", domain.ErrorCodeInvalidRequest,
Description: "subject value MUST be set", "subject value MUST be set",
URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
Frame: xerrors.Caller(1), )
}
} }
return nil return nil
} }
func (req *ExchangeRequest) bind(ctx *http.RequestCtx) (err error) { func (req *ExchangeRequest) bind(ctx *http.RequestCtx) (err error) {
indieAuthError := new(domain.Error)
if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil { if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil {
return domain.Error{ if errors.As(err, indieAuthError) {
Code: "invalid_request", return indieAuthError
Description: err.Error(),
URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
Frame: xerrors.Caller(1),
} }
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
} }
if req.Ticket == "" { if req.Ticket == "" {
return domain.Error{ return domain.NewError(
Code: "invalid_request", domain.ErrorCodeInvalidRequest,
Description: "ticket parameter is required", "ticket parameter is required",
URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
Frame: xerrors.Caller(1), )
}
} }
if req.Resource == nil { if req.Resource == nil {
return domain.Error{ return domain.NewError(
Code: "invalid_request", domain.ErrorCodeInvalidRequest,
Description: "resource value MUST be set", "resource parameter is required",
URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
Frame: xerrors.Caller(1), )
}
} }
if req.Subject == nil { if req.Subject == nil {
return domain.Error{ return domain.NewError(
Code: "invalid_request", domain.ErrorCodeInvalidRequest,
Description: "subject value MUST be set", "subject parameter is required",
URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
Frame: xerrors.Caller(1), )
}
} }
return nil return nil

View File

@ -48,7 +48,7 @@ func TestUpdate(t *testing.T) {
r := router.New() r := router.New()
delivery.NewRequestHandler( delivery.NewRequestHandler(
ucase.NewTicketUseCase(ticketrepo.NewMemoryTicketRepository(new(sync.Map), config), userClient), ucase.NewTicketUseCase(ticketrepo.NewMemoryTicketRepository(new(sync.Map), config), userClient, config),
language.NewMatcher(message.DefaultCatalog.Languages()), config, language.NewMatcher(message.DefaultCatalog.Languages()), config,
).Register(r) ).Register(r)

View File

@ -2,7 +2,6 @@ package ticket
import ( import (
"context" "context"
"errors"
"source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/domain"
) )
@ -13,4 +12,4 @@ type Repository interface {
GC() GC()
} }
var ErrNotExist = errors.New("token_endpoint not found on resource URL") var ErrNotExist error = domain.NewError(domain.ErrorCodeInvalidRequest, "ticket not exist or expired", "")

View File

@ -2,6 +2,7 @@ package memory
import ( import (
"context" "context"
"fmt"
"path" "path"
"sync" "sync"
"time" "time"
@ -43,12 +44,12 @@ func (repo *memoryTicketRepository) Create(_ context.Context, t *domain.Ticket)
func (repo *memoryTicketRepository) GetAndDelete(_ context.Context, t string) (*domain.Ticket, error) { func (repo *memoryTicketRepository) GetAndDelete(_ context.Context, t string) (*domain.Ticket, error) {
src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, t)) src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, t))
if !ok { if !ok {
return nil, ticket.ErrNotExist return nil, fmt.Errorf("cannot find ticket in store: %w", ticket.ErrNotExist)
} }
result, ok := src.(*Ticket) result, ok := src.(*Ticket)
if !ok { if !ok {
return nil, ticket.ErrNotExist return nil, fmt.Errorf("cannot decode ticket in store: %w", ticket.ErrNotExist)
} }
return result.Ticket, nil return result.Ticket, nil

View File

@ -11,7 +11,6 @@ import (
"source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/ticket" "source.toby3d.me/website/indieauth/internal/ticket"
"source.toby3d.me/website/indieauth/internal/token"
) )
type ( type (
@ -62,9 +61,7 @@ func (repo *sqlite3TicketRepository) Create(ctx context.Context, t *domain.Ticke
return nil return nil
} }
func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, ticket string) (*domain.Ticket, error) { func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, t string) (*domain.Ticket, error) {
t := new(Ticket)
tx, err := repo.db.Beginx() tx, err := repo.db.Beginx()
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
@ -72,17 +69,18 @@ func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, ticket st
return nil, fmt.Errorf("failed to begin transaction: %w", err) return nil, fmt.Errorf("failed to begin transaction: %w", err)
} }
if err = tx.GetContext(ctx, t, QueryTable+QueryGet, ticket); err != nil { tkt := new(Ticket)
if err = tx.GetContext(ctx, tkt, QueryTable+QueryGet, t); err != nil {
defer tx.Rollback() defer tx.Rollback()
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, token.ErrNotExist return nil, ticket.ErrNotExist
} }
return nil, fmt.Errorf("cannot find ticket in db: %w", err) return nil, fmt.Errorf("cannot find ticket in db: %w", err)
} }
if _, err = tx.ExecContext(ctx, QueryDelete, ticket); err != nil { if _, err = tx.ExecContext(ctx, QueryDelete, t); err != nil {
tx.Rollback() tx.Rollback()
return nil, fmt.Errorf("cannot remove ticket from db: %w", err) return nil, fmt.Errorf("cannot remove ticket from db: %w", err)
@ -93,7 +91,8 @@ func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, ticket st
} }
result := new(domain.Ticket) result := new(domain.Ticket)
t.Populate(result)
tkt.Populate(result)
return result, nil return result, nil
} }

View File

@ -9,6 +9,17 @@ import (
type UseCase interface { type UseCase interface {
Generate(ctx context.Context, ticket *domain.Ticket) error Generate(ctx context.Context, ticket *domain.Ticket) error
// Exchange transform received ticket into access token. // Redeem transform received ticket into access token.
Exchange(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error) Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error)
Exchange(ctx context.Context, ticket string) (*domain.Token, error)
} }
var (
ErrTicketEndpointNotExist error = domain.NewError(
domain.ErrorCodeServerError, "ticket_endpoint not found on ticket resource", "",
)
ErrTokenEndpointNotExist error = domain.NewError(
domain.ErrorCodeServerError, "token_endpoint not found on ticket resource", "",
)
)

View File

@ -23,23 +23,25 @@ type (
} }
ticketUseCase struct { ticketUseCase struct {
config *domain.Config
client *http.Client client *http.Client
tickets ticket.Repository tickets ticket.Repository
} }
) )
func NewTicketUseCase(tickets ticket.Repository, client *http.Client) ticket.UseCase { func NewTicketUseCase(tickets ticket.Repository, client *http.Client, config *domain.Config) ticket.UseCase {
return &ticketUseCase{ return &ticketUseCase{
client: client, client: client,
tickets: tickets, tickets: tickets,
config: config,
} }
} }
func (useCase *ticketUseCase) Generate(ctx context.Context, ticket *domain.Ticket) error { func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) error {
req := http.AcquireRequest() req := http.AcquireRequest()
defer http.ReleaseRequest(req) defer http.ReleaseRequest(req)
req.Header.SetMethod(http.MethodGet) req.Header.SetMethod(http.MethodGet)
req.SetRequestURIBytes(ticket.Subject.RequestURI()) req.SetRequestURI(t.Subject.String())
resp := http.AcquireResponse() resp := http.AcquireResponse()
defer http.ReleaseResponse(resp) defer http.ReleaseResponse(resp)
@ -54,16 +56,16 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, ticket *domain.Ticke
if metadata, err := util.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil { if metadata, err := util.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil {
ticketEndpoint = metadata.TicketEndpoint ticketEndpoint = metadata.TicketEndpoint
} else { // NOTE(toby3d): fallback to old links searching } else { // NOTE(toby3d): fallback to old links searching
if endpoints := util.ExtractEndpoints(resp, "ticket_endpoint"); endpoints != nil && len(endpoints) > 0 { if endpoints := util.ExtractEndpoints(resp, "ticket_endpoint"); len(endpoints) > 0 {
ticketEndpoint = endpoints[len(endpoints)-1] ticketEndpoint = endpoints[len(endpoints)-1]
} }
} }
if ticketEndpoint == nil { if ticketEndpoint == nil {
return fmt.Errorf("cannot discovery ticket_endpoint on ticket resource") return ticket.ErrTicketEndpointNotExist
} }
if err := useCase.tickets.Create(ctx, ticket); err != nil { if err := useCase.tickets.Create(ctx, t); err != nil {
return fmt.Errorf("cannot save ticket in store: %w", err) return fmt.Errorf("cannot save ticket in store: %w", err)
} }
@ -71,9 +73,9 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, ticket *domain.Ticke
req.Header.SetMethod(http.MethodPost) req.Header.SetMethod(http.MethodPost)
req.SetRequestURIBytes(ticketEndpoint.RequestURI()) req.SetRequestURIBytes(ticketEndpoint.RequestURI())
req.Header.SetContentType(common.MIMEApplicationForm) req.Header.SetContentType(common.MIMEApplicationForm)
req.PostArgs().Set("ticket", ticket.Ticket) req.PostArgs().Set("ticket", t.Ticket)
req.PostArgs().Set("subject", ticket.Subject.String()) req.PostArgs().Set("subject", t.Subject.String())
req.PostArgs().Set("resource", ticket.Resource.String()) req.PostArgs().Set("resource", t.Resource.String())
resp.Reset() resp.Reset()
if err := useCase.client.Do(req, resp); err != nil { if err := useCase.client.Do(req, resp); err != nil {
@ -83,10 +85,10 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, ticket *domain.Ticke
return nil return nil
} }
func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error) { func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*domain.Token, error) {
req := http.AcquireRequest() req := http.AcquireRequest()
defer http.ReleaseRequest(req) defer http.ReleaseRequest(req)
req.SetRequestURI(ticket.Resource.String()) req.SetRequestURI(t.Resource.String())
req.Header.SetMethod(http.MethodGet) req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse() resp := http.AcquireResponse()
@ -102,13 +104,13 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket *domain.Ticke
if metadata, err := util.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil { if metadata, err := util.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil {
tokenEndpoint = metadata.TokenEndpoint tokenEndpoint = metadata.TokenEndpoint
} else { // NOTE(toby3d): fallback to old links searching } else { // NOTE(toby3d): fallback to old links searching
if endpoints := util.ExtractEndpoints(resp, "token_endpoint"); endpoints != nil && len(endpoints) > 0 { if endpoints := util.ExtractEndpoints(resp, "token_endpoint"); len(endpoints) > 0 {
tokenEndpoint = endpoints[len(endpoints)-1] tokenEndpoint = endpoints[len(endpoints)-1]
} }
} }
if tokenEndpoint == nil { if tokenEndpoint == nil {
return nil, fmt.Errorf("cannot discovery token_endpoint on ticket resource") return nil, ticket.ErrTokenEndpointNotExist
} }
req.Reset() req.Reset()
@ -117,7 +119,7 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket *domain.Ticke
req.Header.SetContentType(common.MIMEApplicationForm) req.Header.SetContentType(common.MIMEApplicationForm)
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON) req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String()) req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String())
req.PostArgs().Set("ticket", ticket.Ticket) req.PostArgs().Set("ticket", t.Ticket)
resp.Reset() resp.Reset()
if err := useCase.client.Do(req, resp); err != nil { if err := useCase.client.Do(req, resp); err != nil {
@ -138,3 +140,25 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket *domain.Ticke
Scope: data.Scope, Scope: data.Scope,
}, nil }, nil
} }
func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*domain.Token, error) {
t, err := useCase.tickets.GetAndDelete(ctx, ticket)
if err != nil {
return nil, fmt.Errorf("cannot find provided ticket: %w", err)
}
token, err := domain.NewToken(domain.NewTokenOptions{
Algorithm: useCase.config.JWT.Algorithm,
Expiration: useCase.config.JWT.Expiry,
// TODO(toby3d): Issuer: &domain.ClientID{},
NonceLength: useCase.config.JWT.NonceLength,
Scope: domain.Scopes{domain.ScopeRead},
Secret: []byte(useCase.config.JWT.Secret),
Subject: t.Subject,
})
if err != nil {
return nil, fmt.Errorf("cannot generate a new access token: %w", err)
}
return token, nil
}

View File

@ -16,7 +16,7 @@ import (
ucase "source.toby3d.me/website/indieauth/internal/ticket/usecase" ucase "source.toby3d.me/website/indieauth/internal/ticket/usecase"
) )
func TestExchange(t *testing.T) { func TestRedeem(t *testing.T) {
t.Parallel() t.Parallel()
token := domain.TestToken(t) token := domain.TestToken(t)
@ -39,7 +39,8 @@ func TestExchange(t *testing.T) {
client, _, cleanup := httptest.New(t, r.Handler) client, _, cleanup := httptest.New(t, r.Handler)
t.Cleanup(cleanup) t.Cleanup(cleanup)
result, err := ucase.NewTicketUseCase(nil, client).Exchange(context.Background(), ticket) result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)).
Redeem(context.Background(), ticket)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, token.AccessToken, result.AccessToken) assert.Equal(t, token.AccessToken, result.AccessToken)
assert.Equal(t, token.Me.String(), result.Me.String()) assert.Equal(t, token.Me.String(), result.Me.String())