♻️ Refactored auth package

This commit is contained in:
Maxim Lebedev 2022-01-14 01:49:41 +05:00
parent 60da2ac25e
commit 83dc4286eb
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
7 changed files with 477 additions and 548 deletions

View File

@ -1,348 +1,359 @@
package http
import (
"net/url"
"time"
"fmt"
"path"
"strings"
"github.com/fasthttp/router"
json "github.com/goccy/go-json"
http "github.com/valyala/fasthttp"
"gitlab.com/toby3d/indieauth/internal/auth"
"gitlab.com/toby3d/indieauth/internal/domain"
"gitlab.com/toby3d/indieauth/internal/middleware"
"gitlab.com/toby3d/indieauth/internal/pkce"
"gitlab.com/toby3d/indieauth/web"
"golang.org/x/text/language"
"golang.org/x/text/message"
"golang.org/x/xerrors"
"source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware"
"source.toby3d.me/website/indieauth/internal/auth"
"source.toby3d.me/website/indieauth/internal/client"
"source.toby3d.me/website/indieauth/internal/common"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/web"
)
type (
Handler struct {
useCase auth.UseCase
}
AuthorizeRequest struct {
RedirectURI string
ResponseType string
ClientID string
State []byte
Scope string
CodeChallenge string
CodeChallengeMethod string
Me string
// Indicates to the authorization server that an authorization
// code should be returned as the response.
ResponseType domain.ResponseType `form:"response_type"` // code
// The client URL.
ClientID *domain.ClientID `form:"client_id"`
// The redirect URL indicating where the user should be
// redirected to after approving the request.
RedirectURI *domain.URL `form:"redirect_uri"`
// A parameter set by the client which will be included when the
// user is redirected back to the client. This is used to
// prevent CSRF attacks. The authorization server MUST return
// the unmodified state value back to the client.
State string `form:"state"`
// The code challenge as previously described.
CodeChallenge string `form:"code_challenge"`
// The hashing method used to calculate the code challenge.
CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"`
// A space-separated list of scopes the client is requesting,
// e.g. "profile", or "profile create". If the client omits this
// value, the authorization server MUST NOT issue an access
// token for this authorization code. Only the user's profile
// URL may be returned without any scope requested.
Scope domain.Scopes `form:"scope"`
// The URL that the user entered.
Me *domain.Me `form:"me"`
}
RedirectRequest struct {
Authorize string
ClientID string
CodeChallenge string
CodeChallengeMethod string
Me string
RedirectURI string
ResponseType string
Scope string
State []byte
VerifyRequest struct {
ClientID *domain.ClientID `form:"client_id"`
Me *domain.Me `form:"me"`
RedirectURI *domain.URL `form:"redirect_uri"`
CodeChallengeMethod domain.CodeChallengeMethod `form:"code_challenge_method"`
ResponseType domain.ResponseType `form:"response_type"`
Scope domain.Scopes `form:"scope[]"` // TODO(toby3d): fix parsing in form pkg
Authorize string `form:"authorize"`
CodeChallenge string `form:"code_challenge"`
State string `form:"state"`
}
ExchangeRequest struct {
GrantType string
Code string
ClientID string
RedirectURI string
CodeVerifier string
GrantType domain.GrantType `form:"grant_type"` // authorization_code
// The authorization code received from the authorization
// endpoint in the redirect.
Code string `form:"code"`
// The client's URL, which MUST match the client_id used in the
// authentication request.
ClientID *domain.ClientID `form:"client_id"`
// The client's redirect URL, which MUST match the initial
// authentication request.
RedirectURI *domain.URL `form:"redirect_uri"`
// The original plaintext random string generated before
// starting the authorization request.
CodeVerifier string `form:"code_verifier"`
}
ExchangeResponse struct {
Me string `json:"me"`
Me *domain.Me `json:"me"`
}
NewRequestHandlerOptions struct {
Auth auth.UseCase
Clients client.UseCase
Config *domain.Config
Matcher language.Matcher
}
RequestHandler struct {
clients client.UseCase
config *domain.Config
matcher language.Matcher
useCase auth.UseCase
}
)
func NewAuthHandler(useCase auth.UseCase) *Handler {
return &Handler{
useCase: useCase,
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
return &RequestHandler{
clients: opts.Clients,
config: opts.Config,
matcher: opts.Matcher,
useCase: opts.Auth,
}
}
func (h *Handler) Register(r *router.Router) {
chain := middleware.Chain{middleware.CSRFWithConfig(middleware.CSRFConfig{
ContextKey: "csrf",
CookieHTTPOnly: true,
CookieName: "__Host-CSRF",
CookiePath: "/",
CookieSameSite: http.CookieSameSiteLaxMode,
CookieSecure: true,
TokenLookup: "form:_csrf",
Skipper: func(ctx *http.RequestCtx) bool {
return ctx.IsPost() && ctx.PostArgs().Has("grant_type") &&
string(ctx.PostArgs().Peek("grant_type")) == "authorization_code"
},
})}
func (h *RequestHandler) Register(r *router.Router) {
chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{
Skipper: func(ctx *http.RequestCtx) bool {
matched, _ := path.Match("/api/*", string(ctx.Path()))
r.GET("/authorize", chain.RequestHandler(h.ClientInfo))
r.POST("/authorize", chain.RequestHandler(h.Update))
return ctx.IsPost() && matched
},
CookieSameSite: http.CookieSameSiteLaxMode,
CookieName: "_csrf",
TokenLookup: "form:_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
}),
middleware.LogFmt(),
}
r.GET("/authorize", chain.RequestHandler(h.handleRender))
r.POST("/api/authorize", chain.RequestHandler(h.handleVerify))
r.POST("/authorize", chain.RequestHandler(h.handleExchange))
}
func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error {
if r.ClientID = string(ctx.QueryArgs().Peek("client_id")); r.ClientID == "" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'client_id' query is required",
}
}
if r.ResponseType = string(ctx.QueryArgs().Peek("response_type")); r.ResponseType != "code" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'response_type' must be 'code'",
}
}
if ctx.QueryArgs().Has("code_challenge") {
r.CodeChallenge = string(ctx.QueryArgs().Peek("code_challenge"))
if len(r.CodeChallenge) < 43 || len(r.CodeChallenge) > 128 {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "length of the 'code_challenge' value must be greater than 43 and less than 128 symbols",
}
}
r.CodeChallengeMethod = pkce.DefaultMethod
if ctx.PostArgs().Has("code_challenge_method") {
r.CodeChallengeMethod = string(ctx.QueryArgs().Peek("code_challenge_method"))
}
if _, err := pkce.New(r.CodeChallengeMethod); err != nil {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: err.Error(),
}
}
}
r.RedirectURI = string(ctx.QueryArgs().Peek("redirect_uri"))
r.State = ctx.QueryArgs().Peek("state")
r.Scope = string(ctx.QueryArgs().Peek("scope"))
r.Me = string(ctx.QueryArgs().Peek("me"))
return nil
}
func (h *Handler) ClientInfo(ctx *http.RequestCtx) {
r := new(AuthorizeRequest)
r.Scope = "profile"
if err := r.bind(ctx); err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
client, err := h.useCase.Discovery(ctx, r.ClientID)
if err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
csrf, _ := ctx.UserValue("csrf").([]byte)
ctx.SetContentType("text/html")
web.WritePageTemplate(ctx, &web.AuthPage{
Client: client,
CodeChallenge: r.CodeChallenge,
CodeChallengeMethod: r.CodeChallengeMethod,
CSRF: csrf,
Me: r.Me,
RedirectURI: r.RedirectURI,
ResponseType: r.ResponseType,
Scope: r.Scope,
State: r.State,
})
}
func (h *Handler) Update(ctx *http.RequestCtx) {
if ctx.PostArgs().Has("response_type") && string(ctx.PostArgs().Peek("response_type")) == "code" {
h.Redirect(ctx)
return
}
if ctx.PostArgs().Has("grant_type") && string(ctx.PostArgs().Peek("grant_type")) == "authorization_code" {
h.Exchange(ctx)
return
}
ctx.Error("please, restart your authoriztion flow", http.StatusBadRequest)
}
func (r *RedirectRequest) bind(ctx *http.RequestCtx) (err error) {
r.RedirectURI = string(ctx.PostArgs().Peek("redirect_uri"))
if r.ClientID = string(ctx.PostArgs().Peek("client_id")); r.ClientID == "" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'client_id' query is required",
}
}
if r.Authorize = string(ctx.PostArgs().Peek("authorize")); r.Authorize != "allow" && r.Authorize != "deny" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "invalid prompt action, try starting the authorization flow again",
}
}
if r.ResponseType = string(ctx.PostArgs().Peek("response_type")); r.ResponseType != "code" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'response_type' must be 'code', try starting the authorization flow again",
}
}
if ctx.PostArgs().Has("code_challenge") {
r.CodeChallenge = string(ctx.PostArgs().Peek("code_challenge"))
if len(r.CodeChallenge) < 43 || len(r.CodeChallenge) > 128 {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "length of the 'code_challenge' value must be greater than 43 and less than 128 symbols, try starting the authorization flow again",
}
}
r.CodeChallengeMethod = pkce.DefaultMethod
if ctx.PostArgs().Has("code_challenge_method") {
r.CodeChallengeMethod = string(ctx.PostArgs().Peek("code_challenge_method"))
}
_, err := pkce.New(r.CodeChallengeMethod)
if err != nil {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: err.Error(),
}
}
}
r.State = ctx.PostArgs().Peek("state")
r.Scope = string(ctx.PostArgs().Peek("scope"))
r.Me = string(ctx.PostArgs().Peek("me"))
return nil
}
func (h *Handler) Redirect(ctx *http.RequestCtx) {
r := new(RedirectRequest)
if err := r.bind(ctx); err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
redirectUri, err := url.Parse(r.RedirectURI)
if err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
query := redirectUri.Query()
query.Set("state", string(r.State))
switch r.Authorize {
case "allow":
code, err := h.useCase.Approve(ctx, &domain.Login{
CreatedAt: time.Now().UTC().Unix(),
ClientID: r.ClientID,
CodeChallenge: r.CodeChallenge,
CodeChallengeMethod: r.CodeChallengeMethod,
Me: r.Me,
RedirectURI: r.RedirectURI,
Scope: r.Scope,
})
if err != nil {
query.Set("error", domain.ErrServerError.Code)
query.Set("error_description", err.Error())
redirectUri.RawQuery = query.Encode()
ctx.Redirect(redirectUri.String(), http.StatusFound)
return
}
query.Set("code", code)
case "deny":
query.Set("error", domain.ErrAccessDenied.Code)
}
redirectUri.RawQuery = query.Encode()
ctx.Redirect(redirectUri.String(), http.StatusFound)
}
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) (err error) {
if r.GrantType = string(ctx.PostArgs().Peek("grant_type")); r.GrantType != "authorization_code" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'grant_type' must be 'authorization_code'",
}
}
if r.RedirectURI = string(ctx.PostArgs().Peek("redirect_uri")); r.RedirectURI == "" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'redirect_uri' query is required",
}
}
if r.ClientID = string(ctx.PostArgs().Peek("client_id")); r.ClientID == "" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'client_id' query is required",
}
}
if r.Code = string(ctx.PostArgs().Peek("code")); r.Code == "" {
return domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'code' query is required",
}
}
r.CodeVerifier = string(ctx.PostArgs().Peek("code_verifier"))
return nil
}
func (h *Handler) Exchange(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest)
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
req := new(AuthorizeRequest)
if err := req.bind(ctx); err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
me, err := h.useCase.Exchange(ctx, &domain.ExchangeRequest{
ClientID: req.ClientID,
Code: req.Code,
CodeVerifier: req.CodeVerifier,
RedirectURI: req.RedirectURI,
})
client, err := h.clients.Discovery(ctx, req.ClientID)
if err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
if me == "" {
ctx.Error(domain.ErrUnauthorizedClient.Error(), http.StatusUnauthorized)
if !client.ValidateRedirectURI(req.RedirectURI) {
ctx.Error("requested redirect_uri is not registered on client_id side", http.StatusBadRequest)
return
}
ctx.SetContentType("application/json")
_ = encoder.Encode(&ExchangeResponse{
csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
tag, _, _ := h.matcher.Match(tags...)
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(ctx, &web.AuthorizePage{
BaseOf: web.BaseOf{
Config: h.config,
Language: tag,
Printer: message.NewPrinter(tag),
},
Client: client,
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
CSRF: csrf,
Me: req.Me,
RedirectURI: req.RedirectURI,
ResponseType: req.ResponseType,
Scope: req.Scope,
State: req.State,
})
}
func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(ctx)
req := new(VerifyRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.Error{
Code: "invalid_request",
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return
}
u := http.AcquireURI()
defer http.ReleaseURI(u)
req.RedirectURI.CopyTo(u)
if strings.EqualFold(req.Authorize, "deny") {
u.QueryArgs().Set("error", "access_denied")
u.QueryArgs().Set("error_description", "user deny authorization request")
ctx.Redirect(u.String(), http.StatusFound)
return
}
code, err := h.useCase.Generate(ctx, auth.GenerateOptions{
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
Scope: req.Scope,
Me: req.Me,
})
if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(domain.Error{
Description: err.Error(),
Frame: xerrors.Caller(1),
})
return
}
for key, val := range map[string]string{
"code": code,
"iss": h.config.Server.GetRootURL(),
"state": req.State,
} {
u.QueryArgs().Set(key, val)
}
ctx.Redirect(u.String(), http.StatusFound)
}
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
return
}
me, err := h.useCase.Exchange(ctx, auth.ExchangeOptions{
Code: req.Code,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
CodeVerifier: req.CodeVerifier,
})
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
return
}
encoder.Encode(&ExchangeResponse{
Me: me,
})
}
func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error {
if err := form.Unmarshal(ctx.QueryArgs(), r); err != nil {
return domain.Error{
Code: "invalid_request",
Description: err.Error(),
Frame: xerrors.Caller(1),
}
}
r.Scope = make(domain.Scopes, 0)
parseScope(r.Scope, ctx.QueryArgs().Peek("scope"))
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
return nil
}
func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
return domain.Error{
Code: "invalid_request",
Description: err.Error(),
Frame: xerrors.Caller(1),
}
}
r.Scope = make(domain.Scopes, 0)
parseScope(r.Scope, ctx.PostArgs().PeekMulti("scope[]")...)
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
if !strings.EqualFold(r.Authorize, "allow") && !strings.EqualFold(r.Authorize, "deny") {
return domain.Error{
Code: "invalid_request",
Description: "cannot validate verification request",
Frame: xerrors.Caller(1),
}
}
return nil
}
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
return domain.Error{
Code: "invalid_request",
Description: err.Error(),
Frame: xerrors.Caller(1),
}
}
return nil
}
// TODO(toby3d): fix this in form pkg.
func parseScope(dst domain.Scopes, src ...[]byte) error {
if len(src) == 0 {
return nil
}
var scopes []string
if len(src) == 1 {
scopes = strings.Fields(string(src[0]))
}
for _, rawScope := range scopes {
scope, err := domain.ParseScope(string(rawScope))
if err != nil {
return &domain.Error{
Code: "invalid_request",
Description: fmt.Sprintf("cannot parse scope: %v", err),
Frame: xerrors.Caller(1),
}
}
dst = append(dst, scope)
}
return nil
}

View File

@ -0,0 +1,81 @@
package http_test
import (
"path"
"sync"
"testing"
"github.com/fasthttp/router"
"github.com/fasthttp/session/v2"
"github.com/fasthttp/session/v2/providers/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
delivery "source.toby3d.me/website/indieauth/internal/auth/delivery/http"
ucase "source.toby3d.me/website/indieauth/internal/auth/usecase"
clientrepo "source.toby3d.me/website/indieauth/internal/client/repository/memory"
clientucase "source.toby3d.me/website/indieauth/internal/client/usecase"
"source.toby3d.me/website/indieauth/internal/domain"
sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory"
"source.toby3d.me/website/indieauth/internal/testing/httptest"
userrepo "source.toby3d.me/website/indieauth/internal/user/repository/memory"
)
func TestRender(t *testing.T) {
t.Parallel()
provider, err := memory.New(memory.Config{})
require.NoError(t, err)
s := session.New(session.NewDefaultConfig())
require.NoError(t, s.SetProvider(provider))
me := domain.TestMe(t)
c := domain.TestClient(t)
config := domain.TestConfig(t)
store := new(sync.Map)
store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), domain.TestUser(t))
store.Store(path.Join(clientrepo.DefaultPathPrefix, c.ID.String()), c)
r := router.New()
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
Clients: clientucase.NewClientUseCase(clientrepo.NewMemoryClientRepository(store)),
Config: config,
Matcher: language.NewMatcher(message.DefaultCatalog.Languages()),
Auth: ucase.NewAuthUseCase(sessionrepo.NewMemorySessionRepository(config, store), config),
}).Register(r)
client, _, cleanup := httptest.New(t, r.Handler)
t.Cleanup(cleanup)
u := http.AcquireURI()
defer http.ReleaseURI(u)
u.Update("https://example.com/authorize")
for k, v := range map[string]string{
"client_id": c.ID.String(),
"code_challenge": "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo",
"code_challenge_method": domain.CodeChallengeMethodS256.String(),
"me": me.String(),
"redirect_uri": c.RedirectURI[0].String(),
"response_type": domain.ResponseTypeCode.String(),
"scope": "profile email",
"state": "1234567890",
} {
u.QueryArgs().Set(k, v)
}
req := httptest.NewRequest(http.MethodGet, u.String(), nil)
defer http.ReleaseRequest(req)
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
require.NoError(t, client.Do(req, resp))
assert.Equal(t, http.StatusOK, resp.StatusCode())
assert.Contains(t, string(resp.Body()), `Authorize application`)
}

View File

@ -1,13 +0,0 @@
package auth
import (
"context"
"gitlab.com/toby3d/indieauth/internal/domain"
)
type Repository interface {
Create(ctx context.Context, login *domain.Login) error
Get(ctx context.Context, code string) (*domain.Login, error)
Delete(ctx context.Context, code string) error
}

View File

@ -1,57 +0,0 @@
package bolt
import (
"context"
json "github.com/goccy/go-json"
"gitlab.com/toby3d/indieauth/internal/auth"
"gitlab.com/toby3d/indieauth/internal/domain"
bolt "go.etcd.io/bbolt"
)
type boltAuthRepository struct {
db *bolt.DB
}
func NewBoltAuthRepository(db *bolt.DB) (auth.Repository, error) {
if err := db.Update(func(tx *bolt.Tx) (err error) {
_, err = tx.CreateBucketIfNotExists(domain.Login{}.Bucket())
return err
}); err != nil {
return nil, err
}
return &boltAuthRepository{
db: db,
}, nil
}
func (repo *boltAuthRepository) Create(ctx context.Context, login *domain.Login) error {
jsonLogin, err := json.Marshal(login)
if err != nil {
return err
}
return repo.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(domain.Login{}.Bucket()).Put([]byte(login.Code), jsonLogin)
})
}
func (repo *boltAuthRepository) Get(ctx context.Context, code string) (*domain.Login, error) {
login := new(domain.Login)
if err := repo.db.View(func(tx *bolt.Tx) error {
return json.Unmarshal(tx.Bucket(domain.Login{}.Bucket()).Get([]byte(code)), login)
}); err != nil {
return nil, err
}
return login, nil
}
func (repo *boltAuthRepository) Delete(ctx context.Context, code string) error {
return repo.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(domain.Login{}.Bucket()).Delete([]byte(code))
})
}

View File

@ -1,40 +0,0 @@
package memory
import (
"context"
"sync"
"gitlab.com/toby3d/indieauth/internal/auth"
"gitlab.com/toby3d/indieauth/internal/domain"
)
type memoryAuthRepository struct {
logins *sync.Map
}
func NewMemoryAuthRepository() auth.Repository {
return &memoryAuthRepository{
logins: new(sync.Map),
}
}
func (repo *memoryAuthRepository) Create(ctx context.Context, login *domain.Login) error {
repo.logins.Store(login.Code, login)
return nil
}
func (repo *memoryAuthRepository) Get(ctx context.Context, code string) (*domain.Login, error) {
login, ok := repo.logins.LoadAndDelete(code)
if !ok {
return nil, nil
}
return login.(*domain.Login), nil
}
func (repo *memoryAuthRepository) Delete(ctx context.Context, code string) error {
repo.logins.Delete(code)
return nil
}

View File

@ -0,0 +1,84 @@
package usecase
import (
"context"
"fmt"
"golang.org/x/xerrors"
"source.toby3d.me/website/indieauth/internal/auth"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/random"
"source.toby3d.me/website/indieauth/internal/session"
)
type authUseCase struct {
config *domain.Config
sessions session.Repository
}
func NewAuthUseCase(sessions session.Repository, config *domain.Config) auth.UseCase {
return &authUseCase{
config: config,
sessions: sessions,
}
}
func (useCase *authUseCase) Generate(ctx context.Context, opts auth.GenerateOptions) (string, error) {
code, err := random.String(useCase.config.Code.Length)
if err != nil {
return "", fmt.Errorf("cannot generate random code: %w", err)
}
if err = useCase.sessions.Create(ctx, &domain.Session{
ClientID: opts.ClientID,
Code: code,
CodeChallenge: opts.CodeChallenge,
CodeChallengeMethod: opts.CodeChallengeMethod,
Me: opts.Me,
RedirectURI: opts.RedirectURI,
Scope: opts.Scope,
}); err != nil {
return "", fmt.Errorf("cannot save session in store: %w", err)
}
return code, nil
}
func (useCase *authUseCase) Exchange(ctx context.Context, opts auth.ExchangeOptions) (*domain.Me, error) {
session, err := useCase.sessions.GetAndDelete(ctx, opts.Code)
if err != nil {
return nil, err
}
if opts.ClientID.String() != session.ClientID.String() {
return nil, domain.Error{
Code: "invalid_request",
Description: "client's URL MUST match the client_id used in the authentication request",
URI: "https://indieauth.net/source/#request",
Frame: xerrors.Caller(1),
}
}
if opts.RedirectURI.String() != session.RedirectURI.String() {
return nil, domain.Error{
Code: "invalid_request",
Description: "client's redirect URL MUST match the initial authentication request",
URI: "https://indieauth.net/source/#request",
Frame: xerrors.Caller(1),
}
}
if session.CodeChallenge != "" &&
!session.CodeChallengeMethod.Validate(session.CodeChallenge, opts.CodeVerifier) {
return nil, domain.Error{
Code: "invalid_request",
Description: "code_verifier is not hashes to the same value as given in " +
"the code_challenge in the original authorization request",
URI: "https://indieauth.net/source/#request",
Frame: xerrors.Caller(1),
}
}
return session.Me, nil
}

View File

@ -1,137 +0,0 @@
package usecase
import (
"bytes"
"context"
"net/url"
"time"
http "github.com/valyala/fasthttp"
"gitlab.com/toby3d/indieauth/internal/auth"
"gitlab.com/toby3d/indieauth/internal/domain"
"gitlab.com/toby3d/indieauth/internal/pkce"
"gitlab.com/toby3d/indieauth/internal/random"
"willnorris.com/go/microformats"
)
type authUseCase struct {
client *http.Client
repo auth.Repository
}
func NewAuthUseCase(repo auth.Repository) auth.UseCase {
return &authUseCase{
client: new(http.Client),
repo: repo,
}
}
func (useCase *authUseCase) Discovery(ctx context.Context, clientId string) (*domain.Client, error) {
_, src, err := useCase.client.Get(nil, clientId)
if err != nil {
return nil, err
}
cid, err := url.Parse(clientId)
if err != nil {
return nil, err
}
data := microformats.Parse(bytes.NewReader(src), cid)
client := new(domain.Client)
client.RedirectURI = make([]string, 0)
for i := range data.Items {
if len(data.Items[i].Type) == 0 || data.Items[i].Type[0] != "h-app" {
continue
}
for key, values := range data.Items[i].Properties {
switch key {
case "logo":
for j := range values {
switch val := values[j].(type) {
case string:
client.Logo = val
case map[string]string:
client.Logo = val["value"]
}
}
case "name":
for j := range values {
client.Name, _ = values[j].(string)
}
case "url":
for j := range values {
client.URL, _ = values[j].(string)
}
}
}
}
for key, values := range data.Rels {
if key != "redirect_uri" {
continue
}
client.RedirectURI = append(client.RedirectURI, values...)
}
if client.URL != clientId {
return nil, domain.Error{
Code: domain.ErrInvalidRequest.Code,
Description: "'client_id' does not match the actual client URL",
}
}
return client, nil
}
func (useCase *authUseCase) Approve(ctx context.Context, login *domain.Login) (string, error) {
login.Code = random.New().String(32)
if err := useCase.repo.Create(ctx, login); err != nil {
return "", err
}
return login.Code, nil
}
func (useCase *authUseCase) Exchange(ctx context.Context, req *domain.ExchangeRequest) (string, error) {
login, err := useCase.repo.Get(ctx, req.Code)
if err != nil {
return "", err
}
if login == nil {
return "", nil
}
_ = useCase.repo.Delete(ctx, req.Code)
if time.Now().UTC().After(time.Unix(login.CreatedAt, 0).Add(10 * time.Minute)) {
return "", nil
}
if login.ClientID != req.ClientID || login.RedirectURI != req.RedirectURI {
return "", domain.ErrInvalidRequest
}
if login.CodeChallenge != "" {
codeChallenge, err := pkce.New(login.CodeChallengeMethod)
if err != nil {
return "", err
}
codeChallenge.Verifier = req.CodeVerifier
codeChallenge.Generate()
if login.CodeChallenge != codeChallenge.Challenge {
return "", domain.ErrInvalidRequest
}
}
return login.Me, nil
}