From 83dc4286eb04a69488a9ada0d7a41126b61f1826 Mon Sep 17 00:00:00 2001 From: Maxim Lebedev Date: Fri, 14 Jan 2022 01:49:41 +0500 Subject: [PATCH] :recycle: Refactored auth package --- internal/auth/delivery/http/auth_http.go | 613 +++++++++--------- internal/auth/delivery/http/auth_http_test.go | 81 +++ internal/auth/repository.go | 13 - internal/auth/repository/bolt/bolt_auth.go | 57 -- .../auth/repository/memory/memory_auth.go | 40 -- internal/auth/usecase/auth_ucase.go | 84 +++ internal/auth/usecase/auth_usecase.go | 137 ---- 7 files changed, 477 insertions(+), 548 deletions(-) create mode 100644 internal/auth/delivery/http/auth_http_test.go delete mode 100644 internal/auth/repository.go delete mode 100644 internal/auth/repository/bolt/bolt_auth.go delete mode 100644 internal/auth/repository/memory/memory_auth.go create mode 100644 internal/auth/usecase/auth_ucase.go delete mode 100644 internal/auth/usecase/auth_usecase.go diff --git a/internal/auth/delivery/http/auth_http.go b/internal/auth/delivery/http/auth_http.go index 7b47b3d..a060376 100644 --- a/internal/auth/delivery/http/auth_http.go +++ b/internal/auth/delivery/http/auth_http.go @@ -1,348 +1,359 @@ package http import ( - "net/url" - "time" + "fmt" + "path" + "strings" "github.com/fasthttp/router" json "github.com/goccy/go-json" http "github.com/valyala/fasthttp" - "gitlab.com/toby3d/indieauth/internal/auth" - "gitlab.com/toby3d/indieauth/internal/domain" - "gitlab.com/toby3d/indieauth/internal/middleware" - "gitlab.com/toby3d/indieauth/internal/pkce" - "gitlab.com/toby3d/indieauth/web" + "golang.org/x/text/language" + "golang.org/x/text/message" + "golang.org/x/xerrors" + + "source.toby3d.me/toby3d/form" + "source.toby3d.me/toby3d/middleware" + "source.toby3d.me/website/indieauth/internal/auth" + "source.toby3d.me/website/indieauth/internal/client" + "source.toby3d.me/website/indieauth/internal/common" + "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/web" ) type ( - Handler struct { - useCase auth.UseCase - } - AuthorizeRequest struct { - RedirectURI string - ResponseType string - ClientID string - State []byte - Scope string - CodeChallenge string - CodeChallengeMethod string - Me string + // Indicates to the authorization server that an authorization + // code should be returned as the response. + ResponseType domain.ResponseType `form:"response_type"` // code + + // The client URL. + ClientID *domain.ClientID `form:"client_id"` + + // The redirect URL indicating where the user should be + // redirected to after approving the request. + RedirectURI *domain.URL `form:"redirect_uri"` + + // A parameter set by the client which will be included when the + // user is redirected back to the client. This is used to + // prevent CSRF attacks. The authorization server MUST return + // the unmodified state value back to the client. + State string `form:"state"` + + // The code challenge as previously described. + CodeChallenge string `form:"code_challenge"` + + // The hashing method used to calculate the code challenge. + CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"` + + // A space-separated list of scopes the client is requesting, + // e.g. "profile", or "profile create". If the client omits this + // value, the authorization server MUST NOT issue an access + // token for this authorization code. Only the user's profile + // URL may be returned without any scope requested. + Scope domain.Scopes `form:"scope"` + + // The URL that the user entered. + Me *domain.Me `form:"me"` } - RedirectRequest struct { - Authorize string - ClientID string - CodeChallenge string - CodeChallengeMethod string - Me string - RedirectURI string - ResponseType string - Scope string - State []byte + VerifyRequest struct { + ClientID *domain.ClientID `form:"client_id"` + Me *domain.Me `form:"me"` + RedirectURI *domain.URL `form:"redirect_uri"` + CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"` + ResponseType domain.ResponseType `form:"response_type"` + Scope domain.Scopes `form:"scope[]"` // TODO(toby3d): fix parsing in form pkg + Authorize string `form:"authorize"` + CodeChallenge string `form:"code_challenge"` + State string `form:"state"` } ExchangeRequest struct { - GrantType string - Code string - ClientID string - RedirectURI string - CodeVerifier string + GrantType domain.GrantType `form:"grant_type"` // authorization_code + + // The authorization code received from the authorization + // endpoint in the redirect. + Code string `form:"code"` + + // The client's URL, which MUST match the client_id used in the + // authentication request. + ClientID *domain.ClientID `form:"client_id"` + + // The client's redirect URL, which MUST match the initial + // authentication request. + RedirectURI *domain.URL `form:"redirect_uri"` + + // The original plaintext random string generated before + // starting the authorization request. + CodeVerifier string `form:"code_verifier"` } ExchangeResponse struct { - Me string `json:"me"` + Me *domain.Me `json:"me"` + } + + NewRequestHandlerOptions struct { + Auth auth.UseCase + Clients client.UseCase + Config *domain.Config + Matcher language.Matcher + } + + RequestHandler struct { + clients client.UseCase + config *domain.Config + matcher language.Matcher + useCase auth.UseCase } ) -func NewAuthHandler(useCase auth.UseCase) *Handler { - return &Handler{ - useCase: useCase, +func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { + return &RequestHandler{ + clients: opts.Clients, + config: opts.Config, + matcher: opts.Matcher, + useCase: opts.Auth, } } -func (h *Handler) Register(r *router.Router) { - chain := middleware.Chain{middleware.CSRFWithConfig(middleware.CSRFConfig{ - ContextKey: "csrf", - CookieHTTPOnly: true, - CookieName: "__Host-CSRF", - CookiePath: "/", - CookieSameSite: http.CookieSameSiteLaxMode, - CookieSecure: true, - TokenLookup: "form:_csrf", - Skipper: func(ctx *http.RequestCtx) bool { - return ctx.IsPost() && ctx.PostArgs().Has("grant_type") && - string(ctx.PostArgs().Peek("grant_type")) == "authorization_code" - }, - })} +func (h *RequestHandler) Register(r *router.Router) { + chain := middleware.Chain{ + middleware.CSRFWithConfig(middleware.CSRFConfig{ + Skipper: func(ctx *http.RequestCtx) bool { + matched, _ := path.Match("/api/*", string(ctx.Path())) - r.GET("/authorize", chain.RequestHandler(h.ClientInfo)) - r.POST("/authorize", chain.RequestHandler(h.Update)) + return ctx.IsPost() && matched + }, + CookieSameSite: http.CookieSameSiteLaxMode, + CookieName: "_csrf", + TokenLookup: "form:_csrf", + CookieSecure: true, + CookieHTTPOnly: true, + }), + middleware.LogFmt(), + } + + r.GET("/authorize", chain.RequestHandler(h.handleRender)) + r.POST("/api/authorize", chain.RequestHandler(h.handleVerify)) + r.POST("/authorize", chain.RequestHandler(h.handleExchange)) } -func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error { - if r.ClientID = string(ctx.QueryArgs().Peek("client_id")); r.ClientID == "" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'client_id' query is required", - } - } - - if r.ResponseType = string(ctx.QueryArgs().Peek("response_type")); r.ResponseType != "code" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'response_type' must be 'code'", - } - } - - if ctx.QueryArgs().Has("code_challenge") { - r.CodeChallenge = string(ctx.QueryArgs().Peek("code_challenge")) - if len(r.CodeChallenge) < 43 || len(r.CodeChallenge) > 128 { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "length of the 'code_challenge' value must be greater than 43 and less than 128 symbols", - } - } - - r.CodeChallengeMethod = pkce.DefaultMethod - if ctx.PostArgs().Has("code_challenge_method") { - r.CodeChallengeMethod = string(ctx.QueryArgs().Peek("code_challenge_method")) - } - - if _, err := pkce.New(r.CodeChallengeMethod); err != nil { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: err.Error(), - } - } - } - - r.RedirectURI = string(ctx.QueryArgs().Peek("redirect_uri")) - r.State = ctx.QueryArgs().Peek("state") - r.Scope = string(ctx.QueryArgs().Peek("scope")) - r.Me = string(ctx.QueryArgs().Peek("me")) - - return nil -} - -func (h *Handler) ClientInfo(ctx *http.RequestCtx) { - r := new(AuthorizeRequest) - r.Scope = "profile" - - if err := r.bind(ctx); err != nil { - ctx.Error(err.Error(), http.StatusBadRequest) - - return - } - - client, err := h.useCase.Discovery(ctx, r.ClientID) - if err != nil { - ctx.Error(err.Error(), http.StatusBadRequest) - - return - } - - csrf, _ := ctx.UserValue("csrf").([]byte) - - ctx.SetContentType("text/html") - web.WritePageTemplate(ctx, &web.AuthPage{ - Client: client, - CodeChallenge: r.CodeChallenge, - CodeChallengeMethod: r.CodeChallengeMethod, - CSRF: csrf, - Me: r.Me, - RedirectURI: r.RedirectURI, - ResponseType: r.ResponseType, - Scope: r.Scope, - State: r.State, - }) -} - -func (h *Handler) Update(ctx *http.RequestCtx) { - if ctx.PostArgs().Has("response_type") && string(ctx.PostArgs().Peek("response_type")) == "code" { - h.Redirect(ctx) - - return - } - - if ctx.PostArgs().Has("grant_type") && string(ctx.PostArgs().Peek("grant_type")) == "authorization_code" { - h.Exchange(ctx) - - return - } - - ctx.Error("please, restart your authoriztion flow", http.StatusBadRequest) -} - -func (r *RedirectRequest) bind(ctx *http.RequestCtx) (err error) { - r.RedirectURI = string(ctx.PostArgs().Peek("redirect_uri")) - - if r.ClientID = string(ctx.PostArgs().Peek("client_id")); r.ClientID == "" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'client_id' query is required", - } - } - - if r.Authorize = string(ctx.PostArgs().Peek("authorize")); r.Authorize != "allow" && r.Authorize != "deny" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "invalid prompt action, try starting the authorization flow again", - } - } - - if r.ResponseType = string(ctx.PostArgs().Peek("response_type")); r.ResponseType != "code" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'response_type' must be 'code', try starting the authorization flow again", - } - } - - if ctx.PostArgs().Has("code_challenge") { - r.CodeChallenge = string(ctx.PostArgs().Peek("code_challenge")) - - if len(r.CodeChallenge) < 43 || len(r.CodeChallenge) > 128 { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "length of the 'code_challenge' value must be greater than 43 and less than 128 symbols, try starting the authorization flow again", - } - } - - r.CodeChallengeMethod = pkce.DefaultMethod - if ctx.PostArgs().Has("code_challenge_method") { - r.CodeChallengeMethod = string(ctx.PostArgs().Peek("code_challenge_method")) - } - - _, err := pkce.New(r.CodeChallengeMethod) - if err != nil { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: err.Error(), - } - } - } - - r.State = ctx.PostArgs().Peek("state") - r.Scope = string(ctx.PostArgs().Peek("scope")) - r.Me = string(ctx.PostArgs().Peek("me")) - - return nil -} - -func (h *Handler) Redirect(ctx *http.RequestCtx) { - r := new(RedirectRequest) - if err := r.bind(ctx); err != nil { - ctx.Error(err.Error(), http.StatusBadRequest) - - return - } - - redirectUri, err := url.Parse(r.RedirectURI) - if err != nil { - ctx.Error(err.Error(), http.StatusBadRequest) - - return - } - - query := redirectUri.Query() - query.Set("state", string(r.State)) - - switch r.Authorize { - case "allow": - code, err := h.useCase.Approve(ctx, &domain.Login{ - CreatedAt: time.Now().UTC().Unix(), - ClientID: r.ClientID, - CodeChallenge: r.CodeChallenge, - CodeChallengeMethod: r.CodeChallengeMethod, - Me: r.Me, - RedirectURI: r.RedirectURI, - Scope: r.Scope, - }) - if err != nil { - query.Set("error", domain.ErrServerError.Code) - query.Set("error_description", err.Error()) - - redirectUri.RawQuery = query.Encode() - - ctx.Redirect(redirectUri.String(), http.StatusFound) - - return - } - - query.Set("code", code) - case "deny": - query.Set("error", domain.ErrAccessDenied.Code) - } - - redirectUri.RawQuery = query.Encode() - - ctx.Redirect(redirectUri.String(), http.StatusFound) -} - -func (r *ExchangeRequest) bind(ctx *http.RequestCtx) (err error) { - if r.GrantType = string(ctx.PostArgs().Peek("grant_type")); r.GrantType != "authorization_code" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'grant_type' must be 'authorization_code'", - } - } - - if r.RedirectURI = string(ctx.PostArgs().Peek("redirect_uri")); r.RedirectURI == "" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'redirect_uri' query is required", - } - } - - if r.ClientID = string(ctx.PostArgs().Peek("client_id")); r.ClientID == "" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'client_id' query is required", - } - } - - if r.Code = string(ctx.PostArgs().Peek("code")); r.Code == "" { - return domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'code' query is required", - } - } - - r.CodeVerifier = string(ctx.PostArgs().Peek("code_verifier")) - - return nil -} - -func (h *Handler) Exchange(ctx *http.RequestCtx) { - encoder := json.NewEncoder(ctx) - req := new(ExchangeRequest) - +func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { + req := new(AuthorizeRequest) if err := req.bind(ctx); err != nil { ctx.Error(err.Error(), http.StatusBadRequest) return } - me, err := h.useCase.Exchange(ctx, &domain.ExchangeRequest{ - ClientID: req.ClientID, - Code: req.Code, - CodeVerifier: req.CodeVerifier, - RedirectURI: req.RedirectURI, - }) + client, err := h.clients.Discovery(ctx, req.ClientID) if err != nil { ctx.Error(err.Error(), http.StatusBadRequest) return } - if me == "" { - ctx.Error(domain.ErrUnauthorizedClient.Error(), http.StatusUnauthorized) + if !client.ValidateRedirectURI(req.RedirectURI) { + ctx.Error("requested redirect_uri is not registered on client_id side", http.StatusBadRequest) return } - ctx.SetContentType("application/json") - _ = encoder.Encode(&ExchangeResponse{ + csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte) + tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) + tag, _, _ := h.matcher.Match(tags...) + + ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) + web.WriteTemplate(ctx, &web.AuthorizePage{ + BaseOf: web.BaseOf{ + Config: h.config, + Language: tag, + Printer: message.NewPrinter(tag), + }, + Client: client, + CodeChallenge: req.CodeChallenge, + CodeChallengeMethod: req.CodeChallengeMethod, + CSRF: csrf, + Me: req.Me, + RedirectURI: req.RedirectURI, + ResponseType: req.ResponseType, + Scope: req.Scope, + State: req.State, + }) +} + +func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) { + ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(ctx) + + req := new(VerifyRequest) + if err := req.bind(ctx); err != nil { + ctx.SetStatusCode(http.StatusBadRequest) + encoder.Encode(domain.Error{ + Code: "invalid_request", + Description: err.Error(), + Frame: xerrors.Caller(1), + }) + + return + } + + u := http.AcquireURI() + defer http.ReleaseURI(u) + req.RedirectURI.CopyTo(u) + + if strings.EqualFold(req.Authorize, "deny") { + u.QueryArgs().Set("error", "access_denied") + u.QueryArgs().Set("error_description", "user deny authorization request") + ctx.Redirect(u.String(), http.StatusFound) + + return + } + + code, err := h.useCase.Generate(ctx, auth.GenerateOptions{ + ClientID: req.ClientID, + RedirectURI: req.RedirectURI, + CodeChallenge: req.CodeChallenge, + CodeChallengeMethod: req.CodeChallengeMethod, + Scope: req.Scope, + Me: req.Me, + }) + if err != nil { + ctx.SetStatusCode(http.StatusInternalServerError) + encoder.Encode(domain.Error{ + Description: err.Error(), + Frame: xerrors.Caller(1), + }) + + return + } + + for key, val := range map[string]string{ + "code": code, + "iss": h.config.Server.GetRootURL(), + "state": req.State, + } { + u.QueryArgs().Set(key, val) + } + + ctx.Redirect(u.String(), http.StatusFound) +} + +func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { + ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(ctx) + + req := new(ExchangeRequest) + if err := req.bind(ctx); err != nil { + ctx.SetStatusCode(http.StatusBadRequest) + encoder.Encode(err) + + return + } + + me, err := h.useCase.Exchange(ctx, auth.ExchangeOptions{ + Code: req.Code, + ClientID: req.ClientID, + RedirectURI: req.RedirectURI, + CodeVerifier: req.CodeVerifier, + }) + if err != nil { + ctx.SetStatusCode(http.StatusBadRequest) + encoder.Encode(err) + + return + } + + encoder.Encode(&ExchangeResponse{ Me: me, }) } + +func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error { + if err := form.Unmarshal(ctx.QueryArgs(), r); err != nil { + return domain.Error{ + Code: "invalid_request", + Description: err.Error(), + Frame: xerrors.Caller(1), + } + } + + r.Scope = make(domain.Scopes, 0) + parseScope(r.Scope, ctx.QueryArgs().Peek("scope")) + + if r.ResponseType == domain.ResponseTypeID { + r.ResponseType = domain.ResponseTypeCode + } + + return nil +} + +func (r *VerifyRequest) bind(ctx *http.RequestCtx) error { + if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { + return domain.Error{ + Code: "invalid_request", + Description: err.Error(), + Frame: xerrors.Caller(1), + } + } + + r.Scope = make(domain.Scopes, 0) + parseScope(r.Scope, ctx.PostArgs().PeekMulti("scope[]")...) + + if r.ResponseType == domain.ResponseTypeID { + r.ResponseType = domain.ResponseTypeCode + } + + if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") { + return domain.Error{ + Code: "invalid_request", + Description: "cannot validate verification request", + Frame: xerrors.Caller(1), + } + } + + return nil +} + +func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error { + if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { + return domain.Error{ + Code: "invalid_request", + Description: err.Error(), + Frame: xerrors.Caller(1), + } + } + + return nil +} + +// TODO(toby3d): fix this in form pkg. +func parseScope(dst domain.Scopes, src ...[]byte) error { + if len(src) == 0 { + return nil + } + + var scopes []string + + if len(src) == 1 { + scopes = strings.Fields(string(src[0])) + } + + for _, rawScope := range scopes { + scope, err := domain.ParseScope(string(rawScope)) + if err != nil { + return &domain.Error{ + Code: "invalid_request", + Description: fmt.Sprintf("cannot parse scope: %v", err), + Frame: xerrors.Caller(1), + } + } + + dst = append(dst, scope) + } + + return nil +} diff --git a/internal/auth/delivery/http/auth_http_test.go b/internal/auth/delivery/http/auth_http_test.go new file mode 100644 index 0000000..05c25fb --- /dev/null +++ b/internal/auth/delivery/http/auth_http_test.go @@ -0,0 +1,81 @@ +package http_test + +import ( + "path" + "sync" + "testing" + + "github.com/fasthttp/router" + "github.com/fasthttp/session/v2" + "github.com/fasthttp/session/v2/providers/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + http "github.com/valyala/fasthttp" + "golang.org/x/text/language" + "golang.org/x/text/message" + + delivery "source.toby3d.me/website/indieauth/internal/auth/delivery/http" + ucase "source.toby3d.me/website/indieauth/internal/auth/usecase" + clientrepo "source.toby3d.me/website/indieauth/internal/client/repository/memory" + clientucase "source.toby3d.me/website/indieauth/internal/client/usecase" + "source.toby3d.me/website/indieauth/internal/domain" + sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory" + "source.toby3d.me/website/indieauth/internal/testing/httptest" + userrepo "source.toby3d.me/website/indieauth/internal/user/repository/memory" +) + +func TestRender(t *testing.T) { + t.Parallel() + + provider, err := memory.New(memory.Config{}) + require.NoError(t, err) + + s := session.New(session.NewDefaultConfig()) + require.NoError(t, s.SetProvider(provider)) + + me := domain.TestMe(t) + c := domain.TestClient(t) + config := domain.TestConfig(t) + store := new(sync.Map) + store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), domain.TestUser(t)) + store.Store(path.Join(clientrepo.DefaultPathPrefix, c.ID.String()), c) + + r := router.New() + delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{ + Clients: clientucase.NewClientUseCase(clientrepo.NewMemoryClientRepository(store)), + Config: config, + Matcher: language.NewMatcher(message.DefaultCatalog.Languages()), + Auth: ucase.NewAuthUseCase(sessionrepo.NewMemorySessionRepository(config, store), config), + }).Register(r) + + client, _, cleanup := httptest.New(t, r.Handler) + t.Cleanup(cleanup) + + u := http.AcquireURI() + defer http.ReleaseURI(u) + u.Update("https://example.com/authorize") + + for k, v := range map[string]string{ + "client_id": c.ID.String(), + "code_challenge": "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo", + "code_challenge_method": domain.CodeChallengeMethodS256.String(), + "me": me.String(), + "redirect_uri": c.RedirectURI[0].String(), + "response_type": domain.ResponseTypeCode.String(), + "scope": "profile email", + "state": "1234567890", + } { + u.QueryArgs().Set(k, v) + } + + req := httptest.NewRequest(http.MethodGet, u.String(), nil) + defer http.ReleaseRequest(req) + + resp := http.AcquireResponse() + defer http.ReleaseResponse(resp) + + require.NoError(t, client.Do(req, resp)) + + assert.Equal(t, http.StatusOK, resp.StatusCode()) + assert.Contains(t, string(resp.Body()), `Authorize application`) +} diff --git a/internal/auth/repository.go b/internal/auth/repository.go deleted file mode 100644 index 750cc84..0000000 --- a/internal/auth/repository.go +++ /dev/null @@ -1,13 +0,0 @@ -package auth - -import ( - "context" - - "gitlab.com/toby3d/indieauth/internal/domain" -) - -type Repository interface { - Create(ctx context.Context, login *domain.Login) error - Get(ctx context.Context, code string) (*domain.Login, error) - Delete(ctx context.Context, code string) error -} diff --git a/internal/auth/repository/bolt/bolt_auth.go b/internal/auth/repository/bolt/bolt_auth.go deleted file mode 100644 index 84c5cb7..0000000 --- a/internal/auth/repository/bolt/bolt_auth.go +++ /dev/null @@ -1,57 +0,0 @@ -package bolt - -import ( - "context" - - json "github.com/goccy/go-json" - "gitlab.com/toby3d/indieauth/internal/auth" - "gitlab.com/toby3d/indieauth/internal/domain" - bolt "go.etcd.io/bbolt" -) - -type boltAuthRepository struct { - db *bolt.DB -} - -func NewBoltAuthRepository(db *bolt.DB) (auth.Repository, error) { - if err := db.Update(func(tx *bolt.Tx) (err error) { - _, err = tx.CreateBucketIfNotExists(domain.Login{}.Bucket()) - - return err - }); err != nil { - return nil, err - } - - return &boltAuthRepository{ - db: db, - }, nil -} - -func (repo *boltAuthRepository) Create(ctx context.Context, login *domain.Login) error { - jsonLogin, err := json.Marshal(login) - if err != nil { - return err - } - - return repo.db.Update(func(tx *bolt.Tx) error { - return tx.Bucket(domain.Login{}.Bucket()).Put([]byte(login.Code), jsonLogin) - }) -} - -func (repo *boltAuthRepository) Get(ctx context.Context, code string) (*domain.Login, error) { - login := new(domain.Login) - - if err := repo.db.View(func(tx *bolt.Tx) error { - return json.Unmarshal(tx.Bucket(domain.Login{}.Bucket()).Get([]byte(code)), login) - }); err != nil { - return nil, err - } - - return login, nil -} - -func (repo *boltAuthRepository) Delete(ctx context.Context, code string) error { - return repo.db.Update(func(tx *bolt.Tx) error { - return tx.Bucket(domain.Login{}.Bucket()).Delete([]byte(code)) - }) -} diff --git a/internal/auth/repository/memory/memory_auth.go b/internal/auth/repository/memory/memory_auth.go deleted file mode 100644 index 91ddb63..0000000 --- a/internal/auth/repository/memory/memory_auth.go +++ /dev/null @@ -1,40 +0,0 @@ -package memory - -import ( - "context" - "sync" - - "gitlab.com/toby3d/indieauth/internal/auth" - "gitlab.com/toby3d/indieauth/internal/domain" -) - -type memoryAuthRepository struct { - logins *sync.Map -} - -func NewMemoryAuthRepository() auth.Repository { - return &memoryAuthRepository{ - logins: new(sync.Map), - } -} - -func (repo *memoryAuthRepository) Create(ctx context.Context, login *domain.Login) error { - repo.logins.Store(login.Code, login) - - return nil -} - -func (repo *memoryAuthRepository) Get(ctx context.Context, code string) (*domain.Login, error) { - login, ok := repo.logins.LoadAndDelete(code) - if !ok { - return nil, nil - } - - return login.(*domain.Login), nil -} - -func (repo *memoryAuthRepository) Delete(ctx context.Context, code string) error { - repo.logins.Delete(code) - - return nil -} diff --git a/internal/auth/usecase/auth_ucase.go b/internal/auth/usecase/auth_ucase.go new file mode 100644 index 0000000..aa4246f --- /dev/null +++ b/internal/auth/usecase/auth_ucase.go @@ -0,0 +1,84 @@ +package usecase + +import ( + "context" + "fmt" + + "golang.org/x/xerrors" + + "source.toby3d.me/website/indieauth/internal/auth" + "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/random" + "source.toby3d.me/website/indieauth/internal/session" +) + +type authUseCase struct { + config *domain.Config + sessions session.Repository +} + +func NewAuthUseCase(sessions session.Repository, config *domain.Config) auth.UseCase { + return &authUseCase{ + config: config, + sessions: sessions, + } +} + +func (useCase *authUseCase) Generate(ctx context.Context, opts auth.GenerateOptions) (string, error) { + code, err := random.String(useCase.config.Code.Length) + if err != nil { + return "", fmt.Errorf("cannot generate random code: %w", err) + } + + if err = useCase.sessions.Create(ctx, &domain.Session{ + ClientID: opts.ClientID, + Code: code, + CodeChallenge: opts.CodeChallenge, + CodeChallengeMethod: opts.CodeChallengeMethod, + Me: opts.Me, + RedirectURI: opts.RedirectURI, + Scope: opts.Scope, + }); err != nil { + return "", fmt.Errorf("cannot save session in store: %w", err) + } + + return code, nil +} + +func (useCase *authUseCase) Exchange(ctx context.Context, opts auth.ExchangeOptions) (*domain.Me, error) { + session, err := useCase.sessions.GetAndDelete(ctx, opts.Code) + if err != nil { + return nil, err + } + + if opts.ClientID.String() != session.ClientID.String() { + return nil, domain.Error{ + Code: "invalid_request", + Description: "client's URL MUST match the client_id used in the authentication request", + URI: "https://indieauth.net/source/#request", + Frame: xerrors.Caller(1), + } + } + + if opts.RedirectURI.String() != session.RedirectURI.String() { + return nil, domain.Error{ + Code: "invalid_request", + Description: "client's redirect URL MUST match the initial authentication request", + URI: "https://indieauth.net/source/#request", + Frame: xerrors.Caller(1), + } + } + + if session.CodeChallenge != "" && + !session.CodeChallengeMethod.Validate(session.CodeChallenge, opts.CodeVerifier) { + return nil, domain.Error{ + Code: "invalid_request", + Description: "code_verifier is not hashes to the same value as given in " + + "the code_challenge in the original authorization request", + URI: "https://indieauth.net/source/#request", + Frame: xerrors.Caller(1), + } + } + + return session.Me, nil +} diff --git a/internal/auth/usecase/auth_usecase.go b/internal/auth/usecase/auth_usecase.go deleted file mode 100644 index e2a65cf..0000000 --- a/internal/auth/usecase/auth_usecase.go +++ /dev/null @@ -1,137 +0,0 @@ -package usecase - -import ( - "bytes" - "context" - "net/url" - "time" - - http "github.com/valyala/fasthttp" - "gitlab.com/toby3d/indieauth/internal/auth" - "gitlab.com/toby3d/indieauth/internal/domain" - "gitlab.com/toby3d/indieauth/internal/pkce" - "gitlab.com/toby3d/indieauth/internal/random" - "willnorris.com/go/microformats" -) - -type authUseCase struct { - client *http.Client - repo auth.Repository -} - -func NewAuthUseCase(repo auth.Repository) auth.UseCase { - return &authUseCase{ - client: new(http.Client), - repo: repo, - } -} - -func (useCase *authUseCase) Discovery(ctx context.Context, clientId string) (*domain.Client, error) { - _, src, err := useCase.client.Get(nil, clientId) - if err != nil { - return nil, err - } - - cid, err := url.Parse(clientId) - if err != nil { - return nil, err - } - - data := microformats.Parse(bytes.NewReader(src), cid) - - client := new(domain.Client) - client.RedirectURI = make([]string, 0) - - for i := range data.Items { - if len(data.Items[i].Type) == 0 || data.Items[i].Type[0] != "h-app" { - continue - } - - for key, values := range data.Items[i].Properties { - switch key { - case "logo": - for j := range values { - switch val := values[j].(type) { - case string: - client.Logo = val - case map[string]string: - client.Logo = val["value"] - } - } - case "name": - for j := range values { - client.Name, _ = values[j].(string) - } - case "url": - for j := range values { - client.URL, _ = values[j].(string) - } - } - } - } - - for key, values := range data.Rels { - if key != "redirect_uri" { - continue - } - - client.RedirectURI = append(client.RedirectURI, values...) - } - - if client.URL != clientId { - return nil, domain.Error{ - Code: domain.ErrInvalidRequest.Code, - Description: "'client_id' does not match the actual client URL", - } - } - - return client, nil -} - -func (useCase *authUseCase) Approve(ctx context.Context, login *domain.Login) (string, error) { - login.Code = random.New().String(32) - - if err := useCase.repo.Create(ctx, login); err != nil { - return "", err - } - - return login.Code, nil -} - -func (useCase *authUseCase) Exchange(ctx context.Context, req *domain.ExchangeRequest) (string, error) { - login, err := useCase.repo.Get(ctx, req.Code) - if err != nil { - return "", err - } - - if login == nil { - return "", nil - } - - _ = useCase.repo.Delete(ctx, req.Code) - - if time.Now().UTC().After(time.Unix(login.CreatedAt, 0).Add(10 * time.Minute)) { - return "", nil - } - - if login.ClientID != req.ClientID || login.RedirectURI != req.RedirectURI { - return "", domain.ErrInvalidRequest - } - - if login.CodeChallenge != "" { - codeChallenge, err := pkce.New(login.CodeChallengeMethod) - if err != nil { - return "", err - } - - codeChallenge.Verifier = req.CodeVerifier - - codeChallenge.Generate() - - if login.CodeChallenge != codeChallenge.Challenge { - return "", domain.ErrInvalidRequest - } - } - - return login.Me, nil -}