♻️ Refactored client support and handler rendering

This commit is contained in:
Maxim Lebedev 2022-01-14 01:49:08 +05:00
parent d67279438a
commit 60da2ac25e
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
7 changed files with 257 additions and 180 deletions

View File

@ -3,11 +3,28 @@ package auth
import ( import (
"context" "context"
"gitlab.com/toby3d/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/domain"
) )
type UseCase interface { type (
Discovery(ctx context.Context, clientId string) (*domain.Client, error) GenerateOptions struct {
Approve(ctx context.Context, login *domain.Login) (string, error) ClientID *domain.ClientID
Exchange(ctx context.Context, req *domain.ExchangeRequest) (string, error) RedirectURI *domain.URL
} CodeChallenge string
CodeChallengeMethod domain.CodeChallengeMethod
Scope domain.Scopes
Me *domain.Me
}
ExchangeOptions struct {
Code string
ClientID *domain.ClientID
RedirectURI *domain.URL
CodeVerifier string
}
UseCase interface {
Generate(ctx context.Context, opts GenerateOptions) (string, error)
Exchange(ctx context.Context, opts ExchangeOptions) (*domain.Me, error)
}
)

View File

@ -0,0 +1,135 @@
package http
import (
"fmt"
"strings"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
"source.toby3d.me/toby3d/form"
"source.toby3d.me/toby3d/middleware"
"source.toby3d.me/website/indieauth/internal/common"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/token"
"source.toby3d.me/website/indieauth/web"
)
type (
CallbackRequest struct {
Code string `form:"code"`
State string `form:"state"`
Iss *domain.ClientID `form:"iss"`
}
NewRequestHandlerOptions struct {
Client *domain.Client
Config *domain.Config
Matcher language.Matcher
Tokens token.UseCase
}
RequestHandler struct {
client *domain.Client
config *domain.Config
matcher language.Matcher
tokens token.UseCase
}
)
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
return &RequestHandler{
client: opts.Client,
config: opts.Config,
matcher: opts.Matcher,
tokens: opts.Tokens,
}
}
func (h *RequestHandler) Register(r *router.Router) {
chain := middleware.Chain{
middleware.LogFmt(),
}
r.GET("/", chain.RequestHandler(h.handleRender))
r.GET("/callback", chain.RequestHandler(h.handleCallback))
}
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
redirectUri := make([]string, len(h.client.RedirectURI))
for i := range h.client.RedirectURI {
redirectUri[i] = h.client.RedirectURI[i].String()
}
ctx.Response.Header.Set(
http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
tag, _, _ := h.matcher.Match(tags...)
// TODO(toby3d): generate and store PKCE
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(ctx, &web.HomePage{
BaseOf: web.BaseOf{
Config: h.config,
Language: tag,
Printer: message.NewPrinter(tag),
},
Client: h.client,
State: "hackme", // TODO(toby3d): generate and store state
})
}
func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
req := new(CallbackRequest)
if err := req.bind(ctx); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
return
}
// TODO(toby3d): load and check state
if req.Iss.String() != h.client.ID.String() {
ctx.Error("iss is not equal", http.StatusBadRequest)
return
}
token, err := h.tokens.Exchange(ctx, token.ExchangeOptions{
ClientID: h.client.ID,
RedirectURI: h.client.RedirectURI[0],
Code: req.Code,
CodeVerifier: "", // TODO(toby3d): validate PKCE here
})
if err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
return
}
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
tag, _, _ := h.matcher.Match(tags...)
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
web.WriteTemplate(ctx, &web.CallbackPage{
BaseOf: web.BaseOf{
Config: h.config,
Language: tag,
Printer: message.NewPrinter(tag),
},
Token: token,
})
}
func (req *CallbackRequest) bind(ctx *http.RequestCtx) error {
if err := form.Unmarshal(ctx.QueryArgs(), req); err != nil {
return fmt.Errorf("cannot unmarshal request: %w", err)
}
return nil
}

View File

@ -1,6 +1,7 @@
package http_test package http_test
import ( import (
"sync"
"testing" "testing"
"github.com/fasthttp/router" "github.com/fasthttp/router"
@ -12,16 +13,26 @@ import (
delivery "source.toby3d.me/website/indieauth/internal/client/delivery/http" delivery "source.toby3d.me/website/indieauth/internal/client/delivery/http"
"source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/domain"
sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory"
"source.toby3d.me/website/indieauth/internal/testing/httptest" "source.toby3d.me/website/indieauth/internal/testing/httptest"
tokenrepo "source.toby3d.me/website/indieauth/internal/token/repository/memory"
tokenucase "source.toby3d.me/website/indieauth/internal/token/usecase"
) )
func TestRead(t *testing.T) { func TestRead(t *testing.T) {
t.Parallel() t.Parallel()
store := new(sync.Map)
config := domain.TestConfig(t)
r := router.New() r := router.New()
delivery.NewRequestHandler( delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
domain.TestConfig(t), domain.TestClient(t), language.NewMatcher(message.DefaultCatalog.Languages()), Client: domain.TestClient(t),
).Register(r) Config: config,
Matcher: language.NewMatcher(message.DefaultCatalog.Languages()),
Tokens: tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store),
sessionrepo.NewMemorySessionRepository(config, store), config),
}).Register(r)
client, _, cleanup := httptest.New(t, r.Handler) client, _, cleanup := httptest.New(t, r.Handler)
t.Cleanup(cleanup) t.Cleanup(cleanup)

View File

@ -1,71 +0,0 @@
package http
import (
"encoding/base64"
"strings"
"github.com/fasthttp/router"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
"source.toby3d.me/website/indieauth/internal/common"
"source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/random"
"source.toby3d.me/website/indieauth/web"
)
type RequestHandler struct {
client *domain.Client
config *domain.Config
matcher language.Matcher
}
const DefaultStateLength int = 64
func NewRequestHandler(config *domain.Config, client *domain.Client, matcher language.Matcher) *RequestHandler {
return &RequestHandler{
client: client,
config: config,
matcher: matcher,
}
}
func (h *RequestHandler) Register(r *router.Router) {
r.GET("/", h.read)
}
func (h *RequestHandler) read(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
// TODO(toby3d): save state for checking it at the end of flow?
state, err := random.Bytes(DefaultStateLength)
if err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
return
}
redirectUri := make([]string, len(h.client.RedirectURI))
for i := range h.client.RedirectURI {
redirectUri[i] = h.client.RedirectURI[i].String()
}
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
tag, _, _ := h.matcher.Match(tags...)
ctx.Response.Header.Set(
http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
)
web.WriteTemplate(ctx, &web.HomePage{
BaseOf: web.BaseOf{
Config: h.config,
Language: tag,
Printer: message.NewPrinter(tag),
},
RedirectURI: h.client.RedirectURI,
AuthEndpoint: "/authorize",
Client: h.client,
State: base64.RawURLEncoding.EncodeToString(state),
})
}

View File

@ -1,24 +1,22 @@
package http package http
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"net/url"
"strings"
"github.com/tomnomnom/linkheader"
http "github.com/valyala/fasthttp" http "github.com/valyala/fasthttp"
"willnorris.com/go/microformats"
"source.toby3d.me/website/indieauth/internal/client" "source.toby3d.me/website/indieauth/internal/client"
"source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/domain"
"source.toby3d.me/website/indieauth/internal/util"
) )
type httpClientRepository struct { type httpClientRepository struct {
client *http.Client client *http.Client
} }
const DefaultMaxRedirectsCount int = 10
const ( const (
relRedirectURI string = "redirect_uri" relRedirectURI string = "redirect_uri"
@ -45,7 +43,7 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
resp := http.AcquireResponse() resp := http.AcquireResponse()
defer http.ReleaseResponse(resp) defer http.ReleaseResponse(resp)
if err := repo.client.Do(req, resp); err != nil { if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
return nil, fmt.Errorf("failed to make a request to the client: %w", err) return nil, fmt.Errorf("failed to make a request to the client: %w", err)
} }
@ -55,115 +53,103 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
client := &domain.Client{ client := &domain.Client{
ID: id, ID: id,
RedirectURI: make([]*domain.URL, 0),
Logo: make([]*domain.URL, 0), Logo: make([]*domain.URL, 0),
Name: extractValues(resp, propertyName),
RedirectURI: extractEndpoints(resp, relRedirectURI),
URL: make([]*domain.URL, 0), URL: make([]*domain.URL, 0),
Name: make([]string, 0),
} }
for _, v := range extractValues(resp, propertyLogo) { extract(client, resp)
u, err := domain.NewURL(v)
if err != nil {
continue
}
client.Logo = append(client.Logo, u)
}
for _, v := range extractValues(resp, propertyURL) {
u, err := domain.NewURL(v)
if err != nil {
continue
}
client.URL = append(client.URL, u)
}
return client, nil return client, nil
} }
func extractEndpoints(resp *http.Response, name string) []*domain.URL { func extract(dst *domain.Client, src *http.Response) {
results := make([]*domain.URL, 0) for _, u := range util.ExtractEndpoints(src, relRedirectURI) {
endpoints, _ := extractEndpointsFromHeader(resp, name) if containsURL(dst.RedirectURI, u) {
results = append(results, endpoints...)
endpoints, _ = extractEndpointsFromBody(resp, name)
results = append(results, endpoints...)
return results
}
func extractValues(resp *http.Response, key string) []string {
results := make([]string, 0)
for _, item := range microformats.Parse(bytes.NewReader(resp.Body()), nil).Items {
if len(item.Type) == 0 || (item.Type[0] != hApp && item.Type[0] != hXApp) {
continue continue
} }
properties, ok := item.Properties[key] dst.RedirectURI = append(dst.RedirectURI, u)
if !ok || len(properties) == 0 { }
return nil
}
for j := range properties { for _, t := range []string{hXApp, hApp} {
switch p := properties[j].(type) { for _, name := range util.ExtractProperty(src, t, propertyName) {
case string: n, ok := name.(string)
results = append(results, p) if !ok || containsString(dst.Name, n) {
case map[string][]interface{}: continue
for _, val := range p["value"] {
v, ok := val.(string)
if !ok {
continue
}
results = append(results, v)
}
} }
dst.Name = append(dst.Name, n)
} }
return results for _, logo := range util.ExtractProperty(src, t, propertyLogo) {
} var err error
return nil var u *domain.URL
switch l := logo.(type) {
case string:
u, err = domain.NewURL(l)
case map[string]string:
value, ok := l["value"]
if !ok {
continue
}
u, err = domain.NewURL(value)
}
if err != nil {
continue
}
if containsURL(dst.Logo, u) {
continue
}
dst.Logo = append(dst.Logo, u)
}
for _, url := range util.ExtractProperty(src, t, propertyURL) {
l, ok := url.(string)
if !ok {
continue
}
u, err := domain.NewURL(l)
if err != nil {
continue
}
if containsURL(dst.URL, u) {
continue
}
dst.URL = append(dst.URL, u)
}
}
} }
func extractEndpointsFromHeader(resp *http.Response, name string) ([]*domain.URL, error) { func containsString(src []string, find string) bool {
results := make([]*domain.URL, 0) for i := range src {
if src[i] != find {
for _, link := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) {
if !strings.EqualFold(link.Rel, name) {
continue continue
} }
u := http.AcquireURI() return true
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil { }
return nil, err
return false
}
func containsURL(src []*domain.URL, find *domain.URL) bool {
for i := range src {
if src[i].String() != find.String() {
continue
} }
results = append(results, &domain.URL{URI: u}) return true
} }
return results, nil return false
}
func extractEndpointsFromBody(resp *http.Response, name string) ([]*domain.URL, error) {
host, err := url.Parse(string(resp.Header.Peek(http.HeaderHost)))
if err != nil {
return nil, fmt.Errorf("cannot parse host header: %w", err)
}
endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), host).Rels[name]
if !ok || len(endpoints) == 0 {
return nil, nil
}
results := make([]*domain.URL, 0)
for i := range endpoints {
u := http.AcquireURI()
u.Update(endpoints[i])
results = append(results, &domain.URL{URI: u})
}
return results, nil
} }

View File

@ -12,4 +12,4 @@ type UseCase interface {
Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
} }
var ErrInvalidMe = errors.New("provided me is invalid") var ErrInvalidMe = errors.New("invalid me")

View File

@ -8,16 +8,15 @@ import (
) )
type ( type (
GenerateOptions struct { ExchangeOptions struct {
ClientID string ClientID *domain.ClientID
Me string RedirectURI *domain.URL
Scopes []string Code string
NonceLength int CodeVerifier string
} }
UseCase interface { UseCase interface {
// Generate generates a new Token based on the session data. Exchange(ctx context.Context, opts ExchangeOptions) (*domain.Token, error)
Generate(ctx context.Context, opts GenerateOptions) (*domain.Token, error)
// Verify checks the AccessToken and returns the associated information. // Verify checks the AccessToken and returns the associated information.
Verify(ctx context.Context, accessToken string) (*domain.Token, error) Verify(ctx context.Context, accessToken string) (*domain.Token, error)