diff --git a/internal/auth/delivery/http/auth_http.go b/internal/auth/delivery/http/auth_http.go index c760960..caa61bb 100644 --- a/internal/auth/delivery/http/auth_http.go +++ b/internal/auth/delivery/http/auth_http.go @@ -2,13 +2,10 @@ package http import ( "crypto/subtle" - "errors" - "path" + "net/http" "strings" - "github.com/fasthttp/router" - json "github.com/goccy/go-json" - http "github.com/valyala/fasthttp" + "github.com/goccy/go-json" "golang.org/x/text/language" "golang.org/x/text/message" @@ -16,111 +13,31 @@ import ( "source.toby3d.me/toby3d/auth/internal/client" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" + "source.toby3d.me/toby3d/auth/internal/middleware" "source.toby3d.me/toby3d/auth/internal/profile" + "source.toby3d.me/toby3d/auth/internal/urlutil" "source.toby3d.me/toby3d/auth/web" - "source.toby3d.me/toby3d/form" - "source.toby3d.me/toby3d/middleware" ) type ( - AuthAuthorizationRequest struct { - // Indicates to the authorization server that an authorization - // code should be returned as the response. - ResponseType domain.ResponseType `form:"response_type"` // code - - // The client URL. - ClientID *domain.ClientID `form:"client_id"` - - // The redirect URL indicating where the user should be - // redirected to after approving the request. - RedirectURI *domain.URL `form:"redirect_uri"` - - // A parameter set by the client which will be included when the - // user is redirected back to the client. This is used to - // prevent CSRF attacks. The authorization server MUST return - // the unmodified state value back to the client. - State string `form:"state"` - - // The code challenge as previously described. - CodeChallenge string `form:"code_challenge"` - - // The hashing method used to calculate the code challenge. - CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"` - - // A space-separated list of scopes the client is requesting, - // e.g. "profile", or "profile create". If the client omits this - // value, the authorization server MUST NOT issue an access - // token for this authorization code. Only the user's profile - // URL may be returned without any scope requested. - Scope domain.Scopes `form:"scope,omitempty"` - - // The URL that the user entered. - Me *domain.Me `form:"me"` - } - - AuthVerifyRequest struct { - ClientID *domain.ClientID `form:"client_id"` - Me *domain.Me `form:"me"` - RedirectURI *domain.URL `form:"redirect_uri"` - CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"` - ResponseType domain.ResponseType `form:"response_type"` - Scope domain.Scopes `form:"scope[],omitempty"` - Authorize string `form:"authorize"` - CodeChallenge string `form:"code_challenge"` - State string `form:"state"` - Provider string `form:"provider"` - } - - AuthExchangeRequest struct { - GrantType domain.GrantType `form:"grant_type"` // authorization_code - - // The authorization code received from the authorization - // endpoint in the redirect. - Code string `form:"code"` - - // The client's URL, which MUST match the client_id used in the - // authentication request. - ClientID *domain.ClientID `form:"client_id"` - - // The client's redirect URL, which MUST match the initial - // authentication request. - RedirectURI *domain.URL `form:"redirect_uri"` - - // The original plaintext random string generated before - // starting the authorization request. - CodeVerifier string `form:"code_verifier"` - } - - AuthExchangeResponse struct { - Me *domain.Me `json:"me"` - Profile *AuthProfileResponse `json:"profile,omitempty"` - } - - AuthProfileResponse struct { - Email *domain.Email `json:"email,omitempty"` - Photo *domain.URL `json:"photo,omitempty"` - URL *domain.URL `json:"url,omitempty"` - Name string `json:"name,omitempty"` - } - - NewRequestHandlerOptions struct { + NewHandlerOptions struct { Auth auth.UseCase Clients client.UseCase - Config *domain.Config + Config domain.Config Matcher language.Matcher Profiles profile.UseCase } - RequestHandler struct { + Handler struct { clients client.UseCase - config *domain.Config + config domain.Config matcher language.Matcher useCase auth.UseCase } ) -func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { - return &RequestHandler{ +func NewHandler(opts NewHandlerOptions) *Handler { + return &Handler{ clients: opts.Clients, config: opts.Config, matcher: opts.Matcher, @@ -128,16 +45,16 @@ func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { } } -func (h *RequestHandler) Register(r *router.Router) { +func (h *Handler) Handler() http.Handler { chain := middleware.Chain{ middleware.CSRFWithConfig(middleware.CSRFConfig{ - Skipper: func(ctx *http.RequestCtx) bool { - matched, _ := path.Match("/authorize*", string(ctx.Path())) + Skipper: func(w http.ResponseWriter, r *http.Request) bool { + head, _ := urlutil.ShiftPath(r.URL.Path) - return ctx.IsPost() && matched + return r.Method == http.MethodPost && head == "authorize" }, CookieMaxAge: 0, - CookieSameSite: http.CookieSameSiteStrictMode, + CookieSameSite: http.SameSiteStrictMode, ContextKey: "csrf", CookieDomain: h.config.Server.Domain, CookieName: "__Secure-csrf", @@ -148,14 +65,12 @@ func (h *RequestHandler) Register(r *router.Router) { CookieHTTPOnly: true, }), middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{ - Skipper: func(ctx *http.RequestCtx) bool { - matched, _ := path.Match("/api/*", string(ctx.Path())) - provider := string(ctx.QueryArgs().Peek("provider")) - providerMatched := provider != "" && provider != domain.ProviderDirect.UID + Skipper: func(w http.ResponseWriter, r *http.Request) bool { + head, _ := urlutil.ShiftPath(r.URL.Path) - return !ctx.IsPost() || !matched || providerMatched + return r.Method != http.MethodPost || head != "api" }, - Validator: func(ctx *http.RequestCtx, login, password string) (bool, error) { + Validator: func(w http.ResponseWriter, r *http.Request, login, password string) (bool, error) { userMatch := subtle.ConstantTimeCompare([]byte(login), []byte(h.config.IndieAuth.Username)) passMatch := subtle.ConstantTimeCompare([]byte(password), @@ -165,29 +80,57 @@ func (h *RequestHandler) Register(r *router.Router) { }, Realm: "", }), - middleware.LogFmt(), } - r.GET("/authorize", chain.RequestHandler(h.handleAuthorize)) - r.POST("/api/authorize", chain.RequestHandler(h.handleVerify)) - r.POST("/authorize", chain.RequestHandler(h.handleExchange)) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var head string + head, r.URL.Path = urlutil.ShiftPath(r.URL.Path) + + switch r.Method { + default: + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + case http.MethodGet, "": + if head != "" { + http.NotFound(w, r) + + return + } + + chain.Handler(h.handleAuthorize).ServeHTTP(w, r) + case http.MethodPost: + switch head { + default: + http.NotFound(w, r) + case "": + chain.Handler(h.handleExchange).ServeHTTP(w, r) + case "verify": + chain.Handler(h.handleVerify).ServeHTTP(w, r) + } + } + }) } -func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) +func (h *Handler) handleAuthorize(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != "" { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) + return + } + + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + + tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage)) tag, _, _ := h.matcher.Match(tags...) baseOf := web.BaseOf{ - Config: h.config, + Config: &h.config, Language: tag, Printer: message.NewPrinter(tag), } req := NewAuthAuthorizationRequest() - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) - web.WriteTemplate(ctx, &web.ErrorPage{ + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusBadRequest) + web.WriteTemplate(w, &web.ErrorPage{ BaseOf: baseOf, Error: err, }) @@ -195,10 +138,10 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) { return } - client, err := h.clients.Discovery(ctx, req.ClientID) + client, err := h.clients.Discovery(r.Context(), req.ClientID) if err != nil { - ctx.SetStatusCode(http.StatusBadRequest) - web.WriteTemplate(ctx, &web.ErrorPage{ + w.WriteHeader(http.StatusBadRequest) + web.WriteTemplate(w, &web.ErrorPage{ BaseOf: baseOf, Error: err, }) @@ -207,8 +150,8 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) { } if !client.ValidateRedirectURI(req.RedirectURI.URL) { - ctx.SetStatusCode(http.StatusBadRequest) - web.WriteTemplate(ctx, &web.ErrorPage{ + w.WriteHeader(http.StatusBadRequest) + web.WriteTemplate(w, &web.ErrorPage{ BaseOf: baseOf, Error: domain.NewError( domain.ErrorCodeInvalidClient, @@ -220,15 +163,15 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) { return } - csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte) - web.WriteTemplate(ctx, &web.AuthorizePage{ + csrf, _ := r.Context().Value(middleware.DefaultCSRFConfig.ContextKey).([]byte) + web.WriteTemplate(w, &web.AuthorizePage{ BaseOf: baseOf, CSRF: csrf, Scope: req.Scope, Client: client, - Me: req.Me, - RedirectURI: req.RedirectURI, - CodeChallengeMethod: req.CodeChallengeMethod, + Me: &req.Me, + RedirectURI: &req.RedirectURI, + CodeChallengeMethod: *req.CodeChallengeMethod, ResponseType: req.ResponseType, CodeChallenge: req.CodeChallenge, State: req.State, @@ -236,15 +179,21 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) { }) } -func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) { - ctx.Response.Header.Set(http.HeaderAccessControlAllowOrigin, h.config.Server.Domain) - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) +func (h *Handler) handleVerify(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - encoder := json.NewEncoder(ctx) + return + } + + w.Header().Set(common.HeaderAccessControlAllowOrigin, h.config.Server.Domain) + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(w) req := NewAuthVerifyRequest() - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(err) @@ -254,60 +203,70 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) { if strings.EqualFold(req.Authorize, "deny") { domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State). SetReirectURI(req.RedirectURI.URL) - ctx.Redirect(req.RedirectURI.String(), http.StatusFound) + http.Redirect(w, r, req.RedirectURI.String(), http.StatusFound) return } - code, err := h.useCase.Generate(ctx, auth.GenerateOptions{ + code, err := h.useCase.Generate(r.Context(), auth.GenerateOptions{ ClientID: req.ClientID, Me: req.Me, RedirectURI: req.RedirectURI.URL, - CodeChallengeMethod: req.CodeChallengeMethod, + CodeChallengeMethod: *req.CodeChallengeMethod, Scope: req.Scope, CodeChallenge: req.CodeChallenge, }) if err != nil { - ctx.SetStatusCode(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) _ = encoder.Encode(err) return } + q := req.RedirectURI.Query() + for key, val := range map[string]string{ "code": code, "iss": h.config.Server.GetRootURL(), "state": req.State, } { - req.RedirectURI.Query().Set(key, val) + q.Set(key, val) } - ctx.Redirect(req.RedirectURI.String(), http.StatusFound) + req.RedirectURI.RawQuery = q.Encode() + + http.Redirect(w, r, req.RedirectURI.String(), http.StatusFound) } -func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) +func (h *Handler) handleExchange(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - encoder := json.NewEncoder(ctx) + return + } + + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(w) req := new(AuthExchangeRequest) - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(err) return } - me, profile, err := h.useCase.Exchange(ctx, auth.ExchangeOptions{ + me, profile, err := h.useCase.Exchange(r.Context(), auth.ExchangeOptions{ Code: req.Code, ClientID: req.ClientID, RedirectURI: req.RedirectURI.URL, CodeVerifier: req.CodeVerifier, }) if err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(err) @@ -325,109 +284,7 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { } _ = encoder.Encode(&AuthExchangeResponse{ - Me: me, + Me: *me, Profile: userInfo, }) } - -func NewAuthAuthorizationRequest() *AuthAuthorizationRequest { - return &AuthAuthorizationRequest{ - ClientID: new(domain.ClientID), - CodeChallenge: "", - CodeChallengeMethod: domain.CodeChallengeMethodUnd, - Me: new(domain.Me), - RedirectURI: new(domain.URL), - ResponseType: domain.ResponseTypeUnd, - Scope: make(domain.Scopes, 0), - State: "", - } -} - -//nolint:cyclop -func (r *AuthAuthorizationRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieauth.net/source/#authorization-request", - ) - } - - if r.ResponseType == domain.ResponseTypeID { - r.ResponseType = domain.ResponseTypeCode - } - - return nil -} - -func NewAuthVerifyRequest() *AuthVerifyRequest { - return &AuthVerifyRequest{ - Authorize: "", - ClientID: new(domain.ClientID), - CodeChallenge: "", - CodeChallengeMethod: domain.CodeChallengeMethodUnd, - Me: new(domain.Me), - Provider: "", - RedirectURI: new(domain.URL), - ResponseType: domain.ResponseTypeUnd, - Scope: make(domain.Scopes, 0), - State: "", - } -} - -//nolint:funlen,cyclop -func (r *AuthVerifyRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - - if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieauth.net/source/#authorization-request", - ) - } - - // NOTE(toby3d): backwards-compatible support. - // See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type - if r.ResponseType == domain.ResponseTypeID { - r.ResponseType = domain.ResponseTypeCode - } - - r.Provider = strings.ToLower(r.Provider) - - if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") { - return domain.NewError( - domain.ErrorCodeInvalidRequest, - "cannot validate verification request", - "https://indieauth.net/source/#authorization-request", - ) - } - - return nil -} - -func (r *AuthExchangeRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - "cannot validate verification request", - "https://indieauth.net/source/#redeeming-the-authorization-code", - ) - } - - return nil -} diff --git a/internal/auth/delivery/http/auth_http_schema.go b/internal/auth/delivery/http/auth_http_schema.go new file mode 100644 index 0000000..1843b3c --- /dev/null +++ b/internal/auth/delivery/http/auth_http_schema.go @@ -0,0 +1,212 @@ +package http + +import ( + "errors" + "net/http" + "strings" + + "source.toby3d.me/toby3d/auth/internal/domain" + "source.toby3d.me/toby3d/form" +) + +type ( + AuthAuthorizationRequest struct { + // Indicates to the authorization server that an authorization + // code should be returned as the response. + ResponseType domain.ResponseType `form:"response_type"` // code + + // The client URL. + ClientID domain.ClientID `form:"client_id"` + + // The redirect URL indicating where the user should be + // redirected to after approving the request. + RedirectURI domain.URL `form:"redirect_uri"` + + // The URL that the user entered. + Me domain.Me `form:"me"` + + // The hashing method used to calculate the code challenge. + CodeChallengeMethod *domain.CodeChallengeMethod `form:"code_challenge_method,omitempty"` + + // A space-separated list of scopes the client is requesting, + // e.g. "profile", or "profile create". If the client omits this + // value, the authorization server MUST NOT issue an access + // token for this authorization code. Only the user's profile + // URL may be returned without any scope requested. + Scope domain.Scopes `form:"scope,omitempty"` + + // A parameter set by the client which will be included when the + // user is redirected back to the client. This is used to + // prevent CSRF attacks. The authorization server MUST return + // the unmodified state value back to the client. + State string `form:"state"` + + // The code challenge as previously described. + CodeChallenge string `form:"code_challenge,omitempty"` + } + + AuthVerifyRequest struct { + ClientID domain.ClientID `form:"client_id"` + Me domain.Me `form:"me"` + RedirectURI domain.URL `form:"redirect_uri"` + ResponseType domain.ResponseType `form:"response_type"` + CodeChallengeMethod *domain.CodeChallengeMethod `form:"code_challenge_method,omitempty"` + Scope domain.Scopes `form:"scope[],omitempty"` + Authorize string `form:"authorize"` + CodeChallenge string `form:"code_challenge,omitempty"` + State string `form:"state"` + Provider string `form:"provider"` + } + + AuthExchangeRequest struct { + GrantType domain.GrantType `form:"grant_type"` // authorization_code + + // The client's URL, which MUST match the client_id used in the + // authentication request. + ClientID domain.ClientID `form:"client_id"` + + // The client's redirect URL, which MUST match the initial + // authentication request. + RedirectURI domain.URL `form:"redirect_uri"` + + // The authorization code received from the authorization + // endpoint in the redirect. + Code string `form:"code"` + + // The original plaintext random string generated before + // starting the authorization request. + CodeVerifier string `form:"code_verifier"` + } + + AuthExchangeResponse struct { + Me domain.Me `json:"me"` + Profile *AuthProfileResponse `json:"profile,omitempty"` + } + + AuthProfileResponse struct { + Email *domain.Email `json:"email,omitempty"` + Photo *domain.URL `json:"photo,omitempty"` + URL *domain.URL `json:"url,omitempty"` + Name string `json:"name,omitempty"` + } +) + +func NewAuthAuthorizationRequest() *AuthAuthorizationRequest { + return &AuthAuthorizationRequest{ + ClientID: domain.ClientID{}, + CodeChallenge: "", + CodeChallengeMethod: &domain.CodeChallengeMethodUnd, + Me: domain.Me{}, + RedirectURI: domain.URL{}, + ResponseType: domain.ResponseTypeUnd, + Scope: make(domain.Scopes, 0), + State: "", + } +} + +//nolint:cyclop +func (r *AuthAuthorizationRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#authorization-request", + ) + } + + if r.ResponseType == domain.ResponseTypeID { + r.ResponseType = domain.ResponseTypeCode + } + + return nil +} + +func NewAuthVerifyRequest() *AuthVerifyRequest { + return &AuthVerifyRequest{ + Authorize: "", + ClientID: domain.ClientID{}, + CodeChallenge: "", + CodeChallengeMethod: &domain.CodeChallengeMethodUnd, + Me: domain.Me{}, + Provider: "", + RedirectURI: domain.URL{}, + ResponseType: domain.ResponseTypeUnd, + Scope: make(domain.Scopes, 0), + State: "", + } +} + +//nolint:funlen,cyclop +func (r *AuthVerifyRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := req.ParseForm(); err != nil { + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#authorization-request", + ) + } + + if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#authorization-request", + ) + } + + // NOTE(toby3d): backwards-compatible support. + // See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type + if r.ResponseType == domain.ResponseTypeID { + r.ResponseType = domain.ResponseTypeCode + } + + r.Provider = strings.ToLower(r.Provider) + + if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") { + return domain.NewError( + domain.ErrorCodeInvalidRequest, + "cannot validate verification request", + "https://indieauth.net/source/#authorization-request", + ) + } + + return nil +} + +func (r *AuthExchangeRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := req.ParseForm(); err != nil { + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#authorization-request", + ) + } + + if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + "cannot validate verification request", + "https://indieauth.net/source/#redeeming-the-authorization-code", + ) + } + + return nil +} diff --git a/internal/auth/delivery/http/auth_http_test.go b/internal/auth/delivery/http/auth_http_test.go index 7d11b48..a0dd9d2 100644 --- a/internal/auth/delivery/http/auth_http_test.go +++ b/internal/auth/delivery/http/auth_http_test.go @@ -1,13 +1,14 @@ package http_test import ( - "path" + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" "strings" - "sync" "testing" - "github.com/fasthttp/router" - http "github.com/valyala/fasthttp" "golang.org/x/text/language" "golang.org/x/text/message" @@ -22,7 +23,7 @@ import ( profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory" "source.toby3d.me/toby3d/auth/internal/session" sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory" - "source.toby3d.me/toby3d/auth/internal/testing/httptest" + "source.toby3d.me/toby3d/auth/internal/user" userrepo "source.toby3d.me/toby3d/auth/internal/user/repository/memory" ) @@ -34,36 +35,31 @@ type Dependencies struct { matcher language.Matcher profiles profile.Repository sessions session.Repository - store *sync.Map + users user.Repository } func TestAuthorize(t *testing.T) { t.Parallel() deps := NewDependencies(t) - me := domain.TestMe(t, "https://user.example.net") + me := domain.TestMe(t, "https://user.example.net/") user := domain.TestUser(t) client := domain.TestClient(t) - deps.store.Store(path.Join(clientrepo.DefaultPathPrefix, client.ID.String()), client) - deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, me.String()), user.Profile) - deps.store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), user) + if err := deps.clients.Create(context.Background(), *client); err != nil { + t.Fatal(err) + } - r := router.New() - //nolint:exhaustivestruct - delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{ - Auth: deps.authService, - Clients: deps.clientService, - Config: deps.config, - Matcher: deps.matcher, - }).Register(r) + if err := deps.users.Create(context.Background(), *user); err != nil { + t.Fatal(err) + } - httpClient, _, cleanup := httptest.New(t, r.Handler) - t.Cleanup(cleanup) + if err := deps.profiles.Create(context.Background(), *me, *user.Profile); err != nil { + t.Fatal(err) + } - uri := http.AcquireURI() - defer http.ReleaseURI(uri) - uri.Update("https://example.com/authorize") + u := &url.URL{Scheme: "https", Host: "example.com", Path: "/"} + q := u.Query() for key, val := range map[string]string{ "client_id": client.ID.String(), @@ -75,26 +71,36 @@ func TestAuthorize(t *testing.T) { "scope": "profile email", "state": "1234567890", } { - uri.QueryArgs().Set(key, val) + q.Set(key, val) } - req := httptest.NewRequest(http.MethodGet, uri.String(), nil) - defer http.ReleaseRequest(req) + u.RawQuery = q.Encode() - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) + req := httptest.NewRequest(http.MethodGet, u.String(), nil) + w := httptest.NewRecorder() - if err := httpClient.Do(req, resp); err != nil { + //nolint:exhaustivestruct + delivery.NewHandler(delivery.NewHandlerOptions{ + Auth: deps.authService, + Clients: deps.clientService, + Config: *deps.config, + Matcher: deps.matcher, + }).Handler().ServeHTTP(w, req) + + resp := w.Result() + + body, err := io.ReadAll(resp.Body) + if err != nil { t.Fatal(err) } - if resp.StatusCode() != http.StatusOK { - t.Errorf("GET %s = %d, want %d", uri.String(), resp.StatusCode(), http.StatusOK) + if resp.StatusCode != http.StatusOK { + t.Errorf("%s %s = %d, want %d", req.Method, u.String(), resp.StatusCode, http.StatusOK) } const expResult = `Authorize application` - if result := string(resp.Body()); !strings.Contains(result, expResult) { - t.Errorf("GET %s = %s, want %s", uri.String(), result, expResult) + if result := string(body); !strings.Contains(result, expResult) { + t.Errorf("%s %s = %s, want %s", req.Method, u.String(), result, expResult) } } @@ -103,14 +109,15 @@ func NewDependencies(tb testing.TB) Dependencies { config := domain.TestConfig(tb) matcher := language.NewMatcher(message.DefaultCatalog.Languages()) - store := new(sync.Map) - clients := clientrepo.NewMemoryClientRepository(store) - sessions := sessionrepo.NewMemorySessionRepository(store, config) - profiles := profilerepo.NewMemoryProfileRepository(store) + clients := clientrepo.NewMemoryClientRepository() + users := userrepo.NewMemoryUserRepository() + sessions := sessionrepo.NewMemorySessionRepository(*config) + profiles := profilerepo.NewMemoryProfileRepository() authService := ucase.NewAuthUseCase(sessions, profiles, config) clientService := clientucase.NewClientUseCase(clients) return Dependencies{ + users: users, authService: authService, clients: clients, clientService: clientService, @@ -118,6 +125,5 @@ func NewDependencies(tb testing.TB) Dependencies { matcher: matcher, sessions: sessions, profiles: profiles, - store: store, } } diff --git a/internal/auth/usecase.go b/internal/auth/usecase.go index cf92b73..e6c9155 100644 --- a/internal/auth/usecase.go +++ b/internal/auth/usecase.go @@ -9,8 +9,8 @@ import ( type ( GenerateOptions struct { - ClientID *domain.ClientID - Me *domain.Me + ClientID domain.ClientID + Me domain.Me RedirectURI *url.URL CodeChallengeMethod domain.CodeChallengeMethod Scope domain.Scopes @@ -18,7 +18,7 @@ type ( } ExchangeOptions struct { - ClientID *domain.ClientID + ClientID domain.ClientID RedirectURI *url.URL Code string CodeVerifier string diff --git a/internal/auth/usecase/auth_ucase.go b/internal/auth/usecase/auth_ucase.go index fe6610b..0dd4b83 100644 --- a/internal/auth/usecase/auth_ucase.go +++ b/internal/auth/usecase/auth_ucase.go @@ -45,7 +45,7 @@ func (uc *authUseCase) Generate(ctx context.Context, opts auth.GenerateOptions) } } - if err = uc.sessions.Create(ctx, &domain.Session{ + if err = uc.sessions.Create(ctx, domain.Session{ ClientID: opts.ClientID, Code: code, CodeChallenge: opts.CodeChallenge, @@ -81,5 +81,5 @@ func (uc *authUseCase) Exchange(ctx context.Context, opts auth.ExchangeOptions) return nil, nil, auth.ErrMismatchPKCE } - return session.Me, session.Profile, nil + return &session.Me, session.Profile, nil } diff --git a/internal/client/delivery/http/client_http.go b/internal/client/delivery/http/client_http.go index e553d85..f325dbf 100644 --- a/internal/client/delivery/http/client_http.go +++ b/internal/client/delivery/http/client_http.go @@ -1,48 +1,37 @@ package http import ( - "errors" + "net/http" "strings" - "github.com/fasthttp/router" - http "github.com/valyala/fasthttp" "golang.org/x/text/language" "golang.org/x/text/message" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/token" + "source.toby3d.me/toby3d/auth/internal/urlutil" "source.toby3d.me/toby3d/auth/web" - "source.toby3d.me/toby3d/form" - "source.toby3d.me/toby3d/middleware" ) type ( - ClientCallbackRequest struct { - Error domain.ErrorCode `form:"error,omitempty"` - Iss *domain.ClientID `form:"iss"` - Code string `form:"code"` - ErrorDescription string `form:"error_description,omitempty"` - State string `form:"state"` - } - - NewRequestHandlerOptions struct { + NewHandlerOptions struct { Matcher language.Matcher Tokens token.UseCase - Client *domain.Client - Config *domain.Config + Client domain.Client + Config domain.Config } - RequestHandler struct { + Handler struct { matcher language.Matcher tokens token.UseCase - client *domain.Client - config *domain.Config + client domain.Client + config domain.Config } ) -func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { - return &RequestHandler{ +func NewHandler(opts NewHandlerOptions) *Handler { + return &Handler{ client: opts.Client, config: opts.Config, matcher: opts.Matcher, @@ -50,59 +39,82 @@ func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { } } -func (h *RequestHandler) Register(r *router.Router) { - chain := middleware.Chain{ - middleware.LogFmt(), - } +func (h *Handler) Handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "" && r.Method != http.MethodGet { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - r.GET("/", chain.RequestHandler(h.handleRender)) - r.GET("/callback", chain.RequestHandler(h.handleCallback)) + return + } + + var head string + head, r.URL.Path = urlutil.ShiftPath(r.URL.Path) + + switch head { + default: + http.NotFound(w, r) + case "": + h.handleRender(w, r) + case "callback": + h.handleCallback(w, r) + } + }) } -func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { +func (h *Handler) handleRender(w http.ResponseWriter, r *http.Request) { + if r.Method != "" && r.Method != http.MethodGet { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + + return + } + redirect := make([]string, len(h.client.RedirectURI)) for i := range h.client.RedirectURI { redirect[i] = h.client.RedirectURI[i].String() } - ctx.Response.Header.Set( - http.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`, - ) + w.Header().Set(common.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`) - tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) + tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage)) tag, _, _ := h.matcher.Match(tags...) // TODO(toby3d): generate and store PKCE - ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) - web.WriteTemplate(ctx, &web.HomePage{ + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + web.WriteTemplate(w, &web.HomePage{ BaseOf: web.BaseOf{ - Config: h.config, + Config: &h.config, Language: tag, Printer: message.NewPrinter(tag), }, - Client: h.client, + Client: &h.client, State: "hackme", // TODO(toby3d): generate and store state }) } //nolint:unlen -func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) +func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) { + if r.Method != "" && r.Method != http.MethodGet { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) + return + } + + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + + tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage)) tag, _, _ := h.matcher.Match(tags...) baseOf := web.BaseOf{ - Config: h.config, + Config: &h.config, Language: tag, Printer: message.NewPrinter(tag), } req := new(ClientCallbackRequest) - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusInternalServerError) - web.WriteTemplate(ctx, &web.ErrorPage{ + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusInternalServerError) + web.WriteTemplate(w, &web.ErrorPage{ BaseOf: baseOf, Error: err, }) @@ -111,8 +123,8 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) { } if req.Error != domain.ErrorCodeUnd { - ctx.SetStatusCode(http.StatusUnauthorized) - web.WriteTemplate(ctx, &web.ErrorPage{ + w.WriteHeader(http.StatusUnauthorized) + web.WriteTemplate(w, &web.ErrorPage{ BaseOf: baseOf, Error: domain.NewError( domain.ErrorCodeAccessDenied, @@ -127,9 +139,9 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) { // TODO(toby3d): load and check state - if req.Iss == nil || req.Iss.String() != h.client.ID.String() { - ctx.SetStatusCode(http.StatusBadRequest) - web.WriteTemplate(ctx, &web.ErrorPage{ + if req.Iss.String() != h.client.ID.String() { + w.WriteHeader(http.StatusBadRequest) + web.WriteTemplate(w, &web.ErrorPage{ BaseOf: baseOf, Error: domain.NewError( domain.ErrorCodeInvalidClient, @@ -142,15 +154,15 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) { return } - token, _, err := h.tokens.Exchange(ctx, token.ExchangeOptions{ + token, _, err := h.tokens.Exchange(r.Context(), token.ExchangeOptions{ ClientID: h.client.ID, RedirectURI: h.client.RedirectURI[0], Code: req.Code, CodeVerifier: "", // TODO(toby3d): validate PKCE here }) if err != nil { - ctx.SetStatusCode(http.StatusBadRequest) - web.WriteTemplate(ctx, &web.ErrorPage{ + w.WriteHeader(http.StatusBadRequest) + web.WriteTemplate(w, &web.ErrorPage{ BaseOf: baseOf, Error: err, }) @@ -158,23 +170,9 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) { return } - ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) - web.WriteTemplate(ctx, &web.CallbackPage{ + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + web.WriteTemplate(w, &web.CallbackPage{ BaseOf: baseOf, Token: token, }) } - -func (req *ClientCallbackRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - - if err := form.Unmarshal(ctx.QueryArgs().QueryString(), req); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "") - } - - return nil -} diff --git a/internal/client/delivery/http/client_http_schema.go b/internal/client/delivery/http/client_http_schema.go new file mode 100644 index 0000000..ccb5675 --- /dev/null +++ b/internal/client/delivery/http/client_http_schema.go @@ -0,0 +1,31 @@ +package http + +import ( + "errors" + "net/http" + + "source.toby3d.me/toby3d/auth/internal/domain" + "source.toby3d.me/toby3d/form" +) + +type ClientCallbackRequest struct { + Error domain.ErrorCode `form:"error,omitempty"` + Iss domain.ClientID `form:"iss"` + Code string `form:"code"` + ErrorDescription string `form:"error_description,omitempty"` + State string `form:"state"` +} + +func (req *ClientCallbackRequest) bind(r *http.Request) error { + indieAuthError := new(domain.Error) + + if err := form.Unmarshal([]byte(r.URL.Query().Encode()), req); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "") + } + + return nil +} diff --git a/internal/client/delivery/http/client_http_test.go b/internal/client/delivery/http/client_http_test.go index b46017b..eab41f8 100644 --- a/internal/client/delivery/http/client_http_test.go +++ b/internal/client/delivery/http/client_http_test.go @@ -1,11 +1,10 @@ package http_test import ( - "sync" + "net/http" + "net/http/httptest" "testing" - "github.com/fasthttp/router" - http "github.com/valyala/fasthttp" "golang.org/x/text/language" "golang.org/x/text/message" @@ -15,7 +14,6 @@ import ( profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory" "source.toby3d.me/toby3d/auth/internal/session" sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory" - "source.toby3d.me/toby3d/auth/internal/testing/httptest" "source.toby3d.me/toby3d/auth/internal/token" tokenrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory" tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase" @@ -27,7 +25,6 @@ type Dependencies struct { config *domain.Config matcher language.Matcher sessions session.Repository - store *sync.Map tokens token.Repository tokenService token.UseCase } @@ -36,45 +33,30 @@ func TestRead(t *testing.T) { t.Parallel() deps := NewDependencies(t) + req := httptest.NewRequest(http.MethodGet, "https://app.example.com/", nil) + w := httptest.NewRecorder() - r := router.New() - delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{ - Client: deps.client, - Config: deps.config, + delivery.NewHandler(delivery.NewHandlerOptions{ + Client: *deps.client, + Config: *deps.config, Matcher: deps.matcher, Tokens: deps.tokenService, - }).Register(r) + }).Handler().ServeHTTP(w, req) - client, _, cleanup := httptest.New(t, r.Handler) - t.Cleanup(cleanup) - - const requestURI string = "https://app.example.com/" - req, resp := httptest.NewRequest(http.MethodGet, requestURI, nil), http.AcquireResponse() - - t.Cleanup(func() { - http.ReleaseRequest(req) - http.ReleaseResponse(resp) - }) - - if err := client.Do(req, resp); err != nil { - t.Error(err) - } - - if resp.StatusCode() != http.StatusOK { - t.Errorf("GET %s = %d, want %d", requestURI, resp.StatusCode(), http.StatusOK) + if resp := w.Result(); resp.StatusCode != http.StatusOK { + t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK) } } func NewDependencies(tb testing.TB) Dependencies { tb.Helper() - store := new(sync.Map) client := domain.TestClient(tb) config := domain.TestConfig(tb) matcher := language.NewMatcher(message.DefaultCatalog.Languages()) - sessions := sessionrepo.NewMemorySessionRepository(store, config) - tokens := tokenrepo.NewMemoryTokenRepository(store) - profiles := profilerepo.NewMemoryProfileRepository(store) + sessions := sessionrepo.NewMemorySessionRepository(*config) + tokens := tokenrepo.NewMemoryTokenRepository() + profiles := profilerepo.NewMemoryProfileRepository() tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{ Config: config, Profiles: profiles, @@ -87,7 +69,6 @@ func NewDependencies(tb testing.TB) Dependencies { config: config, matcher: matcher, sessions: sessions, - store: store, profiles: profiles, tokens: tokens, tokenService: tokenService, diff --git a/internal/client/repository.go b/internal/client/repository.go index 6aef5ad..0389498 100644 --- a/internal/client/repository.go +++ b/internal/client/repository.go @@ -7,7 +7,8 @@ import ( ) type Repository interface { - Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) + Create(ctx context.Context, client domain.Client) error + Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) } var ErrNotExist error = domain.NewError( diff --git a/internal/client/repository/http/http_client.go b/internal/client/repository/http/http_client.go index 81e6a21..ff10730 100644 --- a/internal/client/repository/http/http_client.go +++ b/internal/client/repository/http/http_client.go @@ -1,14 +1,17 @@ package http import ( + "bytes" "context" "fmt" - "net" + "io" + "net/http" "net/url" - http "github.com/valyala/fasthttp" + "golang.org/x/exp/slices" "source.toby3d.me/toby3d/auth/internal/client" + "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/httputil" ) @@ -34,33 +37,18 @@ func NewHTTPClientRepository(c *http.Client) client.Repository { } } -func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) { - ips, err := net.LookupIP(cid.URL().Hostname()) +// WARN(toby3d): not implemented. +func (httpClientRepository) Create(_ context.Context, _ domain.Client) error { + return nil +} + +func (repo httpClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) { + resp, err := repo.client.Get(cid.String()) if err != nil { - return nil, fmt.Errorf("cannot resolve client IP by id: %w", err) - } - - for _, ip := range ips { - if !ip.IsLoopback() { - continue - } - - return nil, client.ErrNotExist - } - - req := http.AcquireRequest() - defer http.ReleaseRequest(req) - req.SetRequestURI(cid.String()) - req.Header.SetMethod(http.MethodGet) - - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - - if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil { return nil, fmt.Errorf("failed to make a request to the client: %w", err) } - if resp.StatusCode() == http.StatusNotFound { + if resp.StatusCode == http.StatusNotFound { return nil, fmt.Errorf("%w: status on client page is not 200", client.ErrNotExist) } @@ -72,74 +60,62 @@ func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) Name: make([]string, 0), } - extract(client, resp) + extract(resp.Body, resp.Request.URL, client, resp.Header.Get(common.HeaderLink)) return client, nil } //nolint:gocognit,cyclop -func extract(dst *domain.Client, src *http.Response) { - for _, endpoint := range httputil.ExtractEndpoints(src, relRedirectURI) { - if !containsURL(dst.RedirectURI, endpoint) { +func extract(r io.Reader, u *url.URL, dst *domain.Client, header string) { + body, _ := io.ReadAll(r) + + for _, endpoint := range httputil.ExtractEndpoints(bytes.NewReader(body), u, header, relRedirectURI) { + if !containsUrl(dst.RedirectURI, endpoint) { dst.RedirectURI = append(dst.RedirectURI, endpoint) } } - for _, itemType := range []string{hXApp, hApp} { - for _, name := range httputil.ExtractProperty(src, itemType, propertyName) { - if n, ok := name.(string); ok && !containsString(dst.Name, n) { + for _, itemType := range []string{hApp, hXApp} { + for _, name := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyName) { + if n, ok := name.(string); ok && !slices.Contains(dst.Name, n) { dst.Name = append(dst.Name, n) } } - for _, logo := range httputil.ExtractProperty(src, itemType, propertyLogo) { - var ( - u *url.URL - err error - ) + for _, logo := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyLogo) { + var logoURL *url.URL + var err error switch l := logo.(type) { case string: - u, err = url.Parse(l) + logoURL, err = url.Parse(l) case map[string]string: if value, ok := l["value"]; ok { - u, err = url.Parse(value) + logoURL, err = url.Parse(value) } } - if err != nil || containsURL(dst.Logo, u) { + if err != nil || containsUrl(dst.Logo, logoURL) { continue } - dst.Logo = append(dst.Logo, u) + dst.Logo = append(dst.Logo, logoURL) } - for _, property := range httputil.ExtractProperty(src, itemType, propertyURL) { + for _, property := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyURL) { prop, ok := property.(string) if !ok { continue } - if u, err := url.Parse(prop); err == nil || !containsURL(dst.URL, u) { + if u, err := url.Parse(prop); err == nil && !containsUrl(dst.URL, u) { dst.URL = append(dst.URL, u) } } } } -func containsString(src []string, find string) bool { - for i := range src { - if src[i] != find { - continue - } - - return true - } - - return false -} - -func containsURL(src []*url.URL, find *url.URL) bool { +func containsUrl(src []*url.URL, find *url.URL) bool { for i := range src { if src[i].String() != find.String() { continue diff --git a/internal/client/repository/http/http_client_test.go b/internal/client/repository/http/http_client_test.go index 5594d7e..c5eb739 100644 --- a/internal/client/repository/http/http_client_test.go +++ b/internal/client/repository/http/http_client_test.go @@ -3,22 +3,21 @@ package http_test import ( "context" "fmt" + "net/http" + "net/http/httptest" "testing" - "github.com/stretchr/testify/assert" - http "github.com/valyala/fasthttp" + "github.com/google/go-cmp/cmp" repository "source.toby3d.me/toby3d/auth/internal/client/repository/http" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" - "source.toby3d.me/toby3d/auth/internal/testing/httptest" ) -const testBody string = ` - +const testBody string = ` - + %[1]s @@ -36,38 +35,47 @@ func TestGet(t *testing.T) { t.Parallel() client := domain.TestClient(t) - httpClient, _, cleanup := httptest.New(t, testHandler(t, client)) - t.Cleanup(cleanup) + srv := httptest.NewUnstartedServer(testHandler(t, *client)) + srv.EnableHTTP2 = true - result, err := repository.NewHTTPClientRepository(httpClient).Get(context.Background(), client.ID) + srv.StartTLS() + t.Cleanup(srv.Close) + + client.ID = *domain.TestClientID(t, srv.URL+"/") + clients := repository.NewHTTPClientRepository(srv.Client()) + + result, err := clients.Get(context.Background(), client.ID) if err != nil { t.Fatal(err) } - assert.Equal(t, client.Name, result.Name) - assert.Equal(t, client.ID.String(), result.ID.String()) - - for i := range client.URL { - assert.Equal(t, client.URL[i].String(), result.URL[i].String()) + if out := client.ID; !result.ID.IsEqual(out) { + t.Errorf("GET %s = %s, want %s", client.ID, out, result.ID) } - for i := range client.Logo { - assert.Equal(t, client.Logo[i].String(), result.Logo[i].String()) + if !cmp.Equal(result.Name, client.Name) { + t.Errorf("GET %s = %+s, want %+s", client.ID, result.Name, client.Name) } - for i := range client.RedirectURI { - assert.Equal(t, client.RedirectURI[i].String(), result.RedirectURI[i].String()) + if !cmp.Equal(result.URL, client.URL) { + t.Errorf("GET %s = %+s, want %+s", client.ID, result.URL, client.URL) + } + + if !cmp.Equal(result.Logo, client.Logo) { + t.Errorf("GET %s = %+s, want %+s", client.ID, result.Logo, client.Logo) + } + + if !cmp.Equal(result.RedirectURI, client.RedirectURI) { + t.Errorf("GET %s = %+s, want %+s", client.ID, result.RedirectURI, client.RedirectURI) } } -func testHandler(tb testing.TB, client *domain.Client) http.RequestHandler { +func testHandler(tb testing.TB, client domain.Client) http.Handler { tb.Helper() - return func(ctx *http.RequestCtx) { - ctx.Response.Header.Set(http.HeaderLink, `<`+client.RedirectURI[0].String()+`>; rel="redirect_uri"`) - ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, fmt.Sprintf( - testBody, client.Name[0], client.URL[0].String(), client.Logo[0].String(), - client.RedirectURI[1].String(), - )) - } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + w.Header().Set(common.HeaderLink, `<`+client.RedirectURI[0].String()+`>; rel="redirect_uri"`) + fmt.Fprintf(w, testBody, client.Name[0], client.URL[0], client.Logo[0], client.RedirectURI[1]) + }) } diff --git a/internal/client/repository/memory/memory_client.go b/internal/client/repository/memory/memory_client.go index e4d433f..cfdd04e 100644 --- a/internal/client/repository/memory/memory_client.go +++ b/internal/client/repository/memory/memory_client.go @@ -2,9 +2,6 @@ package memory import ( "context" - "fmt" - "net" - "path" "sync" "source.toby3d.me/toby3d/auth/internal/client" @@ -12,45 +9,33 @@ import ( ) type memoryClientRepository struct { - store *sync.Map + mutex *sync.RWMutex + clients map[string]domain.Client } -const DefaultPathPrefix string = "clients" - -func NewMemoryClientRepository(store *sync.Map) client.Repository { +func NewMemoryClientRepository() client.Repository { return &memoryClientRepository{ - store: store, + mutex: new(sync.RWMutex), + clients: make(map[string]domain.Client), } } -func (repo *memoryClientRepository) Create(ctx context.Context, client *domain.Client) error { - repo.store.Store(path.Join(DefaultPathPrefix, client.ID.String()), client) +func (repo memoryClientRepository) Create(ctx context.Context, client domain.Client) error { + repo.mutex.RLock() + defer repo.mutex.RUnlock() + + repo.clients[client.ID.String()] = client return nil } -func (repo *memoryClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) { - // WARN(toby3d): more often than not, we will work from tests with - // non-existent clients, almost guaranteed to cause a resolution error. - ips, _ := net.LookupIP(id.URL().Hostname()) +func (repo memoryClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) { + repo.mutex.RLock() + defer repo.mutex.RUnlock() - for _, ip := range ips { - if !ip.IsLoopback() { - continue - } - - return nil, client.ErrNotExist + if c, ok := repo.clients[cid.String()]; ok { + return &c, nil } - src, ok := repo.store.Load(path.Join(DefaultPathPrefix, id.String())) - if !ok { - return nil, fmt.Errorf("cannot find client in store: %w", client.ErrNotExist) - } - - c, ok := src.(*domain.Client) - if !ok { - return nil, fmt.Errorf("cannot decode client from store: %w", client.ErrNotExist) - } - - return c, nil + return nil, client.ErrNotExist } diff --git a/internal/client/repository/memory/memory_client_test.go b/internal/client/repository/memory/memory_client_test.go deleted file mode 100644 index 698904e..0000000 --- a/internal/client/repository/memory/memory_client_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package memory_test - -import ( - "context" - "path" - "reflect" - "sync" - "testing" - - repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory" - "source.toby3d.me/toby3d/auth/internal/domain" -) - -func TestGet(t *testing.T) { - t.Parallel() - - client := domain.TestClient(t) - - store := new(sync.Map) - store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client) - - result, err := repository.NewMemoryClientRepository(store). - Get(context.Background(), client.ID) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(result, client) { - t.Errorf("Get(%s) = %+v, want %+v", client.ID, result, client) - } -} diff --git a/internal/client/usecase.go b/internal/client/usecase.go index ca3f655..7a2a706 100644 --- a/internal/client/usecase.go +++ b/internal/client/usecase.go @@ -8,7 +8,7 @@ import ( type UseCase interface { // Discovery returns client public information bu ClientID URL. - Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) + Discovery(ctx context.Context, id domain.ClientID) (*domain.Client, error) } var ErrInvalidMe error = domain.NewError( diff --git a/internal/client/usecase/client_ucase.go b/internal/client/usecase/client_ucase.go index 863b056..0a4b47e 100644 --- a/internal/client/usecase/client_ucase.go +++ b/internal/client/usecase/client_ucase.go @@ -18,7 +18,7 @@ func NewClientUseCase(repo client.Repository) client.UseCase { } } -func (useCase *clientUseCase) Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) { +func (useCase *clientUseCase) Discovery(ctx context.Context, id domain.ClientID) (*domain.Client, error) { c, err := useCase.repo.Get(ctx, id) if err != nil { return nil, fmt.Errorf("cannot discovery client by id: %w", err) diff --git a/internal/client/usecase/client_ucase_test.go b/internal/client/usecase/client_ucase_test.go index ae263f7..6674b8d 100644 --- a/internal/client/usecase/client_ucase_test.go +++ b/internal/client/usecase/client_ucase_test.go @@ -3,12 +3,9 @@ package usecase_test import ( "context" "errors" - "path" "reflect" - "sync" "testing" - "source.toby3d.me/toby3d/auth/internal/client" repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory" "source.toby3d.me/toby3d/auth/internal/client/usecase" "source.toby3d.me/toby3d/auth/internal/domain" @@ -17,12 +14,11 @@ import ( func TestDiscovery(t *testing.T) { t.Parallel() - store := new(sync.Map) - testClient, localhostClient := domain.TestClient(t), domain.TestClient(t) - localhostClient.ID, _ = domain.ParseClientID("http://localhost/") + testClient := domain.TestClient(t) + clients := repository.NewMemoryClientRepository() - for _, client := range []*domain.Client{testClient, localhostClient} { - store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client) + if err := clients.Create(context.Background(), *testClient); err != nil { + t.Fatal(err) } for _, tc := range []struct { @@ -34,17 +30,13 @@ func TestDiscovery(t *testing.T) { name: "default", in: testClient, out: testClient, - }, { - name: "localhost", - in: localhostClient, - expError: client.ErrNotExist, }} { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)). + result, err := usecase.NewClientUseCase(clients). Discovery(context.Background(), tc.in.ID) if tc.expError != nil && !errors.Is(err, tc.expError) { t.Errorf("Discovery(%s) = %+v, want %+v", tc.in.ID, err, tc.expError) diff --git a/internal/domain/client.go b/internal/domain/client.go index a2fb8d1..44131a3 100644 --- a/internal/domain/client.go +++ b/internal/domain/client.go @@ -9,7 +9,7 @@ import ( // Client describes the client requesting data about the user. type Client struct { - ID *ClientID + ID ClientID Logo []*url.URL RedirectURI []*url.URL URL []*url.URL @@ -17,7 +17,7 @@ type Client struct { } // NewClient creates a new empty Client with provided ClientID, if any. -func NewClient(cid *ClientID) *Client { +func NewClient(cid ClientID) *Client { return &Client{ ID: cid, Logo: make([]*url.URL, 0), @@ -32,7 +32,7 @@ func TestClient(tb testing.TB) *Client { tb.Helper() return &Client{ - ID: TestClientID(tb), + ID: *TestClientID(tb), Name: []string{"Example App"}, URL: []*url.URL{{Scheme: "https", Host: "app.example.com", Path: "/"}}, Logo: []*url.URL{{Scheme: "https", Host: "app.example.com", Path: "/logo.png"}}, diff --git a/internal/domain/client_id.go b/internal/domain/client_id.go index f972db5..206d129 100644 --- a/internal/domain/client_id.go +++ b/internal/domain/client_id.go @@ -8,6 +8,8 @@ import ( "testing" "inet.af/netaddr" + + "source.toby3d.me/toby3d/auth/internal/common" ) // ClientID is a URL client identifier. @@ -37,16 +39,20 @@ func ParseClientID(src string) (*ClientID, error) { if cid.Scheme != "http" && cid.Scheme != "https" { return nil, NewError( ErrorCodeInvalidRequest, - "client identifier URL MUST have either an https or http scheme", + "client identifier URL MUST have either an https or http scheme, got '"+cid.Scheme+"'", "https://indieauth.net/source/#client-identifier", ) } - if cid.Path == "" || strings.Contains(cid.Path, "/.") || strings.Contains(cid.Path, "/..") { + if cid.Path == "" { + cid.Path = "/" + } + + if strings.Contains(cid.Path, "/.") || strings.Contains(cid.Path, "/..") { return nil, NewError( ErrorCodeInvalidRequest, "client identifier URL MUST contain a path component and MUST NOT contain "+ - "single-dot or double-dot path segments", + "single-dot or double-dot path segments, got '"+cid.Path+"'", "https://indieauth.net/source/#client-identifier", ) } @@ -54,7 +60,7 @@ func ParseClientID(src string) (*ClientID, error) { if cid.Fragment != "" { return nil, NewError( ErrorCodeInvalidRequest, - "client identifier URL MUST NOT contain a fragment component", + "client identifier URL MUST NOT contain a fragment component, got '"+cid.Fragment+"'", "https://indieauth.net/source/#client-identifier", ) } @@ -62,7 +68,8 @@ func ParseClientID(src string) (*ClientID, error) { if cid.User != nil { return nil, NewError( ErrorCodeInvalidRequest, - "client identifier URL MUST NOT contain a username or password component", + "client identifier URL MUST NOT contain a username or password component, got '"+ + cid.User.String()+"'", "https://indieauth.net/source/#client-identifier", ) } @@ -71,7 +78,7 @@ func ParseClientID(src string) (*ClientID, error) { if domain == "" { return nil, NewError( ErrorCodeInvalidRequest, - "client host name MUST be domain name or a loopback interface", + "client host name MUST be domain name or a loopback interface, got '"+domain+"'", "https://indieauth.net/source/#client-identifier", ) } @@ -102,10 +109,15 @@ func ParseClientID(src string) (*ClientID, error) { } // TestClientID returns valid random generated ClientID for tests. -func TestClientID(tb testing.TB) *ClientID { +func TestClientID(tb testing.TB, forceURL ...string) *ClientID { tb.Helper() - clientID, err := ParseClientID("https://example.com/") + in := "https://app.example.com/" + if len(forceURL) > 0 { + in = forceURL[0] + } + + clientID, err := ParseClientID(in) if err != nil { tb.Fatal(err) } @@ -147,6 +159,11 @@ func (cid ClientID) MarshalJSON() ([]byte, error) { return []byte(strconv.Quote(cid.String())), nil } +// IsEqual checks what cid is equal to provided v. +func (cid ClientID) IsEqual(v ClientID) bool { + return cid.clientID.String() == v.clientID.String() +} + // URL returns url.URL representation of client ID. func (cid ClientID) URL() *url.URL { out, _ := url.Parse(cid.clientID.String()) @@ -156,5 +173,17 @@ func (cid ClientID) URL() *url.URL { // String returns string representation of client ID. func (cid ClientID) String() string { + if cid.clientID == nil { + return "" + } + return cid.clientID.String() } + +func (cid ClientID) GoString() string { + if cid.clientID == nil { + return "domain.ClientID(" + common.Und + ")" + } + + return "domain.ClientID(" + cid.clientID.String() + ")" +} diff --git a/internal/domain/code_challenge_method_test.go b/internal/domain/code_challenge_method_test.go index 71d6dfa..e196766 100644 --- a/internal/domain/code_challenge_method_test.go +++ b/internal/domain/code_challenge_method_test.go @@ -114,7 +114,7 @@ func TestCodeChallengeMethod_String(t *testing.T) { func TestCodeChallengeMethod_Validate(t *testing.T) { t.Parallel() - verifier, err := random.String(gofakeit.Number(43, 128)) + verifier, err := random.String(uint8(gofakeit.Number(43, 128))) if err != nil { t.Fatalf("%+v", err) } diff --git a/internal/domain/config.go b/internal/domain/config.go index 8ff61d9..492b1ad 100644 --- a/internal/domain/config.go +++ b/internal/domain/config.go @@ -29,7 +29,6 @@ type ( Port string `yaml:"port"` Protocol string `yaml:"protocol"` RootURL string `yaml:"rootUrl"` - StaticRootPath string `yaml:"staticRootPath"` StaticURLPrefix string `yaml:"staticUrlPrefix"` EnablePprof bool `yaml:"enablePprof"` } @@ -44,14 +43,14 @@ type ( // exchange it for a token or user information. ConfigCode struct { Expiry time.Duration `yaml:"expiry"` // 10m - Length int `yaml:"length"` // 32 + Length uint8 `yaml:"length"` // 32 } ConfigJWT struct { Expiry time.Duration `yaml:"expiry"` // 1h Algorithm string `yaml:"algorithm"` // HS256 Secret string `yaml:"secret"` - NonceLength int `yaml:"nonceLength"` // 22 + NonceLength uint8 `yaml:"nonceLength"` // 22 } ConfigIndieAuth struct { @@ -62,7 +61,7 @@ type ( ConfigTicketAuth struct { Expiry time.Duration `yaml:"expiry"` // 1m - Length int `yaml:"length"` // 24 + Length uint8 `yaml:"length"` // 24 } ConfigRelMeAuth struct { @@ -95,7 +94,6 @@ func TestConfig(tb testing.TB) *Config { Port: "3000", Protocol: "http", RootURL: "{{protocol}}://{{domain}}:{{port}}/", - StaticRootPath: "/", StaticURLPrefix: "/static", }, Database: ConfigDatabase{ @@ -136,7 +134,6 @@ func (cs ConfigServer) GetRootURL() string { "host": cs.Host, "port": cs.Port, "protocol": cs.Protocol, - "staticRootPath": cs.StaticRootPath, "staticUrlPrefix": cs.StaticURLPrefix, }) } diff --git a/internal/domain/me.go b/internal/domain/me.go index 9a6e6a3..cb6f146 100644 --- a/internal/domain/me.go +++ b/internal/domain/me.go @@ -31,7 +31,7 @@ func ParseMe(raw string) (*Me, error) { if id.Scheme != "http" && id.Scheme != "https" { return nil, NewError( ErrorCodeInvalidRequest, - "profile URL MUST have either an https or http scheme", + "profile URL MUST have either an https or http scheme, got '"+id.Scheme+"'", "https://indieauth.net/source/#user-profile-url", "", ) @@ -45,7 +45,7 @@ func ParseMe(raw string) (*Me, error) { return nil, NewError( ErrorCodeInvalidRequest, "profile URL MUST contain a path component (/ is a valid path), MUST NOT contain single-dot "+ - "or double-dot path segments", + "or double-dot path segments, got '"+id.Path+"'", "https://indieauth.net/source/#user-profile-url", "", ) @@ -54,7 +54,7 @@ func ParseMe(raw string) (*Me, error) { if id.Fragment != "" { return nil, NewError( ErrorCodeInvalidRequest, - "profile URL MUST NOT contain a fragment component", + "profile URL MUST NOT contain a fragment component, got '"+id.Fragment+"'", "https://indieauth.net/source/#user-profile-url", "", ) @@ -63,7 +63,7 @@ func ParseMe(raw string) (*Me, error) { if id.User != nil { return nil, NewError( ErrorCodeInvalidRequest, - "profile URL MUST NOT contain a username or password component", + "profile URL MUST NOT contain a username or password component, got '"+id.User.String()+"'", "https://indieauth.net/source/#user-profile-url", "", ) @@ -72,7 +72,7 @@ func ParseMe(raw string) (*Me, error) { if id.Host == "" { return nil, NewError( ErrorCodeInvalidRequest, - "profile host name MUST be a domain name", + "profile host name MUST be a domain name, got '"+id.Host+"'", "https://indieauth.net/source/#user-profile-url", "", ) @@ -81,16 +81,16 @@ func ParseMe(raw string) (*Me, error) { if _, port, _ := net.SplitHostPort(id.Host); port != "" { return nil, NewError( ErrorCodeInvalidRequest, - "profile MUST NOT contain a port", + "profile MUST NOT contain a port, got '"+port+"'", "https://indieauth.net/source/#user-profile-url", "", ) } - if net.ParseIP(id.Host) != nil { + if out := net.ParseIP(id.Host); out != nil { return nil, NewError( ErrorCodeInvalidRequest, - "profile MUST NOT be ipv4 or ipv6 addresses", + "profile MUST NOT be ipv4 or ipv6 addresses, got '"+out.String()+"'", "https://indieauth.net/source/#user-profile-url", "", ) @@ -103,12 +103,12 @@ func ParseMe(raw string) (*Me, error) { func TestMe(tb testing.TB, src string) *Me { tb.Helper() - me, err := ParseMe(src) + u, err := url.Parse(src) if err != nil { tb.Fatal(err) } - return me + return &Me{id: u} } // UnmarshalForm implements custom unmarshler for form values. diff --git a/internal/domain/metadata.go b/internal/domain/metadata.go index f078453..e1fac9f 100644 --- a/internal/domain/metadata.go +++ b/internal/domain/metadata.go @@ -14,7 +14,7 @@ type Metadata struct { // issuer URL could be https://example.com/, or for a metadata URL of // https://example.com/wp-json/indieauth/1.0/metadata, the issuer URL // could be https://example.com/wp-json/indieauth/1.0 - Issuer *ClientID + Issuer *url.URL // The Authorization Endpoint. AuthorizationEndpoint *url.URL @@ -81,7 +81,11 @@ func TestMetadata(tb testing.TB) *Metadata { tb.Helper() return &Metadata{ - Issuer: TestClientID(tb), + Issuer: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/.well-known/oauth-authorization-server", + }, AuthorizationEndpoint: &url.URL{Scheme: "https", Host: "indieauth.example.com", Path: "/auth"}, TokenEndpoint: &url.URL{Scheme: "https", Host: "indieauth.example.com", Path: "/token"}, TicketEndpoint: &url.URL{Scheme: "https", Host: "auth.example.org", Path: "/ticket"}, diff --git a/internal/domain/provider.go b/internal/domain/provider.go index 377bac1..bb37d9b 100644 --- a/internal/domain/provider.go +++ b/internal/domain/provider.go @@ -1,10 +1,9 @@ package domain import ( + "net/url" "path" "strings" - - http "github.com/valyala/fasthttp" ) // Provider represent 3rd party RelMeAuth provider. @@ -91,9 +90,10 @@ var ( // AuthCodeURL returns URL for authorize user in RelMeAuth client. func (p Provider) AuthCodeURL(state string) string { - uri := http.AcquireURI() - defer http.ReleaseURI(uri) - uri.Update(p.AuthURL) + u, err := url.Parse(p.AuthURL) + if err != nil { + return "" + } for key, val := range map[string]string{ "client_id": p.ClientID, @@ -102,8 +102,8 @@ func (p Provider) AuthCodeURL(state string) string { "scope": strings.Join(p.Scopes, " "), "state": state, } { - uri.QueryArgs().Set(key, val) + u.Query().Set(key, val) } - return uri.String() + return u.String() } diff --git a/internal/domain/scope.go b/internal/domain/scope.go index 7367b32..126a52d 100644 --- a/internal/domain/scope.go +++ b/internal/domain/scope.go @@ -80,6 +80,22 @@ func ParseScope(uid string) (Scope, error) { return ScopeUnd, fmt.Errorf("%w: %s", ErrScopeUnknown, uid) } +func (s *Scope) UnmarshalJSON(v []byte) error { + src, err := strconv.Unquote(string(v)) + if err != nil { + return fmt.Errorf("Scope: UnmarshalJSON: cannot unquote string: %w", err) + } + + out, err := ParseScope(src) + if err != nil { + return fmt.Errorf("Scopes: UnmarshalJSON: cannot parse scope: %w", err) + } + + *s = out + + return nil +} + func (s Scope) MarshalJSON() ([]byte, error) { return []byte(strconv.Quote(s.uid)), nil } diff --git a/internal/domain/session.go b/internal/domain/session.go index bda7542..a8f7e37 100644 --- a/internal/domain/session.go +++ b/internal/domain/session.go @@ -9,9 +9,9 @@ import ( //nolint:tagliatelle type Session struct { - ClientID *ClientID `json:"client_id"` + ClientID ClientID `json:"client_id"` RedirectURI *url.URL `json:"redirect_uri"` - Me *Me `json:"me"` + Me Me `json:"me"` Profile *Profile `json:"profile,omitempty"` Scope Scopes `json:"scope"` CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method,omitempty"` @@ -31,12 +31,12 @@ func TestSession(tb testing.TB) *Session { } return &Session{ - ClientID: TestClientID(tb), + ClientID: *TestClientID(tb), Code: code, CodeChallenge: "hackme", CodeChallengeMethod: CodeChallengeMethodPLAIN, Profile: TestProfile(tb), - Me: TestMe(tb, "https://user.example.net/"), + Me: *TestMe(tb, "https://user.example.net/"), RedirectURI: &url.URL{Scheme: "https", Host: "example.com", Path: "/callback"}, Scope: Scopes{ ScopeEmail, diff --git a/internal/domain/token.go b/internal/domain/token.go index 8acd927..fcb1a5c 100644 --- a/internal/domain/token.go +++ b/internal/domain/token.go @@ -2,13 +2,14 @@ package domain import ( "fmt" + "net/http" "testing" "time" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" - http "github.com/valyala/fasthttp" + "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/random" ) @@ -17,8 +18,8 @@ type ( Token struct { CreatedAt time.Time Expiry time.Time - ClientID *ClientID - Me *Me + ClientID ClientID + Me Me Scope Scopes AccessToken string RefreshToken string @@ -27,12 +28,12 @@ type ( // NewTokenOptions contains options for NewToken function. NewTokenOptions struct { Expiration time.Duration - Issuer *ClientID - Subject *Me + Issuer ClientID + Subject Me Scope Scopes Secret []byte Algorithm string - NonceLength int + NonceLength uint8 } ) @@ -42,8 +43,8 @@ type ( var DefaultNewTokenOptions = NewTokenOptions{ Expiration: 0, Scope: nil, - Issuer: nil, - Subject: nil, + Issuer: ClientID{}, + Subject: Me{}, Secret: nil, Algorithm: "HS256", NonceLength: 32, @@ -82,7 +83,7 @@ func NewToken(opts NewTokenOptions) (*Token, error) { } } - if opts.Issuer != nil { + if opts.Issuer.clientID != nil { if err = tkn.Set(jwt.IssuerKey, opts.Issuer.String()); err != nil { return nil, fmt.Errorf("failed to set JWT token field: %w", err) } @@ -157,8 +158,8 @@ func TestToken(tb testing.TB) *Token { return &Token{ CreatedAt: now.Add(-1 * time.Hour), Expiry: now.Add(1 * time.Hour), - ClientID: cid, - Me: me, + ClientID: *cid, + Me: *me, Scope: scope, AccessToken: string(accessToken), RefreshToken: "", // TODO(toby3d) @@ -171,7 +172,7 @@ func (t Token) SetAuthHeader(r *http.Request) { return } - r.Header.Set(http.HeaderAuthorization, t.String()) + r.Header.Set(common.HeaderAuthorization, t.String()) } // String returns string representation of token. diff --git a/internal/domain/token_test.go b/internal/domain/token_test.go index 1efc2b7..0706075 100644 --- a/internal/domain/token_test.go +++ b/internal/domain/token_test.go @@ -1,13 +1,12 @@ package domain_test import ( - "bytes" "fmt" + "net/http" "testing" "time" - http "github.com/valyala/fasthttp" - + "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" ) @@ -40,16 +39,13 @@ func TestNewToken(t *testing.T) { func TestToken_SetAuthHeader(t *testing.T) { t.Parallel() - token := domain.TestToken(t) - expResult := []byte("Bearer " + token.AccessToken) + in := domain.TestToken(t) + req, _ := http.NewRequest(http.MethodGet, "https://example.com/", nil) + in.SetAuthHeader(req) - req := http.AcquireRequest() - defer http.ReleaseRequest(req) - token.SetAuthHeader(req) - - result := req.Header.Peek(http.HeaderAuthorization) - if result == nil || !bytes.Equal(result, expResult) { - t.Errorf("SetAuthHeader(%+v) = %s, want %s", req, result, expResult) + exp := "Bearer " + in.AccessToken + if out := req.Header.Get(common.HeaderAuthorization); out != exp { + t.Errorf("SetAuthHeader(%+v) = %s, want %s", req, out, exp) } } @@ -57,9 +53,9 @@ func TestToken_String(t *testing.T) { t.Parallel() token := domain.TestToken(t) - expResult := "Bearer " + token.AccessToken + exp := "Bearer " + token.AccessToken - if result := token.String(); result != expResult { - t.Errorf("String() = %s, want %s", result, expResult) + if out := token.String(); out != exp { + t.Errorf("String() = %s, want %s", out, exp) } } diff --git a/internal/domain/url.go b/internal/domain/url.go index 6a95bf7..84a9686 100644 --- a/internal/domain/url.go +++ b/internal/domain/url.go @@ -5,6 +5,8 @@ import ( "net/url" "strconv" "testing" + + "source.toby3d.me/toby3d/auth/internal/common" ) // URL describe any valid HTTP URL. @@ -75,3 +77,11 @@ func (u *URL) UnmarshalJSON(v []byte) error { func (u URL) MarshalJSON() ([]byte, error) { return []byte(strconv.Quote(u.String())), nil } + +func (u URL) GoString() string { + if u.URL == nil { + return "domain.URL(" + common.Und + ")" + } + + return "domain.URL(" + u.URL.String() + ")" +} diff --git a/internal/health/delivery/http/health_http.go b/internal/health/delivery/http/health_http.go index 263155a..30c4a89 100644 --- a/internal/health/delivery/http/health_http.go +++ b/internal/health/delivery/http/health_http.go @@ -5,7 +5,6 @@ import ( "net/http" "source.toby3d.me/toby3d/auth/internal/common" - "source.toby3d.me/toby3d/auth/internal/middleware" ) type Handler struct{} @@ -14,8 +13,8 @@ func NewHandler() *Handler { return &Handler{} } -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - http.HandlerFunc(middleware.HandlerFunc(h.handleFunc).Intercept(middleware.LogFmt())).ServeHTTP(w, r) +func (h *Handler) Handler() http.Handler { + return http.HandlerFunc(h.handleFunc) } func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) { diff --git a/internal/health/delivery/http/health_http_test.go b/internal/health/delivery/http/health_http_test.go index 3e8ebd2..68cd586 100644 --- a/internal/health/delivery/http/health_http_test.go +++ b/internal/health/delivery/http/health_http_test.go @@ -2,11 +2,10 @@ package http_test import ( "io" + "net/http" "net/http/httptest" "testing" - http "github.com/valyala/fasthttp" - delivery "source.toby3d.me/toby3d/auth/internal/health/delivery/http" ) @@ -15,7 +14,10 @@ func TestRequestHandler(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "https://example.com/health", nil) w := httptest.NewRecorder() - delivery.NewHandler().ServeHTTP(w, req) + + delivery.NewHandler(). + Handler(). + ServeHTTP(w, req) resp := w.Result() diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go index d51b62a..1975704 100644 --- a/internal/httputil/httputil.go +++ b/internal/httputil/httputil.go @@ -2,33 +2,74 @@ package httputil import ( "bytes" - "encoding/json" "fmt" + "io" + "io/ioutil" + "net/http" "net/url" "strings" + "github.com/goccy/go-json" "github.com/tomnomnom/linkheader" - http "github.com/valyala/fasthttp" + "golang.org/x/exp/slices" "willnorris.com/go/microformats" + "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" ) +const RelIndieauthMetadata = "indieauth-metadata" + var ErrEndpointNotExist = domain.NewError( domain.ErrorCodeServerError, "cannot found any endpoints", "https://indieauth.net/source/#discovery-0", ) -func ExtractEndpoints(resp *http.Response, rel string) []*url.URL { +func ExtractFromMetadata(client *http.Client, u string) (*domain.Metadata, error) { + req, err := http.NewRequest(http.MethodGet, u, nil) + if err != nil { + return nil, err + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + buf := bytes.NewBuffer(body) + + endpoints := ExtractEndpoints(buf, resp.Request.URL, resp.Header.Get(common.HeaderLink), RelIndieauthMetadata) + if len(endpoints) == 0 { + return nil, ErrEndpointNotExist + } + + if resp, err = client.Get(endpoints[len(endpoints)-1].String()); err != nil { + return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err) + } + + result := new(domain.Metadata) + if err = json.NewDecoder(resp.Body).Decode(result); err != nil { + return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err) + } + + return result, nil +} + +func ExtractEndpoints(body io.Reader, u *url.URL, linkHeader, rel string) []*url.URL { results := make([]*url.URL, 0) - urls, err := ExtractEndpointsFromHeader(resp, rel) + urls, err := ExtractEndpointsFromHeader(linkHeader, rel) if err == nil { results = append(results, urls...) } - urls, err = ExtractEndpointsFromBody(resp, rel) + urls, err = ExtractEndpointsFromBody(body, u, rel) if err == nil { results = append(results, urls...) } @@ -36,15 +77,15 @@ func ExtractEndpoints(resp *http.Response, rel string) []*url.URL { return results } -func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*url.URL, error) { +func ExtractEndpointsFromHeader(linkHeader, rel string) ([]*url.URL, error) { results := make([]*url.URL, 0) - for _, link := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) { + for _, link := range linkheader.Parse(linkHeader) { if !strings.EqualFold(link.Rel, rel) { continue } - u, err := url.ParseRequestURI(link.URL) + u, err := url.Parse(link.URL) if err != nil { return nil, fmt.Errorf("cannot parse header endpoint: %w", err) } @@ -55,8 +96,8 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*url.URL, er return results, nil } -func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*url.URL, error) { - endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), nil).Rels[rel] +func ExtractEndpointsFromBody(body io.Reader, u *url.URL, rel string) ([]*url.URL, error) { + endpoints, ok := microformats.Parse(body, u).Rels[rel] if !ok || len(endpoints) == 0 { return nil, ErrEndpointNotExist } @@ -75,58 +116,23 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*url.URL, erro return results, nil } -func ExtractMetadata(resp *http.Response, client *http.Client) (*domain.Metadata, error) { - endpoints := ExtractEndpoints(resp, "indieauth-metadata") - if len(endpoints) == 0 { - return nil, ErrEndpointNotExist - } - - _, body, err := client.Get(nil, endpoints[len(endpoints)-1].String()) - if err != nil { - return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err) - } - - result := new(domain.Metadata) - if err = json.Unmarshal(body, result); err != nil { - return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err) - } - - return result, nil -} - -func ExtractProperty(resp *http.Response, itemType, key string) []interface{} { - //nolint:exhaustivestruct // only Host part in url.URL is needed - data := microformats.Parse(bytes.NewReader(resp.Body()), &url.URL{ - Host: string(resp.Header.Peek(http.HeaderHost)), - }) - - return findProperty(data.Items, itemType, key) -} - -func contains(src []string, find string) bool { - for i := range src { - if !strings.EqualFold(src[i], find) { - continue - } - - return true - } - - return false -} - -func findProperty(src []*microformats.Microformat, itemType, key string) []interface{} { - for _, item := range src { - if contains(item.Type, itemType) { - return item.Properties[key] - } - - result := findProperty(item.Children, itemType, key) - if result == nil { - continue - } - - return result +func ExtractProperty(body io.Reader, u *url.URL, itemType, key string) []any { + if data := microformats.Parse(body, u); data != nil { + return FindProperty(data.Items, itemType, key) + } + + return nil +} + +func FindProperty(src []*microformats.Microformat, itemType, key string) []any { + for _, item := range src { + if slices.Contains(item.Type, itemType) { + return item.Properties[key] + } + + if result := FindProperty(item.Children, itemType, key); result != nil { + return result + } } return nil diff --git a/internal/httputil/httputil_test.go b/internal/httputil/httputil_test.go index 20a47ab..065e285 100644 --- a/internal/httputil/httputil_test.go +++ b/internal/httputil/httputil_test.go @@ -1,30 +1,72 @@ package httputil_test import ( + "io/ioutil" + "net/http" + "net/url" + "strings" "testing" - http "github.com/valyala/fasthttp" + "github.com/google/go-cmp/cmp" "source.toby3d.me/toby3d/auth/internal/httputil" ) const testBody = ` + + + + -
+

Sample Name

` +func TestExtractEndpointsFromBody(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodGet, "https://example.com/", nil) + if err != nil { + t.Fatal(err) + } + + in := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader(testBody)), + Request: req, + } + + out, err := httputil.ExtractEndpointsFromBody(in.Body, req.URL, "lipsum") + if err != nil { + t.Fatal(err) + } + + exp := []*url.URL{ + {Scheme: "https", Host: "example.com", Path: "/"}, + {Scheme: "https", Host: "example.net", Path: "/"}, + } + + if !cmp.Equal(out, exp) { + t.Errorf(`ExtractProperty(resp, "h-card", "name") = %+s, want %+s`, out, exp) + } +} + func TestExtractProperty(t *testing.T) { t.Parallel() - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - resp.SetBodyString(testBody) + req, err := http.NewRequest(http.MethodGet, "https://example.com/", nil) + if err != nil { + t.Fatal(err) + } - results := httputil.ExtractProperty(resp, "h-card", "name") - if results == nil || results[0] != "Sample Name" { - t.Errorf(`ExtractProperty(resp, "h-card", "name") = %+s, want %+s`, results, []string{"Sample Name"}) + in := &http.Response{ + Body: ioutil.NopCloser(strings.NewReader(testBody)), + Request: req, + } + + if out := httputil.ExtractProperty(in.Body, req.URL, "h-app", "name"); out == nil || out[0] != "Sample Name" { + t.Errorf(`ExtractProperty(%s, %s, %s) = %+s, want %+s`, req.URL, "h-app", "name", out, + []string{"Sample Name"}) } } diff --git a/internal/metadata/delivery/http/metadata_http.go b/internal/metadata/delivery/http/metadata_http.go index d65d553..ea61746 100644 --- a/internal/metadata/delivery/http/metadata_http.go +++ b/internal/metadata/delivery/http/metadata_http.go @@ -1,13 +1,12 @@ package http import ( - "github.com/fasthttp/router" + "net/http" + "github.com/goccy/go-json" - http "github.com/valyala/fasthttp" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" - "source.toby3d.me/toby3d/middleware" ) type ( @@ -60,28 +59,29 @@ type ( UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` } - RequestHandler struct { + Handler struct { metadata *domain.Metadata } ) -func NewRequestHandler(metadata *domain.Metadata) *RequestHandler { - return &RequestHandler{ +func NewHandler(metadata *domain.Metadata) *Handler { + return &Handler{ metadata: metadata, } } -func (h *RequestHandler) Register(r *router.Router) { - chain := middleware.Chain{ - middleware.LogFmt(), - } - - r.GET("/.well-known/oauth-authorization-server", chain.RequestHandler(h.read)) +func (h *Handler) Handler() http.Handler { + return http.HandlerFunc(h.handleFunc) } -func (h *RequestHandler) read(ctx *http.RequestCtx) { - ctx.SetStatusCode(http.StatusOK) - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) +func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) { + if r.Method != "" && r.Method != http.MethodGet { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + + return + } + + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) scopes, responseTypes, grantTypes, codeChallengeMethods := make([]string, 0), make([]string, 0), make([]string, 0), make([]string, 0) @@ -103,7 +103,7 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) { h.metadata.CodeChallengeMethodsSupported[i].String()) } - _ = json.NewEncoder(ctx).Encode(&MetadataResponse{ + _ = json.NewEncoder(w).Encode(&MetadataResponse{ AuthorizationEndpoint: h.metadata.AuthorizationEndpoint.String(), IntrospectionEndpoint: h.metadata.IntrospectionEndpoint.String(), Issuer: h.metadata.Issuer.String(), @@ -123,4 +123,6 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) { // client_secret_basic according to RFC8414. RevocationEndpointAuthMethodsSupported: h.metadata.RevocationEndpointAuthMethodsSupported, }) + + w.WriteHeader(http.StatusOK) } diff --git a/internal/metadata/delivery/http/metadata_http_test.go b/internal/metadata/delivery/http/metadata_http_test.go index c8e545f..8c2dde1 100644 --- a/internal/metadata/delivery/http/metadata_http_test.go +++ b/internal/metadata/delivery/http/metadata_http_test.go @@ -1,40 +1,36 @@ package http_test import ( + "net/http" + "net/http/httptest" "testing" - "github.com/fasthttp/router" "github.com/goccy/go-json" - http "github.com/valyala/fasthttp" "source.toby3d.me/toby3d/auth/internal/domain" delivery "source.toby3d.me/toby3d/auth/internal/metadata/delivery/http" - "source.toby3d.me/toby3d/auth/internal/testing/httptest" ) func TestMetadata(t *testing.T) { t.Parallel() - r := router.New() metadata := domain.TestMetadata(t) - delivery.NewRequestHandler(metadata).Register(r) - client, _, cleanup := httptest.New(t, r.Handler) - t.Cleanup(cleanup) + req := httptest.NewRequest(http.MethodGet, "https://example.com/.well-known/oauth-authorization-server", nil) - const requestURL string = "https://example.com/.well-known/oauth-authorization-server" + w := httptest.NewRecorder() + delivery.NewHandler(metadata). + Handler(). + ServeHTTP(w, req) - status, body, err := client.Get(nil, requestURL) - if err != nil { - t.Fatal(err) + resp := w.Result() + + if resp.StatusCode != http.StatusOK { + t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK) } - if status != http.StatusOK { - t.Errorf("GET %s = %d, want %d", requestURL, status, http.StatusOK) - } - - result := new(delivery.MetadataResponse) - if err = json.Unmarshal(body, result); err != nil { + out := new(delivery.MetadataResponse) + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { t.Fatal(err) } } diff --git a/internal/metadata/repository.go b/internal/metadata/repository.go index 0672df4..0b4cd4c 100644 --- a/internal/metadata/repository.go +++ b/internal/metadata/repository.go @@ -2,12 +2,14 @@ package metadata import ( "context" + "net/url" "source.toby3d.me/toby3d/auth/internal/domain" ) type Repository interface { - Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error) + Create(_ context.Context, _ *url.URL, _ domain.Metadata) error + Get(_ context.Context, u *url.URL) (*domain.Metadata, error) } var ErrNotExist error = domain.NewError( diff --git a/internal/metadata/repository/http/http_metadata.go b/internal/metadata/repository/http/http_metadata.go index 618432a..bea3b6c 100644 --- a/internal/metadata/repository/http/http_metadata.go +++ b/internal/metadata/repository/http/http_metadata.go @@ -2,26 +2,29 @@ package http import ( "context" - "encoding/json" "fmt" + "net/http" + "net/url" - http "github.com/valyala/fasthttp" + "github.com/goccy/go-json" + "github.com/tomnomnom/linkheader" + "willnorris.com/go/microformats" + "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" - "source.toby3d.me/toby3d/auth/internal/httputil" "source.toby3d.me/toby3d/auth/internal/metadata" ) type ( //nolint:tagliatelle,lll - Metadata struct { - Issuer *domain.ClientID `json:"issuer"` - AuthorizationEndpoint *domain.URL `json:"authorization_endpoint"` - IntrospectionEndpoint *domain.URL `json:"introspection_endpoint"` - RevocationEndpoint *domain.URL `json:"revocation_endpoint,omitempty"` - ServiceDocumentation *domain.URL `json:"service_documentation,omitempty"` - TokenEndpoint *domain.URL `json:"token_endpoint"` - UserinfoEndpoint *domain.URL `json:"userinfo_endpoint,omitempty"` + Response struct { + Issuer domain.URL `json:"issuer"` + AuthorizationEndpoint domain.URL `json:"authorization_endpoint"` + IntrospectionEndpoint domain.URL `json:"introspection_endpoint"` + RevocationEndpoint domain.URL `json:"revocation_endpoint,omitempty"` + ServiceDocumentation domain.URL `json:"service_documentation,omitempty"` + TokenEndpoint domain.URL `json:"token_endpoint"` + UserinfoEndpoint domain.URL `json:"userinfo_endpoint,omitempty"` CodeChallengeMethodsSupported []domain.CodeChallengeMethod `json:"code_challenge_methods_supported"` GrantTypesSupported []domain.GrantType `json:"grant_types_supported,omitempty"` ResponseTypesSupported []domain.ResponseType `json:"response_types_supported,omitempty"` @@ -29,6 +32,11 @@ type ( IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"` + + // Extensions + TicketEndpoint domain.URL `json:"ticket_endpoint"` + Micropub domain.URL `json:"micropub"` + Microsub domain.URL `json:"microsub"` } httpMetadataRepository struct { @@ -36,7 +44,7 @@ type ( } ) -const DefaultMaxRedirectsCount int = 10 +const relIndieauthMetadata = "indieauth-metadata" func NewHTTPMetadataRepository(client *http.Client) metadata.Repository { return &httpMetadataRepository{ @@ -44,48 +52,127 @@ func NewHTTPMetadataRepository(client *http.Client) metadata.Repository { } } -func (repo *httpMetadataRepository) Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error) { - req := http.AcquireRequest() - defer http.ReleaseRequest(req) - req.SetRequestURI(me.String()) - req.Header.SetMethod(http.MethodGet) - - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - - if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil { - return nil, fmt.Errorf("failed to make a request to the client: %w", err) - } - - endpoints := httputil.ExtractEndpoints(resp, "indieauth-metadata") - if len(endpoints) == 0 { - return nil, metadata.ErrNotExist - } - - _, body, err := repo.client.Get(nil, endpoints[len(endpoints)-1].String()) - if err != nil { - return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err) - } - - data := new(Metadata) - if err = json.Unmarshal(body, data); err != nil { - return nil, fmt.Errorf("cannot unmarshal metadata configuration: %w", err) - } - - //nolint:exhaustivestruct // TODO(toby3d) - return &domain.Metadata{ - AuthorizationEndpoint: data.AuthorizationEndpoint.URL, - AuthorizationResponseIssParameterSupported: data.AuthorizationResponseIssParameterSupported, - CodeChallengeMethodsSupported: data.CodeChallengeMethodsSupported, - GrantTypesSupported: data.GrantTypesSupported, - Issuer: data.Issuer, - ResponseTypesSupported: data.ResponseTypesSupported, - ScopesSupported: data.ScopesSupported, - ServiceDocumentation: data.ServiceDocumentation.URL, - TokenEndpoint: data.TokenEndpoint.URL, - // TODO(toby3d): support extensions? - // Micropub: data.Micropub, - // Microsub: data.Microsub, - // TicketEndpoint: data.TicketEndpoint, - }, nil +// WARN(toby3d): not implemented. +func (httpMetadataRepository) Create(_ context.Context, _ *url.URL, _ domain.Metadata) error { + return nil +} + +func (repo *httpMetadataRepository) Get(_ context.Context, u *url.URL) (*domain.Metadata, error) { + resp, err := repo.client.Get(u.String()) + if err != nil { + return nil, fmt.Errorf("cannot make request to provided Me: %w", err) + } + + relVals := make(map[string][]string) + for _, link := range linkheader.Parse(resp.Header.Get(common.HeaderLink)) { + populateBuffer(relVals, link.Rel, link.URL) + } + + if mf2 := microformats.Parse(resp.Body, resp.Request.URL); mf2 != nil { + for rel, vals := range mf2.Rels { + if len(vals) > 0 { + populateBuffer(relVals, rel, vals[0]) + } + } + } + + out := new(domain.Metadata) + // NOTE(toby3d): fetch all from metadata endpoint if exists + if endpoints, ok := relVals["indieauth-metadata"]; ok { + if resp, err = repo.client.Get(endpoints[0]); err != nil { + return nil, fmt.Errorf("cannot fetch indieauth-metadata endpoint: %w", err) + } + + in := NewResponse() + if err = in.bind(resp); err != nil { + return nil, err + } + + in.populate(out) + + return out, nil + } + + // NOTE(toby3d): metadata not exists, fallback for old clients + for key, dst := range map[string]**url.URL{ + "authorization_endpoint": &out.AuthorizationEndpoint, + "micropub": &out.MicropubEndpoint, + "microsub": &out.MicrosubEndpoint, + "ticket_endpoint": &out.TicketEndpoint, + "token_endpoint": &out.TokenEndpoint, + } { + if values, ok := relVals[key]; ok && len(values) > 0 { + if u, err := url.Parse(values[0]); err == nil { + *dst = resp.Request.URL.ResolveReference(u) + } + } + } + + return out, nil +} + +func populateBuffer(dst map[string][]string, rel, u string) { + if _, ok := dst[rel]; !ok { + dst[rel] = make([]string, 0) + } + + dst[rel] = append(dst[rel], u) +} + +func NewResponse() *Response { + return &Response{ + CodeChallengeMethodsSupported: make([]domain.CodeChallengeMethod, 0), + GrantTypesSupported: make([]domain.GrantType, 0), + ResponseTypesSupported: make([]domain.ResponseType, 0), + ScopesSupported: make([]domain.Scope, 0), + IntrospectionEndpointAuthMethodsSupported: make([]string, 0), + RevocationEndpointAuthMethodsSupported: make([]string, 0), + } +} + +func (r *Response) bind(resp *http.Response) error { + if err := json.NewDecoder(resp.Body).Decode(r); err != nil { + return fmt.Errorf("cannot unmarshal metadata configuration: %w", err) + } + + return nil +} + +func (r Response) populate(dst *domain.Metadata) { + dst.AuthorizationEndpoint = r.AuthorizationEndpoint.URL + dst.AuthorizationResponseIssParameterSupported = r.AuthorizationResponseIssParameterSupported + dst.IntrospectionEndpoint = r.IntrospectionEndpoint.URL + dst.Issuer = r.Issuer.URL + dst.MicropubEndpoint = r.Micropub.URL + dst.MicrosubEndpoint = r.Microsub.URL + dst.RevocationEndpoint = r.RevocationEndpoint.URL + dst.ServiceDocumentation = r.ServiceDocumentation.URL + dst.TicketEndpoint = r.TicketEndpoint.URL + dst.TokenEndpoint = r.TokenEndpoint.URL + dst.UserinfoEndpoint = r.UserinfoEndpoint.URL + + for _, scope := range r.ScopesSupported { + dst.ScopesSupported = append(dst.ScopesSupported, scope) + } + + for _, method := range r.RevocationEndpointAuthMethodsSupported { + dst.RevocationEndpointAuthMethodsSupported = append(dst.RevocationEndpointAuthMethodsSupported, method) + } + + for _, responseType := range r.ResponseTypesSupported { + dst.ResponseTypesSupported = append(dst.ResponseTypesSupported, responseType) + } + + for _, method := range r.IntrospectionEndpointAuthMethodsSupported { + dst.IntrospectionEndpointAuthMethodsSupported = append(dst.IntrospectionEndpointAuthMethodsSupported, + method) + } + + for _, grantType := range r.GrantTypesSupported { + dst.GrantTypesSupported = append(dst.GrantTypesSupported, grantType) + } + + for _, method := range r.CodeChallengeMethodsSupported { + dst.CodeChallengeMethodsSupported = append(dst.CodeChallengeMethodsSupported, method) + } } diff --git a/internal/metadata/repository/http/http_metadata_test.go b/internal/metadata/repository/http/http_metadata_test.go new file mode 100644 index 0000000..2b4a791 --- /dev/null +++ b/internal/metadata/repository/http/http_metadata_test.go @@ -0,0 +1,183 @@ +package http_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/goccy/go-json" + "github.com/google/go-cmp/cmp" + + "source.toby3d.me/toby3d/auth/internal/common" + "source.toby3d.me/toby3d/auth/internal/domain" + repository "source.toby3d.me/toby3d/auth/internal/metadata/repository/http" +) + +//nolint:lll,tagliatelle +type Response struct { + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported,omitempty"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + IntrospectionEndpoint string `json:"introspection_endpoint"` + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + ServiceDocumentation string `json:"service_documentation,omitempty"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` + TicketEndpoint string `json:"ticket_endpoint"` + Micropub string `json:"micropub"` + Microsub string `json:"microsub"` + AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"` +} + +const testBody string = ` + + + + Testing + %s + + +` + +//nolint:funlen +func TestGet(t *testing.T) { + t.Parallel() + + testMetadata := domain.TestMetadata(t) + + for _, tc := range []struct { + name string + header map[string]string + body map[string]string + out *domain.Metadata + }{ + { + name: "header", + header: map[string]string{ + "indieauth-metadata": "/metadata", + "authorization_endpoint": "http://example.net/authorization", + "token_endpoint": "http://example.net/tkn", + }, + out: testMetadata, + }, /*{ + name: "body", + body: map[string]string{ + "indieauth-metadata": "/metadata", + "authorization_endpoint": "http://example.net/authorization", + "token_endpoint": "http://example.net/tkn", + }, + out: &testMetadata, + }*/} { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/metadata", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + _ = json.NewEncoder(w).Encode(NewResponse(t, *testMetadata)) + }) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + links := make([]string, 0) + for k, v := range tc.header { + links = append(links, `<`+v+`>; rel="`+k+`"`) + } + + w.Header().Set(common.HeaderLink, strings.Join(links, ", ")) + + links = make([]string, 0) + for k, v := range tc.body { + links = append(links, ``) + } + + fmt.Fprintf(w, testBody, strings.Join(links, "\n")) + }) + + srv := httptest.NewUnstartedServer(mux) + srv.EnableHTTP2 = true + srv.Start() + t.Cleanup(srv.Close) + + tc.header["indieauth-metadata"] = srv.URL + tc.header["indieauth-metadata"] + + u, _ := url.Parse(srv.URL + "/") + out, err := repository.NewHTTPMetadataRepository(srv.Client()). + Get(context.Background(), u) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(tc.out, out, cmp.AllowUnexported( + domain.ClientID{}, + domain.CodeChallengeMethod{}, + domain.GrantType{}, + domain.ResponseType{}, + domain.Scope{}, + url.URL{}, + )); diff != "" { + t.Errorf("%+s", diff) + } + }) + } +} + +func NewResponse(tb testing.TB, src domain.Metadata) *Response { + tb.Helper() + + out := &Response{ + CodeChallengeMethodsSupported: make([]string, 0), + GrantTypesSupported: make([]string, 0), + ResponseTypesSupported: make([]string, 0), + ScopesSupported: make([]string, 0), + IntrospectionEndpointAuthMethodsSupported: make([]string, 0), + RevocationEndpointAuthMethodsSupported: make([]string, 0), + Issuer: src.Issuer.String(), + AuthorizationEndpoint: src.AuthorizationEndpoint.String(), + IntrospectionEndpoint: src.IntrospectionEndpoint.String(), + RevocationEndpoint: src.RevocationEndpoint.String(), + ServiceDocumentation: src.ServiceDocumentation.String(), + TokenEndpoint: src.TokenEndpoint.String(), + UserinfoEndpoint: src.UserinfoEndpoint.String(), + TicketEndpoint: src.TicketEndpoint.String(), + Micropub: src.MicropubEndpoint.String(), + Microsub: src.MicrosubEndpoint.String(), + AuthorizationResponseIssParameterSupported: src.AuthorizationResponseIssParameterSupported, + } + + for _, method := range src.CodeChallengeMethodsSupported { + out.CodeChallengeMethodsSupported = append(out.CodeChallengeMethodsSupported, method.String()) + } + + for _, grantType := range src.GrantTypesSupported { + out.GrantTypesSupported = append(out.GrantTypesSupported, grantType.String()) + } + + for _, responseType := range src.ResponseTypesSupported { + out.ResponseTypesSupported = append(out.ResponseTypesSupported, responseType.String()) + } + + for _, scope := range src.ScopesSupported { + out.ScopesSupported = append(out.ScopesSupported, scope.String()) + } + + for _, method := range src.IntrospectionEndpointAuthMethodsSupported { + out.IntrospectionEndpointAuthMethodsSupported = append(out.IntrospectionEndpointAuthMethodsSupported, + method) + } + + for _, method := range src.RevocationEndpointAuthMethodsSupported { + out.RevocationEndpointAuthMethodsSupported = append(out.RevocationEndpointAuthMethodsSupported, method) + } + + return out +} diff --git a/internal/metadata/repository/memory/memory_metadata.go b/internal/metadata/repository/memory/memory_metadata.go index cecb30a..cfa1cea 100644 --- a/internal/metadata/repository/memory/memory_metadata.go +++ b/internal/metadata/repository/memory/memory_metadata.go @@ -2,7 +2,7 @@ package memory import ( "context" - "path" + "net/url" "sync" "source.toby3d.me/toby3d/auth/internal/domain" @@ -10,27 +10,35 @@ import ( ) type memoryMetadataRepository struct { - store *sync.Map + mutex *sync.RWMutex + metadata map[string]domain.Metadata } const DefaultPathPrefix = "metadata" -func NewMemoryMetadataRepository(store *sync.Map) metadata.Repository { +func NewMemoryMetadataRepository() metadata.Repository { return &memoryMetadataRepository{ - store: store, + mutex: new(sync.RWMutex), + metadata: make(map[string]domain.Metadata), } } -func (repo *memoryMetadataRepository) Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error) { - src, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String())) - if !ok { - return nil, metadata.ErrNotExist - } +func (repo *memoryMetadataRepository) Create(ctx context.Context, u *url.URL, metadata domain.Metadata) error { + repo.mutex.Lock() + defer repo.mutex.Unlock() - result, ok := src.(*domain.Metadata) - if !ok { - return nil, metadata.ErrNotExist - } + repo.metadata[u.String()] = metadata - return result, nil + return nil +} + +func (repo *memoryMetadataRepository) Get(ctx context.Context, u *url.URL) (*domain.Metadata, error) { + repo.mutex.RLock() + defer repo.mutex.RUnlock() + + if out, ok := repo.metadata[u.String()]; ok { + return &out, nil + } + + return nil, metadata.ErrNotExist } diff --git a/internal/middleware/extractor.go b/internal/middleware/extractor.go index a5e35b7..a9c2d83 100644 --- a/internal/middleware/extractor.go +++ b/internal/middleware/extractor.go @@ -23,7 +23,6 @@ var ( errHeaderExtractorValueMissing = errors.New("missing value in request header") errHeaderExtractorValueInvalid = errors.New("invalid value in request header") errQueryExtractorValueMissing = errors.New("missing value in the query string") - errParamExtractorValueMissing = errors.New("missing value in path params") errCookieExtractorValueMissing = errors.New("missing value in cookies") errFormExtractorValueMissing = errors.New("missing value in the form") ) @@ -67,8 +66,6 @@ func createExtractors(lookups, authScheme string) ([]ValuesExtractor, error) { switch parts[0] { case "query": extractors = append(extractors, valuesFromQuery(parts[1])) - // case "param": - // extractors = append(extractors, valuesFromParam(parts[1])) case "cookie": extractors = append(extractors, valuesFromCookie(parts[1])) case "form": @@ -163,31 +160,6 @@ func valuesFromQuery(param string) ValuesExtractor { } } -// valuesFromParam returns a function that extracts values from the url param string. -/* -func valuesFromParam(param string) ValuesExtractor { - return func(w http.ResponseWriter, r *http.Request) ([]string, error) { - result := make([]string, 0) - paramVales := r.ParamValues() - - for i, p := range r.ParamNames() { - if param == p { - result = append(result, paramVales[i]) - if i >= extractorLimit-1 { - break - } - } - } - - if len(result) == 0 { - return nil, errParamExtractorValueMissing - } - - return result, nil - } -} -*/ - // valuesFromCookie returns a function that extracts values from the named cookie. func valuesFromCookie(name string) ValuesExtractor { return func(w http.ResponseWriter, r *http.Request) ([]string, error) { diff --git a/internal/middleware/jwt.go b/internal/middleware/jwt.go index 2d0520e..27fe40d 100644 --- a/internal/middleware/jwt.go +++ b/internal/middleware/jwt.go @@ -77,7 +77,6 @@ type ( // Possible values: // - "header:" // - "query:" - // - "param:" // - "cookie:" // - "form:" // Multiply sources example: diff --git a/internal/profile/repository.go b/internal/profile/repository.go index 2070c68..becc840 100644 --- a/internal/profile/repository.go +++ b/internal/profile/repository.go @@ -7,7 +7,8 @@ import ( ) type Repository interface { - Get(ctx context.Context, me *domain.Me) (*domain.Profile, error) + Create(ctx context.Context, me domain.Me, profile domain.Profile) error + Get(ctx context.Context, me domain.Me) (*domain.Profile, error) } var ErrNotExist error = domain.NewError( diff --git a/internal/profile/repository/http/http_profile.go b/internal/profile/repository/http/http_profile.go index ffd4ef1..f8dadab 100644 --- a/internal/profile/repository/http/http_profile.go +++ b/internal/profile/repository/http/http_profile.go @@ -1,12 +1,13 @@ package http import ( + "bytes" "context" "fmt" + "io" + "net/http" "net/url" - http "github.com/valyala/fasthttp" - "source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/httputil" "source.toby3d.me/toby3d/auth/internal/profile" @@ -33,29 +34,33 @@ func NewHTPPClientRepository(client *http.Client) profile.Repository { } } +// WARN(toby3d): not implemented. +func (repo *httpProfileRepository) Create(_ context.Context, _ domain.Me, _ domain.Profile) error { + return nil +} + //nolint:cyclop -func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*domain.Profile, error) { - req := http.AcquireRequest() - defer http.ReleaseRequest(req) - req.Header.SetMethod(http.MethodGet) - req.SetRequestURI(me.String()) - - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - - if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil { +func (repo *httpProfileRepository) Get(ctx context.Context, me domain.Me) (*domain.Profile, error) { + resp, err := repo.client.Get(me.String()) + if err != nil { return nil, fmt.Errorf("%s: cannot fetch user by me: %w", ErrPrefix, err) } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("cannot read response body: %w", err) + } + + buf := bytes.NewReader(body) result := domain.NewProfile() - for _, name := range httputil.ExtractProperty(resp, hCard, propertyName) { + for _, name := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyName) { if n, ok := name.(string); ok { result.Name = append(result.Name, n) } } - for _, rawEmail := range httputil.ExtractProperty(resp, hCard, propertyEmail) { + for _, rawEmail := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyEmail) { email, ok := rawEmail.(string) if !ok { continue @@ -66,7 +71,7 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom } } - for _, rawURL := range httputil.ExtractProperty(resp, hCard, propertyURL) { + for _, rawURL := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyURL) { rawURL, ok := rawURL.(string) if !ok { continue @@ -77,7 +82,7 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom } } - for _, rawPhoto := range httputil.ExtractProperty(resp, hCard, propertyPhoto) { + for _, rawPhoto := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyPhoto) { photo, ok := rawPhoto.(string) if !ok { continue @@ -88,8 +93,8 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom } } - if result.GetName() == "" && result.GetURL() == nil && - result.GetPhoto() == nil && result.GetEmail() == nil { + // TODO(toby3d): create method like result.Empty()? + if result.GetName() == "" && result.GetURL() == nil && result.GetPhoto() == nil && result.GetEmail() == nil { return nil, profile.ErrNotExist } diff --git a/internal/profile/repository/memory/memory_profile.go b/internal/profile/repository/memory/memory_profile.go index 0b77ba0..bcd3479 100644 --- a/internal/profile/repository/memory/memory_profile.go +++ b/internal/profile/repository/memory/memory_profile.go @@ -2,8 +2,6 @@ package memory import ( "context" - "fmt" - "path" "sync" "source.toby3d.me/toby3d/auth/internal/domain" @@ -11,30 +9,33 @@ import ( ) type memoryProfileRepository struct { - store *sync.Map + mutex *sync.RWMutex + profiles map[string]domain.Profile } -const ( - ErrPrefix string = "memory" - DefaultPathPrefix string = "profiles" -) - -func NewMemoryProfileRepository(store *sync.Map) profile.Repository { +func NewMemoryProfileRepository() profile.Repository { return &memoryProfileRepository{ - store: store, + mutex: new(sync.RWMutex), + profiles: make(map[string]domain.Profile), } } -func (repo *memoryProfileRepository) Get(_ context.Context, me *domain.Me) (*domain.Profile, error) { - src, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String())) - if !ok { - return nil, fmt.Errorf("%s: cannot find profile in store: %w", ErrPrefix, profile.ErrNotExist) - } +func (repo *memoryProfileRepository) Create(_ context.Context, me domain.Me, p domain.Profile) error { + repo.mutex.Lock() + defer repo.mutex.Unlock() - result, ok := src.(*domain.Profile) - if !ok { - return nil, fmt.Errorf("%s: cannot decode profile from store: %w", ErrPrefix, profile.ErrNotExist) - } + repo.profiles[me.String()] = p - return result, nil + return nil +} + +func (repo *memoryProfileRepository) Get(_ context.Context, me domain.Me) (*domain.Profile, error) { + repo.mutex.RLock() + defer repo.mutex.RUnlock() + + if p, ok := repo.profiles[me.String()]; ok { + return &p, nil + } + + return nil, profile.ErrNotExist } diff --git a/internal/profile/usecase.go b/internal/profile/usecase.go index 539fd40..6f55c30 100644 --- a/internal/profile/usecase.go +++ b/internal/profile/usecase.go @@ -7,7 +7,7 @@ import ( ) type UseCase interface { - Fetch(ctx context.Context, me *domain.Me) (*domain.Profile, error) + Fetch(ctx context.Context, me domain.Me) (*domain.Profile, error) } var ErrScopeRequired error = domain.NewError( diff --git a/internal/profile/usecase/profile_ucase.go b/internal/profile/usecase/profile_ucase.go index 0b71ca4..d799215 100644 --- a/internal/profile/usecase/profile_ucase.go +++ b/internal/profile/usecase/profile_ucase.go @@ -18,7 +18,7 @@ func NewProfileUseCase(profiles profile.Repository) profile.UseCase { } } -func (uc *profileUseCase) Fetch(ctx context.Context, me *domain.Me) (*domain.Profile, error) { +func (uc *profileUseCase) Fetch(ctx context.Context, me domain.Me) (*domain.Profile, error) { result, err := uc.profiles.Get(ctx, me) if err != nil { return nil, fmt.Errorf("cannot fetch profile info: %w", err) diff --git a/internal/random/random.go b/internal/random/random.go index 80bffd3..c6ed4cb 100644 --- a/internal/random/random.go +++ b/internal/random/random.go @@ -17,7 +17,7 @@ const ( Hex = Numeric + "abcdef" ) -func Bytes(length int) ([]byte, error) { +func Bytes(length uint8) ([]byte, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { @@ -27,7 +27,7 @@ func Bytes(length int) ([]byte, error) { return bytes, nil } -func String(length int, charsets ...string) (string, error) { +func String(length uint8, charsets ...string) (string, error) { charset := strings.Join(charsets, "") if charset == "" { charset = Alphabetic diff --git a/internal/session/repository.go b/internal/session/repository.go index 3b5a10f..b855dbc 100644 --- a/internal/session/repository.go +++ b/internal/session/repository.go @@ -8,7 +8,7 @@ import ( type Repository interface { Get(ctx context.Context, code string) (*domain.Session, error) - Create(ctx context.Context, session *domain.Session) error + Create(ctx context.Context, session domain.Session) error GetAndDelete(ctx context.Context, code string) (*domain.Session, error) GC() } diff --git a/internal/session/repository/memory/memory_session.go b/internal/session/repository/memory/memory_session.go index dde76c3..e2bdb50 100644 --- a/internal/session/repository/memory/memory_session.go +++ b/internal/session/repository/memory/memory_session.go @@ -3,7 +3,6 @@ package memory import ( "context" "fmt" - "path" "sync" "time" @@ -14,59 +13,59 @@ import ( type ( Session struct { CreatedAt time.Time - *domain.Session + domain.Session } memorySessionRepository struct { - store *sync.Map - config *domain.Config + config domain.Config + mutex *sync.RWMutex + sessions map[string]Session } ) -const DefaultPathPrefix string = "sessions" - -func NewMemorySessionRepository(store *sync.Map, config *domain.Config) session.Repository { +func NewMemorySessionRepository(config domain.Config) session.Repository { return &memorySessionRepository{ - config: config, - store: store, + config: config, + mutex: new(sync.RWMutex), + sessions: make(map[string]Session), } } -func (repo *memorySessionRepository) Create(_ context.Context, state *domain.Session) error { - repo.store.Store(path.Join(DefaultPathPrefix, state.Code), &Session{ +func (repo *memorySessionRepository) Create(_ context.Context, s domain.Session) error { + repo.mutex.Lock() + defer repo.mutex.Unlock() + + repo.sessions[s.Code] = Session{ CreatedAt: time.Now().UTC(), - Session: state, - }) + Session: s, + } return nil } func (repo *memorySessionRepository) Get(_ context.Context, code string) (*domain.Session, error) { - src, ok := repo.store.Load(path.Join(DefaultPathPrefix, code)) - if !ok { - return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist) + repo.mutex.Lock() + defer repo.mutex.Unlock() + + if s, ok := repo.sessions[code]; ok { + return &s.Session, nil } - result, ok := src.(*Session) - if !ok { - return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist) - } - - return result.Session, nil + return nil, session.ErrNotExist } -func (repo *memorySessionRepository) GetAndDelete(_ context.Context, code string) (*domain.Session, error) { - src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, code)) - if !ok { - return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist) +func (repo *memorySessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) { + s, err := repo.Get(ctx, code) + if err != nil { + return nil, fmt.Errorf("cannot get and delete session: %w", err) } - result, ok := src.(*Session) - if !ok { - return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist) - } + repo.mutex.Lock() + defer repo.mutex.Unlock() - return result.Session, nil + delete(repo.sessions, s.Code) + + return s, nil } func (repo *memorySessionRepository) GC() { @@ -76,29 +75,20 @@ func (repo *memorySessionRepository) GC() { for ts := range ticker.C { ts := ts - repo.store.Range(func(key, value interface{}) bool { - k, ok := key.(string) - if !ok { - return false + repo.mutex.RLock() + + for code, s := range repo.sessions { + if s.CreatedAt.Add(repo.config.Code.Expiry).After(ts) { + continue } - matched, err := path.Match(DefaultPathPrefix+"/*", k) - if err != nil || !matched { - return false - } + repo.mutex.RUnlock() + repo.mutex.Lock() + delete(repo.sessions, code) + repo.mutex.Unlock() + repo.mutex.RLock() + } - val, ok := value.(*Session) - if !ok { - return false - } - - if val.CreatedAt.Add(repo.config.Code.Expiry).After(ts) { - return false - } - - repo.store.Delete(key) - - return false - }) + repo.mutex.RUnlock() } } diff --git a/internal/session/repository/sqlite3/sqlite3_session.go b/internal/session/repository/sqlite3/sqlite3_session.go index 3aadfac..ea07e66 100644 --- a/internal/session/repository/sqlite3/sqlite3_session.go +++ b/internal/session/repository/sqlite3/sqlite3_session.go @@ -4,11 +4,11 @@ import ( "context" "database/sql" "encoding/base64" - "encoding/json" "errors" "fmt" "time" + "github.com/goccy/go-json" "github.com/jmoiron/sqlx" "source.toby3d.me/toby3d/auth/internal/domain" @@ -53,8 +53,8 @@ func NewSQLite3SessionRepository(db *sqlx.DB) session.Repository { } } -func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domain.Session) error { - src, err := NewSession(session) +func (repo *sqlite3SessionRepository) Create(ctx context.Context, session domain.Session) error { + src, err := NewSession(&session) if err != nil { return fmt.Errorf("cannot encode session data for store: %w", err) } diff --git a/internal/session/repository/sqlite3/sqlite3_session_test.go b/internal/session/repository/sqlite3/sqlite3_session_test.go index b2f2604..8ed11a1 100644 --- a/internal/session/repository/sqlite3/sqlite3_session_test.go +++ b/internal/session/repository/sqlite3/sqlite3_session_test.go @@ -12,7 +12,7 @@ import ( "source.toby3d.me/toby3d/auth/internal/testing/sqltest" ) -//nolint: gochecknoglobals // slices cannot be contants +// nolint: gochecknoglobals // slices cannot be contants var tableColumns = []string{"created_at", "code", "data"} func TestCreate(t *testing.T) { @@ -39,7 +39,7 @@ func TestCreate(t *testing.T) { WillReturnResult(sqlmock.NewResult(1, 1)) if err := repository.NewSQLite3SessionRepository(db). - Create(context.Background(), session); err != nil { + Create(context.Background(), *session); err != nil { t.Error(err) } } diff --git a/internal/testing/httptest/.gitignore b/internal/testing/httptest/.gitignore deleted file mode 100644 index 612424a..0000000 --- a/internal/testing/httptest/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.pem \ No newline at end of file diff --git a/internal/testing/httptest/httptest.go b/internal/testing/httptest/httptest.go deleted file mode 100644 index ca22ce1..0000000 --- a/internal/testing/httptest/httptest.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:generate go run "$GOROOT/src/crypto/tls/generate_cert.go" --host 127.0.0.1,::1,localhost --start-date "Jan 1 00:00:00 1970" --duration=1000000h --ca --rsa-bits 1024 --ecdsa-curve P256 -package httptest - -import ( - "crypto/tls" - _ "embed" // used for running tests without same import in "god object" - "net" - "testing" - "time" - - http "github.com/valyala/fasthttp" - httputil "github.com/valyala/fasthttp/fasthttputil" -) - -var ( - //go:embed cert.pem - certData []byte - //go:embed key.pem - keyData []byte -) - -// New returns the InMemory Server and the Client connected to it with the -// specified handler. -func New(tb testing.TB, handler http.RequestHandler) (*http.Client, *http.Server, func()) { - tb.Helper() - - //nolint:exhaustivestruct - server := &http.Server{ - CloseOnShutdown: true, - DisableKeepalive: true, - ReduceMemoryUsage: true, - Handler: http.TimeoutHandler(handler, 1*time.Second, "handler performance is too slow"), - } - - ln := httputil.NewInmemoryListener() - - //nolint:errcheck - go server.ServeTLSEmbed(ln, certData, keyData) - - //nolint:exhaustivestruct - client := &http.Client{ - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, //nolint:gosec - }, - Dial: func(addr string) (net.Conn, error) { - return ln.Dial() //nolint:wrapcheck - }, - } - - return client, server, func() { - _ = server.Shutdown() - } -} - -// NewRequest returns a new incoming server Request and cleanup function. -func NewRequest(method, target string, body []byte) *http.Request { - req := http.AcquireRequest() - req.Header.SetMethod(method) - req.SetRequestURI(target) - req.SetBody(body) - - return req -} diff --git a/internal/ticket/delivery/http/ticket_http.go b/internal/ticket/delivery/http/ticket_http.go index 67a6d89..e20cc41 100644 --- a/internal/ticket/delivery/http/ticket_http.go +++ b/internal/ticket/delivery/http/ticket_http.go @@ -1,72 +1,48 @@ package http import ( - "errors" "fmt" - "path" + "net/http" - "github.com/fasthttp/router" "github.com/goccy/go-json" "github.com/lestrrat-go/jwx/v2/jwa" - http "github.com/valyala/fasthttp" "golang.org/x/text/language" "golang.org/x/text/message" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" + "source.toby3d.me/toby3d/auth/internal/middleware" "source.toby3d.me/toby3d/auth/internal/random" "source.toby3d.me/toby3d/auth/internal/ticket" + "source.toby3d.me/toby3d/auth/internal/urlutil" "source.toby3d.me/toby3d/auth/web" - "source.toby3d.me/toby3d/form" - "source.toby3d.me/toby3d/middleware" ) -type ( - TicketGenerateRequest struct { - // The access token should be used when acting on behalf of this URL. - Subject *domain.Me `form:"subject"` +type Handler struct { + config domain.Config + matcher language.Matcher + tickets ticket.UseCase +} - // The access token will work at this URL. - Resource *domain.URL `form:"resource"` - } - - TicketExchangeRequest struct { - // A random string that can be redeemed for an access token. - Ticket string `form:"ticket"` - - // The access token should be used when acting on behalf of this URL. - Subject *domain.Me `form:"subject"` - - // The access token will work at this URL. - Resource *domain.URL `form:"resource"` - } - - RequestHandler struct { - config *domain.Config - matcher language.Matcher - tickets ticket.UseCase - } -) - -func NewRequestHandler(tickets ticket.UseCase, matcher language.Matcher, config *domain.Config) *RequestHandler { - return &RequestHandler{ +func NewHandler(tickets ticket.UseCase, matcher language.Matcher, config domain.Config) *Handler { + return &Handler{ config: config, matcher: matcher, tickets: tickets, } } -func (h *RequestHandler) Register(r *router.Router) { +func (h *Handler) Handler() http.Handler { //nolint:exhaustivestruct chain := middleware.Chain{ middleware.CSRFWithConfig(middleware.CSRFConfig{ - Skipper: func(ctx *http.RequestCtx) bool { - matched, _ := path.Match("/ticket*", string(ctx.Path())) + Skipper: func(w http.ResponseWriter, r *http.Request) bool { + head, _ := urlutil.ShiftPath(r.URL.Path) - return ctx.IsPost() && matched + return r.Method == http.MethodPost && head == "ticket" }, CookieMaxAge: 0, - CookieSameSite: http.CookieSameSiteStrictMode, + CookieSameSite: http.SameSiteStrictMode, ContextKey: "csrf", CookieDomain: h.config.Server.Domain, CookieName: "__Secure-csrf", @@ -89,45 +65,69 @@ func (h *RequestHandler) Register(r *router.Router) { SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm), Skipper: middleware.DefaultSkipper, SuccessHandler: nil, - TokenLookup: "header:" + http.HeaderAuthorization + - "," + "cookie:" + "__Secure-auth-token", + TokenLookup: "header:" + common.HeaderAuthorization + + ",cookie:__Secure-auth-token", }), - middleware.LogFmt(), } - r.GET("/ticket", chain.RequestHandler(h.handleRender)) - r.POST("/api/ticket", chain.RequestHandler(h.handleSend)) - r.POST("/ticket", chain.RequestHandler(h.handleRedeem)) + return chain.Handler(h.handleFunc) } -func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) +func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) { + var head string + head, r.URL.Path = urlutil.ShiftPath(r.URL.Path) - tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) + switch r.Method { + default: + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + case "", http.MethodGet: + if head != "" { + http.NotFound(w, r) + + return + } + + h.handleRender(w, r) + case http.MethodPost: + + switch head { + default: + http.NotFound(w, r) + case "": + h.handleRedeem(w, r) + case "send": + h.handleSend(w, r) + } + } +} + +func (h *Handler) handleRender(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + + tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage)) tag, _, _ := h.matcher.Match(tags...) baseOf := web.BaseOf{ - Config: h.config, + Config: &h.config, Language: tag, Printer: message.NewPrinter(tag), } - csrf, _ := ctx.UserValue("csrf").([]byte) - web.WriteTemplate(ctx, &web.TicketPage{ + csrf, _ := r.Context().Value("csrf").([]byte) + web.WriteTemplate(w, &web.TicketPage{ BaseOf: baseOf, CSRF: csrf, }) } -func (h *RequestHandler) handleSend(ctx *http.RequestCtx) { - ctx.Response.Header.Set(http.HeaderAccessControlAllowOrigin, h.config.Server.Domain) - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) - ctx.SetStatusCode(http.StatusOK) +func (h *Handler) handleSend(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderAccessControlAllowOrigin, h.config.Server.Domain) + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) - encoder := json.NewEncoder(ctx) + encoder := json.NewEncoder(w) req := new(TicketGenerateRequest) - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(err) @@ -137,51 +137,50 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) { ticket := &domain.Ticket{ Ticket: "", Resource: req.Resource.URL, - Subject: req.Subject, + Subject: &req.Subject, } var err error if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil { - ctx.SetStatusCode(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) _ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) return } - if err = h.tickets.Generate(ctx, ticket); err != nil { - ctx.SetStatusCode(http.StatusInternalServerError) + if err = h.tickets.Generate(r.Context(), *ticket); err != nil { + w.WriteHeader(http.StatusInternalServerError) _ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) return } - ctx.SetStatusCode(http.StatusOK) + w.WriteHeader(http.StatusOK) } -func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) - ctx.SetStatusCode(http.StatusOK) +func (h *Handler) handleRedeem(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) - encoder := json.NewEncoder(ctx) + encoder := json.NewEncoder(w) req := new(TicketExchangeRequest) - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(err) return } - token, err := h.tickets.Redeem(ctx, &domain.Ticket{ + token, err := h.tickets.Redeem(r.Context(), domain.Ticket{ Ticket: req.Ticket, Resource: req.Resource.URL, - Subject: req.Subject, + Subject: &req.Subject, }) if err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) @@ -190,84 +189,11 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) { // TODO(toby3d): print the result as part of the debugging. Instead, we // need to send or save the token to the recipient for later use. - ctx.SetBodyString(fmt.Sprintf(`{ + fmt.Fprintf(w, `{ "access_token": "%s", "token_type": "Bearer", "scope": "%s", "me": "%s" - }`, token.AccessToken, token.Scope.String(), token.Me.String())) -} - -func (req *TicketGenerateRequest) bind(ctx *http.RequestCtx) (err error) { - indieAuthError := new(domain.Error) - if err = form.Unmarshal(ctx.Request.PostArgs().QueryString(), req); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - ) - } - - if req.Resource == nil { - return domain.NewError( - domain.ErrorCodeInvalidRequest, - "resource value MUST be set", - "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - ) - } - - if req.Subject == nil { - return domain.NewError( - domain.ErrorCodeInvalidRequest, - "subject value MUST be set", - "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - ) - } - - return nil -} - -func (req *TicketExchangeRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - if err := form.Unmarshal(ctx.Request.PostArgs().QueryString(), req); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - ) - } - - if req.Ticket == "" { - return domain.NewError( - domain.ErrorCodeInvalidRequest, - "ticket parameter is required", - "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - ) - } - - if req.Resource == nil { - return domain.NewError( - domain.ErrorCodeInvalidRequest, - "resource parameter is required", - "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - ) - } - - if req.Subject == nil { - return domain.NewError( - domain.ErrorCodeInvalidRequest, - "subject parameter is required", - "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", - ) - } - - return nil + }`, token.AccessToken, token.Scope.String(), token.Me.String()) + w.WriteHeader(http.StatusOK) } diff --git a/internal/ticket/delivery/http/ticket_http_schema.go b/internal/ticket/delivery/http/ticket_http_schema.go new file mode 100644 index 0000000..eb58b29 --- /dev/null +++ b/internal/ticket/delivery/http/ticket_http_schema.go @@ -0,0 +1,90 @@ +package http + +import ( + "errors" + "net/http" + + "source.toby3d.me/toby3d/auth/internal/domain" + "source.toby3d.me/toby3d/form" +) + +type ( + TicketGenerateRequest struct { + // The access token should be used when acting on behalf of this URL. + Subject domain.Me `form:"subject"` + + // The access token will work at this URL. + Resource domain.URL `form:"resource"` + } + + TicketExchangeRequest struct { + // The access token should be used when acting on behalf of this URL. + Subject domain.Me `form:"subject"` + + // The access token will work at this URL. + Resource domain.URL `form:"resource"` + + // A random string that can be redeemed for an access token. + Ticket string `form:"ticket"` + } +) + +func (r *TicketGenerateRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := req.ParseForm(); err != nil { + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#authorization-request", + ) + } + + if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) + } + + return nil +} + +func (r *TicketExchangeRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := req.ParseForm(); err != nil { + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#authorization-request", + ) + } + + if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) + } + + if r.Ticket == "" { + return domain.NewError( + domain.ErrorCodeInvalidRequest, + "ticket parameter is required", + "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket", + ) + } + + return nil +} diff --git a/internal/ticket/delivery/http/ticket_http_test.go b/internal/ticket/delivery/http/ticket_http_test.go index 9eec61b..718966d 100644 --- a/internal/ticket/delivery/http/ticket_http_test.go +++ b/internal/ticket/delivery/http/ticket_http_test.go @@ -1,17 +1,20 @@ package http_test +/* TODO(toby3d): move CSRF middleware into main import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" "sync" "testing" - "github.com/fasthttp/router" - http "github.com/valyala/fasthttp" "golang.org/x/text/language" "golang.org/x/text/message" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" - "source.toby3d.me/toby3d/auth/internal/testing/httptest" "source.toby3d.me/toby3d/auth/internal/ticket" delivery "source.toby3d.me/toby3d/auth/internal/ticket/delivery/http" ticketrepo "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory" @@ -19,6 +22,7 @@ import ( ) type Dependencies struct { + server *httptest.Server client *http.Client config *domain.Config matcher language.Matcher @@ -33,40 +37,35 @@ func TestUpdate(t *testing.T) { t.Parallel() deps := NewDependencies(t) + t.Cleanup(deps.server.Close) - r := router.New() - delivery.NewRequestHandler(deps.ticketService, deps.matcher, deps.config).Register(r) - - client, _, cleanup := httptest.New(t, r.Handler) - t.Cleanup(cleanup) - - const requestURI string = "https://example.com/ticket" - - req := httptest.NewRequest(http.MethodPost, requestURI, []byte( + req := httptest.NewRequest(http.MethodPost, "https://example.com/", strings.NewReader( `ticket=`+deps.ticket.Ticket+ `&resource=`+deps.ticket.Resource.String()+ `&subject=`+deps.ticket.Subject.String(), )) - defer http.ReleaseRequest(req) - req.Header.SetContentType(common.MIMEApplicationForm) + req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm) + deps.token.SetAuthHeader(req) + + w := httptest.NewRecorder() + delivery.NewHandler(deps.ticketService, deps.matcher, *deps.config). + Handler(). + ServeHTTP(w, req) + domain.TestToken(t).SetAuthHeader(req) - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) + resp := w.Result() - if err := client.Do(req, resp); err != nil { - t.Fatal(err) - } - - if resp.StatusCode() != http.StatusOK && resp.StatusCode() != http.StatusAccepted { - t.Errorf("POST %s = %d, want %d or %d", requestURI, resp.StatusCode(), http.StatusOK, + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted { + t.Errorf("%s %s = %d, want %d or %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK, http.StatusAccepted) } // TODO(toby3d): print the result as part of the debugging. Instead, you // need to send or save the token to the recipient for later use. - if resp.Body() == nil { - t.Errorf("POST %s = nil, want something", requestURI) + if resp.Body == nil { + t.Errorf("%s %s = nil, want not nil", req.Method, req.RequestURI) } } @@ -79,29 +78,36 @@ func NewDependencies(tb testing.TB) Dependencies { ticket := domain.TestTicket(tb) token := domain.TestToken(tb) - r := router.New() + mux := http.NewServeMux() // NOTE(toby3d): private resource - r.GET(ticket.Resource.Path, func(ctx *http.RequestCtx) { - ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, - ``) + mux.HandleFunc(ticket.Resource.Path, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + fmt.Fprintf(w, ``) }) // NOTE(toby3d): token endpoint - r.POST("/token", func(ctx *http.RequestCtx) { - ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{ - "access_token": "`+token.AccessToken+`", - "me": "`+token.Me.String()+`", - "scope": "`+token.Scope.String()+`", - "token_type": "Bearer" - }`) + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + + return + } + + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + fmt.Fprintf(w, `{ + "access_token": "`+token.AccessToken+`", + "me": "`+token.Me.String()+`", + "scope": "`+token.Scope.String()+`", + "token_type": "Bearer" + }`) }) - client, _, cleanup := httptest.New(tb, r.Handler) - tb.Cleanup(cleanup) - + server := httptest.NewServer(mux) + client := server.Client() tickets := ticketrepo.NewMemoryTicketRepository(store, config) ticketService := ucase.NewTicketUseCase(tickets, client, config) return Dependencies{ + server: server, client: client, config: config, matcher: matcher, @@ -112,3 +118,4 @@ func NewDependencies(tb testing.TB) Dependencies { token: token, } } +*/ diff --git a/internal/ticket/repository.go b/internal/ticket/repository.go index 06e7eab..5e85b52 100644 --- a/internal/ticket/repository.go +++ b/internal/ticket/repository.go @@ -7,7 +7,7 @@ import ( ) type Repository interface { - Create(ctx context.Context, ticket *domain.Ticket) error + Create(ctx context.Context, ticket domain.Ticket) error GetAndDelete(ctx context.Context, ticket string) (*domain.Ticket, error) GC() } diff --git a/internal/ticket/repository/memory/memory_ticket.go b/internal/ticket/repository/memory/memory_ticket.go index a7ad4c5..08e12d8 100644 --- a/internal/ticket/repository/memory/memory_ticket.go +++ b/internal/ticket/repository/memory/memory_ticket.go @@ -2,8 +2,6 @@ package memory import ( "context" - "fmt" - "path" "sync" "time" @@ -14,77 +12,75 @@ import ( type ( Ticket struct { CreatedAt time.Time - *domain.Ticket + domain.Ticket } memoryTicketRepository struct { - config *domain.Config - store *sync.Map + config domain.Config + mutex *sync.RWMutex + tickets map[string]Ticket } ) -const DefaultPathPrefix string = "tickets" - -func NewMemoryTicketRepository(store *sync.Map, config *domain.Config) ticket.Repository { +func NewMemoryTicketRepository(config domain.Config) ticket.Repository { return &memoryTicketRepository{ - config: config, - store: store, + config: config, + mutex: new(sync.RWMutex), + tickets: make(map[string]Ticket), } } -func (repo *memoryTicketRepository) Create(_ context.Context, t *domain.Ticket) error { - repo.store.Store(path.Join(DefaultPathPrefix, t.Ticket), &Ticket{ +func (repo *memoryTicketRepository) Create(_ context.Context, t domain.Ticket) error { + repo.mutex.Lock() + defer repo.mutex.Unlock() + + repo.tickets[t.Ticket] = Ticket{ CreatedAt: time.Now().UTC(), Ticket: t, - }) + } return nil } func (repo *memoryTicketRepository) GetAndDelete(_ context.Context, t string) (*domain.Ticket, error) { - src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, t)) + repo.mutex.RLock() + + out, ok := repo.tickets[t] if !ok { - return nil, fmt.Errorf("cannot find ticket in store: %w", ticket.ErrNotExist) + repo.mutex.RUnlock() + + return nil, ticket.ErrNotExist } - result, ok := src.(*Ticket) - if !ok { - return nil, fmt.Errorf("cannot decode ticket in store: %w", ticket.ErrNotExist) - } + repo.mutex.RUnlock() + repo.mutex.Lock() + delete(repo.tickets, t) + repo.mutex.Unlock() - return result.Ticket, nil + return &out.Ticket, nil } func (repo *memoryTicketRepository) GC() { ticker := time.NewTicker(time.Second) defer ticker.Stop() - for timeStamp := range ticker.C { - timeStamp := timeStamp.UTC() + for ts := range ticker.C { + ts := ts.UTC() - repo.store.Range(func(key, value interface{}) bool { - k, ok := key.(string) - if !ok { - return false + repo.mutex.RLock() + + for _, t := range repo.tickets { + if t.CreatedAt.Add(repo.config.Code.Expiry).After(ts) { + continue } - matched, err := path.Match(DefaultPathPrefix+"/*", k) - if err != nil || !matched { - return false - } + repo.mutex.RUnlock() + repo.mutex.Lock() + delete(repo.tickets, t.Ticket.Ticket) + repo.mutex.Unlock() + repo.mutex.RLock() + } - val, ok := value.(*Ticket) - if !ok { - return false - } - - if val.CreatedAt.Add(repo.config.Code.Expiry).After(timeStamp) { - return false - } - - repo.store.Delete(key) - - return false - }) + repo.mutex.RUnlock() } } diff --git a/internal/ticket/repository/memory/memory_ticket_test.go b/internal/ticket/repository/memory/memory_ticket_test.go deleted file mode 100644 index 2f017f0..0000000 --- a/internal/ticket/repository/memory/memory_ticket_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package memory_test - -import ( - "context" - "path" - "reflect" - "sync" - "testing" - "time" - - "source.toby3d.me/toby3d/auth/internal/domain" - repository "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory" -) - -func TestCreate(t *testing.T) { - t.Parallel() - - store := new(sync.Map) - ticket := domain.TestTicket(t) - - if err := repository.NewMemoryTicketRepository(store, domain.TestConfig(t)). - Create(context.Background(), ticket); err != nil { - t.Fatal(err) - } - - storePath := path.Join(repository.DefaultPathPrefix, ticket.Ticket) - - src, ok := store.Load(storePath) - if !ok { - t.Fatalf("Load(%s) = %t, want %t", storePath, ok, true) - } - - if result, _ := src.(*repository.Ticket); !reflect.DeepEqual(result.Ticket, ticket) { - t.Errorf("Create(%+v) = %+v, want %+v", ticket, result.Ticket, ticket) - } -} - -func TestGetAndDelete(t *testing.T) { - t.Parallel() - - ticket := domain.TestTicket(t) - - store := new(sync.Map) - store.Store(path.Join(repository.DefaultPathPrefix, ticket.Ticket), &repository.Ticket{ - CreatedAt: time.Now().UTC(), - Ticket: ticket, - }) - - result, err := repository.NewMemoryTicketRepository(store, domain.TestConfig(t)). - GetAndDelete(context.Background(), ticket.Ticket) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(result, ticket) { - t.Errorf("GetAndDelete(%s) = %+v, want %+v", ticket.Ticket, result, ticket) - } - - storePath := path.Join(repository.DefaultPathPrefix, ticket.Ticket) - if src, _ := store.Load(storePath); src != nil { - t.Errorf("Load(%s) = %+v, want %+v", storePath, src, nil) - } -} diff --git a/internal/ticket/repository/sqlite3/sqlite3_ticket.go b/internal/ticket/repository/sqlite3/sqlite3_ticket.go index 1c6b615..7e823b4 100644 --- a/internal/ticket/repository/sqlite3/sqlite3_ticket.go +++ b/internal/ticket/repository/sqlite3/sqlite3_ticket.go @@ -56,8 +56,8 @@ func NewSQLite3TicketRepository(db *sqlx.DB, config *domain.Config) ticket.Repos } } -func (repo *sqlite3TicketRepository) Create(ctx context.Context, t *domain.Ticket) error { - if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewTicket(t)); err != nil { +func (repo *sqlite3TicketRepository) Create(ctx context.Context, t domain.Ticket) error { + if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewTicket(&t)); err != nil { return fmt.Errorf("cannot create token record in db: %w", err) } diff --git a/internal/ticket/repository/sqlite3/sqlite3_ticket_test.go b/internal/ticket/repository/sqlite3/sqlite3_ticket_test.go index 55307ea..7dee99e 100644 --- a/internal/ticket/repository/sqlite3/sqlite3_ticket_test.go +++ b/internal/ticket/repository/sqlite3/sqlite3_ticket_test.go @@ -12,7 +12,7 @@ import ( repository "source.toby3d.me/toby3d/auth/internal/ticket/repository/sqlite3" ) -// nolint: gochecknoglobals // slices cannot be contants +//nolint: gochecknoglobals // slices cannot be contants var tableColumns = []string{"created_at", "resource", "subject", "ticket"} func TestCreate(t *testing.T) { @@ -34,7 +34,7 @@ func TestCreate(t *testing.T) { WillReturnResult(sqlmock.NewResult(1, 1)) if err := repository.NewSQLite3TicketRepository(db, domain.TestConfig(t)). - Create(context.Background(), ticket); err != nil { + Create(context.Background(), *ticket); err != nil { t.Error(err) } } diff --git a/internal/ticket/usecase.go b/internal/ticket/usecase.go index c49f807..7687bae 100644 --- a/internal/ticket/usecase.go +++ b/internal/ticket/usecase.go @@ -7,10 +7,10 @@ import ( ) type UseCase interface { - Generate(ctx context.Context, ticket *domain.Ticket) error + Generate(ctx context.Context, ticket domain.Ticket) error // Redeem transform received ticket into access token. - Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error) + Redeem(ctx context.Context, ticket domain.Ticket) (*domain.Token, error) Exchange(ctx context.Context, ticket string) (*domain.Token, error) } diff --git a/internal/ticket/usecase/ticket_ucase.go b/internal/ticket/usecase/ticket_ucase.go index b285525..7439369 100644 --- a/internal/ticket/usecase/ticket_ucase.go +++ b/internal/ticket/usecase/ticket_ucase.go @@ -1,13 +1,15 @@ package usecase import ( + "bytes" "context" "fmt" + "io" + "net/http" "net/url" "time" json "github.com/goccy/go-json" - http "github.com/valyala/fasthttp" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" @@ -47,26 +49,28 @@ func NewTicketUseCase(tickets ticket.Repository, client *http.Client, config *do } } -func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket) error { - req := http.AcquireRequest() - defer http.ReleaseRequest(req) - req.Header.SetMethod(http.MethodGet) - req.SetRequestURI(tkt.Subject.String()) - - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - - if err := useCase.client.Do(req, resp); err != nil { +func (useCase *ticketUseCase) Generate(ctx context.Context, tkt domain.Ticket) error { + resp, err := useCase.client.Get(tkt.Subject.String()) + if err != nil { return fmt.Errorf("cannot discovery ticket subject: %w", err) } - var ticketEndpoint *url.URL + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("cannot read response body: %w", err) + } + + buf := bytes.NewReader(body) + ticketEndpoint := new(url.URL) // NOTE(toby3d): find metadata first - if metadata, err := httputil.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil { + metadata, err := httputil.ExtractFromMetadata(useCase.client, tkt.Subject.String()) + if err == nil && metadata != nil { ticketEndpoint = metadata.TicketEndpoint } else { // NOTE(toby3d): fallback to old links searching - if endpoints := httputil.ExtractEndpoints(resp, "ticket_endpoint"); len(endpoints) > 0 { + endpoints := httputil.ExtractEndpoints(buf, tkt.Subject.URL(), resp.Header.Get(common.HeaderLink), + "ticket_endpoint") + if len(endpoints) > 0 { ticketEndpoint = endpoints[len(endpoints)-1] } } @@ -79,65 +83,59 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket) return fmt.Errorf("cannot save ticket in store: %w", err) } - req.Reset() - req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(ticketEndpoint.String()) - req.Header.SetContentType(common.MIMEApplicationForm) - req.PostArgs().Set("ticket", tkt.Ticket) - req.PostArgs().Set("subject", tkt.Subject.String()) - req.PostArgs().Set("resource", tkt.Resource.String()) - resp.Reset() + payload := make(url.Values) + payload.Set("ticket", tkt.Ticket) + payload.Set("subject", tkt.Subject.String()) + payload.Set("resource", tkt.Resource.String()) - if err := useCase.client.Do(req, resp); err != nil { + if _, err = useCase.client.PostForm(ticketEndpoint.String(), payload); err != nil { return fmt.Errorf("cannot send ticket to subject ticket_endpoint: %w", err) } return nil } -func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (*domain.Token, error) { - req := http.AcquireRequest() - defer http.ReleaseRequest(req) - req.SetRequestURI(tkt.Resource.String()) - req.Header.SetMethod(http.MethodGet) - - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - - if err := useCase.client.Do(req, resp); err != nil { +func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt domain.Ticket) (*domain.Token, error) { + resp, err := useCase.client.Get(tkt.Resource.String()) + if err != nil { return nil, fmt.Errorf("cannot discovery ticket resource: %w", err) } - var tokenEndpoint *url.URL + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("cannot read response body: %w", err) + } + + buf := bytes.NewReader(body) + tokenEndpoint := new(url.URL) // NOTE(toby3d): find metadata first - if metadata, err := httputil.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil { + metadata, err := httputil.ExtractFromMetadata(useCase.client, tkt.Resource.String()) + if err == nil && metadata != nil { tokenEndpoint = metadata.TokenEndpoint } else { // NOTE(toby3d): fallback to old links searching - if endpoints := httputil.ExtractEndpoints(resp, "token_endpoint"); len(endpoints) > 0 { + endpoints := httputil.ExtractEndpoints(buf, tkt.Resource, resp.Header.Get(common.HeaderLink), + "token_endpoint") + if len(endpoints) > 0 { tokenEndpoint = endpoints[len(endpoints)-1] } } - if tokenEndpoint == nil { + if tokenEndpoint == nil || tokenEndpoint.String() == "" { return nil, ticket.ErrTokenEndpointNotExist } - req.Reset() - req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(tokenEndpoint.String()) - req.Header.SetContentType(common.MIMEApplicationForm) - req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON) - req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String()) - req.PostArgs().Set("ticket", tkt.Ticket) - resp.Reset() + payload := make(url.Values) + payload.Set("grant_type", domain.GrantTypeTicket.String()) + payload.Set("ticket", tkt.Ticket) - if err := useCase.client.Do(req, resp); err != nil { + resp, err = useCase.client.PostForm(tokenEndpoint.String(), payload) + if err != nil { return nil, fmt.Errorf("cannot exchange ticket on token_endpoint: %w", err) } data := new(AccessToken) - if err := json.Unmarshal(resp.Body(), data); err != nil { + if err := json.NewDecoder(resp.Body).Decode(data); err != nil { return nil, fmt.Errorf("cannot unmarshal access token response: %w", err) } @@ -147,8 +145,8 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (* Scope: nil, // TODO(toby3d) // TODO(toby3d): should this also include client_id? // https://github.com/indieweb/indieauth/issues/85 - ClientID: nil, - Me: data.Me, + ClientID: domain.ClientID{}, + Me: *data.Me, AccessToken: data.AccessToken, RefreshToken: "", // TODO(toby3d) }, nil @@ -163,8 +161,8 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*dom token, err := domain.NewToken(domain.NewTokenOptions{ Expiration: useCase.config.JWT.Expiry, Scope: domain.Scopes{domain.ScopeRead}, - Issuer: nil, - Subject: tkt.Subject, + Issuer: domain.ClientID{}, + Subject: *tkt.Subject, Secret: []byte(useCase.config.JWT.Secret), Algorithm: useCase.config.JWT.Algorithm, NonceLength: useCase.config.JWT.NonceLength, diff --git a/internal/ticket/usecase/ticket_ucase_test.go b/internal/ticket/usecase/ticket_ucase_test.go index 73199a8..a67bc9b 100644 --- a/internal/ticket/usecase/ticket_ucase_test.go +++ b/internal/ticket/usecase/ticket_ucase_test.go @@ -3,14 +3,13 @@ package usecase_test import ( "context" "fmt" + "net/http" + "net/http/httptest" + "net/url" "testing" - "github.com/fasthttp/router" - http "github.com/valyala/fasthttp" - "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" - "source.toby3d.me/toby3d/auth/internal/testing/httptest" ucase "source.toby3d.me/toby3d/auth/internal/ticket/usecase" ) @@ -20,25 +19,33 @@ func TestRedeem(t *testing.T) { token := domain.TestToken(t) ticket := domain.TestTicket(t) - router := router.New() - router.GET(string(ticket.Resource.Path), func(ctx *http.RequestCtx) { - ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, ``) - }) - router.POST("/token", func(ctx *http.RequestCtx) { - ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, fmt.Sprintf(`{ + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + + return + } + + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + fmt.Fprintf(w, `{ "token_type": "Bearer", "access_token": "%s", "scope": "%s", "me": "%s" - }`, token.AccessToken, token.Scope.String(), token.Me.String())) - }) + }`, token.AccessToken, token.Scope.String(), token.Me.String()) + })) + t.Cleanup(tokenServer.Close) - client, _, cleanup := httptest.New(t, router.Handler) - t.Cleanup(cleanup) + subjectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + fmt.Fprint(w, ``) + })) + t.Cleanup(subjectServer.Close) - result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)). - Redeem(context.Background(), ticket) + ticket.Resource, _ = url.Parse(subjectServer.URL + "/") + + result, err := ucase.NewTicketUseCase(nil, subjectServer.Client(), domain.TestConfig(t)). + Redeem(context.Background(), *ticket) if err != nil { t.Fatal(err) } diff --git a/internal/token/delivery/http/token_http.go b/internal/token/delivery/http/token_http.go index 6b94433..1936f78 100644 --- a/internal/token/delivery/http/token_http.go +++ b/internal/token/delivery/http/token_http.go @@ -1,187 +1,100 @@ package http import ( - "errors" - "path" + "net/http" - "github.com/fasthttp/router" - json "github.com/goccy/go-json" + "github.com/goccy/go-json" "github.com/lestrrat-go/jwx/v2/jwa" - http "github.com/valyala/fasthttp" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" + "source.toby3d.me/toby3d/auth/internal/middleware" "source.toby3d.me/toby3d/auth/internal/ticket" "source.toby3d.me/toby3d/auth/internal/token" - "source.toby3d.me/toby3d/form" - "source.toby3d.me/toby3d/middleware" + "source.toby3d.me/toby3d/auth/internal/urlutil" ) -type ( - TokenExchangeRequest struct { - ClientID *domain.ClientID `form:"client_id"` - RedirectURI *domain.URL `form:"redirect_uri"` - GrantType domain.GrantType `form:"grant_type"` - Code string `form:"code"` - CodeVerifier string `form:"code_verifier"` - } +type Handler struct { + config *domain.Config + tokens token.UseCase + tickets ticket.UseCase +} - TokenRefreshRequest struct { - GrantType domain.GrantType `form:"grant_type"` // refresh_token - - // The refresh token previously offered to the client. - RefreshToken string `form:"refresh_token"` - - // The client ID that was used when the refresh token was issued. - ClientID *domain.ClientID `form:"client_id"` - - // The client may request a token with the same or fewer scopes - // than the original access token. If omitted, is treated as - // equal to the original scopes granted. - Scope domain.Scopes `form:"scope"` - } - - TokenRevocationRequest struct { - Action domain.Action `form:"action,omitempty"` - Token string `form:"token"` - } - - TokenTicketRequest struct { - Action domain.Action `form:"action"` - Ticket string `form:"ticket"` - } - - TokenIntrospectRequest struct { - Token string `form:"token"` - } - - //nolint:tagliatelle // https://indieauth.net/source/#access-token-response - TokenExchangeResponse struct { - // The OAuth 2.0 Bearer Token RFC6750. - AccessToken string `json:"access_token"` - - // The canonical user profile URL for the user this access token - // corresponds to. - Me string `json:"me"` - - // The user's profile information. - Profile *TokenProfileResponse `json:"profile,omitempty"` - - // The lifetime in seconds of the access token. - ExpiresIn int64 `json:"expires_in,omitempty"` - - // The refresh token, which can be used to obtain new access - // tokens. - RefreshToken string `json:"refresh_token"` - } - - TokenProfileResponse struct { - // Name the user wishes to provide to the client. - Name string `json:"name,omitempty"` - - // URL of the user's website. - URL string `json:"url,omitempty"` - - // A photo or image that the user wishes clients to use as a - // profile image. - Photo string `json:"photo,omitempty"` - - // The email address a user wishes to provide to the client. - Email string `json:"email,omitempty"` - } - - //nolint:tagliatelle // https://indieauth.net/source/#access-token-verification-response - TokenIntrospectResponse struct { - // Boolean indicator of whether or not the presented token is - // currently active. - Active bool `json:"active"` - - // The profile URL of the user corresponding to this token. - Me string `json:"me"` - - // The client ID associated with this token. - ClientID string `json:"client_id"` - - // A space-separated list of scopes associated with this token. - Scope string `json:"scope"` - - // Integer timestamp, measured in the number of seconds since - // January 1 1970 UTC, indicating when this token will expire. - Exp int64 `json:"exp,omitempty"` - - // Integer timestamp, measured in the number of seconds since - // January 1 1970 UTC, indicating when this token was originally - // issued. - Iat int64 `json:"iat,omitempty"` - } - - TokenInvalidIntrospectResponse struct { - Active bool `json:"active"` - } - - TokenRevocationResponse struct{} - - RequestHandler struct { - config *domain.Config - tokens token.UseCase - tickets ticket.UseCase - } -) - -func NewRequestHandler(tokens token.UseCase, tickets ticket.UseCase, config *domain.Config) *RequestHandler { - return &RequestHandler{ +func NewHandler(tokens token.UseCase, tickets ticket.UseCase, config *domain.Config) *Handler { + return &Handler{ config: config, tokens: tokens, tickets: tickets, } } -func (h *RequestHandler) Register(r *router.Router) { +func (h *Handler) Handler() http.Handler { chain := middleware.Chain{ //nolint:exhaustivestruct middleware.JWTWithConfig(middleware.JWTConfig{ - AuthScheme: "Bearer", - ContextKey: "token", + Skipper: func(_ http.ResponseWriter, r *http.Request) bool { + head, _ := urlutil.ShiftPath(r.URL.Path) + + return head == "token" + }, SigningKey: []byte(h.config.JWT.Secret), SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm), - Skipper: func(ctx *http.RequestCtx) bool { - matched, _ := path.Match("/token*", string(ctx.Path())) - - return matched - }, - SuccessHandler: nil, - TokenLookup: "param:token,header:" + http.HeaderAuthorization + ":Bearer ", + ContextKey: "token", + TokenLookup: "form:token," + "header:" + common.HeaderAuthorization + ":Bearer ", + AuthScheme: "Bearer", }), - middleware.LogFmt(), } - r.POST("/token", chain.RequestHandler(h.handleAction)) - r.POST("/introspect", chain.RequestHandler(h.handleIntrospect)) - r.POST("/revocation", chain.RequestHandler(h.handleRevokation)) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + + return + } + + var head string + head, r.URL.Path = urlutil.ShiftPath(r.URL.Path) + + switch head { + default: + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + case "token": + chain.Handler(h.handleAction).ServeHTTP(w, r) + case "introspect": + chain.Handler(h.handleIntrospect).ServeHTTP(w, r) + case "revocation": + chain.Handler(h.handleRevokation).ServeHTTP(w, r) + } + }) } -func (h *RequestHandler) handleIntrospect(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) - ctx.SetStatusCode(http.StatusOK) - - encoder := json.NewEncoder(ctx) - - req := new(TokenIntrospectRequest) - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) - - _ = encoder.Encode(err) +func (h *Handler) handleIntrospect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } - tkn, _, err := h.tokens.Verify(ctx, req.Token) + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(w) + + req := new(TokenIntrospectRequest) + if err := req.bind(r); err != nil { + _ = encoder.Encode(err) + + w.WriteHeader(http.StatusBadRequest) + + return + } + + tkn, _, err := h.tokens.Verify(r.Context(), req.Token) if err != nil || tkn == nil { // WARN(toby3d): If the token is not valid, the endpoint still // MUST return a 200 Response. - _ = encoder.Encode(&TokenInvalidIntrospectResponse{ - Active: false, - }) + _ = encoder.Encode(&TokenInvalidIntrospectResponse{Active: false}) + + w.WriteHeader(http.StatusOK) return } @@ -194,68 +107,83 @@ func (h *RequestHandler) handleIntrospect(ctx *http.RequestCtx) { Me: tkn.Me.String(), Scope: tkn.Scope.String(), }) + + w.WriteHeader(http.StatusOK) } -func (h *RequestHandler) handleAction(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) +func (h *Handler) handleAction(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - encoder := json.NewEncoder(ctx) + return + } + + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(w) switch { - case ctx.PostArgs().Has("grant_type"): - h.handleExchange(ctx) - case ctx.PostArgs().Has("action"): - action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action"))) - if err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + case r.PostForm.Has("grant_type"): + h.handleExchange(w, r) + case r.PostForm.Has("action"): + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) - _ = encoder.Encode(domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "", - )) + _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")) + + return + } + + action, err := domain.ParseAction(r.PostForm.Get("action")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + + _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")) return } switch action { case domain.ActionRevoke: - h.handleRevokation(ctx) + h.handleRevokation(w, r) case domain.ActionTicket: - h.handleTicket(ctx) + h.handleTicket(w, r) } } } //nolint:funlen -func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) +func (h *Handler) handleExchange(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - encoder := json.NewEncoder(ctx) + return + } + + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(w) req := new(TokenExchangeRequest) - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(err) return } - token, profile, err := h.tokens.Exchange(ctx, token.ExchangeOptions{ + token, profile, err := h.tokens.Exchange(r.Context(), token.ExchangeOptions{ ClientID: req.ClientID, RedirectURI: req.RedirectURI.URL, Code: req.Code, CodeVerifier: req.CodeVerifier, }) if err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + w.WriteHeader(http.StatusBadRequest) - _ = encoder.Encode(domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieauth.net/source/#request", - )) + _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), + "https://indieauth.net/source/#request")) return } @@ -294,62 +222,69 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { } _ = encoder.Encode(resp) + + w.WriteHeader(http.StatusOK) } -func (h *RequestHandler) handleRevokation(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) - ctx.SetStatusCode(http.StatusOK) +func (h *Handler) handleRevokation(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - encoder := json.NewEncoder(ctx) + return + } + + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(w) req := NewTokenRevocationRequest() - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(err) return } - if err := h.tokens.Revoke(ctx, req.Token); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + if err := h.tokens.Revoke(r.Context(), req.Token); err != nil { + w.WriteHeader(http.StatusBadRequest) - _ = encoder.Encode(domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "", - )) + _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")) return } _ = encoder.Encode(&TokenRevocationResponse{}) + + w.WriteHeader(http.StatusOK) } -func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) - ctx.SetStatusCode(http.StatusOK) +func (h *Handler) handleTicket(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - encoder := json.NewEncoder(ctx) + return + } + + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + + encoder := json.NewEncoder(w) req := new(TokenTicketRequest) - if err := req.bind(ctx); err != nil { - ctx.SetStatusCode(http.StatusBadRequest) + if err := req.bind(r); err != nil { + w.WriteHeader(http.StatusBadRequest) _ = encoder.Encode(err) return } - tkn, err := h.tickets.Exchange(ctx, req.Ticket) + tkn, err := h.tickets.Exchange(r.Context(), req.Ticket) if err != nil { - ctx.SetStatusCode(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) - _ = encoder.Encode(domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieauth.net/source/#request", - )) + _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), + "https://indieauth.net/source/#request")) return } @@ -361,81 +296,6 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) { ExpiresIn: tkn.Expiry.Unix(), RefreshToken: "", // TODO(toby3d) }) -} - -func (r *TokenExchangeRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieauth.net/source/#request", - ) - } - - return nil -} - -func NewTokenRevocationRequest() *TokenRevocationRequest { - return &TokenRevocationRequest{ - Action: domain.ActionRevoke, - Token: "", - } -} - -func (r *TokenRevocationRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - - err := form.Unmarshal(ctx.PostArgs().QueryString(), r) - if err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieauth.net/source/#request", - ) - } - - return nil -} - -func (r *TokenTicketRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieauth.net/source/#request", - ) - } - - return nil -} - -func (r *TokenIntrospectRequest) bind(ctx *http.RequestCtx) error { - indieAuthError := new(domain.Error) - if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil { - if errors.As(err, indieAuthError) { - return indieAuthError - } - - return domain.NewError( - domain.ErrorCodeInvalidRequest, - err.Error(), - "https://indieauth.net/source/#access-token-verification-request", - ) - } - - return nil + + w.WriteHeader(http.StatusOK) } diff --git a/internal/token/delivery/http/token_http_schema.go b/internal/token/delivery/http/token_http_schema.go new file mode 100644 index 0000000..da7b872 --- /dev/null +++ b/internal/token/delivery/http/token_http_schema.go @@ -0,0 +1,210 @@ +package http + +import ( + "errors" + "net/http" + + "source.toby3d.me/toby3d/auth/internal/domain" + "source.toby3d.me/toby3d/form" +) + +type ( + TokenExchangeRequest struct { + ClientID domain.ClientID `form:"client_id"` + RedirectURI domain.URL `form:"redirect_uri"` + GrantType domain.GrantType `form:"grant_type"` + Code string `form:"code"` + CodeVerifier string `form:"code_verifier"` + } + + TokenRefreshRequest struct { + GrantType domain.GrantType `form:"grant_type"` // refresh_token + + // The client ID that was used when the refresh token was issued. + ClientID domain.ClientID `form:"client_id"` + + // The client may request a token with the same or fewer scopes + // than the original access token. If omitted, is treated as + // equal to the original scopes granted. + Scope domain.Scopes `form:"scope"` + + // The refresh token previously offered to the client. + RefreshToken string `form:"refresh_token"` + } + + TokenRevocationRequest struct { + Action domain.Action `form:"action,omitempty"` + Token string `form:"token"` + } + + TokenTicketRequest struct { + Action domain.Action `form:"action"` + Ticket string `form:"ticket"` + } + + TokenIntrospectRequest struct { + Token string `form:"token"` + } + + //nolint:tagliatelle // https://indieauth.net/source/#access-token-response + TokenExchangeResponse struct { + // The user's profile information. + Profile *TokenProfileResponse `json:"profile,omitempty"` + + // The OAuth 2.0 Bearer Token RFC6750. + AccessToken string `json:"access_token"` + + // The canonical user profile URL for the user this access token + // corresponds to. + Me string `json:"me"` + + // The refresh token, which can be used to obtain new access + // tokens. + RefreshToken string `json:"refresh_token"` + + // The lifetime in seconds of the access token. + ExpiresIn int64 `json:"expires_in,omitempty"` + } + + TokenProfileResponse struct { + // Name the user wishes to provide to the client. + Name string `json:"name,omitempty"` + + // URL of the user's website. + URL string `json:"url,omitempty"` + + // A photo or image that the user wishes clients to use as a + // profile image. + Photo string `json:"photo,omitempty"` + + // The email address a user wishes to provide to the client. + Email string `json:"email,omitempty"` + } + + //nolint:tagliatelle // https://indieauth.net/source/#access-token-verification-response + TokenIntrospectResponse struct { + // The profile URL of the user corresponding to this token. + Me string `json:"me"` + + // The client ID associated with this token. + ClientID string `json:"client_id"` + + // A space-separated list of scopes associated with this token. + Scope string `json:"scope"` + + // Integer timestamp, measured in the number of seconds since + // January 1 1970 UTC, indicating when this token will expire. + Exp int64 `json:"exp,omitempty"` + + // Integer timestamp, measured in the number of seconds since + // January 1 1970 UTC, indicating when this token was originally + // issued. + Iat int64 `json:"iat,omitempty"` + + // Boolean indicator of whether or not the presented token is + // currently active. + Active bool `json:"active"` + } + + TokenInvalidIntrospectResponse struct { + Active bool `json:"active"` + } + + TokenRevocationResponse struct{} +) + +func (r *TokenExchangeRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#request", + ) + } + + return nil +} + +func NewTokenRevocationRequest() *TokenRevocationRequest { + return &TokenRevocationRequest{ + Action: domain.ActionRevoke, + Token: "", + } +} + +func (r *TokenRevocationRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := req.ParseForm(); err != nil { + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#authorization-request", + ) + } + + err := form.Unmarshal([]byte(req.PostForm.Encode()), r) + if err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#request", + ) + } + + return nil +} + +func (r *TokenTicketRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#request", + ) + } + + return nil +} + +func (r *TokenIntrospectRequest) bind(req *http.Request) error { + indieAuthError := new(domain.Error) + + if err := req.ParseForm(); err != nil { + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#authorization-request", + ) + } + + if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil { + if errors.As(err, indieAuthError) { + return indieAuthError + } + + return domain.NewError( + domain.ErrorCodeInvalidRequest, + err.Error(), + "https://indieauth.net/source/#access-token-verification-request", + ) + } + + return nil +} diff --git a/internal/token/delivery/http/token_http_test.go b/internal/token/delivery/http/token_http_test.go index e32b5f3..d31a8c6 100644 --- a/internal/token/delivery/http/token_http_test.go +++ b/internal/token/delivery/http/token_http_test.go @@ -3,12 +3,13 @@ package http_test import ( "bytes" "context" - "sync" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" - "github.com/fasthttp/router" - json "github.com/goccy/go-json" - http "github.com/valyala/fasthttp" + "github.com/goccy/go-json" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" @@ -16,7 +17,6 @@ import ( profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory" "source.toby3d.me/toby3d/auth/internal/session" sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory" - "source.toby3d.me/toby3d/auth/internal/testing/httptest" "source.toby3d.me/toby3d/auth/internal/ticket" ticketrepo "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory" ticketucase "source.toby3d.me/toby3d/auth/internal/ticket/usecase" @@ -31,7 +31,6 @@ type Dependencies struct { config *domain.Config profiles profile.Repository sessions session.Repository - store *sync.Map tickets ticket.Repository ticketService ticket.UseCase token *domain.Token @@ -50,32 +49,24 @@ func TestIntrospection(t *testing.T) { deps := NewDependencies(t) - r := router.New() - delivery.NewRequestHandler(deps.tokenService, deps.ticketService, deps.config).Register(r) + req := httptest.NewRequest(http.MethodPost, "https://app.example.com/introspect", + strings.NewReader("token="+deps.token.AccessToken)) + req.Header.Set(common.HeaderAccept, common.MIMEApplicationJSON) + req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm) - client, _, cleanup := httptest.New(t, r.Handler) - t.Cleanup(cleanup) + w := httptest.NewRecorder() + delivery.NewHandler(deps.tokenService, deps.ticketService, deps.config). + Handler(). + ServeHTTP(w, req) - const requestURL = "https://app.example.com/introspect" + resp := w.Result() - req := httptest.NewRequest(http.MethodPost, requestURL, []byte("token="+deps.token.AccessToken)) - defer http.ReleaseRequest(req) - req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON) - req.Header.SetContentType(common.MIMEApplicationForm) - - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - - if err := client.Do(req, resp); err != nil { - t.Fatal(err) - } - - if result := resp.StatusCode(); result != http.StatusOK { - t.Errorf("GET %s = %d, want %d", requestURL, result, http.StatusOK) + if result := resp.StatusCode; result != http.StatusOK { + t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, result, http.StatusOK) } result := new(delivery.TokenIntrospectResponse) - if err := json.Unmarshal(resp.Body(), result); err != nil { + if err := json.NewDecoder(resp.Body).Decode(result); err != nil { t.Fatal(err) } @@ -84,7 +75,7 @@ func TestIntrospection(t *testing.T) { if result.ClientID != deps.token.ClientID.String() || result.Me != deps.token.Me.String() || result.Scope != deps.token.Scope.String() { - t.Errorf("GET %s = %+v, want %+v", requestURL, result, deps.token) + t.Errorf("%s %s = %+v, want %+v", req.Method, req.RequestURI, result, deps.token) } } @@ -93,33 +84,30 @@ func TestRevocation(t *testing.T) { deps := NewDependencies(t) - r := router.New() - delivery.NewRequestHandler(deps.tokenService, deps.ticketService, deps.config).Register(r) + req := httptest.NewRequest(http.MethodPost, "https://app.example.com/revocation", + strings.NewReader(`token=`+deps.token.AccessToken)) + req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm) + req.Header.Set(common.HeaderAccept, common.MIMEApplicationJSON) - client, _, cleanup := httptest.New(t, r.Handler) - t.Cleanup(cleanup) + w := httptest.NewRecorder() + delivery.NewHandler(deps.tokenService, deps.ticketService, deps.config). + Handler(). + ServeHTTP(w, req) - const requestURL = "https://app.example.com/revocation" + resp := w.Result() - req := httptest.NewRequest(http.MethodPost, requestURL, []byte("token="+deps.token.AccessToken)) - defer http.ReleaseRequest(req) - req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON) - req.Header.SetContentType(common.MIMEApplicationForm) - - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - - if err := client.Do(req, resp); err != nil { + body, err := io.ReadAll(resp.Body) + if err != nil { t.Fatal(err) } - if result := resp.StatusCode(); result != http.StatusOK { - t.Errorf("POST %s = %d, want %d", requestURL, result, http.StatusOK) + if resp.StatusCode != http.StatusOK { + t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK) } expBody := []byte("{}") //nolint:ifshort - if result := bytes.TrimSpace(resp.Body()); !bytes.Equal(result, expBody) { - t.Errorf("POST %s = %s, want %s", requestURL, result, expBody) + if result := bytes.TrimSpace(body); !bytes.Equal(result, expBody) { + t.Errorf("%s %s = %s, want %s", req.Method, req.RequestURI, result, expBody) } result, err := deps.tokens.Get(context.Background(), deps.token.AccessToken) @@ -135,14 +123,13 @@ func TestRevocation(t *testing.T) { func NewDependencies(tb testing.TB) Dependencies { tb.Helper() - store := new(sync.Map) client := new(http.Client) config := domain.TestConfig(tb) token := domain.TestToken(tb) - profiles := profilerepo.NewMemoryProfileRepository(store) - sessions := sessionrepo.NewMemorySessionRepository(store, config) - tickets := ticketrepo.NewMemoryTicketRepository(store, config) - tokens := tokenrepo.NewMemoryTokenRepository(store) + profiles := profilerepo.NewMemoryProfileRepository() + sessions := sessionrepo.NewMemorySessionRepository(*config) + tickets := ticketrepo.NewMemoryTicketRepository(*config) + tokens := tokenrepo.NewMemoryTokenRepository() ticketService := ticketucase.NewTicketUseCase(tickets, client, config) tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{ Config: config, @@ -156,7 +143,6 @@ func NewDependencies(tb testing.TB) Dependencies { config: config, profiles: profiles, sessions: sessions, - store: store, tickets: tickets, ticketService: ticketService, token: token, diff --git a/internal/token/repository.go b/internal/token/repository.go index 65b6896..1240bce 100644 --- a/internal/token/repository.go +++ b/internal/token/repository.go @@ -7,8 +7,8 @@ import ( ) type Repository interface { + Create(ctx context.Context, accessToken domain.Token) error Get(ctx context.Context, accessToken string) (*domain.Token, error) - Create(ctx context.Context, accessToken *domain.Token) error } var ( diff --git a/internal/token/repository/memory/memory_token.go b/internal/token/repository/memory/memory_token.go index c858074..49aa1d8 100644 --- a/internal/token/repository/memory/memory_token.go +++ b/internal/token/repository/memory/memory_token.go @@ -2,8 +2,6 @@ package memory import ( "context" - "errors" - "path" "sync" "source.toby3d.me/toby3d/auth/internal/domain" @@ -11,42 +9,33 @@ import ( ) type memoryTokenRepository struct { - store *sync.Map + mutex *sync.RWMutex + tokens map[string]domain.Token } -const DefaultPathPrefix string = "tokens" - -func NewMemoryTokenRepository(store *sync.Map) token.Repository { +func NewMemoryTokenRepository() token.Repository { return &memoryTokenRepository{ - store: store, + mutex: new(sync.RWMutex), + tokens: make(map[string]domain.Token), } } -func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { - t, err := repo.Get(ctx, accessToken.AccessToken) - if err != nil && !errors.Is(err, token.ErrNotExist) { - return err - } +func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken domain.Token) error { + repo.mutex.Lock() + defer repo.mutex.Unlock() - if t != nil { - return token.ErrExist - } - - repo.store.Store(path.Join(DefaultPathPrefix, accessToken.AccessToken), accessToken) + repo.tokens[accessToken.AccessToken] = accessToken return nil } func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { - t, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken)) - if !ok { - return nil, token.ErrNotExist + repo.mutex.RLock() + defer repo.mutex.RUnlock() + + if t, ok := repo.tokens[accessToken]; ok { + return &t, nil } - result, ok := t.(*domain.Token) - if !ok { - return nil, token.ErrNotExist - } - - return result, nil + return nil, token.ErrNotExist } diff --git a/internal/token/repository/memory/memory_token_test.go b/internal/token/repository/memory/memory_token_test.go deleted file mode 100644 index 991f993..0000000 --- a/internal/token/repository/memory/memory_token_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package memory_test - -import ( - "context" - "path" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - - "source.toby3d.me/toby3d/auth/internal/domain" - "source.toby3d.me/toby3d/auth/internal/token" - repository "source.toby3d.me/toby3d/auth/internal/token/repository/memory" -) - -func TestCreate(t *testing.T) { - t.Parallel() - - store := new(sync.Map) - accessToken := domain.TestToken(t) - - repo := repository.NewMemoryTokenRepository(store) - if err := repo.Create(context.Background(), accessToken); err != nil { - t.Fatal(err) - } - - result, ok := store.Load(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken)) - assert.True(t, ok) - assert.Equal(t, accessToken, result) - - assert.ErrorIs(t, repo.Create(context.Background(), accessToken), token.ErrExist) -} - -func TestGet(t *testing.T) { - t.Parallel() - - store := new(sync.Map) - accessToken := domain.TestToken(t) - - store.Store(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken), accessToken) - - result, err := repository.NewMemoryTokenRepository(store).Get(context.Background(), accessToken.AccessToken) - assert.NoError(t, err) - assert.Equal(t, accessToken, result) -} diff --git a/internal/token/repository/sqlite3/sqlite3_token.go b/internal/token/repository/sqlite3/sqlite3_token.go index 28d7c5b..25b268b 100644 --- a/internal/token/repository/sqlite3/sqlite3_token.go +++ b/internal/token/repository/sqlite3/sqlite3_token.go @@ -53,8 +53,8 @@ func NewSQLite3TokenRepository(db *sqlx.DB) token.Repository { } } -func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { - if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewToken(accessToken)); err != nil { +func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken domain.Token) error { + if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewToken(&accessToken)); err != nil { return fmt.Errorf("cannot create token record in db: %w", err) } @@ -91,9 +91,11 @@ func NewToken(src *domain.Token) *Token { } func (t *Token) Populate(dst *domain.Token) { + cid, _ := domain.ParseClientID(t.ClientID) + me, _ := domain.ParseMe(t.Me) dst.AccessToken = t.AccessToken - dst.ClientID, _ = domain.ParseClientID(t.ClientID) - dst.Me, _ = domain.ParseMe(t.Me) + dst.ClientID = *cid + dst.Me = *me dst.Scope = make(domain.Scopes, 0) for _, scope := range strings.Fields(t.Scope) { diff --git a/internal/token/repository/sqlite3/sqlite3_token_test.go b/internal/token/repository/sqlite3/sqlite3_token_test.go index 04c216d..1f181cc 100644 --- a/internal/token/repository/sqlite3/sqlite3_token_test.go +++ b/internal/token/repository/sqlite3/sqlite3_token_test.go @@ -35,7 +35,7 @@ func TestCreate(t *testing.T) { ). WillReturnResult(sqlmock.NewResult(1, 1)) - if err := repository.NewSQLite3TokenRepository(db).Create(context.Background(), token); err != nil { + if err := repository.NewSQLite3TokenRepository(db).Create(context.Background(), *token); err != nil { t.Error(err) } } diff --git a/internal/token/usecase.go b/internal/token/usecase.go index 818514f..e6c5cea 100644 --- a/internal/token/usecase.go +++ b/internal/token/usecase.go @@ -9,7 +9,7 @@ import ( type ( ExchangeOptions struct { - ClientID *domain.ClientID + ClientID domain.ClientID RedirectURI *url.URL Code string CodeVerifier string diff --git a/internal/token/usecase/token_ucase.go b/internal/token/usecase/token_ucase.go index 1b49423..e73f087 100644 --- a/internal/token/usecase/token_ucase.go +++ b/internal/token/usecase/token_ucase.go @@ -107,17 +107,17 @@ func (uc *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain return nil, nil, fmt.Errorf("cannot validate JWT token: %w", err) } + cid, _ := domain.ParseClientID(tkn.Issuer()) + me, _ := domain.ParseMe(tkn.Subject()) result := &domain.Token{ CreatedAt: tkn.IssuedAt(), Expiry: tkn.Expiration(), - ClientID: nil, - Me: nil, + ClientID: *cid, + Me: *me, Scope: nil, AccessToken: accessToken, RefreshToken: "", // TODO(toby3d) } - result.ClientID, _ = domain.ParseClientID(tkn.Issuer()) - result.Me, _ = domain.ParseMe(tkn.Subject()) if scope, ok := tkn.Get("scope"); ok { result.Scope, _ = scope.(domain.Scopes) @@ -149,7 +149,7 @@ func (uc *tokenUseCase) Revoke(ctx context.Context, accessToken string) error { return fmt.Errorf("cannot verify token: %w", err) } - if err = uc.tokens.Create(ctx, tkn); err != nil && !errors.Is(err, token.ErrExist) { + if err = uc.tokens.Create(ctx, *tkn); err != nil && !errors.Is(err, token.ErrExist) { return fmt.Errorf("cannot save token in database: %w", err) } diff --git a/internal/token/usecase/token_ucase_test.go b/internal/token/usecase/token_ucase_test.go index 47e8b4e..63bfb2c 100644 --- a/internal/token/usecase/token_ucase_test.go +++ b/internal/token/usecase/token_ucase_test.go @@ -2,8 +2,6 @@ package usecase_test import ( "context" - "path" - "sync" "testing" "source.toby3d.me/toby3d/auth/internal/domain" @@ -22,7 +20,6 @@ type Dependencies struct { profiles profile.Repository session *domain.Session sessions session.Repository - store *sync.Map token *domain.Token tokens token.Repository } @@ -31,9 +28,12 @@ func TestExchange(t *testing.T) { t.Parallel() deps := NewDependencies(t) - deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, deps.session.Me.String()), deps.profile) - if err := deps.sessions.Create(context.Background(), deps.session); err != nil { + if err := deps.profiles.Create(context.Background(), deps.session.Me, *deps.profile); err != nil { + t.Fatal(err) + } + + if err := deps.sessions.Create(context.Background(), *deps.session); err != nil { t.Fatal(err) } @@ -95,7 +95,7 @@ func TestVerify(t *testing.T) { t.Parallel() testToken := domain.TestToken(t) - if err := deps.tokens.Create(context.Background(), testToken); err != nil { + if err := deps.tokens.Create(context.Background(), *testToken); err != nil { t.Fatal(err) } @@ -136,17 +136,15 @@ func TestRevoke(t *testing.T) { func NewDependencies(tb testing.TB) Dependencies { tb.Helper() - store := new(sync.Map) config := domain.TestConfig(tb) return Dependencies{ config: config, profile: domain.TestProfile(tb), - profiles: profilerepo.NewMemoryProfileRepository(store), + profiles: profilerepo.NewMemoryProfileRepository(), session: domain.TestSession(tb), - sessions: sessionrepo.NewMemorySessionRepository(store, config), - store: store, + sessions: sessionrepo.NewMemorySessionRepository(*config), token: domain.TestToken(tb), - tokens: tokenrepo.NewMemoryTokenRepository(store), + tokens: tokenrepo.NewMemoryTokenRepository(), } } diff --git a/internal/user/delivery/http/user_http.go b/internal/user/delivery/http/user_http.go index 88548fb..9e13d2c 100644 --- a/internal/user/delivery/http/user_http.go +++ b/internal/user/delivery/http/user_http.go @@ -1,10 +1,10 @@ package http import ( - "encoding/json" "net/http" "strings" + "github.com/goccy/go-json" "github.com/lestrrat-go/jwx/v2/jwa" "source.toby3d.me/toby3d/auth/internal/common" @@ -13,19 +13,10 @@ import ( "source.toby3d.me/toby3d/auth/internal/token" ) -type ( - UserInformationResponse struct { - Name string `json:"name,omitempty"` - URL string `json:"url,omitempty"` - Photo string `json:"photo,omitempty"` - Email string `json:"email,omitempty"` - } - - Handler struct { - config *domain.Config - tokens token.UseCase - } -) +type Handler struct { + config *domain.Config + tokens token.UseCase +} func NewHandler(tokens token.UseCase, config *domain.Config) *Handler { return &Handler{ @@ -34,7 +25,7 @@ func NewHandler(tokens token.UseCase, config *domain.Config) *Handler { } } -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *Handler) Handler() http.Handler { chain := middleware.Chain{ //nolint:exhaustivestruct middleware.JWTWithConfig(middleware.JWTConfig{ @@ -45,13 +36,18 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Skipper: middleware.DefaultSkipper, TokenLookup: "header:" + common.HeaderAuthorization + ":Bearer ", }), - middleware.LogFmt(), } - chain.Handler(h.handleFunc).ServeHTTP(w, r) + return chain.Handler(h.handleFunc) } func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) { + if r.Method != "" && r.Method != http.MethodGet { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + + return + } + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) encoder := json.NewEncoder(w) diff --git a/internal/user/delivery/http/user_http_schema.go b/internal/user/delivery/http/user_http_schema.go new file mode 100644 index 0000000..28c8668 --- /dev/null +++ b/internal/user/delivery/http/user_http_schema.go @@ -0,0 +1,8 @@ +package http + +type UserInformationResponse struct { + Name string `json:"name,omitempty"` + URL string `json:"url,omitempty"` + Photo string `json:"photo,omitempty"` + Email string `json:"email,omitempty"` +} diff --git a/internal/user/delivery/http/user_http_test.go b/internal/user/delivery/http/user_http_test.go index 08821b5..fe8b3ee 100644 --- a/internal/user/delivery/http/user_http_test.go +++ b/internal/user/delivery/http/user_http_test.go @@ -1,13 +1,12 @@ package http_test import ( + "context" + "net/http" "net/http/httptest" - "path" - "sync" "testing" "github.com/goccy/go-json" - http "github.com/valyala/fasthttp" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" @@ -26,7 +25,6 @@ type Dependencies struct { profile *domain.Profile profiles profile.Repository sessions session.Repository - store *sync.Map token *domain.Token tokens token.Repository tokenService token.UseCase @@ -36,13 +34,17 @@ func TestUserInfo(t *testing.T) { t.Parallel() deps := NewDependencies(t) - deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, deps.token.Me.String()), deps.profile) + if err := deps.profiles.Create(context.Background(), deps.token.Me, *deps.profile); err != nil { + t.Fatal(err) + } req := httptest.NewRequest(http.MethodGet, "https://example.com/userinfo", nil) req.Header.Set(common.HeaderAuthorization, "Bearer "+deps.token.AccessToken) w := httptest.NewRecorder() - delivery.NewHandler(deps.tokenService, deps.config).ServeHTTP(w, req) + delivery.NewHandler(deps.tokenService, deps.config). + Handler(). + ServeHTTP(w, req) resp := w.Result() @@ -69,22 +71,23 @@ func TestUserInfo(t *testing.T) { func NewDependencies(tb testing.TB) Dependencies { tb.Helper() - store := new(sync.Map) config := domain.TestConfig(tb) + sessions := sessionrepo.NewMemorySessionRepository(*config) + tokens := tokenrepo.NewMemoryTokenRepository() + profiles := profilerepo.NewMemoryProfileRepository() return Dependencies{ config: config, profile: domain.TestProfile(tb), - profiles: profilerepo.NewMemoryProfileRepository(store), - sessions: sessionrepo.NewMemorySessionRepository(store, config), - store: store, + profiles: profiles, + sessions: sessions, token: domain.TestToken(tb), - tokens: tokenrepo.NewMemoryTokenRepository(store), + tokens: tokens, tokenService: tokenucase.NewTokenUseCase(tokenucase.Config{ Config: config, - Profiles: profilerepo.NewMemoryProfileRepository(store), - Sessions: sessionrepo.NewMemorySessionRepository(store, config), - Tokens: tokenrepo.NewMemoryTokenRepository(store), + Profiles: profiles, + Sessions: sessions, + Tokens: tokens, }), } } diff --git a/internal/user/repository.go b/internal/user/repository.go index ec7edc1..b9dbd2c 100644 --- a/internal/user/repository.go +++ b/internal/user/repository.go @@ -7,7 +7,8 @@ import ( ) type Repository interface { - Get(ctx context.Context, me *domain.Me) (*domain.User, error) + Create(ctx context.Context, user domain.User) error + Get(ctx context.Context, me domain.Me) (*domain.User, error) } var ErrNotExist error = domain.NewError(domain.ErrorCodeServerError, "user not exist", "") diff --git a/internal/user/repository/http/http_user.go b/internal/user/repository/http/http_user.go index bd3252f..9901119 100644 --- a/internal/user/repository/http/http_user.go +++ b/internal/user/repository/http/http_user.go @@ -1,12 +1,16 @@ package http import ( + "bytes" "context" "fmt" + "io" + "net/http" "net/url" - http "github.com/valyala/fasthttp" + "golang.org/x/exp/slices" + "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/httputil" "source.toby3d.me/toby3d/auth/internal/user" @@ -38,26 +42,21 @@ func NewHTTPUserRepository(client *http.Client) user.Repository { } } -func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain.User, error) { - req := http.AcquireRequest() - defer http.ReleaseRequest(req) - req.Header.SetMethod(http.MethodGet) - req.SetRequestURI(me.String()) +// WARN(toby3d): not implemented. +func (httpUserRepository) Create(_ context.Context, _ domain.User) error { + return nil +} - resp := http.AcquireResponse() - defer http.ReleaseResponse(resp) - - if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil { +func (repo *httpUserRepository) Get(ctx context.Context, me domain.Me) (*domain.User, error) { + resp, err := repo.client.Get(me.String()) + if err != nil { return nil, fmt.Errorf("cannot fetch user by me: %w", err) } - // TODO(toby3d): handle error here? - resolvedMe, _ := domain.ParseMe(string(resp.Header.Peek(http.HeaderLocation))) - user := &domain.User{ AuthorizationEndpoint: nil, IndieAuthMetadata: nil, - Me: resolvedMe, + Me: &me, Micropub: nil, Microsub: nil, Profile: domain.NewProfile(), @@ -65,7 +64,7 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain TokenEndpoint: nil, } - if metadata, err := httputil.ExtractMetadata(resp, repo.client); err == nil { + if metadata, err := httputil.ExtractFromMetadata(repo.client, me.String()); err == nil { user.AuthorizationEndpoint = metadata.AuthorizationEndpoint user.Micropub = metadata.MicropubEndpoint user.Microsub = metadata.MicrosubEndpoint @@ -73,89 +72,87 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain user.TokenEndpoint = metadata.TokenEndpoint } - extractUser(user, resp) - extractProfile(user.Profile, resp) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("cannot read response body: %w", err) + } + + extractUser(me.URL(), user, body, resp.Header.Get(common.HeaderLink)) + extractProfile(me.URL(), user.Profile, body) return user, nil } //nolint:cyclop -func extractUser(dst *domain.User, src *http.Response) { - if dst.IndieAuthMetadata != nil { - if endpoints := httputil.ExtractEndpoints(src, relIndieAuthMetadata); len(endpoints) > 0 { - dst.IndieAuthMetadata = endpoints[len(endpoints)-1] +func extractUser(u *url.URL, dst *domain.User, body []byte, header string) { + for key, target := range map[string]**url.URL{ + relAuthorizationEndpoint: &dst.AuthorizationEndpoint, + relIndieAuthMetadata: &dst.IndieAuthMetadata, + relMicropub: &dst.Micropub, + relMicrosub: &dst.Microsub, + relTicketEndpoint: &dst.TicketEndpoint, + relTokenEndpoint: &dst.TokenEndpoint, + } { + if target == nil { + continue } - } - if dst.AuthorizationEndpoint == nil { - if endpoints := httputil.ExtractEndpoints(src, relAuthorizationEndpoint); len(endpoints) > 0 { - dst.AuthorizationEndpoint = endpoints[len(endpoints)-1] - } - } - - if dst.Micropub == nil { - if endpoints := httputil.ExtractEndpoints(src, relMicropub); len(endpoints) > 0 { - dst.Micropub = endpoints[len(endpoints)-1] - } - } - - if dst.Microsub == nil { - if endpoints := httputil.ExtractEndpoints(src, relMicrosub); len(endpoints) > 0 { - dst.Microsub = endpoints[len(endpoints)-1] - } - } - - if dst.TicketEndpoint == nil { - if endpoints := httputil.ExtractEndpoints(src, relTicketEndpoint); len(endpoints) > 0 { - dst.TicketEndpoint = endpoints[len(endpoints)-1] - } - } - - if dst.TokenEndpoint == nil { - if endpoints := httputil.ExtractEndpoints(src, relTokenEndpoint); len(endpoints) > 0 { - dst.TokenEndpoint = endpoints[len(endpoints)-1] + if endpoints := httputil.ExtractEndpoints(bytes.NewReader(body), u, header, key); len(endpoints) > 0 { + *target = endpoints[len(endpoints)-1] } } } //nolint:cyclop -func extractProfile(dst *domain.Profile, src *http.Response) { - for _, name := range httputil.ExtractProperty(src, hCard, propertyName) { - if n, ok := name.(string); ok { +func extractProfile(u *url.URL, dst *domain.Profile, body []byte) { + for _, name := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyName) { + if n, ok := name.(string); ok && !slices.Contains(dst.Name, n) { dst.Name = append(dst.Name, n) } } - for _, rawEmail := range httputil.ExtractProperty(src, hCard, propertyEmail) { + for _, rawEmail := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyEmail) { email, ok := rawEmail.(string) if !ok { continue } - if e, err := domain.ParseEmail(email); err == nil { + if e, err := domain.ParseEmail(email); err == nil && !slices.Contains(dst.Email, e) { dst.Email = append(dst.Email, e) } } - for _, rawURL := range httputil.ExtractProperty(src, hCard, propertyURL) { + for _, rawURL := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyURL) { rawURL, ok := rawURL.(string) if !ok { continue } - if u, err := url.Parse(rawURL); err == nil { + if u, err := url.Parse(rawURL); err == nil && !containsUrl(dst.URL, u) { dst.URL = append(dst.URL, u) } } - for _, rawPhoto := range httputil.ExtractProperty(src, hCard, propertyPhoto) { + for _, rawPhoto := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyPhoto) { photo, ok := rawPhoto.(string) if !ok { continue } - if p, err := url.Parse(photo); err == nil { + if p, err := url.Parse(photo); err == nil && !containsUrl(dst.Photo, p) { dst.Photo = append(dst.Photo, p) } } } + +func containsUrl(src []*url.URL, find *url.URL) bool { + for i := range src { + if src[i].String() != find.String() { + continue + } + + return true + } + + return false +} diff --git a/internal/user/repository/http/http_user_test.go b/internal/user/repository/http/http_user_test.go index 083ad30..9a269d5 100644 --- a/internal/user/repository/http/http_user_test.go +++ b/internal/user/repository/http/http_user_test.go @@ -3,16 +3,15 @@ package http_test import ( "context" "fmt" + "net/http" + "net/http/httptest" "strings" "testing" - "github.com/fasthttp/router" - "github.com/stretchr/testify/assert" - http "github.com/valyala/fasthttp" + "github.com/google/go-cmp/cmp" "source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/domain" - "source.toby3d.me/toby3d/auth/internal/testing/httptest" repository "source.toby3d.me/toby3d/auth/internal/user/repository/http" ) @@ -40,39 +39,29 @@ func TestGet(t *testing.T) { t.Parallel() user := domain.TestUser(t) - client, _, cleanup := httptest.New(t, testHandler(t, user)) - t.Cleanup(cleanup) - result, err := repository.NewHTTPUserRepository(client).Get(context.Background(), user.Me) + srv := httptest.NewServer(testHandler(t, user)) + t.Cleanup(srv.Close) + + user.Me = domain.TestMe(t, srv.URL+"/") + + result, err := repository.NewHTTPUserRepository(srv.Client()). + Get(context.Background(), *user.Me) if err != nil { t.Fatal(err) } - // NOTE(toby3d): endpoints - assert.Equal(t, user.AuthorizationEndpoint.String(), result.AuthorizationEndpoint.String()) - assert.Equal(t, user.TokenEndpoint.String(), result.TokenEndpoint.String()) - assert.Equal(t, user.Micropub.String(), result.Micropub.String()) - assert.Equal(t, user.Microsub.String(), result.Microsub.String()) - - // NOTE(toby3d): profile - assert.Equal(t, user.Profile.Name, result.Profile.Name) - assert.Equal(t, user.Profile.Email, result.Profile.Email) - - for i := range user.Profile.URL { - assert.Equal(t, user.Profile.URL[i].String(), result.Profile.URL[i].String()) - } - - for i := range user.Profile.Photo { - assert.Equal(t, user.Profile.Photo[i].String(), result.Profile.Photo[i].String()) + if diff := cmp.Diff(user, result, cmp.AllowUnexported(domain.Me{}, domain.Email{})); diff != "" { + t.Errorf("%+s", diff) } } -func testHandler(tb testing.TB, user *domain.User) http.RequestHandler { +func testHandler(tb testing.TB, user *domain.User) http.Handler { tb.Helper() - router := router.New() - router.GET("/", func(ctx *http.RequestCtx) { - ctx.Response.Header.Set(http.HeaderLink, strings.Join([]string{ + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderLink, strings.Join([]string{ `<` + user.AuthorizationEndpoint.String() + `>; rel="authorization_endpoint"`, `<` + user.IndieAuthMetadata.String() + `>; rel="indieauth-metadata"`, `<` + user.Micropub.String() + `>; rel="micropub"`, @@ -80,17 +69,17 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler { `<` + user.TicketEndpoint.String() + `>; rel="ticket_endpoint"`, `<` + user.TokenEndpoint.String() + `>; rel="token_endpoint"`, }, ", ")) - ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, fmt.Sprintf( - testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0], - )) + w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8) + fmt.Fprintf(w, testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0]) }) - router.GET(user.IndieAuthMetadata.Path, func(ctx *http.RequestCtx) { - ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{ + mux.HandleFunc(user.IndieAuthMetadata.Path, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) + fmt.Fprint(w, `{ "issuer": "`+user.Me.String()+`", "authorization_endpoint": "`+user.AuthorizationEndpoint.String()+`", "token_endpoint": "`+user.TokenEndpoint.String()+`" }`) }) - return router.Handler + return mux } diff --git a/internal/user/repository/memory/memory_user.go b/internal/user/repository/memory/memory_user.go index 18d5385..5478861 100644 --- a/internal/user/repository/memory/memory_user.go +++ b/internal/user/repository/memory/memory_user.go @@ -2,7 +2,6 @@ package memory import ( "context" - "path" "sync" "source.toby3d.me/toby3d/auth/internal/domain" @@ -10,27 +9,33 @@ import ( ) type memoryUserRepository struct { - store *sync.Map + mutex *sync.RWMutex + users map[string]domain.User } -const DefaultPathPrefix string = "users" - -func NewMemoryUserRepository(store *sync.Map) user.Repository { +func NewMemoryUserRepository() user.Repository { return &memoryUserRepository{ - store: store, + mutex: new(sync.RWMutex), + users: make(map[string]domain.User), } } -func (repo *memoryUserRepository) Get(ctx context.Context, me *domain.Me) (*domain.User, error) { - p, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String())) - if !ok { - return nil, user.ErrNotExist - } +func (repo *memoryUserRepository) Create(ctx context.Context, user domain.User) error { + repo.mutex.Lock() + defer repo.mutex.Unlock() - result, ok := p.(*domain.User) - if !ok { - return nil, user.ErrNotExist - } + repo.users[user.Me.String()] = user - return result, nil + return nil +} + +func (repo *memoryUserRepository) Get(ctx context.Context, me domain.Me) (*domain.User, error) { + repo.mutex.RLock() + defer repo.mutex.RUnlock() + + if u, ok := repo.users[me.String()]; ok { + return &u, nil + } + + return nil, user.ErrNotExist } diff --git a/internal/user/repository/memory/memory_user_test.go b/internal/user/repository/memory/memory_user_test.go deleted file mode 100644 index 397a464..0000000 --- a/internal/user/repository/memory/memory_user_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package memory_test - -import ( - "context" - "path" - "reflect" - "sync" - "testing" - - "source.toby3d.me/toby3d/auth/internal/domain" - repository "source.toby3d.me/toby3d/auth/internal/user/repository/memory" -) - -func TestGet(t *testing.T) { - t.Parallel() - - user := domain.TestUser(t) - - store := new(sync.Map) - store.Store(path.Join(repository.DefaultPathPrefix, user.Me.String()), user) - - result, err := repository.NewMemoryUserRepository(store).Get(context.Background(), user.Me) - if err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(result, user) { - t.Errorf("Get(%s) = %+v, want %+v", user.Me, result, user) - } -} diff --git a/internal/user/usecase.go b/internal/user/usecase.go index 13829d1..fc71209 100644 --- a/internal/user/usecase.go +++ b/internal/user/usecase.go @@ -8,5 +8,5 @@ import ( type UseCase interface { // Fetch discovery all available endpoints and Profile info on Me URL. - Fetch(ctx context.Context, me *domain.Me) (*domain.User, error) + Fetch(ctx context.Context, me domain.Me) (*domain.User, error) } diff --git a/internal/user/usecase/user_ucase.go b/internal/user/usecase/user_ucase.go index 6fffb84..b6a9fd9 100644 --- a/internal/user/usecase/user_ucase.go +++ b/internal/user/usecase/user_ucase.go @@ -18,7 +18,7 @@ func NewUserUseCase(repo user.Repository) user.UseCase { } } -func (useCase *userUseCase) Fetch(ctx context.Context, me *domain.Me) (*domain.User, error) { +func (useCase *userUseCase) Fetch(ctx context.Context, me domain.Me) (*domain.User, error) { user, err := useCase.repo.Get(ctx, me) if err != nil { return nil, fmt.Errorf("cannot find user by me: %w", err) diff --git a/internal/user/usecase/user_ucase_test.go b/internal/user/usecase/user_ucase_test.go index 836d50c..183878b 100644 --- a/internal/user/usecase/user_ucase_test.go +++ b/internal/user/usecase/user_ucase_test.go @@ -2,9 +2,7 @@ package usecase_test import ( "context" - "path" "reflect" - "sync" "testing" "source.toby3d.me/toby3d/auth/internal/domain" @@ -15,19 +13,20 @@ import ( func TestFetch(t *testing.T) { t.Parallel() - me := domain.TestMe(t, "https://user.example.net") user := domain.TestUser(t) + user.Me = domain.TestMe(t, "https://user.example.net") + users := repository.NewMemoryUserRepository() - store := new(sync.Map) - store.Store(path.Join(repository.DefaultPathPrefix, me.String()), user) + if err := users.Create(context.Background(), *user); err != nil { + t.Fatal(err) + } - result, err := ucase.NewUserUseCase(repository.NewMemoryUserRepository(store)). - Fetch(context.Background(), me) + result, err := ucase.NewUserUseCase(users).Fetch(context.Background(), *user.Me) if err != nil { t.Error(err) } if !reflect.DeepEqual(result, user) { - t.Errorf("Fetch(%s) = %+v, want %+v", me, result, user) + t.Errorf("Fetch(%s) = %+v, want %+v", user.Me, result, user) } } diff --git a/main.go b/main.go index 0f0cdbf..c24fdde 100644 --- a/main.go +++ b/main.go @@ -5,27 +5,27 @@ package main import ( + "context" + "embed" _ "embed" "errors" "flag" - "fmt" + "io/fs" "log" + "net/http" + _ "net/http/pprof" + "net/url" "os" "os/signal" - "path" "path/filepath" "runtime" "runtime/pprof" "strings" - "sync" "syscall" "time" - "github.com/fasthttp/router" "github.com/jmoiron/sqlx" "github.com/spf13/viper" - http "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/pprofhandler" "golang.org/x/text/language" "golang.org/x/text/message" _ "modernc.org/sqlite" @@ -40,6 +40,7 @@ import ( "source.toby3d.me/toby3d/auth/internal/domain" healthhttpdelivery "source.toby3d.me/toby3d/auth/internal/health/delivery/http" metadatahttpdelivery "source.toby3d.me/toby3d/auth/internal/metadata/delivery/http" + "source.toby3d.me/toby3d/auth/internal/middleware" "source.toby3d.me/toby3d/auth/internal/profile" profilehttprepo "source.toby3d.me/toby3d/auth/internal/profile/repository/http" profileucase "source.toby3d.me/toby3d/auth/internal/profile/usecase" @@ -57,6 +58,7 @@ import ( tokenmemoryrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory" tokensqlite3repo "source.toby3d.me/toby3d/auth/internal/token/repository/sqlite3" tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase" + "source.toby3d.me/toby3d/auth/internal/urlutil" userhttpdelivery "source.toby3d.me/toby3d/auth/internal/user/delivery/http" ) @@ -69,6 +71,7 @@ type ( tickets ticket.UseCase profiles profile.UseCase tokens token.UseCase + static fs.FS } NewAppOptions struct { @@ -78,6 +81,7 @@ type ( Tickets ticket.Repository Tokens token.Repository Profiles profile.Repository + Static fs.FS } ) @@ -93,13 +97,16 @@ var ( logger = log.New(os.Stdout, "IndieAuth\t", log.Lmsgprefix|log.LstdFlags|log.LUTC) config = new(domain.Config) indieAuthClient = new(domain.Client) - - configPath string - cpuProfilePath string - memProfilePath string - enablePprof bool ) +var ( + configPath, cpuProfilePath, memProfilePath string + enablePprof bool +) + +//go:embed assets/* +var staticFS embed.FS + //nolint:gochecknoinits func init() { flag.StringVar(&configPath, "config", filepath.Join(".", "config.yml"), "load specific config") @@ -133,34 +140,44 @@ func init() { rootURL := config.Server.GetRootURL() indieAuthClient.Name = []string{config.Name} - if indieAuthClient.ID, err = domain.ParseClientID(rootURL); err != nil { + cid, err := domain.ParseClientID(rootURL) + if err != nil { logger.Fatalln("fail to read config:", err) } - url, err := domain.ParseURL(rootURL) + indieAuthClient.ID = *cid + + u, err := url.Parse(rootURL) if err != nil { logger.Fatalln("cannot parse root URL as client URL:", err) } - logo, err := domain.ParseURL(rootURL + config.Server.StaticURLPrefix + "/icon.svg") + logo, err := url.Parse(rootURL + config.Server.StaticURLPrefix + "/icon.svg") if err != nil { logger.Fatalln("cannot parse root URL as client URL:", err) } - redirectURI, err := domain.ParseURL(rootURL + "/callback") + redirectURI, err := url.Parse(rootURL + "callback") if err != nil { logger.Fatalln("cannot parse root URL as client URL:", err) } - indieAuthClient.URL = []*domain.URL{url} - indieAuthClient.Logo = []*domain.URL{logo} - indieAuthClient.RedirectURI = []*domain.URL{redirectURI} + indieAuthClient.URL = []*url.URL{u} + indieAuthClient.Logo = []*url.URL{logo} + indieAuthClient.RedirectURI = []*url.URL{redirectURI} } //nolint:funlen,cyclop // "god object" and the entry point of all modules func main() { + ctx := context.Background() + var opts NewAppOptions + var err error + if opts.Static, err = fs.Sub(staticFS, "assets"); err != nil { + logger.Fatalln(err) + } + switch strings.ToLower(config.Database.Type) { case "sqlite3": store, err := sqlx.Open("sqlite", config.Database.Path) @@ -176,51 +193,27 @@ func main() { opts.Sessions = sessionsqlite3repo.NewSQLite3SessionRepository(store) opts.Tickets = ticketsqlite3repo.NewSQLite3TicketRepository(store, config) case "memory": - store := new(sync.Map) - opts.Tokens = tokenmemoryrepo.NewMemoryTokenRepository(store) - opts.Sessions = sessionmemoryrepo.NewMemorySessionRepository(store, config) - opts.Tickets = ticketmemoryrepo.NewMemoryTicketRepository(store, config) + opts.Tokens = tokenmemoryrepo.NewMemoryTokenRepository() + opts.Sessions = sessionmemoryrepo.NewMemorySessionRepository(*config) + opts.Tickets = ticketmemoryrepo.NewMemoryTicketRepository(*config) default: log.Fatalln("unsupported database type, use 'memory' or 'sqlite3'") } go opts.Sessions.GC() - //nolint:exhaustivestruct // too many options - opts.Client = &http.Client{ - Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()), - ReadTimeout: DefaultReadTimeout, - WriteTimeout: DefaultWriteTimeout, - } + opts.Client = new(http.Client) opts.Clients = clienthttprepo.NewHTTPClientRepository(opts.Client) opts.Profiles = profilehttprepo.NewHTPPClientRepository(opts.Client) - r := router.New() - NewApp(opts).Register(r) - //nolint:exhaustivestruct // too many options - r.ServeFilesCustom(path.Join(config.Server.StaticURLPrefix, "{filepath:*}"), &http.FS{ - Root: config.Server.StaticRootPath, - CacheDuration: DefaultCacheDuration, - AcceptByteRange: true, - Compress: true, - CompressBrotli: true, - GenerateIndexPages: true, - }) - - if enablePprof { - r.GET("/debug/pprof/{filepath:*}", pprofhandler.PprofHandler) - } + app := NewApp(opts) //nolint:exhaustivestruct server := &http.Server{ - Name: fmt.Sprintf("IndieAuth/0.1 (+%s)", config.Server.GetAddress()), - Handler: r.Handler, - ReadTimeout: DefaultReadTimeout, - WriteTimeout: DefaultWriteTimeout, - DisableKeepalive: true, - ReduceMemoryUsage: true, - SecureErrorLogMessage: true, - CloseOnShutdown: true, + Addr: config.Server.GetAddress(), + Handler: app.Handler(), + ReadTimeout: DefaultReadTimeout, + WriteTimeout: DefaultWriteTimeout, } done := make(chan os.Signal, 1) @@ -243,15 +236,15 @@ func main() { logger.Printf("started at %s, available at %s", config.Server.GetAddress(), config.Server.GetRootURL()) - err := server.ListenAndServe(config.Server.GetAddress()) - if err != nil && !errors.Is(err, http.ErrConnectionClosed) { + err := server.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { logger.Fatalln("cannot listen and serve:", err) } }() <-done - if err := server.Shutdown(); err != nil { + if err := server.Shutdown(ctx); err != nil { logger.Fatalln("failed shutdown of server:", err) } @@ -274,6 +267,7 @@ func main() { func NewApp(opts NewAppOptions) *App { return &App{ + static: opts.Static, auth: authucase.NewAuthUseCase(opts.Sessions, opts.Profiles, config), clients: clientucase.NewClientUseCase(opts.Clients), matcher: language.NewMatcher(message.DefaultCatalog.Languages()), @@ -289,20 +283,19 @@ func NewApp(opts NewAppOptions) *App { } } -func (app *App) Register(r *router.Router) { - tickethttpdelivery.NewRequestHandler(app.tickets, app.matcher, config).Register(r) - healthhttpdelivery.NewRequestHandler().Register(r) - metadatahttpdelivery.NewRequestHandler(&domain.Metadata{ - Issuer: indieAuthClient.ID, - AuthorizationEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "authorize"), - TokenEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "token"), - TicketEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "ticket"), +// TODO(toby3d): move module middlewares to here. +func (app *App) Handler() http.Handler { + metadata := metadatahttpdelivery.NewHandler(&domain.Metadata{ + Issuer: indieAuthClient.ID.URL(), + AuthorizationEndpoint: indieAuthClient.ID.URL().JoinPath("authorize"), + TokenEndpoint: indieAuthClient.ID.URL().JoinPath("token"), + TicketEndpoint: indieAuthClient.ID.URL().JoinPath("ticket"), MicropubEndpoint: nil, MicrosubEndpoint: nil, - IntrospectionEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "introspect"), - RevocationEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "revocation"), - UserinfoEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "userinfo"), - ServiceDocumentation: domain.MustParseURL("https://indieauth.net/source/"), + IntrospectionEndpoint: indieAuthClient.ID.URL().JoinPath("introspect"), + RevocationEndpoint: indieAuthClient.ID.URL().JoinPath("revocation"), + UserinfoEndpoint: indieAuthClient.ID.URL().JoinPath("userinfo"), + ServiceDocumentation: &url.URL{Scheme: "https", Host: "indieauth.net", Path: "/source/"}, IntrospectionEndpointAuthMethodsSupported: []string{"Bearer"}, RevocationEndpointAuthMethodsSupported: []string{"none"}, ScopesSupported: domain.Scopes{ @@ -319,8 +312,14 @@ func (app *App) Register(r *router.Router) { domain.ScopeRead, domain.ScopeUpdate, }, - ResponseTypesSupported: []domain.ResponseType{domain.ResponseTypeCode, domain.ResponseTypeID}, - GrantTypesSupported: []domain.GrantType{domain.GrantTypeAuthorizationCode, domain.GrantTypeTicket}, + ResponseTypesSupported: []domain.ResponseType{ + domain.ResponseTypeCode, + domain.ResponseTypeID, + }, + GrantTypesSupported: []domain.GrantType{ + domain.GrantTypeAuthorizationCode, + domain.GrantTypeTicket, + }, CodeChallengeMethodsSupported: []domain.CodeChallengeMethod{ domain.CodeChallengeMethodMD5, domain.CodeChallengeMethodPLAIN, @@ -329,20 +328,57 @@ func (app *App) Register(r *router.Router) { domain.CodeChallengeMethodS512, }, AuthorizationResponseIssParameterSupported: true, - }).Register(r) - tokenhttpdelivery.NewRequestHandler(app.tokens, app.tickets, config).Register(r) - clienthttpdelivery.NewRequestHandler(clienthttpdelivery.NewRequestHandlerOptions{ - Client: indieAuthClient, - Config: config, - Matcher: app.matcher, - Tokens: app.tokens, - }).Register(r) - authhttpdelivery.NewRequestHandler(authhttpdelivery.NewRequestHandlerOptions{ + }).Handler() + health := healthhttpdelivery.NewHandler().Handler() + auth := authhttpdelivery.NewHandler(authhttpdelivery.NewHandlerOptions{ Auth: app.auth, Clients: app.clients, - Config: config, + Config: *config, Matcher: app.matcher, Profiles: app.profiles, - }).Register(r) - userhttpdelivery.NewRequestHandler(app.tokens, config).Register(r) + }).Handler() + token := tokenhttpdelivery.NewHandler(app.tokens, app.tickets, config).Handler() + client := clienthttpdelivery.NewHandler(clienthttpdelivery.NewHandlerOptions{ + Client: *indieAuthClient, + Config: *config, + Matcher: app.matcher, + Tokens: app.tokens, + }).Handler() + user := userhttpdelivery.NewHandler(app.tokens, config).Handler() + ticket := tickethttpdelivery.NewHandler(app.tickets, app.matcher, *config).Handler() + static := http.FileServer(http.FS(app.static)) + + return http.HandlerFunc(middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var head string + head, r.URL.Path = urlutil.ShiftPath(r.URL.Path) + + switch head { + default: + r.URL = r.URL.JoinPath(head, r.URL.Path) + + static.ServeHTTP(w, r) + case "", "callback": + r.URL = r.URL.JoinPath(head, r.URL.Path) + + client.ServeHTTP(w, r) + case "token", "introspect", "revocation": + r.URL = r.URL.JoinPath(head, r.URL.Path) + + token.ServeHTTP(w, r) + case ".well-known": + if head, _ = urlutil.ShiftPath(r.URL.Path); head == "oauth-authorization-server" { + metadata.ServeHTTP(w, r) + } else { + http.NotFound(w, r) + } + case "authorize": + auth.ServeHTTP(w, r) + case "health": + health.ServeHTTP(w, r) + case "userinfo": + user.ServeHTTP(w, r) + case "ticket": + ticket.ServeHTTP(w, r) + } + }).Intercept(middleware.LogFmt())) } diff --git a/web/authorize.qtpl b/web/authorize.qtpl index 80df168..05b6495 100644 --- a/web/authorize.qtpl +++ b/web/authorize.qtpl @@ -17,137 +17,137 @@ } %} {% func (p *AuthorizePage) Title() %} - {% if p.Client.GetName() == "" %} - {%= p.T("Authorize %s", p.Client.GetName()) %} - {% else %} - {%= p.T("Authorize application") %} - {% endif %} +{% if p.Client.GetName() == "" %} +{%= p.T("Authorize %s", p.Client.GetName()) %} +{% else %} +{%= p.T("Authorize application") %} +{% endif %} {% endfunc %} {% func (p *AuthorizePage) Body() %} -
- {% if p.Client.GetLogo() != nil %} - {%s p.Client.GetName() %} - {% endif %} +
+ {% if p.Client.GetLogo() != nil %} + {%s p.Client.GetName() %} + {% endif %} -

- {% if p.Client.GetURL() != nil %} - +

+ {% if p.Client.GetURL() != nil %} + {% endif %} {% if p.Client.GetName() != "" %} - {%s p.Client.GetName() %} + {%s p.Client.GetName() %} {% else %} - {%s p.Client.ID.String() %} + {%s p.Client.ID.String() %} {% endif %} {% if p.Client.GetURL() != nil %} - - {% endif %} -

-

+ + {% endif %} + +
-
-
+
+ - {% if p.CSRF != nil %} - - {% endif %} + {% if p.CSRF != nil %} + + {% endif %} - {% for key, val := range map[string]string{ + {% for key, val := range map[string]string{ "client_id": p.Client.ID.String(), "redirect_uri": p.RedirectURI.String(), "response_type": p.ResponseType.String(), "state": p.State, } %} - + + {% endfor %} + + {% if len(p.Scope) > 0 %} +
+ {%= p.T("Choose your scopes") %} + + {% for _, scope := range p.Scope %} +
+ +
{% endfor %} +
+ {% endif %} - {% if len(p.Scope) > 0 %} -
- {%= p.T("Choose your scopes") %} + {% if p.CodeChallenge != "" %} + - {% for _, scope := range p.Scope %} -
- -
- {% endfor %} -
- {% endif %} + {% if p.Me != nil %} + + {% endif %} - {% if p.CodeChallenge != "" %} - + {% if len(p.Providers) > 0 %} + - {% endif %} - - {% if p.Me != nil %} - - {% endif %} - - {% if len(p.Providers) > 0 %} - - {% else %} - - {% endif %} + {%s provider.Name %} + + {% endfor %} + + {% else %} + + {% endif %} - + {%= p.T("Deny") %} + - - -
-{% endfunc %} \ No newline at end of file + {%= p.T("Allow") %} + + +
+{% endfunc %} diff --git a/web/authorize.qtpl.go b/web/authorize.qtpl.go index b9afc9b..56c0bf0 100644 --- a/web/authorize.qtpl.go +++ b/web/authorize.qtpl.go @@ -41,27 +41,27 @@ type AuthorizePage struct { func (p *AuthorizePage) StreamTitle(qw422016 *qt422016.Writer) { //line web/authorize.qtpl:19 qw422016.N().S(` - `) +`) //line web/authorize.qtpl:20 if p.Client.GetName() == "" { //line web/authorize.qtpl:20 qw422016.N().S(` - `) +`) //line web/authorize.qtpl:21 p.StreamT(qw422016, "Authorize %s", p.Client.GetName()) //line web/authorize.qtpl:21 qw422016.N().S(` - `) +`) //line web/authorize.qtpl:22 } else { //line web/authorize.qtpl:22 qw422016.N().S(` - `) +`) //line web/authorize.qtpl:23 p.StreamT(qw422016, "Authorize application") //line web/authorize.qtpl:23 qw422016.N().S(` - `) +`) //line web/authorize.qtpl:24 } //line web/authorize.qtpl:24 @@ -100,43 +100,43 @@ func (p *AuthorizePage) Title() string { func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) { //line web/authorize.qtpl:27 qw422016.N().S(` -
- `) +
+ `) //line web/authorize.qtpl:29 if p.Client.GetLogo() != nil { //line web/authorize.qtpl:29 qw422016.N().S(` - `)
+       alt= - `) + width="140"> + `) //line web/authorize.qtpl:40 } //line web/authorize.qtpl:40 qw422016.N().S(` -

- `) +

+ `) //line web/authorize.qtpl:43 if p.Client.GetURL() != nil { //line web/authorize.qtpl:43 qw422016.N().S(` - - `) + + `) //line web/authorize.qtpl:53 } //line web/authorize.qtpl:53 qw422016.N().S(` -

-
+ +
-
-
+
+ - `) + `) //line web/authorize.qtpl:67 if p.CSRF != nil { //line web/authorize.qtpl:67 qw422016.N().S(` - - `) + `) //line web/authorize.qtpl:71 } //line web/authorize.qtpl:71 qw422016.N().S(` - `) + `) //line web/authorize.qtpl:73 for key, val := range map[string]string{ "client_id": p.Client.ID.String(), @@ -223,129 +223,129 @@ func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) { } { //line web/authorize.qtpl:78 qw422016.N().S(` - - `) + `) //line web/authorize.qtpl:82 } //line web/authorize.qtpl:82 qw422016.N().S(` - `) + `) //line web/authorize.qtpl:84 if len(p.Scope) > 0 { //line web/authorize.qtpl:84 qw422016.N().S(` -
- `) +
+ `) //line web/authorize.qtpl:86 p.StreamT(qw422016, "Choose your scopes") //line web/authorize.qtpl:86 qw422016.N().S(` - `) + `) //line web/authorize.qtpl:88 for _, scope := range p.Scope { //line web/authorize.qtpl:88 qw422016.N().S(` -
- -
- `) + + + `) //line web/authorize.qtpl:99 } //line web/authorize.qtpl:99 qw422016.N().S(` -
- `) +
+ `) //line web/authorize.qtpl:101 } //line web/authorize.qtpl:101 qw422016.N().S(` - `) + `) //line web/authorize.qtpl:103 if p.CodeChallenge != "" { //line web/authorize.qtpl:103 qw422016.N().S(` - - - `) + `) //line web/authorize.qtpl:111 } //line web/authorize.qtpl:111 qw422016.N().S(` - `) + `) //line web/authorize.qtpl:113 if p.Me != nil { //line web/authorize.qtpl:113 qw422016.N().S(` - - `) + `) //line web/authorize.qtpl:117 } //line web/authorize.qtpl:117 qw422016.N().S(` - `) + `) //line web/authorize.qtpl:119 if len(p.Providers) > 0 { //line web/authorize.qtpl:119 qw422016.N().S(` - - `) + `) //line web/authorize.qtpl:124 for _, provider := range p.Providers { //line web/authorize.qtpl:124 qw422016.N().S(` - - `) + + `) //line web/authorize.qtpl:130 } //line web/authorize.qtpl:130 qw422016.N().S(` - - `) + + `) //line web/authorize.qtpl:132 } else { //line web/authorize.qtpl:132 qw422016.N().S(` - - `) + + `) //line web/authorize.qtpl:136 } //line web/authorize.qtpl:136 qw422016.N().S(` - + - - -
+ + +
`) //line web/authorize.qtpl:153 } diff --git a/web/ticket.qtpl b/web/ticket.qtpl index 7f2cab1..96dcf3f 100644 --- a/web/ticket.qtpl +++ b/web/ticket.qtpl @@ -5,47 +5,47 @@ {% collapsespace %} {% func (p *TicketPage) Body() %} -
-

{%= p.T("TicketAuth") %}

-
+
+

{%= p.T("TicketAuth") %}

+
-
-
+
+ - {% if p.CSRF != nil %} - - {% endif %} + {% if p.CSRF != nil %} + + {% endif %} -
- - -
+
+ + +
-
- - -
+
+ + +
- - -
+ + +
{% endfunc %} {% endcollapsespace %} diff --git a/web/ticket.qtpl.go b/web/ticket.qtpl.go index 53ac936..d579a88 100644 --- a/web/ticket.qtpl.go +++ b/web/ticket.qtpl.go @@ -30,7 +30,7 @@ func (p *TicketPage) StreamBody(qw422016 *qt422016.Writer) { //line web/ticket.qtpl:9 p.StreamT(qw422016, "TicketAuth") //line web/ticket.qtpl:9 - qw422016.N().S(`
`) + qw422016.N().S(`
`) //line web/ticket.qtpl:21 if p.CSRF != nil { //line web/ticket.qtpl:21