♻️ Drop HTTP delivery framework

replaced 'valyala/fasthttp' to native 'net/http' package, close #4
This commit is contained in:
Maxim Lebedev 2023-01-15 03:27:37 +06:00
parent 95cb5a2950
commit c7bd73c63a
Signed by: toby3d
GPG key ID: 1F14E25B7C119FC5
90 changed files with 2551 additions and 2286 deletions

View file

@ -2,13 +2,10 @@ package http
import (
"crypto/subtle"
"errors"
"path"
"net/http"
"strings"
"github.com/fasthttp/router"
json "github.com/goccy/go-json"
http "github.com/valyala/fasthttp"
"github.com/goccy/go-json"
"golang.org/x/text/language"
"golang.org/x/text/message"
@ -16,111 +13,31 @@ import (
"source.toby3d.me/toby3d/auth/internal/client"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/middleware"
"source.toby3d.me/toby3d/auth/internal/profile"
"source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/auth/web"
"source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware"
)
type (
AuthAuthorizationRequest struct {
// Indicates to the authorization server that an authorization
// code should be returned as the response.
ResponseType domain.ResponseType `form:"response_type"` // code
// The client URL.
ClientID *domain.ClientID `form:"client_id"`
// The redirect URL indicating where the user should be
// redirected to after approving the request.
RedirectURI *domain.URL `form:"redirect_uri"`
// A parameter set by the client which will be included when the
// user is redirected back to the client. This is used to
// prevent CSRF attacks. The authorization server MUST return
// the unmodified state value back to the client.
State string `form:"state"`
// The code challenge as previously described.
CodeChallenge string `form:"code_challenge"`
// The hashing method used to calculate the code challenge.
CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"`
// A space-separated list of scopes the client is requesting,
// e.g. "profile", or "profile create". If the client omits this
// value, the authorization server MUST NOT issue an access
// token for this authorization code. Only the user's profile
// URL may be returned without any scope requested.
Scope domain.Scopes `form:"scope,omitempty"`
// The URL that the user entered.
Me *domain.Me `form:"me"`
}
AuthVerifyRequest struct {
ClientID *domain.ClientID `form:"client_id"`
Me *domain.Me `form:"me"`
RedirectURI *domain.URL `form:"redirect_uri"`
CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"`
ResponseType domain.ResponseType `form:"response_type"`
Scope domain.Scopes `form:"scope[],omitempty"`
Authorize string `form:"authorize"`
CodeChallenge string `form:"code_challenge"`
State string `form:"state"`
Provider string `form:"provider"`
}
AuthExchangeRequest struct {
GrantType domain.GrantType `form:"grant_type"` // authorization_code
// The authorization code received from the authorization
// endpoint in the redirect.
Code string `form:"code"`
// The client's URL, which MUST match the client_id used in the
// authentication request.
ClientID *domain.ClientID `form:"client_id"`
// The client's redirect URL, which MUST match the initial
// authentication request.
RedirectURI *domain.URL `form:"redirect_uri"`
// The original plaintext random string generated before
// starting the authorization request.
CodeVerifier string `form:"code_verifier"`
}
AuthExchangeResponse struct {
Me *domain.Me `json:"me"`
Profile *AuthProfileResponse `json:"profile,omitempty"`
}
AuthProfileResponse struct {
Email *domain.Email `json:"email,omitempty"`
Photo *domain.URL `json:"photo,omitempty"`
URL *domain.URL `json:"url,omitempty"`
Name string `json:"name,omitempty"`
}
NewRequestHandlerOptions struct {
NewHandlerOptions struct {
Auth auth.UseCase
Clients client.UseCase
Config *domain.Config
Config domain.Config
Matcher language.Matcher
Profiles profile.UseCase
}
RequestHandler struct {
Handler struct {
clients client.UseCase
config *domain.Config
config domain.Config
matcher language.Matcher
useCase auth.UseCase
}
)
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
return &RequestHandler{
func NewHandler(opts NewHandlerOptions) *Handler {
return &Handler{
clients: opts.Clients,
config: opts.Config,
matcher: opts.Matcher,
@ -128,16 +45,16 @@ func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
}
}
func (h *RequestHandler) Register(r *router.Router) {
func (h *Handler) Handler() http.Handler {
chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{
Skipper: func(ctx *http.RequestCtx) bool {
matched, _ := path.Match("/authorize*", string(ctx.Path()))
Skipper: func(w http.ResponseWriter, r *http.Request) bool {
head, _ := urlutil.ShiftPath(r.URL.Path)
return ctx.IsPost() && matched
return r.Method == http.MethodPost && head == "authorize"
},
CookieMaxAge: 0,
CookieSameSite: http.CookieSameSiteStrictMode,
CookieSameSite: http.SameSiteStrictMode,
ContextKey: "csrf",
CookieDomain: h.config.Server.Domain,
CookieName: "__Secure-csrf",
@ -148,14 +65,12 @@ func (h *RequestHandler) Register(r *router.Router) {
CookieHTTPOnly: true,
}),
middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{
Skipper: func(ctx *http.RequestCtx) bool {
matched, _ := path.Match("/api/*", string(ctx.Path()))
provider := string(ctx.QueryArgs().Peek("provider"))
providerMatched := provider != "" && provider != domain.ProviderDirect.UID
Skipper: func(w http.ResponseWriter, r *http.Request) bool {
head, _ := urlutil.ShiftPath(r.URL.Path)
return !ctx.IsPost() || !matched || providerMatched
return r.Method != http.MethodPost || head != "api"
},
Validator: func(ctx *http.RequestCtx, login, password string) (bool, error) {
Validator: func(w http.ResponseWriter, r *http.Request, login, password string) (bool, error) {
userMatch := subtle.ConstantTimeCompare([]byte(login),
[]byte(h.config.IndieAuth.Username))
passMatch := subtle.ConstantTimeCompare([]byte(password),
@ -165,29 +80,57 @@ func (h *RequestHandler) Register(r *router.Router) {
},
Realm: "",
}),
middleware.LogFmt(),
}
r.GET("/authorize", chain.RequestHandler(h.handleAuthorize))
r.POST("/api/authorize", chain.RequestHandler(h.handleVerify))
r.POST("/authorize", chain.RequestHandler(h.handleExchange))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var head string
head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
switch r.Method {
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
case http.MethodGet, "":
if head != "" {
http.NotFound(w, r)
return
}
chain.Handler(h.handleAuthorize).ServeHTTP(w, r)
case http.MethodPost:
switch head {
default:
http.NotFound(w, r)
case "":
chain.Handler(h.handleExchange).ServeHTTP(w, r)
case "verify":
chain.Handler(h.handleVerify).ServeHTTP(w, r)
}
}
})
}
func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
func (h *Handler) handleAuthorize(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet && r.Method != "" {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
return
}
w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...)
baseOf := web.BaseOf{
Config: h.config,
Config: &h.config,
Language: tag,
Printer: message.NewPrinter(tag),
}
req := NewAuthAuthorizationRequest()
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{
if err := req.bind(r); err != nil {
w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: err,
})
@ -195,10 +138,10 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
return
}
client, err := h.clients.Discovery(ctx, req.ClientID)
client, err := h.clients.Discovery(r.Context(), req.ClientID)
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{
w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: err,
})
@ -207,8 +150,8 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
}
if !client.ValidateRedirectURI(req.RedirectURI.URL) {
ctx.SetStatusCode(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{
w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: domain.NewError(
domain.ErrorCodeInvalidClient,
@ -220,15 +163,15 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
return
}
csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte)
web.WriteTemplate(ctx, &web.AuthorizePage{
csrf, _ := r.Context().Value(middleware.DefaultCSRFConfig.ContextKey).([]byte)
web.WriteTemplate(w, &web.AuthorizePage{
BaseOf: baseOf,
CSRF: csrf,
Scope: req.Scope,
Client: client,
Me: req.Me,
RedirectURI: req.RedirectURI,
CodeChallengeMethod: req.CodeChallengeMethod,
Me: &req.Me,
RedirectURI: &req.RedirectURI,
CodeChallengeMethod: *req.CodeChallengeMethod,
ResponseType: req.ResponseType,
CodeChallenge: req.CodeChallenge,
State: req.State,
@ -236,15 +179,21 @@ func (h *RequestHandler) handleAuthorize(ctx *http.RequestCtx) {
})
}
func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
ctx.Response.Header.Set(http.HeaderAccessControlAllowOrigin, h.config.Server.Domain)
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
func (h *Handler) handleVerify(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx)
return
}
w.Header().Set(common.HeaderAccessControlAllowOrigin, h.config.Server.Domain)
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
req := NewAuthVerifyRequest()
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
if err := req.bind(r); err != nil {
w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
@ -254,60 +203,70 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
if strings.EqualFold(req.Authorize, "deny") {
domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State).
SetReirectURI(req.RedirectURI.URL)
ctx.Redirect(req.RedirectURI.String(), http.StatusFound)
http.Redirect(w, r, req.RedirectURI.String(), http.StatusFound)
return
}
code, err := h.useCase.Generate(ctx, auth.GenerateOptions{
code, err := h.useCase.Generate(r.Context(), auth.GenerateOptions{
ClientID: req.ClientID,
Me: req.Me,
RedirectURI: req.RedirectURI.URL,
CodeChallengeMethod: req.CodeChallengeMethod,
CodeChallengeMethod: *req.CodeChallengeMethod,
Scope: req.Scope,
CodeChallenge: req.CodeChallenge,
})
if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
w.WriteHeader(http.StatusInternalServerError)
_ = encoder.Encode(err)
return
}
q := req.RedirectURI.Query()
for key, val := range map[string]string{
"code": code,
"iss": h.config.Server.GetRootURL(),
"state": req.State,
} {
req.RedirectURI.Query().Set(key, val)
q.Set(key, val)
}
ctx.Redirect(req.RedirectURI.String(), http.StatusFound)
req.RedirectURI.RawQuery = q.Encode()
http.Redirect(w, r, req.RedirectURI.String(), http.StatusFound)
}
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
func (h *Handler) handleExchange(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
encoder := json.NewEncoder(ctx)
return
}
w.Header().Set(common.HeaderContentType, common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(w)
req := new(AuthExchangeRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
if err := req.bind(r); err != nil {
w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
return
}
me, profile, err := h.useCase.Exchange(ctx, auth.ExchangeOptions{
me, profile, err := h.useCase.Exchange(r.Context(), auth.ExchangeOptions{
Code: req.Code,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI.URL,
CodeVerifier: req.CodeVerifier,
})
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
w.WriteHeader(http.StatusBadRequest)
_ = encoder.Encode(err)
@ -325,109 +284,7 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
}
_ = encoder.Encode(&AuthExchangeResponse{
Me: me,
Me: *me,
Profile: userInfo,
})
}
func NewAuthAuthorizationRequest() *AuthAuthorizationRequest {
return &AuthAuthorizationRequest{
ClientID: new(domain.ClientID),
CodeChallenge: "",
CodeChallengeMethod: domain.CodeChallengeMethodUnd,
Me: new(domain.Me),
RedirectURI: new(domain.URL),
ResponseType: domain.ResponseTypeUnd,
Scope: make(domain.Scopes, 0),
State: "",
}
}
//nolint:cyclop
func (r *AuthAuthorizationRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
return nil
}
func NewAuthVerifyRequest() *AuthVerifyRequest {
return &AuthVerifyRequest{
Authorize: "",
ClientID: new(domain.ClientID),
CodeChallenge: "",
CodeChallengeMethod: domain.CodeChallengeMethodUnd,
Me: new(domain.Me),
Provider: "",
RedirectURI: new(domain.URL),
ResponseType: domain.ResponseTypeUnd,
Scope: make(domain.Scopes, 0),
State: "",
}
}
//nolint:funlen,cyclop
func (r *AuthVerifyRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
// NOTE(toby3d): backwards-compatible support.
// See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
r.Provider = strings.ToLower(r.Provider)
if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"cannot validate verification request",
"https://indieauth.net/source/#authorization-request",
)
}
return nil
}
func (r *AuthExchangeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs().QueryString(), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"cannot validate verification request",
"https://indieauth.net/source/#redeeming-the-authorization-code",
)
}
return nil
}

View file

@ -0,0 +1,212 @@
package http
import (
"errors"
"net/http"
"strings"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/form"
)
type (
AuthAuthorizationRequest struct {
// Indicates to the authorization server that an authorization
// code should be returned as the response.
ResponseType domain.ResponseType `form:"response_type"` // code
// The client URL.
ClientID domain.ClientID `form:"client_id"`
// The redirect URL indicating where the user should be
// redirected to after approving the request.
RedirectURI domain.URL `form:"redirect_uri"`
// The URL that the user entered.
Me domain.Me `form:"me"`
// The hashing method used to calculate the code challenge.
CodeChallengeMethod *domain.CodeChallengeMethod `form:"code_challenge_method,omitempty"`
// A space-separated list of scopes the client is requesting,
// e.g. "profile", or "profile create". If the client omits this
// value, the authorization server MUST NOT issue an access
// token for this authorization code. Only the user's profile
// URL may be returned without any scope requested.
Scope domain.Scopes `form:"scope,omitempty"`
// A parameter set by the client which will be included when the
// user is redirected back to the client. This is used to
// prevent CSRF attacks. The authorization server MUST return
// the unmodified state value back to the client.
State string `form:"state"`
// The code challenge as previously described.
CodeChallenge string `form:"code_challenge,omitempty"`
}
AuthVerifyRequest struct {
ClientID domain.ClientID `form:"client_id"`
Me domain.Me `form:"me"`
RedirectURI domain.URL `form:"redirect_uri"`
ResponseType domain.ResponseType `form:"response_type"`
CodeChallengeMethod *domain.CodeChallengeMethod `form:"code_challenge_method,omitempty"`
Scope domain.Scopes `form:"scope[],omitempty"`
Authorize string `form:"authorize"`
CodeChallenge string `form:"code_challenge,omitempty"`
State string `form:"state"`
Provider string `form:"provider"`
}
AuthExchangeRequest struct {
GrantType domain.GrantType `form:"grant_type"` // authorization_code
// The client's URL, which MUST match the client_id used in the
// authentication request.
ClientID domain.ClientID `form:"client_id"`
// The client's redirect URL, which MUST match the initial
// authentication request.
RedirectURI domain.URL `form:"redirect_uri"`
// The authorization code received from the authorization
// endpoint in the redirect.
Code string `form:"code"`
// The original plaintext random string generated before
// starting the authorization request.
CodeVerifier string `form:"code_verifier"`
}
AuthExchangeResponse struct {
Me domain.Me `json:"me"`
Profile *AuthProfileResponse `json:"profile,omitempty"`
}
AuthProfileResponse struct {
Email *domain.Email `json:"email,omitempty"`
Photo *domain.URL `json:"photo,omitempty"`
URL *domain.URL `json:"url,omitempty"`
Name string `json:"name,omitempty"`
}
)
func NewAuthAuthorizationRequest() *AuthAuthorizationRequest {
return &AuthAuthorizationRequest{
ClientID: domain.ClientID{},
CodeChallenge: "",
CodeChallengeMethod: &domain.CodeChallengeMethodUnd,
Me: domain.Me{},
RedirectURI: domain.URL{},
ResponseType: domain.ResponseTypeUnd,
Scope: make(domain.Scopes, 0),
State: "",
}
}
//nolint:cyclop
func (r *AuthAuthorizationRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal([]byte(req.URL.Query().Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
return nil
}
func NewAuthVerifyRequest() *AuthVerifyRequest {
return &AuthVerifyRequest{
Authorize: "",
ClientID: domain.ClientID{},
CodeChallenge: "",
CodeChallengeMethod: &domain.CodeChallengeMethodUnd,
Me: domain.Me{},
Provider: "",
RedirectURI: domain.URL{},
ResponseType: domain.ResponseTypeUnd,
Scope: make(domain.Scopes, 0),
State: "",
}
}
//nolint:funlen,cyclop
func (r *AuthVerifyRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := req.ParseForm(); err != nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
// NOTE(toby3d): backwards-compatible support.
// See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
r.Provider = strings.ToLower(r.Provider)
if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"cannot validate verification request",
"https://indieauth.net/source/#authorization-request",
)
}
return nil
}
func (r *AuthExchangeRequest) bind(req *http.Request) error {
indieAuthError := new(domain.Error)
if err := req.ParseForm(); err != nil {
return domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#authorization-request",
)
}
if err := form.Unmarshal([]byte(req.PostForm.Encode()), r); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(
domain.ErrorCodeInvalidRequest,
"cannot validate verification request",
"https://indieauth.net/source/#redeeming-the-authorization-code",
)
}
return nil
}

View file

@ -1,13 +1,14 @@
package http_test
import (
"path"
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
@ -22,7 +23,7 @@ import (
profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory"
"source.toby3d.me/toby3d/auth/internal/session"
sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
"source.toby3d.me/toby3d/auth/internal/user"
userrepo "source.toby3d.me/toby3d/auth/internal/user/repository/memory"
)
@ -34,36 +35,31 @@ type Dependencies struct {
matcher language.Matcher
profiles profile.Repository
sessions session.Repository
store *sync.Map
users user.Repository
}
func TestAuthorize(t *testing.T) {
t.Parallel()
deps := NewDependencies(t)
me := domain.TestMe(t, "https://user.example.net")
me := domain.TestMe(t, "https://user.example.net/")
user := domain.TestUser(t)
client := domain.TestClient(t)
deps.store.Store(path.Join(clientrepo.DefaultPathPrefix, client.ID.String()), client)
deps.store.Store(path.Join(profilerepo.DefaultPathPrefix, me.String()), user.Profile)
deps.store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), user)
if err := deps.clients.Create(context.Background(), *client); err != nil {
t.Fatal(err)
}
r := router.New()
//nolint:exhaustivestruct
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
Auth: deps.authService,
Clients: deps.clientService,
Config: deps.config,
Matcher: deps.matcher,
}).Register(r)
if err := deps.users.Create(context.Background(), *user); err != nil {
t.Fatal(err)
}
httpClient, _, cleanup := httptest.New(t, r.Handler)
t.Cleanup(cleanup)
if err := deps.profiles.Create(context.Background(), *me, *user.Profile); err != nil {
t.Fatal(err)
}
uri := http.AcquireURI()
defer http.ReleaseURI(uri)
uri.Update("https://example.com/authorize")
u := &url.URL{Scheme: "https", Host: "example.com", Path: "/"}
q := u.Query()
for key, val := range map[string]string{
"client_id": client.ID.String(),
@ -75,26 +71,36 @@ func TestAuthorize(t *testing.T) {
"scope": "profile email",
"state": "1234567890",
} {
uri.QueryArgs().Set(key, val)
q.Set(key, val)
}
req := httptest.NewRequest(http.MethodGet, uri.String(), nil)
defer http.ReleaseRequest(req)
u.RawQuery = q.Encode()
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
req := httptest.NewRequest(http.MethodGet, u.String(), nil)
w := httptest.NewRecorder()
if err := httpClient.Do(req, resp); err != nil {
//nolint:exhaustivestruct
delivery.NewHandler(delivery.NewHandlerOptions{
Auth: deps.authService,
Clients: deps.clientService,
Config: *deps.config,
Matcher: deps.matcher,
}).Handler().ServeHTTP(w, req)
resp := w.Result()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode() != http.StatusOK {
t.Errorf("GET %s = %d, want %d", uri.String(), resp.StatusCode(), http.StatusOK)
if resp.StatusCode != http.StatusOK {
t.Errorf("%s %s = %d, want %d", req.Method, u.String(), resp.StatusCode, http.StatusOK)
}
const expResult = `Authorize application`
if result := string(resp.Body()); !strings.Contains(result, expResult) {
t.Errorf("GET %s = %s, want %s", uri.String(), result, expResult)
if result := string(body); !strings.Contains(result, expResult) {
t.Errorf("%s %s = %s, want %s", req.Method, u.String(), result, expResult)
}
}
@ -103,14 +109,15 @@ func NewDependencies(tb testing.TB) Dependencies {
config := domain.TestConfig(tb)
matcher := language.NewMatcher(message.DefaultCatalog.Languages())
store := new(sync.Map)
clients := clientrepo.NewMemoryClientRepository(store)
sessions := sessionrepo.NewMemorySessionRepository(store, config)
profiles := profilerepo.NewMemoryProfileRepository(store)
clients := clientrepo.NewMemoryClientRepository()
users := userrepo.NewMemoryUserRepository()
sessions := sessionrepo.NewMemorySessionRepository(*config)
profiles := profilerepo.NewMemoryProfileRepository()
authService := ucase.NewAuthUseCase(sessions, profiles, config)
clientService := clientucase.NewClientUseCase(clients)
return Dependencies{
users: users,
authService: authService,
clients: clients,
clientService: clientService,
@ -118,6 +125,5 @@ func NewDependencies(tb testing.TB) Dependencies {
matcher: matcher,
sessions: sessions,
profiles: profiles,
store: store,
}
}

View file

@ -9,8 +9,8 @@ import (
type (
GenerateOptions struct {
ClientID *domain.ClientID
Me *domain.Me
ClientID domain.ClientID
Me domain.Me
RedirectURI *url.URL
CodeChallengeMethod domain.CodeChallengeMethod
Scope domain.Scopes
@ -18,7 +18,7 @@ type (
}
ExchangeOptions struct {
ClientID *domain.ClientID
ClientID domain.ClientID
RedirectURI *url.URL
Code string
CodeVerifier string

View file

@ -45,7 +45,7 @@ func (uc *authUseCase) Generate(ctx context.Context, opts auth.GenerateOptions)
}
}
if err = uc.sessions.Create(ctx, &domain.Session{
if err = uc.sessions.Create(ctx, domain.Session{
ClientID: opts.ClientID,
Code: code,
CodeChallenge: opts.CodeChallenge,
@ -81,5 +81,5 @@ func (uc *authUseCase) Exchange(ctx context.Context, opts auth.ExchangeOptions)
return nil, nil, auth.ErrMismatchPKCE
}
return session.Me, session.Profile, nil
return &session.Me, session.Profile, nil
}

View file

@ -1,48 +1,37 @@
package http
import (
"errors"
"net/http"
"strings"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/token"
"source.toby3d.me/toby3d/auth/internal/urlutil"
"source.toby3d.me/toby3d/auth/web"
"source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware"
)
type (
ClientCallbackRequest struct {
Error domain.ErrorCode `form:"error,omitempty"`
Iss *domain.ClientID `form:"iss"`
Code string `form:"code"`
ErrorDescription string `form:"error_description,omitempty"`
State string `form:"state"`
}
NewRequestHandlerOptions struct {
NewHandlerOptions struct {
Matcher language.Matcher
Tokens token.UseCase
Client *domain.Client
Config *domain.Config
Client domain.Client
Config domain.Config
}
RequestHandler struct {
Handler struct {
matcher language.Matcher
tokens token.UseCase
client *domain.Client
config *domain.Config
client domain.Client
config domain.Config
}
)
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
return &RequestHandler{
func NewHandler(opts NewHandlerOptions) *Handler {
return &Handler{
client: opts.Client,
config: opts.Config,
matcher: opts.Matcher,
@ -50,59 +39,82 @@ func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
}
}
func (h *RequestHandler) Register(r *router.Router) {
chain := middleware.Chain{
middleware.LogFmt(),
}
func (h *Handler) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "" && r.Method != http.MethodGet {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
r.GET("/", chain.RequestHandler(h.handleRender))
r.GET("/callback", chain.RequestHandler(h.handleCallback))
return
}
var head string
head, r.URL.Path = urlutil.ShiftPath(r.URL.Path)
switch head {
default:
http.NotFound(w, r)
case "":
h.handleRender(w, r)
case "callback":
h.handleCallback(w, r)
}
})
}
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
func (h *Handler) handleRender(w http.ResponseWriter, r *http.Request) {
if r.Method != "" && r.Method != http.MethodGet {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
redirect := make([]string, len(h.client.RedirectURI))
for i := range h.client.RedirectURI {
redirect[i] = h.client.RedirectURI[i].String()
}
ctx.Response.Header.Set(
http.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
)
w.Header().Set(common.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...)
// TODO(toby3d): generate and store PKCE
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(ctx, &web.HomePage{
w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(w, &web.HomePage{
BaseOf: web.BaseOf{
Config: h.config,
Config: &h.config,
Language: tag,
Printer: message.NewPrinter(tag),
},
Client: h.client,
Client: &h.client,
State: "hackme", // TODO(toby3d): generate and store state
})
}
//nolint:unlen
func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
if r.Method != "" && r.Method != http.MethodGet {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
return
}
w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
tag, _, _ := h.matcher.Match(tags...)
baseOf := web.BaseOf{
Config: h.config,
Config: &h.config,
Language: tag,
Printer: message.NewPrinter(tag),
}
req := new(ClientCallbackRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
web.WriteTemplate(ctx, &web.ErrorPage{
if err := req.bind(r); err != nil {
w.WriteHeader(http.StatusInternalServerError)
web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: err,
})
@ -111,8 +123,8 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
}
if req.Error != domain.ErrorCodeUnd {
ctx.SetStatusCode(http.StatusUnauthorized)
web.WriteTemplate(ctx, &web.ErrorPage{
w.WriteHeader(http.StatusUnauthorized)
web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: domain.NewError(
domain.ErrorCodeAccessDenied,
@ -127,9 +139,9 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
// TODO(toby3d): load and check state
if req.Iss == nil || req.Iss.String() != h.client.ID.String() {
ctx.SetStatusCode(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{
if req.Iss.String() != h.client.ID.String() {
w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: domain.NewError(
domain.ErrorCodeInvalidClient,
@ -142,15 +154,15 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
return
}
token, _, err := h.tokens.Exchange(ctx, token.ExchangeOptions{
token, _, err := h.tokens.Exchange(r.Context(), token.ExchangeOptions{
ClientID: h.client.ID,
RedirectURI: h.client.RedirectURI[0],
Code: req.Code,
CodeVerifier: "", // TODO(toby3d): validate PKCE here
})
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{
w.WriteHeader(http.StatusBadRequest)
web.WriteTemplate(w, &web.ErrorPage{
BaseOf: baseOf,
Error: err,
})
@ -158,23 +170,9 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
return
}
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(ctx, &web.CallbackPage{
w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(w, &web.CallbackPage{
BaseOf: baseOf,
Token: token,
})
}
func (req *ClientCallbackRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs().QueryString(), req); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")
}
return nil
}

View file

@ -0,0 +1,31 @@
package http
import (
"errors"
"net/http"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/form"
)
type ClientCallbackRequest struct {
Error domain.ErrorCode `form:"error,omitempty"`
Iss domain.ClientID `form:"iss"`
Code string `form:"code"`
ErrorDescription string `form:"error_description,omitempty"`
State string `form:"state"`
}
func (req *ClientCallbackRequest) bind(r *http.Request) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal([]byte(r.URL.Query().Encode()), req); err != nil {
if errors.As(err, indieAuthError) {
return indieAuthError
}
return domain.NewError(domain.ErrorCodeInvalidRequest, err.Error(), "")
}
return nil
}

View file

@ -1,11 +1,10 @@
package http_test
import (
"sync"
"net/http"
"net/http/httptest"
"testing"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
@ -15,7 +14,6 @@ import (
profilerepo "source.toby3d.me/toby3d/auth/internal/profile/repository/memory"
"source.toby3d.me/toby3d/auth/internal/session"
sessionrepo "source.toby3d.me/toby3d/auth/internal/session/repository/memory"
"source.toby3d.me/toby3d/auth/internal/testing/httptest"
"source.toby3d.me/toby3d/auth/internal/token"
tokenrepo "source.toby3d.me/toby3d/auth/internal/token/repository/memory"
tokenucase "source.toby3d.me/toby3d/auth/internal/token/usecase"
@ -27,7 +25,6 @@ type Dependencies struct {
config *domain.Config
matcher language.Matcher
sessions session.Repository
store *sync.Map
tokens token.Repository
tokenService token.UseCase
}
@ -36,45 +33,30 @@ func TestRead(t *testing.T) {
t.Parallel()
deps := NewDependencies(t)
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/", nil)
w := httptest.NewRecorder()
r := router.New()
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
Client: deps.client,
Config: deps.config,
delivery.NewHandler(delivery.NewHandlerOptions{
Client: *deps.client,
Config: *deps.config,
Matcher: deps.matcher,
Tokens: deps.tokenService,
}).Register(r)
}).Handler().ServeHTTP(w, req)
client, _, cleanup := httptest.New(t, r.Handler)
t.Cleanup(cleanup)
const requestURI string = "https://app.example.com/"
req, resp := httptest.NewRequest(http.MethodGet, requestURI, nil), http.AcquireResponse()
t.Cleanup(func() {
http.ReleaseRequest(req)
http.ReleaseResponse(resp)
})
if err := client.Do(req, resp); err != nil {
t.Error(err)
}
if resp.StatusCode() != http.StatusOK {
t.Errorf("GET %s = %d, want %d", requestURI, resp.StatusCode(), http.StatusOK)
if resp := w.Result(); resp.StatusCode != http.StatusOK {
t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, http.StatusOK)
}
}
func NewDependencies(tb testing.TB) Dependencies {
tb.Helper()
store := new(sync.Map)
client := domain.TestClient(tb)
config := domain.TestConfig(tb)
matcher := language.NewMatcher(message.DefaultCatalog.Languages())
sessions := sessionrepo.NewMemorySessionRepository(store, config)
tokens := tokenrepo.NewMemoryTokenRepository(store)
profiles := profilerepo.NewMemoryProfileRepository(store)
sessions := sessionrepo.NewMemorySessionRepository(*config)
tokens := tokenrepo.NewMemoryTokenRepository()
profiles := profilerepo.NewMemoryProfileRepository()
tokenService := tokenucase.NewTokenUseCase(tokenucase.Config{
Config: config,
Profiles: profiles,
@ -87,7 +69,6 @@ func NewDependencies(tb testing.TB) Dependencies {
config: config,
matcher: matcher,
sessions: sessions,
store: store,
profiles: profiles,
tokens: tokens,
tokenService: tokenService,

View file

@ -7,7 +7,8 @@ import (
)
type Repository interface {
Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
Create(ctx context.Context, client domain.Client) error
Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error)
}
var ErrNotExist error = domain.NewError(

View file

@ -1,14 +1,17 @@
package http
import (
"bytes"
"context"
"fmt"
"net"
"io"
"net/http"
"net/url"
http "github.com/valyala/fasthttp"
"golang.org/x/exp/slices"
"source.toby3d.me/toby3d/auth/internal/client"
"source.toby3d.me/toby3d/auth/internal/common"
"source.toby3d.me/toby3d/auth/internal/domain"
"source.toby3d.me/toby3d/auth/internal/httputil"
)
@ -34,33 +37,18 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
}
}
func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) {
ips, err := net.LookupIP(cid.URL().Hostname())
// WARN(toby3d): not implemented.
func (httpClientRepository) Create(_ context.Context, _ domain.Client) error {
return nil
}
func (repo httpClientRepository) Get(ctx context.Context, cid domain.ClientID) (*domain.Client, error) {
resp, err := repo.client.Get(cid.String())
if err != nil {
return nil, fmt.Errorf("cannot resolve client IP by id: %w", err)
}
for _, ip := range ips {
if !ip.IsLoopback() {
continue
}
return nil, client.ErrNotExist
}
req := http.AcquireRequest()
defer http.ReleaseRequest(req)
req.SetRequestURI(cid.String())
req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
return nil, fmt.Errorf("failed to make a request to the client: %w", err)
}
if resp.StatusCode() == http.StatusNotFound {
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("%w: status on client page is not 200", client.ErrNotExist)
}
@ -72,74 +60,62 @@ func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID)
Name: make([]string, 0),
}
extract(client, resp)
extract(resp.Body, resp.Request.URL, client, resp.Header.Get(common.HeaderLink))
return client, nil
}
//nolint:gocognit,cyclop
func extract(dst *domain.Client, src *http.Response) {
for _, endpoint := range httputil.ExtractEndpoints(src, relRedirectURI) {
if !containsURL(dst.RedirectURI, endpoint) {
func extract(r io.Reader, u *url.URL, dst *domain.Client, header string) {
body, _ := io.ReadAll(r)
for _, endpoint := range httputil