♻️ Drop HTTP delivery framework

replaced 'valyala/fasthttp' to native 'net/http' package, close #4
This commit is contained in:
Maxim Lebedev 2023-01-15 03:27:37 +06:00
parent 95cb5a2950
commit c7bd73c63a
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
90 changed files with 2551 additions and 2286 deletions

View File

@ -2,13 +2,10 @@ package http
import ( import (
"crypto/subtle" "crypto/subtle"
"errors" "net/http"
"path"
"strings" "strings"
"github.com/fasthttp/router" "github.com/goccy/go-json"
json "github.com/goccy/go-json"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
@ -16,111 +13,31 @@ import (
"source.toby3d.me/toby3d/auth/internal/client" "source.toby3d.me/toby3d/auth/internal/client"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/profile" "source.toby3d.me/toby3d/auth/internal/profile"
"source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/auth/web" "source.toby3d.me/toby3d/auth/web"
"source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware"
) )
type ( type (
AuthAuthorizationRequest struct { NewHandlerOptions struct {
// Indicates to the authorization server that an authorization
// code should be returned as the response.
ResponseType domain.ResponseType `form:"response_type"` // code
// The client URL.
ClientID *domain.ClientID `form:"client_id"`
// The redirect URL indicating where the user should be
// redirected to after approving the request.
RedirectURI *domain.URL `form:"redirect_uri"`
// A parameter set by the client which will be included when the
// user is redirected back to the client. This is used to
// prevent CSRF attacks. The authorization server MUST return
// the unmodified state value back to the client.
State string `form:"state"`
// The code challenge as previously described.
CodeChallenge string `form:"code_challenge"`
// The hashing method used to calculate the code challenge.
CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"`
// A space-separated list of scopes the client is requesting,
// e.g. "profile", or "profile create". If the client omits this
// value, the authorization server MUST NOT issue an access
// token for this authorization code. Only the user's profile
// URL may be returned without any scope requested.
Scope domain.Scopes `form:"scope,omitempty"`
// The URL that the user entered.
Me *domain.Me `form:"me"`
}
AuthVerifyRequest struct {
ClientID *domain.ClientID `form:"client_id"`
Me *domain.Me `form:"me"`
RedirectURI *domain.URL `form:"redirect_uri"`
CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"`
ResponseType domain.ResponseType `form:"response_type"`
Scope domain.Scopes `form:"scope[],omitempty"`
Authorize string `form:"authorize"`
CodeChallenge string `form:"code_challenge"`
State string `form:"state"`
Provider string `form:"provider"`
}
AuthExchangeRequest struct {
GrantType domain.GrantType `form:"grant_type"` // authorization_code
// The authorization code received from the authorization
// endpoint in the redirect.
Code string `form:"code"`
// The client's URL, which MUST match the client_id used in the
// authentication request.
ClientID *domain.ClientID `form:"client_id"`
// The client's redirect URL, which MUST match the initial
// authentication request.
RedirectURI *domain.URL `form:"redirect_uri"`
// The original plaintext random string generated before
// starting the authorization request.
CodeVerifier string `form:"code_verifier"`
}
AuthExchangeResponse struct {
Me *domain.Me `json:"me"`
Profile *AuthProfileResponse `json:"profile,omitempty"`
}
AuthProfileResponse struct {
Email *domain.Email `json:"email,omitempty"`
Photo *domain.URL `json:"photo,omitempty"`
URL *domain.URL `json:"url,omitempty"`
Name string `json:"name,omitempty"`
}
NewRequestHandlerOptions struct {
Auth auth.UseCase Auth auth.UseCase
Clients client.UseCase Clients client.UseCase
Config *domain.Config Config domain.Config
Matcher language.Matcher Matcher language.Matcher
Profiles profile.UseCase Profiles profile.UseCase
} }
RequestHandler struct { Handler struct {
clients client.UseCase clients client.UseCase
config *domain.Config config domain.Config
matcher language.Matcher matcher language.Matcher
useCase auth.UseCase useCase auth.UseCase
} }
) )
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { func NewHandler(opts NewHandlerOptions) *Handler {
return &RequestHandler{ return &Handler{
clients: opts.Clients, clients: opts.Clients,
config: opts.Config, config: opts.Config,
matcher: opts.Matcher, matcher: opts.Matcher,
@ -128,16 +45,16 @@ func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
} }
} }
func (h *RequestHandler) Register(r *router.Router) { func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{ chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{ middleware.CSRFWithConfig(middleware.CSRFConfig{
Skipper: func(ctx *http.RequestCtx) bool { Skipper: func(w http.ResponseWriter, r *http.Request) bool {
matched, _ := path.Match("/authorize*", string(ctx.Path())) head, _ := urlutil.ShiftPath(r.URL.Path)
return ctx.IsPost() && matched return r.Method == http.MethodPost && head == "authorize"
}, },
CookieMaxAge: 0, CookieMaxAge: 0,
CookieSameSite: http.CookieSameSiteStrictMode, CookieSameSite: http.SameSiteStrictMode,
ContextKey: "csrf", ContextKey: "csrf",
CookieDomain: h.config.Server.Domain, CookieDomain: h.config.Server.Domain,
CookieName: "__Secure-csrf", CookieName: "__Secure-csrf",
@ -148,14 +65,12 @@ func (h *RequestHandler) Register(r *router.Router) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
}), }),
middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{ middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{
Skipper: func(ctx *http.RequestCtx) bool { Skipper: func(w http.ResponseWriter, r *http.Request) bool {
matched, _ := path.Match("/api/*", string(ctx.Path())) head, _ := urlutil.ShiftPath(r.URL.Path)
provider := string(ctx.QueryArgs().Peek("provider"))
providerMatched := provider != "" && provider != domain.ProviderDirect.UID
return !ctx.IsPost() || !matched || providerMatched return r.Method != http.MethodPost || head != "api"
}, },
Validator: func(ctx *http.RequestCtx, login, password string) (bool, error) { Validator: func(w http.ResponseWriter, r *http.Request, login, password string) (bool, error) {
userMatch := subtle.ConstantTimeCompare([]byte(login), userMatch := subtle.ConstantTimeCompare([]byte(login),
[]byte(h.config.IndieAuth.Username)) []byte(h.config.IndieAuth.Username))
passMatch := subtle.ConstantTimeCompare([]byte(password), passMatch := subtle.ConstantTimeCompare([]byte(password),
@ -165,29 +80,57 @@ func (h *RequestHandler) Register(r *router.Router) {
}, },
Realm: "", Realm: "",
}), }),
middleware.LogFmt(),
} }
r.GET("/authorize", chain.RequestHandler(h.handleAuthorize)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.POST("/api/authorize", chain.RequestHandler(h.handleVerify)) var head string
r.POST("/authorize", chain.RequestHandler(h.handleExchange)) head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
switch r.Method {
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
case http.MethodGet, "":
if head != "" {
http.NotFound(w, r)
return
}
chain.Handler(h.handleAuthorize).ServeHTTP(w, r)
case http.MethodPost:
switch head {
default:
http.NotFound(w, r)
case "":
chain.Handler(h.handleExchange).ServeHTTP(w, r)
case "verify":
chain.Handler(h.handleVerify).ServeHTTP(w, r)
}
}
})
} }
func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) { func (h *Handler) handleAuthorize(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) if r.Method != http.MethodGet && r.Method != "" {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) return
}
w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...) tag, _, _ := h.matcher.Match(tags...)
baseOf := web.BaseOf{ baseOf := web.BaseOf{
Config: h.config, Config: &h.config,
Language: tag, Language: tag,
Printer: message.NewPrinter(tag), Printer: message.NewPrinter(tag),
} }
req := NewAuthAuthorizationRequest() req := NewAuthAuthorizationRequest()
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf, BaseOf: baseOf,
Error: err, Error: err,
}) })
@ -195,10 +138,10 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
return return
} }
client, err := h.clients.Discovery(ctx, req.ClientID) client, err := h.clients.Discovery(r.Context(), req.ClientID)
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf, BaseOf: baseOf,
Error: err, Error: err,
}) })
@ -207,8 +150,8 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
} }
if !client.ValidateRedirectURI(req.RedirectURI.URL) { if !client.ValidateRedirectURI(req.RedirectURI.URL) {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf, BaseOf: baseOf,
Error: domain.NewError( Error: domain.NewError(
domain.ErrorCodeInvalidClient, domain.ErrorCodeInvalidClient,
@ -220,15 +163,15 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
return return
} }
csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte) csrf, _ := r.Context().Value(middleware.DefaultCSRFConfig.ContextKey).([]byte)
web.WriteTemplate(ctx, &web.AuthorizePage{ web.WriteTemplate(w, &web.AuthorizePage{
BaseOf: baseOf, BaseOf: baseOf,
CSRF: csrf, CSRF: csrf,
Scope: req.Scope, Scope: req.Scope,
Client: client, Client: client,
Me: req.Me, Me: &req.Me,
RedirectURI: req.RedirectURI, RedirectURI: &req.RedirectURI,
CodeChallengeMethod: req.CodeChallengeMethod, CodeChallengeMethod: *req.CodeChallengeMethod,
ResponseType: req.ResponseType, ResponseType: req.ResponseType,
CodeChallenge: req.CodeChallenge, CodeChallenge: req.CodeChallenge,
State: req.State, State: req.State,
@ -236,15 +179,21 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
}) })
} }
func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) { func (h *Handler) handleVerify(w http.ResponseWriter, r *http.Request) {
ctx.Response.Header.Set(http.HeaderAccessControlAllowOrigin, h.config.Server.Domain) if r.Method != http.MethodPost {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx) return
}
w.Header().Set(common.HeaderAccessControlAllowOrigin, h.config.Server.Domain)
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
req := NewAuthVerifyRequest() req := NewAuthVerifyRequest()
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err) _ = encoder.Encode(err)
@ -254,60 +203,70 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
if strings.EqualFold(req.Authorize, "deny") { if strings.EqualFold(req.Authorize, "deny") {
domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State). domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State).
SetReirectURI(req.RedirectURI.URL) SetReirectURI(req.RedirectURI.URL)
ctx.Redirect(req.RedirectURI.String(), http.StatusFound) http.Redirect(w, r, req.RedirectURI.String(), http.StatusFound)
return return
} }
code, err := h.useCase.Generate(ctx, auth.GenerateOptions{ code, err := h.useCase.Generate(r.Context(), auth.GenerateOptions{
ClientID: req.ClientID, ClientID: req.ClientID,
Me: req.Me, Me: req.Me,
RedirectURI: req.RedirectURI.URL, RedirectURI: req.RedirectURI.URL,
CodeChallengeMethod: req.CodeChallengeMethod, CodeChallengeMethod: *req.CodeChallengeMethod,
Scope: req.Scope, Scope: req.Scope,
CodeChallenge: req.CodeChallenge, CodeChallenge: req.CodeChallenge,
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
_ = encoder.Encode(err) _ = encoder.Encode(err)
return return
} }
q := req.RedirectURI.Query()
for key, val := range map[string]string{ for key, val := range map[string]string{
"code": code, "code": code,
"iss": h.config.Server.GetRootURL(), "iss": h.config.Server.GetRootURL(),
"state": req.State, "state": req.State,
} { } {
req.RedirectURI.Query().Set(key, val) q.Set(key, val)
} }
ctx.Redirect(req.RedirectURI.String(), http.StatusFound) req.RedirectURI.RawQuery = q.Encode()
http.Redirect(w, r, req.RedirectURI.String(), http.StatusFound)
} }
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { func (h *Handler) handleExchange(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) if r.Method != http.MethodPost {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx) return
}
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
req := new(AuthExchangeRequest) req := new(AuthExchangeRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err) _ = encoder.Encode(err)
return return
} }
me, profile, err := h.useCase.Exchange(ctx, auth.ExchangeOptions{ me, profile, err := h.useCase.Exchange(r.Context(), auth.ExchangeOptions{
Code: req.Code, Code: req.Code,
ClientID: req.ClientID, ClientID: req.ClientID,
RedirectURI: req.RedirectURI.URL, RedirectURI: req.RedirectURI.URL,
CodeVerifier: req.CodeVerifier, CodeVerifier: req.CodeVerifier,
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err) _ = encoder.Encode(err)
@ -325,109 +284,7 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
} }
_ = encoder.Encode(&AuthExchangeResponse{ _ = encoder.Encode(&AuthExchangeResponse{
Me: me, Me: *me,
Profile: userInfo, Profile: userInfo,
}) })
} }
func NewAuthAuthorizationRequest() *AuthAuthorizationRequest {
return &AuthAuthorizationRequest{
ClientID: new(domain.ClientID),
CodeChallenge: "",
CodeChallengeMethod: domain.CodeChallengeMethodUnd,
Me: new(domain.Me),
RedirectURI: new(domain.URL),
ResponseType: domain.ResponseTypeUnd,
Scope: make(domain.Scopes, 0),
State: "",
}
}
//nolint:cyclop
func (r *AuthAuthorizationRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
return nil
}
func NewAuthVerifyRequest() *AuthVerifyRequest {
return &AuthVerifyRequest{
Authorize: "",
ClientID: new(domain.ClientID),
CodeChallenge: "",
CodeChallengeMethod: domain.CodeChallengeMethodUnd,
Me: new(domain.Me),
Provider: "",
RedirectURI: new(domain.URL),
ResponseType: domain.ResponseTypeUnd,
Scope: make(domain.Scopes, 0),
State: "",
}
}
//nolint:funlen,cyclop
func (r *AuthVerifyRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
// NOTE(toby3d): backwards-compatible support.
// See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
r.Provider = strings.ToLower(r.Provider)
if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"cannot validate verification request",
"https://indieauth.net/source/#authorization-request",
)
}
return nil
}
func (r *AuthExchangeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"cannot validate verification request",
"https://indieauth.net/source/#redeeming-the-authorization-code",
)
}
return nil
}

View File

@ -0,0 +1,212 @@
package http
import (
"errors"
"net/http"
"strings"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/form"
)
type (
AuthAuthorizationRequest struct {
// Indicates to the authorization server that an authorization
// code should be returned as the response.
ResponseType domain.ResponseType `form:"response_type"` // code
// The client URL.
ClientID domain.ClientID `form:"client_id"`
// The redirect URL indicating where the user should be
// redirected to after approving the request.
RedirectURI domain.URL `form:"redirect_uri"`
// The URL that the user entered.
Me domain.Me `form:"me"`
// The hashing method used to calculate the code challenge.
CodeChallengeMethod *domain.CodeChallengeMethod `form:"code_challenge_method,omitempty"`
// A space-separated list of scopes the client is requesting,
// e.g. "profile", or "profile create". If the client omits this
// value, the authorization server MUST NOT issue an access
// token for this authorization code. Only the user's profile
// URL may be returned without any scope requested.
Scope domain.Scopes `form:"scope,omitempty"`
// A parameter set by the client which will be included when the
// user is redirected back to the client. This is used to
// prevent CSRF attacks. The authorization server MUST return
// the unmodified state value back to the client.
State string `form:"state"`
// The code challenge as previously described.
CodeChallenge string `form:"code_challenge,omitempty"`
}
AuthVerifyRequest struct {
ClientID domain.ClientID `form:"client_id"`
Me domain.Me `form:"me"`
RedirectURI domain.URL `form:"redirect_uri"`
ResponseType domain.ResponseType `form:"response_type"`
CodeChallengeMethod *domain.CodeChallengeMethod `form:"code_challenge_method,omitempty"`
Scope domain.Scopes `form:"scope[],omitempty"`
Authorize string `form:"authorize"`
CodeChallenge string `form:"code_challenge,omitempty"`
State string `form:"state"`
Provider string `form:"provider"`
}
AuthExchangeRequest struct {
GrantType domain.GrantType `form:"grant_type"` // authorization_code
// The client's URL, which MUST match the client_id used in the
// authentication request.
ClientID domain.ClientID `form:"client_id"`
// The client's redirect URL, which MUST match the initial
// authentication request.
RedirectURI domain.URL `form:"redirect_uri"`
// The authorization code received from the authorization
// endpoint in the redirect.
Code string `form:"code"`
// The original plaintext random string generated before
// starting the authorization request.
CodeVerifier string `form:"code_verifier"`
}
AuthExchangeResponse struct {
Me domain.Me `json:"me"`
Profile *AuthProfileResponse `json:"profile,omitempty"`
}
AuthProfileResponse struct {
Email *domain.Email `json:"email,omitempty"`
Photo *domain.URL `json:"photo,omitempty"`
URL *domain.URL `json:"url,omitempty"`
Name string `json:"name,omitempty"`
}
)
func NewAuthAuthorizationRequest() *AuthAuthorizationRequest {
return &AuthAuthorizationRequest{
ClientID: domain.ClientID{},
CodeChallenge: "",
CodeChallengeMethod: &domain.CodeChallengeMethodUnd,
Me: domain.Me{},
RedirectURI: domain.URL{},
ResponseType: domain.ResponseTypeUnd,
Scope: make(domain.Scopes, 0),
State: "",
}
}
//nolint:cyclop
func (r *AuthAuthorizationRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
return nil
}
func NewAuthVerifyRequest() *AuthVerifyRequest {
return &AuthVerifyRequest{
Authorize: "",
ClientID: domain.ClientID{},
CodeChallenge: "",
CodeChallengeMethod: &domain.CodeChallengeMethodUnd,
Me: domain.Me{},
Provider: "",
RedirectURI: domain.URL{},
ResponseType: domain.ResponseTypeUnd,
Scope: make(domain.Scopes, 0),
State: "",
}
}
//nolint:funlen,cyclop
func (r *AuthVerifyRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := req.ParseForm(); err != nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
// NOTE(toby3d): backwards-compatible support.
// See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
r.Provider = strings.ToLower(r.Provider)
if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"cannot validate verification request",
"https://indieauth.net/source/#authorization-request",
)
}
return nil
}
func (r *AuthExchangeRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := req.ParseForm(); err != nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"cannot validate verification request",
"https://indieauth.net/source/#redeeming-the-authorization-code",
)
}
return nil
}

View File

@ -1,13 +1,14 @@
package http_test package http_test
import ( import (
"path" "context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
@ -22,7 +23,7 @@ import (
profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory" profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory"
"source.toby3d.me/toby3d/auth/internal/session" "source.toby3d.me/toby3d/auth/internal/session"
sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory" sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory"
"source.toby3d.me/toby3d/auth/internal/testing/httptest" "source.toby3d.me/toby3d/auth/internal/user"
userrepo "source.toby3d.me/toby3d/auth/internal/user/repository/memory" userrepo "source.toby3d.me/toby3d/auth/internal/user/repository/memory"
) )
@ -34,36 +35,31 @@ type Dependencies struct {
matcher language.Matcher matcher language.Matcher
profiles profile.Repository profiles profile.Repository
sessions session.Repository sessions session.Repository
store *sync.Map users user.Repository
} }
func TestAuthorize(t *testing.T) { func TestAuthorize(t *testing.T) {
t.Parallel() t.Parallel()
deps := NewDependencies(t) deps := NewDependencies(t)
me := domain.TestMe(t, "https://user.example.net") me := domain.TestMe(t, "https://user.example.net/")
user := domain.TestUser(t) user := domain.TestUser(t)
client := domain.TestClient(t) client := domain.TestClient(t)
deps.store.Store(path.Join(clientrepo.DefaultPathPrefix, client.ID.String()), client) if err := deps.clients.Create(context.Background(), *client); err != nil {
deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, me.String()), user.Profile) t.Fatal(err)
deps.store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), user) }
r := router.New() if err := deps.users.Create(context.Background(), *user); err != nil {
//nolint:exhaustivestruct t.Fatal(err)
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{ }
Auth: deps.authService,
Clients: deps.clientService,
Config: deps.config,
Matcher: deps.matcher,
}).Register(r)
httpClient, _, cleanup := httptest.New(t, r.Handler) if err := deps.profiles.Create(context.Background(), *me, *user.Profile); err != nil {
t.Cleanup(cleanup) t.Fatal(err)
}
uri := http.AcquireURI() u := &url.URL{Scheme: "https", Host: "example.com", Path: "/"}
defer http.ReleaseURI(uri) q := u.Query()
uri.Update("https://example.com/authorize")
for key, val := range map[string]string{ for key, val := range map[string]string{
"client_id": client.ID.String(), "client_id": client.ID.String(),
@ -75,26 +71,36 @@ func TestAuthorize(t *testing.T) {
"scope": "profile email", "scope": "profile email",
"state": "1234567890", "state": "1234567890",
} { } {
uri.QueryArgs().Set(key, val) q.Set(key, val)
} }
req := httptest.NewRequest(http.MethodGet, uri.String(), nil) u.RawQuery = q.Encode()
defer http.ReleaseRequest(req)
resp := http.AcquireResponse() req := httptest.NewRequest(http.MethodGet, u.String(), nil)
defer http.ReleaseResponse(resp) w := httptest.NewRecorder()
if err := httpClient.Do(req, resp); err != nil { //nolint:exhaustivestruct
delivery.NewHandler(delivery.NewHandlerOptions{
Auth: deps.authService,
Clients: deps.clientService,
Config: *deps.config,
Matcher: deps.matcher,
}).Handler().ServeHTTP(w, req)
resp := w.Result()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if resp.StatusCode() != http.StatusOK { if resp.StatusCode != http.StatusOK {
t.Errorf("GET %s = %d, want %d", uri.String(), resp.StatusCode(), http.StatusOK) t.Errorf("%s %s = %d, want %d", req.Method, u.String(), resp.StatusCode, http.StatusOK)
} }
const expResult = `Authorize application` const expResult = `Authorize application`
if result := string(resp.Body()); !strings.Contains(result, expResult) { if result := string(body); !strings.Contains(result, expResult) {
t.Errorf("GET %s = %s, want %s", uri.String(), result, expResult) t.Errorf("%s %s = %s, want %s", req.Method, u.String(), result, expResult)
} }
} }
@ -103,14 +109,15 @@ func NewDependencies(tb testing.TB) Dependencies {
config := domain.TestConfig(tb) config := domain.TestConfig(tb)
matcher := language.NewMatcher(message.DefaultCatalog.Languages()) matcher := language.NewMatcher(message.DefaultCatalog.Languages())
store := new(sync.Map) clients := clientrepo.NewMemoryClientRepository()
clients := clientrepo.NewMemoryClientRepository(store) users := userrepo.NewMemoryUserRepository()
sessions := sessionrepo.NewMemorySessionRepository(store, config) sessions := sessionrepo.NewMemorySessionRepository(*config)
profiles := profilerepo.NewMemoryProfileRepository(store) profiles := profilerepo.NewMemoryProfileRepository()
authService := ucase.NewAuthUseCase(sessions, profiles, config) authService := ucase.NewAuthUseCase(sessions, profiles, config)
clientService := clientucase.NewClientUseCase(clients) clientService := clientucase.NewClientUseCase(clients)
return Dependencies{ return Dependencies{
users: users,
authService: authService, authService: authService,
clients: clients, clients: clients,
clientService: clientService, clientService: clientService,
@ -118,6 +125,5 @@ func NewDependencies(tb testing.TB) Dependencies {
matcher: matcher, matcher: matcher,
sessions: sessions, sessions: sessions,
profiles: profiles, profiles: profiles,
store: store,
} }
} }

View File

@ -9,8 +9,8 @@ import (
type ( type (
GenerateOptions struct { GenerateOptions struct {
ClientID *domain.ClientID ClientID domain.ClientID
Me *domain.Me Me domain.Me
RedirectURI *url.URL RedirectURI *url.URL
CodeChallengeMethod domain.CodeChallengeMethod CodeChallengeMethod domain.CodeChallengeMethod
Scope domain.Scopes Scope domain.Scopes
@ -18,7 +18,7 @@ type (
} }
ExchangeOptions struct { ExchangeOptions struct {
ClientID *domain.ClientID ClientID domain.ClientID
RedirectURI *url.URL RedirectURI *url.URL
Code string Code string
CodeVerifier string CodeVerifier string

View File

@ -45,7 +45,7 @@ func (uc *authUseCase) Generate(ctx context.Context, opts auth.GenerateOptions)
} }
} }
if err = uc.sessions.Create(ctx, &domain.Session{ if err = uc.sessions.Create(ctx, domain.Session{
ClientID: opts.ClientID, ClientID: opts.ClientID,
Code: code, Code: code,
CodeChallenge: opts.CodeChallenge, CodeChallenge: opts.CodeChallenge,
@ -81,5 +81,5 @@ func (uc *authUseCase) Exchange(ctx context.Context, opts auth.ExchangeOptions)
return nil, nil, auth.ErrMismatchPKCE return nil, nil, auth.ErrMismatchPKCE
} }
return session.Me, session.Profile, nil return &session.Me, session.Profile, nil
} }

View File

@ -1,48 +1,37 @@
package http package http
import ( import (
"errors" "net/http"
"strings" "strings"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/token" "source.toby3d.me/toby3d/auth/internal/token"
"source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/auth/web" "source.toby3d.me/toby3d/auth/web"
"source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware"
) )
type ( type (
ClientCallbackRequest struct { NewHandlerOptions struct {
Error domain.ErrorCode `form:"error,omitempty"`
Iss *domain.ClientID `form:"iss"`
Code string `form:"code"`
ErrorDescription string `form:"error_description,omitempty"`
State string `form:"state"`
}
NewRequestHandlerOptions struct {
Matcher language.Matcher Matcher language.Matcher
Tokens token.UseCase Tokens token.UseCase
Client *domain.Client Client domain.Client
Config *domain.Config Config domain.Config
} }
RequestHandler struct { Handler struct {
matcher language.Matcher matcher language.Matcher
tokens token.UseCase tokens token.UseCase
client *domain.Client client domain.Client
config *domain.Config config domain.Config
} }
) )
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { func NewHandler(opts NewHandlerOptions) *Handler {
return &RequestHandler{ return &Handler{
client: opts.Client, client: opts.Client,
config: opts.Config, config: opts.Config,
matcher: opts.Matcher, matcher: opts.Matcher,
@ -50,59 +39,82 @@ func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
} }
} }
func (h *RequestHandler) Register(r *router.Router) { func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
middleware.LogFmt(), if r.Method != "" && r.Method != http.MethodGet {
} http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
r.GET("/", chain.RequestHandler(h.handleRender)) return
r.GET("/callback", chain.RequestHandler(h.handleCallback)) }
var head string
head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
switch head {
default:
http.NotFound(w, r)
case "":
h.handleRender(w, r)
case "callback":
h.handleCallback(w, r)
}
})
} }
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { func (h *Handler) handleRender(w http.ResponseWriter, r *http.Request) {
if r.Method != "" && r.Method != http.MethodGet {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
redirect := make([]string, len(h.client.RedirectURI)) redirect := make([]string, len(h.client.RedirectURI))
for i := range h.client.RedirectURI { for i := range h.client.RedirectURI {
redirect[i] = h.client.RedirectURI[i].String() redirect[i] = h.client.RedirectURI[i].String()
} }
ctx.Response.Header.Set( w.Header().Set(common.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`)
http.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...) tag, _, _ := h.matcher.Match(tags...)
// TODO(toby3d): generate and store PKCE // TODO(toby3d): generate and store PKCE
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(ctx, &web.HomePage{ web.WriteTemplate(w, &web.HomePage{
BaseOf: web.BaseOf{ BaseOf: web.BaseOf{
Config: h.config, Config: &h.config,
Language: tag, Language: tag,
Printer: message.NewPrinter(tag), Printer: message.NewPrinter(tag),
}, },
Client: h.client, Client: &h.client,
State: "hackme", // TODO(toby3d): generate and store state State: "hackme", // TODO(toby3d): generate and store state
}) })
} }
//nolint:unlen //nolint:unlen
func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) { func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) if r.Method != "" && r.Method != http.MethodGet {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) return
}
w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...) tag, _, _ := h.matcher.Match(tags...)
baseOf := web.BaseOf{ baseOf := web.BaseOf{
Config: h.config, Config: &h.config,
Language: tag, Language: tag,
Printer: message.NewPrinter(tag), Printer: message.NewPrinter(tag),
} }
req := new(ClientCallbackRequest) req := new(ClientCallbackRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
web.WriteTemplate(ctx, &web.ErrorPage{ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf, BaseOf: baseOf,
Error: err, Error: err,
}) })
@ -111,8 +123,8 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
} }
if req.Error != domain.ErrorCodeUnd { if req.Error != domain.ErrorCodeUnd {
ctx.SetStatusCode(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
web.WriteTemplate(ctx, &web.ErrorPage{ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf, BaseOf: baseOf,
Error: domain.NewError( Error: domain.NewError(
domain.ErrorCodeAccessDenied, domain.ErrorCodeAccessDenied,
@ -127,9 +139,9 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
// TODO(toby3d): load and check state // TODO(toby3d): load and check state
if req.Iss == nil || req.Iss.String() != h.client.ID.String() { if req.Iss.String() != h.client.ID.String() {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf, BaseOf: baseOf,
Error: domain.NewError( Error: domain.NewError(
domain.ErrorCodeInvalidClient, domain.ErrorCodeInvalidClient,
@ -142,15 +154,15 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
return return
} }
token, _, err := h.tokens.Exchange(ctx, token.ExchangeOptions{ token, _, err := h.tokens.Exchange(r.Context(), token.ExchangeOptions{
ClientID: h.client.ID, ClientID: h.client.ID,
RedirectURI: h.client.RedirectURI[0], RedirectURI: h.client.RedirectURI[0],
Code: req.Code, Code: req.Code,
CodeVerifier: "", // TODO(toby3d): validate PKCE here CodeVerifier: "", // TODO(toby3d): validate PKCE here
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{ web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf, BaseOf: baseOf,
Error: err, Error: err,
}) })
@ -158,23 +170,9 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
return return
} }
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(ctx, &web.CallbackPage{ web.WriteTemplate(w, &web.CallbackPage{
BaseOf: baseOf, BaseOf: baseOf,
Token: token, Token: token,
}) })
} }
func (req *ClientCallbackRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs().QueryString(), req); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")
}
return nil
}

View File

@ -0,0 +1,31 @@
package http
import (
"errors"
"net/http"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/form"
)
type ClientCallbackRequest struct {
Error domain.ErrorCode `form:"error,omitempty"`
Iss domain.ClientID `form:"iss"`
Code string `form:"code"`
ErrorDescription string `form:"error_description,omitempty"`
State string `form:"state"`
}
func (req *ClientCallbackRequest) bind(r *http.Request) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal([]byte(r.URL.Query().Encode()), req); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")
}
return nil
}

View File

@ -1,11 +1,10 @@
package http_test package http_test
import ( import (
"sync" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
@ -15,7 +14,6 @@ import (
profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory" profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory"
"source.toby3d.me/toby3d/auth/internal/session" "source.toby3d.me/toby3d/auth/internal/session"
sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory" sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
"source.toby3d.me/toby3d/auth/internal/token" "source.toby3d.me/toby3d/auth/internal/token"
tokenrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory" tokenrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory"
tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase" tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase"
@ -27,7 +25,6 @@ type Dependencies struct {
config *domain.Config config *domain.Config
matcher language.Matcher matcher language.Matcher
sessions session.Repository sessions session.Repository
store *sync.Map
tokens token.Repository tokens token.Repository
tokenService token.UseCase tokenService token.UseCase
} }
@ -36,45 +33,30 @@ func TestRead(t *testing.T) {
t.Parallel() t.Parallel()
deps := NewDependencies(t) deps := NewDependencies(t)
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/", nil)
w := httptest.NewRecorder()
r := router.New() delivery.NewHandler(delivery.NewHandlerOptions{
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{ Client: *deps.client,
Client: deps.client, Config: *deps.config,
Config: deps.config,
Matcher: deps.matcher, Matcher: deps.matcher,
Tokens: deps.tokenService, Tokens: deps.tokenService,
}).Register(r) }).Handler().ServeHTTP(w, req)
client, _, cleanup := httptest.New(t, r.Handler) if resp := w.Result(); resp.StatusCode != http.StatusOK {
t.Cleanup(cleanup) t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK)
const requestURI string = "https://app.example.com/"
req, resp := httptest.NewRequest(http.MethodGet, requestURI, nil), http.AcquireResponse()
t.Cleanup(func() {
http.ReleaseRequest(req)
http.ReleaseResponse(resp)
})
if err := client.Do(req, resp); err != nil {
t.Error(err)
}
if resp.StatusCode() != http.StatusOK {
t.Errorf("GET %s = %d, want %d", requestURI, resp.StatusCode(), http.StatusOK)
} }
} }
func NewDependencies(tb testing.TB) Dependencies { func NewDependencies(tb testing.TB) Dependencies {
tb.Helper() tb.Helper()
store := new(sync.Map)
client := domain.TestClient(tb) client := domain.TestClient(tb)
config := domain.TestConfig(tb) config := domain.TestConfig(tb)
matcher := language.NewMatcher(message.DefaultCatalog.Languages()) matcher := language.NewMatcher(message.DefaultCatalog.Languages())
sessions := sessionrepo.NewMemorySessionRepository(store, config) sessions := sessionrepo.NewMemorySessionRepository(*config)
tokens := tokenrepo.NewMemoryTokenRepository(store) tokens := tokenrepo.NewMemoryTokenRepository()
profiles := profilerepo.NewMemoryProfileRepository(store) profiles := profilerepo.NewMemoryProfileRepository()
tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{ tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{
Config: config, Config: config,
Profiles: profiles, Profiles: profiles,
@ -87,7 +69,6 @@ func NewDependencies(tb testing.TB) Dependencies {
config: config, config: config,
matcher: matcher, matcher: matcher,
sessions: sessions, sessions: sessions,
store: store,
profiles: profiles, profiles: profiles,
tokens: tokens, tokens: tokens,
tokenService: tokenService, tokenService: tokenService,

View File

@ -7,7 +7,8 @@ import (
) )
type Repository interface { type Repository interface {
Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) Create(ctx context.Context, client domain.Client) error
Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error)
} }
var ErrNotExist error = domain.NewError( var ErrNotExist error = domain.NewError(

View File

@ -1,14 +1,17 @@
package http package http
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"net" "io"
"net/http"
"net/url" "net/url"
http "github.com/valyala/fasthttp" "golang.org/x/exp/slices"
"source.toby3d.me/toby3d/auth/internal/client" "source.toby3d.me/toby3d/auth/internal/client"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/httputil" "source.toby3d.me/toby3d/auth/internal/httputil"
) )
@ -34,33 +37,18 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
} }
} }
func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) { // WARN(toby3d): not implemented.
ips, err := net.LookupIP(cid.URL().Hostname()) func (httpClientRepository) Create(_ context.Context, _ domain.Client) error {
return nil
}
func (repo httpClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) {
resp, err := repo.client.Get(cid.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot resolve client IP by id: %w", err)
}
for _, ip := range ips {
if !ip.IsLoopback() {
continue
}
return nil, client.ErrNotExist
}
req := http.AcquireRequest()
defer http.ReleaseRequest(req)
req.SetRequestURI(cid.String())
req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
return nil, fmt.Errorf("failed to make a request to the client: %w", err) return nil, fmt.Errorf("failed to make a request to the client: %w", err)
} }
if resp.StatusCode() == http.StatusNotFound { if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("%w: status on client page is not 200", client.ErrNotExist) return nil, fmt.Errorf("%w: status on client page is not 200", client.ErrNotExist)
} }
@ -72,74 +60,62 @@ func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID)
Name: make([]string, 0), Name: make([]string, 0),
} }
extract(client, resp) extract(resp.Body, resp.Request.URL, client, resp.Header.Get(common.HeaderLink))
return client, nil return client, nil
} }
//nolint:gocognit,cyclop //nolint:gocognit,cyclop
func extract(dst *domain.Client, src *http.Response) { func extract(r io.Reader, u *url.URL, dst *domain.Client, header string) {
for _, endpoint := range httputil.ExtractEndpoints(src, relRedirectURI) { body, _ := io.ReadAll(r)
if !containsURL(dst.RedirectURI, endpoint) {
for _, endpoint := range httputil.ExtractEndpoints(bytes.NewReader(body), u, header, relRedirectURI) {
if !containsUrl(dst.RedirectURI, endpoint) {
dst.RedirectURI = append(dst.RedirectURI, endpoint) dst.RedirectURI = append(dst.RedirectURI, endpoint)
} }
} }
for _, itemType := range []string{hXApp, hApp} { for _, itemType := range []string{hApp, hXApp} {
for _, name := range httputil.ExtractProperty(src, itemType, propertyName) { for _, name := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyName) {
if n, ok := name.(string); ok && !containsString(dst.Name, n) { if n, ok := name.(string); ok && !slices.Contains(dst.Name, n) {
dst.Name = append(dst.Name, n) dst.Name = append(dst.Name, n)
} }
} }
for _, logo := range httputil.ExtractProperty(src, itemType, propertyLogo) { for _, logo := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyLogo) {
var ( var logoURL *url.URL
u *url.URL var err error
err error
)
switch l := logo.(type) { switch l := logo.(type) {
case string: case string:
u, err = url.Parse(l) logoURL, err = url.Parse(l)
case map[string]string: case map[string]string:
if value, ok := l["value"]; ok { if value, ok := l["value"]; ok {
u, err = url.Parse(value) logoURL, err = url.Parse(value)
} }
} }
if err != nil || containsURL(dst.Logo, u) { if err != nil || containsUrl(dst.Logo, logoURL) {
continue continue
} }
dst.Logo = append(dst.Logo, u) dst.Logo = append(dst.Logo, logoURL)
} }
for _, property := range httputil.ExtractProperty(src, itemType, propertyURL) { for _, property := range httputil.ExtractProperty(bytes.NewReader(body), u, itemType, propertyURL) {
prop, ok := property.(string) prop, ok := property.(string)
if !ok { if !ok {
continue continue
} }
if u, err := url.Parse(prop); err == nil || !containsURL(dst.URL, u) { if u, err := url.Parse(prop); err == nil && !containsUrl(dst.URL, u) {
dst.URL = append(dst.URL, u) dst.URL = append(dst.URL, u)
} }
} }
} }
} }
func containsString(src []string, find string) bool { func containsUrl(src []*url.URL, find *url.URL) bool {
for i := range src {
if src[i] != find {
continue
}
return true
}
return false
}
func containsURL(src []*url.URL, find *url.URL) bool {
for i := range src { for i := range src {
if src[i].String() != find.String() { if src[i].String() != find.String() {
continue continue

View File

@ -3,22 +3,21 @@ package http_test
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/google/go-cmp/cmp"
http "github.com/valyala/fasthttp"
repository "source.toby3d.me/toby3d/auth/internal/client/repository/http" repository "source.toby3d.me/toby3d/auth/internal/client/repository/http"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
) )
const testBody string = ` const testBody string = `<!DOCTYPE html>
<!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>%[1]s</title> <title>%[1]s</title>
<link rel="redirect_uri" href="%[4]s"> <link rel="redirect_uri" href="%[4]s">
@ -36,38 +35,47 @@ func TestGet(t *testing.T) {
t.Parallel() t.Parallel()
client := domain.TestClient(t) client := domain.TestClient(t)
httpClient, _, cleanup := httptest.New(t, testHandler(t, client)) srv := httptest.NewUnstartedServer(testHandler(t, *client))
t.Cleanup(cleanup) srv.EnableHTTP2 = true
result, err := repository.NewHTTPClientRepository(httpClient).Get(context.Background(), client.ID) srv.StartTLS()
t.Cleanup(srv.Close)
client.ID = *domain.TestClientID(t, srv.URL+"/")
clients := repository.NewHTTPClientRepository(srv.Client())
result, err := clients.Get(context.Background(), client.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
assert.Equal(t, client.Name, result.Name) if out := client.ID; !result.ID.IsEqual(out) {
assert.Equal(t, client.ID.String(), result.ID.String()) t.Errorf("GET %s = %s, want %s", client.ID, out, result.ID)
for i := range client.URL {
assert.Equal(t, client.URL[i].String(), result.URL[i].String())
} }
for i := range client.Logo { if !cmp.Equal(result.Name, client.Name) {
assert.Equal(t, client.Logo[i].String(), result.Logo[i].String()) t.Errorf("GET %s = %+s, want %+s", client.ID, result.Name, client.Name)
} }
for i := range client.RedirectURI { if !cmp.Equal(result.URL, client.URL) {
assert.Equal(t, client.RedirectURI[i].String(), result.RedirectURI[i].String()) t.Errorf("GET %s = %+s, want %+s", client.ID, result.URL, client.URL)
}
if !cmp.Equal(result.Logo, client.Logo) {
t.Errorf("GET %s = %+s, want %+s", client.ID, result.Logo, client.Logo)
}
if !cmp.Equal(result.RedirectURI, client.RedirectURI) {
t.Errorf("GET %s = %+s, want %+s", client.ID, result.RedirectURI, client.RedirectURI)
} }
} }
func testHandler(tb testing.TB, client *domain.Client) http.RequestHandler { func testHandler(tb testing.TB, client domain.Client) http.Handler {
tb.Helper() tb.Helper()
return func(ctx *http.RequestCtx) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx.Response.Header.Set(http.HeaderLink, `<`+client.RedirectURI[0].String()+`>; rel="redirect_uri"`) w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, fmt.Sprintf( w.Header().Set(common.HeaderLink, `<`+client.RedirectURI[0].String()+`>; rel="redirect_uri"`)
testBody, client.Name[0], client.URL[0].String(), client.Logo[0].String(), fmt.Fprintf(w, testBody, client.Name[0], client.URL[0], client.Logo[0], client.RedirectURI[1])
client.RedirectURI[1].String(), })
))
}
} }

View File

@ -2,9 +2,6 @@ package memory
import ( import (
"context" "context"
"fmt"
"net"
"path"
"sync" "sync"
"source.toby3d.me/toby3d/auth/internal/client" "source.toby3d.me/toby3d/auth/internal/client"
@ -12,45 +9,33 @@ import (
) )
type memoryClientRepository struct { type memoryClientRepository struct {
store *sync.Map mutex *sync.RWMutex
clients map[string]domain.Client
} }
const DefaultPathPrefix string = "clients" func NewMemoryClientRepository() client.Repository {
func NewMemoryClientRepository(store *sync.Map) client.Repository {
return &memoryClientRepository{ return &memoryClientRepository{
store: store, mutex: new(sync.RWMutex),
clients: make(map[string]domain.Client),
} }
} }
func (repo *memoryClientRepository) Create(ctx context.Context, client *domain.Client) error { func (repo memoryClientRepository) Create(ctx context.Context, client domain.Client) error {
repo.store.Store(path.Join(DefaultPathPrefix, client.ID.String()), client) repo.mutex.RLock()
defer repo.mutex.RUnlock()
repo.clients[client.ID.String()] = client
return nil return nil
} }
func (repo *memoryClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) { func (repo memoryClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) {
// WARN(toby3d): more often than not, we will work from tests with repo.mutex.RLock()
// non-existent clients, almost guaranteed to cause a resolution error. defer repo.mutex.RUnlock()
ips, _ := net.LookupIP(id.URL().Hostname())
for _, ip := range ips { if c, ok := repo.clients[cid.String()]; ok {
if !ip.IsLoopback() { return &c, nil
continue
}
return nil, client.ErrNotExist
} }
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, id.String())) return nil, client.ErrNotExist
if !ok {
return nil, fmt.Errorf("cannot find client in store: %w", client.ErrNotExist)
}
c, ok := src.(*domain.Client)
if !ok {
return nil, fmt.Errorf("cannot decode client from store: %w", client.ErrNotExist)
}
return c, nil
} }

View File

@ -1,31 +0,0 @@
package memory_test
import (
"context"
"path"
"reflect"
"sync"
"testing"
repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory"
"source.toby3d.me/toby3d/auth/internal/domain"
)
func TestGet(t *testing.T) {
t.Parallel()
client := domain.TestClient(t)
store := new(sync.Map)
store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client)
result, err := repository.NewMemoryClientRepository(store).
Get(context.Background(), client.ID)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(result, client) {
t.Errorf("Get(%s) = %+v, want %+v", client.ID, result, client)
}
}

View File

@ -8,7 +8,7 @@ import (
type UseCase interface { type UseCase interface {
// Discovery returns client public information bu ClientID URL. // Discovery returns client public information bu ClientID URL.
Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) Discovery(ctx context.Context, id domain.ClientID) (*domain.Client, error)
} }
var ErrInvalidMe error = domain.NewError( var ErrInvalidMe error = domain.NewError(

View File

@ -18,7 +18,7 @@ func NewClientUseCase(repo client.Repository) client.UseCase {
} }
} }
func (useCase *clientUseCase) Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) { func (useCase *clientUseCase) Discovery(ctx context.Context, id domain.ClientID) (*domain.Client, error) {
c, err := useCase.repo.Get(ctx, id) c, err := useCase.repo.Get(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot discovery client by id: %w", err) return nil, fmt.Errorf("cannot discovery client by id: %w", err)

View File

@ -3,12 +3,9 @@ package usecase_test
import ( import (
"context" "context"
"errors" "errors"
"path"
"reflect" "reflect"
"sync"
"testing" "testing"
"source.toby3d.me/toby3d/auth/internal/client"
repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory" repository "source.toby3d.me/toby3d/auth/internal/client/repository/memory"
"source.toby3d.me/toby3d/auth/internal/client/usecase" "source.toby3d.me/toby3d/auth/internal/client/usecase"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -17,12 +14,11 @@ import (
func TestDiscovery(t *testing.T) { func TestDiscovery(t *testing.T) {
t.Parallel() t.Parallel()
store := new(sync.Map) testClient := domain.TestClient(t)
testClient, localhostClient := domain.TestClient(t), domain.TestClient(t) clients := repository.NewMemoryClientRepository()
localhostClient.ID, _ = domain.ParseClientID("http://localhost/")
for _, client := range []*domain.Client{testClient, localhostClient} { if err := clients.Create(context.Background(), *testClient); err != nil {
store.Store(path.Join(repository.DefaultPathPrefix, client.ID.String()), client) t.Fatal(err)
} }
for _, tc := range []struct { for _, tc := range []struct {
@ -34,17 +30,13 @@ func TestDiscovery(t *testing.T) {
name: "default", name: "default",
in: testClient, in: testClient,
out: testClient, out: testClient,
}, {
name: "localhost",
in: localhostClient,
expError: client.ErrNotExist,
}} { }} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
result, err := usecase.NewClientUseCase(repository.NewMemoryClientRepository(store)). result, err := usecase.NewClientUseCase(clients).
Discovery(context.Background(), tc.in.ID) Discovery(context.Background(), tc.in.ID)
if tc.expError != nil && !errors.Is(err, tc.expError) { if tc.expError != nil && !errors.Is(err, tc.expError) {
t.Errorf("Discovery(%s) = %+v, want %+v", tc.in.ID, err, tc.expError) t.Errorf("Discovery(%s) = %+v, want %+v", tc.in.ID, err, tc.expError)

View File

@ -9,7 +9,7 @@ import (
// Client describes the client requesting data about the user. // Client describes the client requesting data about the user.
type Client struct { type Client struct {
ID *ClientID ID ClientID
Logo []*url.URL Logo []*url.URL
RedirectURI []*url.URL RedirectURI []*url.URL
URL []*url.URL URL []*url.URL
@ -17,7 +17,7 @@ type Client struct {
} }
// NewClient creates a new empty Client with provided ClientID, if any. // NewClient creates a new empty Client with provided ClientID, if any.
func NewClient(cid *ClientID) *Client { func NewClient(cid ClientID) *Client {
return &Client{ return &Client{
ID: cid, ID: cid,
Logo: make([]*url.URL, 0), Logo: make([]*url.URL, 0),
@ -32,7 +32,7 @@ func TestClient(tb testing.TB) *Client {
tb.Helper() tb.Helper()
return &Client{ return &Client{
ID: TestClientID(tb), ID: *TestClientID(tb),
Name: []string{"Example App"}, Name: []string{"Example App"},
URL: []*url.URL{{Scheme: "https", Host: "app.example.com", Path: "/"}}, URL: []*url.URL{{Scheme: "https", Host: "app.example.com", Path: "/"}},
Logo: []*url.URL{{Scheme: "https", Host: "app.example.com", Path: "/logo.png"}}, Logo: []*url.URL{{Scheme: "https", Host: "app.example.com", Path: "/logo.png"}},

View File

@ -8,6 +8,8 @@ import (
"testing" "testing"
"inet.af/netaddr" "inet.af/netaddr"
"source.toby3d.me/toby3d/auth/internal/common"
) )
// ClientID is a URL client identifier. // ClientID is a URL client identifier.
@ -37,16 +39,20 @@ func ParseClientID(src string) (*ClientID, error) {
if cid.Scheme != "http" && cid.Scheme != "https" { if cid.Scheme != "http" && cid.Scheme != "https" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client identifier URL MUST have either an https or http scheme", "client identifier URL MUST have either an https or http scheme, got '"+cid.Scheme+"'",
"https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
) )
} }
if cid.Path == "" || strings.Contains(cid.Path, "/.") || strings.Contains(cid.Path, "/..") { if cid.Path == "" {
cid.Path = "/"
}
if strings.Contains(cid.Path, "/.") || strings.Contains(cid.Path, "/..") {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client identifier URL MUST contain a path component and MUST NOT contain "+ "client identifier URL MUST contain a path component and MUST NOT contain "+
"single-dot or double-dot path segments", "single-dot or double-dot path segments, got '"+cid.Path+"'",
"https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
) )
} }
@ -54,7 +60,7 @@ func ParseClientID(src string) (*ClientID, error) {
if cid.Fragment != "" { if cid.Fragment != "" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client identifier URL MUST NOT contain a fragment component", "client identifier URL MUST NOT contain a fragment component, got '"+cid.Fragment+"'",
"https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
) )
} }
@ -62,7 +68,8 @@ func ParseClientID(src string) (*ClientID, error) {
if cid.User != nil { if cid.User != nil {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client identifier URL MUST NOT contain a username or password component", "client identifier URL MUST NOT contain a username or password component, got '"+
cid.User.String()+"'",
"https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
) )
} }
@ -71,7 +78,7 @@ func ParseClientID(src string) (*ClientID, error) {
if domain == "" { if domain == "" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"client host name MUST be domain name or a loopback interface", "client host name MUST be domain name or a loopback interface, got '"+domain+"'",
"https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
) )
} }
@ -102,10 +109,15 @@ func ParseClientID(src string) (*ClientID, error) {
} }
// TestClientID returns valid random generated ClientID for tests. // TestClientID returns valid random generated ClientID for tests.
func TestClientID(tb testing.TB) *ClientID { func TestClientID(tb testing.TB, forceURL ...string) *ClientID {
tb.Helper() tb.Helper()
clientID, err := ParseClientID("https://example.com/") in := "https://app.example.com/"
if len(forceURL) > 0 {
in = forceURL[0]
}
clientID, err := ParseClientID(in)
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
} }
@ -147,6 +159,11 @@ func (cid ClientID) MarshalJSON() ([]byte, error) {
return []byte(strconv.Quote(cid.String())), nil return []byte(strconv.Quote(cid.String())), nil
} }
// IsEqual checks what cid is equal to provided v.
func (cid ClientID) IsEqual(v ClientID) bool {
return cid.clientID.String() == v.clientID.String()
}
// URL returns url.URL representation of client ID. // URL returns url.URL representation of client ID.
func (cid ClientID) URL() *url.URL { func (cid ClientID) URL() *url.URL {
out, _ := url.Parse(cid.clientID.String()) out, _ := url.Parse(cid.clientID.String())
@ -156,5 +173,17 @@ func (cid ClientID) URL() *url.URL {
// String returns string representation of client ID. // String returns string representation of client ID.
func (cid ClientID) String() string { func (cid ClientID) String() string {
if cid.clientID == nil {
return ""
}
return cid.clientID.String() return cid.clientID.String()
} }
func (cid ClientID) GoString() string {
if cid.clientID == nil {
return "domain.ClientID(" + common.Und + ")"
}
return "domain.ClientID(" + cid.clientID.String() + ")"
}

View File

@ -114,7 +114,7 @@ func TestCodeChallengeMethod_String(t *testing.T) {
func TestCodeChallengeMethod_Validate(t *testing.T) { func TestCodeChallengeMethod_Validate(t *testing.T) {
t.Parallel() t.Parallel()
verifier, err := random.String(gofakeit.Number(43, 128)) verifier, err := random.String(uint8(gofakeit.Number(43, 128)))
if err != nil { if err != nil {
t.Fatalf("%+v", err) t.Fatalf("%+v", err)
} }

View File

@ -29,7 +29,6 @@ type (
Port string `yaml:"port"` Port string `yaml:"port"`
Protocol string `yaml:"protocol"` Protocol string `yaml:"protocol"`
RootURL string `yaml:"rootUrl"` RootURL string `yaml:"rootUrl"`
StaticRootPath string `yaml:"staticRootPath"`
StaticURLPrefix string `yaml:"staticUrlPrefix"` StaticURLPrefix string `yaml:"staticUrlPrefix"`
EnablePprof bool `yaml:"enablePprof"` EnablePprof bool `yaml:"enablePprof"`
} }
@ -44,14 +43,14 @@ type (
// exchange it for a token or user information. // exchange it for a token or user information.
ConfigCode struct { ConfigCode struct {
Expiry time.Duration `yaml:"expiry"` // 10m Expiry time.Duration `yaml:"expiry"` // 10m
Length int `yaml:"length"` // 32 Length uint8 `yaml:"length"` // 32
} }
ConfigJWT struct { ConfigJWT struct {
Expiry time.Duration `yaml:"expiry"` // 1h Expiry time.Duration `yaml:"expiry"` // 1h
Algorithm string `yaml:"algorithm"` // HS256 Algorithm string `yaml:"algorithm"` // HS256
Secret string `yaml:"secret"` Secret string `yaml:"secret"`
NonceLength int `yaml:"nonceLength"` // 22 NonceLength uint8 `yaml:"nonceLength"` // 22
} }
ConfigIndieAuth struct { ConfigIndieAuth struct {
@ -62,7 +61,7 @@ type (
ConfigTicketAuth struct { ConfigTicketAuth struct {
Expiry time.Duration `yaml:"expiry"` // 1m Expiry time.Duration `yaml:"expiry"` // 1m
Length int `yaml:"length"` // 24 Length uint8 `yaml:"length"` // 24
} }
ConfigRelMeAuth struct { ConfigRelMeAuth struct {
@ -95,7 +94,6 @@ func TestConfig(tb testing.TB) *Config {
Port: "3000", Port: "3000",
Protocol: "http", Protocol: "http",
RootURL: "{{protocol}}://{{domain}}:{{port}}/", RootURL: "{{protocol}}://{{domain}}:{{port}}/",
StaticRootPath: "/",
StaticURLPrefix: "/static", StaticURLPrefix: "/static",
}, },
Database: ConfigDatabase{ Database: ConfigDatabase{
@ -136,7 +134,6 @@ func (cs ConfigServer) GetRootURL() string {
"host": cs.Host, "host": cs.Host,
"port": cs.Port, "port": cs.Port,
"protocol": cs.Protocol, "protocol": cs.Protocol,
"staticRootPath": cs.StaticRootPath,
"staticUrlPrefix": cs.StaticURLPrefix, "staticUrlPrefix": cs.StaticURLPrefix,
}) })
} }

View File

@ -31,7 +31,7 @@ func ParseMe(raw string) (*Me, error) {
if id.Scheme != "http" && id.Scheme != "https" { if id.Scheme != "http" && id.Scheme != "https" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"profile URL MUST have either an https or http scheme", "profile URL MUST have either an https or http scheme, got '"+id.Scheme+"'",
"https://indieauth.net/source/#user-profile-url", "https://indieauth.net/source/#user-profile-url",
"", "",
) )
@ -45,7 +45,7 @@ func ParseMe(raw string) (*Me, error) {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"profile URL MUST contain a path component (/ is a valid path), MUST NOT contain single-dot "+ "profile URL MUST contain a path component (/ is a valid path), MUST NOT contain single-dot "+
"or double-dot path segments", "or double-dot path segments, got '"+id.Path+"'",
"https://indieauth.net/source/#user-profile-url", "https://indieauth.net/source/#user-profile-url",
"", "",
) )
@ -54,7 +54,7 @@ func ParseMe(raw string) (*Me, error) {
if id.Fragment != "" { if id.Fragment != "" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"profile URL MUST NOT contain a fragment component", "profile URL MUST NOT contain a fragment component, got '"+id.Fragment+"'",
"https://indieauth.net/source/#user-profile-url", "https://indieauth.net/source/#user-profile-url",
"", "",
) )
@ -63,7 +63,7 @@ func ParseMe(raw string) (*Me, error) {
if id.User != nil { if id.User != nil {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"profile URL MUST NOT contain a username or password component", "profile URL MUST NOT contain a username or password component, got '"+id.User.String()+"'",
"https://indieauth.net/source/#user-profile-url", "https://indieauth.net/source/#user-profile-url",
"", "",
) )
@ -72,7 +72,7 @@ func ParseMe(raw string) (*Me, error) {
if id.Host == "" { if id.Host == "" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"profile host name MUST be a domain name", "profile host name MUST be a domain name, got '"+id.Host+"'",
"https://indieauth.net/source/#user-profile-url", "https://indieauth.net/source/#user-profile-url",
"", "",
) )
@ -81,16 +81,16 @@ func ParseMe(raw string) (*Me, error) {
if _, port, _ := net.SplitHostPort(id.Host); port != "" { if _, port, _ := net.SplitHostPort(id.Host); port != "" {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"profile MUST NOT contain a port", "profile MUST NOT contain a port, got '"+port+"'",
"https://indieauth.net/source/#user-profile-url", "https://indieauth.net/source/#user-profile-url",
"", "",
) )
} }
if net.ParseIP(id.Host) != nil { if out := net.ParseIP(id.Host); out != nil {
return nil, NewError( return nil, NewError(
ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
"profile MUST NOT be ipv4 or ipv6 addresses", "profile MUST NOT be ipv4 or ipv6 addresses, got '"+out.String()+"'",
"https://indieauth.net/source/#user-profile-url", "https://indieauth.net/source/#user-profile-url",
"", "",
) )
@ -103,12 +103,12 @@ func ParseMe(raw string) (*Me, error) {
func TestMe(tb testing.TB, src string) *Me { func TestMe(tb testing.TB, src string) *Me {
tb.Helper() tb.Helper()
me, err := ParseMe(src) u, err := url.Parse(src)
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
} }
return me return &Me{id: u}
} }
// UnmarshalForm implements custom unmarshler for form values. // UnmarshalForm implements custom unmarshler for form values.

View File

@ -14,7 +14,7 @@ type Metadata struct {
// issuer URL could be https://example.com/, or for a metadata URL of // issuer URL could be https://example.com/, or for a metadata URL of
// https://example.com/wp-json/indieauth/1.0/metadata, the issuer URL // https://example.com/wp-json/indieauth/1.0/metadata, the issuer URL
// could be https://example.com/wp-json/indieauth/1.0 // could be https://example.com/wp-json/indieauth/1.0
Issuer *ClientID Issuer *url.URL
// The Authorization Endpoint. // The Authorization Endpoint.
AuthorizationEndpoint *url.URL AuthorizationEndpoint *url.URL
@ -81,7 +81,11 @@ func TestMetadata(tb testing.TB) *Metadata {
tb.Helper() tb.Helper()
return &Metadata{ return &Metadata{
Issuer: TestClientID(tb), Issuer: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/.well-known/oauth-authorization-server",
},
AuthorizationEndpoint: &url.URL{Scheme: "https", Host: "indieauth.example.com", Path: "/auth"}, AuthorizationEndpoint: &url.URL{Scheme: "https", Host: "indieauth.example.com", Path: "/auth"},
TokenEndpoint: &url.URL{Scheme: "https", Host: "indieauth.example.com", Path: "/token"}, TokenEndpoint: &url.URL{Scheme: "https", Host: "indieauth.example.com", Path: "/token"},
TicketEndpoint: &url.URL{Scheme: "https", Host: "auth.example.org", Path: "/ticket"}, TicketEndpoint: &url.URL{Scheme: "https", Host: "auth.example.org", Path: "/ticket"},

View File

@ -1,10 +1,9 @@
package domain package domain
import ( import (
"net/url"
"path" "path"
"strings" "strings"
http "github.com/valyala/fasthttp"
) )
// Provider represent 3rd party RelMeAuth provider. // Provider represent 3rd party RelMeAuth provider.
@ -91,9 +90,10 @@ var (
// AuthCodeURL returns URL for authorize user in RelMeAuth client. // AuthCodeURL returns URL for authorize user in RelMeAuth client.
func (p Provider) AuthCodeURL(state string) string { func (p Provider) AuthCodeURL(state string) string {
uri := http.AcquireURI() u, err := url.Parse(p.AuthURL)
defer http.ReleaseURI(uri) if err != nil {
uri.Update(p.AuthURL) return ""
}
for key, val := range map[string]string{ for key, val := range map[string]string{
"client_id": p.ClientID, "client_id": p.ClientID,
@ -102,8 +102,8 @@ func (p Provider) AuthCodeURL(state string) string {
"scope": strings.Join(p.Scopes, " "), "scope": strings.Join(p.Scopes, " "),
"state": state, "state": state,
} { } {
uri.QueryArgs().Set(key, val) u.Query().Set(key, val)
} }
return uri.String() return u.String()
} }

View File

@ -80,6 +80,22 @@ func ParseScope(uid string) (Scope, error) {
return ScopeUnd, fmt.Errorf("%w: %s", ErrScopeUnknown, uid) return ScopeUnd, fmt.Errorf("%w: %s", ErrScopeUnknown, uid)
} }
func (s *Scope) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return fmt.Errorf("Scope: UnmarshalJSON: cannot unquote string: %w", err)
}
out, err := ParseScope(src)
if err != nil {
return fmt.Errorf("Scopes: UnmarshalJSON: cannot parse scope: %w", err)
}
*s = out
return nil
}
func (s Scope) MarshalJSON() ([]byte, error) { func (s Scope) MarshalJSON() ([]byte, error) {
return []byte(strconv.Quote(s.uid)), nil return []byte(strconv.Quote(s.uid)), nil
} }

View File

@ -9,9 +9,9 @@ import (
//nolint:tagliatelle //nolint:tagliatelle
type Session struct { type Session struct {
ClientID *ClientID `json:"client_id"` ClientID ClientID `json:"client_id"`
RedirectURI *url.URL `json:"redirect_uri"` RedirectURI *url.URL `json:"redirect_uri"`
Me *Me `json:"me"` Me Me `json:"me"`
Profile *Profile `json:"profile,omitempty"` Profile *Profile `json:"profile,omitempty"`
Scope Scopes `json:"scope"` Scope Scopes `json:"scope"`
CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method,omitempty"` CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method,omitempty"`
@ -31,12 +31,12 @@ func TestSession(tb testing.TB) *Session {
} }
return &Session{ return &Session{
ClientID: TestClientID(tb), ClientID: *TestClientID(tb),
Code: code, Code: code,
CodeChallenge: "hackme", CodeChallenge: "hackme",
CodeChallengeMethod: CodeChallengeMethodPLAIN, CodeChallengeMethod: CodeChallengeMethodPLAIN,
Profile: TestProfile(tb), Profile: TestProfile(tb),
Me: TestMe(tb, "https://user.example.net/"), Me: *TestMe(tb, "https://user.example.net/"),
RedirectURI: &url.URL{Scheme: "https", Host: "example.com", Path: "/callback"}, RedirectURI: &url.URL{Scheme: "https", Host: "example.com", Path: "/callback"},
Scope: Scopes{ Scope: Scopes{
ScopeEmail, ScopeEmail,

View File

@ -2,13 +2,14 @@ package domain
import ( import (
"fmt" "fmt"
"net/http"
"testing" "testing"
"time" "time"
"github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt" "github.com/lestrrat-go/jwx/v2/jwt"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/random" "source.toby3d.me/toby3d/auth/internal/random"
) )
@ -17,8 +18,8 @@ type (
Token struct { Token struct {
CreatedAt time.Time CreatedAt time.Time
Expiry time.Time Expiry time.Time
ClientID *ClientID ClientID ClientID
Me *Me Me Me
Scope Scopes Scope Scopes
AccessToken string AccessToken string
RefreshToken string RefreshToken string
@ -27,12 +28,12 @@ type (
// NewTokenOptions contains options for NewToken function. // NewTokenOptions contains options for NewToken function.
NewTokenOptions struct { NewTokenOptions struct {
Expiration time.Duration Expiration time.Duration
Issuer *ClientID Issuer ClientID
Subject *Me Subject Me
Scope Scopes Scope Scopes
Secret []byte Secret []byte
Algorithm string Algorithm string
NonceLength int NonceLength uint8
} }
) )
@ -42,8 +43,8 @@ type (
var DefaultNewTokenOptions = NewTokenOptions{ var DefaultNewTokenOptions = NewTokenOptions{
Expiration: 0, Expiration: 0,
Scope: nil, Scope: nil,
Issuer: nil, Issuer: ClientID{},
Subject: nil, Subject: Me{},
Secret: nil, Secret: nil,
Algorithm: "HS256", Algorithm: "HS256",
NonceLength: 32, NonceLength: 32,
@ -82,7 +83,7 @@ func NewToken(opts NewTokenOptions) (*Token, error) {
} }
} }
if opts.Issuer != nil { if opts.Issuer.clientID != nil {
if err = tkn.Set(jwt.IssuerKey, opts.Issuer.String()); err != nil { if err = tkn.Set(jwt.IssuerKey, opts.Issuer.String()); err != nil {
return nil, fmt.Errorf("failed to set JWT token field: %w", err) return nil, fmt.Errorf("failed to set JWT token field: %w", err)
} }
@ -157,8 +158,8 @@ func TestToken(tb testing.TB) *Token {
return &Token{ return &Token{
CreatedAt: now.Add(-1 * time.Hour), CreatedAt: now.Add(-1 * time.Hour),
Expiry: now.Add(1 * time.Hour), Expiry: now.Add(1 * time.Hour),
ClientID: cid, ClientID: *cid,
Me: me, Me: *me,
Scope: scope, Scope: scope,
AccessToken: string(accessToken), AccessToken: string(accessToken),
RefreshToken: "", // TODO(toby3d) RefreshToken: "", // TODO(toby3d)
@ -171,7 +172,7 @@ func (t Token) SetAuthHeader(r *http.Request) {
return return
} }
r.Header.Set(http.HeaderAuthorization, t.String()) r.Header.Set(common.HeaderAuthorization, t.String())
} }
// String returns string representation of token. // String returns string representation of token.

View File

@ -1,13 +1,12 @@
package domain_test package domain_test
import ( import (
"bytes"
"fmt" "fmt"
"net/http"
"testing" "testing"
"time" "time"
http "github.com/valyala/fasthttp" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
) )
@ -40,16 +39,13 @@ func TestNewToken(t *testing.T) {
func TestToken_SetAuthHeader(t *testing.T) { func TestToken_SetAuthHeader(t *testing.T) {
t.Parallel() t.Parallel()
token := domain.TestToken(t) in := domain.TestToken(t)
expResult := []byte("Bearer " + token.AccessToken) req, _ := http.NewRequest(http.MethodGet, "https://example.com/", nil)
in.SetAuthHeader(req)
req := http.AcquireRequest() exp := "Bearer " + in.AccessToken
defer http.ReleaseRequest(req) if out := req.Header.Get(common.HeaderAuthorization); out != exp {
token.SetAuthHeader(req) t.Errorf("SetAuthHeader(%+v) = %s, want %s", req, out, exp)
result := req.Header.Peek(http.HeaderAuthorization)
if result == nil || !bytes.Equal(result, expResult) {
t.Errorf("SetAuthHeader(%+v) = %s, want %s", req, result, expResult)
} }
} }
@ -57,9 +53,9 @@ func TestToken_String(t *testing.T) {
t.Parallel() t.Parallel()
token := domain.TestToken(t) token := domain.TestToken(t)
expResult := "Bearer " + token.AccessToken exp := "Bearer " + token.AccessToken
if result := token.String(); result != expResult { if out := token.String(); out != exp {
t.Errorf("String() = %s, want %s", result, expResult) t.Errorf("String() = %s, want %s", out, exp)
} }
} }

View File

@ -5,6 +5,8 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"testing" "testing"
"source.toby3d.me/toby3d/auth/internal/common"
) )
// URL describe any valid HTTP URL. // URL describe any valid HTTP URL.
@ -75,3 +77,11 @@ func (u *URL) UnmarshalJSON(v []byte) error {
func (u URL) MarshalJSON() ([]byte, error) { func (u URL) MarshalJSON() ([]byte, error) {
return []byte(strconv.Quote(u.String())), nil return []byte(strconv.Quote(u.String())), nil
} }
func (u URL) GoString() string {
if u.URL == nil {
return "domain.URL(" + common.Und + ")"
}
return "domain.URL(" + u.URL.String() + ")"
}

View File

@ -5,7 +5,6 @@ import (
"net/http" "net/http"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/middleware"
) )
type Handler struct{} type Handler struct{}
@ -14,8 +13,8 @@ func NewHandler() *Handler {
return &Handler{} return &Handler{}
} }
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *Handler) Handler() http.Handler {
http.HandlerFunc(middleware.HandlerFunc(h.handleFunc).Intercept(middleware.LogFmt())).ServeHTTP(w, r) return http.HandlerFunc(h.handleFunc)
} }
func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) {

View File

@ -2,11 +2,10 @@ package http_test
import ( import (
"io" "io"
"net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
http "github.com/valyala/fasthttp"
delivery "source.toby3d.me/toby3d/auth/internal/health/delivery/http" delivery "source.toby3d.me/toby3d/auth/internal/health/delivery/http"
) )
@ -15,7 +14,10 @@ func TestRequestHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "https://example.com/health", nil) req := httptest.NewRequest(http.MethodGet, "https://example.com/health", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
delivery.NewHandler().ServeHTTP(w, req)
delivery.NewHandler().
Handler().
ServeHTTP(w, req)
resp := w.Result() resp := w.Result()

View File

@ -2,33 +2,74 @@ package httputil
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"io"
"io/ioutil"
"net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/goccy/go-json"
"github.com/tomnomnom/linkheader" "github.com/tomnomnom/linkheader"
http "github.com/valyala/fasthttp" "golang.org/x/exp/slices"
"willnorris.com/go/microformats" "willnorris.com/go/microformats"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
) )
const RelIndieauthMetadata = "indieauth-metadata"
var ErrEndpointNotExist = domain.NewError( var ErrEndpointNotExist = domain.NewError(
domain.ErrorCodeServerError, domain.ErrorCodeServerError,
"cannot found any endpoints", "cannot found any endpoints",
"https://indieauth.net/source/#discovery-0", "https://indieauth.net/source/#discovery-0",
) )
func ExtractEndpoints(resp *http.Response, rel string) []*url.URL { func ExtractFromMetadata(client *http.Client, u string) (*domain.Metadata, error) {
req, err := http.NewRequest(http.MethodGet, u, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(body)
endpoints := ExtractEndpoints(buf, resp.Request.URL, resp.Header.Get(common.HeaderLink), RelIndieauthMetadata)
if len(endpoints) == 0 {
return nil, ErrEndpointNotExist
}
if resp, err = client.Get(endpoints[len(endpoints)-1].String()); err != nil {
return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err)
}
result := new(domain.Metadata)
if err = json.NewDecoder(resp.Body).Decode(result); err != nil {
return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err)
}
return result, nil
}
func ExtractEndpoints(body io.Reader, u *url.URL, linkHeader, rel string) []*url.URL {
results := make([]*url.URL, 0) results := make([]*url.URL, 0)
urls, err := ExtractEndpointsFromHeader(resp, rel) urls, err := ExtractEndpointsFromHeader(linkHeader, rel)
if err == nil { if err == nil {
results = append(results, urls...) results = append(results, urls...)
} }
urls, err = ExtractEndpointsFromBody(resp, rel) urls, err = ExtractEndpointsFromBody(body, u, rel)
if err == nil { if err == nil {
results = append(results, urls...) results = append(results, urls...)
} }
@ -36,15 +77,15 @@ func ExtractEndpoints(resp *http.Response, rel string) []*url.URL {
return results return results
} }
func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*url.URL, error) { func ExtractEndpointsFromHeader(linkHeader, rel string) ([]*url.URL, error) {
results := make([]*url.URL, 0) results := make([]*url.URL, 0)
for _, link := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) { for _, link := range linkheader.Parse(linkHeader) {
if !strings.EqualFold(link.Rel, rel) { if !strings.EqualFold(link.Rel, rel) {
continue continue
} }
u, err := url.ParseRequestURI(link.URL) u, err := url.Parse(link.URL)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot parse header endpoint: %w", err) return nil, fmt.Errorf("cannot parse header endpoint: %w", err)
} }
@ -55,8 +96,8 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*url.URL, er
return results, nil return results, nil
} }
func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*url.URL, error) { func ExtractEndpointsFromBody(body io.Reader, u *url.URL, rel string) ([]*url.URL, error) {
endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), nil).Rels[rel] endpoints, ok := microformats.Parse(body, u).Rels[rel]
if !ok || len(endpoints) == 0 { if !ok || len(endpoints) == 0 {
return nil, ErrEndpointNotExist return nil, ErrEndpointNotExist
} }
@ -75,58 +116,23 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*url.URL, erro
return results, nil return results, nil
} }
func ExtractMetadata(resp *http.Response, client *http.Client) (*domain.Metadata, error) { func ExtractProperty(body io.Reader, u *url.URL, itemType, key string) []any {
endpoints := ExtractEndpoints(resp, "indieauth-metadata") if data := microformats.Parse(body, u); data != nil {
if len(endpoints) == 0 { return FindProperty(data.Items, itemType, key)
return nil, ErrEndpointNotExist }
}
return nil
_, body, err := client.Get(nil, endpoints[len(endpoints)-1].String()) }
if err != nil {
return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err) func FindProperty(src []*microformats.Microformat, itemType, key string) []any {
} for _, item := range src {
if slices.Contains(item.Type, itemType) {
result := new(domain.Metadata) return item.Properties[key]
if err = json.Unmarshal(body, result); err != nil { }
return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err)
} if result := FindProperty(item.Children, itemType, key); result != nil {
return result
return result, nil }
}
func ExtractProperty(resp *http.Response, itemType, key string) []interface{} {
//nolint:exhaustivestruct // only Host part in url.URL is needed
data := microformats.Parse(bytes.NewReader(resp.Body()), &url.URL{
Host: string(resp.Header.Peek(http.HeaderHost)),
})
return findProperty(data.Items, itemType, key)
}
func contains(src []string, find string) bool {
for i := range src {
if !strings.EqualFold(src[i], find) {
continue
}
return true
}
return false
}
func findProperty(src []*microformats.Microformat, itemType, key string) []interface{} {
for _, item := range src {
if contains(item.Type, itemType) {
return item.Properties[key]
}
result := findProperty(item.Children, itemType, key)
if result == nil {
continue
}
return result
} }
return nil return nil

View File

@ -1,30 +1,72 @@
package httputil_test package httputil_test
import ( import (
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing" "testing"
http "github.com/valyala/fasthttp" "github.com/google/go-cmp/cmp"
"source.toby3d.me/toby3d/auth/internal/httputil" "source.toby3d.me/toby3d/auth/internal/httputil"
) )
const testBody = `<html> const testBody = `<html>
<head>
<link rel="lipsum" href="https://example.com/">
<link rel="lipsum" href="https://example.net/">
</head>
<body class="h-page"> <body class="h-page">
<main class="h-card"> <main class="h-app">
<h1 class="p-name">Sample Name</h1> <h1 class="p-name">Sample Name</h1>
</main> </main>
</body> </body>
</html>` </html>`
func TestExtractEndpointsFromBody(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
if err != nil {
t.Fatal(err)
}
in := &http.Response{
Body: ioutil.NopCloser(strings.NewReader(testBody)),
Request: req,
}
out, err := httputil.ExtractEndpointsFromBody(in.Body, req.URL, "lipsum")
if err != nil {
t.Fatal(err)
}
exp := []*url.URL{
{Scheme: "https", Host: "example.com", Path: "/"},
{Scheme: "https", Host: "example.net", Path: "/"},
}
if !cmp.Equal(out, exp) {
t.Errorf(`ExtractProperty(resp, "h-card", "name") = %+s, want %+s`, out, exp)
}
}
func TestExtractProperty(t *testing.T) { func TestExtractProperty(t *testing.T) {
t.Parallel() t.Parallel()
resp := http.AcquireResponse() req, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
defer http.ReleaseResponse(resp) if err != nil {
resp.SetBodyString(testBody) t.Fatal(err)
}
results := httputil.ExtractProperty(resp, "h-card", "name") in := &http.Response{
if results == nil || results[0] != "Sample Name" { Body: ioutil.NopCloser(strings.NewReader(testBody)),
t.Errorf(`ExtractProperty(resp, "h-card", "name") = %+s, want %+s`, results, []string{"Sample Name"}) Request: req,
}
if out := httputil.ExtractProperty(in.Body, req.URL, "h-app", "name"); out == nil || out[0] != "Sample Name" {
t.Errorf(`ExtractProperty(%s, %s, %s) = %+s, want %+s`, req.URL, "h-app", "name", out,
[]string{"Sample Name"})
} }
} }

View File

@ -1,13 +1,12 @@
package http package http
import ( import (
"github.com/fasthttp/router" "net/http"
"github.com/goccy/go-json" "github.com/goccy/go-json"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/middleware"
) )
type ( type (
@ -60,28 +59,29 @@ type (
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
} }
RequestHandler struct { Handler struct {
metadata *domain.Metadata metadata *domain.Metadata
} }
) )
func NewRequestHandler(metadata *domain.Metadata) *RequestHandler { func NewHandler(metadata *domain.Metadata) *Handler {
return &RequestHandler{ return &Handler{
metadata: metadata, metadata: metadata,
} }
} }
func (h *RequestHandler) Register(r *router.Router) { func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{ return http.HandlerFunc(h.handleFunc)
middleware.LogFmt(),
}
r.GET("/.well-known/oauth-authorization-server", chain.RequestHandler(h.read))
} }
func (h *RequestHandler) read(ctx *http.RequestCtx) { func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) {
ctx.SetStatusCode(http.StatusOK) if r.Method != "" && r.Method != http.MethodGet {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
scopes, responseTypes, grantTypes, codeChallengeMethods := make([]string, 0), make([]string, 0), scopes, responseTypes, grantTypes, codeChallengeMethods := make([]string, 0), make([]string, 0),
make([]string, 0), make([]string, 0) make([]string, 0), make([]string, 0)
@ -103,7 +103,7 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) {
h.metadata.CodeChallengeMethodsSupported[i].String()) h.metadata.CodeChallengeMethodsSupported[i].String())
} }
_ = json.NewEncoder(ctx).Encode(&MetadataResponse{ _ = json.NewEncoder(w).Encode(&MetadataResponse{
AuthorizationEndpoint: h.metadata.AuthorizationEndpoint.String(), AuthorizationEndpoint: h.metadata.AuthorizationEndpoint.String(),
IntrospectionEndpoint: h.metadata.IntrospectionEndpoint.String(), IntrospectionEndpoint: h.metadata.IntrospectionEndpoint.String(),
Issuer: h.metadata.Issuer.String(), Issuer: h.metadata.Issuer.String(),
@ -123,4 +123,6 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) {
// client_secret_basic according to RFC8414. // client_secret_basic according to RFC8414.
RevocationEndpointAuthMethodsSupported: h.metadata.RevocationEndpointAuthMethodsSupported, RevocationEndpointAuthMethodsSupported: h.metadata.RevocationEndpointAuthMethodsSupported,
}) })
w.WriteHeader(http.StatusOK)
} }

View File

@ -1,40 +1,36 @@
package http_test package http_test
import ( import (
"net/http"
"net/http/httptest"
"testing" "testing"
"github.com/fasthttp/router"
"github.com/goccy/go-json" "github.com/goccy/go-json"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
delivery "source.toby3d.me/toby3d/auth/internal/metadata/delivery/http" delivery "source.toby3d.me/toby3d/auth/internal/metadata/delivery/http"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
) )
func TestMetadata(t *testing.T) { func TestMetadata(t *testing.T) {
t.Parallel() t.Parallel()
r := router.New()
metadata := domain.TestMetadata(t) metadata := domain.TestMetadata(t)
delivery.NewRequestHandler(metadata).Register(r)
client, _, cleanup := httptest.New(t, r.Handler) req := httptest.NewRequest(http.MethodGet, "https://example.com/.well-known/oauth-authorization-server", nil)
t.Cleanup(cleanup)
const requestURL string = "https://example.com/.well-known/oauth-authorization-server" w := httptest.NewRecorder()
delivery.NewHandler(metadata).
Handler().
ServeHTTP(w, req)
status, body, err := client.Get(nil, requestURL) resp := w.Result()
if err != nil {
t.Fatal(err) if resp.StatusCode != http.StatusOK {
t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK)
} }
if status != http.StatusOK { out := new(delivery.MetadataResponse)
t.Errorf("GET %s = %d, want %d", requestURL, status, http.StatusOK) if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
}
result := new(delivery.MetadataResponse)
if err = json.Unmarshal(body, result); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@ -2,12 +2,14 @@ package metadata
import ( import (
"context" "context"
"net/url"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
) )
type Repository interface { type Repository interface {
Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error) Create(_ context.Context, _ *url.URL, _ domain.Metadata) error
Get(_ context.Context, u *url.URL) (*domain.Metadata, error)
} }
var ErrNotExist error = domain.NewError( var ErrNotExist error = domain.NewError(

View File

@ -2,26 +2,29 @@ package http
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http"
"net/url"
http "github.com/valyala/fasthttp" "github.com/goccy/go-json"
"github.com/tomnomnom/linkheader"
"willnorris.com/go/microformats"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/httputil"
"source.toby3d.me/toby3d/auth/internal/metadata" "source.toby3d.me/toby3d/auth/internal/metadata"
) )
type ( type (
//nolint:tagliatelle,lll //nolint:tagliatelle,lll
Metadata struct { Response struct {
Issuer *domain.ClientID `json:"issuer"` Issuer domain.URL `json:"issuer"`
AuthorizationEndpoint *domain.URL `json:"authorization_endpoint"` AuthorizationEndpoint domain.URL `json:"authorization_endpoint"`
IntrospectionEndpoint *domain.URL `json:"introspection_endpoint"` IntrospectionEndpoint domain.URL `json:"introspection_endpoint"`
RevocationEndpoint *domain.URL `json:"revocation_endpoint,omitempty"` RevocationEndpoint domain.URL `json:"revocation_endpoint,omitempty"`
ServiceDocumentation *domain.URL `json:"service_documentation,omitempty"` ServiceDocumentation domain.URL `json:"service_documentation,omitempty"`
TokenEndpoint *domain.URL `json:"token_endpoint"` TokenEndpoint domain.URL `json:"token_endpoint"`
UserinfoEndpoint *domain.URL `json:"userinfo_endpoint,omitempty"` UserinfoEndpoint domain.URL `json:"userinfo_endpoint,omitempty"`
CodeChallengeMethodsSupported []domain.CodeChallengeMethod `json:"code_challenge_methods_supported"` CodeChallengeMethodsSupported []domain.CodeChallengeMethod `json:"code_challenge_methods_supported"`
GrantTypesSupported []domain.GrantType `json:"grant_types_supported,omitempty"` GrantTypesSupported []domain.GrantType `json:"grant_types_supported,omitempty"`
ResponseTypesSupported []domain.ResponseType `json:"response_types_supported,omitempty"` ResponseTypesSupported []domain.ResponseType `json:"response_types_supported,omitempty"`
@ -29,6 +32,11 @@ type (
IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"`
RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"`
AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"` AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"`
// Extensions
TicketEndpoint domain.URL `json:"ticket_endpoint"`
Micropub domain.URL `json:"micropub"`
Microsub domain.URL `json:"microsub"`
} }
httpMetadataRepository struct { httpMetadataRepository struct {
@ -36,7 +44,7 @@ type (
} }
) )
const DefaultMaxRedirectsCount int = 10 const relIndieauthMetadata = "indieauth-metadata"
func NewHTTPMetadataRepository(client *http.Client) metadata.Repository { func NewHTTPMetadataRepository(client *http.Client) metadata.Repository {
return &httpMetadataRepository{ return &httpMetadataRepository{
@ -44,48 +52,127 @@ func NewHTTPMetadataRepository(client *http.Client) metadata.Repository {
} }
} }
func (repo *httpMetadataRepository) Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error) { // WARN(toby3d): not implemented.
req := http.AcquireRequest() func (httpMetadataRepository) Create(_ context.Context, _ *url.URL, _ domain.Metadata) error {
defer http.ReleaseRequest(req) return nil
req.SetRequestURI(me.String()) }
req.Header.SetMethod(http.MethodGet)
func (repo *httpMetadataRepository) Get(_ context.Context, u *url.URL) (*domain.Metadata, error) {
resp := http.AcquireResponse() resp, err := repo.client.Get(u.String())
defer http.ReleaseResponse(resp) if err != nil {
return nil, fmt.Errorf("cannot make request to provided Me: %w", err)
if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil { }
return nil, fmt.Errorf("failed to make a request to the client: %w", err)
} relVals := make(map[string][]string)
for _, link := range linkheader.Parse(resp.Header.Get(common.HeaderLink)) {
endpoints := httputil.ExtractEndpoints(resp, "indieauth-metadata") populateBuffer(relVals, link.Rel, link.URL)
if len(endpoints) == 0 { }
return nil, metadata.ErrNotExist
} if mf2 := microformats.Parse(resp.Body, resp.Request.URL); mf2 != nil {
for rel, vals := range mf2.Rels {
_, body, err := repo.client.Get(nil, endpoints[len(endpoints)-1].String()) if len(vals) > 0 {
if err != nil { populateBuffer(relVals, rel, vals[0])
return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err) }
} }
}
data := new(Metadata)
if err = json.Unmarshal(body, data); err != nil { out := new(domain.Metadata)
return nil, fmt.Errorf("cannot unmarshal metadata configuration: %w", err) // NOTE(toby3d): fetch all from metadata endpoint if exists
} if endpoints, ok := relVals["indieauth-metadata"]; ok {
if resp, err = repo.client.Get(endpoints[0]); err != nil {
//nolint:exhaustivestruct // TODO(toby3d) return nil, fmt.Errorf("cannot fetch indieauth-metadata endpoint: %w", err)
return &domain.Metadata{ }
AuthorizationEndpoint: data.AuthorizationEndpoint.URL,
AuthorizationResponseIssParameterSupported: data.AuthorizationResponseIssParameterSupported, in := NewResponse()
CodeChallengeMethodsSupported: data.CodeChallengeMethodsSupported, if err = in.bind(resp); err != nil {
GrantTypesSupported: data.GrantTypesSupported, return nil, err
Issuer: data.Issuer, }
ResponseTypesSupported: data.ResponseTypesSupported,
ScopesSupported: data.ScopesSupported, in.populate(out)
ServiceDocumentation: data.ServiceDocumentation.URL,
TokenEndpoint: data.TokenEndpoint.URL, return out, nil
// TODO(toby3d): support extensions? }
// Micropub: data.Micropub,
// Microsub: data.Microsub, // NOTE(toby3d): metadata not exists, fallback for old clients
// TicketEndpoint: data.TicketEndpoint, for key, dst := range map[string]**url.URL{
}, nil "authorization_endpoint": &out.AuthorizationEndpoint,
"micropub": &out.MicropubEndpoint,
"microsub": &out.MicrosubEndpoint,
"ticket_endpoint": &out.TicketEndpoint,
"token_endpoint": &out.TokenEndpoint,
} {
if values, ok := relVals[key]; ok && len(values) > 0 {
if u, err := url.Parse(values[0]); err == nil {
*dst = resp.Request.URL.ResolveReference(u)
}
}
}
return out, nil
}
func populateBuffer(dst map[string][]string, rel, u string) {
if _, ok := dst[rel]; !ok {
dst[rel] = make([]string, 0)
}
dst[rel] = append(dst[rel], u)
}
func NewResponse() *Response {
return &Response{
CodeChallengeMethodsSupported: make([]domain.CodeChallengeMethod, 0),
GrantTypesSupported: make([]domain.GrantType, 0),
ResponseTypesSupported: make([]domain.ResponseType, 0),
ScopesSupported: make([]domain.Scope, 0),
IntrospectionEndpointAuthMethodsSupported: make([]string, 0),
RevocationEndpointAuthMethodsSupported: make([]string, 0),
}
}
func (r *Response) bind(resp *http.Response) error {
if err := json.NewDecoder(resp.Body).Decode(r); err != nil {
return fmt.Errorf("cannot unmarshal metadata configuration: %w", err)
}
return nil
}
func (r Response) populate(dst *domain.Metadata) {
dst.AuthorizationEndpoint = r.AuthorizationEndpoint.URL
dst.AuthorizationResponseIssParameterSupported = r.AuthorizationResponseIssParameterSupported
dst.IntrospectionEndpoint = r.IntrospectionEndpoint.URL
dst.Issuer = r.Issuer.URL
dst.MicropubEndpoint = r.Micropub.URL
dst.MicrosubEndpoint = r.Microsub.URL
dst.RevocationEndpoint = r.RevocationEndpoint.URL
dst.ServiceDocumentation = r.ServiceDocumentation.URL
dst.TicketEndpoint = r.TicketEndpoint.URL
dst.TokenEndpoint = r.TokenEndpoint.URL
dst.UserinfoEndpoint = r.UserinfoEndpoint.URL
for _, scope := range r.ScopesSupported {
dst.ScopesSupported = append(dst.ScopesSupported, scope)
}
for _, method := range r.RevocationEndpointAuthMethodsSupported {
dst.RevocationEndpointAuthMethodsSupported = append(dst.RevocationEndpointAuthMethodsSupported, method)
}
for _, responseType := range r.ResponseTypesSupported {
dst.ResponseTypesSupported = append(dst.ResponseTypesSupported, responseType)
}
for _, method := range r.IntrospectionEndpointAuthMethodsSupported {
dst.IntrospectionEndpointAuthMethodsSupported = append(dst.IntrospectionEndpointAuthMethodsSupported,
method)
}
for _, grantType := range r.GrantTypesSupported {
dst.GrantTypesSupported = append(dst.GrantTypesSupported, grantType)
}
for _, method := range r.CodeChallengeMethodsSupported {
dst.CodeChallengeMethodsSupported = append(dst.CodeChallengeMethodsSupported, method)
}
} }

View File

@ -0,0 +1,183 @@
package http_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/goccy/go-json"
"github.com/google/go-cmp/cmp"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
repository "source.toby3d.me/toby3d/auth/internal/metadata/repository/http"
)
//nolint:lll,tagliatelle
type Response struct {
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
ScopesSupported []string `json:"scopes_supported,omitempty"`
IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"`
RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"`
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
IntrospectionEndpoint string `json:"introspection_endpoint"`
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
ServiceDocumentation string `json:"service_documentation,omitempty"`
TokenEndpoint string `json:"token_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
TicketEndpoint string `json:"ticket_endpoint"`
Micropub string `json:"micropub"`
Microsub string `json:"microsub"`
AuthorizationResponseIssParameterSupported bool `json:"authorization_response_iss_parameter_supported,omitempty"`
}
const testBody string = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Testing</title>
%s
</head>
<body></body>
</html>`
//nolint:funlen
func TestGet(t *testing.T) {
t.Parallel()
testMetadata := domain.TestMetadata(t)
for _, tc := range []struct {
name string
header map[string]string
body map[string]string
out *domain.Metadata
}{
{
name: "header",
header: map[string]string{
"indieauth-metadata": "/metadata",
"authorization_endpoint": "http://example.net/authorization",
"token_endpoint": "http://example.net/tkn",
},
out: testMetadata,
}, /*{
name: "body",
body: map[string]string{
"indieauth-metadata": "/metadata",
"authorization_endpoint": "http://example.net/authorization",
"token_endpoint": "http://example.net/tkn",
},
out: &testMetadata,
}*/} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.HandleFunc("/metadata", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
_ = json.NewEncoder(w).Encode(NewResponse(t, *testMetadata))
})
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
links := make([]string, 0)
for k, v := range tc.header {
links = append(links, `<`+v+`>; rel="`+k+`"`)
}
w.Header().Set(common.HeaderLink, strings.Join(links, ", "))
links = make([]string, 0)
for k, v := range tc.body {
links = append(links, `<link rel="`+k+`" href="`+v+`">`)
}
fmt.Fprintf(w, testBody, strings.Join(links, "\n"))
})
srv := httptest.NewUnstartedServer(mux)
srv.EnableHTTP2 = true
srv.Start()
t.Cleanup(srv.Close)
tc.header["indieauth-metadata"] = srv.URL + tc.header["indieauth-metadata"]
u, _ := url.Parse(srv.URL + "/")
out, err := repository.NewHTTPMetadataRepository(srv.Client()).
Get(context.Background(), u)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tc.out, out, cmp.AllowUnexported(
domain.ClientID{},
domain.CodeChallengeMethod{},
domain.GrantType{},
domain.ResponseType{},
domain.Scope{},
url.URL{},
)); diff != "" {
t.Errorf("%+s", diff)
}
})
}
}
func NewResponse(tb testing.TB, src domain.Metadata) *Response {
tb.Helper()
out := &Response{
CodeChallengeMethodsSupported: make([]string, 0),
GrantTypesSupported: make([]string, 0),
ResponseTypesSupported: make([]string, 0),
ScopesSupported: make([]string, 0),
IntrospectionEndpointAuthMethodsSupported: make([]string, 0),
RevocationEndpointAuthMethodsSupported: make([]string, 0),
Issuer: src.Issuer.String(),
AuthorizationEndpoint: src.AuthorizationEndpoint.String(),
IntrospectionEndpoint: src.IntrospectionEndpoint.String(),
RevocationEndpoint: src.RevocationEndpoint.String(),
ServiceDocumentation: src.ServiceDocumentation.String(),
TokenEndpoint: src.TokenEndpoint.String(),
UserinfoEndpoint: src.UserinfoEndpoint.String(),
TicketEndpoint: src.TicketEndpoint.String(),
Micropub: src.MicropubEndpoint.String(),
Microsub: src.MicrosubEndpoint.String(),
AuthorizationResponseIssParameterSupported: src.AuthorizationResponseIssParameterSupported,
}
for _, method := range src.CodeChallengeMethodsSupported {
out.CodeChallengeMethodsSupported = append(out.CodeChallengeMethodsSupported, method.String())
}
for _, grantType := range src.GrantTypesSupported {
out.GrantTypesSupported = append(out.GrantTypesSupported, grantType.String())
}
for _, responseType := range src.ResponseTypesSupported {
out.ResponseTypesSupported = append(out.ResponseTypesSupported, responseType.String())
}
for _, scope := range src.ScopesSupported {
out.ScopesSupported = append(out.ScopesSupported, scope.String())
}
for _, method := range src.IntrospectionEndpointAuthMethodsSupported {
out.IntrospectionEndpointAuthMethodsSupported = append(out.IntrospectionEndpointAuthMethodsSupported,
method)
}
for _, method := range src.RevocationEndpointAuthMethodsSupported {
out.RevocationEndpointAuthMethodsSupported = append(out.RevocationEndpointAuthMethodsSupported, method)
}
return out
}

View File

@ -2,7 +2,7 @@ package memory
import ( import (
"context" "context"
"path" "net/url"
"sync" "sync"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -10,27 +10,35 @@ import (
) )
type memoryMetadataRepository struct { type memoryMetadataRepository struct {
store *sync.Map mutex *sync.RWMutex
metadata map[string]domain.Metadata
} }
const DefaultPathPrefix = "metadata" const DefaultPathPrefix = "metadata"
func NewMemoryMetadataRepository(store *sync.Map) metadata.Repository { func NewMemoryMetadataRepository() metadata.Repository {
return &memoryMetadataRepository{ return &memoryMetadataRepository{
store: store, mutex: new(sync.RWMutex),
metadata: make(map[string]domain.Metadata),
} }
} }
func (repo *memoryMetadataRepository) Get(ctx context.Context, me *domain.Me) (*domain.Metadata, error) { func (repo *memoryMetadataRepository) Create(ctx context.Context, u *url.URL, metadata domain.Metadata) error {
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String())) repo.mutex.Lock()
if !ok { defer repo.mutex.Unlock()
return nil, metadata.ErrNotExist
}
result, ok := src.(*domain.Metadata) repo.metadata[u.String()] = metadata
if !ok {
return nil, metadata.ErrNotExist
}
return result, nil return nil
}
func (repo *memoryMetadataRepository) Get(ctx context.Context, u *url.URL) (*domain.Metadata, error) {
repo.mutex.RLock()
defer repo.mutex.RUnlock()
if out, ok := repo.metadata[u.String()]; ok {
return &out, nil
}
return nil, metadata.ErrNotExist
} }

View File

@ -23,7 +23,6 @@ var (
errHeaderExtractorValueMissing = errors.New("missing value in request header") errHeaderExtractorValueMissing = errors.New("missing value in request header")
errHeaderExtractorValueInvalid = errors.New("invalid value in request header") errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
errQueryExtractorValueMissing = errors.New("missing value in the query string") errQueryExtractorValueMissing = errors.New("missing value in the query string")
errParamExtractorValueMissing = errors.New("missing value in path params")
errCookieExtractorValueMissing = errors.New("missing value in cookies") errCookieExtractorValueMissing = errors.New("missing value in cookies")
errFormExtractorValueMissing = errors.New("missing value in the form") errFormExtractorValueMissing = errors.New("missing value in the form")
) )
@ -67,8 +66,6 @@ func createExtractors(lookups, authScheme string) ([]ValuesExtractor, error) {
switch parts[0] { switch parts[0] {
case "query": case "query":
extractors = append(extractors, valuesFromQuery(parts[1])) extractors = append(extractors, valuesFromQuery(parts[1]))
// case "param":
// extractors = append(extractors, valuesFromParam(parts[1]))
case "cookie": case "cookie":
extractors = append(extractors, valuesFromCookie(parts[1])) extractors = append(extractors, valuesFromCookie(parts[1]))
case "form": case "form":
@ -163,31 +160,6 @@ func valuesFromQuery(param string) ValuesExtractor {
} }
} }
// valuesFromParam returns a function that extracts values from the url param string.
/*
func valuesFromParam(param string) ValuesExtractor {
return func(w http.ResponseWriter, r *http.Request) ([]string, error) {
result := make([]string, 0)
paramVales := r.ParamValues()
for i, p := range r.ParamNames() {
if param == p {
result = append(result, paramVales[i])
if i >= extractorLimit-1 {
break
}
}
}
if len(result) == 0 {
return nil, errParamExtractorValueMissing
}
return result, nil
}
}
*/
// valuesFromCookie returns a function that extracts values from the named cookie. // valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) ValuesExtractor { func valuesFromCookie(name string) ValuesExtractor {
return func(w http.ResponseWriter, r *http.Request) ([]string, error) { return func(w http.ResponseWriter, r *http.Request) ([]string, error) {

View File

@ -77,7 +77,6 @@ type (
// Possible values: // Possible values:
// - "header:<name>" // - "header:<name>"
// - "query:<name>" // - "query:<name>"
// - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
// - "form:<name>" // - "form:<name>"
// Multiply sources example: // Multiply sources example:

View File

@ -7,7 +7,8 @@ import (
) )
type Repository interface { type Repository interface {
Get(ctx context.Context, me *domain.Me) (*domain.Profile, error) Create(ctx context.Context, me domain.Me, profile domain.Profile) error
Get(ctx context.Context, me domain.Me) (*domain.Profile, error)
} }
var ErrNotExist error = domain.NewError( var ErrNotExist error = domain.NewError(

View File

@ -1,12 +1,13 @@
package http package http
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io"
"net/http"
"net/url" "net/url"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/httputil" "source.toby3d.me/toby3d/auth/internal/httputil"
"source.toby3d.me/toby3d/auth/internal/profile" "source.toby3d.me/toby3d/auth/internal/profile"
@ -33,29 +34,33 @@ func NewHTPPClientRepository(client *http.Client) profile.Repository {
} }
} }
// WARN(toby3d): not implemented.
func (repo *httpProfileRepository) Create(_ context.Context, _ domain.Me, _ domain.Profile) error {
return nil
}
//nolint:cyclop //nolint:cyclop
func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*domain.Profile, error) { func (repo *httpProfileRepository) Get(ctx context.Context, me domain.Me) (*domain.Profile, error) {
req := http.AcquireRequest() resp, err := repo.client.Get(me.String())
defer http.ReleaseRequest(req) if err != nil {
req.Header.SetMethod(http.MethodGet)
req.SetRequestURI(me.String())
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
return nil, fmt.Errorf("%s: cannot fetch user by me: %w", ErrPrefix, err) return nil, fmt.Errorf("%s: cannot fetch user by me: %w", ErrPrefix, err)
} }
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("cannot read response body: %w", err)
}
buf := bytes.NewReader(body)
result := domain.NewProfile() result := domain.NewProfile()
for _, name := range httputil.ExtractProperty(resp, hCard, propertyName) { for _, name := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyName) {
if n, ok := name.(string); ok { if n, ok := name.(string); ok {
result.Name = append(result.Name, n) result.Name = append(result.Name, n)
} }
} }
for _, rawEmail := range httputil.ExtractProperty(resp, hCard, propertyEmail) { for _, rawEmail := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyEmail) {
email, ok := rawEmail.(string) email, ok := rawEmail.(string)
if !ok { if !ok {
continue continue
@ -66,7 +71,7 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom
} }
} }
for _, rawURL := range httputil.ExtractProperty(resp, hCard, propertyURL) { for _, rawURL := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyURL) {
rawURL, ok := rawURL.(string) rawURL, ok := rawURL.(string)
if !ok { if !ok {
continue continue
@ -77,7 +82,7 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom
} }
} }
for _, rawPhoto := range httputil.ExtractProperty(resp, hCard, propertyPhoto) { for _, rawPhoto := range httputil.ExtractProperty(buf, me.URL(), hCard, propertyPhoto) {
photo, ok := rawPhoto.(string) photo, ok := rawPhoto.(string)
if !ok { if !ok {
continue continue
@ -88,8 +93,8 @@ func (repo *httpProfileRepository) Get(ctx context.Context, me *domain.Me) (*dom
} }
} }
if result.GetName() == "" && result.GetURL() == nil && // TODO(toby3d): create method like result.Empty()?
result.GetPhoto() == nil && result.GetEmail() == nil { if result.GetName() == "" && result.GetURL() == nil && result.GetPhoto() == nil && result.GetEmail() == nil {
return nil, profile.ErrNotExist return nil, profile.ErrNotExist
} }

View File

@ -2,8 +2,6 @@ package memory
import ( import (
"context" "context"
"fmt"
"path"
"sync" "sync"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -11,30 +9,33 @@ import (
) )
type memoryProfileRepository struct { type memoryProfileRepository struct {
store *sync.Map mutex *sync.RWMutex
profiles map[string]domain.Profile
} }
const ( func NewMemoryProfileRepository() profile.Repository {
ErrPrefix string = "memory"
DefaultPathPrefix string = "profiles"
)
func NewMemoryProfileRepository(store *sync.Map) profile.Repository {
return &memoryProfileRepository{ return &memoryProfileRepository{
store: store, mutex: new(sync.RWMutex),
profiles: make(map[string]domain.Profile),
} }
} }
func (repo *memoryProfileRepository) Get(_ context.Context, me *domain.Me) (*domain.Profile, error) { func (repo *memoryProfileRepository) Create(_ context.Context, me domain.Me, p domain.Profile) error {
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String())) repo.mutex.Lock()
if !ok { defer repo.mutex.Unlock()
return nil, fmt.Errorf("%s: cannot find profile in store: %w", ErrPrefix, profile.ErrNotExist)
}
result, ok := src.(*domain.Profile) repo.profiles[me.String()] = p
if !ok {
return nil, fmt.Errorf("%s: cannot decode profile from store: %w", ErrPrefix, profile.ErrNotExist)
}
return result, nil return nil
}
func (repo *memoryProfileRepository) Get(_ context.Context, me domain.Me) (*domain.Profile, error) {
repo.mutex.RLock()
defer repo.mutex.RUnlock()
if p, ok := repo.profiles[me.String()]; ok {
return &p, nil
}
return nil, profile.ErrNotExist
} }

View File

@ -7,7 +7,7 @@ import (
) )
type UseCase interface { type UseCase interface {
Fetch(ctx context.Context, me *domain.Me) (*domain.Profile, error) Fetch(ctx context.Context, me domain.Me) (*domain.Profile, error)
} }
var ErrScopeRequired error = domain.NewError( var ErrScopeRequired error = domain.NewError(

View File

@ -18,7 +18,7 @@ func NewProfileUseCase(profiles profile.Repository) profile.UseCase {
} }
} }
func (uc *profileUseCase) Fetch(ctx context.Context, me *domain.Me) (*domain.Profile, error) { func (uc *profileUseCase) Fetch(ctx context.Context, me domain.Me) (*domain.Profile, error) {
result, err := uc.profiles.Get(ctx, me) result, err := uc.profiles.Get(ctx, me)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot fetch profile info: %w", err) return nil, fmt.Errorf("cannot fetch profile info: %w", err)

View File

@ -17,7 +17,7 @@ const (
Hex = Numeric + "abcdef" Hex = Numeric + "abcdef"
) )
func Bytes(length int) ([]byte, error) { func Bytes(length uint8) ([]byte, error) {
bytes := make([]byte, length) bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil { if _, err := rand.Read(bytes); err != nil {
@ -27,7 +27,7 @@ func Bytes(length int) ([]byte, error) {
return bytes, nil return bytes, nil
} }
func String(length int, charsets ...string) (string, error) { func String(length uint8, charsets ...string) (string, error) {
charset := strings.Join(charsets, "") charset := strings.Join(charsets, "")
if charset == "" { if charset == "" {
charset = Alphabetic charset = Alphabetic

View File

@ -8,7 +8,7 @@ import (
type Repository interface { type Repository interface {
Get(ctx context.Context, code string) (*domain.Session, error) Get(ctx context.Context, code string) (*domain.Session, error)
Create(ctx context.Context, session *domain.Session) error Create(ctx context.Context, session domain.Session) error
GetAndDelete(ctx context.Context, code string) (*domain.Session, error) GetAndDelete(ctx context.Context, code string) (*domain.Session, error)
GC() GC()
} }

View File

@ -3,7 +3,6 @@ package memory
import ( import (
"context" "context"
"fmt" "fmt"
"path"
"sync" "sync"
"time" "time"
@ -14,59 +13,59 @@ import (
type ( type (
Session struct { Session struct {
CreatedAt time.Time CreatedAt time.Time
*domain.Session domain.Session
} }
memorySessionRepository struct { memorySessionRepository struct {
store *sync.Map config domain.Config
config *domain.Config mutex *sync.RWMutex
sessions map[string]Session
} }
) )
const DefaultPathPrefix string = "sessions" func NewMemorySessionRepository(config domain.Config) session.Repository {
func NewMemorySessionRepository(store *sync.Map, config *domain.Config) session.Repository {
return &memorySessionRepository{ return &memorySessionRepository{
config: config, config: config,
store: store, mutex: new(sync.RWMutex),
sessions: make(map[string]Session),
} }
} }
func (repo *memorySessionRepository) Create(_ context.Context, state *domain.Session) error { func (repo *memorySessionRepository) Create(_ context.Context, s domain.Session) error {
repo.store.Store(path.Join(DefaultPathPrefix, state.Code), &Session{ repo.mutex.Lock()
defer repo.mutex.Unlock()
repo.sessions[s.Code] = Session{
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Session: state, Session: s,
}) }
return nil return nil
} }
func (repo *memorySessionRepository) Get(_ context.Context, code string) (*domain.Session, error) { func (repo *memorySessionRepository) Get(_ context.Context, code string) (*domain.Session, error) {
src, ok := repo.store.Load(path.Join(DefaultPathPrefix, code)) repo.mutex.Lock()
if !ok { defer repo.mutex.Unlock()
return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist)
if s, ok := repo.sessions[code]; ok {
return &s.Session, nil
} }
result, ok := src.(*Session) return nil, session.ErrNotExist
if !ok {
return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist)
}
return result.Session, nil
} }
func (repo *memorySessionRepository) GetAndDelete(_ context.Context, code string) (*domain.Session, error) { func (repo *memorySessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) {
src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, code)) s, err := repo.Get(ctx, code)
if !ok { if err != nil {
return nil, fmt.Errorf("cannot find session in store: %w", session.ErrNotExist) return nil, fmt.Errorf("cannot get and delete session: %w", err)
} }
result, ok := src.(*Session) repo.mutex.Lock()
if !ok { defer repo.mutex.Unlock()
return nil, fmt.Errorf("cannot decode session in store: %w", session.ErrNotExist)
}
return result.Session, nil delete(repo.sessions, s.Code)
return s, nil
} }
func (repo *memorySessionRepository) GC() { func (repo *memorySessionRepository) GC() {
@ -76,29 +75,20 @@ func (repo *memorySessionRepository) GC() {
for ts := range ticker.C { for ts := range ticker.C {
ts := ts ts := ts
repo.store.Range(func(key, value interface{}) bool { repo.mutex.RLock()
k, ok := key.(string)
if !ok { for code, s := range repo.sessions {
return false if s.CreatedAt.Add(repo.config.Code.Expiry).After(ts) {
continue
} }
matched, err := path.Match(DefaultPathPrefix+"/*", k) repo.mutex.RUnlock()
if err != nil || !matched { repo.mutex.Lock()
return false delete(repo.sessions, code)
} repo.mutex.Unlock()
repo.mutex.RLock()
}
val, ok := value.(*Session) repo.mutex.RUnlock()
if !ok {
return false
}
if val.CreatedAt.Add(repo.config.Code.Expiry).After(ts) {
return false
}
repo.store.Delete(key)
return false
})
} }
} }

View File

@ -4,11 +4,11 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"time" "time"
"github.com/goccy/go-json"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -53,8 +53,8 @@ func NewSQLite3SessionRepository(db *sqlx.DB) session.Repository {
} }
} }
func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domain.Session) error { func (repo *sqlite3SessionRepository) Create(ctx context.Context, session domain.Session) error {
src, err := NewSession(session) src, err := NewSession(&session)
if err != nil { if err != nil {
return fmt.Errorf("cannot encode session data for store: %w", err) return fmt.Errorf("cannot encode session data for store: %w", err)
} }

View File

@ -12,7 +12,7 @@ import (
"source.toby3d.me/toby3d/auth/internal/testing/sqltest" "source.toby3d.me/toby3d/auth/internal/testing/sqltest"
) )
//nolint: gochecknoglobals // slices cannot be contants // nolint: gochecknoglobals // slices cannot be contants
var tableColumns = []string{"created_at", "code", "data"} var tableColumns = []string{"created_at", "code", "data"}
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
@ -39,7 +39,7 @@ func TestCreate(t *testing.T) {
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnResult(sqlmock.NewResult(1, 1))
if err := repository.NewSQLite3SessionRepository(db). if err := repository.NewSQLite3SessionRepository(db).
Create(context.Background(), session); err != nil { Create(context.Background(), *session); err != nil {
t.Error(err) t.Error(err)
} }
} }

View File

@ -1 +0,0 @@
*.pem

View File

@ -1,63 +0,0 @@
//go:generate go run "$GOROOT/src/crypto/tls/generate_cert.go" --host 127.0.0.1,::1,localhost --start-date "Jan 1 00:00:00 1970" --duration=1000000h --ca --rsa-bits 1024 --ecdsa-curve P256
package httptest
import (
"crypto/tls"
_ "embed" // used for running tests without same import in "god object"
"net"
"testing"
"time"
http "github.com/valyala/fasthttp"
httputil "github.com/valyala/fasthttp/fasthttputil"
)
var (
//go:embed cert.pem
certData []byte
//go:embed key.pem
keyData []byte
)
// New returns the InMemory Server and the Client connected to it with the
// specified handler.
func New(tb testing.TB, handler http.RequestHandler) (*http.Client, *http.Server, func()) {
tb.Helper()
//nolint:exhaustivestruct
server := &http.Server{
CloseOnShutdown: true,
DisableKeepalive: true,
ReduceMemoryUsage: true,
Handler: http.TimeoutHandler(handler, 1*time.Second, "handler performance is too slow"),
}
ln := httputil.NewInmemoryListener()
//nolint:errcheck
go server.ServeTLSEmbed(ln, certData, keyData)
//nolint:exhaustivestruct
client := &http.Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true, //nolint:gosec
},
Dial: func(addr string) (net.Conn, error) {
return ln.Dial() //nolint:wrapcheck
},
}
return client, server, func() {
_ = server.Shutdown()
}
}
// NewRequest returns a new incoming server Request and cleanup function.
func NewRequest(method, target string, body []byte) *http.Request {
req := http.AcquireRequest()
req.Header.SetMethod(method)
req.SetRequestURI(target)
req.SetBody(body)
return req
}

View File

@ -1,72 +1,48 @@
package http package http
import ( import (
"errors"
"fmt" "fmt"
"path" "net/http"
"github.com/fasthttp/router"
"github.com/goccy/go-json" "github.com/goccy/go-json"
"github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwa"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/random" "source.toby3d.me/toby3d/auth/internal/random"
"source.toby3d.me/toby3d/auth/internal/ticket" "source.toby3d.me/toby3d/auth/internal/ticket"
"source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/auth/web" "source.toby3d.me/toby3d/auth/web"
"source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware"
) )
type ( type Handler struct {
TicketGenerateRequest struct { config domain.Config
// The access token should be used when acting on behalf of this URL. matcher language.Matcher
Subject *domain.Me `form:"subject"` tickets ticket.UseCase
}
// The access token will work at this URL. func NewHandler(tickets ticket.UseCase, matcher language.Matcher, config domain.Config) *Handler {
Resource *domain.URL `form:"resource"` return &Handler{
}
TicketExchangeRequest struct {
// A random string that can be redeemed for an access token.
Ticket string `form:"ticket"`
// The access token should be used when acting on behalf of this URL.
Subject *domain.Me `form:"subject"`
// The access token will work at this URL.
Resource *domain.URL `form:"resource"`
}
RequestHandler struct {
config *domain.Config
matcher language.Matcher
tickets ticket.UseCase
}
)
func NewRequestHandler(tickets ticket.UseCase, matcher language.Matcher, config *domain.Config) *RequestHandler {
return &RequestHandler{
config: config, config: config,
matcher: matcher, matcher: matcher,
tickets: tickets, tickets: tickets,
} }
} }
func (h *RequestHandler) Register(r *router.Router) { func (h *Handler) Handler() http.Handler {
//nolint:exhaustivestruct //nolint:exhaustivestruct
chain := middleware.Chain{ chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{ middleware.CSRFWithConfig(middleware.CSRFConfig{
Skipper: func(ctx *http.RequestCtx) bool { Skipper: func(w http.ResponseWriter, r *http.Request) bool {
matched, _ := path.Match("/ticket*", string(ctx.Path())) head, _ := urlutil.ShiftPath(r.URL.Path)
return ctx.IsPost() && matched return r.Method == http.MethodPost && head == "ticket"
}, },
CookieMaxAge: 0, CookieMaxAge: 0,
CookieSameSite: http.CookieSameSiteStrictMode, CookieSameSite: http.SameSiteStrictMode,
ContextKey: "csrf", ContextKey: "csrf",
CookieDomain: h.config.Server.Domain, CookieDomain: h.config.Server.Domain,
CookieName: "__Secure-csrf", CookieName: "__Secure-csrf",
@ -89,45 +65,69 @@ func (h *RequestHandler) Register(r *router.Router) {
SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm), SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm),
Skipper: middleware.DefaultSkipper, Skipper: middleware.DefaultSkipper,
SuccessHandler: nil, SuccessHandler: nil,
TokenLookup: "header:" + http.HeaderAuthorization + TokenLookup: "header:" + common.HeaderAuthorization +
"," + "cookie:" + "__Secure-auth-token", ",cookie:__Secure-auth-token",
}), }),
middleware.LogFmt(),
} }
r.GET("/ticket", chain.RequestHandler(h.handleRender)) return chain.Handler(h.handleFunc)
r.POST("/api/ticket", chain.RequestHandler(h.handleSend))
r.POST("/ticket", chain.RequestHandler(h.handleRedeem))
} }
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) var head string
head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) switch r.Method {
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
case "", http.MethodGet:
if head != "" {
http.NotFound(w, r)
return
}
h.handleRender(w, r)
case http.MethodPost:
switch head {
default:
http.NotFound(w, r)
case "":
h.handleRedeem(w, r)
case "send":
h.handleSend(w, r)
}
}
}
func (h *Handler) handleRender(w http.ResponseWriter, r *http.Request) {
w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...) tag, _, _ := h.matcher.Match(tags...)
baseOf := web.BaseOf{ baseOf := web.BaseOf{
Config: h.config, Config: &h.config,
Language: tag, Language: tag,
Printer: message.NewPrinter(tag), Printer: message.NewPrinter(tag),
} }
csrf, _ := ctx.UserValue("csrf").([]byte) csrf, _ := r.Context().Value("csrf").([]byte)
web.WriteTemplate(ctx, &web.TicketPage{ web.WriteTemplate(w, &web.TicketPage{
BaseOf: baseOf, BaseOf: baseOf,
CSRF: csrf, CSRF: csrf,
}) })
} }
func (h *RequestHandler) handleSend(ctx *http.RequestCtx) { func (h *Handler) handleSend(w http.ResponseWriter, r *http.Request) {
ctx.Response.Header.Set(http.HeaderAccessControlAllowOrigin, h.config.Server.Domain) w.Header().Set(common.HeaderAccessControlAllowOrigin, h.config.Server.Domain)
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
ctx.SetStatusCode(http.StatusOK)
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(w)
req := new(TicketGenerateRequest) req := new(TicketGenerateRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err) _ = encoder.Encode(err)
@ -137,51 +137,50 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
ticket := &domain.Ticket{ ticket := &domain.Ticket{
Ticket: "", Ticket: "",
Resource: req.Resource.URL, Resource: req.Resource.URL,
Subject: req.Subject, Subject: &req.Subject,
} }
var err error var err error
if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil { if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) _ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return return
} }
if err = h.tickets.Generate(ctx, ticket); err != nil { if err = h.tickets.Generate(r.Context(), *ticket); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) _ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return return
} }
ctx.SetStatusCode(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) { func (h *Handler) handleRedeem(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
ctx.SetStatusCode(http.StatusOK)
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(w)
req := new(TicketExchangeRequest) req := new(TicketExchangeRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err) _ = encoder.Encode(err)
return return
} }
token, err := h.tickets.Redeem(ctx, &domain.Ticket{ token, err := h.tickets.Redeem(r.Context(), domain.Ticket{
Ticket: req.Ticket, Ticket: req.Ticket,
Resource: req.Resource.URL, Resource: req.Resource.URL,
Subject: req.Subject, Subject: &req.Subject,
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), "")) _ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
@ -190,84 +189,11 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
// TODO(toby3d): print the result as part of the debugging. Instead, we // TODO(toby3d): print the result as part of the debugging. Instead, we
// need to send or save the token to the recipient for later use. // need to send or save the token to the recipient for later use.
ctx.SetBodyString(fmt.Sprintf(`{ fmt.Fprintf(w, `{
"access_token": "%s", "access_token": "%s",
"token_type": "Bearer", "token_type": "Bearer",
"scope": "%s", "scope": "%s",
"me": "%s" "me": "%s"
}`, token.AccessToken, token.Scope.String(), token.Me.String())) }`, token.AccessToken, token.Scope.String(), token.Me.String())
} w.WriteHeader(http.StatusOK)
func (req *TicketGenerateRequest) bind(ctx *http.RequestCtx) (err error) {
indieAuthError := new(domain.Error)
if err = form.Unmarshal(ctx.Request.PostArgs().QueryString(), req); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
if req.Resource == nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"resource value MUST be set",
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
if req.Subject == nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"subject value MUST be set",
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
return nil
}
func (req *TicketExchangeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.Request.PostArgs().QueryString(), req); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
if req.Ticket == "" {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"ticket parameter is required",
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
if req.Resource == nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"resource parameter is required",
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
if req.Subject == nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"subject parameter is required",
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
return nil
} }

View File

@ -0,0 +1,90 @@
package http
import (
"errors"
"net/http"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/form"
)
type (
TicketGenerateRequest struct {
// The access token should be used when acting on behalf of this URL.
Subject domain.Me `form:"subject"`
// The access token will work at this URL.
Resource domain.URL `form:"resource"`
}
TicketExchangeRequest struct {
// The access token should be used when acting on behalf of this URL.
Subject domain.Me `form:"subject"`
// The access token will work at this URL.
Resource domain.URL `form:"resource"`
// A random string that can be redeemed for an access token.
Ticket string `form:"ticket"`
}
)
func (r *TicketGenerateRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := req.ParseForm(); err != nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
return nil
}
func (r *TicketExchangeRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := req.ParseForm(); err != nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
if r.Ticket == "" {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"ticket parameter is required",
"https://indieweb.org/IndieAuth_Ticket_Auth#Create_the_IndieAuth_ticket",
)
}
return nil
}

View File

@ -1,17 +1,20 @@
package http_test package http_test
/* TODO(toby3d): move CSRF middleware into main
import ( import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync" "sync"
"testing" "testing"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
"source.toby3d.me/toby3d/auth/internal/ticket" "source.toby3d.me/toby3d/auth/internal/ticket"
delivery "source.toby3d.me/toby3d/auth/internal/ticket/delivery/http" delivery "source.toby3d.me/toby3d/auth/internal/ticket/delivery/http"
ticketrepo "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory" ticketrepo "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory"
@ -19,6 +22,7 @@ import (
) )
type Dependencies struct { type Dependencies struct {
server *httptest.Server
client *http.Client client *http.Client
config *domain.Config config *domain.Config
matcher language.Matcher matcher language.Matcher
@ -33,40 +37,35 @@ func TestUpdate(t *testing.T) {
t.Parallel() t.Parallel()
deps := NewDependencies(t) deps := NewDependencies(t)
t.Cleanup(deps.server.Close)
r := router.New() req := httptest.NewRequest(http.MethodPost, "https://example.com/", strings.NewReader(
delivery.NewRequestHandler(deps.ticketService, deps.matcher, deps.config).Register(r)
client, _, cleanup := httptest.New(t, r.Handler)
t.Cleanup(cleanup)
const requestURI string = "https://example.com/ticket"
req := httptest.NewRequest(http.MethodPost, requestURI, []byte(
`ticket=`+deps.ticket.Ticket+ `ticket=`+deps.ticket.Ticket+
`&resource=`+deps.ticket.Resource.String()+ `&resource=`+deps.ticket.Resource.String()+
`&subject=`+deps.ticket.Subject.String(), `&subject=`+deps.ticket.Subject.String(),
)) ))
defer http.ReleaseRequest(req) req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm)
req.Header.SetContentType(common.MIMEApplicationForm) deps.token.SetAuthHeader(req)
w := httptest.NewRecorder()
delivery.NewHandler(deps.ticketService, deps.matcher, *deps.config).
Handler().
ServeHTTP(w, req)
domain.TestToken(t).SetAuthHeader(req) domain.TestToken(t).SetAuthHeader(req)
resp := http.AcquireResponse() resp := w.Result()
defer http.ReleaseResponse(resp)
if err := client.Do(req, resp); err != nil { if resp.StatusCode != http.StatusOK &&
t.Fatal(err) resp.StatusCode != http.StatusAccepted {
} t.Errorf("%s %s = %d, want %d or %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK,
if resp.StatusCode() != http.StatusOK && resp.StatusCode() != http.StatusAccepted {
t.Errorf("POST %s = %d, want %d or %d", requestURI, resp.StatusCode(), http.StatusOK,
http.StatusAccepted) http.StatusAccepted)
} }
// TODO(toby3d): print the result as part of the debugging. Instead, you // TODO(toby3d): print the result as part of the debugging. Instead, you
// need to send or save the token to the recipient for later use. // need to send or save the token to the recipient for later use.
if resp.Body() == nil { if resp.Body == nil {
t.Errorf("POST %s = nil, want something", requestURI) t.Errorf("%s %s = nil, want not nil", req.Method, req.RequestURI)
} }
} }
@ -79,29 +78,36 @@ func NewDependencies(tb testing.TB) Dependencies {
ticket := domain.TestTicket(tb) ticket := domain.TestTicket(tb)
token := domain.TestToken(tb) token := domain.TestToken(tb)
r := router.New() mux := http.NewServeMux()
// NOTE(toby3d): private resource // NOTE(toby3d): private resource
r.GET(ticket.Resource.Path, func(ctx *http.RequestCtx) { mux.HandleFunc(ticket.Resource.Path, func(w http.ResponseWriter, r *http.Request) {
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
`<link rel="token_endpoint" href="https://auth.example.org/token">`) fmt.Fprintf(w, `<link rel="token_endpoint" href="https://auth.example.org/token">`)
}) })
// NOTE(toby3d): token endpoint // NOTE(toby3d): token endpoint
r.POST("/token", func(ctx *http.RequestCtx) { mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{ if r.Method != http.MethodPost {
"access_token": "`+token.AccessToken+`", http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
"me": "`+token.Me.String()+`",
"scope": "`+token.Scope.String()+`", return
"token_type": "Bearer" }
}`)
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
fmt.Fprintf(w, `{
"access_token": "`+token.AccessToken+`",
"me": "`+token.Me.String()+`",
"scope": "`+token.Scope.String()+`",
"token_type": "Bearer"
}`)
}) })
client, _, cleanup := httptest.New(tb, r.Handler) server := httptest.NewServer(mux)
tb.Cleanup(cleanup) client := server.Client()
tickets := ticketrepo.NewMemoryTicketRepository(store, config) tickets := ticketrepo.NewMemoryTicketRepository(store, config)
ticketService := ucase.NewTicketUseCase(tickets, client, config) ticketService := ucase.NewTicketUseCase(tickets, client, config)
return Dependencies{ return Dependencies{
server: server,
client: client, client: client,
config: config, config: config,
matcher: matcher, matcher: matcher,
@ -112,3 +118,4 @@ func NewDependencies(tb testing.TB) Dependencies {
token: token, token: token,
} }
} }
*/

View File

@ -7,7 +7,7 @@ import (
) )
type Repository interface { type Repository interface {
Create(ctx context.Context, ticket *domain.Ticket) error Create(ctx context.Context, ticket domain.Ticket) error
GetAndDelete(ctx context.Context, ticket string) (*domain.Ticket, error) GetAndDelete(ctx context.Context, ticket string) (*domain.Ticket, error)
GC() GC()
} }

View File

@ -2,8 +2,6 @@ package memory
import ( import (
"context" "context"
"fmt"
"path"
"sync" "sync"
"time" "time"
@ -14,77 +12,75 @@ import (
type ( type (
Ticket struct { Ticket struct {
CreatedAt time.Time CreatedAt time.Time
*domain.Ticket domain.Ticket
} }
memoryTicketRepository struct { memoryTicketRepository struct {
config *domain.Config config domain.Config
store *sync.Map mutex *sync.RWMutex
tickets map[string]Ticket
} }
) )
const DefaultPathPrefix string = "tickets" func NewMemoryTicketRepository(config domain.Config) ticket.Repository {
func NewMemoryTicketRepository(store *sync.Map, config *domain.Config) ticket.Repository {
return &memoryTicketRepository{ return &memoryTicketRepository{
config: config, config: config,
store: store, mutex: new(sync.RWMutex),
tickets: make(map[string]Ticket),
} }
} }
func (repo *memoryTicketRepository) Create(_ context.Context, t *domain.Ticket) error { func (repo *memoryTicketRepository) Create(_ context.Context, t domain.Ticket) error {
repo.store.Store(path.Join(DefaultPathPrefix, t.Ticket), &Ticket{ repo.mutex.Lock()
defer repo.mutex.Unlock()
repo.tickets[t.Ticket] = Ticket{
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Ticket: t, Ticket: t,
}) }
return nil return nil
} }
func (repo *memoryTicketRepository) GetAndDelete(_ context.Context, t string) (*domain.Ticket, error) { func (repo *memoryTicketRepository) GetAndDelete(_ context.Context, t string) (*domain.Ticket, error) {
src, ok := repo.store.LoadAndDelete(path.Join(DefaultPathPrefix, t)) repo.mutex.RLock()
out, ok := repo.tickets[t]
if !ok { if !ok {
return nil, fmt.Errorf("cannot find ticket in store: %w", ticket.ErrNotExist) repo.mutex.RUnlock()
return nil, ticket.ErrNotExist
} }
result, ok := src.(*Ticket) repo.mutex.RUnlock()
if !ok { repo.mutex.Lock()
return nil, fmt.Errorf("cannot decode ticket in store: %w", ticket.ErrNotExist) delete(repo.tickets, t)
} repo.mutex.Unlock()
return result.Ticket, nil return &out.Ticket, nil
} }
func (repo *memoryTicketRepository) GC() { func (repo *memoryTicketRepository) GC() {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
defer ticker.Stop() defer ticker.Stop()
for timeStamp := range ticker.C { for ts := range ticker.C {
timeStamp := timeStamp.UTC() ts := ts.UTC()
repo.store.Range(func(key, value interface{}) bool { repo.mutex.RLock()
k, ok := key.(string)
if !ok { for _, t := range repo.tickets {
return false if t.CreatedAt.Add(repo.config.Code.Expiry).After(ts) {
continue
} }
matched, err := path.Match(DefaultPathPrefix+"/*", k) repo.mutex.RUnlock()
if err != nil || !matched { repo.mutex.Lock()
return false delete(repo.tickets, t.Ticket.Ticket)
} repo.mutex.Unlock()
repo.mutex.RLock()
}
val, ok := value.(*Ticket) repo.mutex.RUnlock()
if !ok {
return false
}
if val.CreatedAt.Add(repo.config.Code.Expiry).After(timeStamp) {
return false
}
repo.store.Delete(key)
return false
})
} }
} }

View File

@ -1,63 +0,0 @@
package memory_test
import (
"context"
"path"
"reflect"
"sync"
"testing"
"time"
"source.toby3d.me/toby3d/auth/internal/domain"
repository "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory"
)
func TestCreate(t *testing.T) {
t.Parallel()
store := new(sync.Map)
ticket := domain.TestTicket(t)
if err := repository.NewMemoryTicketRepository(store, domain.TestConfig(t)).
Create(context.Background(), ticket); err != nil {
t.Fatal(err)
}
storePath := path.Join(repository.DefaultPathPrefix, ticket.Ticket)
src, ok := store.Load(storePath)
if !ok {
t.Fatalf("Load(%s) = %t, want %t", storePath, ok, true)
}
if result, _ := src.(*repository.Ticket); !reflect.DeepEqual(result.Ticket, ticket) {
t.Errorf("Create(%+v) = %+v, want %+v", ticket, result.Ticket, ticket)
}
}
func TestGetAndDelete(t *testing.T) {
t.Parallel()
ticket := domain.TestTicket(t)
store := new(sync.Map)
store.Store(path.Join(repository.DefaultPathPrefix, ticket.Ticket), &repository.Ticket{
CreatedAt: time.Now().UTC(),
Ticket: ticket,
})
result, err := repository.NewMemoryTicketRepository(store, domain.TestConfig(t)).
GetAndDelete(context.Background(), ticket.Ticket)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(result, ticket) {
t.Errorf("GetAndDelete(%s) = %+v, want %+v", ticket.Ticket, result, ticket)
}
storePath := path.Join(repository.DefaultPathPrefix, ticket.Ticket)
if src, _ := store.Load(storePath); src != nil {
t.Errorf("Load(%s) = %+v, want %+v", storePath, src, nil)
}
}

View File

@ -56,8 +56,8 @@ func NewSQLite3TicketRepository(db *sqlx.DB, config *domain.Config) ticket.Repos
} }
} }
func (repo *sqlite3TicketRepository) Create(ctx context.Context, t *domain.Ticket) error { func (repo *sqlite3TicketRepository) Create(ctx context.Context, t domain.Ticket) error {
if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewTicket(t)); err != nil { if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewTicket(&t)); err != nil {
return fmt.Errorf("cannot create token record in db: %w", err) return fmt.Errorf("cannot create token record in db: %w", err)
} }

View File

@ -12,7 +12,7 @@ import (
repository "source.toby3d.me/toby3d/auth/internal/ticket/repository/sqlite3" repository "source.toby3d.me/toby3d/auth/internal/ticket/repository/sqlite3"
) )
// nolint: gochecknoglobals // slices cannot be contants //nolint: gochecknoglobals // slices cannot be contants
var tableColumns = []string{"created_at", "resource", "subject", "ticket"} var tableColumns = []string{"created_at", "resource", "subject", "ticket"}
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
@ -34,7 +34,7 @@ func TestCreate(t *testing.T) {
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnResult(sqlmock.NewResult(1, 1))
if err := repository.NewSQLite3TicketRepository(db, domain.TestConfig(t)). if err := repository.NewSQLite3TicketRepository(db, domain.TestConfig(t)).
Create(context.Background(), ticket); err != nil { Create(context.Background(), *ticket); err != nil {
t.Error(err) t.Error(err)
} }
} }

View File

@ -7,10 +7,10 @@ import (
) )
type UseCase interface { type UseCase interface {
Generate(ctx context.Context, ticket *domain.Ticket) error Generate(ctx context.Context, ticket domain.Ticket) error
// Redeem transform received ticket into access token. // Redeem transform received ticket into access token.
Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error) Redeem(ctx context.Context, ticket domain.Ticket) (*domain.Token, error)
Exchange(ctx context.Context, ticket string) (*domain.Token, error) Exchange(ctx context.Context, ticket string) (*domain.Token, error)
} }

View File

@ -1,13 +1,15 @@
package usecase package usecase
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io"
"net/http"
"net/url" "net/url"
"time" "time"
json "github.com/goccy/go-json" json "github.com/goccy/go-json"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -47,26 +49,28 @@ func NewTicketUseCase(tickets ticket.Repository, client *http.Client, config *do
} }
} }
func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket) error { func (useCase *ticketUseCase) Generate(ctx context.Context, tkt domain.Ticket) error {
req := http.AcquireRequest() resp, err := useCase.client.Get(tkt.Subject.String())
defer http.ReleaseRequest(req) if err != nil {
req.Header.SetMethod(http.MethodGet)
req.SetRequestURI(tkt.Subject.String())
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
if err := useCase.client.Do(req, resp); err != nil {
return fmt.Errorf("cannot discovery ticket subject: %w", err) return fmt.Errorf("cannot discovery ticket subject: %w", err)
} }
var ticketEndpoint *url.URL body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("cannot read response body: %w", err)
}
buf := bytes.NewReader(body)
ticketEndpoint := new(url.URL)
// NOTE(toby3d): find metadata first // NOTE(toby3d): find metadata first
if metadata, err := httputil.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil { metadata, err := httputil.ExtractFromMetadata(useCase.client, tkt.Subject.String())
if err == nil && metadata != nil {
ticketEndpoint = metadata.TicketEndpoint ticketEndpoint = metadata.TicketEndpoint
} else { // NOTE(toby3d): fallback to old links searching } else { // NOTE(toby3d): fallback to old links searching
if endpoints := httputil.ExtractEndpoints(resp, "ticket_endpoint"); len(endpoints) > 0 { endpoints := httputil.ExtractEndpoints(buf, tkt.Subject.URL(), resp.Header.Get(common.HeaderLink),
"ticket_endpoint")
if len(endpoints) > 0 {
ticketEndpoint = endpoints[len(endpoints)-1] ticketEndpoint = endpoints[len(endpoints)-1]
} }
} }
@ -79,65 +83,59 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket)
return fmt.Errorf("cannot save ticket in store: %w", err) return fmt.Errorf("cannot save ticket in store: %w", err)
} }
req.Reset() payload := make(url.Values)
req.Header.SetMethod(http.MethodPost) payload.Set("ticket", tkt.Ticket)
req.SetRequestURI(ticketEndpoint.String()) payload.Set("subject", tkt.Subject.String())
req.Header.SetContentType(common.MIMEApplicationForm) payload.Set("resource", tkt.Resource.String())
req.PostArgs().Set("ticket", tkt.Ticket)
req.PostArgs().Set("subject", tkt.Subject.String())
req.PostArgs().Set("resource", tkt.Resource.String())
resp.Reset()
if err := useCase.client.Do(req, resp); err != nil { if _, err = useCase.client.PostForm(ticketEndpoint.String(), payload); err != nil {
return fmt.Errorf("cannot send ticket to subject ticket_endpoint: %w", err) return fmt.Errorf("cannot send ticket to subject ticket_endpoint: %w", err)
} }
return nil return nil
} }
func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (*domain.Token, error) { func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt domain.Ticket) (*domain.Token, error) {
req := http.AcquireRequest() resp, err := useCase.client.Get(tkt.Resource.String())
defer http.ReleaseRequest(req) if err != nil {
req.SetRequestURI(tkt.Resource.String())
req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
if err := useCase.client.Do(req, resp); err != nil {
return nil, fmt.Errorf("cannot discovery ticket resource: %w", err) return nil, fmt.Errorf("cannot discovery ticket resource: %w", err)
} }
var tokenEndpoint *url.URL body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("cannot read response body: %w", err)
}
buf := bytes.NewReader(body)
tokenEndpoint := new(url.URL)
// NOTE(toby3d): find metadata first // NOTE(toby3d): find metadata first
if metadata, err := httputil.ExtractMetadata(resp, useCase.client); err == nil && metadata != nil { metadata, err := httputil.ExtractFromMetadata(useCase.client, tkt.Resource.String())
if err == nil && metadata != nil {
tokenEndpoint = metadata.TokenEndpoint tokenEndpoint = metadata.TokenEndpoint
} else { // NOTE(toby3d): fallback to old links searching } else { // NOTE(toby3d): fallback to old links searching
if endpoints := httputil.ExtractEndpoints(resp, "token_endpoint"); len(endpoints) > 0 { endpoints := httputil.ExtractEndpoints(buf, tkt.Resource, resp.Header.Get(common.HeaderLink),
"token_endpoint")
if len(endpoints) > 0 {
tokenEndpoint = endpoints[len(endpoints)-1] tokenEndpoint = endpoints[len(endpoints)-1]
} }
} }
if tokenEndpoint == nil { if tokenEndpoint == nil || tokenEndpoint.String() == "" {
return nil, ticket.ErrTokenEndpointNotExist return nil, ticket.ErrTokenEndpointNotExist
} }
req.Reset() payload := make(url.Values)
req.Header.SetMethod(http.MethodPost) payload.Set("grant_type", domain.GrantTypeTicket.String())
req.SetRequestURI(tokenEndpoint.String()) payload.Set("ticket", tkt.Ticket)
req.Header.SetContentType(common.MIMEApplicationForm)
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String())
req.PostArgs().Set("ticket", tkt.Ticket)
resp.Reset()
if err := useCase.client.Do(req, resp); err != nil { resp, err = useCase.client.PostForm(tokenEndpoint.String(), payload)
if err != nil {
return nil, fmt.Errorf("cannot exchange ticket on token_endpoint: %w", err) return nil, fmt.Errorf("cannot exchange ticket on token_endpoint: %w", err)
} }
data := new(AccessToken) data := new(AccessToken)
if err := json.Unmarshal(resp.Body(), data); err != nil { if err := json.NewDecoder(resp.Body).Decode(data); err != nil {
return nil, fmt.Errorf("cannot unmarshal access token response: %w", err) return nil, fmt.Errorf("cannot unmarshal access token response: %w", err)
} }
@ -147,8 +145,8 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (*
Scope: nil, // TODO(toby3d) Scope: nil, // TODO(toby3d)
// TODO(toby3d): should this also include client_id? // TODO(toby3d): should this also include client_id?
// https://github.com/indieweb/indieauth/issues/85 // https://github.com/indieweb/indieauth/issues/85
ClientID: nil, ClientID: domain.ClientID{},
Me: data.Me, Me: *data.Me,
AccessToken: data.AccessToken, AccessToken: data.AccessToken,
RefreshToken: "", // TODO(toby3d) RefreshToken: "", // TODO(toby3d)
}, nil }, nil
@ -163,8 +161,8 @@ func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*dom
token, err := domain.NewToken(domain.NewTokenOptions{ token, err := domain.NewToken(domain.NewTokenOptions{
Expiration: useCase.config.JWT.Expiry, Expiration: useCase.config.JWT.Expiry,
Scope: domain.Scopes{domain.ScopeRead}, Scope: domain.Scopes{domain.ScopeRead},
Issuer: nil, Issuer: domain.ClientID{},
Subject: tkt.Subject, Subject: *tkt.Subject,
Secret: []byte(useCase.config.JWT.Secret), Secret: []byte(useCase.config.JWT.Secret),
Algorithm: useCase.config.JWT.Algorithm, Algorithm: useCase.config.JWT.Algorithm,
NonceLength: useCase.config.JWT.NonceLength, NonceLength: useCase.config.JWT.NonceLength,

View File

@ -3,14 +3,13 @@ package usecase_test
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing" "testing"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
ucase "source.toby3d.me/toby3d/auth/internal/ticket/usecase" ucase "source.toby3d.me/toby3d/auth/internal/ticket/usecase"
) )
@ -20,25 +19,33 @@ func TestRedeem(t *testing.T) {
token := domain.TestToken(t) token := domain.TestToken(t)
ticket := domain.TestTicket(t) ticket := domain.TestTicket(t)
router := router.New() tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
router.GET(string(ticket.Resource.Path), func(ctx *http.RequestCtx) { if r.Method != http.MethodPost {
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, `<link rel="token_endpoint" href="`+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
ticket.Subject.String()+`token">`)
}) return
router.POST("/token", func(ctx *http.RequestCtx) { }
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, fmt.Sprintf(`{
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
fmt.Fprintf(w, `{
"token_type": "Bearer", "token_type": "Bearer",
"access_token": "%s", "access_token": "%s",
"scope": "%s", "scope": "%s",
"me": "%s" "me": "%s"
}`, token.AccessToken, token.Scope.String(), token.Me.String())) }`, token.AccessToken, token.Scope.String(), token.Me.String())
}) }))
t.Cleanup(tokenServer.Close)
client, _, cleanup := httptest.New(t, router.Handler) subjectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Cleanup(cleanup) w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
fmt.Fprint(w, `<link rel="token_endpoint" href="`+tokenServer.URL+`/token">`)
}))
t.Cleanup(subjectServer.Close)
result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)). ticket.Resource, _ = url.Parse(subjectServer.URL + "/")
Redeem(context.Background(), ticket)
result, err := ucase.NewTicketUseCase(nil, subjectServer.Client(), domain.TestConfig(t)).
Redeem(context.Background(), *ticket)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1,187 +1,100 @@
package http package http
import ( import (
"errors" "net/http"
"path"
"github.com/fasthttp/router" "github.com/goccy/go-json"
json "github.com/goccy/go-json"
"github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwa"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/ticket" "source.toby3d.me/toby3d/auth/internal/ticket"
"source.toby3d.me/toby3d/auth/internal/token" "source.toby3d.me/toby3d/auth/internal/token"
"source.toby3d.me/toby3d/form" "source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/middleware"
) )
type ( type Handler struct {
TokenExchangeRequest struct { config *domain.Config
ClientID *domain.ClientID `form:"client_id"` tokens token.UseCase
RedirectURI *domain.URL `form:"redirect_uri"` tickets ticket.UseCase
GrantType domain.GrantType `form:"grant_type"` }
Code string `form:"code"`
CodeVerifier string `form:"code_verifier"`
}
TokenRefreshRequest struct { func NewHandler(tokens token.UseCase, tickets ticket.UseCase, config *domain.Config) *Handler {
GrantType domain.GrantType `form:"grant_type"` // refresh_token return &Handler{
// The refresh token previously offered to the client.
RefreshToken string `form:"refresh_token"`
// The client ID that was used when the refresh token was issued.
ClientID *domain.ClientID `form:"client_id"`
// The client may request a token with the same or fewer scopes
// than the original access token. If omitted, is treated as
// equal to the original scopes granted.
Scope domain.Scopes `form:"scope"`
}
TokenRevocationRequest struct {
Action domain.Action `form:"action,omitempty"`
Token string `form:"token"`
}
TokenTicketRequest struct {
Action domain.Action `form:"action"`
Ticket string `form:"ticket"`
}
TokenIntrospectRequest struct {
Token string `form:"token"`
}
//nolint:tagliatelle // https://indieauth.net/source/#access-token-response
TokenExchangeResponse struct {
// The OAuth 2.0 Bearer Token RFC6750.
AccessToken string `json:"access_token"`
// The canonical user profile URL for the user this access token
// corresponds to.
Me string `json:"me"`
// The user's profile information.
Profile *TokenProfileResponse `json:"profile,omitempty"`
// The lifetime in seconds of the access token.
ExpiresIn int64 `json:"expires_in,omitempty"`
// The refresh token, which can be used to obtain new access
// tokens.
RefreshToken string `json:"refresh_token"`
}
TokenProfileResponse struct {
// Name the user wishes to provide to the client.
Name string `json:"name,omitempty"`
// URL of the user's website.
URL string `json:"url,omitempty"`
// A photo or image that the user wishes clients to use as a
// profile image.
Photo string `json:"photo,omitempty"`
// The email address a user wishes to provide to the client.
Email string `json:"email,omitempty"`
}
//nolint:tagliatelle // https://indieauth.net/source/#access-token-verification-response
TokenIntrospectResponse struct {
// Boolean indicator of whether or not the presented token is
// currently active.
Active bool `json:"active"`
// The profile URL of the user corresponding to this token.
Me string `json:"me"`
// The client ID associated with this token.
ClientID string `json:"client_id"`
// A space-separated list of scopes associated with this token.
Scope string `json:"scope"`
// Integer timestamp, measured in the number of seconds since
// January 1 1970 UTC, indicating when this token will expire.
Exp int64 `json:"exp,omitempty"`
// Integer timestamp, measured in the number of seconds since
// January 1 1970 UTC, indicating when this token was originally
// issued.
Iat int64 `json:"iat,omitempty"`
}
TokenInvalidIntrospectResponse struct {
Active bool `json:"active"`
}
TokenRevocationResponse struct{}
RequestHandler struct {
config *domain.Config
tokens token.UseCase
tickets ticket.UseCase
}
)
func NewRequestHandler(tokens token.UseCase, tickets ticket.UseCase, config *domain.Config) *RequestHandler {
return &RequestHandler{
config: config, config: config,
tokens: tokens, tokens: tokens,
tickets: tickets, tickets: tickets,
} }
} }
func (h *RequestHandler) Register(r *router.Router) { func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{ chain := middleware.Chain{
//nolint:exhaustivestruct //nolint:exhaustivestruct
middleware.JWTWithConfig(middleware.JWTConfig{ middleware.JWTWithConfig(middleware.JWTConfig{
AuthScheme: "Bearer", Skipper: func(_ http.ResponseWriter, r *http.Request) bool {
ContextKey: "token", head, _ := urlutil.ShiftPath(r.URL.Path)
return head == "token"
},
SigningKey: []byte(h.config.JWT.Secret), SigningKey: []byte(h.config.JWT.Secret),
SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm), SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm),
Skipper: func(ctx *http.RequestCtx) bool { ContextKey: "token",
matched, _ := path.Match("/token*", string(ctx.Path())) TokenLookup: "form:token," + "header:" + common.HeaderAuthorization + ":Bearer ",
AuthScheme: "Bearer",
return matched
},
SuccessHandler: nil,
TokenLookup: "param:token,header:" + http.HeaderAuthorization + ":Bearer ",
}), }),
middleware.LogFmt(),
} }
r.POST("/token", chain.RequestHandler(h.handleAction)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.POST("/introspect", chain.RequestHandler(h.handleIntrospect)) if r.Method != http.MethodPost {
r.POST("/revocation", chain.RequestHandler(h.handleRevokation)) http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
var head string
head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
switch head {
default:
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
case "token":
chain.Handler(h.handleAction).ServeHTTP(w, r)
case "introspect":
chain.Handler(h.handleIntrospect).ServeHTTP(w, r)
case "revocation":
chain.Handler(h.handleRevokation).ServeHTTP(w, r)
}
})
} }
func (h *RequestHandler) handleIntrospect(ctx *http.RequestCtx) { func (h *Handler) handleIntrospect(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) if r.Method != http.MethodPost {
ctx.SetStatusCode(http.StatusOK) http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx)
req := new(TokenIntrospectRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
_ = encoder.Encode(err)
return return
} }
tkn, _, err := h.tokens.Verify(ctx, req.Token) w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
req := new(TokenIntrospectRequest)
if err := req.bind(r); err != nil {
_ = encoder.Encode(err)
w.WriteHeader(http.StatusBadRequest)
return
}
tkn, _, err := h.tokens.Verify(r.Context(), req.Token)
if err != nil || tkn == nil { if err != nil || tkn == nil {
// WARN(toby3d): If the token is not valid, the endpoint still // WARN(toby3d): If the token is not valid, the endpoint still
// MUST return a 200 Response. // MUST return a 200 Response.
_ = encoder.Encode(&TokenInvalidIntrospectResponse{ _ = encoder.Encode(&TokenInvalidIntrospectResponse{Active: false})
Active: false,
}) w.WriteHeader(http.StatusOK)
return return
} }
@ -194,68 +107,83 @@ func (h *RequestHandler) handleIntrospect(ctx *http.RequestCtx) {
Me: tkn.Me.String(), Me: tkn.Me.String(),
Scope: tkn.Scope.String(), Scope: tkn.Scope.String(),
}) })
w.WriteHeader(http.StatusOK)
} }
func (h *RequestHandler) handleAction(ctx *http.RequestCtx) { func (h *Handler) handleAction(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) if r.Method != http.MethodPost {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx) return
}
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
switch { switch {
case ctx.PostArgs().Has("grant_type"): case r.PostForm.Has("grant_type"):
h.handleExchange(ctx) h.handleExchange(w, r)
case ctx.PostArgs().Has("action"): case r.PostForm.Has("action"):
action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action"))) if err := r.ParseForm(); err != nil {
if err != nil { w.WriteHeader(http.StatusBadRequest)
ctx.SetStatusCode(http.StatusBadRequest)
_ = encoder.Encode(domain.NewError( _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), ""))
domain.ErrorCodeInvalidRequest,
err.Error(), return
"", }
))
action, err := domain.ParseAction(r.PostForm.Get("action"))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), ""))
return return
} }
switch action { switch action {
case domain.ActionRevoke: case domain.ActionRevoke:
h.handleRevokation(ctx) h.handleRevokation(w, r)
case domain.ActionTicket: case domain.ActionTicket:
h.handleTicket(ctx) h.handleTicket(w, r)
} }
} }
} }
//nolint:funlen //nolint:funlen
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { func (h *Handler) handleExchange(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) if r.Method != http.MethodPost {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx) return
}
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
req := new(TokenExchangeRequest) req := new(TokenExchangeRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err) _ = encoder.Encode(err)
return return
} }
token, profile, err := h.tokens.Exchange(ctx, token.ExchangeOptions{ token, profile, err := h.tokens.Exchange(r.Context(), token.ExchangeOptions{
ClientID: req.ClientID, ClientID: req.ClientID,
RedirectURI: req.RedirectURI.URL, RedirectURI: req.RedirectURI.URL,
Code: req.Code, Code: req.Code,
CodeVerifier: req.CodeVerifier, CodeVerifier: req.CodeVerifier,
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(domain.NewError( _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(),
domain.ErrorCodeInvalidRequest, "https://indieauth.net/source/#request"))
err.Error(),
"https://indieauth.net/source/#request",
))
return return
} }
@ -294,62 +222,69 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
} }
_ = encoder.Encode(resp) _ = encoder.Encode(resp)
w.WriteHeader(http.StatusOK)
} }
func (h *RequestHandler) handleRevokation(ctx *http.RequestCtx) { func (h *Handler) handleRevokation(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) if r.Method != http.MethodPost {
ctx.SetStatusCode(http.StatusOK) http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx) return
}
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
req := NewTokenRevocationRequest() req := NewTokenRevocationRequest()
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err) _ = encoder.Encode(err)
return return
} }
if err := h.tokens.Revoke(ctx, req.Token); err != nil { if err := h.tokens.Revoke(r.Context(), req.Token); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(domain.NewError( _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), ""))
domain.ErrorCodeInvalidRequest,
err.Error(),
"",
))
return return
} }
_ = encoder.Encode(&TokenRevocationResponse{}) _ = encoder.Encode(&TokenRevocationResponse{})
w.WriteHeader(http.StatusOK)
} }
func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) { func (h *Handler) handleTicket(w http.ResponseWriter, r *http.Request) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) if r.Method != http.MethodPost {
ctx.SetStatusCode(http.StatusOK) http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx) return
}
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
req := new(TokenTicketRequest) req := new(TokenTicketRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(r); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err) _ = encoder.Encode(err)
return return
} }
tkn, err := h.tickets.Exchange(ctx, req.Ticket) tkn, err := h.tickets.Exchange(r.Context(), req.Ticket)
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
_ = encoder.Encode(domain.NewError( _ = encoder.Encode(domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(),
domain.ErrorCodeInvalidRequest, "https://indieauth.net/source/#request"))
err.Error(),
"https://indieauth.net/source/#request",
))
return return
} }
@ -361,81 +296,6 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
ExpiresIn: tkn.Expiry.Unix(), ExpiresIn: tkn.Expiry.Unix(),
RefreshToken: "", // TODO(toby3d) RefreshToken: "", // TODO(toby3d)
}) })
}
w.WriteHeader(http.StatusOK)
func (r *TokenExchangeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#request",
)
}
return nil
}
func NewTokenRevocationRequest() *TokenRevocationRequest {
return &TokenRevocationRequest{
Action: domain.ActionRevoke,
Token: "",
}
}
func (r *TokenRevocationRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
err := form.Unmarshal(ctx.PostArgs().QueryString(), r)
if err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#request",
)
}
return nil
}
func (r *TokenTicketRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#request",
)
}
return nil
}
func (r *TokenIntrospectRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#access-token-verification-request",
)
}
return nil
} }

View File

@ -0,0 +1,210 @@
package http
import (
"errors"
"net/http"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/form"
)
type (
TokenExchangeRequest struct {
ClientID domain.ClientID `form:"client_id"`
RedirectURI domain.URL `form:"redirect_uri"`
GrantType domain.GrantType `form:"grant_type"`
Code string `form:"code"`
CodeVerifier string `form:"code_verifier"`
}
TokenRefreshRequest struct {
GrantType domain.GrantType `form:"grant_type"` // refresh_token
// The client ID that was used when the refresh token was issued.
ClientID domain.ClientID `form:"client_id"`
// The client may request a token with the same or fewer scopes
// than the original access token. If omitted, is treated as
// equal to the original scopes granted.
Scope domain.Scopes `form:"scope"`
// The refresh token previously offered to the client.
RefreshToken string `form:"refresh_token"`
}
TokenRevocationRequest struct {
Action domain.Action `form:"action,omitempty"`
Token string `form:"token"`
}
TokenTicketRequest struct {
Action domain.Action `form:"action"`
Ticket string `form:"ticket"`
}
TokenIntrospectRequest struct {
Token string `form:"token"`
}
//nolint:tagliatelle // https://indieauth.net/source/#access-token-response
TokenExchangeResponse struct {
// The user's profile information.
Profile *TokenProfileResponse `json:"profile,omitempty"`
// The OAuth 2.0 Bearer Token RFC6750.
AccessToken string `json:"access_token"`
// The canonical user profile URL for the user this access token
// corresponds to.
Me string `json:"me"`
// The refresh token, which can be used to obtain new access
// tokens.
RefreshToken string `json:"refresh_token"`
// The lifetime in seconds of the access token.
ExpiresIn int64 `json:"expires_in,omitempty"`
}
TokenProfileResponse struct {
// Name the user wishes to provide to the client.
Name string `json:"name,omitempty"`
// URL of the user's website.
URL string `json:"url,omitempty"`
// A photo or image that the user wishes clients to use as a
// profile image.
Photo string `json:"photo,omitempty"`
// The email address a user wishes to provide to the client.
Email string `json:"email,omitempty"`
}
//nolint:tagliatelle // https://indieauth.net/source/#access-token-verification-response
TokenIntrospectResponse struct {
// The profile URL of the user corresponding to this token.
Me string `json:"me"`
// The client ID associated with this token.
ClientID string `json:"client_id"`
// A space-separated list of scopes associated with this token.
Scope string `json:"scope"`
// Integer timestamp, measured in the number of seconds since
// January 1 1970 UTC, indicating when this token will expire.
Exp int64 `json:"exp,omitempty"`
// Integer timestamp, measured in the number of seconds since
// January 1 1970 UTC, indicating when this token was originally
// issued.
Iat int64 `json:"iat,omitempty"`
// Boolean indicator of whether or not the presented token is
// currently active.
Active bool `json:"active"`
}
TokenInvalidIntrospectResponse struct {
Active bool `json:"active"`
}
TokenRevocationResponse struct{}
)
func (r *TokenExchangeRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#request",
)
}
return nil
}
func NewTokenRevocationRequest() *TokenRevocationRequest {
return &TokenRevocationRequest{
Action: domain.ActionRevoke,
Token: "",
}
}
func (r *TokenRevocationRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := req.ParseForm(); err != nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
err := form.Unmarshal([]byte(req.PostForm.Encode()), r)
if err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#request",
)
}
return nil
}
func (r *TokenTicketRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#request",
)
}
return nil
}
func (r *TokenIntrospectRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := req.ParseForm(); err != nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#access-token-verification-request",
)
}
return nil
}

View File

@ -3,12 +3,13 @@ package http_test
import ( import (
"bytes" "bytes"
"context" "context"
"sync" "io"
"net/http"
"net/http/httptest"
"strings"
"testing" "testing"
"github.com/fasthttp/router" "github.com/goccy/go-json"
json "github.com/goccy/go-json"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -16,7 +17,6 @@ import (
profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory" profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory"
"source.toby3d.me/toby3d/auth/internal/session" "source.toby3d.me/toby3d/auth/internal/session"
sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory" sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
"source.toby3d.me/toby3d/auth/internal/ticket" "source.toby3d.me/toby3d/auth/internal/ticket"
ticketrepo "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory" ticketrepo "source.toby3d.me/toby3d/auth/internal/ticket/repository/memory"
ticketucase "source.toby3d.me/toby3d/auth/internal/ticket/usecase" ticketucase "source.toby3d.me/toby3d/auth/internal/ticket/usecase"
@ -31,7 +31,6 @@ type Dependencies struct {
config *domain.Config config *domain.Config
profiles profile.Repository profiles profile.Repository
sessions session.Repository sessions session.Repository
store *sync.Map
tickets ticket.Repository tickets ticket.Repository
ticketService ticket.UseCase ticketService ticket.UseCase
token *domain.Token token *domain.Token
@ -50,32 +49,24 @@ func TestIntrospection(t *testing.T) {
deps := NewDependencies(t) deps := NewDependencies(t)
r := router.New() req := httptest.NewRequest(http.MethodPost, "https://app.example.com/introspect",
delivery.NewRequestHandler(deps.tokenService, deps.ticketService, deps.config).Register(r) strings.NewReader("token="+deps.token.AccessToken))
req.Header.Set(common.HeaderAccept, common.MIMEApplicationJSON)
req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm)
client, _, cleanup := httptest.New(t, r.Handler) w := httptest.NewRecorder()
t.Cleanup(cleanup) delivery.NewHandler(deps.tokenService, deps.ticketService, deps.config).
Handler().
ServeHTTP(w, req)
const requestURL = "https://app.example.com/introspect" resp := w.Result()
req := httptest.NewRequest(http.MethodPost, requestURL, []byte("token="+deps.token.AccessToken)) if result := resp.StatusCode; result != http.StatusOK {
defer http.ReleaseRequest(req) t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, result, http.StatusOK)
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
req.Header.SetContentType(common.MIMEApplicationForm)
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
if err := client.Do(req, resp); err != nil {
t.Fatal(err)
}
if result := resp.StatusCode(); result != http.StatusOK {
t.Errorf("GET %s = %d, want %d", requestURL, result, http.StatusOK)
} }
result := new(delivery.TokenIntrospectResponse) result := new(delivery.TokenIntrospectResponse)
if err := json.Unmarshal(resp.Body(), result); err != nil { if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -84,7 +75,7 @@ func TestIntrospection(t *testing.T) {
if result.ClientID != deps.token.ClientID.String() || if result.ClientID != deps.token.ClientID.String() ||
result.Me != deps.token.Me.String() || result.Me != deps.token.Me.String() ||
result.Scope != deps.token.Scope.String() { result.Scope != deps.token.Scope.String() {
t.Errorf("GET %s = %+v, want %+v", requestURL, result, deps.token) t.Errorf("%s %s = %+v, want %+v", req.Method, req.RequestURI, result, deps.token)
} }
} }
@ -93,33 +84,30 @@ func TestRevocation(t *testing.T) {
deps := NewDependencies(t) deps := NewDependencies(t)
r := router.New() req := httptest.NewRequest(http.MethodPost, "https://app.example.com/revocation",
delivery.NewRequestHandler(deps.tokenService, deps.ticketService, deps.config).Register(r) strings.NewReader(`token=`+deps.token.AccessToken))
req.Header.Set(common.HeaderContentType, common.MIMEApplicationForm)
req.Header.Set(common.HeaderAccept, common.MIMEApplicationJSON)
client, _, cleanup := httptest.New(t, r.Handler) w := httptest.NewRecorder()
t.Cleanup(cleanup) delivery.NewHandler(deps.tokenService, deps.ticketService, deps.config).
Handler().
ServeHTTP(w, req)
const requestURL = "https://app.example.com/revocation" resp := w.Result()
req := httptest.NewRequest(http.MethodPost, requestURL, []byte("token="+deps.token.AccessToken)) body, err := io.ReadAll(resp.Body)
defer http.ReleaseRequest(req) if err != nil {
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
req.Header.SetContentType(common.MIMEApplicationForm)
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
if err := client.Do(req, resp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if result := resp.StatusCode(); result != http.StatusOK { if resp.StatusCode != http.StatusOK {
t.Errorf("POST %s = %d, want %d", requestURL, result, http.StatusOK) t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK)
} }
expBody := []byte("{}") //nolint:ifshort expBody := []byte("{}") //nolint:ifshort
if result := bytes.TrimSpace(resp.Body()); !bytes.Equal(result, expBody) { if result := bytes.TrimSpace(body); !bytes.Equal(result, expBody) {
t.Errorf("POST %s = %s, want %s", requestURL, result, expBody) t.Errorf("%s %s = %s, want %s", req.Method, req.RequestURI, result, expBody)
} }
result, err := deps.tokens.Get(context.Background(), deps.token.AccessToken) result, err := deps.tokens.Get(context.Background(), deps.token.AccessToken)
@ -135,14 +123,13 @@ func TestRevocation(t *testing.T) {
func NewDependencies(tb testing.TB) Dependencies { func NewDependencies(tb testing.TB) Dependencies {
tb.Helper() tb.Helper()
store := new(sync.Map)
client := new(http.Client) client := new(http.Client)
config := domain.TestConfig(tb) config := domain.TestConfig(tb)
token := domain.TestToken(tb) token := domain.TestToken(tb)
profiles := profilerepo.NewMemoryProfileRepository(store) profiles := profilerepo.NewMemoryProfileRepository()
sessions := sessionrepo.NewMemorySessionRepository(store, config) sessions := sessionrepo.NewMemorySessionRepository(*config)
tickets := ticketrepo.NewMemoryTicketRepository(store, config) tickets := ticketrepo.NewMemoryTicketRepository(*config)
tokens := tokenrepo.NewMemoryTokenRepository(store) tokens := tokenrepo.NewMemoryTokenRepository()
ticketService := ticketucase.NewTicketUseCase(tickets, client, config) ticketService := ticketucase.NewTicketUseCase(tickets, client, config)
tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{ tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{
Config: config, Config: config,
@ -156,7 +143,6 @@ func NewDependencies(tb testing.TB) Dependencies {
config: config, config: config,
profiles: profiles, profiles: profiles,
sessions: sessions, sessions: sessions,
store: store,
tickets: tickets, tickets: tickets,
ticketService: ticketService, ticketService: ticketService,
token: token, token: token,

View File

@ -7,8 +7,8 @@ import (
) )
type Repository interface { type Repository interface {
Create(ctx context.Context, accessToken domain.Token) error
Get(ctx context.Context, accessToken string) (*domain.Token, error) Get(ctx context.Context, accessToken string) (*domain.Token, error)
Create(ctx context.Context, accessToken *domain.Token) error
} }
var ( var (

View File

@ -2,8 +2,6 @@ package memory
import ( import (
"context" "context"
"errors"
"path"
"sync" "sync"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -11,42 +9,33 @@ import (
) )
type memoryTokenRepository struct { type memoryTokenRepository struct {
store *sync.Map mutex *sync.RWMutex
tokens map[string]domain.Token
} }
const DefaultPathPrefix string = "tokens" func NewMemoryTokenRepository() token.Repository {
func NewMemoryTokenRepository(store *sync.Map) token.Repository {
return &memoryTokenRepository{ return &memoryTokenRepository{
store: store, mutex: new(sync.RWMutex),
tokens: make(map[string]domain.Token),
} }
} }
func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { func (repo *memoryTokenRepository) Create(ctx context.Context, accessToken domain.Token) error {
t, err := repo.Get(ctx, accessToken.AccessToken) repo.mutex.Lock()
if err != nil && !errors.Is(err, token.ErrNotExist) { defer repo.mutex.Unlock()
return err
}
if t != nil { repo.tokens[accessToken.AccessToken] = accessToken
return token.ErrExist
}
repo.store.Store(path.Join(DefaultPathPrefix, accessToken.AccessToken), accessToken)
return nil return nil
} }
func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { func (repo *memoryTokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
t, ok := repo.store.Load(path.Join(DefaultPathPrefix, accessToken)) repo.mutex.RLock()
if !ok { defer repo.mutex.RUnlock()
return nil, token.ErrNotExist
if t, ok := repo.tokens[accessToken]; ok {
return &t, nil
} }
result, ok := t.(*domain.Token) return nil, token.ErrNotExist
if !ok {
return nil, token.ErrNotExist
}
return result, nil
} }

View File

@ -1,45 +0,0 @@
package memory_test
import (
"context"
"path"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/token"
repository "source.toby3d.me/toby3d/auth/internal/token/repository/memory"
)
func TestCreate(t *testing.T) {
t.Parallel()
store := new(sync.Map)
accessToken := domain.TestToken(t)
repo := repository.NewMemoryTokenRepository(store)
if err := repo.Create(context.Background(), accessToken); err != nil {
t.Fatal(err)
}
result, ok := store.Load(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken))
assert.True(t, ok)
assert.Equal(t, accessToken, result)
assert.ErrorIs(t, repo.Create(context.Background(), accessToken), token.ErrExist)
}
func TestGet(t *testing.T) {
t.Parallel()
store := new(sync.Map)
accessToken := domain.TestToken(t)
store.Store(path.Join(repository.DefaultPathPrefix, accessToken.AccessToken), accessToken)
result, err := repository.NewMemoryTokenRepository(store).Get(context.Background(), accessToken.AccessToken)
assert.NoError(t, err)
assert.Equal(t, accessToken, result)
}

View File

@ -53,8 +53,8 @@ func NewSQLite3TokenRepository(db *sqlx.DB) token.Repository {
} }
} }
func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *domain.Token) error { func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken domain.Token) error {
if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewToken(accessToken)); err != nil { if _, err := repo.db.NamedExecContext(ctx, QueryCreate, NewToken(&accessToken)); err != nil {
return fmt.Errorf("cannot create token record in db: %w", err) return fmt.Errorf("cannot create token record in db: %w", err)
} }
@ -91,9 +91,11 @@ func NewToken(src *domain.Token) *Token {
} }
func (t *Token) Populate(dst *domain.Token) { func (t *Token) Populate(dst *domain.Token) {
cid, _ := domain.ParseClientID(t.ClientID)
me, _ := domain.ParseMe(t.Me)
dst.AccessToken = t.AccessToken dst.AccessToken = t.AccessToken
dst.ClientID, _ = domain.ParseClientID(t.ClientID) dst.ClientID = *cid
dst.Me, _ = domain.ParseMe(t.Me) dst.Me = *me
dst.Scope = make(domain.Scopes, 0) dst.Scope = make(domain.Scopes, 0)
for _, scope := range strings.Fields(t.Scope) { for _, scope := range strings.Fields(t.Scope) {

View File

@ -35,7 +35,7 @@ func TestCreate(t *testing.T) {
). ).
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnResult(sqlmock.NewResult(1, 1))
if err := repository.NewSQLite3TokenRepository(db).Create(context.Background(), token); err != nil { if err := repository.NewSQLite3TokenRepository(db).Create(context.Background(), *token); err != nil {
t.Error(err) t.Error(err)
} }
} }

View File

@ -9,7 +9,7 @@ import (
type ( type (
ExchangeOptions struct { ExchangeOptions struct {
ClientID *domain.ClientID ClientID domain.ClientID
RedirectURI *url.URL RedirectURI *url.URL
Code string Code string
CodeVerifier string CodeVerifier string

View File

@ -107,17 +107,17 @@ func (uc *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain
return nil, nil, fmt.Errorf("cannot validate JWT token: %w", err) return nil, nil, fmt.Errorf("cannot validate JWT token: %w", err)
} }
cid, _ := domain.ParseClientID(tkn.Issuer())
me, _ := domain.ParseMe(tkn.Subject())
result := &domain.Token{ result := &domain.Token{
CreatedAt: tkn.IssuedAt(), CreatedAt: tkn.IssuedAt(),
Expiry: tkn.Expiration(), Expiry: tkn.Expiration(),
ClientID: nil, ClientID: *cid,
Me: nil, Me: *me,
Scope: nil, Scope: nil,
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: "", // TODO(toby3d) RefreshToken: "", // TODO(toby3d)
} }
result.ClientID, _ = domain.ParseClientID(tkn.Issuer())
result.Me, _ = domain.ParseMe(tkn.Subject())
if scope, ok := tkn.Get("scope"); ok { if scope, ok := tkn.Get("scope"); ok {
result.Scope, _ = scope.(domain.Scopes) result.Scope, _ = scope.(domain.Scopes)
@ -149,7 +149,7 @@ func (uc *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
return fmt.Errorf("cannot verify token: %w", err) return fmt.Errorf("cannot verify token: %w", err)
} }
if err = uc.tokens.Create(ctx, tkn); err != nil && !errors.Is(err, token.ErrExist) { if err = uc.tokens.Create(ctx, *tkn); err != nil && !errors.Is(err, token.ErrExist) {
return fmt.Errorf("cannot save token in database: %w", err) return fmt.Errorf("cannot save token in database: %w", err)
} }

View File

@ -2,8 +2,6 @@ package usecase_test
import ( import (
"context" "context"
"path"
"sync"
"testing" "testing"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -22,7 +20,6 @@ type Dependencies struct {
profiles profile.Repository profiles profile.Repository
session *domain.Session session *domain.Session
sessions session.Repository sessions session.Repository
store *sync.Map
token *domain.Token token *domain.Token
tokens token.Repository tokens token.Repository
} }
@ -31,9 +28,12 @@ func TestExchange(t *testing.T) {
t.Parallel() t.Parallel()
deps := NewDependencies(t) deps := NewDependencies(t)
deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, deps.session.Me.String()), deps.profile)
if err := deps.sessions.Create(context.Background(), deps.session); err != nil { if err := deps.profiles.Create(context.Background(), deps.session.Me, *deps.profile); err != nil {
t.Fatal(err)
}
if err := deps.sessions.Create(context.Background(), *deps.session); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -95,7 +95,7 @@ func TestVerify(t *testing.T) {
t.Parallel() t.Parallel()
testToken := domain.TestToken(t) testToken := domain.TestToken(t)
if err := deps.tokens.Create(context.Background(), testToken); err != nil { if err := deps.tokens.Create(context.Background(), *testToken); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -136,17 +136,15 @@ func TestRevoke(t *testing.T) {
func NewDependencies(tb testing.TB) Dependencies { func NewDependencies(tb testing.TB) Dependencies {
tb.Helper() tb.Helper()
store := new(sync.Map)
config := domain.TestConfig(tb) config := domain.TestConfig(tb)
return Dependencies{ return Dependencies{
config: config, config: config,
profile: domain.TestProfile(tb), profile: domain.TestProfile(tb),
profiles: profilerepo.NewMemoryProfileRepository(store), profiles: profilerepo.NewMemoryProfileRepository(),
session: domain.TestSession(tb), session: domain.TestSession(tb),
sessions: sessionrepo.NewMemorySessionRepository(store, config), sessions: sessionrepo.NewMemorySessionRepository(*config),
store: store,
token: domain.TestToken(tb), token: domain.TestToken(tb),
tokens: tokenrepo.NewMemoryTokenRepository(store), tokens: tokenrepo.NewMemoryTokenRepository(),
} }
} }

View File

@ -1,10 +1,10 @@
package http package http
import ( import (
"encoding/json"
"net/http" "net/http"
"strings" "strings"
"github.com/goccy/go-json"
"github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwa"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
@ -13,19 +13,10 @@ import (
"source.toby3d.me/toby3d/auth/internal/token" "source.toby3d.me/toby3d/auth/internal/token"
) )
type ( type Handler struct {
UserInformationResponse struct { config *domain.Config
Name string `json:"name,omitempty"` tokens token.UseCase
URL string `json:"url,omitempty"` }
Photo string `json:"photo,omitempty"`
Email string `json:"email,omitempty"`
}
Handler struct {
config *domain.Config
tokens token.UseCase
}
)
func NewHandler(tokens token.UseCase, config *domain.Config) *Handler { func NewHandler(tokens token.UseCase, config *domain.Config) *Handler {
return &Handler{ return &Handler{
@ -34,7 +25,7 @@ func NewHandler(tokens token.UseCase, config *domain.Config) *Handler {
} }
} }
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{ chain := middleware.Chain{
//nolint:exhaustivestruct //nolint:exhaustivestruct
middleware.JWTWithConfig(middleware.JWTConfig{ middleware.JWTWithConfig(middleware.JWTConfig{
@ -45,13 +36,18 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Skipper: middleware.DefaultSkipper, Skipper: middleware.DefaultSkipper,
TokenLookup: "header:" + common.HeaderAuthorization + ":Bearer ", TokenLookup: "header:" + common.HeaderAuthorization + ":Bearer ",
}), }),
middleware.LogFmt(),
} }
chain.Handler(h.handleFunc).ServeHTTP(w, r) return chain.Handler(h.handleFunc)
} }
func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleFunc(w http.ResponseWriter, r *http.Request) {
if r.Method != "" && r.Method != http.MethodGet {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8) w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w) encoder := json.NewEncoder(w)

View File

@ -0,0 +1,8 @@
package http
type UserInformationResponse struct {
Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"`
Photo string `json:"photo,omitempty"`
Email string `json:"email,omitempty"`
}

View File

@ -1,13 +1,12 @@
package http_test package http_test
import ( import (
"context"
"net/http"
"net/http/httptest" "net/http/httptest"
"path"
"sync"
"testing" "testing"
"github.com/goccy/go-json" "github.com/goccy/go-json"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -26,7 +25,6 @@ type Dependencies struct {
profile *domain.Profile profile *domain.Profile
profiles profile.Repository profiles profile.Repository
sessions session.Repository sessions session.Repository
store *sync.Map
token *domain.Token token *domain.Token
tokens token.Repository tokens token.Repository
tokenService token.UseCase tokenService token.UseCase
@ -36,13 +34,17 @@ func TestUserInfo(t *testing.T) {
t.Parallel() t.Parallel()
deps := NewDependencies(t) deps := NewDependencies(t)
deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, deps.token.Me.String()), deps.profile) if err := deps.profiles.Create(context.Background(), deps.token.Me, *deps.profile); err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodGet, "https://example.com/userinfo", nil) req := httptest.NewRequest(http.MethodGet, "https://example.com/userinfo", nil)
req.Header.Set(common.HeaderAuthorization, "Bearer "+deps.token.AccessToken) req.Header.Set(common.HeaderAuthorization, "Bearer "+deps.token.AccessToken)
w := httptest.NewRecorder() w := httptest.NewRecorder()
delivery.NewHandler(deps.tokenService, deps.config).ServeHTTP(w, req) delivery.NewHandler(deps.tokenService, deps.config).
Handler().
ServeHTTP(w, req)
resp := w.Result() resp := w.Result()
@ -69,22 +71,23 @@ func TestUserInfo(t *testing.T) {
func NewDependencies(tb testing.TB) Dependencies { func NewDependencies(tb testing.TB) Dependencies {
tb.Helper() tb.Helper()
store := new(sync.Map)
config := domain.TestConfig(tb) config := domain.TestConfig(tb)
sessions := sessionrepo.NewMemorySessionRepository(*config)
tokens := tokenrepo.NewMemoryTokenRepository()
profiles := profilerepo.NewMemoryProfileRepository()
return Dependencies{ return Dependencies{
config: config, config: config,
profile: domain.TestProfile(tb), profile: domain.TestProfile(tb),
profiles: profilerepo.NewMemoryProfileRepository(store), profiles: profiles,
sessions: sessionrepo.NewMemorySessionRepository(store, config), sessions: sessions,
store: store,
token: domain.TestToken(tb), token: domain.TestToken(tb),
tokens: tokenrepo.NewMemoryTokenRepository(store), tokens: tokens,
tokenService: tokenucase.NewTokenUseCase(tokenucase.Config{ tokenService: tokenucase.NewTokenUseCase(tokenucase.Config{
Config: config, Config: config,
Profiles: profilerepo.NewMemoryProfileRepository(store), Profiles: profiles,
Sessions: sessionrepo.NewMemorySessionRepository(store, config), Sessions: sessions,
Tokens: tokenrepo.NewMemoryTokenRepository(store), Tokens: tokens,
}), }),
} }
} }

View File

@ -7,7 +7,8 @@ import (
) )
type Repository interface { type Repository interface {
Get(ctx context.Context, me *domain.Me) (*domain.User, error) Create(ctx context.Context, user domain.User) error
Get(ctx context.Context, me domain.Me) (*domain.User, error)
} }
var ErrNotExist error = domain.NewError(domain.ErrorCodeServerError, "user not exist", "") var ErrNotExist error = domain.NewError(domain.ErrorCodeServerError, "user not exist", "")

View File

@ -1,12 +1,16 @@
package http package http
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"io"
"net/http"
"net/url" "net/url"
http "github.com/valyala/fasthttp" "golang.org/x/exp/slices"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/httputil" "source.toby3d.me/toby3d/auth/internal/httputil"
"source.toby3d.me/toby3d/auth/internal/user" "source.toby3d.me/toby3d/auth/internal/user"
@ -38,26 +42,21 @@ func NewHTTPUserRepository(client *http.Client) user.Repository {
} }
} }
func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain.User, error) { // WARN(toby3d): not implemented.
req := http.AcquireRequest() func (httpUserRepository) Create(_ context.Context, _ domain.User) error {
defer http.ReleaseRequest(req) return nil
req.Header.SetMethod(http.MethodGet) }
req.SetRequestURI(me.String())
resp := http.AcquireResponse() func (repo *httpUserRepository) Get(ctx context.Context, me domain.Me) (*domain.User, error) {
defer http.ReleaseResponse(resp) resp, err := repo.client.Get(me.String())
if err != nil {
if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
return nil, fmt.Errorf("cannot fetch user by me: %w", err) return nil, fmt.Errorf("cannot fetch user by me: %w", err)
} }
// TODO(toby3d): handle error here?
resolvedMe, _ := domain.ParseMe(string(resp.Header.Peek(http.HeaderLocation)))
user := &domain.User{ user := &domain.User{
AuthorizationEndpoint: nil, AuthorizationEndpoint: nil,
IndieAuthMetadata: nil, IndieAuthMetadata: nil,
Me: resolvedMe, Me: &me,
Micropub: nil, Micropub: nil,
Microsub: nil, Microsub: nil,
Profile: domain.NewProfile(), Profile: domain.NewProfile(),
@ -65,7 +64,7 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain
TokenEndpoint: nil, TokenEndpoint: nil,
} }
if metadata, err := httputil.ExtractMetadata(resp, repo.client); err == nil { if metadata, err := httputil.ExtractFromMetadata(repo.client, me.String()); err == nil {
user.AuthorizationEndpoint = metadata.AuthorizationEndpoint user.AuthorizationEndpoint = metadata.AuthorizationEndpoint
user.Micropub = metadata.MicropubEndpoint user.Micropub = metadata.MicropubEndpoint
user.Microsub = metadata.MicrosubEndpoint user.Microsub = metadata.MicrosubEndpoint
@ -73,89 +72,87 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain
user.TokenEndpoint = metadata.TokenEndpoint user.TokenEndpoint = metadata.TokenEndpoint
} }
extractUser(user, resp) body, err := io.ReadAll(resp.Body)
extractProfile(user.Profile, resp) if err != nil {
return nil, fmt.Errorf("cannot read response body: %w", err)
}
extractUser(me.URL(), user, body, resp.Header.Get(common.HeaderLink))
extractProfile(me.URL(), user.Profile, body)
return user, nil return user, nil
} }
//nolint:cyclop //nolint:cyclop
func extractUser(dst *domain.User, src *http.Response) { func extractUser(u *url.URL, dst *domain.User, body []byte, header string) {
if dst.IndieAuthMetadata != nil { for key, target := range map[string]**url.URL{
if endpoints := httputil.ExtractEndpoints(src, relIndieAuthMetadata); len(endpoints) > 0 { relAuthorizationEndpoint: &dst.AuthorizationEndpoint,
dst.IndieAuthMetadata = endpoints[len(endpoints)-1] relIndieAuthMetadata: &dst.IndieAuthMetadata,
relMicropub: &dst.Micropub,
relMicrosub: &dst.Microsub,
relTicketEndpoint: &dst.TicketEndpoint,
relTokenEndpoint: &dst.TokenEndpoint,
} {
if target == nil {
continue
} }
}
if dst.AuthorizationEndpoint == nil { if endpoints := httputil.ExtractEndpoints(bytes.NewReader(body), u, header, key); len(endpoints) > 0 {
if endpoints := httputil.ExtractEndpoints(src, relAuthorizationEndpoint); len(endpoints) > 0 { *target = endpoints[len(endpoints)-1]
dst.AuthorizationEndpoint = endpoints[len(endpoints)-1]
}
}
if dst.Micropub == nil {
if endpoints := httputil.ExtractEndpoints(src, relMicropub); len(endpoints) > 0 {
dst.Micropub = endpoints[len(endpoints)-1]
}
}
if dst.Microsub == nil {
if endpoints := httputil.ExtractEndpoints(src, relMicrosub); len(endpoints) > 0 {
dst.Microsub = endpoints[len(endpoints)-1]
}
}
if dst.TicketEndpoint == nil {
if endpoints := httputil.ExtractEndpoints(src, relTicketEndpoint); len(endpoints) > 0 {
dst.TicketEndpoint = endpoints[len(endpoints)-1]
}
}
if dst.TokenEndpoint == nil {
if endpoints := httputil.ExtractEndpoints(src, relTokenEndpoint); len(endpoints) > 0 {
dst.TokenEndpoint = endpoints[len(endpoints)-1]
} }
} }
} }
//nolint:cyclop //nolint:cyclop
func extractProfile(dst *domain.Profile, src *http.Response) { func extractProfile(u *url.URL, dst *domain.Profile, body []byte) {
for _, name := range httputil.ExtractProperty(src, hCard, propertyName) { for _, name := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyName) {
if n, ok := name.(string); ok { if n, ok := name.(string); ok && !slices.Contains(dst.Name, n) {
dst.Name = append(dst.Name, n) dst.Name = append(dst.Name, n)
} }
} }
for _, rawEmail := range httputil.ExtractProperty(src, hCard, propertyEmail) { for _, rawEmail := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyEmail) {
email, ok := rawEmail.(string) email, ok := rawEmail.(string)
if !ok { if !ok {
continue continue
} }
if e, err := domain.ParseEmail(email); err == nil { if e, err := domain.ParseEmail(email); err == nil && !slices.Contains(dst.Email, e) {
dst.Email = append(dst.Email, e) dst.Email = append(dst.Email, e)
} }
} }
for _, rawURL := range httputil.ExtractProperty(src, hCard, propertyURL) { for _, rawURL := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyURL) {
rawURL, ok := rawURL.(string) rawURL, ok := rawURL.(string)
if !ok { if !ok {
continue continue
} }
if u, err := url.Parse(rawURL); err == nil { if u, err := url.Parse(rawURL); err == nil && !containsUrl(dst.URL, u) {
dst.URL = append(dst.URL, u) dst.URL = append(dst.URL, u)
} }
} }
for _, rawPhoto := range httputil.ExtractProperty(src, hCard, propertyPhoto) { for _, rawPhoto := range httputil.ExtractProperty(bytes.NewReader(body), u, hCard, propertyPhoto) {
photo, ok := rawPhoto.(string) photo, ok := rawPhoto.(string)
if !ok { if !ok {
continue continue
} }
if p, err := url.Parse(photo); err == nil { if p, err := url.Parse(photo); err == nil && !containsUrl(dst.Photo, p) {
dst.Photo = append(dst.Photo, p) dst.Photo = append(dst.Photo, p)
} }
} }
} }
func containsUrl(src []*url.URL, find *url.URL) bool {
for i := range src {
if src[i].String() != find.String() {
continue
}
return true
}
return false
}

View File

@ -3,16 +3,15 @@ package http_test
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
"github.com/fasthttp/router" "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
http "github.com/valyala/fasthttp"
"source.toby3d.me/toby3d/auth/internal/common" "source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
repository "source.toby3d.me/toby3d/auth/internal/user/repository/http" repository "source.toby3d.me/toby3d/auth/internal/user/repository/http"
) )
@ -40,39 +39,29 @@ func TestGet(t *testing.T) {
t.Parallel() t.Parallel()
user := domain.TestUser(t) user := domain.TestUser(t)
client, _, cleanup := httptest.New(t, testHandler(t, user))
t.Cleanup(cleanup)
result, err := repository.NewHTTPUserRepository(client).Get(context.Background(), user.Me) srv := httptest.NewServer(testHandler(t, user))
t.Cleanup(srv.Close)
user.Me = domain.TestMe(t, srv.URL+"/")
result, err := repository.NewHTTPUserRepository(srv.Client()).
Get(context.Background(), *user.Me)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// NOTE(toby3d): endpoints if diff := cmp.Diff(user, result, cmp.AllowUnexported(domain.Me{}, domain.Email{})); diff != "" {
assert.Equal(t, user.AuthorizationEndpoint.String(), result.AuthorizationEndpoint.String()) t.Errorf("%+s", diff)
assert.Equal(t, user.TokenEndpoint.String(), result.TokenEndpoint.String())
assert.Equal(t, user.Micropub.String(), result.Micropub.String())
assert.Equal(t, user.Microsub.String(), result.Microsub.String())
// NOTE(toby3d): profile
assert.Equal(t, user.Profile.Name, result.Profile.Name)
assert.Equal(t, user.Profile.Email, result.Profile.Email)
for i := range user.Profile.URL {
assert.Equal(t, user.Profile.URL[i].String(), result.Profile.URL[i].String())
}
for i := range user.Profile.Photo {
assert.Equal(t, user.Profile.Photo[i].String(), result.Profile.Photo[i].String())
} }
} }
func testHandler(tb testing.TB, user *domain.User) http.RequestHandler { func testHandler(tb testing.TB, user *domain.User) http.Handler {
tb.Helper() tb.Helper()
router := router.New() mux := http.NewServeMux()
router.GET("/", func(ctx *http.RequestCtx) { mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
ctx.Response.Header.Set(http.HeaderLink, strings.Join([]string{ w.Header().Set(common.HeaderLink, strings.Join([]string{
`<` + user.AuthorizationEndpoint.String() + `>; rel="authorization_endpoint"`, `<` + user.AuthorizationEndpoint.String() + `>; rel="authorization_endpoint"`,
`<` + user.IndieAuthMetadata.String() + `>; rel="indieauth-metadata"`, `<` + user.IndieAuthMetadata.String() + `>; rel="indieauth-metadata"`,
`<` + user.Micropub.String() + `>; rel="micropub"`, `<` + user.Micropub.String() + `>; rel="micropub"`,
@ -80,17 +69,17 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
`<` + user.TicketEndpoint.String() + `>; rel="ticket_endpoint"`, `<` + user.TicketEndpoint.String() + `>; rel="ticket_endpoint"`,
`<` + user.TokenEndpoint.String() + `>; rel="token_endpoint"`, `<` + user.TokenEndpoint.String() + `>; rel="token_endpoint"`,
}, ", ")) }, ", "))
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, fmt.Sprintf( w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0], fmt.Fprintf(w, testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0])
))
}) })
router.GET(user.IndieAuthMetadata.Path, func(ctx *http.RequestCtx) { mux.HandleFunc(user.IndieAuthMetadata.Path, func(w http.ResponseWriter, r *http.Request) {
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{ w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
fmt.Fprint(w, `{
"issuer": "`+user.Me.String()+`", "issuer": "`+user.Me.String()+`",
"authorization_endpoint": "`+user.AuthorizationEndpoint.String()+`", "authorization_endpoint": "`+user.AuthorizationEndpoint.String()+`",
"token_endpoint": "`+user.TokenEndpoint.String()+`" "token_endpoint": "`+user.TokenEndpoint.String()+`"
}`) }`)
}) })
return router.Handler return mux
} }

View File

@ -2,7 +2,6 @@ package memory
import ( import (
"context" "context"
"path"
"sync" "sync"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -10,27 +9,33 @@ import (
) )
type memoryUserRepository struct { type memoryUserRepository struct {
store *sync.Map mutex *sync.RWMutex
users map[string]domain.User
} }
const DefaultPathPrefix string = "users" func NewMemoryUserRepository() user.Repository {
func NewMemoryUserRepository(store *sync.Map) user.Repository {
return &memoryUserRepository{ return &memoryUserRepository{
store: store, mutex: new(sync.RWMutex),
users: make(map[string]domain.User),
} }
} }
func (repo *memoryUserRepository) Get(ctx context.Context, me *domain.Me) (*domain.User, error) { func (repo *memoryUserRepository) Create(ctx context.Context, user domain.User) error {
p, ok := repo.store.Load(path.Join(DefaultPathPrefix, me.String())) repo.mutex.Lock()
if !ok { defer repo.mutex.Unlock()
return nil, user.ErrNotExist
}
result, ok := p.(*domain.User) repo.users[user.Me.String()] = user
if !ok {
return nil, user.ErrNotExist
}
return result, nil return nil
}
func (repo *memoryUserRepository) Get(ctx context.Context, me domain.Me) (*domain.User, error) {
repo.mutex.RLock()
defer repo.mutex.RUnlock()
if u, ok := repo.users[me.String()]; ok {
return &u, nil
}
return nil, user.ErrNotExist
} }

View File

@ -1,30 +0,0 @@
package memory_test
import (
"context"
"path"
"reflect"
"sync"
"testing"
"source.toby3d.me/toby3d/auth/internal/domain"
repository "source.toby3d.me/toby3d/auth/internal/user/repository/memory"
)
func TestGet(t *testing.T) {
t.Parallel()
user := domain.TestUser(t)
store := new(sync.Map)
store.Store(path.Join(repository.DefaultPathPrefix, user.Me.String()), user)
result, err := repository.NewMemoryUserRepository(store).Get(context.Background(), user.Me)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(result, user) {
t.Errorf("Get(%s) = %+v, want %+v", user.Me, result, user)
}
}

View File

@ -8,5 +8,5 @@ import (
type UseCase interface { type UseCase interface {
// Fetch discovery all available endpoints and Profile info on Me URL. // Fetch discovery all available endpoints and Profile info on Me URL.
Fetch(ctx context.Context, me *domain.Me) (*domain.User, error) Fetch(ctx context.Context, me domain.Me) (*domain.User, error)
} }

View File

@ -18,7 +18,7 @@ func NewUserUseCase(repo user.Repository) user.UseCase {
} }
} }
func (useCase *userUseCase) Fetch(ctx context.Context, me *domain.Me) (*domain.User, error) { func (useCase *userUseCase) Fetch(ctx context.Context, me domain.Me) (*domain.User, error) {
user, err := useCase.repo.Get(ctx, me) user, err := useCase.repo.Get(ctx, me)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot find user by me: %w", err) return nil, fmt.Errorf("cannot find user by me: %w", err)

View File

@ -2,9 +2,7 @@ package usecase_test
import ( import (
"context" "context"
"path"
"reflect" "reflect"
"sync"
"testing" "testing"
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
@ -15,19 +13,20 @@ import (
func TestFetch(t *testing.T) { func TestFetch(t *testing.T) {
t.Parallel() t.Parallel()
me := domain.TestMe(t, "https://user.example.net")
user := domain.TestUser(t) user := domain.TestUser(t)
user.Me = domain.TestMe(t, "https://user.example.net")
users := repository.NewMemoryUserRepository()
store := new(sync.Map) if err := users.Create(context.Background(), *user); err != nil {
store.Store(path.Join(repository.DefaultPathPrefix, me.String()), user) t.Fatal(err)
}
result, err := ucase.NewUserUseCase(repository.NewMemoryUserRepository(store)). result, err := ucase.NewUserUseCase(users).Fetch(context.Background(), *user.Me)
Fetch(context.Background(), me)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if !reflect.DeepEqual(result, user) { if !reflect.DeepEqual(result, user) {
t.Errorf("Fetch(%s) = %+v, want %+v", me, result, user) t.Errorf("Fetch(%s) = %+v, want %+v", user.Me, result, user)
} }
} }

196
main.go
View File

@ -5,27 +5,27 @@
package main package main
import ( import (
"context"
"embed"
_ "embed" _ "embed"
"errors" "errors"
"flag" "flag"
"fmt" "io/fs"
"log" "log"
"net/http"
_ "net/http/pprof"
"net/url"
"os" "os"
"os/signal" "os/signal"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
"github.com/fasthttp/router"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/spf13/viper" "github.com/spf13/viper"
http "github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/pprofhandler"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
@ -40,6 +40,7 @@ import (
"source.toby3d.me/toby3d/auth/internal/domain" "source.toby3d.me/toby3d/auth/internal/domain"
healthhttpdelivery "source.toby3d.me/toby3d/auth/internal/health/delivery/http" healthhttpdelivery "source.toby3d.me/toby3d/auth/internal/health/delivery/http"
metadatahttpdelivery "source.toby3d.me/toby3d/auth/internal/metadata/delivery/http" metadatahttpdelivery "source.toby3d.me/toby3d/auth/internal/metadata/delivery/http"
"source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/profile" "source.toby3d.me/toby3d/auth/internal/profile"
profilehttprepo "source.toby3d.me/toby3d/auth/internal/profile/repository/http" profilehttprepo "source.toby3d.me/toby3d/auth/internal/profile/repository/http"
profileucase "source.toby3d.me/toby3d/auth/internal/profile/usecase" profileucase "source.toby3d.me/toby3d/auth/internal/profile/usecase"
@ -57,6 +58,7 @@ import (
tokenmemoryrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory" tokenmemoryrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory"
tokensqlite3repo "source.toby3d.me/toby3d/auth/internal/token/repository/sqlite3" tokensqlite3repo "source.toby3d.me/toby3d/auth/internal/token/repository/sqlite3"
tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase" tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase"
"source.toby3d.me/toby3d/auth/internal/urlutil"
userhttpdelivery "source.toby3d.me/toby3d/auth/internal/user/delivery/http" userhttpdelivery "source.toby3d.me/toby3d/auth/internal/user/delivery/http"
) )
@ -69,6 +71,7 @@ type (
tickets ticket.UseCase tickets ticket.UseCase
profiles profile.UseCase profiles profile.UseCase
tokens token.UseCase tokens token.UseCase
static fs.FS
} }
NewAppOptions struct { NewAppOptions struct {
@ -78,6 +81,7 @@ type (
Tickets ticket.Repository Tickets ticket.Repository
Tokens token.Repository Tokens token.Repository
Profiles profile.Repository Profiles profile.Repository
Static fs.FS
} }
) )
@ -93,13 +97,16 @@ var (
logger = log.New(os.Stdout, "IndieAuth\t", log.Lmsgprefix|log.LstdFlags|log.LUTC) logger = log.New(os.Stdout, "IndieAuth\t", log.Lmsgprefix|log.LstdFlags|log.LUTC)
config = new(domain.Config) config = new(domain.Config)
indieAuthClient = new(domain.Client) indieAuthClient = new(domain.Client)
configPath string
cpuProfilePath string
memProfilePath string
enablePprof bool
) )
var (
configPath, cpuProfilePath, memProfilePath string
enablePprof bool
)
//go:embed assets/*
var staticFS embed.FS
//nolint:gochecknoinits //nolint:gochecknoinits
func init() { func init() {
flag.StringVar(&configPath, "config", filepath.Join(".", "config.yml"), "load specific config") flag.StringVar(&configPath, "config", filepath.Join(".", "config.yml"), "load specific config")
@ -133,34 +140,44 @@ func init() {
rootURL := config.Server.GetRootURL() rootURL := config.Server.GetRootURL()
indieAuthClient.Name = []string{config.Name} indieAuthClient.Name = []string{config.Name}
if indieAuthClient.ID, err = domain.ParseClientID(rootURL); err != nil { cid, err := domain.ParseClientID(rootURL)
if err != nil {
logger.Fatalln("fail to read config:", err) logger.Fatalln("fail to read config:", err)
} }
url, err := domain.ParseURL(rootURL) indieAuthClient.ID = *cid
u, err := url.Parse(rootURL)
if err != nil { if err != nil {
logger.Fatalln("cannot parse root URL as client URL:", err) logger.Fatalln("cannot parse root URL as client URL:", err)
} }
logo, err := domain.ParseURL(rootURL + config.Server.StaticURLPrefix + "/icon.svg") logo, err := url.Parse(rootURL + config.Server.StaticURLPrefix + "/icon.svg")
if err != nil { if err != nil {
logger.Fatalln("cannot parse root URL as client URL:", err) logger.Fatalln("cannot parse root URL as client URL:", err)
} }
redirectURI, err := domain.ParseURL(rootURL + "/callback") redirectURI, err := url.Parse(rootURL + "callback")
if err != nil { if err != nil {
logger.Fatalln("cannot parse root URL as client URL:", err) logger.Fatalln("cannot parse root URL as client URL:", err)
} }
indieAuthClient.URL = []*domain.URL{url} indieAuthClient.URL = []*url.URL{u}
indieAuthClient.Logo = []*domain.URL{logo} indieAuthClient.Logo = []*url.URL{logo}
indieAuthClient.RedirectURI = []*domain.URL{redirectURI} indieAuthClient.RedirectURI = []*url.URL{redirectURI}
} }
//nolint:funlen,cyclop // "god object" and the entry point of all modules //nolint:funlen,cyclop // "god object" and the entry point of all modules
func main() { func main() {
ctx := context.Background()
var opts NewAppOptions var opts NewAppOptions
var err error
if opts.Static, err = fs.Sub(staticFS, "assets"); err != nil {
logger.Fatalln(err)
}
switch strings.ToLower(config.Database.Type) { switch strings.ToLower(config.Database.Type) {
case "sqlite3": case "sqlite3":
store, err := sqlx.Open("sqlite", config.Database.Path) store, err := sqlx.Open("sqlite", config.Database.Path)
@ -176,51 +193,27 @@ func main() {
opts.Sessions = sessionsqlite3repo.NewSQLite3SessionRepository(store) opts.Sessions = sessionsqlite3repo.NewSQLite3SessionRepository(store)
opts.Tickets = ticketsqlite3repo.NewSQLite3TicketRepository(store, config) opts.Tickets = ticketsqlite3repo.NewSQLite3TicketRepository(store, config)
case "memory": case "memory":
store := new(sync.Map) opts.Tokens = tokenmemoryrepo.NewMemoryTokenRepository()
opts.Tokens = tokenmemoryrepo.NewMemoryTokenRepository(store) opts.Sessions = sessionmemoryrepo.NewMemorySessionRepository(*config)
opts.Sessions = sessionmemoryrepo.NewMemorySessionRepository(store, config) opts.Tickets = ticketmemoryrepo.NewMemoryTicketRepository(*config)
opts.Tickets = ticketmemoryrepo.NewMemoryTicketRepository(store, config)
default: default:
log.Fatalln("unsupported database type, use 'memory' or 'sqlite3'") log.Fatalln("unsupported database type, use 'memory' or 'sqlite3'")
} }
go opts.Sessions.GC() go opts.Sessions.GC()
//nolint:exhaustivestruct // too many options opts.Client = new(http.Client)
opts.Client = &http.Client{
Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()),
ReadTimeout: DefaultReadTimeout,
WriteTimeout: DefaultWriteTimeout,
}
opts.Clients = clienthttprepo.NewHTTPClientRepository(opts.Client) opts.Clients = clienthttprepo.NewHTTPClientRepository(opts.Client)
opts.Profiles = profilehttprepo.NewHTPPClientRepository(opts.Client) opts.Profiles = profilehttprepo.NewHTPPClientRepository(opts.Client)
r := router.New() app := NewApp(opts)
NewApp(opts).Register(r)
//nolint:exhaustivestruct // too many options
r.ServeFilesCustom(path.Join(config.Server.StaticURLPrefix, "{filepath:*}"), &http.FS{
Root: config.Server.StaticRootPath,
CacheDuration: DefaultCacheDuration,
AcceptByteRange: true,
Compress: true,
CompressBrotli: true,
GenerateIndexPages: true,
})
if enablePprof {
r.GET("/debug/pprof/{filepath:*}", pprofhandler.PprofHandler)
}
//nolint:exhaustivestruct //nolint:exhaustivestruct
server := &http.Server{ server := &http.Server{
Name: fmt.Sprintf("IndieAuth/0.1 (+%s)", config.Server.GetAddress()), Addr: config.Server.GetAddress(),
Handler: r.Handler, Handler: app.Handler(),
ReadTimeout: DefaultReadTimeout, ReadTimeout: DefaultReadTimeout,
WriteTimeout: DefaultWriteTimeout, WriteTimeout: DefaultWriteTimeout,
DisableKeepalive: true,
ReduceMemoryUsage: true,
SecureErrorLogMessage: true,
CloseOnShutdown: true,
} }
done := make(chan os.Signal, 1) done := make(chan os.Signal, 1)
@ -243,15 +236,15 @@ func main() {
logger.Printf("started at %s, available at %s", config.Server.GetAddress(), logger.Printf("started at %s, available at %s", config.Server.GetAddress(),
config.Server.GetRootURL()) config.Server.GetRootURL())
err := server.ListenAndServe(config.Server.GetAddress()) err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrConnectionClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatalln("cannot listen and serve:", err) logger.Fatalln("cannot listen and serve:", err)
} }
}() }()
<-done <-done
if err := server.Shutdown(); err != nil { if err := server.Shutdown(ctx); err != nil {
logger.Fatalln("failed shutdown of server:", err) logger.Fatalln("failed shutdown of server:", err)
} }
@ -274,6 +267,7 @@ func main() {
func NewApp(opts NewAppOptions) *App { func NewApp(opts NewAppOptions) *App {
return &App{ return &App{
static: opts.Static,
auth: authucase.NewAuthUseCase(opts.Sessions, opts.Profiles, config), auth: authucase.NewAuthUseCase(opts.Sessions, opts.Profiles, config),
clients: clientucase.NewClientUseCase(opts.Clients), clients: clientucase.NewClientUseCase(opts.Clients),
matcher: language.NewMatcher(message.DefaultCatalog.Languages()), matcher: language.NewMatcher(message.DefaultCatalog.Languages()),
@ -289,20 +283,19 @@ func NewApp(opts NewAppOptions) *App {
} }
} }
func (app *App) Register(r *router.Router) { // TODO(toby3d): move module middlewares to here.
tickethttpdelivery.NewRequestHandler(app.tickets, app.matcher, config).Register(r) func (app *App) Handler() http.Handler {
healthhttpdelivery.NewRequestHandler().Register(r) metadata := metadatahttpdelivery.NewHandler(&domain.Metadata{
metadatahttpdelivery.NewRequestHandler(&domain.Metadata{ Issuer: indieAuthClient.ID.URL(),
Issuer: indieAuthClient.ID, AuthorizationEndpoint: indieAuthClient.ID.URL().JoinPath("authorize"),
AuthorizationEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "authorize"), TokenEndpoint: indieAuthClient.ID.URL().JoinPath("token"),
TokenEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "token"), TicketEndpoint: indieAuthClient.ID.URL().JoinPath("ticket"),
TicketEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "ticket"),
MicropubEndpoint: nil, MicropubEndpoint: nil,
MicrosubEndpoint: nil, MicrosubEndpoint: nil,
IntrospectionEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "introspect"), IntrospectionEndpoint: indieAuthClient.ID.URL().JoinPath("introspect"),
RevocationEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "revocation"), RevocationEndpoint: indieAuthClient.ID.URL().JoinPath("revocation"),
UserinfoEndpoint: domain.MustParseURL(indieAuthClient.ID.String() + "userinfo"), UserinfoEndpoint: indieAuthClient.ID.URL().JoinPath("userinfo"),
ServiceDocumentation: domain.MustParseURL("https://indieauth.net/source/"), ServiceDocumentation: &url.URL{Scheme: "https", Host: "indieauth.net", Path: "/source/"},
IntrospectionEndpointAuthMethodsSupported: []string{"Bearer"}, IntrospectionEndpointAuthMethodsSupported: []string{"Bearer"},
RevocationEndpointAuthMethodsSupported: []string{"none"}, RevocationEndpointAuthMethodsSupported: []string{"none"},
ScopesSupported: domain.Scopes{ ScopesSupported: domain.Scopes{
@ -319,8 +312,14 @@ func (app *App) Register(r *router.Router) {
domain.ScopeRead, domain.ScopeRead,
domain.ScopeUpdate, domain.ScopeUpdate,
}, },
ResponseTypesSupported: []domain.ResponseType{domain.ResponseTypeCode, domain.ResponseTypeID}, ResponseTypesSupported: []domain.ResponseType{
GrantTypesSupported: []domain.GrantType{domain.GrantTypeAuthorizationCode, domain.GrantTypeTicket}, domain.ResponseTypeCode,
domain.ResponseTypeID,
},
GrantTypesSupported: []domain.GrantType{
domain.GrantTypeAuthorizationCode,
domain.GrantTypeTicket,
},
CodeChallengeMethodsSupported: []domain.CodeChallengeMethod{ CodeChallengeMethodsSupported: []domain.CodeChallengeMethod{
domain.CodeChallengeMethodMD5, domain.CodeChallengeMethodMD5,
domain.CodeChallengeMethodPLAIN, domain.CodeChallengeMethodPLAIN,
@ -329,20 +328,57 @@ func (app *App) Register(r *router.Router) {
domain.CodeChallengeMethodS512, domain.CodeChallengeMethodS512,
}, },
AuthorizationResponseIssParameterSupported: true, AuthorizationResponseIssParameterSupported: true,
}).Register(r) }).Handler()
tokenhttpdelivery.NewRequestHandler(app.tokens, app.tickets, config).Register(r) health := healthhttpdelivery.NewHandler().Handler()
clienthttpdelivery.NewRequestHandler(clienthttpdelivery.NewRequestHandlerOptions{ auth := authhttpdelivery.NewHandler(authhttpdelivery.NewHandlerOptions{
Client: indieAuthClient,
Config: config,
Matcher: app.matcher,
Tokens: app.tokens,
}).Register(r)
authhttpdelivery.NewRequestHandler(authhttpdelivery.NewRequestHandlerOptions{
Auth: app.auth, Auth: app.auth,
Clients: app.clients, Clients: app.clients,
Config: config, Config: *config,
Matcher: app.matcher, Matcher: app.matcher,
Profiles: app.profiles, Profiles: app.profiles,
}).Register(r) }).Handler()
userhttpdelivery.NewRequestHandler(app.tokens, config).Register(r) token := tokenhttpdelivery.NewHandler(app.tokens, app.tickets, config).Handler()
client := clienthttpdelivery.NewHandler(clienthttpdelivery.NewHandlerOptions{
Client: *indieAuthClient,
Config: *config,
Matcher: app.matcher,
Tokens: app.tokens,
}).Handler()
user := userhttpdelivery.NewHandler(app.tokens, config).Handler()
ticket := tickethttpdelivery.NewHandler(app.tickets, app.matcher, *config).Handler()
static := http.FileServer(http.FS(app.static))
return http.HandlerFunc(middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var head string
head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
switch head {
default:
r.URL = r.URL.JoinPath(head, r.URL.Path)
static.ServeHTTP(w, r)
case "", "callback":
r.URL = r.URL.JoinPath(head, r.URL.Path)
client.ServeHTTP(w, r)
case "token", "introspect", "revocation":
r.URL = r.URL.JoinPath(head, r.URL.Path)
token.ServeHTTP(w, r)
case ".well-known":
if head, _ = urlutil.ShiftPath(r.URL.Path); head == "oauth-authorization-server" {
metadata.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
case "authorize":
auth.ServeHTTP(w, r)
case "health":
health.ServeHTTP(w, r)
case "userinfo":
user.ServeHTTP(w, r)
case "ticket":
ticket.ServeHTTP(w, r)
}
}).Intercept(middleware.LogFmt()))
} }

View File

@ -17,137 +17,137 @@
} %} } %}
{% func (p *AuthorizePage) Title() %} {% func (p *AuthorizePage) Title() %}
{% if p.Client.GetName() == "" %} {% if p.Client.GetName() == "" %}
{%= p.T("Authorize %s", p.Client.GetName()) %} {%= p.T("Authorize %s", p.Client.GetName()) %}
{% else %} {% else %}
{%= p.T("Authorize application") %} {%= p.T("Authorize application") %}
{% endif %} {% endif %}
{% endfunc %} {% endfunc %}
{% func (p *AuthorizePage) Body() %} {% func (p *AuthorizePage) Body() %}
<header> <header>
{% if p.Client.GetLogo() != nil %} {% if p.Client.GetLogo() != nil %}
<img class="" <img class=""
crossorigin="anonymous" crossorigin="anonymous"
decoding="async" decoding="async"
height="140" height="140"
importance="high" importance="high"
loading="lazy" loading="lazy"
referrerpolicy="no-referrer-when-downgrade" referrerpolicy="no-referrer-when-downgrade"
src="{%s p.Client.GetLogo().String() %}" src="{%s p.Client.GetLogo().String() %}"
alt="{%s p.Client.GetName() %}" alt="{%s p.Client.GetName() %}"
width="140"> width="140">
{% endif %} {% endif %}
<h2> <h2>
{% if p.Client.GetURL() != nil %} {% if p.Client.GetURL() != nil %}
<a href="{%s p.Client.GetURL().String() %}"> <a href="{%s p.Client.GetURL().String() %}">
{% endif %} {% endif %}
{% if p.Client.GetName() != "" %} {% if p.Client.GetName() != "" %}
{%s p.Client.GetName() %} {%s p.Client.GetName() %}
{% else %} {% else %}
{%s p.Client.ID.String() %} {%s p.Client.ID.String() %}
{% endif %} {% endif %}
{% if p.Client.GetURL() != nil %} {% if p.Client.GetURL() != nil %}
</a> </a>
{% endif %} {% endif %}
</h2> </h2>
</header> </header>
<main> <main>
<form class="" <form class=""
accept-charset="utf-8" accept-charset="utf-8"
action="/api/authorize" action="/authorize/verify"
autocomplete="off" autocomplete="off"
enctype="application/x-www-form-urlencoded" enctype="application/x-www-form-urlencoded"
method="post" method="post"
novalidate="true" novalidate="true"
target="_self"> target="_self">
{% if p.CSRF != nil %} {% if p.CSRF != nil %}
<input type="hidden" <input type="hidden"
name="_csrf" name="_csrf"
value="{%z p.CSRF %}"> value="{%z p.CSRF %}">
{% endif %} {% endif %}
{% for key, val := range map[string]string{ {% for key, val := range map[string]string{
"client_id": p.Client.ID.String(), "client_id": p.Client.ID.String(),
"redirect_uri": p.RedirectURI.String(), "redirect_uri": p.RedirectURI.String(),
"response_type": p.ResponseType.String(), "response_type": p.ResponseType.String(),
"state": p.State, "state": p.State,
} %} } %}
<input type="hidden" <input type="hidden"
name="{%s key %}" name="{%s key %}"
value="{%s val %}"> value="{%s val %}">
{% endfor %}
{% if len(p.Scope) > 0 %}
<fieldset>
<legend>{%= p.T("Choose your scopes") %}</legend>
{% for _, scope := range p.Scope %}
<div>
<label>
<input type="checkbox"
name="scope[]"
value="{%s scope.String() %}"
checked>
{%s scope.String() %}
</label>
</div>
{% endfor %} {% endfor %}
</fieldset>
{% endif %}
{% if len(p.Scope) > 0 %} {% if p.CodeChallenge != "" %}
<fieldset> <input type="hidden"
<legend>{%= p.T("Choose your scopes") %}</legend> name="code_challenge"
value="{%s p.CodeChallenge %}">
{% for _, scope := range p.Scope %} <input type="hidden"
<div> name="code_challenge_method"
<label> value="{%s p.CodeChallengeMethod.String() %}">
<input type="checkbox" {% endif %}
name="scope[]"
value="{%s scope.String() %}"
checked>
{%s scope.String() %} {% if p.Me != nil %}
</label> <input type="hidden"
</div> name="me"
{% endfor %} value="{%s p.Me.String() %}">
</fieldset> {% endif %}
{% endif %}
{% if p.CodeChallenge != "" %} {% if len(p.Providers) > 0 %}
<input type="hidden" <select name="provider"
name="code_challenge" autocomplete
value="{%s p.CodeChallenge %}"> required>
<input type="hidden" {% for _, provider := range p.Providers %}
name="code_challenge_method" <option value="{%s provider.UID %}"
value="{%s p.CodeChallengeMethod.String() %}">
{% endif %}
{% if p.Me != nil %}
<input type="hidden"
name="me"
value="{%s p.Me.String() %}">
{% endif %}
{% if len(p.Providers) > 0 %}
<select name="provider"
autocomplete
required>
{% for _, provider := range p.Providers %}
<option value="{%s provider.UID %}"
{% if provider.UID == "mastodon" %}selected{% endif %}> {% if provider.UID == "mastodon" %}selected{% endif %}>
{%s provider.Name %} {%s provider.Name %}
</option> </option>
{% endfor %} {% endfor %}
</select> </select>
{% else %} {% else %}
<input type="hidden" <input type="hidden"
name="provider" name="provider"
value="direct"> value="direct">
{% endif %} {% endif %}
<button type="submit" <button type="submit"
name="authorize" name="authorize"
value="deny"> value="deny">
{%= p.T("Deny") %} {%= p.T("Deny") %}
</button> </button>
<button type="submit" <button type="submit"
name="authorize" name="authorize"
value="allow"> value="allow">
{%= p.T("Allow") %} {%= p.T("Allow") %}
</button> </button>
</form> </form>
</main> </main>
{% endfunc %} {% endfunc %}

View File

@ -41,27 +41,27 @@ type AuthorizePage struct {
func (p *AuthorizePage) StreamTitle(qw422016 *qt422016.Writer) { func (p *AuthorizePage) StreamTitle(qw422016 *qt422016.Writer) {
//line web/authorize.qtpl:19 //line web/authorize.qtpl:19
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:20 //line web/authorize.qtpl:20
if p.Client.GetName() == "" { if p.Client.GetName() == "" {
//line web/authorize.qtpl:20 //line web/authorize.qtpl:20
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:21 //line web/authorize.qtpl:21
p.StreamT(qw422016, "Authorize %s", p.Client.GetName()) p.StreamT(qw422016, "Authorize %s", p.Client.GetName())
//line web/authorize.qtpl:21 //line web/authorize.qtpl:21
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:22 //line web/authorize.qtpl:22
} else { } else {
//line web/authorize.qtpl:22 //line web/authorize.qtpl:22
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:23 //line web/authorize.qtpl:23
p.StreamT(qw422016, "Authorize application") p.StreamT(qw422016, "Authorize application")
//line web/authorize.qtpl:23 //line web/authorize.qtpl:23
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:24 //line web/authorize.qtpl:24
} }
//line web/authorize.qtpl:24 //line web/authorize.qtpl:24
@ -100,43 +100,43 @@ func (p *AuthorizePage) Title() string {
func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) { func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) {
//line web/authorize.qtpl:27 //line web/authorize.qtpl:27
qw422016.N().S(` qw422016.N().S(`
<header> <header>
`) `)
//line web/authorize.qtpl:29 //line web/authorize.qtpl:29
if p.Client.GetLogo() != nil { if p.Client.GetLogo() != nil {
//line web/authorize.qtpl:29 //line web/authorize.qtpl:29
qw422016.N().S(` qw422016.N().S(`
<img class="" <img class=""
crossorigin="anonymous" crossorigin="anonymous"
decoding="async" decoding="async"
height="140" height="140"
importance="high" importance="high"
loading="lazy" loading="lazy"
referrerpolicy="no-referrer-when-downgrade" referrerpolicy="no-referrer-when-downgrade"
src="`) src="`)
//line web/authorize.qtpl:37 //line web/authorize.qtpl:37
qw422016.E().S(p.Client.GetLogo().String()) qw422016.E().S(p.Client.GetLogo().String())
//line web/authorize.qtpl:37 //line web/authorize.qtpl:37
qw422016.N().S(`" qw422016.N().S(`"
alt="`) alt="`)
//line web/authorize.qtpl:38 //line web/authorize.qtpl:38
qw422016.E().S(p.Client.GetName()) qw422016.E().S(p.Client.GetName())
//line web/authorize.qtpl:38 //line web/authorize.qtpl:38
qw422016.N().S(`" qw422016.N().S(`"
width="140"> width="140">
`) `)
//line web/authorize.qtpl:40 //line web/authorize.qtpl:40
} }
//line web/authorize.qtpl:40 //line web/authorize.qtpl:40
qw422016.N().S(` qw422016.N().S(`
<h2> <h2>
`) `)
//line web/authorize.qtpl:43 //line web/authorize.qtpl:43
if p.Client.GetURL() != nil { if p.Client.GetURL() != nil {
//line web/authorize.qtpl:43 //line web/authorize.qtpl:43
qw422016.N().S(` qw422016.N().S(`
<a href="`) <a href="`)
//line web/authorize.qtpl:44 //line web/authorize.qtpl:44
qw422016.E().S(p.Client.GetURL().String()) qw422016.E().S(p.Client.GetURL().String())
//line web/authorize.qtpl:44 //line web/authorize.qtpl:44
@ -151,7 +151,7 @@ func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) {
if p.Client.GetName() != "" { if p.Client.GetName() != "" {
//line web/authorize.qtpl:46 //line web/authorize.qtpl:46
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:47 //line web/authorize.qtpl:47
qw422016.E().S(p.Client.GetName()) qw422016.E().S(p.Client.GetName())
//line web/authorize.qtpl:47 //line web/authorize.qtpl:47
@ -161,7 +161,7 @@ func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) {
} else { } else {
//line web/authorize.qtpl:48 //line web/authorize.qtpl:48
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:49 //line web/authorize.qtpl:49
qw422016.E().S(p.Client.ID.String()) qw422016.E().S(p.Client.ID.String())
//line web/authorize.qtpl:49 //line web/authorize.qtpl:49
@ -176,44 +176,44 @@ func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) {
if p.Client.GetURL() != nil { if p.Client.GetURL() != nil {
//line web/authorize.qtpl:51 //line web/authorize.qtpl:51
qw422016.N().S(` qw422016.N().S(`
</a> </a>
`) `)
//line web/authorize.qtpl:53 //line web/authorize.qtpl:53
} }
//line web/authorize.qtpl:53 //line web/authorize.qtpl:53
qw422016.N().S(` qw422016.N().S(`
</h2> </h2>
</header> </header>
<main> <main>
<form class="" <form class=""
accept-charset="utf-8" accept-charset="utf-8"
action="/api/authorize" action="/authorize/verify"
autocomplete="off" autocomplete="off"
enctype="application/x-www-form-urlencoded" enctype="application/x-www-form-urlencoded"
method="post" method="post"
novalidate="true" novalidate="true"
target="_self"> target="_self">
`) `)
//line web/authorize.qtpl:67 //line web/authorize.qtpl:67
if p.CSRF != nil { if p.CSRF != nil {
//line web/authorize.qtpl:67 //line web/authorize.qtpl:67
qw422016.N().S(` qw422016.N().S(`
<input type="hidden" <input type="hidden"
name="_csrf" name="_csrf"
value="`) value="`)
//line web/authorize.qtpl:70 //line web/authorize.qtpl:70
qw422016.E().Z(p.CSRF) qw422016.E().Z(p.CSRF)
//line web/authorize.qtpl:70 //line web/authorize.qtpl:70
qw422016.N().S(`"> qw422016.N().S(`">
`) `)
//line web/authorize.qtpl:71 //line web/authorize.qtpl:71
} }
//line web/authorize.qtpl:71 //line web/authorize.qtpl:71
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:73 //line web/authorize.qtpl:73
for key, val := range map[string]string{ for key, val := range map[string]string{
"client_id": p.Client.ID.String(), "client_id": p.Client.ID.String(),
@ -223,129 +223,129 @@ func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) {
} { } {
//line web/authorize.qtpl:78 //line web/authorize.qtpl:78
qw422016.N().S(` qw422016.N().S(`
<input type="hidden" <input type="hidden"
name="`) name="`)
//line web/authorize.qtpl:80 //line web/authorize.qtpl:80
qw422016.E().S(key) qw422016.E().S(key)
//line web/authorize.qtpl:80 //line web/authorize.qtpl:80
qw422016.N().S(`" qw422016.N().S(`"
value="`) value="`)
//line web/authorize.qtpl:81 //line web/authorize.qtpl:81
qw422016.E().S(val) qw422016.E().S(val)
//line web/authorize.qtpl:81 //line web/authorize.qtpl:81
qw422016.N().S(`"> qw422016.N().S(`">
`) `)
//line web/authorize.qtpl:82 //line web/authorize.qtpl:82
} }
//line web/authorize.qtpl:82 //line web/authorize.qtpl:82
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:84 //line web/authorize.qtpl:84
if len(p.Scope) > 0 { if len(p.Scope) > 0 {
//line web/authorize.qtpl:84 //line web/authorize.qtpl:84
qw422016.N().S(` qw422016.N().S(`
<fieldset> <fieldset>
<legend>`) <legend>`)
//line web/authorize.qtpl:86 //line web/authorize.qtpl:86
p.StreamT(qw422016, "Choose your scopes") p.StreamT(qw422016, "Choose your scopes")
//line web/authorize.qtpl:86 //line web/authorize.qtpl:86
qw422016.N().S(`</legend> qw422016.N().S(`</legend>
`) `)
//line web/authorize.qtpl:88 //line web/authorize.qtpl:88
for _, scope := range p.Scope { for _, scope := range p.Scope {
//line web/authorize.qtpl:88 //line web/authorize.qtpl:88
qw422016.N().S(` qw422016.N().S(`
<div> <div>
<label> <label>
<input type="checkbox" <input type="checkbox"
name="scope[]" name="scope[]"
value="`) value="`)
//line web/authorize.qtpl:93 //line web/authorize.qtpl:93
qw422016.E().S(scope.String()) qw422016.E().S(scope.String())
//line web/authorize.qtpl:93 //line web/authorize.qtpl:93
qw422016.N().S(`" qw422016.N().S(`"
checked> checked>
`) `)
//line web/authorize.qtpl:96 //line web/authorize.qtpl:96
qw422016.E().S(scope.String()) qw422016.E().S(scope.String())
//line web/authorize.qtpl:96 //line web/authorize.qtpl:96
qw422016.N().S(` qw422016.N().S(`
</label> </label>
</div> </div>
`) `)
//line web/authorize.qtpl:99 //line web/authorize.qtpl:99
} }
//line web/authorize.qtpl:99 //line web/authorize.qtpl:99
qw422016.N().S(` qw422016.N().S(`
</fieldset> </fieldset>
`) `)
//line web/authorize.qtpl:101 //line web/authorize.qtpl:101
} }
//line web/authorize.qtpl:101 //line web/authorize.qtpl:101
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:103 //line web/authorize.qtpl:103
if p.CodeChallenge != "" { if p.CodeChallenge != "" {
//line web/authorize.qtpl:103 //line web/authorize.qtpl:103
qw422016.N().S(` qw422016.N().S(`
<input type="hidden" <input type="hidden"
name="code_challenge" name="code_challenge"
value="`) value="`)
//line web/authorize.qtpl:106 //line web/authorize.qtpl:106
qw422016.E().S(p.CodeChallenge) qw422016.E().S(p.CodeChallenge)
//line web/authorize.qtpl:106 //line web/authorize.qtpl:106
qw422016.N().S(`"> qw422016.N().S(`">
<input type="hidden" <input type="hidden"
name="code_challenge_method" name="code_challenge_method"
value="`) value="`)
//line web/authorize.qtpl:110 //line web/authorize.qtpl:110
qw422016.E().S(p.CodeChallengeMethod.String()) qw422016.E().S(p.CodeChallengeMethod.String())
//line web/authorize.qtpl:110 //line web/authorize.qtpl:110
qw422016.N().S(`"> qw422016.N().S(`">
`) `)
//line web/authorize.qtpl:111 //line web/authorize.qtpl:111
} }
//line web/authorize.qtpl:111 //line web/authorize.qtpl:111
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:113 //line web/authorize.qtpl:113
if p.Me != nil { if p.Me != nil {
//line web/authorize.qtpl:113 //line web/authorize.qtpl:113
qw422016.N().S(` qw422016.N().S(`
<input type="hidden" <input type="hidden"
name="me" name="me"
value="`) value="`)
//line web/authorize.qtpl:116 //line web/authorize.qtpl:116
qw422016.E().S(p.Me.String()) qw422016.E().S(p.Me.String())
//line web/authorize.qtpl:116 //line web/authorize.qtpl:116
qw422016.N().S(`"> qw422016.N().S(`">
`) `)
//line web/authorize.qtpl:117 //line web/authorize.qtpl:117
} }
//line web/authorize.qtpl:117 //line web/authorize.qtpl:117
qw422016.N().S(` qw422016.N().S(`
`) `)
//line web/authorize.qtpl:119 //line web/authorize.qtpl:119
if len(p.Providers) > 0 { if len(p.Providers) > 0 {
//line web/authorize.qtpl:119 //line web/authorize.qtpl:119
qw422016.N().S(` qw422016.N().S(`
<select name="provider" <select name="provider"
autocomplete autocomplete
required> required>
`) `)
//line web/authorize.qtpl:124 //line web/authorize.qtpl:124
for _, provider := range p.Providers { for _, provider := range p.Providers {
//line web/authorize.qtpl:124 //line web/authorize.qtpl:124
qw422016.N().S(` qw422016.N().S(`
<option value="`) <option value="`)
//line web/authorize.qtpl:125 //line web/authorize.qtpl:125
qw422016.E().S(provider.UID) qw422016.E().S(provider.UID)
//line web/authorize.qtpl:125 //line web/authorize.qtpl:125
@ -360,55 +360,55 @@ func (p *AuthorizePage) StreamBody(qw422016 *qt422016.Writer) {
//line web/authorize.qtpl:126 //line web/authorize.qtpl:126
qw422016.N().S(`> qw422016.N().S(`>
`) `)
//line web/authorize.qtpl:128 //line web/authorize.qtpl:128
qw422016.E().S(provider.Name) qw422016.E().S(provider.Name)
//line web/authorize.qtpl:128 //line web/authorize.qtpl:128
qw422016.N().S(` qw422016.N().S(`
</option> </option>
`) `)
//line web/authorize.qtpl:130 //line web/authorize.qtpl:130
} }
//line web/authorize.qtpl:130 //line web/authorize.qtpl:130
qw422016.N().S(` qw422016.N().S(`
</select> </select>
`) `)
//line web/authorize.qtpl:132 //line web/authorize.qtpl:132
} else { } else {
//line web/authorize.qtpl:132 //line web/authorize.qtpl:132
qw422016.N().S(` qw422016.N().S(`
<input type="hidden" <input type="hidden"
name="provider" name="provider"
value="direct"> value="direct">
`) `)
//line web/authorize.qtpl:136 //line web/authorize.qtpl:136
} }
//line web/authorize.qtpl:136 //line web/authorize.qtpl:136
qw422016.N().S(` qw422016.N().S(`
<button type="submit" <button type="submit"
name="authorize" name="authorize"
value="deny"> value="deny">
`) `)
//line web/authorize.qtpl:142 //line web/authorize.qtpl:142
p.StreamT(qw422016, "Deny") p.StreamT(qw422016, "Deny")
//line web/authorize.qtpl:142 //line web/authorize.qtpl:142
qw422016.N().S(` qw422016.N().S(`
</button> </button>
<button type="submit" <button type="submit"
name="authorize" name="authorize"
value="allow"> value="allow">
`) `)
//line web/authorize.qtpl:149 //line web/authorize.qtpl:149
p.StreamT(qw422016, "Allow") p.StreamT(qw422016, "Allow")
//line web/authorize.qtpl:149 //line web/authorize.qtpl:149
qw422016.N().S(` qw422016.N().S(`
</button> </button>
</form> </form>
</main> </main>
`) `)
//line web/authorize.qtpl:153 //line web/authorize.qtpl:153
} }

View File

@ -5,47 +5,47 @@
{% collapsespace %} {% collapsespace %}
{% func (p *TicketPage) Body() %} {% func (p *TicketPage) Body() %}
<header> <header>
<h1>{%= p.T("TicketAuth") %}</h1> <h1>{%= p.T("TicketAuth") %}</h1>
</header> </header>
<main> <main>
<form class="" <form class=""
accept-charset="utf-8" accept-charset="utf-8"
action="/api/ticket" action="/ticket/send"
autocomplete="off" autocomplete="off"
enctype="application/x-www-form-urlencoded" enctype="application/x-www-form-urlencoded"
method="post" method="post"
target="_self"> target="_self">
{% if p.CSRF != nil %} {% if p.CSRF != nil %}
<input type="hidden" <input type="hidden"
name="_csrf" name="_csrf"
value="{%z p.CSRF %}"> value="{%z p.CSRF %}">
{% endif %} {% endif %}
<div> <div>
<label for="subject">{%= p.T("Recipient") %}</label> <label for="subject">{%= p.T("Recipient") %}</label>
<input id="subject" <input id="subject"
type="url" type="url"
name="subject" name="subject"
inputmode="url" inputmode="url"
placeholder="https://bob.example.org" placeholder="https://bob.example.org"
required> required>
</div> </div>
<div> <div>
<label for="resource">{%= p.T("Resource") %}</label> <label for="resource">{%= p.T("Resource") %}</label>
<input id="resource" <input id="resource"
type="url" type="url"
name="resource" name="resource"
inputmode="url" inputmode="url"
placeholder="https://alice.example.com/private/" placeholder="https://alice.example.com/private/"
required> required>
</div> </div>
<button type="submit">{%= p.T("Send") %}</button> <button type="submit">{%= p.T("Send") %}</button>
</form> </form>
</main> </main>
{% endfunc %} {% endfunc %}
{% endcollapsespace %} {% endcollapsespace %}

View File

@ -30,7 +30,7 @@ func (p *TicketPage) StreamBody(qw422016 *qt422016.Writer) {
//line web/ticket.qtpl:9 //line web/ticket.qtpl:9
p.StreamT(qw422016, "TicketAuth") p.StreamT(qw422016, "TicketAuth")
//line web/ticket.qtpl:9 //line web/ticket.qtpl:9
qw422016.N().S(`</h1> </header> <main> <form class="" accept-charset="utf-8" action="/api/ticket" autocomplete="off" enctype="application/x-www-form-urlencoded" method="post" target="_self"> `) qw422016.N().S(`</h1> </header> <main> <form class="" accept-charset="utf-8" action="/ticket/send" autocomplete="off" enctype="application/x-www-form-urlencoded" method="post" target="_self"> `)
//line web/ticket.qtpl:21 //line web/ticket.qtpl:21
if p.CSRF != nil { if p.CSRF != nil {
//line web/ticket.qtpl:21 //line web/ticket.qtpl:21