diff --git a/internal/auth/delivery/http/auth_http.go b/internal/auth/delivery/http/auth_http.go
index c760960..caa61bb 100644
--- a/internal/auth/delivery/http/auth_http.go
+++ b/internal/auth/delivery/http/auth_http.go
@@ -2,13 +2,10 @@ package http
import (
"crypto/subtle"
- "errors"
- "path"
+ "net/http"
"strings"
- "github.com/fasthttp/router"
- json "github.com/goccy/go-json"
- http "github.com/valyala/fasthttp"
+ "github.com/goccy/go-json"
"golang.org/x/text/language"
"golang.org/x/text/message"
@@ -16,111 +13,31 @@ import (
"source.toby3d.me/toby3d/auth/internal/client"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
+ "source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/profile"
+ "source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/auth/web"
- "source.toby3d.me/toby3d/form"
- "source.toby3d.me/toby3d/middleware"
)
type (
- AuthAuthorizationRequest struct {
- // Indicates to the authorization server that an authorization
- // code should be returned as the response.
- ResponseType domain.ResponseType `form:"response_type"` // code
-
- // The client URL.
- ClientID *domain.ClientID `form:"client_id"`
-
- // The redirect URL indicating where the user should be
- // redirected to after approving the request.
- RedirectURI *domain.URL `form:"redirect_uri"`
-
- // A parameter set by the client which will be included when the
- // user is redirected back to the client. This is used to
- // prevent CSRF attacks. The authorization server MUST return
- // the unmodified state value back to the client.
- State string `form:"state"`
-
- // The code challenge as previously described.
- CodeChallenge string `form:"code_challenge"`
-
- // The hashing method used to calculate the code challenge.
- CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"`
-
- // A space-separated list of scopes the client is requesting,
- // e.g. "profile", or "profile create". If the client omits this
- // value, the authorization server MUST NOT issue an access
- // token for this authorization code. Only the user's profile
- // URL may be returned without any scope requested.
- Scope domain.Scopes `form:"scope,omitempty"`
-
- // The URL that the user entered.
- Me *domain.Me `form:"me"`
- }
-
- AuthVerifyRequest struct {
- ClientID *domain.ClientID `form:"client_id"`
- Me *domain.Me `form:"me"`
- RedirectURI *domain.URL `form:"redirect_uri"`
- CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"`
- ResponseType domain.ResponseType `form:"response_type"`
- Scope domain.Scopes `form:"scope[],omitempty"`
- Authorize string `form:"authorize"`
- CodeChallenge string `form:"code_challenge"`
- State string `form:"state"`
- Provider string `form:"provider"`
- }
-
- AuthExchangeRequest struct {
- GrantType domain.GrantType `form:"grant_type"` // authorization_code
-
- // The authorization code received from the authorization
- // endpoint in the redirect.
- Code string `form:"code"`
-
- // The client's URL, which MUST match the client_id used in the
- // authentication request.
- ClientID *domain.ClientID `form:"client_id"`
-
- // The client's redirect URL, which MUST match the initial
- // authentication request.
- RedirectURI *domain.URL `form:"redirect_uri"`
-
- // The original plaintext random string generated before
- // starting the authorization request.
- CodeVerifier string `form:"code_verifier"`
- }
-
- AuthExchangeResponse struct {
- Me *domain.Me `json:"me"`
- Profile *AuthProfileResponse `json:"profile,omitempty"`
- }
-
- AuthProfileResponse struct {
- Email *domain.Email `json:"email,omitempty"`
- Photo *domain.URL `json:"photo,omitempty"`
- URL *domain.URL `json:"url,omitempty"`
- Name string `json:"name,omitempty"`
- }
-
- NewRequestHandlerOptions struct {
+ NewHandlerOptions struct {
Auth auth.UseCase
Clients client.UseCase
- Config *domain.Config
+ Config domain.Config
Matcher language.Matcher
Profiles profile.UseCase
}
- RequestHandler struct {
+ Handler struct {
clients client.UseCase
- config *domain.Config
+ config domain.Config
matcher language.Matcher
useCase auth.UseCase
}
)
-func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
- return &RequestHandler{
+func NewHandler(opts NewHandlerOptions) *Handler {
+ return &Handler{
clients: opts.Clients,
config: opts.Config,
matcher: opts.Matcher,
@@ -128,16 +45,16 @@ func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
}
}
-func (h *RequestHandler) Register(r *router.Router) {
+func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{
- Skipper: func(ctx *http.RequestCtx) bool {
- matched, _ := path.Match("/authorize*", string(ctx.Path()))
+ Skipper: func(w http.ResponseWriter, r *http.Request) bool {
+ head, _ := urlutil.ShiftPath(r.URL.Path)
- return ctx.IsPost() && matched
+ return r.Method == http.MethodPost && head == "authorize"
},
CookieMaxAge: 0,
- CookieSameSite: http.CookieSameSiteStrictMode,
+ CookieSameSite: http.SameSiteStrictMode,
ContextKey: "csrf",
CookieDomain: h.config.Server.Domain,
CookieName: "__Secure-csrf",
@@ -148,14 +65,12 @@ func (h *RequestHandler) Register(r *router.Router) {
CookieHTTPOnly: true,
}),
middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{
- Skipper: func(ctx *http.RequestCtx) bool {
- matched, _ := path.Match("/api/*", string(ctx.Path()))
- provider := string(ctx.QueryArgs().Peek("provider"))
- providerMatched := provider != "" && provider != domain.ProviderDirect.UID
+ Skipper: func(w http.ResponseWriter, r *http.Request) bool {
+ head, _ := urlutil.ShiftPath(r.URL.Path)
- return !ctx.IsPost() || !matched || providerMatched
+ return r.Method != http.MethodPost || head != "api"
},
- Validator: func(ctx *http.RequestCtx, login, password string) (bool, error) {
+ Validator: func(w http.ResponseWriter, r *http.Request, login, password string) (bool, error) {
userMatch := subtle.ConstantTimeCompare([]byte(login),
[]byte(h.config.IndieAuth.Username))
passMatch := subtle.ConstantTimeCompare([]byte(password),
@@ -165,29 +80,57 @@ func (h *RequestHandler) Register(r *router.Router) {
},
Realm: "",
}),
- middleware.LogFmt(),
}
- r.GET("/authorize", chain.RequestHandler(h.handleAuthorize))
- r.POST("/api/authorize", chain.RequestHandler(h.handleVerify))
- r.POST("/authorize", chain.RequestHandler(h.handleExchange))
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var head string
+ head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
+
+ switch r.Method {
+ default:
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+ case http.MethodGet, "":
+ if head != "" {
+ http.NotFound(w, r)
+
+ return
+ }
+
+ chain.Handler(h.handleAuthorize).ServeHTTP(w, r)
+ case http.MethodPost:
+ switch head {
+ default:
+ http.NotFound(w, r)
+ case "":
+ chain.Handler(h.handleExchange).ServeHTTP(w, r)
+ case "verify":
+ chain.Handler(h.handleVerify).ServeHTTP(w, r)
+ }
+ }
+ })
}
-func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
+func (h *Handler) handleAuthorize(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet && r.Method != "" {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+
+ tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...)
baseOf := web.BaseOf{
- Config: h.config,
+ Config: &h.config,
Language: tag,
Printer: message.NewPrinter(tag),
}
req := NewAuthAuthorizationRequest()
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
- web.WriteTemplate(ctx, &web.ErrorPage{
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: err,
})
@@ -195,10 +138,10 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
return
}
- client, err := h.clients.Discovery(ctx, req.ClientID)
+ client, err := h.clients.Discovery(r.Context(), req.ClientID)
if err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
- web.WriteTemplate(ctx, &web.ErrorPage{
+ w.WriteHeader(http.StatusBadRequest)
+ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: err,
})
@@ -207,8 +150,8 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
}
if !client.ValidateRedirectURI(req.RedirectURI.URL) {
- ctx.SetStatusCode(http.StatusBadRequest)
- web.WriteTemplate(ctx, &web.ErrorPage{
+ w.WriteHeader(http.StatusBadRequest)
+ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: domain.NewError(
domain.ErrorCodeInvalidClient,
@@ -220,15 +163,15 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
return
}
- csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte)
- web.WriteTemplate(ctx, &web.AuthorizePage{
+ csrf, _ := r.Context().Value(middleware.DefaultCSRFConfig.ContextKey).([]byte)
+ web.WriteTemplate(w, &web.AuthorizePage{
BaseOf: baseOf,
CSRF: csrf,
Scope: req.Scope,
Client: client,
- Me: req.Me,
- RedirectURI: req.RedirectURI,
- CodeChallengeMethod: req.CodeChallengeMethod,
+ Me: &req.Me,
+ RedirectURI: &req.RedirectURI,
+ CodeChallengeMethod: *req.CodeChallengeMethod,
ResponseType: req.ResponseType,
CodeChallenge: req.CodeChallenge,
State: req.State,
@@ -236,15 +179,21 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
})
}
-func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
- ctx.Response.Header.Set(http.HeaderAccessControlAllowOrigin, h.config.Server.Domain)
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
+func (h *Handler) handleVerify(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- encoder := json.NewEncoder(ctx)
+ return
+ }
+
+ w.Header().Set(common.HeaderAccessControlAllowOrigin, h.config.Server.Domain)
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+
+ encoder := json.NewEncoder(w)
req := NewAuthVerifyRequest()
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
@@ -254,60 +203,70 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
if strings.EqualFold(req.Authorize, "deny") {
domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State).
SetReirectURI(req.RedirectURI.URL)
- ctx.Redirect(req.RedirectURI.String(), http.StatusFound)
+ http.Redirect(w, r, req.RedirectURI.String(), http.StatusFound)
return
}
- code, err := h.useCase.Generate(ctx, auth.GenerateOptions{
+ code, err := h.useCase.Generate(r.Context(), auth.GenerateOptions{
ClientID: req.ClientID,
Me: req.Me,
RedirectURI: req.RedirectURI.URL,
- CodeChallengeMethod: req.CodeChallengeMethod,
+ CodeChallengeMethod: *req.CodeChallengeMethod,
Scope: req.Scope,
CodeChallenge: req.CodeChallenge,
})
if err != nil {
- ctx.SetStatusCode(http.StatusInternalServerError)
+ w.WriteHeader(http.StatusInternalServerError)
_ = encoder.Encode(err)
return
}
+ q := req.RedirectURI.Query()
+
for key, val := range map[string]string{
"code": code,
"iss": h.config.Server.GetRootURL(),
"state": req.State,
} {
- req.RedirectURI.Query().Set(key, val)
+ q.Set(key, val)
}
- ctx.Redirect(req.RedirectURI.String(), http.StatusFound)
+ req.RedirectURI.RawQuery = q.Encode()
+
+ http.Redirect(w, r, req.RedirectURI.String(), http.StatusFound)
}
-func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
+func (h *Handler) handleExchange(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- encoder := json.NewEncoder(ctx)
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+
+ encoder := json.NewEncoder(w)
req := new(AuthExchangeRequest)
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
return
}
- me, profile, err := h.useCase.Exchange(ctx, auth.ExchangeOptions{
+ me, profile, err := h.useCase.Exchange(r.Context(), auth.ExchangeOptions{
Code: req.Code,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI.URL,
CodeVerifier: req.CodeVerifier,
})
if err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
@@ -325,109 +284,7 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
}
_ = encoder.Encode(&AuthExchangeResponse{
- Me: me,
+ Me: *me,
Profile: userInfo,
})
}
-
-func NewAuthAuthorizationRequest() *AuthAuthorizationRequest {
- return &AuthAuthorizationRequest{
- ClientID: new(domain.ClientID),
- CodeChallenge: "",
- CodeChallengeMethod: domain.CodeChallengeMethodUnd,
- Me: new(domain.Me),
- RedirectURI: new(domain.URL),
- ResponseType: domain.ResponseTypeUnd,
- Scope: make(domain.Scopes, 0),
- State: "",
- }
-}
-
-//nolint:cyclop
-func (r *AuthAuthorizationRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
- if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieauth.net/source/#authorization-request",
- )
- }
-
- if r.ResponseType == domain.ResponseTypeID {
- r.ResponseType = domain.ResponseTypeCode
- }
-
- return nil
-}
-
-func NewAuthVerifyRequest() *AuthVerifyRequest {
- return &AuthVerifyRequest{
- Authorize: "",
- ClientID: new(domain.ClientID),
- CodeChallenge: "",
- CodeChallengeMethod: domain.CodeChallengeMethodUnd,
- Me: new(domain.Me),
- Provider: "",
- RedirectURI: new(domain.URL),
- ResponseType: domain.ResponseTypeUnd,
- Scope: make(domain.Scopes, 0),
- State: "",
- }
-}
-
-//nolint:funlen,cyclop
-func (r *AuthVerifyRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
-
- if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieauth.net/source/#authorization-request",
- )
- }
-
- // NOTE(toby3d): backwards-compatible support.
- // See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
- if r.ResponseType == domain.ResponseTypeID {
- r.ResponseType = domain.ResponseTypeCode
- }
-
- r.Provider = strings.ToLower(r.Provider)
-
- if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") {
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- "cannot validate verification request",
- "https://indieauth.net/source/#authorization-request",
- )
- }
-
- return nil
-}
-
-func (r *AuthExchangeRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
- if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- "cannot validate verification request",
- "https://indieauth.net/source/#redeeming-the-authorization-code",
- )
- }
-
- return nil
-}
diff --git a/internal/auth/delivery/http/auth_http_schema.go b/internal/auth/delivery/http/auth_http_schema.go
new file mode 100644
index 0000000..1843b3c
--- /dev/null
+++ b/internal/auth/delivery/http/auth_http_schema.go
@@ -0,0 +1,212 @@
+package http
+
+import (
+ "errors"
+ "net/http"
+ "strings"
+
+ "source.toby3d.me/toby3d/auth/internal/domain"
+ "source.toby3d.me/toby3d/form"
+)
+
+type (
+ AuthAuthorizationRequest struct {
+ // Indicates to the authorization server that an authorization
+ // code should be returned as the response.
+ ResponseType domain.ResponseType `form:"response_type"` // code
+
+ // The client URL.
+ ClientID domain.ClientID `form:"client_id"`
+
+ // The redirect URL indicating where the user should be
+ // redirected to after approving the request.
+ RedirectURI domain.URL `form:"redirect_uri"`
+
+ // The URL that the user entered.
+ Me domain.Me `form:"me"`
+
+ // The hashing method used to calculate the code challenge.
+ CodeChallengeMethod *domain.CodeChallengeMethod `form:"code_challenge_method,omitempty"`
+
+ // A space-separated list of scopes the client is requesting,
+ // e.g. "profile", or "profile create". If the client omits this
+ // value, the authorization server MUST NOT issue an access
+ // token for this authorization code. Only the user's profile
+ // URL may be returned without any scope requested.
+ Scope domain.Scopes `form:"scope,omitempty"`
+
+ // A parameter set by the client which will be included when the
+ // user is redirected back to the client. This is used to
+ // prevent CSRF attacks. The authorization server MUST return
+ // the unmodified state value back to the client.
+ State string `form:"state"`
+
+ // The code challenge as previously described.
+ CodeChallenge string `form:"code_challenge,omitempty"`
+ }
+
+ AuthVerifyRequest struct {
+ ClientID domain.ClientID `form:"client_id"`
+ Me domain.Me `form:"me"`
+ RedirectURI domain.URL `form:"redirect_uri"`
+ ResponseType domain.ResponseType `form:"response_type"`
+ CodeChallengeMethod *domain.CodeChallengeMethod `form:"code_challenge_method,omitempty"`
+ Scope domain.Scopes `form:"scope[],omitempty"`
+ Authorize string `form:"authorize"`
+ CodeChallenge string `form:"code_challenge,omitempty"`
+ State string `form:"state"`
+ Provider string `form:"provider"`
+ }
+
+ AuthExchangeRequest struct {
+ GrantType domain.GrantType `form:"grant_type"` // authorization_code
+
+ // The client's URL, which MUST match the client_id used in the
+ // authentication request.
+ ClientID domain.ClientID `form:"client_id"`
+
+ // The client's redirect URL, which MUST match the initial
+ // authentication request.
+ RedirectURI domain.URL `form:"redirect_uri"`
+
+ // The authorization code received from the authorization
+ // endpoint in the redirect.
+ Code string `form:"code"`
+
+ // The original plaintext random string generated before
+ // starting the authorization request.
+ CodeVerifier string `form:"code_verifier"`
+ }
+
+ AuthExchangeResponse struct {
+ Me domain.Me `json:"me"`
+ Profile *AuthProfileResponse `json:"profile,omitempty"`
+ }
+
+ AuthProfileResponse struct {
+ Email *domain.Email `json:"email,omitempty"`
+ Photo *domain.URL `json:"photo,omitempty"`
+ URL *domain.URL `json:"url,omitempty"`
+ Name string `json:"name,omitempty"`
+ }
+)
+
+func NewAuthAuthorizationRequest() *AuthAuthorizationRequest {
+ return &AuthAuthorizationRequest{
+ ClientID: domain.ClientID{},
+ CodeChallenge: "",
+ CodeChallengeMethod: &domain.CodeChallengeMethodUnd,
+ Me: domain.Me{},
+ RedirectURI: domain.URL{},
+ ResponseType: domain.ResponseTypeUnd,
+ Scope: make(domain.Scopes, 0),
+ State: "",
+ }
+}
+
+//nolint:cyclop
+func (r *AuthAuthorizationRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ if r.ResponseType == domain.ResponseTypeID {
+ r.ResponseType = domain.ResponseTypeCode
+ }
+
+ return nil
+}
+
+func NewAuthVerifyRequest() *AuthVerifyRequest {
+ return &AuthVerifyRequest{
+ Authorize: "",
+ ClientID: domain.ClientID{},
+ CodeChallenge: "",
+ CodeChallengeMethod: &domain.CodeChallengeMethodUnd,
+ Me: domain.Me{},
+ Provider: "",
+ RedirectURI: domain.URL{},
+ ResponseType: domain.ResponseTypeUnd,
+ Scope: make(domain.Scopes, 0),
+ State: "",
+ }
+}
+
+//nolint:funlen,cyclop
+func (r *AuthVerifyRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := req.ParseForm(); err != nil {
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ // NOTE(toby3d): backwards-compatible support.
+ // See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
+ if r.ResponseType == domain.ResponseTypeID {
+ r.ResponseType = domain.ResponseTypeCode
+ }
+
+ r.Provider = strings.ToLower(r.Provider)
+
+ if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") {
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ "cannot validate verification request",
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ return nil
+}
+
+func (r *AuthExchangeRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := req.ParseForm(); err != nil {
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ "cannot validate verification request",
+ "https://indieauth.net/source/#redeeming-the-authorization-code",
+ )
+ }
+
+ return nil
+}
diff --git a/internal/auth/delivery/http/auth_http_test.go b/internal/auth/delivery/http/auth_http_test.go
index 7d11b48..a0dd9d2 100644
--- a/internal/auth/delivery/http/auth_http_test.go
+++ b/internal/auth/delivery/http/auth_http_test.go
@@ -1,13 +1,14 @@
package http_test
import (
- "path"
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
"strings"
- "sync"
"testing"
- "github.com/fasthttp/router"
- http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
@@ -22,7 +23,7 @@ import (
profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory"
"source.toby3d.me/toby3d/auth/internal/session"
sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory"
- "source.toby3d.me/toby3d/auth/internal/testing/httptest"
+ "source.toby3d.me/toby3d/auth/internal/user"
userrepo "source.toby3d.me/toby3d/auth/internal/user/repository/memory"
)
@@ -34,36 +35,31 @@ type Dependencies struct {
matcher language.Matcher
profiles profile.Repository
sessions session.Repository
- store *sync.Map
+ users user.Repository
}
func TestAuthorize(t *testing.T) {
t.Parallel()
deps := NewDependencies(t)
- me := domain.TestMe(t, "https://user.example.net")
+ me := domain.TestMe(t, "https://user.example.net/")
user := domain.TestUser(t)
client := domain.TestClient(t)
- deps.store.Store(path.Join(clientrepo.DefaultPathPrefix, client.ID.String()), client)
- deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, me.String()), user.Profile)
- deps.store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), user)
+ if err := deps.clients.Create(context.Background(), *client); err != nil {
+ t.Fatal(err)
+ }
- r := router.New()
- //nolint:exhaustivestruct
- delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
- Auth: deps.authService,
- Clients: deps.clientService,
- Config: deps.config,
- Matcher: deps.matcher,
- }).Register(r)
+ if err := deps.users.Create(context.Background(), *user); err != nil {
+ t.Fatal(err)
+ }
- httpClient, _, cleanup := httptest.New(t, r.Handler)
- t.Cleanup(cleanup)
+ if err := deps.profiles.Create(context.Background(), *me, *user.Profile); err != nil {
+ t.Fatal(err)
+ }
- uri := http.AcquireURI()
- defer http.ReleaseURI(uri)
- uri.Update("https://example.com/authorize")
+ u := &url.URL{Scheme: "https", Host: "example.com", Path: "/"}
+ q := u.Query()
for key, val := range map[string]string{
"client_id": client.ID.String(),
@@ -75,26 +71,36 @@ func TestAuthorize(t *testing.T) {
"scope": "profile email",
"state": "1234567890",
} {
- uri.QueryArgs().Set(key, val)
+ q.Set(key, val)
}
- req := httptest.NewRequest(http.MethodGet, uri.String(), nil)
- defer http.ReleaseRequest(req)
+ u.RawQuery = q.Encode()
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
+ req := httptest.NewRequest(http.MethodGet, u.String(), nil)
+ w := httptest.NewRecorder()
- if err := httpClient.Do(req, resp); err != nil {
+ //nolint:exhaustivestruct
+ delivery.NewHandler(delivery.NewHandlerOptions{
+ Auth: deps.authService,
+ Clients: deps.clientService,
+ Config: *deps.config,
+ Matcher: deps.matcher,
+ }).Handler().ServeHTTP(w, req)
+
+ resp := w.Result()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
t.Fatal(err)
}
- if resp.StatusCode() != http.StatusOK {
- t.Errorf("GET %s = %d, want %d", uri.String(), resp.StatusCode(), http.StatusOK)
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("%s %s = %d, want %d", req.Method, u.String(), resp.StatusCode, http.StatusOK)
}
const expResult = `Authorize application`
- if result := string(resp.Body()); !strings.Contains(result, expResult) {
- t.Errorf("GET %s = %s, want %s", uri.String(), result, expResult)
+ if result := string(body); !strings.Contains(result, expResult) {
+ t.Errorf("%s %s = %s, want %s", req.Method, u.String(), result, expResult)
}
}
@@ -103,14 +109,15 @@ func NewDependencies(tb testing.TB) Dependencies {
config := domain.TestConfig(tb)
matcher := language.NewMatcher(message.DefaultCatalog.Languages())
- store := new(sync.Map)
- clients := clientrepo.NewMemoryClientRepository(store)
- sessions := sessionrepo.NewMemorySessionRepository(store, config)
- profiles := profilerepo.NewMemoryProfileRepository(store)
+ clients := clientrepo.NewMemoryClientRepository()
+ users := userrepo.NewMemoryUserRepository()
+ sessions := sessionrepo.NewMemorySessionRepository(*config)
+ profiles := profilerepo.NewMemoryProfileRepository()
authService := ucase.NewAuthUseCase(sessions, profiles, config)
clientService := clientucase.NewClientUseCase(clients)
return Dependencies{
+ users: users,
authService: authService,
clients: clients,
clientService: clientService,
@@ -118,6 +125,5 @@ func NewDependencies(tb testing.TB) Dependencies {
matcher: matcher,
sessions: sessions,
profiles: profiles,
- store: store,
}
}
diff --git a/internal/auth/usecase.go b/internal/auth/usecase.go
index cf92b73..e6c9155 100644
--- a/internal/auth/usecase.go
+++ b/internal/auth/usecase.go
@@ -9,8 +9,8 @@ import (
type (
GenerateOptions struct {
- ClientID *domain.ClientID
- Me *domain.Me
+ ClientID domain.ClientID
+ Me domain.Me
RedirectURI *url.URL
CodeChallengeMethod domain.CodeChallengeMethod
Scope domain.Scopes
@@ -18,7 +18,7 @@ type (
}
ExchangeOptions struct {
- ClientID *domain.ClientID
+ ClientID domain.ClientID
RedirectURI *url.URL
Code string
CodeVerifier string
diff --git a/internal/auth/usecase/auth_ucase.go b/internal/auth/usecase/auth_ucase.go
index fe6610b..0dd4b83 100644
--- a/internal/auth/usecase/auth_ucase.go
+++ b/internal/auth/usecase/auth_ucase.go
@@ -45,7 +45,7 @@ func (uc *authUseCase) Generate(ctx context.Context, opts auth.GenerateOptions)
}
}
- if err = uc.sessions.Create(ctx, &domain.Session{
+ if err = uc.sessions.Create(ctx, domain.Session{
ClientID: opts.ClientID,
Code: code,
CodeChallenge: opts.CodeChallenge,
@@ -81,5 +81,5 @@ func (uc *authUseCase) Exchange(ctx context.Context, opts auth.ExchangeOptions)
return nil, nil, auth.ErrMismatchPKCE
}
- return session.Me, session.Profile, nil
+ return &session.Me, session.Profile, nil
}
diff --git a/internal/client/delivery/http/client_http.go b/internal/client/delivery/http/client_http.go
index e553d85..f325dbf 100644
--- a/internal/client/delivery/http/client_http.go
+++ b/internal/client/delivery/http/client_http.go
@@ -1,48 +1,37 @@
package http
import (
- "errors"
+ "net/http"
"strings"
- "github.com/fasthttp/router"
- http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/token"
+ "source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/auth/web"
- "source.toby3d.me/toby3d/form"
- "source.toby3d.me/toby3d/middleware"
)
type (
- ClientCallbackRequest struct {
- Error domain.ErrorCode `form:"error,omitempty"`
- Iss *domain.ClientID `form:"iss"`
- Code string `form:"code"`
- ErrorDescription string `form:"error_description,omitempty"`
- State string `form:"state"`
- }
-
- NewRequestHandlerOptions struct {
+ NewHandlerOptions struct {
Matcher language.Matcher
Tokens token.UseCase
- Client *domain.Client
- Config *domain.Config
+ Client domain.Client
+ Config domain.Config
}
- RequestHandler struct {
+ Handler struct {
matcher language.Matcher
tokens token.UseCase
- client *domain.Client
- config *domain.Config
+ client domain.Client
+ config domain.Config
}
)
-func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
- return &RequestHandler{
+func NewHandler(opts NewHandlerOptions) *Handler {
+ return &Handler{
client: opts.Client,
config: opts.Config,
matcher: opts.Matcher,
@@ -50,59 +39,82 @@ func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
}
}
-func (h *RequestHandler) Register(r *router.Router) {
- chain := middleware.Chain{
- middleware.LogFmt(),
- }
+func (h *Handler) Handler() http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "" && r.Method != http.MethodGet {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- r.GET("/", chain.RequestHandler(h.handleRender))
- r.GET("/callback", chain.RequestHandler(h.handleCallback))
+ return
+ }
+
+ var head string
+ head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
+
+ switch head {
+ default:
+ http.NotFound(w, r)
+ case "":
+ h.handleRender(w, r)
+ case "callback":
+ h.handleCallback(w, r)
+ }
+ })
}
-func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
+func (h *Handler) handleRender(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "" && r.Method != http.MethodGet {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+
+ return
+ }
+
redirect := make([]string, len(h.client.RedirectURI))
for i := range h.client.RedirectURI {
redirect[i] = h.client.RedirectURI[i].String()
}
- ctx.Response.Header.Set(
- http.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
- )
+ w.Header().Set(common.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`)
- tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
+ tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...)
// TODO(toby3d): generate and store PKCE
- ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
- web.WriteTemplate(ctx, &web.HomePage{
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+ web.WriteTemplate(w, &web.HomePage{
BaseOf: web.BaseOf{
- Config: h.config,
+ Config: &h.config,
Language: tag,
Printer: message.NewPrinter(tag),
},
- Client: h.client,
+ Client: &h.client,
State: "hackme", // TODO(toby3d): generate and store state
})
}
//nolint:unlen
-func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
+func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "" && r.Method != http.MethodGet {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+
+ tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...)
baseOf := web.BaseOf{
- Config: h.config,
+ Config: &h.config,
Language: tag,
Printer: message.NewPrinter(tag),
}
req := new(ClientCallbackRequest)
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusInternalServerError)
- web.WriteTemplate(ctx, &web.ErrorPage{
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: err,
})
@@ -111,8 +123,8 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
}
if req.Error != domain.ErrorCodeUnd {
- ctx.SetStatusCode(http.StatusUnauthorized)
- web.WriteTemplate(ctx, &web.ErrorPage{
+ w.WriteHeader(http.StatusUnauthorized)
+ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: domain.NewError(
domain.ErrorCodeAccessDenied,
@@ -127,9 +139,9 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
// TODO(toby3d): load and check state
- if req.Iss == nil || req.Iss.String() != h.client.ID.String() {
- ctx.SetStatusCode(http.StatusBadRequest)
- web.WriteTemplate(ctx, &web.ErrorPage{
+ if req.Iss.String() != h.client.ID.String() {
+ w.WriteHeader(http.StatusBadRequest)
+ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: domain.NewError(
domain.ErrorCodeInvalidClient,
@@ -142,15 +154,15 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
return
}
- token, _, err := h.tokens.Exchange(ctx, token.ExchangeOptions{
+ token, _, err := h.tokens.Exchange(r.Context(), token.ExchangeOptions{
ClientID: h.client.ID,
RedirectURI: h.client.RedirectURI[0],
Code: req.Code,
CodeVerifier: "", // TODO(toby3d): validate PKCE here
})
if err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
- web.WriteTemplate(ctx, &web.ErrorPage{
+ w.WriteHeader(http.StatusBadRequest)
+ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: err,
})
@@ -158,23 +170,9 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
return
}
- ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
- web.WriteTemplate(ctx, &web.CallbackPage{
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+ web.WriteTemplate(w, &web.CallbackPage{
BaseOf: baseOf,
Token: token,
})
}
-
-func (req *ClientCallbackRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
-
- if err := form.Unmarshal(ctx.QueryArgs().QueryString(), req); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")
- }
-
- return nil
-}
diff --git a/internal/client/delivery/http/client_http_schema.go b/internal/client/delivery/http/client_http_schema.go
new file mode 100644
index 0000000..ccb5675
--- /dev/null
+++ b/internal/client/delivery/http/client_http_schema.go
@@ -0,0 +1,31 @@
+package http
+
+import (
+ "errors"
+ "net/http"
+
+ "source.toby3d.me/toby3d/auth/internal/domain"
+ "source.toby3d.me/toby3d/form"
+)
+
+type ClientCallbackRequest struct {
+ Error domain.ErrorCode `form:"error,omitempty"`
+ Iss domain.ClientID `form:"iss"`
+ Code string `form:"code"`
+ ErrorDescription string `form:"error_description,omitempty"`
+ State string `form:"state"`
+}
+
+func (req *ClientCallbackRequest) bind(r *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := form.Unmarshal([]byte(r.URL.Query().Encode()), req); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")
+ }
+
+ return nil
+}
diff --git a/internal/client/delivery/http/client_http_test.go b/internal/client/delivery/http/client_http_test.go
index b46017b..eab41f8 100644
--- a/internal/client/delivery/http/client_http_test.go
+++ b/internal/client/delivery/http/client_http_test.go
@@ -1,11 +1,10 @@
package http_test
import (
- "sync"
+ "net/http"
+ "net/http/httptest"
"testing"
- "github.com/fasthttp/router"
- http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
@@ -15,7 +14,6 @@ import (
profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory"
"source.toby3d.me/toby3d/auth/internal/session"
sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory"
- "source.toby3d.me/toby3d/auth/internal/testing/httptest"
"source.toby3d.me/toby3d/auth/internal/token"
tokenrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory"
tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase"
@@ -27,7 +25,6 @@ type Dependencies struct {
config *domain.Config
matcher language.Matcher
sessions session.Repository
- store *sync.Map
tokens token.Repository
tokenService token.UseCase
}
@@ -36,45 +33,30 @@ func TestRead(t *testing.T) {
t.Parallel()
deps := NewDependencies(t)
+ req := httptest.NewRequest(http.MethodGet, "https://app.example.com/", nil)
+ w := httptest.NewRecorder()
- r := router.New()
- delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
- Client: deps.client,
- Config: deps.config,
+ delivery.NewHandler(delivery.NewHandlerOptions{
+ Client: *deps.client,
+ Config: *deps.config,
Matcher: deps.matcher,
Tokens: deps.tokenService,
- }).Register(r)
+ }).Handler().ServeHTTP(w, req)
- client, _, cleanup := httptest.New(t, r.Handler)
- t.Cleanup(cleanup)
-
- const requestURI string = "https://app.example.com/"
- req, resp := httptest.NewRequest(http.MethodGet, requestURI, nil), http.AcquireResponse()
-
- t.Cleanup(func() {
- http.ReleaseRequest(req)
- http.ReleaseResponse(resp)
- })
-
- if err := client.Do(req, resp); err != nil {
- t.Error(err)
- }
-
- if resp.StatusCode() != http.StatusOK {
- t.Errorf("GET %s = %d, want %d", requestURI, resp.StatusCode(), http.StatusOK)
+ if resp := w.Result(); resp.StatusCode != http.StatusOK {
+ t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK)
}
}
func NewDependencies(tb testing.TB) Dependencies {
tb.Helper()
- store := new(sync.Map)
client := domain.TestClient(tb)
config := domain.TestConfig(tb)
matcher := language.NewMatcher(message.DefaultCatalog.Languages())
- sessions := sessionrepo.NewMemorySessionRepository(store, config)
- tokens := tokenrepo.NewMemoryTokenRepository(store)
- profiles := profilerepo.NewMemoryProfileRepository(store)
+ sessions := sessionrepo.NewMemorySessionRepository(*config)
+ tokens := tokenrepo.NewMemoryTokenRepository()
+ profiles := profilerepo.NewMemoryProfileRepository()
tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{
Config: config,
Profiles: profiles,
@@ -87,7 +69,6 @@ func NewDependencies(tb testing.TB) Dependencies {
config: config,
matcher: matcher,
sessions: sessions,
- store: store,
profiles: profiles,
tokens: tokens,
tokenService: tokenService,
diff --git a/internal/client/repository.go b/internal/client/repository.go
index 6aef5ad..0389498 100644
--- a/internal/client/repository.go
+++ b/internal/client/repository.go
@@ -7,7 +7,8 @@ import (
)
type Repository interface {
- Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
+ Create(ctx context.Context, client domain.Client) error
+ Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error)
}
var ErrNotExist error = domain.NewError(
diff --git a/internal/client/repository/http/http_client.go b/internal/client/repository/http/http_client.go
index 81e6a21..ff10730 100644
--- a/internal/client/repository/http/http_client.go
+++ b/internal/client/repository/http/http_client.go
@@ -1,14 +1,17 @@
package http
import (
+ "bytes"
"context"
"fmt"
- "net"
+ "io"
+ "net/http"
"net/url"
- http "github.com/valyala/fasthttp"
+ "golang.org/x/exp/slices"
"source.toby3d.me/toby3d/auth/internal/client"
+ "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/httputil"
)
@@ -34,33 +37,18 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
}
}
-func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) {
- ips, err := net.LookupIP(cid.URL().Hostname())
+// WARN(toby3d): not implemented.
+func (httpClientRepository) Create(_ context.Context, _ domain.Client) error {
+ return nil
+}
+
+func (repo httpClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) {
+ resp, err := repo.client.Get(cid.String())
if err != nil {
- return nil, fmt.Errorf("cannot resolve client IP by id: %w", err)
- }
-
- for _, ip := range ips {
- if !ip.IsLoopback() {
- continue
- }
-
- return nil, client.ErrNotExist
- }
-
- req := http.AcquireRequest()
- defer http.ReleaseRequest(req)
- req.SetRequestURI(cid.String())
- req.Header.SetMethod(http.MethodGet)
-
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
-
- if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
return nil, fmt.Errorf("failed to make a request to the client: %w", err)
}
- if resp.StatusCode() == http.StatusNotFound {
+ if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("%w: status on client page is not 200", client.ErrNotExist)
}
@@ -72,74 +60,62 @@ func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID)
Name: make([]string, 0),
}
- extract(client, resp)
+ extract(resp.Body, resp.Request.URL, client, resp.Header.Get(common.HeaderLink))
return client, nil
}
//nolint:gocognit,cyclop
-func extract(dst *domain.Client, src *http.Response) {
- for _, endpoint := range httputil.ExtractEndpoints(src, relRedirectURI) {
- if !containsURL(dst.RedirectURI, endpoint) {
+func extract(r io.Reader, u *url.URL, dst *domain.Client, header string) {
+ body, _ := io.ReadAll(r)
+
+ for _, endpoint := range httputil.ExtractEndpoints(bytes.NewReader(body), u, header, relRedirectURI) {
+ if !containsUrl(dst.RedirectURI, endpoint) {
dst.RedirectURI = append(dst.RedirectURI, endpoint)
}
}
- for _, itemType := range []string{hXApp, hApp} {
- for _, name := range httputil.ExtractProperty(src, itemType, propertyName) {
- if n, ok := name.(string); ok && !containsString(dst.Name, n) {
+ for _, itemType := range []string{hApp, hXApp} {
+ for _, name := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyName) {
+ if n, ok := name.(string); ok && !slices.Contains(dst.Name, n) {
dst.Name = append(dst.Name, n)
}
}
- for _, logo := range httputil.ExtractProperty(src, itemType, propertyLogo) {
- var (
- u *url.URL
- err error
- )
+ for _, logo := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyLogo) {
+ var logoURL *url.URL
+ var err error
switch l := logo.(type) {
case string:
- u, err = url.Parse(l)
+ logoURL, err = url.Parse(l)
case map[string]string:
if value, ok := l["value"]; ok {
- u, err = url.Parse(value)
+ logoURL, err = url.Parse(value)
}
}
- if err != nil || containsURL(dst.Logo, u) {
+ if err != nil || containsUrl(dst.Logo, logoURL) {
continue
}
- dst.Logo = append(dst.Logo, u)
+ dst.Logo = append(dst.Logo, logoURL)
}
- for _, property := range httputil.ExtractProperty(src, itemType, propertyURL) {
+ for _, property := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyURL) {
prop, ok := property.(string)
if !ok {
continue
}
- if u, err := url.Parse(prop); err == nil || !containsURL(dst.URL, u) {
+ if u, err := url.Parse(prop); err == nil && !containsUrl(dst.URL, u) {
dst.URL = append(dst.URL, u)
}
}
}
}
-func containsString(src []string, find string) bool {
- for i := range src {
- if src[i] != find {
- continue
- }
-
- return true
- }
-
- return false
-}
-
-func containsURL(src []*url.URL, find *url.URL) bool {
+func containsUrl(src []*url.URL, find *url.URL) bool {
for i := range src {
if src[i].String() != find.String() {
continue
diff --git a/internal/client/repository/http/http_client_test.go b/internal/client/repository/http/http_client_test.go
index 5594d7e..c5eb739 100644
--- a/internal/client/repository/http/http_client_test.go
+++ b/internal/client/repository/http/http_client_test.go
@@ -3,22 +3,21 @@ package http_test
import (
"context"
"fmt"
+ "net/http"
+ "net/http/httptest"
"testing"
- "github.com/stretchr/testify/assert"
- http "github.com/valyala/fasthttp"
+ "github.com/google/go-cmp/cmp"
repository "source.toby3d.me/toby3d/auth/internal/client/repository/http"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
- "source.toby3d.me/toby3d/auth/internal/testing/httptest"
)
-const testBody string = `
-
+const testBody string = `
-
+
%[1]s
@@ -36,38 +35,47 @@ func TestGet(t *testing.T) {
t.Parallel()
client := domain.TestClient(t)
- httpClient, _, cleanup := httptest.New(t, testHandler(t, client))
- t.Cleanup(cleanup)
+ srv := httptest.NewUnstartedServer(testHandler(t, *client))
+ srv.EnableHTTP2 = true
- result, err := repository.NewHTTPClientRepository(httpClient).Get(context.Background(), client.ID)
+ srv.StartTLS()
+ t.Cleanup(srv.Close)
+
+ client.ID = *domain.TestClientID(t, srv.URL+"/")
+ clients := repository.NewHTTPClientRepository(srv.Client())
+
+ result, err := clients.Get(context.Background(), client.ID)
if err != nil {
t.Fatal(err)
}
- assert.Equal(t, client.Name, result.Name)
- assert.Equal(t, client.ID.String(), result.ID.String())
-
- for i := range client.URL {
- assert.Equal(t, client.URL[i].String(), result.URL[i].String())
+ if out := client.ID; !result.ID.IsEqual(out) {
+ t.Errorf("GET %s = %s, want %s", client.ID, out, result.ID)
}
- for i := range client.Logo {
- assert.Equal(t, client.Logo[i].String(), result.Logo[i].String())
+ if !cmp.Equal(result.Name, client.Name) {
+ t.Errorf("GET %s = %+s, want %+s", client.ID, result.Name, client.Name)
}
- for i := range client.RedirectURI {
- assert.Equal(t, client.RedirectURI[i].String(), result.RedirectURI[i].String())
+ if !cmp.Equal(result.URL, client.URL) {
+ t.Errorf("GET %s = %+s, want %+s", client.ID, result.URL, client.URL)
+ }
+
+ if !cmp.Equal(result.Logo, client.Logo) {
+ t.Errorf("GET %s = %+s, want %+s", client.ID, result.Logo, client.Logo)
+ }
+
+ if !cmp.Equal(result.RedirectURI, client.RedirectURI) {
+ t.Errorf("GET %s = %+s, want %+s", client.ID, result.RedirectURI, client.RedirectURI)
}
}
-func testHandler(tb testing.TB, client *domain.Client) http.RequestHandler {
+func testHandler(tb testing.TB, client domain.Client) http.Handler {
tb.Helper()
- return func(ctx *http.RequestCtx) {
- ctx.Response.Header.Set(http.HeaderLink, `<`+client.RedirectURI[0].String()+`>; rel="redirect_uri"`)
- ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, fmt.Sprintf(
- testBody, client.Name[0], client.URL[0].String(), client.Logo[0].String(),
- client.RedirectURI[1].String(),
- ))
- }
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+ w.Header().Set(common.HeaderLink, `<`+client.RedirectURI[0].String()+`>; rel="redirect_uri"`)
+ fmt.Fprintf(w, testBody, client.Name[0], client.URL[0], client.Logo[0], client.RedirectURI[1])
+ })
}
diff --git a/internal/client/repository/memory/memory_client.go b/internal/client/repository/memory/memory_client.go
index e4d433f..cfdd04e 100644
--- a/internal/client/repository/memory/memory_client.go
+++ b/internal/client/repository/memory/memory_client.go
@@ -2,9 +2,6 @@ package memory
import (
"context"
- "fmt"
- "net"
- "path"
"sync"
"source.toby3d.me/toby3d/auth/internal/client"
@@ -12,45 +9,33 @@ import (
)
type memoryClientRepository struct {
- store *sync.Map
+ mutex *sync.RWMutex
+ clients map[string]domain.Client
}
-const DefaultPathPrefix string = "clients"
-
-func NewMemoryClientRepository(store *sync.Map) client.Repository {
+func NewMemoryClientRepository() client.Repository {
return &memoryClientRepository{
- store: store,
+ mutex: new(sync.RWMutex),
+ clients: make(map[string]domain.Client),
}
}
-func (repo *memoryClientRepository) Create(ctx context.Context, client *domain.Client) error {
- repo.store.Store(path.Join(DefaultPathPrefix, client.ID.String()), client)
+func (repo memoryClientRepository) Create(ctx context.Context, client domain.Client) error {
+ repo.mutex.RLock()
+ defer repo.mutex.RUnlock()
+
+ repo.clients[client.ID.String()] = client
return nil
}
-func (repo *memoryClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
- // WARN(toby3d): more often than not, we will work from tests with
- // non-existent clients, almost guaranteed to cause a resolution error.
- ips, _ := net.LookupIP(id.URL().Hostname())
+func (repo memoryClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) {
+ repo.mutex.RLock()
+ defer repo.mutex.RUnlock()
- for _, ip := range ips {
- if !ip.IsLoopback() {
- continue
- }
-
- return nil, client.ErrNotExist
+ if c, ok := repo.clients[cid.String()]; ok {
+ return &c, nil
}
- src, ok := repo.store.Load(path.Join(DefaultPathPrefix, id.String()))
- if !ok {
- return nil, fmt.Errorf("cannot find client in store: %w", client.ErrNotExist)
- }
-
- c, ok := src.(*domain.Client)
- if !ok {
- return nil, fmt.Errorf("cannot decode client from store: %w", client.ErrNotExist)
- }
-
- return c, nil
+ return nil, client.ErrNotExist
}
diff --git a/internal/client/repository/memory/memory_client_test.go b/internal/client/repository/memory/memory_client_test.go
deleted file mode 100644
index 698904e..0000000
--- a/internal/client/repository/memory/memory_client_test.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package memory_test
-
-import (
- "context"
- "path"
- "reflect"
- "sync"
- "testing"
-
- repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory"
- "source.toby3d.me/toby3d/auth/internal/domain"
-)
-
-func TestGet(t *testing.T) {
- t.Parallel()
-
- client := domain.TestClient(t)
-
- store := new(sync.Map)
- store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
-
- result, err := repository.NewMemoryClientRepository(store).
- Get(context.Background(), client.ID)
- if err != nil {
- t.Fatal(err)
- }
-
- if !reflect.DeepEqual(result, client) {
- t.Errorf("Get(%s) = %+v, want %+v", client.ID, result, client)
- }
-}
diff --git a/internal/client/usecase.go b/internal/client/usecase.go
index ca3f655..7a2a706 100644
--- a/internal/client/usecase.go
+++ b/internal/client/usecase.go
@@ -8,7 +8,7 @@ import (
type UseCase interface {
// Discovery returns client public information bu ClientID URL.
- Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
+ Discovery(ctx context.Context, id domain.ClientID) (*domain.Client, error)
}
var ErrInvalidMe error = domain.NewError(
diff --git a/internal/client/usecase/client_ucase.go b/internal/client/usecase/client_ucase.go
index 863b056..0a4b47e 100644
--- a/internal/client/usecase/client_ucase.go
+++ b/internal/client/usecase/client_ucase.go
@@ -18,7 +18,7 @@ func NewClientUseCase(repo client.Repository) client.UseCase {
}
}
-func (useCase *clientUseCase) Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
+func (useCase *clientUseCase) Discovery(ctx context.Context, id domain.ClientID) (*domain.Client, error) {
c, err := useCase.repo.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("cannot discovery client by id: %w", err)
diff --git a/internal/client/usecase/client_ucase_test.go b/internal/client/usecase/client_ucase_test.go
index ae263f7..6674b8d 100644
--- a/internal/client/usecase/client_ucase_test.go
+++ b/internal/client/usecase/client_ucase_test.go
@@ -3,12 +3,9 @@ package usecase_test
import (
"context"
"errors"
- "path"
"reflect"
- "sync"
"testing"
- "source.toby3d.me/toby3d/auth/internal/client"
repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory"
"source.toby3d.me/toby3d/auth/internal/client/usecase"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -17,12 +14,11 @@ import (
func TestDiscovery(t *testing.T) {
t.Parallel()
- store := new(sync.Map)
- testClient, localhostClient := domain.TestClient(t), domain.TestClient(t)
- localhostClient.ID, _ = domain.ParseClientID("http://localhost/")
+ testClient := domain.TestClient(t)
+ clients := repository.NewMemoryClientRepository()
- for _, client := range []*domain.Client{testClient, localhostClient} {
- store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
+ if err := clients.Create(context.Background(), *testClient); err != nil {
+ t.Fatal(err)
}
for _, tc := range []struct {
@@ -34,17 +30,13 @@ func TestDiscovery(t *testing.T) {
name: "default",
in: testClient,
out: testClient,
- }, {
- name: "localhost",
- in: localhostClient,
- expError: client.ErrNotExist,
}} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
- result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)).
+ result, err := usecase.NewClientUseCase(clients).
Discovery(context.Background(), tc.in.ID)
if tc.expError != nil && !errors.Is(err, tc.expError) {
t.Errorf("Discovery(%s) = %+v, want %+v", tc.in.ID, err, tc.expError)
diff --git a/internal/domain/client.go b/internal/domain/client.go
index a2fb8d1..44131a3 100644
--- a/internal/domain/client.go
+++ b/internal/domain/client.go
@@ -9,7 +9,7 @@ import (
// Client describes the client requesting data about the user.
type Client struct {
- ID *ClientID
+ ID ClientID
Logo []*url.URL
RedirectURI []*url.URL
URL []*url.URL
@@ -17,7 +17,7 @@ type Client struct {
}
// NewClient creates a new empty Client with provided ClientID, if any.
-func NewClient(cid *ClientID) *Client {
+func NewClient(cid ClientID) *Client {
return &Client{
ID: cid,
Logo: make([]*url.URL, 0),
@@ -32,7 +32,7 @@ func TestClient(tb testing.TB) *Client {
tb.Helper()
return &Client{
- ID: TestClientID(tb),
+ ID: *TestClientID(tb),
Name: []string{"Example App"},
URL: []*url.URL{{Scheme: "https", Host: "app.example.com", Path: "/"}},
Logo: []*url.URL{{Scheme: "https", Host: "app.example.com", Path: "/logo.png"}},
diff --git a/internal/domain/client_id.go b/internal/domain/client_id.go
index f972db5..206d129 100644
--- a/internal/domain/client_id.go
+++ b/internal/domain/client_id.go
@@ -8,6 +8,8 @@ import (
"testing"
"inet.af/netaddr"
+
+ "source.toby3d.me/toby3d/auth/internal/common"
)
// ClientID is a URL client identifier.
@@ -37,16 +39,20 @@ func ParseClientID(src string) (*ClientID, error) {
if cid.Scheme != "http" && cid.Scheme != "https" {
return nil, NewError(
ErrorCodeInvalidRequest,
- "client identifier URL MUST have either an https or http scheme",
+ "client identifier URL MUST have either an https or http scheme, got '"+cid.Scheme+"'",
"https://indieauth.net/source/#client-identifier",
)
}
- if cid.Path == "" || strings.Contains(cid.Path, "/.") || strings.Contains(cid.Path, "/..") {
+ if cid.Path == "" {
+ cid.Path = "/"
+ }
+
+ if strings.Contains(cid.Path, "/.") || strings.Contains(cid.Path, "/..") {
return nil, NewError(
ErrorCodeInvalidRequest,
"client identifier URL MUST contain a path component and MUST NOT contain "+
- "single-dot or double-dot path segments",
+ "single-dot or double-dot path segments, got '"+cid.Path+"'",
"https://indieauth.net/source/#client-identifier",
)
}
@@ -54,7 +60,7 @@ func ParseClientID(src string) (*ClientID, error) {
if cid.Fragment != "" {
return nil, NewError(
ErrorCodeInvalidRequest,
- "client identifier URL MUST NOT contain a fragment component",
+ "client identifier URL MUST NOT contain a fragment component, got '"+cid.Fragment+"'",
"https://indieauth.net/source/#client-identifier",
)
}
@@ -62,7 +68,8 @@ func ParseClientID(src string) (*ClientID, error) {
if cid.User != nil {
return nil, NewError(
ErrorCodeInvalidRequest,
- "client identifier URL MUST NOT contain a username or password component",
+ "client identifier URL MUST NOT contain a username or password component, got '"+
+ cid.User.String()+"'",
"https://indieauth.net/source/#client-identifier",
)
}
@@ -71,7 +78,7 @@ func ParseClientID(src string) (*ClientID, error) {
if domain == "" {
return nil, NewError(
ErrorCodeInvalidRequest,
- "client host name MUST be domain name or a loopback interface",
+ "client host name MUST be domain name or a loopback interface, got '"+domain+"'",
"https://indieauth.net/source/#client-identifier",
)
}
@@ -102,10 +109,15 @@ func ParseClientID(src string) (*ClientID, error) {
}
// TestClientID returns valid random generated ClientID for tests.
-func TestClientID(tb testing.TB) *ClientID {
+func TestClientID(tb testing.TB, forceURL ...string) *ClientID {
tb.Helper()
- clientID, err := ParseClientID("https://example.com/")
+ in := "https://app.example.com/"
+ if len(forceURL) > 0 {
+ in = forceURL[0]
+ }
+
+ clientID, err := ParseClientID(in)
if err != nil {
tb.Fatal(err)
}
@@ -147,6 +159,11 @@ func (cid ClientID) MarshalJSON() ([]byte, error) {
return []byte(strconv.Quote(cid.String())), nil
}
+// IsEqual checks what cid is equal to provided v.
+func (cid ClientID) IsEqual(v ClientID) bool {
+ return cid.clientID.String() == v.clientID.String()
+}
+
// URL returns url.URL representation of client ID.
func (cid ClientID) URL() *url.URL {
out, _ := url.Parse(cid.clientID.String())
@@ -156,5 +173,17 @@ func (cid ClientID) URL() *url.URL {
// String returns string representation of client ID.
func (cid ClientID) String() string {
+ if cid.clientID == nil {
+ return ""
+ }
+
return cid.clientID.String()
}
+
+func (cid ClientID) GoString() string {
+ if cid.clientID == nil {
+ return "domain.ClientID(" + common.Und + ")"
+ }
+
+ return "domain.ClientID(" + cid.clientID.String() + ")"
+}
diff --git a/internal/domain/code_challenge_method_test.go b/internal/domain/code_challenge_method_test.go
index 71d6dfa..e196766 100644
--- a/internal/domain/code_challenge_method_test.go
+++ b/internal/domain/code_challenge_method_test.go
@@ -114,7 +114,7 @@ func TestCodeChallengeMethod_String(t *testing.T) {
func TestCodeChallengeMethod_Validate(t *testing.T) {
t.Parallel()
- verifier, err := random.String(gofakeit.Number(43, 128))
+ verifier, err := random.String(uint8(gofakeit.Number(43, 128)))
if err != nil {
t.Fatalf("%+v", err)
}
diff --git a/internal/domain/config.go b/internal/domain/config.go
index 8ff61d9..492b1ad 100644
--- a/internal/domain/config.go
+++ b/internal/domain/config.go
@@ -29,7 +29,6 @@ type (
Port string `yaml:"port"`
Protocol string `yaml:"protocol"`
RootURL string `yaml:"rootUrl"`
- StaticRootPath string `yaml:"staticRootPath"`
StaticURLPrefix string `yaml:"staticUrlPrefix"`
EnablePprof bool `yaml:"enablePprof"`
}
@@ -44,14 +43,14 @@ type (
// exchange it for a token or user information.
ConfigCode struct {
Expiry time.Duration `yaml:"expiry"` // 10m
- Length int `yaml:"length"` // 32
+ Length uint8 `yaml:"length"` // 32
}
ConfigJWT struct {
Expiry time.Duration `yaml:"expiry"` // 1h
Algorithm string `yaml:"algorithm"` // HS256
Secret string `yaml:"secret"`
- NonceLength int `yaml:"nonceLength"` // 22
+ NonceLength uint8 `yaml:"nonceLength"` // 22
}
ConfigIndieAuth struct {
@@ -62,7 +61,7 @@ type (
ConfigTicketAuth struct {
Expiry time.Duration `yaml:"expiry"` // 1m
- Length int `yaml:"length"` // 24
+ Length uint8 `yaml:"length"` // 24
}
ConfigRelMeAuth struct {
@@ -95,7 +94,6 @@ func TestConfig(tb testing.TB) *Config {
Port: "3000",
Protocol: "http",
RootURL: "{{protocol}}://{{domain}}:{{port}}/",
- StaticRootPath: "/",
StaticURLPrefix: "/static",
},
Database: ConfigDatabase{
@@ -136,7 +134,6 @@ func (cs ConfigServer) GetRootURL() string {
"host": cs.Host,
"port": cs.Port,
"protocol": cs.Protocol,
- "staticRootPath": cs.StaticRootPath,
"staticUrlPrefix": cs.StaticURLPrefix,
})
}
diff --git a/internal/domain/me.go b/internal/domain/me.go
index 9a6e6a3..cb6f146 100644
--- a/internal/domain/me.go
+++ b/internal/domain/me.go
@@ -31,7 +31,7 @@ func ParseMe(raw string) (*Me, error) {
if id.Scheme != "http" && id.Scheme != "https" {
return nil, NewError(
ErrorCodeInvalidRequest,
- "profile URL MUST have either an https or http scheme",
+ "profile URL MUST have either an https or http scheme, got '"+id.Scheme+"'",
"https://indieauth.net/source/#user-profile-url",
"",
)
@@ -45,7 +45,7 @@ func ParseMe(raw string) (*Me, error) {
return nil, NewError(
ErrorCodeInvalidRequest,
"profile URL MUST contain a path component (/ is a valid path), MUST NOT contain single-dot "+
- "or double-dot path segments",
+ "or double-dot path segments, got '"+id.Path+"'",
"https://indieauth.net/source/#user-profile-url",
"",
)
@@ -54,7 +54,7 @@ func ParseMe(raw string) (*Me, error) {
if id.Fragment != "" {
return nil, NewError(
ErrorCodeInvalidRequest,
- "profile URL MUST NOT contain a fragment component",
+ "profile URL MUST NOT contain a fragment component, got '"+id.Fragment+"'",
"https://indieauth.net/source/#user-profile-url",
"",
)
@@ -63,7 +63,7 @@ func ParseMe(raw string) (*Me, error) {
if id.User != nil {
return nil, NewError(
ErrorCodeInvalidRequest,
- "profile URL MUST NOT contain a username or password component",
+ "profile URL MUST NOT contain a username or password component, got '"+id.User.String()+"'",
"https://indieauth.net/source/#user-profile-url",
"",
)
@@ -72,7 +72,7 @@ func ParseMe(raw string) (*Me, error) {
if id.Host == "" {
return nil, NewError(
ErrorCodeInvalidRequest,
- "profile host name MUST be a domain name",
+ "profile host name MUST be a domain name, got '"+id.Host+"'",
"https://indieauth.net/source/#user-profile-url",
"",
)
@@ -81,16 +81,16 @@ func ParseMe(raw string) (*Me, error) {
if _, port, _ := net.SplitHostPort(id.Host); port != "" {
return nil, NewError(
ErrorCodeInvalidRequest,
- "profile MUST NOT contain a port",
+ "profile MUST NOT contain a port, got '"+port+"'",
"https://indieauth.net/source/#user-profile-url",
"",
)
}
- if net.ParseIP(id.Host) != nil {
+ if out := net.ParseIP(id.Host); out != nil {
return nil, NewError(
ErrorCodeInvalidRequest,
- "profile MUST NOT be ipv4 or ipv6 addresses",
+ "profile MUST NOT be ipv4 or ipv6 addresses, got '"+out.String()+"'",
"https://indieauth.net/source/#user-profile-url",
"",
)
@@ -103,12 +103,12 @@ func ParseMe(raw string) (*Me, error) {
func TestMe(tb testing.TB, src string) *Me {
tb.Helper()
- me, err := ParseMe(src)
+ u, err := url.Parse(src)
if err != nil {
tb.Fatal(err)
}
- return me
+ return &Me{id: u}
}
// UnmarshalForm implements custom unmarshler for form values.
diff --git a/internal/domain/metadata.go b/internal/domain/metadata.go
index f078453..e1fac9f 100644
--- a/internal/domain/metadata.go
+++ b/internal/domain/metadata.go
@@ -14,7 +14,7 @@ type Metadata struct {
// issuer URL could be https://example.com/, or for a metadata URL of
// https://example.com/wp-json/indieauth/1.0/metadata, the issuer URL
// could be https://example.com/wp-json/indieauth/1.0
- Issuer *ClientID
+ Issuer *url.URL
// The Authorization Endpoint.
AuthorizationEndpoint *url.URL
@@ -81,7 +81,11 @@ func TestMetadata(tb testing.TB) *Metadata {
tb.Helper()
return &Metadata{
- Issuer: TestClientID(tb),
+ Issuer: &url.URL{
+ Scheme: "https",
+ Host: "example.com",
+ Path: "/.well-known/oauth-authorization-server",
+ },
AuthorizationEndpoint: &url.URL{Scheme: "https", Host: "indieauth.example.com", Path: "/auth"},
TokenEndpoint: &url.URL{Scheme: "https", Host: "indieauth.example.com", Path: "/token"},
TicketEndpoint: &url.URL{Scheme: "https", Host: "auth.example.org", Path: "/ticket"},
diff --git a/internal/domain/provider.go b/internal/domain/provider.go
index 377bac1..bb37d9b 100644
--- a/internal/domain/provider.go
+++ b/internal/domain/provider.go
@@ -1,10 +1,9 @@
package domain
import (
+ "net/url"
"path"
"strings"
-
- http "github.com/valyala/fasthttp"
)
// Provider represent 3rd party RelMeAuth provider.
@@ -91,9 +90,10 @@ var (
// AuthCodeURL returns URL for authorize user in RelMeAuth client.
func (p Provider) AuthCodeURL(state string) string {
- uri := http.AcquireURI()
- defer http.ReleaseURI(uri)
- uri.Update(p.AuthURL)
+ u, err := url.Parse(p.AuthURL)
+ if err != nil {
+ return ""
+ }
for key, val := range map[string]string{
"client_id": p.ClientID,
@@ -102,8 +102,8 @@ func (p Provider) AuthCodeURL(state string) string {
"scope": strings.Join(p.Scopes, " "),
"state": state,
} {
- uri.QueryArgs().Set(key, val)
+ u.Query().Set(key, val)
}
- return uri.String()
+ return u.String()
}
diff --git a/internal/domain/scope.go b/internal/domain/scope.go
index 7367b32..126a52d 100644
--- a/internal/domain/scope.go
+++ b/internal/domain/scope.go
@@ -80,6 +80,22 @@ func ParseScope(uid string) (Scope, error) {
return ScopeUnd, fmt.Errorf("%w: %s", ErrScopeUnknown, uid)
}
+func (s *Scope) UnmarshalJSON(v []byte) error {
+ src, err := strconv.Unquote(string(v))
+ if err != nil {
+ return fmt.Errorf("Scope: UnmarshalJSON: cannot unquote string: %w", err)
+ }
+
+ out, err := ParseScope(src)
+ if err != nil {
+ return fmt.Errorf("Scopes: UnmarshalJSON: cannot parse scope: %w", err)
+ }
+
+ *s = out
+
+ return nil
+}
+
func (s Scope) MarshalJSON() ([]byte, error) {
return []byte(strconv.Quote(s.uid)), nil
}
diff --git a/internal/domain/session.go b/internal/domain/session.go
index bda7542..a8f7e37 100644
--- a/internal/domain/session.go
+++ b/internal/domain/session.go
@@ -9,9 +9,9 @@ import (
//nolint:tagliatelle
type Session struct {
- ClientID *ClientID `json:"client_id"`
+ ClientID ClientID `json:"client_id"`
RedirectURI *url.URL `json:"redirect_uri"`
- Me *Me `json:"me"`
+ Me Me `json:"me"`
Profile *Profile `json:"profile,omitempty"`
Scope Scopes `json:"scope"`
CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method,omitempty"`
@@ -31,12 +31,12 @@ func TestSession(tb testing.TB) *Session {
}
return &Session{
- ClientID: TestClientID(tb),
+ ClientID: *TestClientID(tb),
Code: code,
CodeChallenge: "hackme",
CodeChallengeMethod: CodeChallengeMethodPLAIN,
Profile: TestProfile(tb),
- Me: TestMe(tb, "https://user.example.net/"),
+ Me: *TestMe(tb, "https://user.example.net/"),
RedirectURI: &url.URL{Scheme: "https", Host: "example.com", Path: "/callback"},
Scope: Scopes{
ScopeEmail,
diff --git a/internal/domain/token.go b/internal/domain/token.go
index 8acd927..fcb1a5c 100644
--- a/internal/domain/token.go
+++ b/internal/domain/token.go
@@ -2,13 +2,14 @@ package domain
import (
"fmt"
+ "net/http"
"testing"
"time"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
- http "github.com/valyala/fasthttp"
+ "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/random"
)
@@ -17,8 +18,8 @@ type (
Token struct {
CreatedAt time.Time
Expiry time.Time
- ClientID *ClientID
- Me *Me
+ ClientID ClientID
+ Me Me
Scope Scopes
AccessToken string
RefreshToken string
@@ -27,12 +28,12 @@ type (
// NewTokenOptions contains options for NewToken function.
NewTokenOptions struct {
Expiration time.Duration
- Issuer *ClientID
- Subject *Me
+ Issuer ClientID
+ Subject Me
Scope Scopes
Secret []byte
Algorithm string
- NonceLength int
+ NonceLength uint8
}
)
@@ -42,8 +43,8 @@ type (
var DefaultNewTokenOptions = NewTokenOptions{
Expiration: 0,
Scope: nil,
- Issuer: nil,
- Subject: nil,
+ Issuer: ClientID{},
+ Subject: Me{},
Secret: nil,
Algorithm: "HS256",
NonceLength: 32,
@@ -82,7 +83,7 @@ func NewToken(opts NewTokenOptions) (*Token, error) {
}
}
- if opts.Issuer != nil {
+ if opts.Issuer.clientID != nil {
if err = tkn.Set(jwt.IssuerKey, opts.Issuer.String()); err != nil {
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
}
@@ -157,8 +158,8 @@ func TestToken(tb testing.TB) *Token {
return &Token{
CreatedAt: now.Add(-1 * time.Hour),
Expiry: now.Add(1 * time.Hour),
- ClientID: cid,
- Me: me,
+ ClientID: *cid,
+ Me: *me,
Scope: scope,
AccessToken: string(accessToken),
RefreshToken: "", // TODO(toby3d)
@@ -171,7 +172,7 @@ func (t Token) SetAuthHeader(r *http.Request) {
return
}
- r.Header.Set(http.HeaderAuthorization, t.String())
+ r.Header.Set(common.HeaderAuthorization, t.String())
}
// String returns string representation of token.
diff --git a/internal/domain/token_test.go b/internal/domain/token_test.go
index 1efc2b7..0706075 100644
--- a/internal/domain/token_test.go
+++ b/internal/domain/token_test.go
@@ -1,13 +1,12 @@
package domain_test
import (
- "bytes"
"fmt"
+ "net/http"
"testing"
"time"
- http "github.com/valyala/fasthttp"
-
+ "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
)
@@ -40,16 +39,13 @@ func TestNewToken(t *testing.T) {
func TestToken_SetAuthHeader(t *testing.T) {
t.Parallel()
- token := domain.TestToken(t)
- expResult := []byte("Bearer " + token.AccessToken)
+ in := domain.TestToken(t)
+ req, _ := http.NewRequest(http.MethodGet, "https://example.com/", nil)
+ in.SetAuthHeader(req)
- req := http.AcquireRequest()
- defer http.ReleaseRequest(req)
- token.SetAuthHeader(req)
-
- result := req.Header.Peek(http.HeaderAuthorization)
- if result == nil || !bytes.Equal(result, expResult) {
- t.Errorf("SetAuthHeader(%+v) = %s, want %s", req, result, expResult)
+ exp := "Bearer " + in.AccessToken
+ if out := req.Header.Get(common.HeaderAuthorization); out != exp {
+ t.Errorf("SetAuthHeader(%+v) = %s, want %s", req, out, exp)
}
}
@@ -57,9 +53,9 @@ func TestToken_String(t *testing.T) {
t.Parallel()
token := domain.TestToken(t)
- expResult := "Bearer " + token.AccessToken
+ exp := "Bearer " + token.AccessToken
- if result := token.String(); result != expResult {
- t.Errorf("String() = %s, want %s", result, expResult)
+ if out := token.String(); out != exp {
+ t.Errorf("String() = %s, want %s", out, exp)
}
}
diff --git a/internal/domain/url.go b/internal/domain/url.go
index 6a95bf7..84a9686 100644
--- a/internal/domain/url.go
+++ b/internal/domain/url.go
@@ -5,6 +5,8 @@ import (
"net/url"
"strconv"
"testing"
+
+ "source.toby3d.me/toby3d/auth/internal/common"
)
// URL describe any valid HTTP URL.
@@ -75,3 +77,11 @@ func (u *URL) UnmarshalJSON(v []byte) error {
func (u URL) MarshalJSON() ([]byte, error) {
return []byte(strconv.Quote(u.String())), nil
}
+
+func (u URL) GoString() string {
+ if u.URL == nil {
+ return "domain.URL(" + common.Und + ")"
+ }
+
+ return "domain.URL(" + u.URL.String() + ")"
+}
diff --git a/internal/health/delivery/http/health_http.go b/internal/health/delivery/http/health_http.go
index 263155a..30c4a89 100644
--- a/internal/health/delivery/http/health_http.go
+++ b/internal/health/delivery/http/health_http.go
@@ -5,7 +5,6 @@ import (
"net/http"
"source.toby3d.me/toby3d/auth/internal/common"
- "source.toby3d.me/toby3d/auth/internal/middleware"
)
type Handler struct{}
@@ -14,8 +13,8 @@ func NewHandler() *Handler {
return &Handler{}
}
-func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- http.HandlerFunc(middleware.HandlerFunc(h.handleFunc).Intercept(middleware.LogFmt())).ServeHTTP(w, r)
+func (h *Handler) Handler() http.Handler {
+ return http.HandlerFunc(h.handleFunc)
}
func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) {
diff --git a/internal/health/delivery/http/health_http_test.go b/internal/health/delivery/http/health_http_test.go
index 3e8ebd2..68cd586 100644
--- a/internal/health/delivery/http/health_http_test.go
+++ b/internal/health/delivery/http/health_http_test.go
@@ -2,11 +2,10 @@ package http_test
import (
"io"
+ "net/http"
"net/http/httptest"
"testing"
- http "github.com/valyala/fasthttp"
-
delivery "source.toby3d.me/toby3d/auth/internal/health/delivery/http"
)
@@ -15,7 +14,10 @@ func TestRequestHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "https://example.com/health", nil)
w := httptest.NewRecorder()
- delivery.NewHandler().ServeHTTP(w, req)
+
+ delivery.NewHandler().
+ Handler().
+ ServeHTTP(w, req)
resp := w.Result()
diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go
index d51b62a..1975704 100644
--- a/internal/httputil/httputil.go
+++ b/internal/httputil/httputil.go
@@ -2,33 +2,74 @@ package httputil
import (
"bytes"
- "encoding/json"
"fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
"net/url"
"strings"
+ "github.com/goccy/go-json"
"github.com/tomnomnom/linkheader"
- http "github.com/valyala/fasthttp"
+ "golang.org/x/exp/slices"
"willnorris.com/go/microformats"
+ "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
)
+const RelIndieauthMetadata = "indieauth-metadata"
+
var ErrEndpointNotExist = domain.NewError(
domain.ErrorCodeServerError,
"cannot found any endpoints",
"https://indieauth.net/source/#discovery-0",
)
-func ExtractEndpoints(resp *http.Response, rel string) []*url.URL {
+func ExtractFromMetadata(client *http.Client, u string) (*domain.Metadata, error) {
+ req, err := http.NewRequest(http.MethodGet, u, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ buf := bytes.NewBuffer(body)
+
+ endpoints := ExtractEndpoints(buf, resp.Request.URL, resp.Header.Get(common.HeaderLink), RelIndieauthMetadata)
+ if len(endpoints) == 0 {
+ return nil, ErrEndpointNotExist
+ }
+
+ if resp, err = client.Get(endpoints[len(endpoints)-1].String()); err != nil {
+ return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err)
+ }
+
+ result := new(domain.Metadata)
+ if err = json.NewDecoder(resp.Body).Decode(result); err != nil {
+ return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err)
+ }
+
+ return result, nil
+}
+
+func ExtractEndpoints(body io.Reader, u *url.URL, linkHeader, rel string) []*url.URL {
results := make([]*url.URL, 0)
- urls, err := ExtractEndpointsFromHeader(resp, rel)
+ urls, err := ExtractEndpointsFromHeader(linkHeader, rel)
if err == nil {
results = append(results, urls...)
}
- urls, err = ExtractEndpointsFromBody(resp, rel)
+ urls, err = ExtractEndpointsFromBody(body, u, rel)
if err == nil {
results = append(results, urls...)
}
@@ -36,15 +77,15 @@ func ExtractEndpoints(resp *http.Response, rel string) []*url.URL {
return results
}
-func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*url.URL, error) {
+func ExtractEndpointsFromHeader(linkHeader, rel string) ([]*url.URL, error) {
results := make([]*url.URL, 0)
- for _, link := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) {
+ for _, link := range linkheader.Parse(linkHeader) {
if !strings.EqualFold(link.Rel, rel) {
continue
}
- u, err := url.ParseRequestURI(link.URL)
+ u, err := url.Parse(link.URL)
if err != nil {
return nil, fmt.Errorf("cannot parse header endpoint: %w", err)
}
@@ -55,8 +96,8 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*url.URL, er
return results, nil
}
-func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*url.URL, error) {
- endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), nil).Rels[rel]
+func ExtractEndpointsFromBody(body io.Reader, u *url.URL, rel string) ([]*url.URL, error) {
+ endpoints, ok := microformats.Parse(body, u).Rels[rel]
if !ok || len(endpoints) == 0 {
return nil, ErrEndpointNotExist
}
@@ -75,58 +116,23 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*url.URL, erro
return results, nil
}
-func ExtractMetadata(resp *http.Response, client *http.Client) (*domain.Metadata, error) {
- endpoints := ExtractEndpoints(resp, "indieauth-metadata")
- if len(endpoints) == 0 {
- return nil, ErrEndpointNotExist
- }
-
- _, body, err := client.Get(nil, endpoints[len(endpoints)-1].String())
- if err != nil {
- return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err)
- }
-
- result := new(domain.Metadata)
- if err = json.Unmarshal(body, result); err != nil {
- return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err)
- }
-
- return result, nil
-}
-
-func ExtractProperty(resp *http.Response, itemType, key string) []interface{} {
- //nolint:exhaustivestruct // only Host part in url.URL is needed
- data := microformats.Parse(bytes.NewReader(resp.Body()), &url.URL{
- Host: string(resp.Header.Peek(http.HeaderHost)),
- })
-
- return findProperty(data.Items, itemType, key)
-}
-
-func contains(src []string, find string) bool {
- for i := range src {
- if !strings.EqualFold(src[i], find) {
- continue
- }
-
- return true
- }
-
- return false
-}
-
-func findProperty(src []*microformats.Microformat, itemType, key string) []interface{} {
- for _, item := range src {
- if contains(item.Type, itemType) {
- return item.Properties[key]
- }
-
- result := findProperty(item.Children, itemType, key)
- if result == nil {
- continue
- }
-
- return result
+func ExtractProperty(body io.Reader, u *url.URL, itemType, key string) []any {
+ if data := microformats.Parse(body, u); data != nil {
+ return FindProperty(data.Items, itemType, key)
+ }
+
+ return nil
+}
+
+func FindProperty(src []*microformats.Microformat, itemType, key string) []any {
+ for _, item := range src {
+ if slices.Contains(item.Type, itemType) {
+ return item.Properties[key]
+ }
+
+ if result := FindProperty(item.Children, itemType, key); result != nil {
+ return result
+ }
}
return nil
diff --git a/internal/httputil/httputil_test.go b/internal/httputil/httputil_test.go
index 20a47ab..065e285 100644
--- a/internal/httputil/httputil_test.go
+++ b/internal/httputil/httputil_test.go
@@ -1,30 +1,72 @@
package httputil_test
import (
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "strings"
"testing"
- http "github.com/valyala/fasthttp"
+ "github.com/google/go-cmp/cmp"
"source.toby3d.me/toby3d/auth/internal/httputil"
)
const testBody = `
+
+
+
+
-
+
Sample Name
`
+func TestExtractEndpointsFromBody(t *testing.T) {
+ t.Parallel()
+
+ req, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ in := &http.Response{
+ Body: ioutil.NopCloser(strings.NewReader(testBody)),
+ Request: req,
+ }
+
+ out, err := httputil.ExtractEndpointsFromBody(in.Body, req.URL, "lipsum")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ exp := []*url.URL{
+ {Scheme: "https", Host: "example.com", Path: "/"},
+ {Scheme: "https", Host: "example.net", Path: "/"},
+ }
+
+ if !cmp.Equal(out, exp) {
+ t.Errorf(`ExtractProperty(resp, "h-card", "name") = %+s, want %+s`, out, exp)
+ }
+}
+
func TestExtractProperty(t *testing.T) {
t.Parallel()
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
- resp.SetBodyString(testBody)
+ req, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
- results := httputil.ExtractProperty(resp, "h-card", "name")
- if results == nil || results[0] != "Sample Name" {
- t.Errorf(`ExtractProperty(resp, "h-card", "name") = %+s, want %+s`, results, []string{"Sample Name"})
+ in := &http.Response{
+ Body: ioutil.NopCloser(strings.NewReader(testBody)),
+ Request: req,
+ }
+
+ if out := httputil.ExtractProperty(in.Body, req.URL, "h-app", "name"); out == nil || out[0] != "Sample Name" {
+ t.Errorf(`ExtractProperty(%s, %s, %s) = %+s, want %+s`, req.URL, "h-app", "name", out,
+ []string{"Sample Name"})
}
}
diff --git a/internal/metadata/delivery/http/metadata_http.go b/internal/metadata/delivery/http/metadata_http.go
index d65d553..ea61746 100644
--- a/internal/metadata/delivery/http/metadata_http.go
+++ b/internal/metadata/delivery/http/metadata_http.go
@@ -1,13 +1,12 @@
package http
import (
- "github.com/fasthttp/router"
+ "net/http"
+
"github.com/goccy/go-json"
- http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
- "source.toby3d.me/toby3d/middleware"
)
type (
@@ -60,28 +59,29 @@ type (
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
}
- RequestHandler struct {
+ Handler struct {
metadata *domain.Metadata
}
)
-func NewRequestHandler(metadata *domain.Metadata) *RequestHandler {
- return &RequestHandler{
+func NewHandler(metadata *domain.Metadata) *Handler {
+ return &Handler{
metadata: metadata,
}
}
-func (h *RequestHandler) Register(r *router.Router) {
- chain := middleware.Chain{
- middleware.LogFmt(),
- }
-
- r.GET("/.well-known/oauth-authorization-server", chain.RequestHandler(h.read))
+func (h *Handler) Handler() http.Handler {
+ return http.HandlerFunc(h.handleFunc)
}
-func (h *RequestHandler) read(ctx *http.RequestCtx) {
- ctx.SetStatusCode(http.StatusOK)
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
+func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "" && r.Method != http.MethodGet {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
scopes, responseTypes, grantTypes, codeChallengeMethods := make([]string, 0), make([]string, 0),
make([]string, 0), make([]string, 0)
@@ -103,7 +103,7 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) {
h.metadata.CodeChallengeMethodsSupported[i].String())
}
- _ = json.NewEncoder(ctx).Encode(&MetadataResponse{
+ _ = json.NewEncoder(w).Encode(&MetadataResponse{
AuthorizationEndpoint: h.metadata.AuthorizationEndpoint.String(),
IntrospectionEndpoint: h.metadata.IntrospectionEndpoint.String(),
Issuer: h.metadata.Issuer.String(),
@@ -123,4 +123,6 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) {
// client_secret_basic according to RFC8414.
RevocationEndpointAuthMethodsSupported: h.metadata.RevocationEndpointAuthMethodsSupported,
})
+
+ w.WriteHeader(http.StatusOK)
}
diff --git a/internal/metadata/delivery/http/metadata_http_test.go b/internal/metadata/delivery/http/metadata_http_test.go
index c8e545f..8c2dde1 100644
--- a/internal/metadata/delivery/http/metadata_http_test.go
+++ b/internal/metadata/delivery/http/metadata_http_test.go
@@ -1,40 +1,36 @@
package http_test
import (
+ "net/http"
+ "net/http/httptest"
"testing"
- "github.com/fasthttp/router"
"github.com/goccy/go-json"
- http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/domain"
delivery "source.toby3d.me/toby3d/auth/internal/metadata/delivery/http"
- "source.toby3d.me/toby3d/auth/internal/testing/httptest"
)
func TestMetadata(t *testing.T) {
t.Parallel()
- r := router.New()
metadata := domain.TestMetadata(t)
- delivery.NewRequestHandler(metadata).Register(r)
- client, _, cleanup := httptest.New(t, r.Handler)
- t.Cleanup(cleanup)
+ req := httptest.NewRequest(http.MethodGet, "https://example.com/.well-known/oauth-authorization-server", nil)
- const requestURL string = "https://example.com/.well-known/oauth-authorization-server"
+ w := httptest.NewRecorder()
+ delivery.NewHandler(metadata).
+ Handler().
+ ServeHTTP(w, req)
- status, body, err := client.Get(nil, requestURL)
- if err != nil {
- t.Fatal(err)
+ resp := w.Result()
+
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK)
}
- if status != http.StatusOK {
- t.Errorf("GET %s = %d, want %d", requestURL, status, http.StatusOK)
- }
-
- result := new(delivery.MetadataResponse)
- if err = json.Unmarshal(body, result); err != nil {
+ out := new(delivery.MetadataResponse)
+ if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
t.Fatal(err)
}
}
diff --git a/internal/metadata/repository.go b/internal/metadata/repository.go
index 0672df4..0b4cd4c 100644
--- a/internal/metadata/repository.go
+++ b/internal/metadata/repository.go
@@ -2,12 +2,14 @@ package metadata
import (
"context"
+ "net/url"
"source.toby3d.me/toby3d/auth/internal/domain"
)
type Repository interface {
- Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error)
+ Create(_ context.Context, _ *url.URL, _ domain.Metadata) error
+ Get(_ context.Context, u *url.URL) (*domain.Metadata, error)
}
var ErrNotExist error = domain.NewError(
diff --git a/internal/metadata/repository/http/http_metadata.go b/internal/metadata/repository/http/http_metadata.go
index 618432a..bea3b6c 100644
--- a/internal/metadata/repository/http/http_metadata.go
+++ b/internal/metadata/repository/http/http_metadata.go
@@ -2,26 +2,29 @@ package http
import (
"context"
- "encoding/json"
"fmt"
+ "net/http"
+ "net/url"
- http "github.com/valyala/fasthttp"
+ "github.com/goccy/go-json"
+ "github.com/tomnomnom/linkheader"
+ "willnorris.com/go/microformats"
+ "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
- "source.toby3d.me/toby3d/auth/internal/httputil"
"source.toby3d.me/toby3d/auth/internal/metadata"
)
type (
//nolint:tagliatelle,lll
- Metadata struct {
- Issuer *domain.ClientID `json:"issuer"`
- AuthorizationEndpoint *domain.URL `json:"authorization_endpoint"`
- IntrospectionEndpoint *domain.URL `json:"introspection_endpoint"`
- RevocationEndpoint *domain.URL `json:"revocation_endpoint,omitempty"`
- ServiceDocumentation *domain.URL `json:"service_documentation,omitempty"`
- TokenEndpoint *domain.URL `json:"token_endpoint"`
- UserinfoEndpoint *domain.URL `json:"userinfo_endpoint,omitempty"`
+ Response struct {
+ Issuer domain.URL `json:"issuer"`
+ AuthorizationEndpoint domain.URL `json:"authorization_endpoint"`
+ IntrospectionEndpoint domain.URL `json:"introspection_endpoint"`
+ RevocationEndpoint domain.URL `json:"revocation_endpoint,omitempty"`
+ ServiceDocumentation domain.URL `json:"service_documentation,omitempty"`
+ TokenEndpoint domain.URL `json:"token_endpoint"`
+ UserinfoEndpoint domain.URL `json:"userinfo_endpoint,omitempty"`
CodeChallengeMethodsSupported []domain.CodeChallengeMethod `json:"code_challenge_methods_supported"`
GrantTypesSupported []domain.GrantType `json:"grant_types_supported,omitempty"`
ResponseTypesSupported []domain.ResponseType `json:"response_types_supported,omitempty"`
@@ -29,6 +32,11 @@ type (
IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"`
RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"`
AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"`
+
+ // Extensions
+ TicketEndpoint domain.URL `json:"ticket_endpoint"`
+ Micropub domain.URL `json:"micropub"`
+ Microsub domain.URL `json:"microsub"`
}
httpMetadataRepository struct {
@@ -36,7 +44,7 @@ type (
}
)
-const DefaultMaxRedirectsCount int = 10
+const relIndieauthMetadata = "indieauth-metadata"
func NewHTTPMetadataRepository(client *http.Client) metadata.Repository {
return &httpMetadataRepository{
@@ -44,48 +52,127 @@ func NewHTTPMetadataRepository(client *http.Client) metadata.Repository {
}
}
-func (repo *httpMetadataRepository) Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error) {
- req := http.AcquireRequest()
- defer http.ReleaseRequest(req)
- req.SetRequestURI(me.String())
- req.Header.SetMethod(http.MethodGet)
-
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
-
- if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
- return nil, fmt.Errorf("failed to make a request to the client: %w", err)
- }
-
- endpoints := httputil.ExtractEndpoints(resp, "indieauth-metadata")
- if len(endpoints) == 0 {
- return nil, metadata.ErrNotExist
- }
-
- _, body, err := repo.client.Get(nil, endpoints[len(endpoints)-1].String())
- if err != nil {
- return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err)
- }
-
- data := new(Metadata)
- if err = json.Unmarshal(body, data); err != nil {
- return nil, fmt.Errorf("cannot unmarshal metadata configuration: %w", err)
- }
-
- //nolint:exhaustivestruct // TODO(toby3d)
- return &domain.Metadata{
- AuthorizationEndpoint: data.AuthorizationEndpoint.URL,
- AuthorizationResponseIssParameterSupported: data.AuthorizationResponseIssParameterSupported,
- CodeChallengeMethodsSupported: data.CodeChallengeMethodsSupported,
- GrantTypesSupported: data.GrantTypesSupported,
- Issuer: data.Issuer,
- ResponseTypesSupported: data.ResponseTypesSupported,
- ScopesSupported: data.ScopesSupported,
- ServiceDocumentation: data.ServiceDocumentation.URL,
- TokenEndpoint: data.TokenEndpoint.URL,
- // TODO(toby3d): support extensions?
- // Micropub: data.Micropub,
- // Microsub: data.Microsub,
- // TicketEndpoint: data.TicketEndpoint,
- }, nil
+// WARN(toby3d): not implemented.
+func (httpMetadataRepository) Create(_ context.Context, _ *url.URL, _ domain.Metadata) error {
+ return nil
+}
+
+func (repo *httpMetadataRepository) Get(_ context.Context, u *url.URL) (*domain.Metadata, error) {
+ resp, err := repo.client.Get(u.String())
+ if err != nil {
+ return nil, fmt.Errorf("cannot make request to provided Me: %w", err)
+ }
+
+ relVals := make(map[string][]string)
+ for _, link := range linkheader.Parse(resp.Header.Get(common.HeaderLink)) {
+ populateBuffer(relVals, link.Rel, link.URL)
+ }
+
+ if mf2 := microformats.Parse(resp.Body, resp.Request.URL); mf2 != nil {
+ for rel, vals := range mf2.Rels {
+ if len(vals) > 0 {
+ populateBuffer(relVals, rel, vals[0])
+ }
+ }
+ }
+
+ out := new(domain.Metadata)
+ // NOTE(toby3d): fetch all from metadata endpoint if exists
+ if endpoints, ok := relVals["indieauth-metadata"]; ok {
+ if resp, err = repo.client.Get(endpoints[0]); err != nil {
+ return nil, fmt.Errorf("cannot fetch indieauth-metadata endpoint: %w", err)
+ }
+
+ in := NewResponse()
+ if err = in.bind(resp); err != nil {
+ return nil, err
+ }
+
+ in.populate(out)
+
+ return out, nil
+ }
+
+ // NOTE(toby3d): metadata not exists, fallback for old clients
+ for key, dst := range map[string]**url.URL{
+ "authorization_endpoint": &out.AuthorizationEndpoint,
+ "micropub": &out.MicropubEndpoint,
+ "microsub": &out.MicrosubEndpoint,
+ "ticket_endpoint": &out.TicketEndpoint,
+ "token_endpoint": &out.TokenEndpoint,
+ } {
+ if values, ok := relVals[key]; ok && len(values) > 0 {
+ if u, err := url.Parse(values[0]); err == nil {
+ *dst = resp.Request.URL.ResolveReference(u)
+ }
+ }
+ }
+
+ return out, nil
+}
+
+func populateBuffer(dst map[string][]string, rel, u string) {
+ if _, ok := dst[rel]; !ok {
+ dst[rel] = make([]string, 0)
+ }
+
+ dst[rel] = append(dst[rel], u)
+}
+
+func NewResponse() *Response {
+ return &Response{
+ CodeChallengeMethodsSupported: make([]domain.CodeChallengeMethod, 0),
+ GrantTypesSupported: make([]domain.GrantType, 0),
+ ResponseTypesSupported: make([]domain.ResponseType, 0),
+ ScopesSupported: make([]domain.Scope, 0),
+ IntrospectionEndpointAuthMethodsSupported: make([]string, 0),
+ RevocationEndpointAuthMethodsSupported: make([]string, 0),
+ }
+}
+
+func (r *Response) bind(resp *http.Response) error {
+ if err := json.NewDecoder(resp.Body).Decode(r); err != nil {
+ return fmt.Errorf("cannot unmarshal metadata configuration: %w", err)
+ }
+
+ return nil
+}
+
+func (r Response) populate(dst *domain.Metadata) {
+ dst.AuthorizationEndpoint = r.AuthorizationEndpoint.URL
+ dst.AuthorizationResponseIssParameterSupported = r.AuthorizationResponseIssParameterSupported
+ dst.IntrospectionEndpoint = r.IntrospectionEndpoint.URL
+ dst.Issuer = r.Issuer.URL
+ dst.MicropubEndpoint = r.Micropub.URL
+ dst.MicrosubEndpoint = r.Microsub.URL
+ dst.RevocationEndpoint = r.RevocationEndpoint.URL
+ dst.ServiceDocumentation = r.ServiceDocumentation.URL
+ dst.TicketEndpoint = r.TicketEndpoint.URL
+ dst.TokenEndpoint = r.TokenEndpoint.URL
+ dst.UserinfoEndpoint = r.UserinfoEndpoint.URL
+
+ for _, scope := range r.ScopesSupported {
+ dst.ScopesSupported = append(dst.ScopesSupported, scope)
+ }
+
+ for _, method := range r.RevocationEndpointAuthMethodsSupported {
+ dst.RevocationEndpointAuthMethodsSupported = append(dst.RevocationEndpointAuthMethodsSupported, method)
+ }
+
+ for _, responseType := range r.ResponseTypesSupported {
+ dst.ResponseTypesSupported = append(dst.ResponseTypesSupported, responseType)
+ }
+
+ for _, method := range r.IntrospectionEndpointAuthMethodsSupported {
+ dst.IntrospectionEndpointAuthMethodsSupported = append(dst.IntrospectionEndpointAuthMethodsSupported,
+ method)
+ }
+
+ for _, grantType := range r.GrantTypesSupported {
+ dst.GrantTypesSupported = append(dst.GrantTypesSupported, grantType)
+ }
+
+ for _, method := range r.CodeChallengeMethodsSupported {
+ dst.CodeChallengeMethodsSupported = append(dst.CodeChallengeMethodsSupported, method)
+ }
}
diff --git a/internal/metadata/repository/http/http_metadata_test.go b/internal/metadata/repository/http/http_metadata_test.go
new file mode 100644
index 0000000..2b4a791
--- /dev/null
+++ b/internal/metadata/repository/http/http_metadata_test.go
@@ -0,0 +1,183 @@
+package http_test
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/goccy/go-json"
+ "github.com/google/go-cmp/cmp"
+
+ "source.toby3d.me/toby3d/auth/internal/common"
+ "source.toby3d.me/toby3d/auth/internal/domain"
+ repository "source.toby3d.me/toby3d/auth/internal/metadata/repository/http"
+)
+
+//nolint:lll,tagliatelle
+type Response struct {
+ CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
+ GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
+ ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
+ ScopesSupported []string `json:"scopes_supported,omitempty"`
+ IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"`
+ RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"`
+ Issuer string `json:"issuer"`
+ AuthorizationEndpoint string `json:"authorization_endpoint"`
+ IntrospectionEndpoint string `json:"introspection_endpoint"`
+ RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
+ ServiceDocumentation string `json:"service_documentation,omitempty"`
+ TokenEndpoint string `json:"token_endpoint"`
+ UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
+ TicketEndpoint string `json:"ticket_endpoint"`
+ Micropub string `json:"micropub"`
+ Microsub string `json:"microsub"`
+ AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"`
+}
+
+const testBody string = `
+
+
+
+ Testing
+ %s
+
+
+`
+
+//nolint:funlen
+func TestGet(t *testing.T) {
+ t.Parallel()
+
+ testMetadata := domain.TestMetadata(t)
+
+ for _, tc := range []struct {
+ name string
+ header map[string]string
+ body map[string]string
+ out *domain.Metadata
+ }{
+ {
+ name: "header",
+ header: map[string]string{
+ "indieauth-metadata": "/metadata",
+ "authorization_endpoint": "http://example.net/authorization",
+ "token_endpoint": "http://example.net/tkn",
+ },
+ out: testMetadata,
+ }, /*{
+ name: "body",
+ body: map[string]string{
+ "indieauth-metadata": "/metadata",
+ "authorization_endpoint": "http://example.net/authorization",
+ "token_endpoint": "http://example.net/tkn",
+ },
+ out: &testMetadata,
+ }*/} {
+ tc := tc
+
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/metadata", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+ _ = json.NewEncoder(w).Encode(NewResponse(t, *testMetadata))
+ })
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ links := make([]string, 0)
+ for k, v := range tc.header {
+ links = append(links, `<`+v+`>; rel="`+k+`"`)
+ }
+
+ w.Header().Set(common.HeaderLink, strings.Join(links, ", "))
+
+ links = make([]string, 0)
+ for k, v := range tc.body {
+ links = append(links, ``)
+ }
+
+ fmt.Fprintf(w, testBody, strings.Join(links, "\n"))
+ })
+
+ srv := httptest.NewUnstartedServer(mux)
+ srv.EnableHTTP2 = true
+ srv.Start()
+ t.Cleanup(srv.Close)
+
+ tc.header["indieauth-metadata"] = srv.URL + tc.header["indieauth-metadata"]
+
+ u, _ := url.Parse(srv.URL + "/")
+ out, err := repository.NewHTTPMetadataRepository(srv.Client()).
+ Get(context.Background(), u)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if diff := cmp.Diff(tc.out, out, cmp.AllowUnexported(
+ domain.ClientID{},
+ domain.CodeChallengeMethod{},
+ domain.GrantType{},
+ domain.ResponseType{},
+ domain.Scope{},
+ url.URL{},
+ )); diff != "" {
+ t.Errorf("%+s", diff)
+ }
+ })
+ }
+}
+
+func NewResponse(tb testing.TB, src domain.Metadata) *Response {
+ tb.Helper()
+
+ out := &Response{
+ CodeChallengeMethodsSupported: make([]string, 0),
+ GrantTypesSupported: make([]string, 0),
+ ResponseTypesSupported: make([]string, 0),
+ ScopesSupported: make([]string, 0),
+ IntrospectionEndpointAuthMethodsSupported: make([]string, 0),
+ RevocationEndpointAuthMethodsSupported: make([]string, 0),
+ Issuer: src.Issuer.String(),
+ AuthorizationEndpoint: src.AuthorizationEndpoint.String(),
+ IntrospectionEndpoint: src.IntrospectionEndpoint.String(),
+ RevocationEndpoint: src.RevocationEndpoint.String(),
+ ServiceDocumentation: src.ServiceDocumentation.String(),
+ TokenEndpoint: src.TokenEndpoint.String(),
+ UserinfoEndpoint: src.UserinfoEndpoint.String(),
+ TicketEndpoint: src.TicketEndpoint.String(),
+ Micropub: src.MicropubEndpoint.String(),
+ Microsub: src.MicrosubEndpoint.String(),
+ AuthorizationResponseIssParameterSupported: src.AuthorizationResponseIssParameterSupported,
+ }
+
+ for _, method := range src.CodeChallengeMethodsSupported {
+ out.CodeChallengeMethodsSupported = append(out.CodeChallengeMethodsSupported, method.String())
+ }
+
+ for _, grantType := range src.GrantTypesSupported {
+ out.GrantTypesSupported = append(out.GrantTypesSupported, grantType.String())
+ }
+
+ for _, responseType := range src.ResponseTypesSupported {
+ out.ResponseTypesSupported = append(out.ResponseTypesSupported, responseType.String())
+ }
+
+ for _, scope := range src.ScopesSupported {
+ out.ScopesSupported = append(out.ScopesSupported, scope.String())
+ }
+
+ for _, method := range src.IntrospectionEndpointAuthMethodsSupported {
+ out.IntrospectionEndpointAuthMethodsSupported = append(out.IntrospectionEndpointAuthMethodsSupported,
+ method)
+ }
+
+ for _, method := range src.RevocationEndpointAuthMethodsSupported {
+ out.RevocationEndpointAuthMethodsSupported = append(out.RevocationEndpointAuthMethodsSupported, method)
+ }
+
+ return out
+}
diff --git a/internal/metadata/repository/memory/memory_metadata.go b/internal/metadata/repository/memory/memory_metadata.go
index cecb30a..cfa1cea 100644
--- a/internal/metadata/repository/memory/memory_metadata.go
+++ b/internal/metadata/repository/memory/memory_metadata.go
@@ -2,7 +2,7 @@ package memory
import (
"context"
- "path"
+ "net/url"
"sync"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -10,27 +10,35 @@ import (
)
type memoryMetadataRepository struct {
- store *sync.Map
+ mutex *sync.RWMutex
+ metadata map[string]domain.Metadata
}
const DefaultPathPrefix = "metadata"
-func NewMemoryMetadataRepository(store *sync.Map) metadata.Repository {
+func NewMemoryMetadataRepository() metadata.Repository {
return &memoryMetadataRepository{
- store: store,
+ mutex: new(sync.RWMutex),
+ metadata: make(map[string]domain.Metadata),
}
}
-func (repo *memoryMetadataRepository) Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error) {
- src, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String()))
- if !ok {
- return nil, metadata.ErrNotExist
- }
+func (repo *memoryMetadataRepository) Create(ctx context.Context, u *url.URL, metadata domain.Metadata) error {
+ repo.mutex.Lock()
+ defer repo.mutex.Unlock()
- result, ok := src.(*domain.Metadata)
- if !ok {
- return nil, metadata.ErrNotExist
- }
+ repo.metadata[u.String()] = metadata
- return result, nil
+ return nil
+}
+
+func (repo *memoryMetadataRepository) Get(ctx context.Context, u *url.URL) (*domain.Metadata, error) {
+ repo.mutex.RLock()
+ defer repo.mutex.RUnlock()
+
+ if out, ok := repo.metadata[u.String()]; ok {
+ return &out, nil
+ }
+
+ return nil, metadata.ErrNotExist
}
diff --git a/internal/middleware/extractor.go b/internal/middleware/extractor.go
index a5e35b7..a9c2d83 100644
--- a/internal/middleware/extractor.go
+++ b/internal/middleware/extractor.go
@@ -23,7 +23,6 @@ var (
errHeaderExtractorValueMissing = errors.New("missing value in request header")
errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
errQueryExtractorValueMissing = errors.New("missing value in the query string")
- errParamExtractorValueMissing = errors.New("missing value in path params")
errCookieExtractorValueMissing = errors.New("missing value in cookies")
errFormExtractorValueMissing = errors.New("missing value in the form")
)
@@ -67,8 +66,6 @@ func createExtractors(lookups, authScheme string) ([]ValuesExtractor, error) {
switch parts[0] {
case "query":
extractors = append(extractors, valuesFromQuery(parts[1]))
- // case "param":
- // extractors = append(extractors, valuesFromParam(parts[1]))
case "cookie":
extractors = append(extractors, valuesFromCookie(parts[1]))
case "form":
@@ -163,31 +160,6 @@ func valuesFromQuery(param string) ValuesExtractor {
}
}
-// valuesFromParam returns a function that extracts values from the url param string.
-/*
-func valuesFromParam(param string) ValuesExtractor {
- return func(w http.ResponseWriter, r *http.Request) ([]string, error) {
- result := make([]string, 0)
- paramVales := r.ParamValues()
-
- for i, p := range r.ParamNames() {
- if param == p {
- result = append(result, paramVales[i])
- if i >= extractorLimit-1 {
- break
- }
- }
- }
-
- if len(result) == 0 {
- return nil, errParamExtractorValueMissing
- }
-
- return result, nil
- }
-}
-*/
-
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) ValuesExtractor {
return func(w http.ResponseWriter, r *http.Request) ([]string, error) {
diff --git a/internal/middleware/jwt.go b/internal/middleware/jwt.go
index 2d0520e..27fe40d 100644
--- a/internal/middleware/jwt.go
+++ b/internal/middleware/jwt.go
@@ -77,7 +77,6 @@ type (
// Possible values:
// - "header:"
// - "query:"
- // - "param:"
// - "cookie:"
// - "form:"
// Multiply sources example:
diff --git a/internal/profile/repository.go b/internal/profile/repository.go
index 2070c68..becc840 100644
--- a/internal/profile/repository.go
+++ b/internal/profile/repository.go
@@ -7,7 +7,8 @@ import (
)
type Repository interface {
- Get(ctx context.Context, me *domain.Me) (*domain.Profile, error)
+ Create(ctx context.Context, me domain.Me, profile domain.Profile) error
+ Get(ctx context.Context, me domain.Me) (*domain.Profile, error)
}
var ErrNotExist error = domain.NewError(
diff --git a/internal/profile/repository/http/http_profile.go b/internal/profile/repository/http/http_profile.go
index ffd4ef1..f8dadab 100644
--- a/internal/profile/repository/http/http_profile.go
+++ b/internal/profile/repository/http/http_profile.go
@@ -1,12 +1,13 @@
package http
import (
+ "bytes"
"context"
"fmt"
+ "io"
+ "net/http"
"net/url"
- http "github.com/valyala/fasthttp"
-
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/httputil"
"source.toby3d.me/toby3d/auth/internal/profile"
@@ -33,29 +34,33 @@ func NewHTPPClientRepository(client *http.Client) profile.Repository {
}
}
+// WARN(toby3d): not implemented.
+func (repo *httpProfileRepository) Create(_ context.Context, _ domain.Me, _ domain.Profile) error {
+ return nil
+}
+
//nolint:cyclop
-func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*domain.Profile, error) {
- req := http.AcquireRequest()
- defer http.ReleaseRequest(req)
- req.Header.SetMethod(http.MethodGet)
- req.SetRequestURI(me.String())
-
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
-
- if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
+func (repo *httpProfileRepository) Get(ctx context.Context, me domain.Me) (*domain.Profile, error) {
+ resp, err := repo.client.Get(me.String())
+ if err != nil {
return nil, fmt.Errorf("%s: cannot fetch user by me: %w", ErrPrefix, err)
}
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("cannot read response body: %w", err)
+ }
+
+ buf := bytes.NewReader(body)
result := domain.NewProfile()
- for _, name := range httputil.ExtractProperty(resp, hCard, propertyName) {
+ for _, name := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyName) {
if n, ok := name.(string); ok {
result.Name = append(result.Name, n)
}
}
- for _, rawEmail := range httputil.ExtractProperty(resp, hCard, propertyEmail) {
+ for _, rawEmail := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyEmail) {
email, ok := rawEmail.(string)
if !ok {
continue
@@ -66,7 +71,7 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom
}
}
- for _, rawURL := range httputil.ExtractProperty(resp, hCard, propertyURL) {
+ for _, rawURL := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyURL) {
rawURL, ok := rawURL.(string)
if !ok {
continue
@@ -77,7 +82,7 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom
}
}
- for _, rawPhoto := range httputil.ExtractProperty(resp, hCard, propertyPhoto) {
+ for _, rawPhoto := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyPhoto) {
photo, ok := rawPhoto.(string)
if !ok {
continue
@@ -88,8 +93,8 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom
}
}
- if result.GetName() == "" && result.GetURL() == nil &&
- result.GetPhoto() == nil && result.GetEmail() == nil {
+ // TODO(toby3d): create method like result.Empty()?
+ if result.GetName() == "" && result.GetURL() == nil && result.GetPhoto() == nil && result.GetEmail() == nil {
return nil, profile.ErrNotExist
}
diff --git a/internal/profile/repository/memory/memory_profile.go b/internal/profile/repository/memory/memory_profile.go
index 0b77ba0..bcd3479 100644
--- a/internal/profile/repository/memory/memory_profile.go
+++ b/internal/profile/repository/memory/memory_profile.go
@@ -2,8 +2,6 @@ package memory
import (
"context"
- "fmt"
- "path"
"sync"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -11,30 +9,33 @@ import (
)
type memoryProfileRepository struct {
- store *sync.Map
+ mutex *sync.RWMutex
+ profiles map[string]domain.Profile
}
-const (
- ErrPrefix string = "memory"
- DefaultPathPrefix string = "profiles"
-)
-
-func NewMemoryProfileRepository(store *sync.Map) profile.Repository {
+func NewMemoryProfileRepository() profile.Repository {
return &memoryProfileRepository{
- store: store,
+ mutex: new(sync.RWMutex),
+ profiles: make(map[string]domain.Profile),
}
}
-func (repo *memoryProfileRepository) Get(_ context.Context, me *domain.Me) (*domain.Profile, error) {
- src, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String()))
- if !ok {
- return nil, fmt.Errorf("%s: cannot find profile in store: %w", ErrPrefix, profile.ErrNotExist)
- }
+func (repo *memoryProfileRepository) Create(_ context.Context, me domain.Me, p domain.Profile) error {
+ repo.mutex.Lock()
+ defer repo.mutex.Unlock()
- result, ok := src.(*domain.Profile)
- if !ok {
- return nil, fmt.Errorf("%s: cannot decode profile from store: %w", ErrPrefix, profile.ErrNotExist)
- }
+ repo.profiles[me.String()] = p
- return result, nil
+ return nil
+}
+
+func (repo *memoryProfileRepository) Get(_ context.Context, me domain.Me) (*domain.Profile, error) {
+ repo.mutex.RLock()
+ defer repo.mutex.RUnlock()
+
+ if p, ok := repo.profiles[me.String()]; ok {
+ return &p, nil
+ }
+
+ return nil, profile.ErrNotExist
}
diff --git a/internal/profile/usecase.go b/internal/profile/usecase.go
index 539fd40..6f55c30 100644
--- a/internal/profile/usecase.go
+++ b/internal/profile/usecase.go
@@ -7,7 +7,7 @@ import (
)
type UseCase interface {
- Fetch(ctx context.Context, me *domain.Me) (*domain.Profile, error)
+ Fetch(ctx context.Context, me domain.Me) (*domain.Profile, error)
}
var ErrScopeRequired error = domain.NewError(
diff --git a/internal/profile/usecase/profile_ucase.go b/internal/profile/usecase/profile_ucase.go
index 0b71ca4..d799215 100644
--- a/internal/profile/usecase/profile_ucase.go
+++ b/internal/profile/usecase/profile_ucase.go
@@ -18,7 +18,7 @@ func NewProfileUseCase(profiles profile.Repository) profile.UseCase {
}
}
-func (uc *profileUseCase) Fetch(ctx context.Context, me *domain.Me) (*domain.Profile, error) {
+func (uc *profileUseCase) Fetch(ctx context.Context, me domain.Me) (*domain.Profile, error) {
result, err := uc.profiles.Get(ctx, me)
if err != nil {
return nil, fmt.Errorf("cannot fetch profile info: %w", err)
diff --git a/internal/random/random.go b/internal/random/random.go
index 80bffd3..c6ed4cb 100644
--- a/internal/random/random.go
+++ b/internal/random/random.go
@@ -17,7 +17,7 @@ const (
Hex = Numeric + "abcdef"
)
-func Bytes(length int) ([]byte, error) {
+func Bytes(length uint8) ([]byte, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
@@ -27,7 +27,7 @@ func Bytes(length int) ([]byte, error) {
return bytes, nil
}
-func String(length int, charsets ...string) (string, error) {
+func String(length uint8, charsets ...string) (string, error) {
charset := strings.Join(charsets, "")
if charset == "" {
charset = Alphabetic
diff --git a/internal/session/repository.go b/internal/session/repository.go
index 3b5a10f..b855dbc 100644
--- a/internal/session/repository.go
+++ b/internal/session/repository.go
@@ -8,7 +8,7 @@ import (
type Repository interface {
Get(ctx context.Context, code string) (*domain.Session, error)
- Create(ctx context.Context, session *domain.Session) error
+ Create(ctx context.Context, session domain.Session) error
GetAndDelete(ctx context.Context, code string) (*domain.Session, error)
GC()
}
diff --git a/internal/session/repository/memory/memory_session.go b/internal/session/repository/memory/memory_session.go
index dde76c3..e2bdb50 100644
--- a/internal/session/repository/memory/memory_session.go
+++ b/internal/session/repository/memory/memory_session.go
@@ -3,7 +3,6 @@ package memory
import (
"context"
"fmt"
- "path"
"sync"
"time"
@@ -14,59 +13,59 @@ import (
type (
Session struct {
CreatedAt time.Time
- *domain.Session
+ domain.Session
}
memorySessionRepository struct {
- store *sync.Map
- config *domain.Config
+ config domain.Config
+ mutex *sync.RWMutex
+ sessions map[string]Session
}
)
-const DefaultPathPrefix string = "sessions"
-
-func NewMemorySessionRepository(store *sync.Map, config *domain.Config) session.Repository {
+func NewMemorySessionRepository(config domain.Config) session.Repository {
return &memorySessionRepository{
- config: config,
- store: store,
+ config: config,
+ mutex: new(sync.RWMutex),
+ sessions: make(map[string]Session),
}
}
-func (repo *memorySessionRepository) Create(_ context.Context, state *domain.Session) error {
- repo.store.Store(path.Join(DefaultPathPrefix, state.Code), &Session{
+func (repo *memorySessionRepository) Create(_ context.Context, s domain.Session) error {
+ repo.mutex.Lock()
+ defer repo.mutex.Unlock()
+
+ repo.sessions[s.Code] = Session{
CreatedAt: time.Now().UTC(),
- Session: state,
- })
+ Session: s,
+ }
return nil
}
func (repo *memorySessionRepository) Get(_ context.Context, code string) (*domain.Session, error) {
- src, ok := repo.store.Load(path.Join(DefaultPathPrefix, code))
- if !ok {
- return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist)
+ repo.mutex.Lock()
+ defer repo.mutex.Unlock()
+
+ if s, ok := repo.sessions[code]; ok {
+ return &s.Session, nil
}
- result, ok := src.(*Session)
- if !ok {
- return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist)
- }
-
- return result.Session, nil
+ return nil, session.ErrNotExist
}
-func (repo *memorySessionRepository) GetAndDelete(_ context.Context, code string) (*domain.Session, error) {
- src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, code))
- if !ok {
- return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist)
+func (repo *memorySessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) {
+ s, err := repo.Get(ctx, code)
+ if err != nil {
+ return nil, fmt.Errorf("cannot get and delete session: %w", err)
}
- result, ok := src.(*Session)
- if !ok {
- return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist)
- }
+ repo.mutex.Lock()
+ defer repo.mutex.Unlock()
- return result.Session, nil
+ delete(repo.sessions, s.Code)
+
+ return s, nil
}
func (repo *memorySessionRepository) GC() {
@@ -76,29 +75,20 @@ func (repo *memorySessionRepository) GC() {
for ts := range ticker.C {
ts := ts
- repo.store.Range(func(key, value interface{}) bool {
- k, ok := key.(string)
- if !ok {
- return false
+ repo.mutex.RLock()
+
+ for code, s := range repo.sessions {
+ if s.CreatedAt.Add(repo.config.Code.Expiry).After(ts) {
+ continue
}
- matched, err := path.Match(DefaultPathPrefix+"/*", k)
- if err != nil || !matched {
- return false
- }
+ repo.mutex.RUnlock()
+ repo.mutex.Lock()
+ delete(repo.sessions, code)
+ repo.mutex.Unlock()
+ repo.mutex.RLock()
+ }
- val, ok := value.(*Session)
- if !ok {
- return false
- }
-
- if val.CreatedAt.Add(repo.config.Code.Expiry).After(ts) {
- return false
- }
-
- repo.store.Delete(key)
-
- return false
- })
+ repo.mutex.RUnlock()
}
}
diff --git a/internal/session/repository/sqlite3/sqlite3_session.go b/internal/session/repository/sqlite3/sqlite3_session.go
index 3aadfac..ea07e66 100644
--- a/internal/session/repository/sqlite3/sqlite3_session.go
+++ b/internal/session/repository/sqlite3/sqlite3_session.go
@@ -4,11 +4,11 @@ import (
"context"
"database/sql"
"encoding/base64"
- "encoding/json"
"errors"
"fmt"
"time"
+ "github.com/goccy/go-json"
"github.com/jmoiron/sqlx"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -53,8 +53,8 @@ func NewSQLite3SessionRepository(db *sqlx.DB) session.Repository {
}
}
-func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domain.Session) error {
- src, err := NewSession(session)
+func (repo *sqlite3SessionRepository) Create(ctx context.Context, session domain.Session) error {
+ src, err := NewSession(&session)
if err != nil {
return fmt.Errorf("cannot encode session data for store: %w", err)
}
diff --git a/internal/session/repository/sqlite3/sqlite3_session_test.go b/internal/session/repository/sqlite3/sqlite3_session_test.go
index b2f2604..8ed11a1 100644
--- a/internal/session/repository/sqlite3/sqlite3_session_test.go
+++ b/internal/session/repository/sqlite3/sqlite3_session_test.go
@@ -12,7 +12,7 @@ import (
"source.toby3d.me/toby3d/auth/internal/testing/sqltest"
)
-//nolint: gochecknoglobals // slices cannot be contants
+// nolint: gochecknoglobals // slices cannot be contants
var tableColumns = []string{"created_at", "code", "data"}
func TestCreate(t *testing.T) {
@@ -39,7 +39,7 @@ func TestCreate(t *testing.T) {
WillReturnResult(sqlmock.NewResult(1, 1))
if err := repository.NewSQLite3SessionRepository(db).
- Create(context.Background(), session); err != nil {
+ Create(context.Background(), *session); err != nil {
t.Error(err)
}
}
diff --git a/internal/testing/httptest/.gitignore b/internal/testing/httptest/.gitignore
deleted file mode 100644
index 612424a..0000000
--- a/internal/testing/httptest/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-*.pem
\ No newline at end of file
diff --git a/internal/testing/httptest/httptest.go b/internal/testing/httptest/httptest.go
deleted file mode 100644
index ca22ce1..0000000
--- a/internal/testing/httptest/httptest.go
+++ /dev/null
@@ -1,63 +0,0 @@
-//go:generate go run "$GOROOT/src/crypto/tls/generate_cert.go" --host 127.0.0.1,::1,localhost --start-date "Jan 1 00:00:00 1970" --duration=1000000h --ca --rsa-bits 1024 --ecdsa-curve P256
-package httptest
-
-import (
- "crypto/tls"
- _ "embed" // used for running tests without same import in "god object"
- "net"
- "testing"
- "time"
-
- http "github.com/valyala/fasthttp"
- httputil "github.com/valyala/fasthttp/fasthttputil"
-)
-
-var (
- //go:embed cert.pem
- certData []byte
- //go:embed key.pem
- keyData []byte
-)
-
-// New returns the InMemory Server and the Client connected to it with the
-// specified handler.
-func New(tb testing.TB, handler http.RequestHandler) (*http.Client, *http.Server, func()) {
- tb.Helper()
-
- //nolint:exhaustivestruct
- server := &http.Server{
- CloseOnShutdown: true,
- DisableKeepalive: true,
- ReduceMemoryUsage: true,
- Handler: http.TimeoutHandler(handler, 1*time.Second, "handler performance is too slow"),
- }
-
- ln := httputil.NewInmemoryListener()
-
- //nolint:errcheck
- go server.ServeTLSEmbed(ln, certData, keyData)
-
- //nolint:exhaustivestruct
- client := &http.Client{
- TLSConfig: &tls.Config{
- InsecureSkipVerify: true, //nolint:gosec
- },
- Dial: func(addr string) (net.Conn, error) {
- return ln.Dial() //nolint:wrapcheck
- },
- }
-
- return client, server, func() {
- _ = server.Shutdown()
- }
-}
-
-// NewRequest returns a new incoming server Request and cleanup function.
-func NewRequest(method, target string, body []byte) *http.Request {
- req := http.AcquireRequest()
- req.Header.SetMethod(method)
- req.SetRequestURI(target)
- req.SetBody(body)
-
- return req
-}
diff --git a/internal/ticket/delivery/http/ticket_http.go b/internal/ticket/delivery/http/ticket_http.go
index 67a6d89..e20cc41 100644
--- a/internal/ticket/delivery/http/ticket_http.go
+++ b/internal/ticket/delivery/http/ticket_http.go
@@ -1,72 +1,48 @@
package http
import (
- "errors"
"fmt"
- "path"
+ "net/http"
- "github.com/fasthttp/router"
"github.com/goccy/go-json"
"github.com/lestrrat-go/jwx/v2/jwa"
- http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
+ "source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/random"
"source.toby3d.me/toby3d/auth/internal/ticket"
+ "source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/auth/web"
- "source.toby3d.me/toby3d/form"
- "source.toby3d.me/toby3d/middleware"
)
-type (
- TicketGenerateRequest struct {
- // The access token should be used when acting on behalf of this URL.
- Subject *domain.Me `form:"subject"`
+type Handler struct {
+ config domain.Config
+ matcher language.Matcher
+ tickets ticket.UseCase
+}
- // The access token will work at this URL.
- Resource *domain.URL `form:"resource"`
- }
-
- TicketExchangeRequest struct {
- // A random string that can be redeemed for an access token.
- Ticket string `form:"ticket"`
-
- // The access token should be used when acting on behalf of this URL.
- Subject *domain.Me `form:"subject"`
-
- // The access token will work at this URL.
- Resource *domain.URL `form:"resource"`
- }
-
- RequestHandler struct {
- config *domain.Config
- matcher language.Matcher
- tickets ticket.UseCase
- }
-)
-
-func NewRequestHandler(tickets ticket.UseCase, matcher language.Matcher, config *domain.Config) *RequestHandler {
- return &RequestHandler{
+func NewHandler(tickets ticket.UseCase, matcher language.Matcher, config domain.Config) *Handler {
+ return &Handler{
config: config,
matcher: matcher,
tickets: tickets,
}
}
-func (h *RequestHandler) Register(r *router.Router) {
+func (h *Handler) Handler() http.Handler {
//nolint:exhaustivestruct
chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{
- Skipper: func(ctx *http.RequestCtx) bool {
- matched, _ := path.Match("/ticket*", string(ctx.Path()))
+ Skipper: func(w http.ResponseWriter, r *http.Request) bool {
+ head, _ := urlutil.ShiftPath(r.URL.Path)
- return ctx.IsPost() && matched
+ return r.Method == http.MethodPost && head == "ticket"
},
CookieMaxAge: 0,
- CookieSameSite: http.CookieSameSiteStrictMode,
+ CookieSameSite: http.SameSiteStrictMode,
ContextKey: "csrf",
CookieDomain: h.config.Server.Domain,
CookieName: "__Secure-csrf",
@@ -89,45 +65,69 @@ func (h *RequestHandler) Register(r *router.Router) {
SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm),
Skipper: middleware.DefaultSkipper,
SuccessHandler: nil,
- TokenLookup: "header:" + http.HeaderAuthorization +
- "," + "cookie:" + "__Secure-auth-token",
+ TokenLookup: "header:" + common.HeaderAuthorization +
+ ",cookie:__Secure-auth-token",
}),
- middleware.LogFmt(),
}
- r.GET("/ticket", chain.RequestHandler(h.handleRender))
- r.POST("/api/ticket", chain.RequestHandler(h.handleSend))
- r.POST("/ticket", chain.RequestHandler(h.handleRedeem))
+ return chain.Handler(h.handleFunc)
}
-func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
+func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) {
+ var head string
+ head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
- tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
+ switch r.Method {
+ default:
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+ case "", http.MethodGet:
+ if head != "" {
+ http.NotFound(w, r)
+
+ return
+ }
+
+ h.handleRender(w, r)
+ case http.MethodPost:
+
+ switch head {
+ default:
+ http.NotFound(w, r)
+ case "":
+ h.handleRedeem(w, r)
+ case "send":
+ h.handleSend(w, r)
+ }
+ }
+}
+
+func (h *Handler) handleRender(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+
+ tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...)
baseOf := web.BaseOf{
- Config: h.config,
+ Config: &h.config,
Language: tag,
Printer: message.NewPrinter(tag),
}
- csrf, _ := ctx.UserValue("csrf").([]byte)
- web.WriteTemplate(ctx, &web.TicketPage{
+ csrf, _ := r.Context().Value("csrf").([]byte)
+ web.WriteTemplate(w, &web.TicketPage{
BaseOf: baseOf,
CSRF: csrf,
})
}
-func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
- ctx.Response.Header.Set(http.HeaderAccessControlAllowOrigin, h.config.Server.Domain)
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
- ctx.SetStatusCode(http.StatusOK)
+func (h *Handler) handleSend(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderAccessControlAllowOrigin, h.config.Server.Domain)
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
- encoder := json.NewEncoder(ctx)
+ encoder := json.NewEncoder(w)
req := new(TicketGenerateRequest)
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
@@ -137,51 +137,50 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
ticket := &domain.Ticket{
Ticket: "",
Resource: req.Resource.URL,
- Subject: req.Subject,
+ Subject: &req.Subject,
}
var err error
if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil {
- ctx.SetStatusCode(http.StatusInternalServerError)
+ w.WriteHeader(http.StatusInternalServerError)
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return
}
- if err = h.tickets.Generate(ctx, ticket); err != nil {
- ctx.SetStatusCode(http.StatusInternalServerError)
+ if err = h.tickets.Generate(r.Context(), *ticket); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return
}
- ctx.SetStatusCode(http.StatusOK)
+ w.WriteHeader(http.StatusOK)
}
-func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
- ctx.SetStatusCode(http.StatusOK)
+func (h *Handler) handleRedeem(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
- encoder := json.NewEncoder(ctx)
+ encoder := json.NewEncoder(w)
req := new(TicketExchangeRequest)
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
return
}
- token, err := h.tickets.Redeem(ctx, &domain.Ticket{
+ token, err := h.tickets.Redeem(r.Context(), domain.Ticket{
Ticket: req.Ticket,
Resource: req.Resource.URL,
- Subject: req.Subject,
+ Subject: &req.Subject,
})
if err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
@@ -190,84 +189,11 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
// TODO(toby3d): print the result as part of the debugging. Instead, we
// need to send or save the token to the recipient for later use.
- ctx.SetBodyString(fmt.Sprintf(`{
+ fmt.Fprintf(w, `{
"access_token": "%s",
"token_type": "Bearer",
"scope": "%s",
"me": "%s"
- }`, token.AccessToken, token.Scope.String(), token.Me.String()))
-}
-
-func (req *TicketGenerateRequest) bind(ctx *http.RequestCtx) (err error) {
- indieAuthError := new(domain.Error)
- if err = form.Unmarshal(ctx.Request.PostArgs().QueryString(), req); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
- )
- }
-
- if req.Resource == nil {
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- "resource value MUST be set",
- "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
- )
- }
-
- if req.Subject == nil {
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- "subject value MUST be set",
- "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
- )
- }
-
- return nil
-}
-
-func (req *TicketExchangeRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
- if err := form.Unmarshal(ctx.Request.PostArgs().QueryString(), req); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
- )
- }
-
- if req.Ticket == "" {
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- "ticket parameter is required",
- "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
- )
- }
-
- if req.Resource == nil {
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- "resource parameter is required",
- "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
- )
- }
-
- if req.Subject == nil {
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- "subject parameter is required",
- "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
- )
- }
-
- return nil
+ }`, token.AccessToken, token.Scope.String(), token.Me.String())
+ w.WriteHeader(http.StatusOK)
}
diff --git a/internal/ticket/delivery/http/ticket_http_schema.go b/internal/ticket/delivery/http/ticket_http_schema.go
new file mode 100644
index 0000000..eb58b29
--- /dev/null
+++ b/internal/ticket/delivery/http/ticket_http_schema.go
@@ -0,0 +1,90 @@
+package http
+
+import (
+ "errors"
+ "net/http"
+
+ "source.toby3d.me/toby3d/auth/internal/domain"
+ "source.toby3d.me/toby3d/form"
+)
+
+type (
+ TicketGenerateRequest struct {
+ // The access token should be used when acting on behalf of this URL.
+ Subject domain.Me `form:"subject"`
+
+ // The access token will work at this URL.
+ Resource domain.URL `form:"resource"`
+ }
+
+ TicketExchangeRequest struct {
+ // The access token should be used when acting on behalf of this URL.
+ Subject domain.Me `form:"subject"`
+
+ // The access token will work at this URL.
+ Resource domain.URL `form:"resource"`
+
+ // A random string that can be redeemed for an access token.
+ Ticket string `form:"ticket"`
+ }
+)
+
+func (r *TicketGenerateRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := req.ParseForm(); err != nil {
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
+ )
+ }
+
+ return nil
+}
+
+func (r *TicketExchangeRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := req.ParseForm(); err != nil {
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
+ )
+ }
+
+ if r.Ticket == "" {
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ "ticket parameter is required",
+ "https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
+ )
+ }
+
+ return nil
+}
diff --git a/internal/ticket/delivery/http/ticket_http_test.go b/internal/ticket/delivery/http/ticket_http_test.go
index 9eec61b..718966d 100644
--- a/internal/ticket/delivery/http/ticket_http_test.go
+++ b/internal/ticket/delivery/http/ticket_http_test.go
@@ -1,17 +1,20 @@
package http_test
+/* TODO(toby3d): move CSRF middleware into main
import (
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
"sync"
"testing"
- "github.com/fasthttp/router"
- http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
- "source.toby3d.me/toby3d/auth/internal/testing/httptest"
"source.toby3d.me/toby3d/auth/internal/ticket"
delivery "source.toby3d.me/toby3d/auth/internal/ticket/delivery/http"
ticketrepo "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory"
@@ -19,6 +22,7 @@ import (
)
type Dependencies struct {
+ server *httptest.Server
client *http.Client
config *domain.Config
matcher language.Matcher
@@ -33,40 +37,35 @@ func TestUpdate(t *testing.T) {
t.Parallel()
deps := NewDependencies(t)
+ t.Cleanup(deps.server.Close)
- r := router.New()
- delivery.NewRequestHandler(deps.ticketService, deps.matcher, deps.config).Register(r)
-
- client, _, cleanup := httptest.New(t, r.Handler)
- t.Cleanup(cleanup)
-
- const requestURI string = "https://example.com/ticket"
-
- req := httptest.NewRequest(http.MethodPost, requestURI, []byte(
+ req := httptest.NewRequest(http.MethodPost, "https://example.com/", strings.NewReader(
`ticket=`+deps.ticket.Ticket+
`&resource=`+deps.ticket.Resource.String()+
`&subject=`+deps.ticket.Subject.String(),
))
- defer http.ReleaseRequest(req)
- req.Header.SetContentType(common.MIMEApplicationForm)
+ req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm)
+ deps.token.SetAuthHeader(req)
+
+ w := httptest.NewRecorder()
+ delivery.NewHandler(deps.ticketService, deps.matcher, *deps.config).
+ Handler().
+ ServeHTTP(w, req)
+
domain.TestToken(t).SetAuthHeader(req)
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
+ resp := w.Result()
- if err := client.Do(req, resp); err != nil {
- t.Fatal(err)
- }
-
- if resp.StatusCode() != http.StatusOK && resp.StatusCode() != http.StatusAccepted {
- t.Errorf("POST %s = %d, want %d or %d", requestURI, resp.StatusCode(), http.StatusOK,
+ if resp.StatusCode != http.StatusOK &&
+ resp.StatusCode != http.StatusAccepted {
+ t.Errorf("%s %s = %d, want %d or %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK,
http.StatusAccepted)
}
// TODO(toby3d): print the result as part of the debugging. Instead, you
// need to send or save the token to the recipient for later use.
- if resp.Body() == nil {
- t.Errorf("POST %s = nil, want something", requestURI)
+ if resp.Body == nil {
+ t.Errorf("%s %s = nil, want not nil", req.Method, req.RequestURI)
}
}
@@ -79,29 +78,36 @@ func NewDependencies(tb testing.TB) Dependencies {
ticket := domain.TestTicket(tb)
token := domain.TestToken(tb)
- r := router.New()
+ mux := http.NewServeMux()
// NOTE(toby3d): private resource
- r.GET(ticket.Resource.Path, func(ctx *http.RequestCtx) {
- ctx.SuccessString(common.MIMETextHTMLCharsetUTF8,
- ``)
+ mux.HandleFunc(ticket.Resource.Path, func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+ fmt.Fprintf(w, ``)
})
// NOTE(toby3d): token endpoint
- r.POST("/token", func(ctx *http.RequestCtx) {
- ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{
- "access_token": "`+token.AccessToken+`",
- "me": "`+token.Me.String()+`",
- "scope": "`+token.Scope.String()+`",
- "token_type": "Bearer"
- }`)
+ mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+ fmt.Fprintf(w, `{
+ "access_token": "`+token.AccessToken+`",
+ "me": "`+token.Me.String()+`",
+ "scope": "`+token.Scope.String()+`",
+ "token_type": "Bearer"
+ }`)
})
- client, _, cleanup := httptest.New(tb, r.Handler)
- tb.Cleanup(cleanup)
-
+ server := httptest.NewServer(mux)
+ client := server.Client()
tickets := ticketrepo.NewMemoryTicketRepository(store, config)
ticketService := ucase.NewTicketUseCase(tickets, client, config)
return Dependencies{
+ server: server,
client: client,
config: config,
matcher: matcher,
@@ -112,3 +118,4 @@ func NewDependencies(tb testing.TB) Dependencies {
token: token,
}
}
+*/
diff --git a/internal/ticket/repository.go b/internal/ticket/repository.go
index 06e7eab..5e85b52 100644
--- a/internal/ticket/repository.go
+++ b/internal/ticket/repository.go
@@ -7,7 +7,7 @@ import (
)
type Repository interface {
- Create(ctx context.Context, ticket *domain.Ticket) error
+ Create(ctx context.Context, ticket domain.Ticket) error
GetAndDelete(ctx context.Context, ticket string) (*domain.Ticket, error)
GC()
}
diff --git a/internal/ticket/repository/memory/memory_ticket.go b/internal/ticket/repository/memory/memory_ticket.go
index a7ad4c5..08e12d8 100644
--- a/internal/ticket/repository/memory/memory_ticket.go
+++ b/internal/ticket/repository/memory/memory_ticket.go
@@ -2,8 +2,6 @@ package memory
import (
"context"
- "fmt"
- "path"
"sync"
"time"
@@ -14,77 +12,75 @@ import (
type (
Ticket struct {
CreatedAt time.Time
- *domain.Ticket
+ domain.Ticket
}
memoryTicketRepository struct {
- config *domain.Config
- store *sync.Map
+ config domain.Config
+ mutex *sync.RWMutex
+ tickets map[string]Ticket
}
)
-const DefaultPathPrefix string = "tickets"
-
-func NewMemoryTicketRepository(store *sync.Map, config *domain.Config) ticket.Repository {
+func NewMemoryTicketRepository(config domain.Config) ticket.Repository {
return &memoryTicketRepository{
- config: config,
- store: store,
+ config: config,
+ mutex: new(sync.RWMutex),
+ tickets: make(map[string]Ticket),
}
}
-func (repo *memoryTicketRepository) Create(_ context.Context, t *domain.Ticket) error {
- repo.store.Store(path.Join(DefaultPathPrefix, t.Ticket), &Ticket{
+func (repo *memoryTicketRepository) Create(_ context.Context, t domain.Ticket) error {
+ repo.mutex.Lock()
+ defer repo.mutex.Unlock()
+
+ repo.tickets[t.Ticket] = Ticket{
CreatedAt: time.Now().UTC(),
Ticket: t,
- })
+ }
return nil
}
func (repo *memoryTicketRepository) GetAndDelete(_ context.Context, t string) (*domain.Ticket, error) {
- src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, t))
+ repo.mutex.RLock()
+
+ out, ok := repo.tickets[t]
if !ok {
- return nil, fmt.Errorf("cannot find ticket in store: %w", ticket.ErrNotExist)
+ repo.mutex.RUnlock()
+
+ return nil, ticket.ErrNotExist
}
- result, ok := src.(*Ticket)
- if !ok {
- return nil, fmt.Errorf("cannot decode ticket in store: %w", ticket.ErrNotExist)
- }
+ repo.mutex.RUnlock()
+ repo.mutex.Lock()
+ delete(repo.tickets, t)
+ repo.mutex.Unlock()
- return result.Ticket, nil
+ return &out.Ticket, nil
}
func (repo *memoryTicketRepository) GC() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
- for timeStamp := range ticker.C {
- timeStamp := timeStamp.UTC()
+ for ts := range ticker.C {
+ ts := ts.UTC()
- repo.store.Range(func(key, value interface{}) bool {
- k, ok := key.(string)
- if !ok {
- return false
+ repo.mutex.RLock()
+
+ for _, t := range repo.tickets {
+ if t.CreatedAt.Add(repo.config.Code.Expiry).After(ts) {
+ continue
}
- matched, err := path.Match(DefaultPathPrefix+"/*", k)
- if err != nil || !matched {
- return false
- }
+ repo.mutex.RUnlock()
+ repo.mutex.Lock()
+ delete(repo.tickets, t.Ticket.Ticket)
+ repo.mutex.Unlock()
+ repo.mutex.RLock()
+ }
- val, ok := value.(*Ticket)
- if !ok {
- return false
- }
-
- if val.CreatedAt.Add(repo.config.Code.Expiry).After(timeStamp) {
- return false
- }
-
- repo.store.Delete(key)
-
- return false
- })
+ repo.mutex.RUnlock()
}
}
diff --git a/internal/ticket/repository/memory/memory_ticket_test.go b/internal/ticket/repository/memory/memory_ticket_test.go
deleted file mode 100644
index 2f017f0..0000000
--- a/internal/ticket/repository/memory/memory_ticket_test.go
+++ /dev/null
@@ -1,63 +0,0 @@
-package memory_test
-
-import (
- "context"
- "path"
- "reflect"
- "sync"
- "testing"
- "time"
-
- "source.toby3d.me/toby3d/auth/internal/domain"
- repository "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory"
-)
-
-func TestCreate(t *testing.T) {
- t.Parallel()
-
- store := new(sync.Map)
- ticket := domain.TestTicket(t)
-
- if err := repository.NewMemoryTicketRepository(store, domain.TestConfig(t)).
- Create(context.Background(), ticket); err != nil {
- t.Fatal(err)
- }
-
- storePath := path.Join(repository.DefaultPathPrefix, ticket.Ticket)
-
- src, ok := store.Load(storePath)
- if !ok {
- t.Fatalf("Load(%s) = %t, want %t", storePath, ok, true)
- }
-
- if result, _ := src.(*repository.Ticket); !reflect.DeepEqual(result.Ticket, ticket) {
- t.Errorf("Create(%+v) = %+v, want %+v", ticket, result.Ticket, ticket)
- }
-}
-
-func TestGetAndDelete(t *testing.T) {
- t.Parallel()
-
- ticket := domain.TestTicket(t)
-
- store := new(sync.Map)
- store.Store(path.Join(repository.DefaultPathPrefix, ticket.Ticket), &repository.Ticket{
- CreatedAt: time.Now().UTC(),
- Ticket: ticket,
- })
-
- result, err := repository.NewMemoryTicketRepository(store, domain.TestConfig(t)).
- GetAndDelete(context.Background(), ticket.Ticket)
- if err != nil {
- t.Fatal(err)
- }
-
- if !reflect.DeepEqual(result, ticket) {
- t.Errorf("GetAndDelete(%s) = %+v, want %+v", ticket.Ticket, result, ticket)
- }
-
- storePath := path.Join(repository.DefaultPathPrefix, ticket.Ticket)
- if src, _ := store.Load(storePath); src != nil {
- t.Errorf("Load(%s) = %+v, want %+v", storePath, src, nil)
- }
-}
diff --git a/internal/ticket/repository/sqlite3/sqlite3_ticket.go b/internal/ticket/repository/sqlite3/sqlite3_ticket.go
index 1c6b615..7e823b4 100644
--- a/internal/ticket/repository/sqlite3/sqlite3_ticket.go
+++ b/internal/ticket/repository/sqlite3/sqlite3_ticket.go
@@ -56,8 +56,8 @@ func NewSQLite3TicketRepository(db *sqlx.DB, config *domain.Config) ticket.Repos
}
}
-func (repo *sqlite3TicketRepository) Create(ctx context.Context, t *domain.Ticket) error {
- if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewTicket(t)); err != nil {
+func (repo *sqlite3TicketRepository) Create(ctx context.Context, t domain.Ticket) error {
+ if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewTicket(&t)); err != nil {
return fmt.Errorf("cannot create token record in db: %w", err)
}
diff --git a/internal/ticket/repository/sqlite3/sqlite3_ticket_test.go b/internal/ticket/repository/sqlite3/sqlite3_ticket_test.go
index 55307ea..7dee99e 100644
--- a/internal/ticket/repository/sqlite3/sqlite3_ticket_test.go
+++ b/internal/ticket/repository/sqlite3/sqlite3_ticket_test.go
@@ -12,7 +12,7 @@ import (
repository "source.toby3d.me/toby3d/auth/internal/ticket/repository/sqlite3"
)
-// nolint: gochecknoglobals // slices cannot be contants
+//nolint: gochecknoglobals // slices cannot be contants
var tableColumns = []string{"created_at", "resource", "subject", "ticket"}
func TestCreate(t *testing.T) {
@@ -34,7 +34,7 @@ func TestCreate(t *testing.T) {
WillReturnResult(sqlmock.NewResult(1, 1))
if err := repository.NewSQLite3TicketRepository(db, domain.TestConfig(t)).
- Create(context.Background(), ticket); err != nil {
+ Create(context.Background(), *ticket); err != nil {
t.Error(err)
}
}
diff --git a/internal/ticket/usecase.go b/internal/ticket/usecase.go
index c49f807..7687bae 100644
--- a/internal/ticket/usecase.go
+++ b/internal/ticket/usecase.go
@@ -7,10 +7,10 @@ import (
)
type UseCase interface {
- Generate(ctx context.Context, ticket *domain.Ticket) error
+ Generate(ctx context.Context, ticket domain.Ticket) error
// Redeem transform received ticket into access token.
- Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error)
+ Redeem(ctx context.Context, ticket domain.Ticket) (*domain.Token, error)
Exchange(ctx context.Context, ticket string) (*domain.Token, error)
}
diff --git a/internal/ticket/usecase/ticket_ucase.go b/internal/ticket/usecase/ticket_ucase.go
index b285525..7439369 100644
--- a/internal/ticket/usecase/ticket_ucase.go
+++ b/internal/ticket/usecase/ticket_ucase.go
@@ -1,13 +1,15 @@
package usecase
import (
+ "bytes"
"context"
"fmt"
+ "io"
+ "net/http"
"net/url"
"time"
json "github.com/goccy/go-json"
- http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -47,26 +49,28 @@ func NewTicketUseCase(tickets ticket.Repository, client *http.Client, config *do
}
}
-func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket) error {
- req := http.AcquireRequest()
- defer http.ReleaseRequest(req)
- req.Header.SetMethod(http.MethodGet)
- req.SetRequestURI(tkt.Subject.String())
-
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
-
- if err := useCase.client.Do(req, resp); err != nil {
+func (useCase *ticketUseCase) Generate(ctx context.Context, tkt domain.Ticket) error {
+ resp, err := useCase.client.Get(tkt.Subject.String())
+ if err != nil {
return fmt.Errorf("cannot discovery ticket subject: %w", err)
}
- var ticketEndpoint *url.URL
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("cannot read response body: %w", err)
+ }
+
+ buf := bytes.NewReader(body)
+ ticketEndpoint := new(url.URL)
// NOTE(toby3d): find metadata first
- if metadata, err := httputil.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil {
+ metadata, err := httputil.ExtractFromMetadata(useCase.client, tkt.Subject.String())
+ if err == nil && metadata != nil {
ticketEndpoint = metadata.TicketEndpoint
} else { // NOTE(toby3d): fallback to old links searching
- if endpoints := httputil.ExtractEndpoints(resp, "ticket_endpoint"); len(endpoints) > 0 {
+ endpoints := httputil.ExtractEndpoints(buf, tkt.Subject.URL(), resp.Header.Get(common.HeaderLink),
+ "ticket_endpoint")
+ if len(endpoints) > 0 {
ticketEndpoint = endpoints[len(endpoints)-1]
}
}
@@ -79,65 +83,59 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket)
return fmt.Errorf("cannot save ticket in store: %w", err)
}
- req.Reset()
- req.Header.SetMethod(http.MethodPost)
- req.SetRequestURI(ticketEndpoint.String())
- req.Header.SetContentType(common.MIMEApplicationForm)
- req.PostArgs().Set("ticket", tkt.Ticket)
- req.PostArgs().Set("subject", tkt.Subject.String())
- req.PostArgs().Set("resource", tkt.Resource.String())
- resp.Reset()
+ payload := make(url.Values)
+ payload.Set("ticket", tkt.Ticket)
+ payload.Set("subject", tkt.Subject.String())
+ payload.Set("resource", tkt.Resource.String())
- if err := useCase.client.Do(req, resp); err != nil {
+ if _, err = useCase.client.PostForm(ticketEndpoint.String(), payload); err != nil {
return fmt.Errorf("cannot send ticket to subject ticket_endpoint: %w", err)
}
return nil
}
-func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (*domain.Token, error) {
- req := http.AcquireRequest()
- defer http.ReleaseRequest(req)
- req.SetRequestURI(tkt.Resource.String())
- req.Header.SetMethod(http.MethodGet)
-
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
-
- if err := useCase.client.Do(req, resp); err != nil {
+func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt domain.Ticket) (*domain.Token, error) {
+ resp, err := useCase.client.Get(tkt.Resource.String())
+ if err != nil {
return nil, fmt.Errorf("cannot discovery ticket resource: %w", err)
}
- var tokenEndpoint *url.URL
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("cannot read response body: %w", err)
+ }
+
+ buf := bytes.NewReader(body)
+ tokenEndpoint := new(url.URL)
// NOTE(toby3d): find metadata first
- if metadata, err := httputil.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil {
+ metadata, err := httputil.ExtractFromMetadata(useCase.client, tkt.Resource.String())
+ if err == nil && metadata != nil {
tokenEndpoint = metadata.TokenEndpoint
} else { // NOTE(toby3d): fallback to old links searching
- if endpoints := httputil.ExtractEndpoints(resp, "token_endpoint"); len(endpoints) > 0 {
+ endpoints := httputil.ExtractEndpoints(buf, tkt.Resource, resp.Header.Get(common.HeaderLink),
+ "token_endpoint")
+ if len(endpoints) > 0 {
tokenEndpoint = endpoints[len(endpoints)-1]
}
}
- if tokenEndpoint == nil {
+ if tokenEndpoint == nil || tokenEndpoint.String() == "" {
return nil, ticket.ErrTokenEndpointNotExist
}
- req.Reset()
- req.Header.SetMethod(http.MethodPost)
- req.SetRequestURI(tokenEndpoint.String())
- req.Header.SetContentType(common.MIMEApplicationForm)
- req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
- req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String())
- req.PostArgs().Set("ticket", tkt.Ticket)
- resp.Reset()
+ payload := make(url.Values)
+ payload.Set("grant_type", domain.GrantTypeTicket.String())
+ payload.Set("ticket", tkt.Ticket)
- if err := useCase.client.Do(req, resp); err != nil {
+ resp, err = useCase.client.PostForm(tokenEndpoint.String(), payload)
+ if err != nil {
return nil, fmt.Errorf("cannot exchange ticket on token_endpoint: %w", err)
}
data := new(AccessToken)
- if err := json.Unmarshal(resp.Body(), data); err != nil {
+ if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, fmt.Errorf("cannot unmarshal access token response: %w", err)
}
@@ -147,8 +145,8 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (*
Scope: nil, // TODO(toby3d)
// TODO(toby3d): should this also include client_id?
// https://github.com/indieweb/indieauth/issues/85
- ClientID: nil,
- Me: data.Me,
+ ClientID: domain.ClientID{},
+ Me: *data.Me,
AccessToken: data.AccessToken,
RefreshToken: "", // TODO(toby3d)
}, nil
@@ -163,8 +161,8 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*dom
token, err := domain.NewToken(domain.NewTokenOptions{
Expiration: useCase.config.JWT.Expiry,
Scope: domain.Scopes{domain.ScopeRead},
- Issuer: nil,
- Subject: tkt.Subject,
+ Issuer: domain.ClientID{},
+ Subject: *tkt.Subject,
Secret: []byte(useCase.config.JWT.Secret),
Algorithm: useCase.config.JWT.Algorithm,
NonceLength: useCase.config.JWT.NonceLength,
diff --git a/internal/ticket/usecase/ticket_ucase_test.go b/internal/ticket/usecase/ticket_ucase_test.go
index 73199a8..a67bc9b 100644
--- a/internal/ticket/usecase/ticket_ucase_test.go
+++ b/internal/ticket/usecase/ticket_ucase_test.go
@@ -3,14 +3,13 @@ package usecase_test
import (
"context"
"fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
"testing"
- "github.com/fasthttp/router"
- http "github.com/valyala/fasthttp"
-
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
- "source.toby3d.me/toby3d/auth/internal/testing/httptest"
ucase "source.toby3d.me/toby3d/auth/internal/ticket/usecase"
)
@@ -20,25 +19,33 @@ func TestRedeem(t *testing.T) {
token := domain.TestToken(t)
ticket := domain.TestTicket(t)
- router := router.New()
- router.GET(string(ticket.Resource.Path), func(ctx *http.RequestCtx) {
- ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, ``)
- })
- router.POST("/token", func(ctx *http.RequestCtx) {
- ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, fmt.Sprintf(`{
+ tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+ fmt.Fprintf(w, `{
"token_type": "Bearer",
"access_token": "%s",
"scope": "%s",
"me": "%s"
- }`, token.AccessToken, token.Scope.String(), token.Me.String()))
- })
+ }`, token.AccessToken, token.Scope.String(), token.Me.String())
+ }))
+ t.Cleanup(tokenServer.Close)
- client, _, cleanup := httptest.New(t, router.Handler)
- t.Cleanup(cleanup)
+ subjectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+ fmt.Fprint(w, ``)
+ }))
+ t.Cleanup(subjectServer.Close)
- result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)).
- Redeem(context.Background(), ticket)
+ ticket.Resource, _ = url.Parse(subjectServer.URL + "/")
+
+ result, err := ucase.NewTicketUseCase(nil, subjectServer.Client(), domain.TestConfig(t)).
+ Redeem(context.Background(), *ticket)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/token/delivery/http/token_http.go b/internal/token/delivery/http/token_http.go
index 6b94433..1936f78 100644
--- a/internal/token/delivery/http/token_http.go
+++ b/internal/token/delivery/http/token_http.go
@@ -1,187 +1,100 @@
package http
import (
- "errors"
- "path"
+ "net/http"
- "github.com/fasthttp/router"
- json "github.com/goccy/go-json"
+ "github.com/goccy/go-json"
"github.com/lestrrat-go/jwx/v2/jwa"
- http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
+ "source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/ticket"
"source.toby3d.me/toby3d/auth/internal/token"
- "source.toby3d.me/toby3d/form"
- "source.toby3d.me/toby3d/middleware"
+ "source.toby3d.me/toby3d/auth/internal/urlutil"
)
-type (
- TokenExchangeRequest struct {
- ClientID *domain.ClientID `form:"client_id"`
- RedirectURI *domain.URL `form:"redirect_uri"`
- GrantType domain.GrantType `form:"grant_type"`
- Code string `form:"code"`
- CodeVerifier string `form:"code_verifier"`
- }
+type Handler struct {
+ config *domain.Config
+ tokens token.UseCase
+ tickets ticket.UseCase
+}
- TokenRefreshRequest struct {
- GrantType domain.GrantType `form:"grant_type"` // refresh_token
-
- // The refresh token previously offered to the client.
- RefreshToken string `form:"refresh_token"`
-
- // The client ID that was used when the refresh token was issued.
- ClientID *domain.ClientID `form:"client_id"`
-
- // The client may request a token with the same or fewer scopes
- // than the original access token. If omitted, is treated as
- // equal to the original scopes granted.
- Scope domain.Scopes `form:"scope"`
- }
-
- TokenRevocationRequest struct {
- Action domain.Action `form:"action,omitempty"`
- Token string `form:"token"`
- }
-
- TokenTicketRequest struct {
- Action domain.Action `form:"action"`
- Ticket string `form:"ticket"`
- }
-
- TokenIntrospectRequest struct {
- Token string `form:"token"`
- }
-
- //nolint:tagliatelle // https://indieauth.net/source/#access-token-response
- TokenExchangeResponse struct {
- // The OAuth 2.0 Bearer Token RFC6750.
- AccessToken string `json:"access_token"`
-
- // The canonical user profile URL for the user this access token
- // corresponds to.
- Me string `json:"me"`
-
- // The user's profile information.
- Profile *TokenProfileResponse `json:"profile,omitempty"`
-
- // The lifetime in seconds of the access token.
- ExpiresIn int64 `json:"expires_in,omitempty"`
-
- // The refresh token, which can be used to obtain new access
- // tokens.
- RefreshToken string `json:"refresh_token"`
- }
-
- TokenProfileResponse struct {
- // Name the user wishes to provide to the client.
- Name string `json:"name,omitempty"`
-
- // URL of the user's website.
- URL string `json:"url,omitempty"`
-
- // A photo or image that the user wishes clients to use as a
- // profile image.
- Photo string `json:"photo,omitempty"`
-
- // The email address a user wishes to provide to the client.
- Email string `json:"email,omitempty"`
- }
-
- //nolint:tagliatelle // https://indieauth.net/source/#access-token-verification-response
- TokenIntrospectResponse struct {
- // Boolean indicator of whether or not the presented token is
- // currently active.
- Active bool `json:"active"`
-
- // The profile URL of the user corresponding to this token.
- Me string `json:"me"`
-
- // The client ID associated with this token.
- ClientID string `json:"client_id"`
-
- // A space-separated list of scopes associated with this token.
- Scope string `json:"scope"`
-
- // Integer timestamp, measured in the number of seconds since
- // January 1 1970 UTC, indicating when this token will expire.
- Exp int64 `json:"exp,omitempty"`
-
- // Integer timestamp, measured in the number of seconds since
- // January 1 1970 UTC, indicating when this token was originally
- // issued.
- Iat int64 `json:"iat,omitempty"`
- }
-
- TokenInvalidIntrospectResponse struct {
- Active bool `json:"active"`
- }
-
- TokenRevocationResponse struct{}
-
- RequestHandler struct {
- config *domain.Config
- tokens token.UseCase
- tickets ticket.UseCase
- }
-)
-
-func NewRequestHandler(tokens token.UseCase, tickets ticket.UseCase, config *domain.Config) *RequestHandler {
- return &RequestHandler{
+func NewHandler(tokens token.UseCase, tickets ticket.UseCase, config *domain.Config) *Handler {
+ return &Handler{
config: config,
tokens: tokens,
tickets: tickets,
}
}
-func (h *RequestHandler) Register(r *router.Router) {
+func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{
//nolint:exhaustivestruct
middleware.JWTWithConfig(middleware.JWTConfig{
- AuthScheme: "Bearer",
- ContextKey: "token",
+ Skipper: func(_ http.ResponseWriter, r *http.Request) bool {
+ head, _ := urlutil.ShiftPath(r.URL.Path)
+
+ return head == "token"
+ },
SigningKey: []byte(h.config.JWT.Secret),
SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm),
- Skipper: func(ctx *http.RequestCtx) bool {
- matched, _ := path.Match("/token*", string(ctx.Path()))
-
- return matched
- },
- SuccessHandler: nil,
- TokenLookup: "param:token,header:" + http.HeaderAuthorization + ":Bearer ",
+ ContextKey: "token",
+ TokenLookup: "form:token," + "header:" + common.HeaderAuthorization + ":Bearer ",
+ AuthScheme: "Bearer",
}),
- middleware.LogFmt(),
}
- r.POST("/token", chain.RequestHandler(h.handleAction))
- r.POST("/introspect", chain.RequestHandler(h.handleIntrospect))
- r.POST("/revocation", chain.RequestHandler(h.handleRevokation))
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+
+ return
+ }
+
+ var head string
+ head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
+
+ switch head {
+ default:
+ http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
+ case "token":
+ chain.Handler(h.handleAction).ServeHTTP(w, r)
+ case "introspect":
+ chain.Handler(h.handleIntrospect).ServeHTTP(w, r)
+ case "revocation":
+ chain.Handler(h.handleRevokation).ServeHTTP(w, r)
+ }
+ })
}
-func (h *RequestHandler) handleIntrospect(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
- ctx.SetStatusCode(http.StatusOK)
-
- encoder := json.NewEncoder(ctx)
-
- req := new(TokenIntrospectRequest)
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
-
- _ = encoder.Encode(err)
+func (h *Handler) handleIntrospect(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
- tkn, _, err := h.tokens.Verify(ctx, req.Token)
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+
+ encoder := json.NewEncoder(w)
+
+ req := new(TokenIntrospectRequest)
+ if err := req.bind(r); err != nil {
+ _ = encoder.Encode(err)
+
+ w.WriteHeader(http.StatusBadRequest)
+
+ return
+ }
+
+ tkn, _, err := h.tokens.Verify(r.Context(), req.Token)
if err != nil || tkn == nil {
// WARN(toby3d): If the token is not valid, the endpoint still
// MUST return a 200 Response.
- _ = encoder.Encode(&TokenInvalidIntrospectResponse{
- Active: false,
- })
+ _ = encoder.Encode(&TokenInvalidIntrospectResponse{Active: false})
+
+ w.WriteHeader(http.StatusOK)
return
}
@@ -194,68 +107,83 @@ func (h *RequestHandler) handleIntrospect(ctx *http.RequestCtx) {
Me: tkn.Me.String(),
Scope: tkn.Scope.String(),
})
+
+ w.WriteHeader(http.StatusOK)
}
-func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
+func (h *Handler) handleAction(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- encoder := json.NewEncoder(ctx)
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+
+ encoder := json.NewEncoder(w)
switch {
- case ctx.PostArgs().Has("grant_type"):
- h.handleExchange(ctx)
- case ctx.PostArgs().Has("action"):
- action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action")))
- if err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ case r.PostForm.Has("grant_type"):
+ h.handleExchange(w, r)
+ case r.PostForm.Has("action"):
+ if err := r.ParseForm(); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
- _ = encoder.Encode(domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "",
- ))
+ _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), ""))
+
+ return
+ }
+
+ action, err := domain.ParseAction(r.PostForm.Get("action"))
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+
+ _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), ""))
return
}
switch action {
case domain.ActionRevoke:
- h.handleRevokation(ctx)
+ h.handleRevokation(w, r)
case domain.ActionTicket:
- h.handleTicket(ctx)
+ h.handleTicket(w, r)
}
}
}
//nolint:funlen
-func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
+func (h *Handler) handleExchange(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- encoder := json.NewEncoder(ctx)
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+
+ encoder := json.NewEncoder(w)
req := new(TokenExchangeRequest)
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
return
}
- token, profile, err := h.tokens.Exchange(ctx, token.ExchangeOptions{
+ token, profile, err := h.tokens.Exchange(r.Context(), token.ExchangeOptions{
ClientID: req.ClientID,
RedirectURI: req.RedirectURI.URL,
Code: req.Code,
CodeVerifier: req.CodeVerifier,
})
if err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ w.WriteHeader(http.StatusBadRequest)
- _ = encoder.Encode(domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieauth.net/source/#request",
- ))
+ _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(),
+ "https://indieauth.net/source/#request"))
return
}
@@ -294,62 +222,69 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
}
_ = encoder.Encode(resp)
+
+ w.WriteHeader(http.StatusOK)
}
-func (h *RequestHandler) handleRevokation(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
- ctx.SetStatusCode(http.StatusOK)
+func (h *Handler) handleRevokation(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- encoder := json.NewEncoder(ctx)
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+
+ encoder := json.NewEncoder(w)
req := NewTokenRevocationRequest()
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
return
}
- if err := h.tokens.Revoke(ctx, req.Token); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ if err := h.tokens.Revoke(r.Context(), req.Token); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
- _ = encoder.Encode(domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "",
- ))
+ _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), ""))
return
}
_ = encoder.Encode(&TokenRevocationResponse{})
+
+ w.WriteHeader(http.StatusOK)
}
-func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
- ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
- ctx.SetStatusCode(http.StatusOK)
+func (h *Handler) handleTicket(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
- encoder := json.NewEncoder(ctx)
+ return
+ }
+
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+
+ encoder := json.NewEncoder(w)
req := new(TokenTicketRequest)
- if err := req.bind(ctx); err != nil {
- ctx.SetStatusCode(http.StatusBadRequest)
+ if err := req.bind(r); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
return
}
- tkn, err := h.tickets.Exchange(ctx, req.Ticket)
+ tkn, err := h.tickets.Exchange(r.Context(), req.Ticket)
if err != nil {
- ctx.SetStatusCode(http.StatusInternalServerError)
+ w.WriteHeader(http.StatusInternalServerError)
- _ = encoder.Encode(domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieauth.net/source/#request",
- ))
+ _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(),
+ "https://indieauth.net/source/#request"))
return
}
@@ -361,81 +296,6 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
ExpiresIn: tkn.Expiry.Unix(),
RefreshToken: "", // TODO(toby3d)
})
-}
-
-func (r *TokenExchangeRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
- if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieauth.net/source/#request",
- )
- }
-
- return nil
-}
-
-func NewTokenRevocationRequest() *TokenRevocationRequest {
- return &TokenRevocationRequest{
- Action: domain.ActionRevoke,
- Token: "",
- }
-}
-
-func (r *TokenRevocationRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
-
- err := form.Unmarshal(ctx.PostArgs().QueryString(), r)
- if err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieauth.net/source/#request",
- )
- }
-
- return nil
-}
-
-func (r *TokenTicketRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
- if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieauth.net/source/#request",
- )
- }
-
- return nil
-}
-
-func (r *TokenIntrospectRequest) bind(ctx *http.RequestCtx) error {
- indieAuthError := new(domain.Error)
- if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil {
- if errors.As(err, indieAuthError) {
- return indieAuthError
- }
-
- return domain.NewError(
- domain.ErrorCodeInvalidRequest,
- err.Error(),
- "https://indieauth.net/source/#access-token-verification-request",
- )
- }
-
- return nil
+
+ w.WriteHeader(http.StatusOK)
}
diff --git a/internal/token/delivery/http/token_http_schema.go b/internal/token/delivery/http/token_http_schema.go
new file mode 100644
index 0000000..da7b872
--- /dev/null
+++ b/internal/token/delivery/http/token_http_schema.go
@@ -0,0 +1,210 @@
+package http
+
+import (
+ "errors"
+ "net/http"
+
+ "source.toby3d.me/toby3d/auth/internal/domain"
+ "source.toby3d.me/toby3d/form"
+)
+
+type (
+ TokenExchangeRequest struct {
+ ClientID domain.ClientID `form:"client_id"`
+ RedirectURI domain.URL `form:"redirect_uri"`
+ GrantType domain.GrantType `form:"grant_type"`
+ Code string `form:"code"`
+ CodeVerifier string `form:"code_verifier"`
+ }
+
+ TokenRefreshRequest struct {
+ GrantType domain.GrantType `form:"grant_type"` // refresh_token
+
+ // The client ID that was used when the refresh token was issued.
+ ClientID domain.ClientID `form:"client_id"`
+
+ // The client may request a token with the same or fewer scopes
+ // than the original access token. If omitted, is treated as
+ // equal to the original scopes granted.
+ Scope domain.Scopes `form:"scope"`
+
+ // The refresh token previously offered to the client.
+ RefreshToken string `form:"refresh_token"`
+ }
+
+ TokenRevocationRequest struct {
+ Action domain.Action `form:"action,omitempty"`
+ Token string `form:"token"`
+ }
+
+ TokenTicketRequest struct {
+ Action domain.Action `form:"action"`
+ Ticket string `form:"ticket"`
+ }
+
+ TokenIntrospectRequest struct {
+ Token string `form:"token"`
+ }
+
+ //nolint:tagliatelle // https://indieauth.net/source/#access-token-response
+ TokenExchangeResponse struct {
+ // The user's profile information.
+ Profile *TokenProfileResponse `json:"profile,omitempty"`
+
+ // The OAuth 2.0 Bearer Token RFC6750.
+ AccessToken string `json:"access_token"`
+
+ // The canonical user profile URL for the user this access token
+ // corresponds to.
+ Me string `json:"me"`
+
+ // The refresh token, which can be used to obtain new access
+ // tokens.
+ RefreshToken string `json:"refresh_token"`
+
+ // The lifetime in seconds of the access token.
+ ExpiresIn int64 `json:"expires_in,omitempty"`
+ }
+
+ TokenProfileResponse struct {
+ // Name the user wishes to provide to the client.
+ Name string `json:"name,omitempty"`
+
+ // URL of the user's website.
+ URL string `json:"url,omitempty"`
+
+ // A photo or image that the user wishes clients to use as a
+ // profile image.
+ Photo string `json:"photo,omitempty"`
+
+ // The email address a user wishes to provide to the client.
+ Email string `json:"email,omitempty"`
+ }
+
+ //nolint:tagliatelle // https://indieauth.net/source/#access-token-verification-response
+ TokenIntrospectResponse struct {
+ // The profile URL of the user corresponding to this token.
+ Me string `json:"me"`
+
+ // The client ID associated with this token.
+ ClientID string `json:"client_id"`
+
+ // A space-separated list of scopes associated with this token.
+ Scope string `json:"scope"`
+
+ // Integer timestamp, measured in the number of seconds since
+ // January 1 1970 UTC, indicating when this token will expire.
+ Exp int64 `json:"exp,omitempty"`
+
+ // Integer timestamp, measured in the number of seconds since
+ // January 1 1970 UTC, indicating when this token was originally
+ // issued.
+ Iat int64 `json:"iat,omitempty"`
+
+ // Boolean indicator of whether or not the presented token is
+ // currently active.
+ Active bool `json:"active"`
+ }
+
+ TokenInvalidIntrospectResponse struct {
+ Active bool `json:"active"`
+ }
+
+ TokenRevocationResponse struct{}
+)
+
+func (r *TokenExchangeRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#request",
+ )
+ }
+
+ return nil
+}
+
+func NewTokenRevocationRequest() *TokenRevocationRequest {
+ return &TokenRevocationRequest{
+ Action: domain.ActionRevoke,
+ Token: "",
+ }
+}
+
+func (r *TokenRevocationRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := req.ParseForm(); err != nil {
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ err := form.Unmarshal([]byte(req.PostForm.Encode()), r)
+ if err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#request",
+ )
+ }
+
+ return nil
+}
+
+func (r *TokenTicketRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#request",
+ )
+ }
+
+ return nil
+}
+
+func (r *TokenIntrospectRequest) bind(req *http.Request) error {
+ indieAuthError := new(domain.Error)
+
+ if err := req.ParseForm(); err != nil {
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#authorization-request",
+ )
+ }
+
+ if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
+ if errors.As(err, indieAuthError) {
+ return indieAuthError
+ }
+
+ return domain.NewError(
+ domain.ErrorCodeInvalidRequest,
+ err.Error(),
+ "https://indieauth.net/source/#access-token-verification-request",
+ )
+ }
+
+ return nil
+}
diff --git a/internal/token/delivery/http/token_http_test.go b/internal/token/delivery/http/token_http_test.go
index e32b5f3..d31a8c6 100644
--- a/internal/token/delivery/http/token_http_test.go
+++ b/internal/token/delivery/http/token_http_test.go
@@ -3,12 +3,13 @@ package http_test
import (
"bytes"
"context"
- "sync"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
"testing"
- "github.com/fasthttp/router"
- json "github.com/goccy/go-json"
- http "github.com/valyala/fasthttp"
+ "github.com/goccy/go-json"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -16,7 +17,6 @@ import (
profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory"
"source.toby3d.me/toby3d/auth/internal/session"
sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory"
- "source.toby3d.me/toby3d/auth/internal/testing/httptest"
"source.toby3d.me/toby3d/auth/internal/ticket"
ticketrepo "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory"
ticketucase "source.toby3d.me/toby3d/auth/internal/ticket/usecase"
@@ -31,7 +31,6 @@ type Dependencies struct {
config *domain.Config
profiles profile.Repository
sessions session.Repository
- store *sync.Map
tickets ticket.Repository
ticketService ticket.UseCase
token *domain.Token
@@ -50,32 +49,24 @@ func TestIntrospection(t *testing.T) {
deps := NewDependencies(t)
- r := router.New()
- delivery.NewRequestHandler(deps.tokenService, deps.ticketService, deps.config).Register(r)
+ req := httptest.NewRequest(http.MethodPost, "https://app.example.com/introspect",
+ strings.NewReader("token="+deps.token.AccessToken))
+ req.Header.Set(common.HeaderAccept, common.MIMEApplicationJSON)
+ req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm)
- client, _, cleanup := httptest.New(t, r.Handler)
- t.Cleanup(cleanup)
+ w := httptest.NewRecorder()
+ delivery.NewHandler(deps.tokenService, deps.ticketService, deps.config).
+ Handler().
+ ServeHTTP(w, req)
- const requestURL = "https://app.example.com/introspect"
+ resp := w.Result()
- req := httptest.NewRequest(http.MethodPost, requestURL, []byte("token="+deps.token.AccessToken))
- defer http.ReleaseRequest(req)
- req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
- req.Header.SetContentType(common.MIMEApplicationForm)
-
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
-
- if err := client.Do(req, resp); err != nil {
- t.Fatal(err)
- }
-
- if result := resp.StatusCode(); result != http.StatusOK {
- t.Errorf("GET %s = %d, want %d", requestURL, result, http.StatusOK)
+ if result := resp.StatusCode; result != http.StatusOK {
+ t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, result, http.StatusOK)
}
result := new(delivery.TokenIntrospectResponse)
- if err := json.Unmarshal(resp.Body(), result); err != nil {
+ if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
t.Fatal(err)
}
@@ -84,7 +75,7 @@ func TestIntrospection(t *testing.T) {
if result.ClientID != deps.token.ClientID.String() ||
result.Me != deps.token.Me.String() ||
result.Scope != deps.token.Scope.String() {
- t.Errorf("GET %s = %+v, want %+v", requestURL, result, deps.token)
+ t.Errorf("%s %s = %+v, want %+v", req.Method, req.RequestURI, result, deps.token)
}
}
@@ -93,33 +84,30 @@ func TestRevocation(t *testing.T) {
deps := NewDependencies(t)
- r := router.New()
- delivery.NewRequestHandler(deps.tokenService, deps.ticketService, deps.config).Register(r)
+ req := httptest.NewRequest(http.MethodPost, "https://app.example.com/revocation",
+ strings.NewReader(`token=`+deps.token.AccessToken))
+ req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm)
+ req.Header.Set(common.HeaderAccept, common.MIMEApplicationJSON)
- client, _, cleanup := httptest.New(t, r.Handler)
- t.Cleanup(cleanup)
+ w := httptest.NewRecorder()
+ delivery.NewHandler(deps.tokenService, deps.ticketService, deps.config).
+ Handler().
+ ServeHTTP(w, req)
- const requestURL = "https://app.example.com/revocation"
+ resp := w.Result()
- req := httptest.NewRequest(http.MethodPost, requestURL, []byte("token="+deps.token.AccessToken))
- defer http.ReleaseRequest(req)
- req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
- req.Header.SetContentType(common.MIMEApplicationForm)
-
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
-
- if err := client.Do(req, resp); err != nil {
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
t.Fatal(err)
}
- if result := resp.StatusCode(); result != http.StatusOK {
- t.Errorf("POST %s = %d, want %d", requestURL, result, http.StatusOK)
+ if resp.StatusCode != http.StatusOK {
+ t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK)
}
expBody := []byte("{}") //nolint:ifshort
- if result := bytes.TrimSpace(resp.Body()); !bytes.Equal(result, expBody) {
- t.Errorf("POST %s = %s, want %s", requestURL, result, expBody)
+ if result := bytes.TrimSpace(body); !bytes.Equal(result, expBody) {
+ t.Errorf("%s %s = %s, want %s", req.Method, req.RequestURI, result, expBody)
}
result, err := deps.tokens.Get(context.Background(), deps.token.AccessToken)
@@ -135,14 +123,13 @@ func TestRevocation(t *testing.T) {
func NewDependencies(tb testing.TB) Dependencies {
tb.Helper()
- store := new(sync.Map)
client := new(http.Client)
config := domain.TestConfig(tb)
token := domain.TestToken(tb)
- profiles := profilerepo.NewMemoryProfileRepository(store)
- sessions := sessionrepo.NewMemorySessionRepository(store, config)
- tickets := ticketrepo.NewMemoryTicketRepository(store, config)
- tokens := tokenrepo.NewMemoryTokenRepository(store)
+ profiles := profilerepo.NewMemoryProfileRepository()
+ sessions := sessionrepo.NewMemorySessionRepository(*config)
+ tickets := ticketrepo.NewMemoryTicketRepository(*config)
+ tokens := tokenrepo.NewMemoryTokenRepository()
ticketService := ticketucase.NewTicketUseCase(tickets, client, config)
tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{
Config: config,
@@ -156,7 +143,6 @@ func NewDependencies(tb testing.TB) Dependencies {
config: config,
profiles: profiles,
sessions: sessions,
- store: store,
tickets: tickets,
ticketService: ticketService,
token: token,
diff --git a/internal/token/repository.go b/internal/token/repository.go
index 65b6896..1240bce 100644
--- a/internal/token/repository.go
+++ b/internal/token/repository.go
@@ -7,8 +7,8 @@ import (
)
type Repository interface {
+ Create(ctx context.Context, accessToken domain.Token) error
Get(ctx context.Context, accessToken string) (*domain.Token, error)
- Create(ctx context.Context, accessToken *domain.Token) error
}
var (
diff --git a/internal/token/repository/memory/memory_token.go b/internal/token/repository/memory/memory_token.go
index c858074..49aa1d8 100644
--- a/internal/token/repository/memory/memory_token.go
+++ b/internal/token/repository/memory/memory_token.go
@@ -2,8 +2,6 @@ package memory
import (
"context"
- "errors"
- "path"
"sync"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -11,42 +9,33 @@ import (
)
type memoryTokenRepository struct {
- store *sync.Map
+ mutex *sync.RWMutex
+ tokens map[string]domain.Token
}
-const DefaultPathPrefix string = "tokens"
-
-func NewMemoryTokenRepository(store *sync.Map) token.Repository {
+func NewMemoryTokenRepository() token.Repository {
return &memoryTokenRepository{
- store: store,
+ mutex: new(sync.RWMutex),
+ tokens: make(map[string]domain.Token),
}
}
-func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
- t, err := repo.Get(ctx, accessToken.AccessToken)
- if err != nil && !errors.Is(err, token.ErrNotExist) {
- return err
- }
+func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken domain.Token) error {
+ repo.mutex.Lock()
+ defer repo.mutex.Unlock()
- if t != nil {
- return token.ErrExist
- }
-
- repo.store.Store(path.Join(DefaultPathPrefix, accessToken.AccessToken), accessToken)
+ repo.tokens[accessToken.AccessToken] = accessToken
return nil
}
func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
- t, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken))
- if !ok {
- return nil, token.ErrNotExist
+ repo.mutex.RLock()
+ defer repo.mutex.RUnlock()
+
+ if t, ok := repo.tokens[accessToken]; ok {
+ return &t, nil
}
- result, ok := t.(*domain.Token)
- if !ok {
- return nil, token.ErrNotExist
- }
-
- return result, nil
+ return nil, token.ErrNotExist
}
diff --git a/internal/token/repository/memory/memory_token_test.go b/internal/token/repository/memory/memory_token_test.go
deleted file mode 100644
index 991f993..0000000
--- a/internal/token/repository/memory/memory_token_test.go
+++ /dev/null
@@ -1,45 +0,0 @@
-package memory_test
-
-import (
- "context"
- "path"
- "sync"
- "testing"
-
- "github.com/stretchr/testify/assert"
-
- "source.toby3d.me/toby3d/auth/internal/domain"
- "source.toby3d.me/toby3d/auth/internal/token"
- repository "source.toby3d.me/toby3d/auth/internal/token/repository/memory"
-)
-
-func TestCreate(t *testing.T) {
- t.Parallel()
-
- store := new(sync.Map)
- accessToken := domain.TestToken(t)
-
- repo := repository.NewMemoryTokenRepository(store)
- if err := repo.Create(context.Background(), accessToken); err != nil {
- t.Fatal(err)
- }
-
- result, ok := store.Load(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken))
- assert.True(t, ok)
- assert.Equal(t, accessToken, result)
-
- assert.ErrorIs(t, repo.Create(context.Background(), accessToken), token.ErrExist)
-}
-
-func TestGet(t *testing.T) {
- t.Parallel()
-
- store := new(sync.Map)
- accessToken := domain.TestToken(t)
-
- store.Store(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken), accessToken)
-
- result, err := repository.NewMemoryTokenRepository(store).Get(context.Background(), accessToken.AccessToken)
- assert.NoError(t, err)
- assert.Equal(t, accessToken, result)
-}
diff --git a/internal/token/repository/sqlite3/sqlite3_token.go b/internal/token/repository/sqlite3/sqlite3_token.go
index 28d7c5b..25b268b 100644
--- a/internal/token/repository/sqlite3/sqlite3_token.go
+++ b/internal/token/repository/sqlite3/sqlite3_token.go
@@ -53,8 +53,8 @@ func NewSQLite3TokenRepository(db *sqlx.DB) token.Repository {
}
}
-func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *domain.Token) error {
- if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewToken(accessToken)); err != nil {
+func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken domain.Token) error {
+ if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewToken(&accessToken)); err != nil {
return fmt.Errorf("cannot create token record in db: %w", err)
}
@@ -91,9 +91,11 @@ func NewToken(src *domain.Token) *Token {
}
func (t *Token) Populate(dst *domain.Token) {
+ cid, _ := domain.ParseClientID(t.ClientID)
+ me, _ := domain.ParseMe(t.Me)
dst.AccessToken = t.AccessToken
- dst.ClientID, _ = domain.ParseClientID(t.ClientID)
- dst.Me, _ = domain.ParseMe(t.Me)
+ dst.ClientID = *cid
+ dst.Me = *me
dst.Scope = make(domain.Scopes, 0)
for _, scope := range strings.Fields(t.Scope) {
diff --git a/internal/token/repository/sqlite3/sqlite3_token_test.go b/internal/token/repository/sqlite3/sqlite3_token_test.go
index 04c216d..1f181cc 100644
--- a/internal/token/repository/sqlite3/sqlite3_token_test.go
+++ b/internal/token/repository/sqlite3/sqlite3_token_test.go
@@ -35,7 +35,7 @@ func TestCreate(t *testing.T) {
).
WillReturnResult(sqlmock.NewResult(1, 1))
- if err := repository.NewSQLite3TokenRepository(db).Create(context.Background(), token); err != nil {
+ if err := repository.NewSQLite3TokenRepository(db).Create(context.Background(), *token); err != nil {
t.Error(err)
}
}
diff --git a/internal/token/usecase.go b/internal/token/usecase.go
index 818514f..e6c5cea 100644
--- a/internal/token/usecase.go
+++ b/internal/token/usecase.go
@@ -9,7 +9,7 @@ import (
type (
ExchangeOptions struct {
- ClientID *domain.ClientID
+ ClientID domain.ClientID
RedirectURI *url.URL
Code string
CodeVerifier string
diff --git a/internal/token/usecase/token_ucase.go b/internal/token/usecase/token_ucase.go
index 1b49423..e73f087 100644
--- a/internal/token/usecase/token_ucase.go
+++ b/internal/token/usecase/token_ucase.go
@@ -107,17 +107,17 @@ func (uc *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain
return nil, nil, fmt.Errorf("cannot validate JWT token: %w", err)
}
+ cid, _ := domain.ParseClientID(tkn.Issuer())
+ me, _ := domain.ParseMe(tkn.Subject())
result := &domain.Token{
CreatedAt: tkn.IssuedAt(),
Expiry: tkn.Expiration(),
- ClientID: nil,
- Me: nil,
+ ClientID: *cid,
+ Me: *me,
Scope: nil,
AccessToken: accessToken,
RefreshToken: "", // TODO(toby3d)
}
- result.ClientID, _ = domain.ParseClientID(tkn.Issuer())
- result.Me, _ = domain.ParseMe(tkn.Subject())
if scope, ok := tkn.Get("scope"); ok {
result.Scope, _ = scope.(domain.Scopes)
@@ -149,7 +149,7 @@ func (uc *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
return fmt.Errorf("cannot verify token: %w", err)
}
- if err = uc.tokens.Create(ctx, tkn); err != nil && !errors.Is(err, token.ErrExist) {
+ if err = uc.tokens.Create(ctx, *tkn); err != nil && !errors.Is(err, token.ErrExist) {
return fmt.Errorf("cannot save token in database: %w", err)
}
diff --git a/internal/token/usecase/token_ucase_test.go b/internal/token/usecase/token_ucase_test.go
index 47e8b4e..63bfb2c 100644
--- a/internal/token/usecase/token_ucase_test.go
+++ b/internal/token/usecase/token_ucase_test.go
@@ -2,8 +2,6 @@ package usecase_test
import (
"context"
- "path"
- "sync"
"testing"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -22,7 +20,6 @@ type Dependencies struct {
profiles profile.Repository
session *domain.Session
sessions session.Repository
- store *sync.Map
token *domain.Token
tokens token.Repository
}
@@ -31,9 +28,12 @@ func TestExchange(t *testing.T) {
t.Parallel()
deps := NewDependencies(t)
- deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, deps.session.Me.String()), deps.profile)
- if err := deps.sessions.Create(context.Background(), deps.session); err != nil {
+ if err := deps.profiles.Create(context.Background(), deps.session.Me, *deps.profile); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := deps.sessions.Create(context.Background(), *deps.session); err != nil {
t.Fatal(err)
}
@@ -95,7 +95,7 @@ func TestVerify(t *testing.T) {
t.Parallel()
testToken := domain.TestToken(t)
- if err := deps.tokens.Create(context.Background(), testToken); err != nil {
+ if err := deps.tokens.Create(context.Background(), *testToken); err != nil {
t.Fatal(err)
}
@@ -136,17 +136,15 @@ func TestRevoke(t *testing.T) {
func NewDependencies(tb testing.TB) Dependencies {
tb.Helper()
- store := new(sync.Map)
config := domain.TestConfig(tb)
return Dependencies{
config: config,
profile: domain.TestProfile(tb),
- profiles: profilerepo.NewMemoryProfileRepository(store),
+ profiles: profilerepo.NewMemoryProfileRepository(),
session: domain.TestSession(tb),
- sessions: sessionrepo.NewMemorySessionRepository(store, config),
- store: store,
+ sessions: sessionrepo.NewMemorySessionRepository(*config),
token: domain.TestToken(tb),
- tokens: tokenrepo.NewMemoryTokenRepository(store),
+ tokens: tokenrepo.NewMemoryTokenRepository(),
}
}
diff --git a/internal/user/delivery/http/user_http.go b/internal/user/delivery/http/user_http.go
index 88548fb..9e13d2c 100644
--- a/internal/user/delivery/http/user_http.go
+++ b/internal/user/delivery/http/user_http.go
@@ -1,10 +1,10 @@
package http
import (
- "encoding/json"
"net/http"
"strings"
+ "github.com/goccy/go-json"
"github.com/lestrrat-go/jwx/v2/jwa"
"source.toby3d.me/toby3d/auth/internal/common"
@@ -13,19 +13,10 @@ import (
"source.toby3d.me/toby3d/auth/internal/token"
)
-type (
- UserInformationResponse struct {
- Name string `json:"name,omitempty"`
- URL string `json:"url,omitempty"`
- Photo string `json:"photo,omitempty"`
- Email string `json:"email,omitempty"`
- }
-
- Handler struct {
- config *domain.Config
- tokens token.UseCase
- }
-)
+type Handler struct {
+ config *domain.Config
+ tokens token.UseCase
+}
func NewHandler(tokens token.UseCase, config *domain.Config) *Handler {
return &Handler{
@@ -34,7 +25,7 @@ func NewHandler(tokens token.UseCase, config *domain.Config) *Handler {
}
}
-func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{
//nolint:exhaustivestruct
middleware.JWTWithConfig(middleware.JWTConfig{
@@ -45,13 +36,18 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Skipper: middleware.DefaultSkipper,
TokenLookup: "header:" + common.HeaderAuthorization + ":Bearer ",
}),
- middleware.LogFmt(),
}
- chain.Handler(h.handleFunc).ServeHTTP(w, r)
+ return chain.Handler(h.handleFunc)
}
func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "" && r.Method != http.MethodGet {
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+
+ return
+ }
+
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
diff --git a/internal/user/delivery/http/user_http_schema.go b/internal/user/delivery/http/user_http_schema.go
new file mode 100644
index 0000000..28c8668
--- /dev/null
+++ b/internal/user/delivery/http/user_http_schema.go
@@ -0,0 +1,8 @@
+package http
+
+type UserInformationResponse struct {
+ Name string `json:"name,omitempty"`
+ URL string `json:"url,omitempty"`
+ Photo string `json:"photo,omitempty"`
+ Email string `json:"email,omitempty"`
+}
diff --git a/internal/user/delivery/http/user_http_test.go b/internal/user/delivery/http/user_http_test.go
index 08821b5..fe8b3ee 100644
--- a/internal/user/delivery/http/user_http_test.go
+++ b/internal/user/delivery/http/user_http_test.go
@@ -1,13 +1,12 @@
package http_test
import (
+ "context"
+ "net/http"
"net/http/httptest"
- "path"
- "sync"
"testing"
"github.com/goccy/go-json"
- http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -26,7 +25,6 @@ type Dependencies struct {
profile *domain.Profile
profiles profile.Repository
sessions session.Repository
- store *sync.Map
token *domain.Token
tokens token.Repository
tokenService token.UseCase
@@ -36,13 +34,17 @@ func TestUserInfo(t *testing.T) {
t.Parallel()
deps := NewDependencies(t)
- deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, deps.token.Me.String()), deps.profile)
+ if err := deps.profiles.Create(context.Background(), deps.token.Me, *deps.profile); err != nil {
+ t.Fatal(err)
+ }
req := httptest.NewRequest(http.MethodGet, "https://example.com/userinfo", nil)
req.Header.Set(common.HeaderAuthorization, "Bearer "+deps.token.AccessToken)
w := httptest.NewRecorder()
- delivery.NewHandler(deps.tokenService, deps.config).ServeHTTP(w, req)
+ delivery.NewHandler(deps.tokenService, deps.config).
+ Handler().
+ ServeHTTP(w, req)
resp := w.Result()
@@ -69,22 +71,23 @@ func TestUserInfo(t *testing.T) {
func NewDependencies(tb testing.TB) Dependencies {
tb.Helper()
- store := new(sync.Map)
config := domain.TestConfig(tb)
+ sessions := sessionrepo.NewMemorySessionRepository(*config)
+ tokens := tokenrepo.NewMemoryTokenRepository()
+ profiles := profilerepo.NewMemoryProfileRepository()
return Dependencies{
config: config,
profile: domain.TestProfile(tb),
- profiles: profilerepo.NewMemoryProfileRepository(store),
- sessions: sessionrepo.NewMemorySessionRepository(store, config),
- store: store,
+ profiles: profiles,
+ sessions: sessions,
token: domain.TestToken(tb),
- tokens: tokenrepo.NewMemoryTokenRepository(store),
+ tokens: tokens,
tokenService: tokenucase.NewTokenUseCase(tokenucase.Config{
Config: config,
- Profiles: profilerepo.NewMemoryProfileRepository(store),
- Sessions: sessionrepo.NewMemorySessionRepository(store, config),
- Tokens: tokenrepo.NewMemoryTokenRepository(store),
+ Profiles: profiles,
+ Sessions: sessions,
+ Tokens: tokens,
}),
}
}
diff --git a/internal/user/repository.go b/internal/user/repository.go
index ec7edc1..b9dbd2c 100644
--- a/internal/user/repository.go
+++ b/internal/user/repository.go
@@ -7,7 +7,8 @@ import (
)
type Repository interface {
- Get(ctx context.Context, me *domain.Me) (*domain.User, error)
+ Create(ctx context.Context, user domain.User) error
+ Get(ctx context.Context, me domain.Me) (*domain.User, error)
}
var ErrNotExist error = domain.NewError(domain.ErrorCodeServerError, "user not exist", "")
diff --git a/internal/user/repository/http/http_user.go b/internal/user/repository/http/http_user.go
index bd3252f..9901119 100644
--- a/internal/user/repository/http/http_user.go
+++ b/internal/user/repository/http/http_user.go
@@ -1,12 +1,16 @@
package http
import (
+ "bytes"
"context"
"fmt"
+ "io"
+ "net/http"
"net/url"
- http "github.com/valyala/fasthttp"
+ "golang.org/x/exp/slices"
+ "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/httputil"
"source.toby3d.me/toby3d/auth/internal/user"
@@ -38,26 +42,21 @@ func NewHTTPUserRepository(client *http.Client) user.Repository {
}
}
-func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain.User, error) {
- req := http.AcquireRequest()
- defer http.ReleaseRequest(req)
- req.Header.SetMethod(http.MethodGet)
- req.SetRequestURI(me.String())
+// WARN(toby3d): not implemented.
+func (httpUserRepository) Create(_ context.Context, _ domain.User) error {
+ return nil
+}
- resp := http.AcquireResponse()
- defer http.ReleaseResponse(resp)
-
- if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
+func (repo *httpUserRepository) Get(ctx context.Context, me domain.Me) (*domain.User, error) {
+ resp, err := repo.client.Get(me.String())
+ if err != nil {
return nil, fmt.Errorf("cannot fetch user by me: %w", err)
}
- // TODO(toby3d): handle error here?
- resolvedMe, _ := domain.ParseMe(string(resp.Header.Peek(http.HeaderLocation)))
-
user := &domain.User{
AuthorizationEndpoint: nil,
IndieAuthMetadata: nil,
- Me: resolvedMe,
+ Me: &me,
Micropub: nil,
Microsub: nil,
Profile: domain.NewProfile(),
@@ -65,7 +64,7 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain
TokenEndpoint: nil,
}
- if metadata, err := httputil.ExtractMetadata(resp, repo.client); err == nil {
+ if metadata, err := httputil.ExtractFromMetadata(repo.client, me.String()); err == nil {
user.AuthorizationEndpoint = metadata.AuthorizationEndpoint
user.Micropub = metadata.MicropubEndpoint
user.Microsub = metadata.MicrosubEndpoint
@@ -73,89 +72,87 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain
user.TokenEndpoint = metadata.TokenEndpoint
}
- extractUser(user, resp)
- extractProfile(user.Profile, resp)
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("cannot read response body: %w", err)
+ }
+
+ extractUser(me.URL(), user, body, resp.Header.Get(common.HeaderLink))
+ extractProfile(me.URL(), user.Profile, body)
return user, nil
}
//nolint:cyclop
-func extractUser(dst *domain.User, src *http.Response) {
- if dst.IndieAuthMetadata != nil {
- if endpoints := httputil.ExtractEndpoints(src, relIndieAuthMetadata); len(endpoints) > 0 {
- dst.IndieAuthMetadata = endpoints[len(endpoints)-1]
+func extractUser(u *url.URL, dst *domain.User, body []byte, header string) {
+ for key, target := range map[string]**url.URL{
+ relAuthorizationEndpoint: &dst.AuthorizationEndpoint,
+ relIndieAuthMetadata: &dst.IndieAuthMetadata,
+ relMicropub: &dst.Micropub,
+ relMicrosub: &dst.Microsub,
+ relTicketEndpoint: &dst.TicketEndpoint,
+ relTokenEndpoint: &dst.TokenEndpoint,
+ } {
+ if target == nil {
+ continue
}
- }
- if dst.AuthorizationEndpoint == nil {
- if endpoints := httputil.ExtractEndpoints(src, relAuthorizationEndpoint); len(endpoints) > 0 {
- dst.AuthorizationEndpoint = endpoints[len(endpoints)-1]
- }
- }
-
- if dst.Micropub == nil {
- if endpoints := httputil.ExtractEndpoints(src, relMicropub); len(endpoints) > 0 {
- dst.Micropub = endpoints[len(endpoints)-1]
- }
- }
-
- if dst.Microsub == nil {
- if endpoints := httputil.ExtractEndpoints(src, relMicrosub); len(endpoints) > 0 {
- dst.Microsub = endpoints[len(endpoints)-1]
- }
- }
-
- if dst.TicketEndpoint == nil {
- if endpoints := httputil.ExtractEndpoints(src, relTicketEndpoint); len(endpoints) > 0 {
- dst.TicketEndpoint = endpoints[len(endpoints)-1]
- }
- }
-
- if dst.TokenEndpoint == nil {
- if endpoints := httputil.ExtractEndpoints(src, relTokenEndpoint); len(endpoints) > 0 {
- dst.TokenEndpoint = endpoints[len(endpoints)-1]
+ if endpoints := httputil.ExtractEndpoints(bytes.NewReader(body), u, header, key); len(endpoints) > 0 {
+ *target = endpoints[len(endpoints)-1]
}
}
}
//nolint:cyclop
-func extractProfile(dst *domain.Profile, src *http.Response) {
- for _, name := range httputil.ExtractProperty(src, hCard, propertyName) {
- if n, ok := name.(string); ok {
+func extractProfile(u *url.URL, dst *domain.Profile, body []byte) {
+ for _, name := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyName) {
+ if n, ok := name.(string); ok && !slices.Contains(dst.Name, n) {
dst.Name = append(dst.Name, n)
}
}
- for _, rawEmail := range httputil.ExtractProperty(src, hCard, propertyEmail) {
+ for _, rawEmail := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyEmail) {
email, ok := rawEmail.(string)
if !ok {
continue
}
- if e, err := domain.ParseEmail(email); err == nil {
+ if e, err := domain.ParseEmail(email); err == nil && !slices.Contains(dst.Email, e) {
dst.Email = append(dst.Email, e)
}
}
- for _, rawURL := range httputil.ExtractProperty(src, hCard, propertyURL) {
+ for _, rawURL := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyURL) {
rawURL, ok := rawURL.(string)
if !ok {
continue
}
- if u, err := url.Parse(rawURL); err == nil {
+ if u, err := url.Parse(rawURL); err == nil && !containsUrl(dst.URL, u) {
dst.URL = append(dst.URL, u)
}
}
- for _, rawPhoto := range httputil.ExtractProperty(src, hCard, propertyPhoto) {
+ for _, rawPhoto := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyPhoto) {
photo, ok := rawPhoto.(string)
if !ok {
continue
}
- if p, err := url.Parse(photo); err == nil {
+ if p, err := url.Parse(photo); err == nil && !containsUrl(dst.Photo, p) {
dst.Photo = append(dst.Photo, p)
}
}
}
+
+func containsUrl(src []*url.URL, find *url.URL) bool {
+ for i := range src {
+ if src[i].String() != find.String() {
+ continue
+ }
+
+ return true
+ }
+
+ return false
+}
diff --git a/internal/user/repository/http/http_user_test.go b/internal/user/repository/http/http_user_test.go
index 083ad30..9a269d5 100644
--- a/internal/user/repository/http/http_user_test.go
+++ b/internal/user/repository/http/http_user_test.go
@@ -3,16 +3,15 @@ package http_test
import (
"context"
"fmt"
+ "net/http"
+ "net/http/httptest"
"strings"
"testing"
- "github.com/fasthttp/router"
- "github.com/stretchr/testify/assert"
- http "github.com/valyala/fasthttp"
+ "github.com/google/go-cmp/cmp"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
- "source.toby3d.me/toby3d/auth/internal/testing/httptest"
repository "source.toby3d.me/toby3d/auth/internal/user/repository/http"
)
@@ -40,39 +39,29 @@ func TestGet(t *testing.T) {
t.Parallel()
user := domain.TestUser(t)
- client, _, cleanup := httptest.New(t, testHandler(t, user))
- t.Cleanup(cleanup)
- result, err := repository.NewHTTPUserRepository(client).Get(context.Background(), user.Me)
+ srv := httptest.NewServer(testHandler(t, user))
+ t.Cleanup(srv.Close)
+
+ user.Me = domain.TestMe(t, srv.URL+"/")
+
+ result, err := repository.NewHTTPUserRepository(srv.Client()).
+ Get(context.Background(), *user.Me)
if err != nil {
t.Fatal(err)
}
- // NOTE(toby3d): endpoints
- assert.Equal(t, user.AuthorizationEndpoint.String(), result.AuthorizationEndpoint.String())
- assert.Equal(t, user.TokenEndpoint.String(), result.TokenEndpoint.String())
- assert.Equal(t, user.Micropub.String(), result.Micropub.String())
- assert.Equal(t, user.Microsub.String(), result.Microsub.String())
-
- // NOTE(toby3d): profile
- assert.Equal(t, user.Profile.Name, result.Profile.Name)
- assert.Equal(t, user.Profile.Email, result.Profile.Email)
-
- for i := range user.Profile.URL {
- assert.Equal(t, user.Profile.URL[i].String(), result.Profile.URL[i].String())
- }
-
- for i := range user.Profile.Photo {
- assert.Equal(t, user.Profile.Photo[i].String(), result.Profile.Photo[i].String())
+ if diff := cmp.Diff(user, result, cmp.AllowUnexported(domain.Me{}, domain.Email{})); diff != "" {
+ t.Errorf("%+s", diff)
}
}
-func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
+func testHandler(tb testing.TB, user *domain.User) http.Handler {
tb.Helper()
- router := router.New()
- router.GET("/", func(ctx *http.RequestCtx) {
- ctx.Response.Header.Set(http.HeaderLink, strings.Join([]string{
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderLink, strings.Join([]string{
`<` + user.AuthorizationEndpoint.String() + `>; rel="authorization_endpoint"`,
`<` + user.IndieAuthMetadata.String() + `>; rel="indieauth-metadata"`,
`<` + user.Micropub.String() + `>; rel="micropub"`,
@@ -80,17 +69,17 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
`<` + user.TicketEndpoint.String() + `>; rel="ticket_endpoint"`,
`<` + user.TokenEndpoint.String() + `>; rel="token_endpoint"`,
}, ", "))
- ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, fmt.Sprintf(
- testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0],
- ))
+ w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
+ fmt.Fprintf(w, testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0])
})
- router.GET(user.IndieAuthMetadata.Path, func(ctx *http.RequestCtx) {
- ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{
+ mux.HandleFunc(user.IndieAuthMetadata.Path, func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
+ fmt.Fprint(w, `{
"issuer": "`+user.Me.String()+`",
"authorization_endpoint": "`+user.AuthorizationEndpoint.String()+`",
"token_endpoint": "`+user.TokenEndpoint.String()+`"
}`)
})
- return router.Handler
+ return mux
}
diff --git a/internal/user/repository/memory/memory_user.go b/internal/user/repository/memory/memory_user.go
index 18d5385..5478861 100644
--- a/internal/user/repository/memory/memory_user.go
+++ b/internal/user/repository/memory/memory_user.go
@@ -2,7 +2,6 @@ package memory
import (
"context"
- "path"
"sync"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -10,27 +9,33 @@ import (
)
type memoryUserRepository struct {
- store *sync.Map
+ mutex *sync.RWMutex
+ users map[string]domain.User
}
-const DefaultPathPrefix string = "users"
-
-func NewMemoryUserRepository(store *sync.Map) user.Repository {
+func NewMemoryUserRepository() user.Repository {
return &memoryUserRepository{
- store: store,
+ mutex: new(sync.RWMutex),
+ users: make(map[string]domain.User),
}
}
-func (repo *memoryUserRepository) Get(ctx context.Context, me *domain.Me) (*domain.User, error) {
- p, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String()))
- if !ok {
- return nil, user.ErrNotExist
- }
+func (repo *memoryUserRepository) Create(ctx context.Context, user domain.User) error {
+ repo.mutex.Lock()
+ defer repo.mutex.Unlock()
- result, ok := p.(*domain.User)
- if !ok {
- return nil, user.ErrNotExist
- }
+ repo.users[user.Me.String()] = user
- return result, nil
+ return nil
+}
+
+func (repo *memoryUserRepository) Get(ctx context.Context, me domain.Me) (*domain.User, error) {
+ repo.mutex.RLock()
+ defer repo.mutex.RUnlock()
+
+ if u, ok := repo.users[me.String()]; ok {
+ return &u, nil
+ }
+
+ return nil, user.ErrNotExist
}
diff --git a/internal/user/repository/memory/memory_user_test.go b/internal/user/repository/memory/memory_user_test.go
deleted file mode 100644
index 397a464..0000000
--- a/internal/user/repository/memory/memory_user_test.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package memory_test
-
-import (
- "context"
- "path"
- "reflect"
- "sync"
- "testing"
-
- "source.toby3d.me/toby3d/auth/internal/domain"
- repository "source.toby3d.me/toby3d/auth/internal/user/repository/memory"
-)
-
-func TestGet(t *testing.T) {
- t.Parallel()
-
- user := domain.TestUser(t)
-
- store := new(sync.Map)
- store.Store(path.Join(repository.DefaultPathPrefix, user.Me.String()), user)
-
- result, err := repository.NewMemoryUserRepository(store).Get(context.Background(), user.Me)
- if err != nil {
- t.Error(err)
- }
-
- if !reflect.DeepEqual(result, user) {
- t.Errorf("Get(%s) = %+v, want %+v", user.Me, result, user)
- }
-}
diff --git a/internal/user/usecase.go b/internal/user/usecase.go
index 13829d1..fc71209 100644
--- a/internal/user/usecase.go
+++ b/internal/user/usecase.go
@@ -8,5 +8,5 @@ import (
type UseCase interface {
// Fetch discovery all available endpoints and Profile info on Me URL.
- Fetch(ctx context.Context, me *domain.Me) (*domain.User, error)
+ Fetch(ctx context.Context, me domain.Me) (*domain.User, error)
}
diff --git a/internal/user/usecase/user_ucase.go b/internal/user/usecase/user_ucase.go
index 6fffb84..b6a9fd9 100644
--- a/internal/user/usecase/user_ucase.go
+++ b/internal/user/usecase/user_ucase.go
@@ -18,7 +18,7 @@ func NewUserUseCase(repo user.Repository) user.UseCase {
}
}
-func (useCase *userUseCase) Fetch(ctx context.Context, me *domain.Me) (*domain.User, error) {
+func (useCase *userUseCase) Fetch(ctx context.Context, me domain.Me) (*domain.User, error) {
user, err := useCase.repo.Get(ctx, me)
if err != nil {
return nil, fmt.Errorf("cannot find user by me: %w", err)
diff --git a/internal/user/usecase/user_ucase_test.go b/internal/user/usecase/user_ucase_test.go
index 836d50c..183878b 100644
--- a/internal/user/usecase/user_ucase_test.go
+++ b/internal/user/usecase/user_ucase_test.go
@@ -2,9 +2,7 @@ package usecase_test
import (
"context"
- "path"
"reflect"
- "sync"
"testing"
"source.toby3d.me/toby3d/auth/internal/domain"
@@ -15,19 +13,20 @@ import (
func TestFetch(t *testing.T) {
t.Parallel()
- me := domain.TestMe(t, "https://user.example.net")
user := domain.TestUser(t)
+ user.Me = domain.TestMe(t, "https://user.example.net")
+ users := repository.NewMemoryUserRepository()
- store := new(sync.Map)
- store.Store(path.Join(repository.DefaultPathPrefix, me.String()), user)
+ if err := users.Create(context.Background(), *user); err != nil {
+ t.Fatal(err)
+ }
- result, err := ucase.NewUserUseCase(repository.NewMemoryUserRepository(store)).
- Fetch(context.Background(), me)
+ result, err := ucase.NewUserUseCase(users).Fetch(context.Background(), *user.Me)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(result, user) {
- t.Errorf("Fetch(%s) = %+v, want %+v", me, result, user)
+ t.Errorf("Fetch(%s) = %+v, want %+v", user.Me, result, user)
}
}
diff --git a/main.go b/main.go
index 0f0cdbf..c24fdde 100644
--- a/main.go
+++ b/main.go
@@ -5,27 +5,27 @@
package main
import (
+ "context"
+ "embed"
_ "embed"
"errors"
"flag"
- "fmt"
+ "io/fs"
"log"
+ "net/http"
+ _ "net/http/pprof"
+ "net/url"
"os"
"os/signal"
- "path"
"path/filepath"
"runtime"
"runtime/pprof"
"strings"
- "sync"
"syscall"
"time"
- "github.com/fasthttp/router"
"github.com/jmoiron/sqlx"
"github.com/spf13/viper"
- http "github.com/valyala/fasthttp"
- "github.com/valyala/fasthttp/pprofhandler"
"golang.org/x/text/language"
"golang.org/x/text/message"
_ "modernc.org/sqlite"
@@ -40,6 +40,7 @@ import (
"source.toby3d.me/toby3d/auth/internal/domain"
healthhttpdelivery "source.toby3d.me/toby3d/auth/internal/health/delivery/http"
metadatahttpdelivery "source.toby3d.me/toby3d/auth/internal/metadata/delivery/http"
+ "source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/profile"
profilehttprepo "source.toby3d.me/toby3d/auth/internal/profile/repository/http"
profileucase "source.toby3d.me/toby3d/auth/internal/profile/usecase"
@@ -57,6 +58,7 @@ import (
tokenmemoryrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory"
tokensqlite3repo "source.toby3d.me/toby3d/auth/internal/token/repository/sqlite3"
tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase"
+ "source.toby3d.me/toby3d/auth/internal/urlutil"
userhttpdelivery "source.toby3d.me/toby3d/auth/internal/user/delivery/http"
)
@@ -69,6 +71,7 @@ type (
tickets ticket.UseCase
profiles profile.UseCase
tokens token.UseCase
+ static fs.FS
}
NewAppOptions struct {
@@ -78,6 +81,7 @@ type (
Tickets ticket.Repository
Tokens token.Repository
Profiles profile.Repository
+ Static fs.FS
}
)
@@ -93,13 +97,16 @@ var (
logger = log.New(os.Stdout, "IndieAuth\t", log.Lmsgprefix|log.LstdFlags|log.LUTC)
config = new(domain.Config)
indieAuthClient = new(domain.Client)
-
- configPath string
- cpuProfilePath string
- memProfilePath string
- enablePprof bool
)
+var (
+ configPath, cpuProfilePath, memProfilePath string
+ enablePprof bool
+)
+
+//go:embed assets/*
+var staticFS embed.FS
+
//nolint:gochecknoinits
func init() {
flag.StringVar(&configPath, "config", filepath.Join(".", "config.yml"), "load specific config")
@@ -133,34 +140,44 @@ func init() {
rootURL := config.Server.GetRootURL()
indieAuthClient.Name = []string{config.Name}
- if indieAuthClient.ID, err = domain.ParseClientID(rootURL); err != nil {
+ cid, err := domain.ParseClientID(rootURL)
+ if err != nil {
logger.Fatalln("fail to read config:", err)
}
- url, err := domain.ParseURL(rootURL)
+ indieAuthClient.ID = *cid
+
+ u, err := url.Parse(rootURL)
if err != nil {
logger.Fatalln("cannot parse root URL as client URL:", err)
}
- logo, err := domain.ParseURL(rootURL + config.Server.StaticURLPrefix + "/icon.svg")
+ logo, err := url.Parse(rootURL + config.Server.StaticURLPrefix + "/icon.svg")
if err != nil {
logger.Fatalln("cannot parse root URL as client URL:", err)
}
- redirectURI, err := domain.ParseURL(rootURL + "/callback")
+ redirectURI, err := url.Parse(rootURL + "callback")
if err != nil {
logger.Fatalln("cannot parse root URL as client URL:", err)
}
- indieAuthClient.URL = []*domain.URL{url}
- indieAuthClient.Logo = []*domain.URL{logo}
- indieAuthClient.RedirectURI = []*domain.URL{redirectURI}
+ indieAuthClient.URL = []*url.URL{u}
+ indieAuthClient.Logo = []*url.URL{logo}
+ indieAuthClient.RedirectURI = []*url.URL{redirectURI}
}
//nolint:funlen,cyclop // "god object" and the entry point of all modules
func main() {
+ ctx := context.Background()
+
var opts NewAppOptions
+ var err error
+ if opts.Static, err = fs.Sub(staticFS, "assets"); err != nil {
+ logger.Fatalln(err)
+ }
+
switch strings.ToLower(config.Database.Type) {
case "sqlite3":
store, err := sqlx.Open("sqlite", config.Database.Path)
@@ -176,51 +193,27 @@ func main() {
opts.Sessions = sessionsqlite3repo.NewSQLite3SessionRepository(store)
opts.Tickets = ticketsqlite3repo.NewSQLite3TicketRepository(store, config)
case "memory":
- store := new(sync.Map)
- opts.Tokens = tokenmemoryrepo.NewMemoryTokenRepository(store)
- opts.Sessions = sessionmemoryrepo.NewMemorySessionRepository(store, config)
- opts.Tickets = ticketmemoryrepo.NewMemoryTicketRepository(store, config)
+ opts.Tokens = tokenmemoryrepo.NewMemoryTokenRepository()
+ opts.Sessions = sessionmemoryrepo.NewMemorySessionRepository(*config)
+ opts.Tickets = ticketmemoryrepo.NewMemoryTicketRepository(*config)
default:
log.Fatalln("unsupported database type, use 'memory' or 'sqlite3'")
}
go opts.Sessions.GC()
- //nolint:exhaustivestruct // too many options
- opts.Client = &http.Client{
- Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()),
- ReadTimeout: DefaultReadTimeout,
- WriteTimeout: DefaultWriteTimeout,
- }
+ opts.Client = new(http.Client)
opts.Clients = clienthttprepo.NewHTTPClientRepository(opts.Client)
opts.Profiles = profilehttprepo.NewHTPPClientRepository(opts.Client)
- r := router.New()
- NewApp(opts).Register(r)
- //nolint:exhaustivestruct // too many options
- r.ServeFilesCustom(path.Join(config.Server.StaticURLPrefix, "{filepath:*}"), &http.FS{
- Root: config.Server.StaticRootPath,
- CacheDuration: DefaultCacheDuration,
- AcceptByteRange: true,
- Compress: true,
- CompressBrotli: true,
- GenerateIndexPages: true,
- })
-
- if enablePprof {
- r.GET("/debug/pprof/{filepath:*}", pprofhandler.PprofHandler)
- }
+ app := NewApp(opts)
//nolint:exhaustivestruct
server := &http.Server{
- Name: fmt.Sprintf("IndieAuth/0.1 (+%s)", config.Server.GetAddress()),
- Handler: r.Handler,
- ReadTimeout: DefaultReadTimeout,
- WriteTimeout: DefaultWriteTimeout,
- DisableKeepalive: true,
- ReduceMemoryUsage: true,
- SecureErrorLogMessage: true,
- CloseOnShutdown: true,
+ Addr: config.Server.GetAddress(),
+ Handler: app.Handler(),
+ ReadTimeout: DefaultReadTimeout,
+ WriteTimeout: DefaultWriteTimeout,
}
done := make(chan os.Signal, 1)
@@ -243,15 +236,15 @@ func main() {
logger.Printf("started at %s, available at %s", config.Server.GetAddress(),
config.Server.GetRootURL())
- err := server.ListenAndServe(config.Server.GetAddress())
- if err != nil && !errors.Is(err, http.ErrConnectionClosed) {
+ err := server.ListenAndServe()
+ if err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatalln("cannot listen and serve:", err)
}
}()
<-done
- if err := server.Shutdown(); err != nil {
+ if err := server.Shutdown(ctx); err != nil {
logger.Fatalln("failed shutdown of server:", err)
}
@@ -274,6 +267,7 @@ func main() {
func NewApp(opts NewAppOptions) *App {
return &App{
+ static: opts.Static,
auth: authucase.NewAuthUseCase(opts.Sessions, opts.Profiles, config),
clients: clientucase.NewClientUseCase(opts.Clients),
matcher: language.NewMatcher(message.DefaultCatalog.Languages()),
@@ -289,20 +283,19 @@ func NewApp(opts NewAppOptions) *App {
}
}
-func (app *App) Register(r *router.Router) {
- tickethttpdelivery.NewRequestHandler(app.tickets, app.matcher, config).Register(r)
- healthhttpdelivery.NewRequestHandler().Register(r)
- metadatahttpdelivery.NewRequestHandler(&domain.Metadata{
- Issuer: indieAuthClient.ID,
- AuthorizationEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "authorize"),
- TokenEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "token"),
- TicketEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "ticket"),
+// TODO(toby3d): move module middlewares to here.
+func (app *App) Handler() http.Handler {
+ metadata := metadatahttpdelivery.NewHandler(&domain.Metadata{
+ Issuer: indieAuthClient.ID.URL(),
+ AuthorizationEndpoint: indieAuthClient.ID.URL().JoinPath("authorize"),
+ TokenEndpoint: indieAuthClient.ID.URL().JoinPath("token"),
+ TicketEndpoint: indieAuthClient.ID.URL().JoinPath("ticket"),
MicropubEndpoint: nil,
MicrosubEndpoint: nil,
- IntrospectionEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "introspect"),
- RevocationEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "revocation"),
- UserinfoEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "userinfo"),
- ServiceDocumentation: domain.MustParseURL("https://indieauth.net/source/"),
+ IntrospectionEndpoint: indieAuthClient.ID.URL().JoinPath("introspect"),
+ RevocationEndpoint: indieAuthClient.ID.URL().JoinPath("revocation"),
+ UserinfoEndpoint: indieAuthClient.ID.URL().JoinPath("userinfo"),
+ ServiceDocumentation: &url.URL{Scheme: "https", Host: "indieauth.net", Path: "/source/"},
IntrospectionEndpointAuthMethodsSupported: []string{"Bearer"},
RevocationEndpointAuthMethodsSupported: []string{"none"},
ScopesSupported: domain.Scopes{
@@ -319,8 +312,14 @@ func (app *App) Register(r *router.Router) {
domain.ScopeRead,
domain.ScopeUpdate,
},
- ResponseTypesSupported: []domain.ResponseType{domain.ResponseTypeCode, domain.ResponseTypeID},
- GrantTypesSupported: []domain.GrantType{domain.GrantTypeAuthorizationCode, domain.GrantTypeTicket},
+ ResponseTypesSupported: []domain.ResponseType{
+ domain.ResponseTypeCode,
+ domain.ResponseTypeID,
+ },
+ GrantTypesSupported: []domain.GrantType{
+ domain.GrantTypeAuthorizationCode,
+ domain.GrantTypeTicket,
+ },
CodeChallengeMethodsSupported: []domain.CodeChallengeMethod{
domain.CodeChallengeMethodMD5,
domain.CodeChallengeMethodPLAIN,
@@ -329,20 +328,57 @@ func (app *App) Register(r *router.Router) {
domain.CodeChallengeMethodS512,
},
AuthorizationResponseIssParameterSupported: true,
- }).Register(r)
- tokenhttpdelivery.NewRequestHandler(app.tokens, app.tickets, config).Register(r)
- clienthttpdelivery.NewRequestHandler(clienthttpdelivery.NewRequestHandlerOptions{
- Client: indieAuthClient,
- Config: config,
- Matcher: app.matcher,
- Tokens: app.tokens,
- }).Register(r)
- authhttpdelivery.NewRequestHandler(authhttpdelivery.NewRequestHandlerOptions{
+ }).Handler()
+ health := healthhttpdelivery.NewHandler().Handler()
+ auth := authhttpdelivery.NewHandler(authhttpdelivery.NewHandlerOptions{
Auth: app.auth,
Clients: app.clients,
- Config: config,
+ Config: *config,
Matcher: app.matcher,
Profiles: app.profiles,
- }).Register(r)
- userhttpdelivery.NewRequestHandler(app.tokens, config).Register(r)
+ }).Handler()
+ token := tokenhttpdelivery.NewHandler(app.tokens, app.tickets, config).Handler()
+ client := clienthttpdelivery.NewHandler(clienthttpdelivery.NewHandlerOptions{
+ Client: *indieAuthClient,
+ Config: *config,
+ Matcher: app.matcher,
+ Tokens: app.tokens,
+ }).Handler()
+ user := userhttpdelivery.NewHandler(app.tokens, config).Handler()
+ ticket := tickethttpdelivery.NewHandler(app.tickets, app.matcher, *config).Handler()
+ static := http.FileServer(http.FS(app.static))
+
+ return http.HandlerFunc(middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var head string
+ head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
+
+ switch head {
+ default:
+ r.URL = r.URL.JoinPath(head, r.URL.Path)
+
+ static.ServeHTTP(w, r)
+ case "", "callback":
+ r.URL = r.URL.JoinPath(head, r.URL.Path)
+
+ client.ServeHTTP(w, r)
+ case "token", "introspect", "revocation":
+ r.URL = r.URL.JoinPath(head, r.URL.Path)
+
+ token.ServeHTTP(w, r)
+ case ".well-known":
+ if head, _ = urlutil.ShiftPath(r.URL.Path); head == "oauth-authorization-server" {
+ metadata.ServeHTTP(w, r)
+ } else {
+ http.NotFound(w, r)
+ }
+ case "authorize":
+ auth.ServeHTTP(w, r)
+ case "health":
+ health.ServeHTTP(w, r)
+ case "userinfo":
+ user.ServeHTTP(w, r)
+ case "ticket":
+ ticket.ServeHTTP(w, r)
+ }
+ }).Intercept(middleware.LogFmt()))
}
diff --git a/web/authorize.qtpl b/web/authorize.qtpl
index 80df168..05b6495 100644
--- a/web/authorize.qtpl
+++ b/web/authorize.qtpl
@@ -17,137 +17,137 @@
} %}
{% func (p *AuthorizePage) Title() %}
- {% if p.Client.GetName() == "" %}
- {%= p.T("Authorize %s", p.Client.GetName()) %}
- {% else %}
- {%= p.T("Authorize application") %}
- {% endif %}
+{% if p.Client.GetName() == "" %}
+{%= p.T("Authorize %s", p.Client.GetName()) %}
+{% else %}
+{%= p.T("Authorize application") %}
+{% endif %}
{% endfunc %}
{% func (p *AuthorizePage) Body() %}
-
- {% if p.Client.GetLogo() != nil %}
-
- {% endif %}
+
+ {% if p.Client.GetLogo() != nil %}
+
+ {% endif %}
-
+
+ {% endif %}
+
+
-
-
+
+{% endfunc %}
diff --git a/web/authorize.qtpl.go b/web/authorize.qtpl.go
index b9afc9b..56c0bf0 100644
--- a/web/authorize.qtpl.go
+++ b/web/authorize.qtpl.go
@@ -41,27 +41,27 @@ type AuthorizePage struct {
func (p *AuthorizePage) StreamTitle(qw422016 *qt422016.Writer) {
//line web/authorize.qtpl:19
qw422016.N().S(`
- `)
+`)
//line web/authorize.qtpl:20
if p.Client.GetName() == "" {
//line web/authorize.qtpl:20
qw422016.N().S(`
- `)
+`)
//line web/authorize.qtpl:21
p.StreamT(qw422016, "Authorize %s", p.Client.GetName())
//line web/authorize.qtpl:21
qw422016.N().S(`
- `)
+`)
//line web/authorize.qtpl:22
} else {
//line web/authorize.qtpl:22
qw422016.N().S(`
- `)
+`)
//line web/authorize.qtpl:23
p.StreamT(qw422016, "Authorize application")
//line web/authorize.qtpl:23
qw422016.N().S(`
- `)
+`)
//line web/authorize.qtpl:24
}
//line web/authorize.qtpl:24
@@ -100,43 +100,43 @@ func (p *AuthorizePage) Title() string {
func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) {
//line web/authorize.qtpl:27
qw422016.N().S(`
-
- `)
+
+ `)
//line web/authorize.qtpl:29
if p.Client.GetLogo() != nil {
//line web/authorize.qtpl:29
qw422016.N().S(`
-
- `)
+ width="140">
+ `)
//line web/authorize.qtpl:40
}
//line web/authorize.qtpl:40
qw422016.N().S(`
-
- `)
+
+ `)
//line web/authorize.qtpl:43
if p.Client.GetURL() != nil {
//line web/authorize.qtpl:43
qw422016.N().S(`
-
- `)
+
+ `)
//line web/authorize.qtpl:53
}
//line web/authorize.qtpl:53
qw422016.N().S(`
-
-
+
+
-
-
+
`)
//line web/authorize.qtpl:153
}
diff --git a/web/ticket.qtpl b/web/ticket.qtpl
index 7f2cab1..96dcf3f 100644
--- a/web/ticket.qtpl
+++ b/web/ticket.qtpl
@@ -5,47 +5,47 @@
{% collapsespace %}
{% func (p *TicketPage) Body() %}
-
- {%= p.T("TicketAuth") %}
-
+
+ {%= p.T("TicketAuth") %}
+
-
-
+
+
- {% if p.CSRF != nil %}
-
- {% endif %}
+ {% if p.CSRF != nil %}
+
+ {% endif %}
-
-
-
-
+
+
+
+
-
-
-
-
+
+
+
+
-
-
-
+
+
+
{% endfunc %}
{% endcollapsespace %}
diff --git a/web/ticket.qtpl.go b/web/ticket.qtpl.go
index 53ac936..d579a88 100644
--- a/web/ticket.qtpl.go
+++ b/web/ticket.qtpl.go
@@ -30,7 +30,7 @@ func (p *TicketPage) StreamBody(qw422016 *qt422016.Writer) {
//line web/ticket.qtpl:9
p.StreamT(qw422016, "TicketAuth")
//line web/ticket.qtpl:9
- qw422016.N().S(` `)
+ qw422016.N().S(` `)
//line web/ticket.qtpl:21
if p.CSRF != nil {
//line web/ticket.qtpl:21