From ed55c8cded2a56566000d8edc3a93d5c50620baf Mon Sep 17 00:00:00 2001 From: Maxim Lebedev Date: Sun, 30 Jan 2022 01:30:37 +0500 Subject: [PATCH] :recycle: Simplify error usage in ticket package --- internal/ticket/delivery/http/ticket_http.go | 109 ++++++++---------- .../ticket/delivery/http/ticket_http_test.go | 2 +- internal/ticket/repository.go | 3 +- .../ticket/repository/memory/memory_ticket.go | 5 +- .../repository/sqlite3/sqlite3_ticket.go | 15 ++- internal/ticket/usecase.go | 15 ++- internal/ticket/usecase/ticket_ucase.go | 52 ++++++--- internal/ticket/usecase/ticket_ucase_test.go | 5 +- 8 files changed, 116 insertions(+), 90 deletions(-) diff --git a/internal/ticket/delivery/http/ticket_http.go b/internal/ticket/delivery/http/ticket_http.go index 01fcd2f..754e6bb 100644 --- a/internal/ticket/delivery/http/ticket_http.go +++ b/internal/ticket/delivery/http/ticket_http.go @@ -1,6 +1,7 @@ package http import ( + "errors" "fmt" "path" @@ -9,7 +10,6 @@ import ( http "github.com/valyala/fasthttp" "golang.org/x/text/language" "golang.org/x/text/message" - "golang.org/x/xerrors" "source.toby3d.me/toby3d/form" "source.toby3d.me/toby3d/middleware" @@ -75,7 +75,7 @@ func (h *RequestHandler) Register(r *router.Router) { // TODO(toby3d): secure this via JWT middleware r.GET("/ticket", chain.RequestHandler(h.handleRender)) r.POST("/api/ticket", chain.RequestHandler(h.handleSend)) - r.POST("/ticket", chain.RequestHandler(h.handleExchange)) + r.POST("/ticket", chain.RequestHandler(h.handleRedeem)) } func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { @@ -120,22 +120,14 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) { var err error if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil { ctx.SetStatusCode(http.StatusInternalServerError) - encoder.Encode(&domain.Error{ - Code: "unauthorized_client", - Description: err.Error(), - Frame: xerrors.Caller(1), - }) + encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) return } if err = h.tickets.Generate(ctx, ticket); err != nil { ctx.SetStatusCode(http.StatusInternalServerError) - encoder.Encode(&domain.Error{ - Code: "unauthorized_client", - Description: err.Error(), - Frame: xerrors.Caller(1), - }) + encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) return } @@ -143,7 +135,7 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) { ctx.SetStatusCode(http.StatusOK) } -func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { +func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) { ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) ctx.SetStatusCode(http.StatusOK) @@ -157,18 +149,14 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { return } - token, err := h.tickets.Exchange(ctx, &domain.Ticket{ + token, err := h.tickets.Redeem(ctx, &domain.Ticket{ Ticket: req.Ticket, Resource: req.Resource, Subject: req.Subject, }) if err != nil { ctx.SetStatusCode(http.StatusBadRequest) - encoder.Encode(domain.Error{ - Code: "invalid_request", - Description: err.Error(), - Frame: xerrors.Caller(1), - }) + encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) return } @@ -184,71 +172,74 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { } func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) { + indieAuthError := new(domain.Error) if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil { - return domain.Error{ - Code: "invalid_request", - Description: err.Error(), - URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - Frame: xerrors.Caller(1), + if errors.As(err, indieAuthError) { + return indieAuthError } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) } if req.Resource == nil { - return domain.Error{ - Code: "invalid_request", - Description: "resource value MUST be set", - URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - Frame: xerrors.Caller(1), - } + return domain.NewError( + domain.ErrorCodeInvalidRequest, + "resource value MUST be set", + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) } if req.Subject == nil { - return domain.Error{ - Code: "invalid_request", - Description: "subject value MUST be set", - URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - Frame: xerrors.Caller(1), - } + return domain.NewError( + domain.ErrorCodeInvalidRequest, + "subject value MUST be set", + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) } return nil } func (req *ExchangeRequest) bind(ctx *http.RequestCtx) (err error) { + indieAuthError := new(domain.Error) if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil { - return domain.Error{ - Code: "invalid_request", - Description: err.Error(), - URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - Frame: xerrors.Caller(1), + if errors.As(err, indieAuthError) { + return indieAuthError } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) } if req.Ticket == "" { - return domain.Error{ - Code: "invalid_request", - Description: "ticket parameter is required", - URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - Frame: xerrors.Caller(1), - } + return domain.NewError( + domain.ErrorCodeInvalidRequest, + "ticket parameter is required", + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) } if req.Resource == nil { - return domain.Error{ - Code: "invalid_request", - Description: "resource value MUST be set", - URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - Frame: xerrors.Caller(1), - } + return domain.NewError( + domain.ErrorCodeInvalidRequest, + "resource parameter is required", + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) } if req.Subject == nil { - return domain.Error{ - Code: "invalid_request", - Description: "subject value MUST be set", - URI: "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - Frame: xerrors.Caller(1), - } + return domain.NewError( + domain.ErrorCodeInvalidRequest, + "subject parameter is required", + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) } return nil diff --git a/internal/ticket/delivery/http/ticket_http_test.go b/internal/ticket/delivery/http/ticket_http_test.go index a60e0dc..9699d5c 100644 --- a/internal/ticket/delivery/http/ticket_http_test.go +++ b/internal/ticket/delivery/http/ticket_http_test.go @@ -48,7 +48,7 @@ func TestUpdate(t *testing.T) { r := router.New() delivery.NewRequestHandler( - ucase.NewTicketUseCase(ticketrepo.NewMemoryTicketRepository(new(sync.Map), config), userClient), + ucase.NewTicketUseCase(ticketrepo.NewMemoryTicketRepository(new(sync.Map), config), userClient, config), language.NewMatcher(message.DefaultCatalog.Languages()), config, ).Register(r) diff --git a/internal/ticket/repository.go b/internal/ticket/repository.go index 7d5bb86..83133ba 100644 --- a/internal/ticket/repository.go +++ b/internal/ticket/repository.go @@ -2,7 +2,6 @@ package ticket import ( "context" - "errors" "source.toby3d.me/website/indieauth/internal/domain" ) @@ -13,4 +12,4 @@ type Repository interface { GC() } -var ErrNotExist = errors.New("token_endpoint not found on resource URL") +var ErrNotExist error = domain.NewError(domain.ErrorCodeInvalidRequest, "ticket not exist or expired", "") diff --git a/internal/ticket/repository/memory/memory_ticket.go b/internal/ticket/repository/memory/memory_ticket.go index 265478e..d433a23 100644 --- a/internal/ticket/repository/memory/memory_ticket.go +++ b/internal/ticket/repository/memory/memory_ticket.go @@ -2,6 +2,7 @@ package memory import ( "context" + "fmt" "path" "sync" "time" @@ -43,12 +44,12 @@ func (repo *memoryTicketRepository) Create(_ context.Context, t *domain.Ticket) func (repo *memoryTicketRepository) GetAndDelete(_ context.Context, t string) (*domain.Ticket, error) { src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, t)) if !ok { - return nil, ticket.ErrNotExist + return nil, fmt.Errorf("cannot find ticket in store: %w", ticket.ErrNotExist) } result, ok := src.(*Ticket) if !ok { - return nil, ticket.ErrNotExist + return nil, fmt.Errorf("cannot decode ticket in store: %w", ticket.ErrNotExist) } return result.Ticket, nil diff --git a/internal/ticket/repository/sqlite3/sqlite3_ticket.go b/internal/ticket/repository/sqlite3/sqlite3_ticket.go index 1b0a198..d4e2c78 100644 --- a/internal/ticket/repository/sqlite3/sqlite3_ticket.go +++ b/internal/ticket/repository/sqlite3/sqlite3_ticket.go @@ -11,7 +11,6 @@ import ( "source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/ticket" - "source.toby3d.me/website/indieauth/internal/token" ) type ( @@ -62,9 +61,7 @@ func (repo *sqlite3TicketRepository) Create(ctx context.Context, t *domain.Ticke return nil } -func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, ticket string) (*domain.Ticket, error) { - t := new(Ticket) - +func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, t string) (*domain.Ticket, error) { tx, err := repo.db.Beginx() if err != nil { tx.Rollback() @@ -72,17 +69,18 @@ func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, ticket st return nil, fmt.Errorf("failed to begin transaction: %w", err) } - if err = tx.GetContext(ctx, t, QueryTable+QueryGet, ticket); err != nil { + tkt := new(Ticket) + if err = tx.GetContext(ctx, tkt, QueryTable+QueryGet, t); err != nil { defer tx.Rollback() if errors.Is(err, sql.ErrNoRows) { - return nil, token.ErrNotExist + return nil, ticket.ErrNotExist } return nil, fmt.Errorf("cannot find ticket in db: %w", err) } - if _, err = tx.ExecContext(ctx, QueryDelete, ticket); err != nil { + if _, err = tx.ExecContext(ctx, QueryDelete, t); err != nil { tx.Rollback() return nil, fmt.Errorf("cannot remove ticket from db: %w", err) @@ -93,7 +91,8 @@ func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, ticket st } result := new(domain.Ticket) - t.Populate(result) + + tkt.Populate(result) return result, nil } diff --git a/internal/ticket/usecase.go b/internal/ticket/usecase.go index 5e9589f..019639f 100644 --- a/internal/ticket/usecase.go +++ b/internal/ticket/usecase.go @@ -9,6 +9,17 @@ import ( type UseCase interface { Generate(ctx context.Context, ticket *domain.Ticket) error - // Exchange transform received ticket into access token. - Exchange(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error) + // Redeem transform received ticket into access token. + Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error) + + Exchange(ctx context.Context, ticket string) (*domain.Token, error) } + +var ( + ErrTicketEndpointNotExist error = domain.NewError( + domain.ErrorCodeServerError, "ticket_endpoint not found on ticket resource", "", + ) + ErrTokenEndpointNotExist error = domain.NewError( + domain.ErrorCodeServerError, "token_endpoint not found on ticket resource", "", + ) +) diff --git a/internal/ticket/usecase/ticket_ucase.go b/internal/ticket/usecase/ticket_ucase.go index 0d2e658..bad0d3f 100644 --- a/internal/ticket/usecase/ticket_ucase.go +++ b/internal/ticket/usecase/ticket_ucase.go @@ -23,23 +23,25 @@ type ( } ticketUseCase struct { + config *domain.Config client *http.Client tickets ticket.Repository } ) -func NewTicketUseCase(tickets ticket.Repository, client *http.Client) ticket.UseCase { +func NewTicketUseCase(tickets ticket.Repository, client *http.Client, config *domain.Config) ticket.UseCase { return &ticketUseCase{ client: client, tickets: tickets, + config: config, } } -func (useCase *ticketUseCase) Generate(ctx context.Context, ticket *domain.Ticket) error { +func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) error { req := http.AcquireRequest() defer http.ReleaseRequest(req) req.Header.SetMethod(http.MethodGet) - req.SetRequestURIBytes(ticket.Subject.RequestURI()) + req.SetRequestURI(t.Subject.String()) resp := http.AcquireResponse() defer http.ReleaseResponse(resp) @@ -54,16 +56,16 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, ticket *domain.Ticke if metadata, err := util.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil { ticketEndpoint = metadata.TicketEndpoint } else { // NOTE(toby3d): fallback to old links searching - if endpoints := util.ExtractEndpoints(resp, "ticket_endpoint"); endpoints != nil && len(endpoints) > 0 { + if endpoints := util.ExtractEndpoints(resp, "ticket_endpoint"); len(endpoints) > 0 { ticketEndpoint = endpoints[len(endpoints)-1] } } if ticketEndpoint == nil { - return fmt.Errorf("cannot discovery ticket_endpoint on ticket resource") + return ticket.ErrTicketEndpointNotExist } - if err := useCase.tickets.Create(ctx, ticket); err != nil { + if err := useCase.tickets.Create(ctx, t); err != nil { return fmt.Errorf("cannot save ticket in store: %w", err) } @@ -71,9 +73,9 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, ticket *domain.Ticke req.Header.SetMethod(http.MethodPost) req.SetRequestURIBytes(ticketEndpoint.RequestURI()) req.Header.SetContentType(common.MIMEApplicationForm) - req.PostArgs().Set("ticket", ticket.Ticket) - req.PostArgs().Set("subject", ticket.Subject.String()) - req.PostArgs().Set("resource", ticket.Resource.String()) + req.PostArgs().Set("ticket", t.Ticket) + req.PostArgs().Set("subject", t.Subject.String()) + req.PostArgs().Set("resource", t.Resource.String()) resp.Reset() if err := useCase.client.Do(req, resp); err != nil { @@ -83,10 +85,10 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, ticket *domain.Ticke return nil } -func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error) { +func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*domain.Token, error) { req := http.AcquireRequest() defer http.ReleaseRequest(req) - req.SetRequestURI(ticket.Resource.String()) + req.SetRequestURI(t.Resource.String()) req.Header.SetMethod(http.MethodGet) resp := http.AcquireResponse() @@ -102,13 +104,13 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket *domain.Ticke if metadata, err := util.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil { tokenEndpoint = metadata.TokenEndpoint } else { // NOTE(toby3d): fallback to old links searching - if endpoints := util.ExtractEndpoints(resp, "token_endpoint"); endpoints != nil && len(endpoints) > 0 { + if endpoints := util.ExtractEndpoints(resp, "token_endpoint"); len(endpoints) > 0 { tokenEndpoint = endpoints[len(endpoints)-1] } } if tokenEndpoint == nil { - return nil, fmt.Errorf("cannot discovery token_endpoint on ticket resource") + return nil, ticket.ErrTokenEndpointNotExist } req.Reset() @@ -117,7 +119,7 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket *domain.Ticke req.Header.SetContentType(common.MIMEApplicationForm) req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON) req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String()) - req.PostArgs().Set("ticket", ticket.Ticket) + req.PostArgs().Set("ticket", t.Ticket) resp.Reset() if err := useCase.client.Do(req, resp); err != nil { @@ -138,3 +140,25 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket *domain.Ticke Scope: data.Scope, }, nil } + +func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*domain.Token, error) { + t, err := useCase.tickets.GetAndDelete(ctx, ticket) + if err != nil { + return nil, fmt.Errorf("cannot find provided ticket: %w", err) + } + + token, err := domain.NewToken(domain.NewTokenOptions{ + Algorithm: useCase.config.JWT.Algorithm, + Expiration: useCase.config.JWT.Expiry, + // TODO(toby3d): Issuer: &domain.ClientID{}, + NonceLength: useCase.config.JWT.NonceLength, + Scope: domain.Scopes{domain.ScopeRead}, + Secret: []byte(useCase.config.JWT.Secret), + Subject: t.Subject, + }) + if err != nil { + return nil, fmt.Errorf("cannot generate a new access token: %w", err) + } + + return token, nil +} diff --git a/internal/ticket/usecase/ticket_ucase_test.go b/internal/ticket/usecase/ticket_ucase_test.go index 580ecc0..c6a0e39 100644 --- a/internal/ticket/usecase/ticket_ucase_test.go +++ b/internal/ticket/usecase/ticket_ucase_test.go @@ -16,7 +16,7 @@ import ( ucase "source.toby3d.me/website/indieauth/internal/ticket/usecase" ) -func TestExchange(t *testing.T) { +func TestRedeem(t *testing.T) { t.Parallel() token := domain.TestToken(t) @@ -39,7 +39,8 @@ func TestExchange(t *testing.T) { client, _, cleanup := httptest.New(t, r.Handler) t.Cleanup(cleanup) - result, err := ucase.NewTicketUseCase(nil, client).Exchange(context.Background(), ticket) + result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)). + Redeem(context.Background(), ticket) require.NoError(t, err) assert.Equal(t, token.AccessToken, result.AccessToken) assert.Equal(t, token.Me.String(), result.Me.String())