♻️ Refactored client support and handler rendering
This commit is contained in:
parent
d67279438a
commit
60da2ac25e
|
@ -3,11 +3,28 @@ package auth
|
|||
import (
|
||||
"context"
|
||||
|
||||
"gitlab.com/toby3d/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
)
|
||||
|
||||
type UseCase interface {
|
||||
Discovery(ctx context.Context, clientId string) (*domain.Client, error)
|
||||
Approve(ctx context.Context, login *domain.Login) (string, error)
|
||||
Exchange(ctx context.Context, req *domain.ExchangeRequest) (string, error)
|
||||
}
|
||||
type (
|
||||
GenerateOptions struct {
|
||||
ClientID *domain.ClientID
|
||||
RedirectURI *domain.URL
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod domain.CodeChallengeMethod
|
||||
Scope domain.Scopes
|
||||
Me *domain.Me
|
||||
}
|
||||
|
||||
ExchangeOptions struct {
|
||||
Code string
|
||||
ClientID *domain.ClientID
|
||||
RedirectURI *domain.URL
|
||||
CodeVerifier string
|
||||
}
|
||||
|
||||
UseCase interface {
|
||||
Generate(ctx context.Context, opts GenerateOptions) (string, error)
|
||||
Exchange(ctx context.Context, opts ExchangeOptions) (*domain.Me, error)
|
||||
}
|
||||
)
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
http "github.com/valyala/fasthttp"
|
||||
"golang.org/x/text/language"
|
||||
"golang.org/x/text/message"
|
||||
|
||||
"source.toby3d.me/toby3d/form"
|
||||
"source.toby3d.me/toby3d/middleware"
|
||||
"source.toby3d.me/website/indieauth/internal/common"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/token"
|
||||
"source.toby3d.me/website/indieauth/web"
|
||||
)
|
||||
|
||||
type (
|
||||
CallbackRequest struct {
|
||||
Code string `form:"code"`
|
||||
State string `form:"state"`
|
||||
Iss *domain.ClientID `form:"iss"`
|
||||
}
|
||||
|
||||
NewRequestHandlerOptions struct {
|
||||
Client *domain.Client
|
||||
Config *domain.Config
|
||||
Matcher language.Matcher
|
||||
Tokens token.UseCase
|
||||
}
|
||||
|
||||
RequestHandler struct {
|
||||
client *domain.Client
|
||||
config *domain.Config
|
||||
matcher language.Matcher
|
||||
tokens token.UseCase
|
||||
}
|
||||
)
|
||||
|
||||
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
|
||||
return &RequestHandler{
|
||||
client: opts.Client,
|
||||
config: opts.Config,
|
||||
matcher: opts.Matcher,
|
||||
tokens: opts.Tokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RequestHandler) Register(r *router.Router) {
|
||||
chain := middleware.Chain{
|
||||
middleware.LogFmt(),
|
||||
}
|
||||
|
||||
r.GET("/", chain.RequestHandler(h.handleRender))
|
||||
r.GET("/callback", chain.RequestHandler(h.handleCallback))
|
||||
}
|
||||
|
||||
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
|
||||
redirectUri := make([]string, len(h.client.RedirectURI))
|
||||
for i := range h.client.RedirectURI {
|
||||
redirectUri[i] = h.client.RedirectURI[i].String()
|
||||
}
|
||||
|
||||
ctx.Response.Header.Set(
|
||||
http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
|
||||
)
|
||||
|
||||
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
|
||||
tag, _, _ := h.matcher.Match(tags...)
|
||||
|
||||
// TODO(toby3d): generate and store PKCE
|
||||
|
||||
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
|
||||
web.WriteTemplate(ctx, &web.HomePage{
|
||||
BaseOf: web.BaseOf{
|
||||
Config: h.config,
|
||||
Language: tag,
|
||||
Printer: message.NewPrinter(tag),
|
||||
},
|
||||
Client: h.client,
|
||||
State: "hackme", // TODO(toby3d): generate and store state
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
|
||||
req := new(CallbackRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(toby3d): load and check state
|
||||
|
||||
if req.Iss.String() != h.client.ID.String() {
|
||||
ctx.Error("iss is not equal", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.tokens.Exchange(ctx, token.ExchangeOptions{
|
||||
ClientID: h.client.ID,
|
||||
RedirectURI: h.client.RedirectURI[0],
|
||||
Code: req.Code,
|
||||
CodeVerifier: "", // TODO(toby3d): validate PKCE here
|
||||
})
|
||||
if err != nil {
|
||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
|
||||
tag, _, _ := h.matcher.Match(tags...)
|
||||
|
||||
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
|
||||
web.WriteTemplate(ctx, &web.CallbackPage{
|
||||
BaseOf: web.BaseOf{
|
||||
Config: h.config,
|
||||
Language: tag,
|
||||
Printer: message.NewPrinter(tag),
|
||||
},
|
||||
Token: token,
|
||||
})
|
||||
}
|
||||
|
||||
func (req *CallbackRequest) bind(ctx *http.RequestCtx) error {
|
||||
if err := form.Unmarshal(ctx.QueryArgs(), req); err != nil {
|
||||
return fmt.Errorf("cannot unmarshal request: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package http_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
|
@ -12,16 +13,26 @@ import (
|
|||
|
||||
delivery "source.toby3d.me/website/indieauth/internal/client/delivery/http"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory"
|
||||
"source.toby3d.me/website/indieauth/internal/testing/httptest"
|
||||
tokenrepo "source.toby3d.me/website/indieauth/internal/token/repository/memory"
|
||||
tokenucase "source.toby3d.me/website/indieauth/internal/token/usecase"
|
||||
)
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := new(sync.Map)
|
||||
config := domain.TestConfig(t)
|
||||
|
||||
r := router.New()
|
||||
delivery.NewRequestHandler(
|
||||
domain.TestConfig(t), domain.TestClient(t), language.NewMatcher(message.DefaultCatalog.Languages()),
|
||||
).Register(r)
|
||||
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
|
||||
Client: domain.TestClient(t),
|
||||
Config: config,
|
||||
Matcher: language.NewMatcher(message.DefaultCatalog.Languages()),
|
||||
Tokens: tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store),
|
||||
sessionrepo.NewMemorySessionRepository(config, store), config),
|
||||
}).Register(r)
|
||||
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
t.Cleanup(cleanup)
|
|
@ -1,71 +0,0 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
|
||||
"github.com/fasthttp/router"
|
||||
http "github.com/valyala/fasthttp"
|
||||
"golang.org/x/text/language"
|
||||
"golang.org/x/text/message"
|
||||
|
||||
"source.toby3d.me/website/indieauth/internal/common"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/random"
|
||||
"source.toby3d.me/website/indieauth/web"
|
||||
)
|
||||
|
||||
type RequestHandler struct {
|
||||
client *domain.Client
|
||||
config *domain.Config
|
||||
matcher language.Matcher
|
||||
}
|
||||
|
||||
const DefaultStateLength int = 64
|
||||
|
||||
func NewRequestHandler(config *domain.Config, client *domain.Client, matcher language.Matcher) *RequestHandler {
|
||||
return &RequestHandler{
|
||||
client: client,
|
||||
config: config,
|
||||
matcher: matcher,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RequestHandler) Register(r *router.Router) {
|
||||
r.GET("/", h.read)
|
||||
}
|
||||
|
||||
func (h *RequestHandler) read(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
|
||||
|
||||
// TODO(toby3d): save state for checking it at the end of flow?
|
||||
state, err := random.Bytes(DefaultStateLength)
|
||||
if err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
redirectUri := make([]string, len(h.client.RedirectURI))
|
||||
for i := range h.client.RedirectURI {
|
||||
redirectUri[i] = h.client.RedirectURI[i].String()
|
||||
}
|
||||
|
||||
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
|
||||
tag, _, _ := h.matcher.Match(tags...)
|
||||
|
||||
ctx.Response.Header.Set(
|
||||
http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
|
||||
)
|
||||
web.WriteTemplate(ctx, &web.HomePage{
|
||||
BaseOf: web.BaseOf{
|
||||
Config: h.config,
|
||||
Language: tag,
|
||||
Printer: message.NewPrinter(tag),
|
||||
},
|
||||
RedirectURI: h.client.RedirectURI,
|
||||
AuthEndpoint: "/authorize",
|
||||
Client: h.client,
|
||||
State: base64.RawURLEncoding.EncodeToString(state),
|
||||
})
|
||||
}
|
|
@ -1,24 +1,22 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/tomnomnom/linkheader"
|
||||
http "github.com/valyala/fasthttp"
|
||||
"willnorris.com/go/microformats"
|
||||
|
||||
"source.toby3d.me/website/indieauth/internal/client"
|
||||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
"source.toby3d.me/website/indieauth/internal/util"
|
||||
)
|
||||
|
||||
type httpClientRepository struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
const DefaultMaxRedirectsCount int = 10
|
||||
|
||||
const (
|
||||
relRedirectURI string = "redirect_uri"
|
||||
|
||||
|
@ -45,7 +43,7 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
|
|||
resp := http.AcquireResponse()
|
||||
defer http.ReleaseResponse(resp)
|
||||
|
||||
if err := repo.client.Do(req, resp); err != nil {
|
||||
if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil {
|
||||
return nil, fmt.Errorf("failed to make a request to the client: %w", err)
|
||||
}
|
||||
|
||||
|
@ -55,115 +53,103 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
|
|||
|
||||
client := &domain.Client{
|
||||
ID: id,
|
||||
RedirectURI: make([]*domain.URL, 0),
|
||||
Logo: make([]*domain.URL, 0),
|
||||
Name: extractValues(resp, propertyName),
|
||||
RedirectURI: extractEndpoints(resp, relRedirectURI),
|
||||
URL: make([]*domain.URL, 0),
|
||||
Name: make([]string, 0),
|
||||
}
|
||||
|
||||
for _, v := range extractValues(resp, propertyLogo) {
|
||||
u, err := domain.NewURL(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
client.Logo = append(client.Logo, u)
|
||||
}
|
||||
|
||||
for _, v := range extractValues(resp, propertyURL) {
|
||||
u, err := domain.NewURL(v)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
client.URL = append(client.URL, u)
|
||||
}
|
||||
extract(client, resp)
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func extractEndpoints(resp *http.Response, name string) []*domain.URL {
|
||||
results := make([]*domain.URL, 0)
|
||||
endpoints, _ := extractEndpointsFromHeader(resp, name)
|
||||
results = append(results, endpoints...)
|
||||
endpoints, _ = extractEndpointsFromBody(resp, name)
|
||||
results = append(results, endpoints...)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func extractValues(resp *http.Response, key string) []string {
|
||||
results := make([]string, 0)
|
||||
|
||||
for _, item := range microformats.Parse(bytes.NewReader(resp.Body()), nil).Items {
|
||||
if len(item.Type) == 0 || (item.Type[0] != hApp && item.Type[0] != hXApp) {
|
||||
func extract(dst *domain.Client, src *http.Response) {
|
||||
for _, u := range util.ExtractEndpoints(src, relRedirectURI) {
|
||||
if containsURL(dst.RedirectURI, u) {
|
||||
continue
|
||||
}
|
||||
|
||||
properties, ok := item.Properties[key]
|
||||
if !ok || len(properties) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst.RedirectURI = append(dst.RedirectURI, u)
|
||||
}
|
||||
|
||||
for j := range properties {
|
||||
switch p := properties[j].(type) {
|
||||
case string:
|
||||
results = append(results, p)
|
||||
case map[string][]interface{}:
|
||||
for _, val := range p["value"] {
|
||||
v, ok := val.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
results = append(results, v)
|
||||
}
|
||||
for _, t := range []string{hXApp, hApp} {
|
||||
for _, name := range util.ExtractProperty(src, t, propertyName) {
|
||||
n, ok := name.(string)
|
||||
if !ok || containsString(dst.Name, n) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Name = append(dst.Name, n)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
for _, logo := range util.ExtractProperty(src, t, propertyLogo) {
|
||||
var err error
|
||||
|
||||
return nil
|
||||
var u *domain.URL
|
||||
switch l := logo.(type) {
|
||||
case string:
|
||||
u, err = domain.NewURL(l)
|
||||
case map[string]string:
|
||||
value, ok := l["value"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
u, err = domain.NewURL(value)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if containsURL(dst.Logo, u) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Logo = append(dst.Logo, u)
|
||||
}
|
||||
|
||||
for _, url := range util.ExtractProperty(src, t, propertyURL) {
|
||||
l, ok := url.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
u, err := domain.NewURL(l)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if containsURL(dst.URL, u) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.URL = append(dst.URL, u)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractEndpointsFromHeader(resp *http.Response, name string) ([]*domain.URL, error) {
|
||||
results := make([]*domain.URL, 0)
|
||||
|
||||
for _, link := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) {
|
||||
if !strings.EqualFold(link.Rel, name) {
|
||||
func containsString(src []string, find string) bool {
|
||||
for i := range src {
|
||||
if src[i] != find {
|
||||
continue
|
||||
}
|
||||
|
||||
u := http.AcquireURI()
|
||||
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil {
|
||||
return nil, err
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func containsURL(src []*domain.URL, find *domain.URL) bool {
|
||||
for i := range src {
|
||||
if src[i].String() != find.String() {
|
||||
continue
|
||||
}
|
||||
|
||||
results = append(results, &domain.URL{URI: u})
|
||||
return true
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func extractEndpointsFromBody(resp *http.Response, name string) ([]*domain.URL, error) {
|
||||
host, err := url.Parse(string(resp.Header.Peek(http.HeaderHost)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse host header: %w", err)
|
||||
}
|
||||
|
||||
endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), host).Rels[name]
|
||||
if !ok || len(endpoints) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
results := make([]*domain.URL, 0)
|
||||
for i := range endpoints {
|
||||
u := http.AcquireURI()
|
||||
u.Update(endpoints[i])
|
||||
|
||||
results = append(results, &domain.URL{URI: u})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -12,4 +12,4 @@ type UseCase interface {
|
|||
Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
|
||||
}
|
||||
|
||||
var ErrInvalidMe = errors.New("provided me is invalid")
|
||||
var ErrInvalidMe = errors.New("invalid me")
|
||||
|
|
|
@ -8,16 +8,15 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
GenerateOptions struct {
|
||||
ClientID string
|
||||
Me string
|
||||
Scopes []string
|
||||
NonceLength int
|
||||
ExchangeOptions struct {
|
||||
ClientID *domain.ClientID
|
||||
RedirectURI *domain.URL
|
||||
Code string
|
||||
CodeVerifier string
|
||||
}
|
||||
|
||||
UseCase interface {
|
||||
// Generate generates a new Token based on the session data.
|
||||
Generate(ctx context.Context, opts GenerateOptions) (*domain.Token, error)
|
||||
Exchange(ctx context.Context, opts ExchangeOptions) (*domain.Token, error)
|
||||
|
||||
// Verify checks the AccessToken and returns the associated information.
|
||||
Verify(ctx context.Context, accessToken string) (*domain.Token, error)
|
||||
|
|
Loading…
Reference in New Issue