diff --git a/internal/auth/usecase.go b/internal/auth/usecase.go index 9bd0bbc..7927f4b 100644 --- a/internal/auth/usecase.go +++ b/internal/auth/usecase.go @@ -3,11 +3,28 @@ package auth import ( "context" - "gitlab.com/toby3d/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/domain" ) -type UseCase interface { - Discovery(ctx context.Context, clientId string) (*domain.Client, error) - Approve(ctx context.Context, login *domain.Login) (string, error) - Exchange(ctx context.Context, req *domain.ExchangeRequest) (string, error) -} +type ( + GenerateOptions struct { + ClientID *domain.ClientID + RedirectURI *domain.URL + CodeChallenge string + CodeChallengeMethod domain.CodeChallengeMethod + Scope domain.Scopes + Me *domain.Me + } + + ExchangeOptions struct { + Code string + ClientID *domain.ClientID + RedirectURI *domain.URL + CodeVerifier string + } + + UseCase interface { + Generate(ctx context.Context, opts GenerateOptions) (string, error) + Exchange(ctx context.Context, opts ExchangeOptions) (*domain.Me, error) + } +) diff --git a/internal/client/delivery/http/client_http.go b/internal/client/delivery/http/client_http.go new file mode 100644 index 0000000..479dd8c --- /dev/null +++ b/internal/client/delivery/http/client_http.go @@ -0,0 +1,135 @@ +package http + +import ( + "fmt" + "strings" + + "github.com/fasthttp/router" + http "github.com/valyala/fasthttp" + "golang.org/x/text/language" + "golang.org/x/text/message" + + "source.toby3d.me/toby3d/form" + "source.toby3d.me/toby3d/middleware" + "source.toby3d.me/website/indieauth/internal/common" + "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/token" + "source.toby3d.me/website/indieauth/web" +) + +type ( + CallbackRequest struct { + Code string `form:"code"` + State string `form:"state"` + Iss *domain.ClientID `form:"iss"` + } + + NewRequestHandlerOptions struct { + Client *domain.Client + Config *domain.Config + Matcher language.Matcher + Tokens token.UseCase + } + + RequestHandler struct { + client *domain.Client + config *domain.Config + matcher language.Matcher + tokens token.UseCase + } +) + +func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { + return &RequestHandler{ + client: opts.Client, + config: opts.Config, + matcher: opts.Matcher, + tokens: opts.Tokens, + } +} + +func (h *RequestHandler) Register(r *router.Router) { + chain := middleware.Chain{ + middleware.LogFmt(), + } + + r.GET("/", chain.RequestHandler(h.handleRender)) + r.GET("/callback", chain.RequestHandler(h.handleCallback)) +} + +func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { + redirectUri := make([]string, len(h.client.RedirectURI)) + for i := range h.client.RedirectURI { + redirectUri[i] = h.client.RedirectURI[i].String() + } + + ctx.Response.Header.Set( + http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`, + ) + + tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) + tag, _, _ := h.matcher.Match(tags...) + + // TODO(toby3d): generate and store PKCE + + ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) + web.WriteTemplate(ctx, &web.HomePage{ + BaseOf: web.BaseOf{ + Config: h.config, + Language: tag, + Printer: message.NewPrinter(tag), + }, + Client: h.client, + State: "hackme", // TODO(toby3d): generate and store state + }) +} + +func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) { + req := new(CallbackRequest) + if err := req.bind(ctx); err != nil { + ctx.Error(err.Error(), http.StatusInternalServerError) + + return + } + + // TODO(toby3d): load and check state + + if req.Iss.String() != h.client.ID.String() { + ctx.Error("iss is not equal", http.StatusBadRequest) + + return + } + + token, err := h.tokens.Exchange(ctx, token.ExchangeOptions{ + ClientID: h.client.ID, + RedirectURI: h.client.RedirectURI[0], + Code: req.Code, + CodeVerifier: "", // TODO(toby3d): validate PKCE here + }) + if err != nil { + ctx.Error(err.Error(), http.StatusBadRequest) + + return + } + + tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) + tag, _, _ := h.matcher.Match(tags...) + + ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) + web.WriteTemplate(ctx, &web.CallbackPage{ + BaseOf: web.BaseOf{ + Config: h.config, + Language: tag, + Printer: message.NewPrinter(tag), + }, + Token: token, + }) +} + +func (req *CallbackRequest) bind(ctx *http.RequestCtx) error { + if err := form.Unmarshal(ctx.QueryArgs(), req); err != nil { + return fmt.Errorf("cannot unmarshal request: %w", err) + } + + return nil +} diff --git a/internal/client/delivery/http/home_http_test.go b/internal/client/delivery/http/client_http_test.go similarity index 56% rename from internal/client/delivery/http/home_http_test.go rename to internal/client/delivery/http/client_http_test.go index 51811c4..ca35e8a 100644 --- a/internal/client/delivery/http/home_http_test.go +++ b/internal/client/delivery/http/client_http_test.go @@ -1,6 +1,7 @@ package http_test import ( + "sync" "testing" "github.com/fasthttp/router" @@ -12,16 +13,26 @@ import ( delivery "source.toby3d.me/website/indieauth/internal/client/delivery/http" "source.toby3d.me/website/indieauth/internal/domain" + sessionrepo "source.toby3d.me/website/indieauth/internal/session/repository/memory" "source.toby3d.me/website/indieauth/internal/testing/httptest" + tokenrepo "source.toby3d.me/website/indieauth/internal/token/repository/memory" + tokenucase "source.toby3d.me/website/indieauth/internal/token/usecase" ) func TestRead(t *testing.T) { t.Parallel() + store := new(sync.Map) + config := domain.TestConfig(t) + r := router.New() - delivery.NewRequestHandler( - domain.TestConfig(t), domain.TestClient(t), language.NewMatcher(message.DefaultCatalog.Languages()), - ).Register(r) + delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{ + Client: domain.TestClient(t), + Config: config, + Matcher: language.NewMatcher(message.DefaultCatalog.Languages()), + Tokens: tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store), + sessionrepo.NewMemorySessionRepository(config, store), config), + }).Register(r) client, _, cleanup := httptest.New(t, r.Handler) t.Cleanup(cleanup) diff --git a/internal/client/delivery/http/home_http.go b/internal/client/delivery/http/home_http.go deleted file mode 100644 index 55be69c..0000000 --- a/internal/client/delivery/http/home_http.go +++ /dev/null @@ -1,71 +0,0 @@ -package http - -import ( - "encoding/base64" - "strings" - - "github.com/fasthttp/router" - http "github.com/valyala/fasthttp" - "golang.org/x/text/language" - "golang.org/x/text/message" - - "source.toby3d.me/website/indieauth/internal/common" - "source.toby3d.me/website/indieauth/internal/domain" - "source.toby3d.me/website/indieauth/internal/random" - "source.toby3d.me/website/indieauth/web" -) - -type RequestHandler struct { - client *domain.Client - config *domain.Config - matcher language.Matcher -} - -const DefaultStateLength int = 64 - -func NewRequestHandler(config *domain.Config, client *domain.Client, matcher language.Matcher) *RequestHandler { - return &RequestHandler{ - client: client, - config: config, - matcher: matcher, - } -} - -func (h *RequestHandler) Register(r *router.Router) { - r.GET("/", h.read) -} - -func (h *RequestHandler) read(ctx *http.RequestCtx) { - ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) - - // TODO(toby3d): save state for checking it at the end of flow? - state, err := random.Bytes(DefaultStateLength) - if err != nil { - ctx.Error(err.Error(), http.StatusInternalServerError) - - return - } - - redirectUri := make([]string, len(h.client.RedirectURI)) - for i := range h.client.RedirectURI { - redirectUri[i] = h.client.RedirectURI[i].String() - } - - tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) - tag, _, _ := h.matcher.Match(tags...) - - ctx.Response.Header.Set( - http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`, - ) - web.WriteTemplate(ctx, &web.HomePage{ - BaseOf: web.BaseOf{ - Config: h.config, - Language: tag, - Printer: message.NewPrinter(tag), - }, - RedirectURI: h.client.RedirectURI, - AuthEndpoint: "/authorize", - Client: h.client, - State: base64.RawURLEncoding.EncodeToString(state), - }) -} diff --git a/internal/client/repository/http/http_client.go b/internal/client/repository/http/http_client.go index 5169a74..ba48492 100644 --- a/internal/client/repository/http/http_client.go +++ b/internal/client/repository/http/http_client.go @@ -1,24 +1,22 @@ package http import ( - "bytes" "context" "fmt" - "net/url" - "strings" - "github.com/tomnomnom/linkheader" http "github.com/valyala/fasthttp" - "willnorris.com/go/microformats" "source.toby3d.me/website/indieauth/internal/client" "source.toby3d.me/website/indieauth/internal/domain" + "source.toby3d.me/website/indieauth/internal/util" ) type httpClientRepository struct { client *http.Client } +const DefaultMaxRedirectsCount int = 10 + const ( relRedirectURI string = "redirect_uri" @@ -45,7 +43,7 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID) resp := http.AcquireResponse() defer http.ReleaseResponse(resp) - if err := repo.client.Do(req, resp); err != nil { + if err := repo.client.DoRedirects(req, resp, DefaultMaxRedirectsCount); err != nil { return nil, fmt.Errorf("failed to make a request to the client: %w", err) } @@ -55,115 +53,103 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID) client := &domain.Client{ ID: id, + RedirectURI: make([]*domain.URL, 0), Logo: make([]*domain.URL, 0), - Name: extractValues(resp, propertyName), - RedirectURI: extractEndpoints(resp, relRedirectURI), URL: make([]*domain.URL, 0), + Name: make([]string, 0), } - for _, v := range extractValues(resp, propertyLogo) { - u, err := domain.NewURL(v) - if err != nil { - continue - } - - client.Logo = append(client.Logo, u) - } - - for _, v := range extractValues(resp, propertyURL) { - u, err := domain.NewURL(v) - if err != nil { - continue - } - - client.URL = append(client.URL, u) - } + extract(client, resp) return client, nil } -func extractEndpoints(resp *http.Response, name string) []*domain.URL { - results := make([]*domain.URL, 0) - endpoints, _ := extractEndpointsFromHeader(resp, name) - results = append(results, endpoints...) - endpoints, _ = extractEndpointsFromBody(resp, name) - results = append(results, endpoints...) - - return results -} - -func extractValues(resp *http.Response, key string) []string { - results := make([]string, 0) - - for _, item := range microformats.Parse(bytes.NewReader(resp.Body()), nil).Items { - if len(item.Type) == 0 || (item.Type[0] != hApp && item.Type[0] != hXApp) { +func extract(dst *domain.Client, src *http.Response) { + for _, u := range util.ExtractEndpoints(src, relRedirectURI) { + if containsURL(dst.RedirectURI, u) { continue } - properties, ok := item.Properties[key] - if !ok || len(properties) == 0 { - return nil - } + dst.RedirectURI = append(dst.RedirectURI, u) + } - for j := range properties { - switch p := properties[j].(type) { - case string: - results = append(results, p) - case map[string][]interface{}: - for _, val := range p["value"] { - v, ok := val.(string) - if !ok { - continue - } - - results = append(results, v) - } + for _, t := range []string{hXApp, hApp} { + for _, name := range util.ExtractProperty(src, t, propertyName) { + n, ok := name.(string) + if !ok || containsString(dst.Name, n) { + continue } + + dst.Name = append(dst.Name, n) } - return results - } + for _, logo := range util.ExtractProperty(src, t, propertyLogo) { + var err error - return nil + var u *domain.URL + switch l := logo.(type) { + case string: + u, err = domain.NewURL(l) + case map[string]string: + value, ok := l["value"] + if !ok { + continue + } + + u, err = domain.NewURL(value) + } + + if err != nil { + continue + } + + if containsURL(dst.Logo, u) { + continue + } + + dst.Logo = append(dst.Logo, u) + } + + for _, url := range util.ExtractProperty(src, t, propertyURL) { + l, ok := url.(string) + if !ok { + continue + } + + u, err := domain.NewURL(l) + if err != nil { + continue + } + + if containsURL(dst.URL, u) { + continue + } + + dst.URL = append(dst.URL, u) + } + } } -func extractEndpointsFromHeader(resp *http.Response, name string) ([]*domain.URL, error) { - results := make([]*domain.URL, 0) - - for _, link := range linkheader.Parse(string(resp.Header.Peek(http.HeaderLink))) { - if !strings.EqualFold(link.Rel, name) { +func containsString(src []string, find string) bool { + for i := range src { + if src[i] != find { continue } - u := http.AcquireURI() - if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil { - return nil, err + return true + } + + return false +} + +func containsURL(src []*domain.URL, find *domain.URL) bool { + for i := range src { + if src[i].String() != find.String() { + continue } - results = append(results, &domain.URL{URI: u}) + return true } - return results, nil -} - -func extractEndpointsFromBody(resp *http.Response, name string) ([]*domain.URL, error) { - host, err := url.Parse(string(resp.Header.Peek(http.HeaderHost))) - if err != nil { - return nil, fmt.Errorf("cannot parse host header: %w", err) - } - - endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), host).Rels[name] - if !ok || len(endpoints) == 0 { - return nil, nil - } - - results := make([]*domain.URL, 0) - for i := range endpoints { - u := http.AcquireURI() - u.Update(endpoints[i]) - - results = append(results, &domain.URL{URI: u}) - } - - return results, nil + return false } diff --git a/internal/client/usecase.go b/internal/client/usecase.go index f67e586..305caa1 100644 --- a/internal/client/usecase.go +++ b/internal/client/usecase.go @@ -12,4 +12,4 @@ type UseCase interface { Discovery(ctx context.Context, id *domain.ClientID) (*domain.Client, error) } -var ErrInvalidMe = errors.New("provided me is invalid") +var ErrInvalidMe = errors.New("invalid me") diff --git a/internal/token/usecase.go b/internal/token/usecase.go index 9aea133..66e92eb 100644 --- a/internal/token/usecase.go +++ b/internal/token/usecase.go @@ -8,16 +8,15 @@ import ( ) type ( - GenerateOptions struct { - ClientID string - Me string - Scopes []string - NonceLength int + ExchangeOptions struct { + ClientID *domain.ClientID + RedirectURI *domain.URL + Code string + CodeVerifier string } UseCase interface { - // Generate generates a new Token based on the session data. - Generate(ctx context.Context, opts GenerateOptions) (*domain.Token, error) + Exchange(ctx context.Context, opts ExchangeOptions) (*domain.Token, error) // Verify checks the AccessToken and returns the associated information. Verify(ctx context.Context, accessToken string) (*domain.Token, error)