🚨 Removed linter warnings

This commit is contained in:
Maxim Lebedev 2022-02-01 22:27:48 +05:00
parent 7680845f74
commit 59d4c4988a
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
60 changed files with 877 additions and 886 deletions

View File

@ -1,15 +1,54 @@
--- ---
run:
tests: true
skip-dirs:
- locales
- testdata
- web
skip-dirs-use-default: true
skip-files:
- ".*_gen\\.go$"
output: output:
sort-results: true sort-results: true
linters-settings: linters-settings:
lll:
tab-width: 8
gci: gci:
local-prefixes: source.toby3d.me local-prefixes: source.toby3d.me
goimports: goimports:
local-prefixes: source.toby3d.me local-prefixes: source.toby3d.me
ireturn:
allow:
- "(Repository|UseCase)$"
- error
- stdlib
lll:
tab-width: 8
varnamelen:
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-names:
- ctx # context
- db # dataBase
- err # error
- i # index
- ip
- ln # listener
- me
- ok
- tc # testCase
- ts # timeStamp
- tx # transaction
ignore-decls:
- "cid *domain.ClientID"
- "ctx *fasthttp.RequestCtx"
- "ctx context.Context"
- "i int"
- "me *domain.Me"
- "r *router.Router"
linters: linters:
enable-all: true enable-all: true
disable:
- godox
issues: issues:
exclude-rules: exclude-rules:
- source: "^//go:generate " - source: "^//go:generate "

1
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/hashicorp/go-retryablehttp v0.7.0 // indirect github.com/hashicorp/go-retryablehttp v0.7.0 // indirect
github.com/jmoiron/sqlx v1.3.4 github.com/jmoiron/sqlx v1.3.4
github.com/klauspost/compress v1.14.2 // indirect github.com/klauspost/compress v1.14.2 // indirect
github.com/lestrrat-go/iter v1.0.1 // indirect
github.com/lestrrat-go/jwx v1.2.18 github.com/lestrrat-go/jwx v1.2.18
github.com/mattn/go-mastodon v0.0.4 github.com/mattn/go-mastodon v0.0.4
github.com/spf13/afero v1.8.0 // indirect github.com/spf13/afero v1.8.0 // indirect

View File

@ -23,7 +23,7 @@ import (
) )
type ( type (
AuthorizeRequest struct { AuthAuthorizeRequest struct {
// Indicates to the authorization server that an authorization // Indicates to the authorization server that an authorization
// code should be returned as the response. // code should be returned as the response.
ResponseType domain.ResponseType `form:"response_type"` // code ResponseType domain.ResponseType `form:"response_type"` // code
@ -58,7 +58,7 @@ type (
Me *domain.Me `form:"me"` Me *domain.Me `form:"me"`
} }
VerifyRequest struct { AuthVerifyRequest struct {
ClientID *domain.ClientID `form:"client_id"` ClientID *domain.ClientID `form:"client_id"`
Me *domain.Me `form:"me"` Me *domain.Me `form:"me"`
RedirectURI *domain.URL `form:"redirect_uri"` RedirectURI *domain.URL `form:"redirect_uri"`
@ -71,7 +71,7 @@ type (
Provider string `form:"provider"` Provider string `form:"provider"`
} }
ExchangeRequest struct { AuthExchangeRequest struct {
GrantType domain.GrantType `form:"grant_type"` // authorization_code GrantType domain.GrantType `form:"grant_type"` // authorization_code
// The authorization code received from the authorization // The authorization code received from the authorization
@ -91,50 +91,52 @@ type (
CodeVerifier string `form:"code_verifier"` CodeVerifier string `form:"code_verifier"`
} }
ExchangeResponse struct { AuthExchangeResponse struct {
Me *domain.Me `json:"me"` Me *domain.Me `json:"me"`
} }
NewRequestHandlerOptions struct { NewRequestHandlerOptions struct {
Auth auth.UseCase Auth auth.UseCase
Clients client.UseCase Clients client.UseCase
Config *domain.Config Config *domain.Config
Matcher language.Matcher Matcher language.Matcher
Providers []*domain.Provider
} }
RequestHandler struct { RequestHandler struct {
clients client.UseCase clients client.UseCase
config *domain.Config config *domain.Config
matcher language.Matcher matcher language.Matcher
useCase auth.UseCase useCase auth.UseCase
providers []*domain.Provider
} }
) )
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler { func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
return &RequestHandler{ return &RequestHandler{
clients: opts.Clients, clients: opts.Clients,
config: opts.Config, config: opts.Config,
matcher: opts.Matcher, matcher: opts.Matcher,
useCase: opts.Auth, useCase: opts.Auth,
providers: opts.Providers,
} }
} }
func (h *RequestHandler) Register(r *router.Router) { func (h *RequestHandler) Register(r *router.Router) {
chain := middleware.Chain{ chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{ middleware.CSRFWithConfig(middleware.CSRFConfig{
CookieSameSite: http.CookieSameSiteStrictMode,
CookieName: "_csrf",
TokenLookup: "form:_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
Skipper: func(ctx *http.RequestCtx) bool { Skipper: func(ctx *http.RequestCtx) bool {
matched, _ := path.Match("/api/*", string(ctx.Path())) matched, _ := path.Match("/api/*", string(ctx.Path()))
return ctx.IsPost() && matched return ctx.IsPost() && matched
}, },
CookieMaxAge: 0,
CookieSameSite: http.CookieSameSiteStrictMode,
ContextKey: "",
CookieDomain: "",
CookieName: "_csrf",
CookiePath: "",
TokenLookup: "form:_csrf",
TokenLength: 0,
CookieSecure: true,
CookieHTTPOnly: true,
}), }),
middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{ middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{
Skipper: func(ctx *http.RequestCtx) bool { Skipper: func(ctx *http.RequestCtx) bool {
@ -145,15 +147,14 @@ func (h *RequestHandler) Register(r *router.Router) {
return !ctx.IsPost() || !matched || providerMatched return !ctx.IsPost() || !matched || providerMatched
}, },
Validator: func(ctx *http.RequestCtx, login, password string) (bool, error) { Validator: func(ctx *http.RequestCtx, login, password string) (bool, error) {
userMatch := subtle.ConstantTimeCompare( userMatch := subtle.ConstantTimeCompare([]byte(login),
[]byte(login), []byte(h.config.IndieAuth.Username), []byte(h.config.IndieAuth.Username))
) passMatch := subtle.ConstantTimeCompare([]byte(password),
passMatch := subtle.ConstantTimeCompare( []byte(h.config.IndieAuth.Password))
[]byte(password), []byte(h.config.IndieAuth.Password),
)
return userMatch == 1 && passMatch == 1, nil return userMatch == 1 && passMatch == 1, nil
}, },
Realm: "",
}), }),
middleware.LogFmt(), middleware.LogFmt(),
} }
@ -164,7 +165,6 @@ func (h *RequestHandler) Register(r *router.Router) {
} }
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
req := new(AuthorizeRequest)
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8) ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
@ -175,6 +175,7 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
Printer: message.NewPrinter(tag), Printer: message.NewPrinter(tag),
} }
req := new(AuthAuthorizeRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{ web.WriteTemplate(ctx, &web.ErrorPage{
@ -213,15 +214,16 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte) csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte)
web.WriteTemplate(ctx, &web.AuthorizePage{ web.WriteTemplate(ctx, &web.AuthorizePage{
BaseOf: baseOf, BaseOf: baseOf,
Client: client,
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
CSRF: csrf, CSRF: csrf,
Scope: req.Scope,
Client: client,
Me: req.Me, Me: req.Me,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
CodeChallengeMethod: req.CodeChallengeMethod,
ResponseType: req.ResponseType, ResponseType: req.ResponseType,
Scope: req.Scope, CodeChallenge: req.CodeChallenge,
State: req.State, State: req.State,
Providers: make([]*domain.Provider, 0), // TODO(toby3d)
}) })
} }
@ -230,22 +232,23 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(ctx)
req := new(VerifyRequest) req := new(AuthVerifyRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
u := http.AcquireURI() redirectURL := http.AcquireURI()
defer http.ReleaseURI(u) defer http.ReleaseURI(redirectURL)
req.RedirectURI.CopyTo(u) req.RedirectURI.CopyTo(redirectURL)
if strings.EqualFold(req.Authorize, "deny") { if strings.EqualFold(req.Authorize, "deny") {
domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State). domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State).
SetReirectURI(u) SetReirectURI(redirectURL)
ctx.Redirect(u.String(), http.StatusFound) ctx.Redirect(redirectURL.String(), http.StatusFound)
return return
} }
@ -260,7 +263,8 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
@ -270,10 +274,10 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
"iss": h.config.Server.GetRootURL(), "iss": h.config.Server.GetRootURL(),
"state": req.State, "state": req.State,
} { } {
u.QueryArgs().Set(key, val) redirectURL.QueryArgs().Set(key, val)
} }
ctx.Redirect(u.String(), http.StatusFound) ctx.Redirect(redirectURL.String(), http.StatusFound)
} }
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
@ -281,10 +285,11 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest) req := new(AuthExchangeRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
@ -297,17 +302,18 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
encoder.Encode(&ExchangeResponse{ _ = encoder.Encode(&AuthExchangeResponse{
Me: me, Me: me,
}) })
} }
func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error { func (r *AuthAuthorizeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs(), r); err != nil { if err := form.Unmarshal(ctx.QueryArgs(), r); err != nil {
if errors.As(err, indieAuthError) { if errors.As(err, indieAuthError) {
@ -341,7 +347,7 @@ func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error {
return nil return nil
} }
func (r *VerifyRequest) bind(ctx *http.RequestCtx) error { func (r *AuthVerifyRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
@ -369,6 +375,8 @@ func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
) )
} }
// NOTE(toby3d): backwards-compatible support.
// See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
if r.ResponseType == domain.ResponseTypeID { if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode r.ResponseType = domain.ResponseTypeCode
} }
@ -386,7 +394,7 @@ func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
return nil return nil
} }
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error { func (r *AuthExchangeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
if errors.As(err, indieAuthError) { if errors.As(err, indieAuthError) {

View File

@ -35,15 +35,15 @@ func TestRender(t *testing.T) {
require.NoError(t, s.SetProvider(provider)) require.NoError(t, s.SetProvider(provider))
me := domain.TestMe(t, "https://user.example.net") me := domain.TestMe(t, "https://user.example.net")
c := domain.TestClient(t) client := domain.TestClient(t)
config := domain.TestConfig(t) config := domain.TestConfig(t)
store := new(sync.Map) store := new(sync.Map)
user := domain.TestUser(t) user := domain.TestUser(t)
store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), user) store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), user)
store.Store(path.Join(clientrepo.DefaultPathPrefix, c.ID.String()), c) store.Store(path.Join(clientrepo.DefaultPathPrefix, client.ID.String()), client)
store.Store(path.Join(profilerepo.DefaultPathPrefix, me.String()), user.Profile) store.Store(path.Join(profilerepo.DefaultPathPrefix, me.String()), user.Profile)
r := router.New() router := router.New()
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{ delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
Clients: clientucase.NewClientUseCase(clientrepo.NewMemoryClientRepository(store)), Clients: clientucase.NewClientUseCase(clientrepo.NewMemoryClientRepository(store)),
Config: config, Config: config,
@ -52,35 +52,35 @@ func TestRender(t *testing.T) {
sessionrepo.NewMemorySessionRepository(config, store), sessionrepo.NewMemorySessionRepository(config, store),
config, config,
), ),
}).Register(r) }).Register(router)
client, _, cleanup := httptest.New(t, r.Handler) httpClient, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup) t.Cleanup(cleanup)
u := http.AcquireURI() uri := http.AcquireURI()
defer http.ReleaseURI(u) defer http.ReleaseURI(uri)
u.Update("https://example.com/authorize") uri.Update("https://example.com/authorize")
for k, v := range map[string]string{ for key, val := range map[string]string{
"client_id": c.ID.String(), "client_id": client.ID.String(),
"code_challenge": "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo", "code_challenge": "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo",
"code_challenge_method": domain.CodeChallengeMethodS256.String(), "code_challenge_method": domain.CodeChallengeMethodS256.String(),
"me": me.String(), "me": me.String(),
"redirect_uri": c.RedirectURI[0].String(), "redirect_uri": client.RedirectURI[0].String(),
"response_type": domain.ResponseTypeCode.String(), "response_type": domain.ResponseTypeCode.String(),
"scope": "profile email", "scope": "profile email",
"state": "1234567890", "state": "1234567890",
} { } {
u.QueryArgs().Set(k, v) uri.QueryArgs().Set(key, val)
} }
req := httptest.NewRequest(http.MethodGet, u.String(), nil) req := httptest.NewRequest(http.MethodGet, uri.String(), nil)
defer http.ReleaseRequest(req) defer http.ReleaseRequest(req)
resp := http.AcquireResponse() resp := http.AcquireResponse()
defer http.ReleaseResponse(resp) defer http.ReleaseResponse(resp)
require.NoError(t, client.Do(req, resp)) require.NoError(t, httpClient.Do(req, resp))
assert.Equal(t, http.StatusOK, resp.StatusCode()) assert.Equal(t, http.StatusOK, resp.StatusCode())
assert.Contains(t, string(resp.Body()), `Authorize application`) assert.Contains(t, string(resp.Body()), `Authorize application`)

View File

@ -18,7 +18,7 @@ import (
) )
type ( type (
CallbackRequest struct { ClientCallbackRequest struct {
Iss *domain.ClientID `form:"iss"` Iss *domain.ClientID `form:"iss"`
Code string `form:"code"` Code string `form:"code"`
Error string `form:"error"` Error string `form:"error"`
@ -60,13 +60,14 @@ func (h *RequestHandler) Register(r *router.Router) {
} }
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) { func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
redirectUri := make([]string, len(h.client.RedirectURI)) redirect := make([]string, len(h.client.RedirectURI))
for i := range h.client.RedirectURI { for i := range h.client.RedirectURI {
redirectUri[i] = h.client.RedirectURI[i].String() redirect[i] = h.client.RedirectURI[i].String()
} }
ctx.Response.Header.Set( ctx.Response.Header.Set(
http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`, http.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
) )
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage))) tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
@ -87,7 +88,7 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
} }
func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) { func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
req := new(CallbackRequest) req := new(ClientCallbackRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError) ctx.Error(err.Error(), http.StatusInternalServerError)
@ -134,7 +135,7 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
}) })
} }
func (req *CallbackRequest) bind(ctx *http.RequestCtx) error { func (req *ClientCallbackRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs(), req); err != nil { if err := form.Unmarshal(ctx.QueryArgs(), req); err != nil {

View File

@ -25,16 +25,16 @@ func TestRead(t *testing.T) {
store := new(sync.Map) store := new(sync.Map)
config := domain.TestConfig(t) config := domain.TestConfig(t)
r := router.New() router := router.New()
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{ delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
Client: domain.TestClient(t), Client: domain.TestClient(t),
Config: config, Config: config,
Matcher: language.NewMatcher(message.DefaultCatalog.Languages()), Matcher: language.NewMatcher(message.DefaultCatalog.Languages()),
Tokens: tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store), Tokens: tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store),
sessionrepo.NewMemorySessionRepository(config, store), config), sessionrepo.NewMemorySessionRepository(config, store), config),
}).Register(r) }).Register(router)
client, _, cleanup := httptest.New(t, r.Handler) client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup) t.Cleanup(cleanup)
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/", nil) req := httptest.NewRequest(http.MethodGet, "https://app.example.com/", nil)

View File

@ -10,4 +10,8 @@ type Repository interface {
Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
} }
var ErrNotExist error = domain.NewError(domain.ErrorCodeInvalidClient, "client with the specified ID does not exist", "") var ErrNotExist error = domain.NewError(
domain.ErrorCodeInvalidClient,
"client with the specified ID does not exist",
"",
)

View File

@ -32,10 +32,10 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
} }
} }
func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) { func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) {
req := http.AcquireRequest() req := http.AcquireRequest()
defer http.ReleaseRequest(req) defer http.ReleaseRequest(req)
req.SetRequestURI(id.String()) req.SetRequestURI(cid.String())
req.Header.SetMethod(http.MethodGet) req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse() resp := http.AcquireResponse()
@ -50,7 +50,7 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
} }
client := &domain.Client{ client := &domain.Client{
ID: id, ID: cid,
RedirectURI: make([]*domain.URL, 0), RedirectURI: make([]*domain.URL, 0),
Logo: make([]*domain.URL, 0), Logo: make([]*domain.URL, 0),
URL: make([]*domain.URL, 0), URL: make([]*domain.URL, 0),
@ -62,68 +62,52 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
return client, nil return client, nil
} }
//nolint: gocognit, cyclop
func extract(dst *domain.Client, src *http.Response) { func extract(dst *domain.Client, src *http.Response) {
for _, u := range util.ExtractEndpoints(src, relRedirectURI) { for _, endpoint := range util.ExtractEndpoints(src, relRedirectURI) {
if containsURL(dst.RedirectURI, u) { if !containsURL(dst.RedirectURI, endpoint) {
continue dst.RedirectURI = append(dst.RedirectURI, endpoint)
} }
dst.RedirectURI = append(dst.RedirectURI, u)
} }
for _, t := range []string{hXApp, hApp} { for _, itemType := range []string{hXApp, hApp} {
for _, name := range util.ExtractProperty(src, t, propertyName) { for _, name := range util.ExtractProperty(src, itemType, propertyName) {
n, ok := name.(string) if n, ok := name.(string); ok && !containsString(dst.Name, n) {
if !ok || containsString(dst.Name, n) { dst.Name = append(dst.Name, n)
continue
} }
dst.Name = append(dst.Name, n)
} }
for _, logo := range util.ExtractProperty(src, t, propertyLogo) { for _, logo := range util.ExtractProperty(src, itemType, propertyLogo) {
var err error var (
uri *domain.URL
err error
)
var u *domain.URL
switch l := logo.(type) { switch l := logo.(type) {
case string: case string:
u, err = domain.ParseURL(l) uri, err = domain.ParseURL(l)
case map[string]string: case map[string]string:
value, ok := l["value"] if value, ok := l["value"]; ok {
if !ok { uri, err = domain.ParseURL(value)
continue
} }
u, err = domain.ParseURL(value)
} }
if err != nil { if err != nil || containsURL(dst.Logo, uri) {
continue continue
} }
if containsURL(dst.Logo, u) { dst.Logo = append(dst.Logo, uri)
continue
}
dst.Logo = append(dst.Logo, u)
} }
for _, url := range util.ExtractProperty(src, t, propertyURL) { for _, property := range util.ExtractProperty(src, itemType, propertyURL) {
l, ok := url.(string) prop, ok := property.(string)
if !ok { if !ok {
continue continue
} }
u, err := domain.ParseURL(l) if u, err := domain.ParseURL(prop); err == nil || !containsURL(dst.URL, u) {
if err != nil { dst.URL = append(dst.URL, u)
continue
} }
if containsURL(dst.URL, u) {
continue
}
dst.URL = append(dst.URL, u)
} }
} }
} }

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain package domain
import ( import (
@ -14,7 +15,7 @@ type Action struct {
uid string uid string
} }
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants //nolint: gochecknoglobals // structs cannot be constants
var ( var (
ActionUndefined = Action{uid: ""} ActionUndefined = Action{uid: ""}

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain_test package domain_test
import ( import (
@ -12,13 +13,10 @@ func TestParseAction(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
in string in string
out domain.Action out domain.Action
}{{ }{
in: "revoke", {in: "revoke", out: domain.ActionRevoke},
out: domain.ActionRevoke, {in: "ticket", out: domain.ActionTicket},
}, { } {
in: "ticket",
out: domain.ActionTicket,
}} {
tc := tc tc := tc
t.Run(tc.in, func(t *testing.T) { t.Run(tc.in, func(t *testing.T) {
@ -73,15 +71,10 @@ func TestAction_String(t *testing.T) {
name string name string
in domain.Action in domain.Action
out string out string
}{{ }{
name: "revoke", {name: "revoke", in: domain.ActionRevoke, out: "revoke"},
in: domain.ActionRevoke, {name: "ticket", in: domain.ActionTicket, out: "ticket"},
out: "revoke", } {
}, {
name: "ticket",
in: domain.ActionTicket,
out: "ticket",
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -8,111 +8,94 @@ import (
"testing" "testing"
http "github.com/valyala/fasthttp" http "github.com/valyala/fasthttp"
"golang.org/x/xerrors"
"inet.af/netaddr" "inet.af/netaddr"
) )
// ClientID is a URL client identifier. // ClientID is a URL client identifier.
type ClientID struct { type ClientID struct {
clientID *http.URI clientID *http.URI
valid bool
} }
//nolint: gochecknoglobals //nolint: gochecknoglobals // slices cannot be constants
var ( var (
localhostIPv4 = netaddr.MustParseIP("127.0.0.1") localhostIPv4 = netaddr.MustParseIP("127.0.0.1")
localhostIPv6 = netaddr.MustParseIP("::1") localhostIPv6 = netaddr.MustParseIP("::1")
) )
// ParseClientID parse string as client ID URL identifier. // ParseClientID parse string as client ID URL identifier.
//nolint: funlen //nolint: funlen, cyclop
func ParseClientID(src string) (*ClientID, error) { func ParseClientID(src string) (*ClientID, error) {
cid := http.AcquireURI() cid := http.AcquireURI()
if err := cid.Parse(nil, []byte(src)); err != nil { if err := cid.Parse(nil, []byte(src)); err != nil {
return nil, Error{ return nil, NewError(
Code: ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
Description: err.Error(), err.Error(),
URI: "https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
State: "", )
frame: xerrors.Caller(1),
}
} }
scheme := string(cid.Scheme()) scheme := string(cid.Scheme())
if scheme != "http" && scheme != "https" { if scheme != "http" && scheme != "https" {
return nil, Error{ return nil, NewError(
Code: ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
Description: "client identifier URL MUST have either an https or http scheme", "client identifier URL MUST have either an https or http scheme",
URI: "https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
State: "", )
frame: xerrors.Caller(1),
}
} }
path := string(cid.PathOriginal()) path := string(cid.PathOriginal())
if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") { if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") {
return nil, Error{ return nil, NewError(
Code: ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
Description: "client identifier URL MUST contain a path component and MUST NOT contain " + "client identifier URL MUST contain a path component and MUST NOT contain "+
"single-dot or double-dot path segments", "single-dot or double-dot path segments",
URI: "https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
State: "", )
frame: xerrors.Caller(1),
}
} }
if cid.Hash() != nil { if cid.Hash() != nil {
return nil, Error{ return nil, NewError(
Code: ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
Description: "client identifier URL MUST NOT contain a fragment component", "client identifier URL MUST NOT contain a fragment component",
URI: "https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
State: "", )
frame: xerrors.Caller(1),
}
} }
if cid.Username() != nil || cid.Password() != nil { if cid.Username() != nil || cid.Password() != nil {
return nil, Error{ return nil, NewError(
Code: ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
Description: "client identifier URL MUST NOT contain a username or password component", "client identifier URL MUST NOT contain a username or password component",
URI: "https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
State: "", )
frame: xerrors.Caller(1),
}
} }
domain := string(cid.Host()) domain := string(cid.Host())
if domain == "" { if domain == "" {
return nil, Error{ return nil, NewError(
Code: ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
Description: "client host name MUST be domain name or a loopback interface", "client host name MUST be domain name or a loopback interface",
URI: "https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
State: "", )
frame: xerrors.Caller(1),
}
} }
ip, err := netaddr.ParseIP(domain) ip, err := netaddr.ParseIP(domain)
if err != nil { if err != nil {
ipPort, err := netaddr.ParseIPPort(domain) ipPort, err := netaddr.ParseIPPort(domain)
if err != nil { if err != nil {
return &ClientID{ //nolint: nilerr // ClientID does not contain an IP address, so it is valid
clientID: cid, return &ClientID{clientID: cid}, nil
}, nil
} }
ip = ipPort.IP() ip = ipPort.IP()
} }
if !ip.IsLoopback() && ip.Compare(localhostIPv4) != 0 && ip.Compare(localhostIPv6) != 0 { if !ip.IsLoopback() && ip.Compare(localhostIPv4) != 0 && ip.Compare(localhostIPv6) != 0 {
return nil, Error{ return nil, NewError(
Code: ErrorCodeInvalidRequest, ErrorCodeInvalidRequest,
Description: "client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 " + "client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 "+
"127.0.0.1 or IPv6 [::1]", "127.0.0.1 or IPv6 [::1]",
URI: "https://indieauth.net/source/#client-identifier", "https://indieauth.net/source/#client-identifier",
State: "", )
frame: xerrors.Caller(1),
}
} }
return &ClientID{ return &ClientID{
@ -126,7 +109,7 @@ func TestClientID(tb testing.TB) *ClientID {
clientID, err := ParseClientID("https://app.example.com/") clientID, err := ParseClientID("https://app.example.com/")
if err != nil { if err != nil {
tb.Fatalf("%+v", err) tb.Fatal(err)
} }
return clientID return clientID
@ -167,23 +150,28 @@ func (cid ClientID) MarshalJSON() ([]byte, error) {
} }
// URI returns copy of parsed *fasthttp.URI. // URI returns copy of parsed *fasthttp.URI.
// This copy MUST be released via fasthttp.ReleaseURI. //
// WARN(toby3d): This copy MUST be released via fasthttp.ReleaseURI.
func (cid ClientID) URI() *http.URI { func (cid ClientID) URI() *http.URI {
u := http.AcquireURI() uri := http.AcquireURI()
cid.clientID.CopyTo(u) cid.clientID.CopyTo(uri)
return u return uri
} }
// URL returns url.URL representation of client ID. // URL returns url.URL representation of client ID.
func (cid ClientID) URL() *url.URL { func (cid ClientID) URL() *url.URL {
return &url.URL{ return &url.URL{
Scheme: string(cid.clientID.Scheme()), ForceQuery: false,
Host: string(cid.clientID.Host()), Fragment: string(cid.clientID.Hash()),
Path: string(cid.clientID.Path()), Host: string(cid.clientID.Host()),
RawPath: string(cid.clientID.PathOriginal()), Opaque: "",
RawQuery: string(cid.clientID.QueryString()), Path: string(cid.clientID.Path()),
Fragment: string(cid.clientID.Hash()), RawFragment: "",
RawPath: string(cid.clientID.PathOriginal()),
RawQuery: string(cid.clientID.QueryString()),
Scheme: string(cid.clientID.Scheme()),
User: nil,
} }
} }

View File

@ -7,7 +7,6 @@ import (
"source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/domain"
) )
//nolint: funlen
func TestParseClientID(t *testing.T) { func TestParseClientID(t *testing.T) {
t.Parallel() t.Parallel()
@ -15,51 +14,19 @@ func TestParseClientID(t *testing.T) {
name string name string
in string in string
expError bool expError bool
}{{ }{
name: "valid", {name: "valid", in: "https://example.com/", expError: false},
in: "https://example.com/", {name: "valid path", in: "https://example.com/username", expError: false},
expError: false, {name: "valid query", in: "https://example.com/users?id=100", expError: false},
}, { {name: "valid port", in: "https://example.com:8443/", expError: false},
name: "valid path", {name: "valid loopback", in: "https://127.0.0.1:8443/", expError: false},
in: "https://example.com/username", {name: "missing scheme", in: "example.com", expError: true},
expError: false, {name: "invalid scheme", in: "mailto:user@example.com", expError: true},
}, { {name: "invalid double-dot path", in: "https://example.com/foo/../bar", expError: true},
name: "valid query", {name: "invalid fragment", in: "https://example.com/#me", expError: true},
in: "https://example.com/users?id=100", {name: "invalid user", in: "https://user:pass@example.com/", expError: true},
expError: false, {name: "host is an IP address", in: "https://172.28.92.51/", expError: true},
}, { } {
name: "valid port",
in: "https://example.com:8443/",
expError: false,
}, {
name: "valid loopback",
in: "https://127.0.0.1:8443/",
expError: false,
}, {
name: "missing scheme",
in: "example.com",
expError: true,
}, {
name: "invalid scheme",
in: "mailto:user@example.com",
expError: true,
}, {
name: "invalid double-dot path",
in: "https://example.com/foo/../bar",
expError: true,
}, {
name: "invalid fragment",
in: "https://example.com/#me",
expError: true,
}, {
name: "invalid user",
in: "https://user:pass@example.com/",
expError: true,
}, {
name: "host is an IP address",
in: "https://172.28.92.51/",
expError: true,
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -131,8 +98,7 @@ func TestClientID_MarshalJSON(t *testing.T) {
func TestClientID_String(t *testing.T) { func TestClientID_String(t *testing.T) {
t.Parallel() t.Parallel()
cid := domain.TestClientID(t) if cid := domain.TestClientID(t); cid.String() != fmt.Sprint(cid) {
if result := cid.String(); result != fmt.Sprint(cid) { t.Errorf("String() = %s, want %s", cid.String(), fmt.Sprint(cid))
t.Errorf("Strig() = %s, want %s", result, fmt.Sprint(cid))
} }
} }

View File

@ -15,13 +15,10 @@ func TestClient_ValidateRedirectURI(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
in *domain.URL in *domain.URL
}{{ }{
name: "client_id prefix", {name: "client_id prefix", in: domain.TestURL(t, fmt.Sprint(client.ID, "/callback"))},
in: domain.TestURL(t, fmt.Sprint(client.ID, "/callback")), {name: "registered redirect_uri", in: client.RedirectURI[len(client.RedirectURI)-1]},
}, { } {
name: "registered redirect_uri",
in: client.RedirectURI[len(client.RedirectURI)-1],
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -1,5 +1,6 @@
package domain package domain
//nolint: gosec // support old clients
import ( import (
"crypto/md5" "crypto/md5"
"crypto/sha1" "crypto/sha1"
@ -21,7 +22,7 @@ type CodeChallengeMethod struct {
uid string uid string
} }
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants //nolint: gochecknoglobals // structs cannot be constants
var ( var (
CodeChallengeMethodUndefined = CodeChallengeMethod{ CodeChallengeMethodUndefined = CodeChallengeMethod{
uid: "", uid: "",
@ -34,12 +35,14 @@ var (
} }
CodeChallengeMethodMD5 = CodeChallengeMethod{ CodeChallengeMethodMD5 = CodeChallengeMethod{
uid: "MD5", uid: "MD5",
//nolint: gosec // support old clients
hash: md5.New(), hash: md5.New(),
} }
CodeChallengeMethodS1 = CodeChallengeMethod{ CodeChallengeMethodS1 = CodeChallengeMethod{
uid: "S1", uid: "S1",
//nolint: gosec // support old clients
hash: sha1.New(), hash: sha1.New(),
} }
@ -60,7 +63,7 @@ var ErrCodeChallengeMethodUnknown error = NewError(
"https://indieauth.net/source/#authorization-request", "https://indieauth.net/source/#authorization-request",
) )
//nolint: gochecknoglobals // NOTE(toby3d): maps cannot be constants //nolint: gochecknoglobals // maps cannot be constants
var slugsMethods = map[string]CodeChallengeMethod{ var slugsMethods = map[string]CodeChallengeMethod{
CodeChallengeMethodMD5.uid: CodeChallengeMethodMD5, CodeChallengeMethodMD5.uid: CodeChallengeMethodMD5,
CodeChallengeMethodPLAIN.uid: CodeChallengeMethodPLAIN, CodeChallengeMethodPLAIN.uid: CodeChallengeMethodPLAIN,

View File

@ -1,5 +1,6 @@
package domain_test package domain_test
//nolint: gosec // support old clients
import ( import (
"crypto/md5" "crypto/md5"
"crypto/sha1" "crypto/sha1"
@ -23,32 +24,14 @@ func TestParseCodeChallengeMethod(t *testing.T) {
in string in string
out domain.CodeChallengeMethod out domain.CodeChallengeMethod
expError bool expError bool
}{{ }{
expError: true, {name: "invalid", in: "und", out: domain.CodeChallengeMethodUndefined, expError: true},
name: "invalid", {name: "PLAIN", in: "plain", out: domain.CodeChallengeMethodPLAIN, expError: false},
in: "und", {name: "MD5", in: "Md5", out: domain.CodeChallengeMethodMD5, expError: false},
out: domain.CodeChallengeMethodUndefined, {name: "S1", in: "S1", out: domain.CodeChallengeMethodS1, expError: false},
}, { {name: "S256", in: "S256", out: domain.CodeChallengeMethodS256, expError: false},
name: "PLAIN", {name: "S512", in: "S512", out: domain.CodeChallengeMethodS512, expError: false},
in: "plain", } {
out: domain.CodeChallengeMethodPLAIN,
}, {
name: "MD5",
in: "Md5",
out: domain.CodeChallengeMethodMD5,
}, {
name: "S1",
in: "S1",
out: domain.CodeChallengeMethodS1,
}, {
name: "S256",
in: "S256",
out: domain.CodeChallengeMethodS256,
}, {
name: "S512",
in: "S512",
out: domain.CodeChallengeMethodS512,
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -107,27 +90,13 @@ func TestCodeChallengeMethod_String(t *testing.T) {
name string name string
in domain.CodeChallengeMethod in domain.CodeChallengeMethod
out string out string
}{{ }{
name: "plain", {name: "plain", in: domain.CodeChallengeMethodPLAIN, out: "PLAIN"},
in: domain.CodeChallengeMethodPLAIN, {name: "md5", in: domain.CodeChallengeMethodMD5, out: "MD5"},
out: "PLAIN", {name: "s1", in: domain.CodeChallengeMethodS1, out: "S1"},
}, { {name: "s256", in: domain.CodeChallengeMethodS256, out: "S256"},
name: "md5", {name: "s512", in: domain.CodeChallengeMethodS512, out: "S512"},
in: domain.CodeChallengeMethodMD5, } {
out: "MD5",
}, {
name: "s1",
in: domain.CodeChallengeMethodS1,
out: "S1",
}, {
name: "s256",
in: domain.CodeChallengeMethodS256,
out: "S256",
}, {
name: "s512",
in: domain.CodeChallengeMethodS512,
out: "S512",
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -141,7 +110,7 @@ func TestCodeChallengeMethod_String(t *testing.T) {
} }
} }
//nolint: funlen //nolint: gosec // support old clients
func TestCodeChallengeMethod_Validate(t *testing.T) { func TestCodeChallengeMethod_Validate(t *testing.T) {
t.Parallel() t.Parallel()
@ -155,42 +124,15 @@ func TestCodeChallengeMethod_Validate(t *testing.T) {
in domain.CodeChallengeMethod in domain.CodeChallengeMethod
name string name string
expError bool expError bool
}{{ }{
name: "invalid", {name: "invalid", in: domain.CodeChallengeMethodS256, hash: md5.New(), expError: true},
in: domain.CodeChallengeMethodS256, {name: "MD5", in: domain.CodeChallengeMethodMD5, hash: md5.New(), expError: false},
hash: md5.New(), {name: "plain", in: domain.CodeChallengeMethodPLAIN, hash: nil, expError: false},
expError: true, {name: "S1", in: domain.CodeChallengeMethodS1, hash: sha1.New(), expError: false},
}, { {name: "S256", in: domain.CodeChallengeMethodS256, hash: sha256.New(), expError: false},
name: "MD5", {name: "S512", in: domain.CodeChallengeMethodS512, hash: sha512.New(), expError: false},
in: domain.CodeChallengeMethodMD5, {name: "undefined", in: domain.CodeChallengeMethodUndefined, hash: nil, expError: true},
hash: md5.New(), } {
expError: false,
}, {
name: "plain",
in: domain.CodeChallengeMethodPLAIN,
hash: nil,
expError: false,
}, {
name: "S1",
in: domain.CodeChallengeMethodS1,
hash: sha1.New(),
expError: false,
}, {
name: "S256",
in: domain.CodeChallengeMethodS256,
hash: sha256.New(),
expError: false,
}, {
name: "S512",
in: domain.CodeChallengeMethodS512,
hash: sha512.New(),
expError: false,
}, {
name: "undefined",
in: domain.CodeChallengeMethodUndefined,
hash: nil,
expError: true,
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -78,6 +78,7 @@ type (
) )
// TestConfig returns a valid config for tests. // TestConfig returns a valid config for tests.
//nolint: gomnd // testing domain can contains non-standart values
func TestConfig(tb testing.TB) *Config { func TestConfig(tb testing.TB) *Config {
tb.Helper() tb.Helper()

View File

@ -12,12 +12,14 @@ type Email struct {
subAddress string subAddress string
} }
const DefaultEmailPartsLength int = 2
var ErrEmailInvalid error = NewError(ErrorCodeInvalidRequest, "cannot parse email", "") var ErrEmailInvalid error = NewError(ErrorCodeInvalidRequest, "cannot parse email", "")
// ParseEmail parse strings to email identifier. // ParseEmail parse strings to email identifier.
func ParseEmail(src string) (*Email, error) { func ParseEmail(src string) (*Email, error) {
parts := strings.Split(strings.TrimPrefix(src, "mailto:"), "@") parts := strings.Split(strings.TrimPrefix(src, "mailto:"), "@")
if len(parts) != 2 { //nolint: gomnd if len(parts) != DefaultEmailPartsLength {
return nil, ErrEmailInvalid return nil, ErrEmailInvalid
} }
@ -27,7 +29,7 @@ func ParseEmail(src string) (*Email, error) {
subAddress: "", subAddress: "",
} }
if userParts := strings.SplitN(parts[0], `+`, 2); len(userParts) > 1 { if userParts := strings.SplitN(parts[0], `+`, DefaultEmailPartsLength); len(userParts) > 1 {
result.user = userParts[0] result.user = userParts[0]
result.subAddress = userParts[1] result.subAddress = userParts[1]
} }

View File

@ -14,19 +14,11 @@ func TestParseEmail(t *testing.T) {
name string name string
in string in string
out string out string
}{{ }{
name: "simple", {name: "simple", in: "user@example.com", out: "user@example.com"},
in: "user@example.com", {name: "subAddress", in: "user+suffix@example.com", out: "user+suffix@example.com"},
out: "user@example.com", {name: "mailto prefix", in: "mailto:user@example.com", out: "user@example.com"},
}, { } {
name: "subAddress",
in: "user+suffix@example.com",
out: "user+suffix@example.com",
}, {
name: "mailto prefix",
in: "mailto:user@example.com",
out: "user@example.com",
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -47,7 +39,7 @@ func TestParseEmail(t *testing.T) {
func TestEmail_String(t *testing.T) { func TestEmail_String(t *testing.T) {
t.Parallel() t.Parallel()
email := domain.TestEmail(t) email := domain.TestEmail(t) //nolint: ifshort
if result := email.String(); result != fmt.Sprint(email) { if result := email.String(); result != fmt.Sprint(email) {
t.Errorf("String() = %v, want %v", result, email) t.Errorf("String() = %v, want %v", result, email)
} }

View File

@ -10,18 +10,19 @@ import (
type ( type (
// Error describes the format of a typical IndieAuth error. // Error describes the format of a typical IndieAuth error.
//nolint: tagliatelle // RFC 6749 section 5.2
Error struct { Error struct {
// A single error code. // A single error code.
Code ErrorCode `json:"error"` Code ErrorCode `json:"error"`
// Human-readable ASCII text providing additional information, used to // Human-readable ASCII text providing additional information, used to
// assist the client developer in understanding the error that occurred. // assist the client developer in understanding the error that occurred.
Description string `json:"error_description,omitempty"` //nolint: tagliatelle Description string `json:"error_description,omitempty"`
// A URI identifying a human-readable web page with information about // A URI identifying a human-readable web page with information about
// the error, used to provide the client developer with additional // the error, used to provide the client developer with additional
// information about the error. // information about the error.
URI string `json:"error_uri,omitempty"` //nolint: tagliatelle URI string `json:"error_uri,omitempty"`
// REQUIRED if a "state" parameter was present in the client // REQUIRED if a "state" parameter was present in the client
// authorization request. The exact value received from the // authorization request. The exact value received from the
@ -33,17 +34,29 @@ type (
// ErrorCode represent error code described in RFC 6749. // ErrorCode represent error code described in RFC 6749.
ErrorCode struct { ErrorCode struct {
uid string uid string
status int
} }
) )
var ( var (
ErrorCodeUndefined ErrorCode = ErrorCode{uid: ""} // ErrorCodeUndefined describes an unrecognized error code.
ErrorCodeUndefined = ErrorCode{
uid: "",
status: 0,
}
// ErrorCodeAccessDenied describes the access_denied error code.
//
// RFC 6749 section 4.1.2.1: The resource owner or authorization server // RFC 6749 section 4.1.2.1: The resource owner or authorization server
// denied the request. // denied the request.
ErrorCodeAccessDenied ErrorCode = ErrorCode{uid: "access_denied"} ErrorCodeAccessDenied = ErrorCode{
uid: "access_denied",
status: 0, // TODO(toby3d)
}
// ErrorCodeInvalidClient describes the invalid_client error code.
//
// RFC 6749 section 5.2: Client authentication failed (e.g., unknown // RFC 6749 section 5.2: Client authentication failed (e.g., unknown
// client, no client authentication included, or unsupported // client, no client authentication included, or unsupported
// authentication method). // authentication method).
@ -56,14 +69,26 @@ var (
// HTTP 401 (Unauthorized) status code and include the // HTTP 401 (Unauthorized) status code and include the
// "WWW-Authenticate" response header field matching the authentication // "WWW-Authenticate" response header field matching the authentication
// scheme used by the client. // scheme used by the client.
ErrorCodeInvalidClient ErrorCode = ErrorCode{uid: "invalid_client"} ErrorCodeInvalidClient = ErrorCode{
uid: "invalid_client",
status: 0, // TODO(toby3d)
}
// ErrorCodeInvalidGrant describes the invalid_grant error code.
//
// RFC 6749 section 5.2: The provided authorization grant (e.g., // RFC 6749 section 5.2: The provided authorization grant (e.g.,
// authorization code, resource owner credentials) or refresh token is // authorization code, resource owner credentials) or refresh token is
// invalid, expired, revoked, does not match the redirection URI used in // invalid, expired, revoked, does not match the redirection URI used in
// the authorization request, or was issued to another client. // the authorization request, or was issued to another client.
ErrorCodeInvalidGrant ErrorCode = ErrorCode{uid: "invalid_grant"} ErrorCodeInvalidGrant = ErrorCode{
uid: "invalid_grant",
status: 0, // TODO(toby3d)
}
// ErrorCodeInvalidRequest describes the invalid_request error code.
//
// IndieAuth: The request is not valid.
//
// RFC 6749 section 4.1.2.1: The request is missing a required // RFC 6749 section 4.1.2.1: The request is missing a required
// parameter, includes an invalid parameter value, includes a parameter // parameter, includes an invalid parameter value, includes a parameter
// more than once, or is otherwise malformed. // more than once, or is otherwise malformed.
@ -73,42 +98,91 @@ var (
// repeats a parameter, includes multiple credentials, utilizes more // repeats a parameter, includes multiple credentials, utilizes more
// than one mechanism for authenticating the client, or is otherwise // than one mechanism for authenticating the client, or is otherwise
// malformed. // malformed.
ErrorCodeInvalidRequest ErrorCode = ErrorCode{uid: "invalid_request"} ErrorCodeInvalidRequest = ErrorCode{
uid: "invalid_request",
status: http.StatusBadRequest,
}
// ErrorCodeInvalidScope describes the invalid_scope error code.
//
// RFC 6749 section 4.1.2.1: The requested scope is invalid, unknown, or // RFC 6749 section 4.1.2.1: The requested scope is invalid, unknown, or
// malformed. // malformed.
// //
// RFC 6749 section 5.2: The requested scope is invalid, unknown, // RFC 6749 section 5.2: The requested scope is invalid, unknown,
// malformed, or exceeds the scope granted by the resource owner. // malformed, or exceeds the scope granted by the resource owner.
ErrorCodeInvalidScope ErrorCode = ErrorCode{uid: "invalid_scope"} ErrorCodeInvalidScope = ErrorCode{
uid: "invalid_scope",
status: 0, // TODO(toby3d)
}
// ErrorCodeServerError describes the server_error error code.
//
// RFC 6749 section 4.1.2.1: The authorization server encountered an // RFC 6749 section 4.1.2.1: The authorization server encountered an
// unexpected condition that prevented it from fulfilling the request. // unexpected condition that prevented it from fulfilling the request.
// (This error code is needed because a 500 Internal Server Error HTTP // (This error code is needed because a 500 Internal Server Error HTTP
// status code cannot be returned to the client via an HTTP redirect.) // status code cannot be returned to the client via an HTTP redirect.)
ErrorCodeServerError ErrorCode = ErrorCode{uid: "server_error"} ErrorCodeServerError = ErrorCode{
uid: "server_error",
status: 0, // TODO(toby3d)
}
// ErrorCodeTemporarilyUnavailable describes the temporarily_unavailable error code.
//
// RFC 6749 section 4.1.2.1: The authorization server is currently // RFC 6749 section 4.1.2.1: The authorization server is currently
// unable to handle the request due to a temporary overloading or // unable to handle the request due to a temporary overloading or
// maintenance of the server. (This error code is needed because a 503 // maintenance of the server. (This error code is needed because a 503
// Service Unavailable HTTP status code cannot be returned to the client // Service Unavailable HTTP status code cannot be returned to the client
// via an HTTP redirect.) // via an HTTP redirect.)
ErrorCodeTemporarilyUnavailable ErrorCode = ErrorCode{uid: "temporarily_unavailable"} ErrorCodeTemporarilyUnavailable = ErrorCode{
uid: "temporarily_unavailable",
status: 0, // TODO(toby3d)
}
// ErrorCodeUnauthorizedClient describes the unauthorized_client error code.
//
// RFC 6749 section 4.1.2.1: The client is not authorized to request an // RFC 6749 section 4.1.2.1: The client is not authorized to request an
// authorization code using this method. // authorization code using this method.
// //
// RFC 6749 section 5.2: The authenticated client is not authorized to // RFC 6749 section 5.2: The authenticated client is not authorized to
// use this authorization grant type. // use this authorization grant type.
ErrorCodeUnauthorizedClient ErrorCode = ErrorCode{uid: "unauthorized_client"} ErrorCodeUnauthorizedClient = ErrorCode{
uid: "unauthorized_client",
status: 0, // TODO(toby3d)
}
// ErrorCodeUnsupportedGrantType describes the unsupported_grant_type error code.
//
// RFC 6749 section 5.2: The authorization grant type is not supported // RFC 6749 section 5.2: The authorization grant type is not supported
// by the authorization server. // by the authorization server.
ErrorCodeUnsupportedGrantType ErrorCode = ErrorCode{uid: "unsupported_grant_type"} ErrorCodeUnsupportedGrantType = ErrorCode{
uid: "unsupported_grant_type",
status: 0, // TODO(toby3d)
}
// ErrorCodeUnsupportedResponseType describes the unsupported_response_type error code.
//
// RFC 6749 section 4.1.2.1: The authorization server does not support // RFC 6749 section 4.1.2.1: The authorization server does not support
// obtaining an authorization code using this method. // obtaining an authorization code using this method.
ErrorCodeUnsupportedResponseType ErrorCode = ErrorCode{uid: "unsupported_response_type"} ErrorCodeUnsupportedResponseType = ErrorCode{
uid: "unsupported_response_type",
status: 0, // TODO(toby3d)
}
// ErrorCodeInvalidToken describes the invalid_token error code.
//
// IndieAuth: The access token provided is expired, revoked, or invalid.
ErrorCodeInvalidToken = ErrorCode{
uid: "invalid_token",
status: http.StatusUnauthorized,
}
// ErrorCodeInsufficientScope describes the insufficient_scope error code.
//
// IndieAuth: The request requires higher privileges than provided.
ErrorCodeInsufficientScope = ErrorCode{
uid: "insufficient_scope",
status: http.StatusForbidden,
}
) )
// String returns a string representation of the error code. // String returns a string representation of the error code.
@ -128,45 +202,45 @@ func (e Error) Error() string {
} }
// Format prints the stack as error detail. // Format prints the stack as error detail.
func (e Error) Format(s fmt.State, r rune) { func (e Error) Format(state fmt.State, r rune) {
xerrors.FormatError(e, s, r) xerrors.FormatError(e, state, r)
} }
// FormatError prints the receiver's error, if any. // FormatError prints the receiver's error, if any.
func (e Error) FormatError(p xerrors.Printer) error { func (e Error) FormatError(printer xerrors.Printer) error {
p.Print(e.Code) printer.Print(e.Code)
if e.Description != "" { if e.Description != "" {
p.Print(": ", e.Description) printer.Print(": ", e.Description)
} }
if !p.Detail() { if !printer.Detail() {
return nil return nil
} }
e.frame.Format(p) e.frame.Format(printer)
return nil return nil
} }
// SetReirectURI sets fasthttp.QueryArgs with the request state, code, // SetReirectURI sets fasthttp.QueryArgs with the request state, code,
// description and error URI in the provided fasthttp.URI. // description and error URI in the provided fasthttp.URI.
func (e Error) SetReirectURI(u *http.URI) { func (e Error) SetReirectURI(uri *http.URI) {
if u == nil { if uri == nil {
return return
} }
for k, v := range map[string]string{ for key, val := range map[string]string{
"error": e.Code.String(), "error": e.Code.String(),
"error_description": e.Description, "error_description": e.Description,
"error_uri": e.URI, "error_uri": e.URI,
"state": e.State, "state": e.State,
} { } {
if v == "" { if val == "" {
continue continue
} }
u.QueryArgs().Set(k, v) uri.QueryArgs().Set(key, val)
} }
} }

View File

@ -1,19 +1,21 @@
//nolint: dupl
package domain package domain
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
) )
// GrantType represent fixed grant_type parameter.
//
// NOTE(toby3d): Encapsulate enums in structs for extra compile-time safety: // NOTE(toby3d): Encapsulate enums in structs for extra compile-time safety:
// https://threedots.tech/post/safer-enums-in-go/#struct-based-enums // https://threedots.tech/post/safer-enums-in-go/#struct-based-enums
type GrantType struct { type GrantType struct {
uid string uid string
} }
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants //nolint: gochecknoglobals // structs cannot be constants
var ( var (
GrantTypeUndefined = GrantType{uid: ""} GrantTypeUndefined = GrantType{uid: ""}
GrantTypeAuthorizationCode = GrantType{uid: "authorization_code"} GrantTypeAuthorizationCode = GrantType{uid: "authorization_code"}
@ -22,7 +24,11 @@ var (
GrantTypeTicket = GrantType{uid: "ticket"} GrantTypeTicket = GrantType{uid: "ticket"}
) )
var ErrGrantTypeUnknown error = errors.New("unknown grant type") var ErrGrantTypeUnknown error = NewError(
ErrorCodeInvalidGrant,
"unknown grant type",
"",
)
// ParseGrantType parse grant_type value as GrantType struct enum. // ParseGrantType parse grant_type value as GrantType struct enum.
func ParseGrantType(uid string) (GrantType, error) { func ParseGrantType(uid string) (GrantType, error) {
@ -40,7 +46,7 @@ func ParseGrantType(uid string) (GrantType, error) {
func (gt *GrantType) UnmarshalForm(src []byte) error { func (gt *GrantType) UnmarshalForm(src []byte) error {
responseType, err := ParseGrantType(string(src)) responseType, err := ParseGrantType(string(src))
if err != nil { if err != nil {
return fmt.Errorf("grant_type: %w", err) return fmt.Errorf("UnmarshalForm: %w", err)
} }
*gt = responseType *gt = responseType
@ -52,12 +58,12 @@ func (gt *GrantType) UnmarshalForm(src []byte) error {
func (gt *GrantType) UnmarshalJSON(v []byte) error { func (gt *GrantType) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v)) src, err := strconv.Unquote(string(v))
if err != nil { if err != nil {
return err return fmt.Errorf("UnmarshalJSON: %w", err)
} }
responseType, err := ParseGrantType(src) responseType, err := ParseGrantType(src)
if err != nil { if err != nil {
return fmt.Errorf("grant_type: %w", err) return fmt.Errorf("UnmarshalJSON: %w", err)
} }
*gt = responseType *gt = responseType

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain_test package domain_test
import ( import (
@ -12,13 +13,10 @@ func TestParseGrantType(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
in string in string
out domain.GrantType out domain.GrantType
}{{ }{
in: "authorization_code", {in: "authorization_code", out: domain.GrantTypeAuthorizationCode},
out: domain.GrantTypeAuthorizationCode, {in: "ticket", out: domain.GrantTypeTicket},
}, { } {
in: "ticket",
out: domain.GrantTypeTicket,
}} {
tc := tc tc := tc
t.Run(tc.in, func(t *testing.T) { t.Run(tc.in, func(t *testing.T) {
@ -73,15 +71,10 @@ func TestGrantType_String(t *testing.T) {
name string name string
in domain.GrantType in domain.GrantType
out string out string
}{{ }{
name: "authorization_code", {name: "authorization_code", in: domain.GrantTypeAuthorizationCode, out: "authorization_code"},
in: domain.GrantTypeAuthorizationCode, {name: "ticket", in: domain.GrantTypeTicket, out: "ticket"},
out: "authorization_code", } {
}, {
name: "ticket",
in: domain.GrantTypeTicket,
out: "ticket",
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -18,7 +18,7 @@ type Me struct {
} }
// ParseMe parse string as me URL identifier. // ParseMe parse string as me URL identifier.
//nolint: funlen //nolint: funlen, cyclop
func ParseMe(raw string) (*Me, error) { func ParseMe(raw string) (*Me, error) {
me := http.AcquireURI() me := http.AcquireURI()
if err := me.Parse(nil, []byte(raw)); err != nil { if err := me.Parse(nil, []byte(raw)); err != nil {
@ -114,7 +114,7 @@ func TestMe(tb testing.TB, src string) *Me {
me, err := ParseMe(src) me, err := ParseMe(src)
if err != nil { if err != nil {
tb.Fatalf("%+v", err) tb.Fatal(err)
} }
return me return me
@ -174,12 +174,16 @@ func (m Me) URL() *url.URL {
} }
return &url.URL{ return &url.URL{
Scheme: string(m.me.Scheme()), ForceQuery: false,
Host: string(m.me.Host()), Fragment: string(m.me.Hash()),
Path: string(m.me.Path()), Host: string(m.me.Host()),
RawPath: string(m.me.PathOriginal()), Opaque: "",
RawQuery: string(m.me.QueryString()), Path: string(m.me.Path()),
Fragment: string(m.me.Hash()), RawFragment: "",
RawPath: string(m.me.PathOriginal()),
RawQuery: string(m.me.QueryString()),
Scheme: string(m.me.Scheme()),
User: nil,
} }
} }

View File

@ -7,7 +7,6 @@ import (
"source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/domain"
) )
//nolint: funlen
func TestParseMe(t *testing.T) { func TestParseMe(t *testing.T) {
t.Parallel() t.Parallel()
@ -15,47 +14,18 @@ func TestParseMe(t *testing.T) {
name string name string
in string in string
expError bool expError bool
}{{ }{
name: "valid", {name: "valid", in: "https://example.com/", expError: false},
in: "https://example.com/", {name: "valid path", in: "https://example.com/username", expError: false},
expError: false, {name: "valid query", in: "https://example.com/users?id=100", expError: false},
}, { {name: "missing scheme", in: "example.com", expError: true},
name: "valid path", {name: "invalid scheme", in: "mailto:user@example.com", expError: true},
in: "https://example.com/username", {name: "contains double-dot path", in: "https://example.com/foo/../bar", expError: true},
expError: false, {name: "contains fragment", in: "https://example.com/#me", expError: true},
}, { {name: "contains user", in: "https://user:pass@example.com/", expError: true},
name: "valid query", {name: "contains port", in: "https://example.com:8443/", expError: true},
in: "https://example.com/users?id=100", {name: "host is an IP address", in: "https://172.28.92.51/", expError: true},
expError: false, } {
}, {
name: "missing scheme",
in: "example.com",
expError: true,
}, {
name: "invalid scheme",
in: "mailto:user@example.com",
expError: true,
}, {
name: "contains double-dot path",
in: "https://example.com/foo/../bar",
expError: true,
}, {
name: "contains fragment",
in: "https://example.com/#me",
expError: true,
}, {
name: "contains user",
in: "https://user:pass@example.com/",
expError: true,
}, {
name: "contains port",
in: "https://example.com:8443/",
expError: true,
}, {
name: "host is an IP address",
in: "https://172.28.92.51/",
expError: true,
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -1,6 +1,6 @@
package domain package domain
//nolint: tagliatelle //nolint: tagliatelle // https://indieauth.net/source/#indieauth-server-metadata
type Metadata struct { type Metadata struct {
// The server's issuer identifier. The issuer identifier is a URL that // The server's issuer identifier. The issuer identifier is a URL that
// uses the "https" scheme and has no query or fragment components. The // uses the "https" scheme and has no query or fragment components. The

View File

@ -21,89 +21,89 @@ type Provider struct {
URL string URL string
} }
//nolint: gochecknoglobals //nolint: gochecknoglobals // structs cannot be contants
var ( var (
ProviderDirect = Provider{ ProviderDirect = Provider{
AuthURL: "/authorize", AuthURL: "/authorize",
Name: "IndieAuth", ClientID: "",
Photo: path.Join("static", "icon.svg"), ClientSecret: "",
RedirectURL: path.Join("callback"), Name: "IndieAuth",
Scopes: []string{}, Photo: path.Join("static", "icon.svg"),
TokenURL: "/token", RedirectURL: path.Join("callback"),
UID: "direct", Scopes: []string{},
URL: "/", TokenURL: "/token",
UID: "direct",
URL: "/",
} }
ProviderTwitter = Provider{ ProviderTwitter = Provider{
AuthURL: "https://twitter.com/i/oauth2/authorize", AuthURL: "https://twitter.com/i/oauth2/authorize",
Name: "Twitter", ClientID: "",
Photo: path.Join("static", "providers", "twitter.svg"), ClientSecret: "",
RedirectURL: path.Join("callback", "twitter"), Name: "Twitter",
Scopes: []string{ Photo: path.Join("static", "providers", "twitter.svg"),
"tweet.read", RedirectURL: path.Join("callback", "twitter"),
"users.read", Scopes: []string{"tweet.read", "users.read"},
}, TokenURL: "https://api.twitter.com/2/oauth2/token",
TokenURL: "https://api.twitter.com/2/oauth2/token", UID: "twitter",
UID: "twitter", URL: "https://twitter.com/",
URL: "https://twitter.com/",
} }
ProviderGitHub = Provider{ ProviderGitHub = Provider{
AuthURL: "https://github.com/login/oauth/authorize", AuthURL: "https://github.com/login/oauth/authorize",
Name: "GitHub", ClientID: "",
Photo: path.Join("static", "providers", "github.svg"), ClientSecret: "",
RedirectURL: path.Join("callback", "github"), Name: "GitHub",
Scopes: []string{ Photo: path.Join("static", "providers", "github.svg"),
"read:user", RedirectURL: path.Join("callback", "github"),
"user:email", Scopes: []string{"read:user", "user:email"},
}, TokenURL: "https://github.com/login/oauth/access_token",
TokenURL: "https://github.com/login/oauth/access_token", UID: "github",
UID: "github", URL: "https://github.com/",
URL: "https://github.com/",
} }
ProviderGitLab = Provider{ ProviderGitLab = Provider{
AuthURL: "https://gitlab.com/oauth/authorize", AuthURL: "https://gitlab.com/oauth/authorize",
Name: "GitLab", ClientID: "",
Photo: path.Join("static", "providers", "gitlab.svg"), ClientSecret: "",
RedirectURL: path.Join("callback", "gitlab"), Name: "GitLab",
Scopes: []string{ Photo: path.Join("static", "providers", "gitlab.svg"),
"read_user", RedirectURL: path.Join("callback", "gitlab"),
}, Scopes: []string{"read_user"},
TokenURL: "https://gitlab.com/oauth/token", TokenURL: "https://gitlab.com/oauth/token",
UID: "gitlab", UID: "gitlab",
URL: "https://gitlab.com/", URL: "https://gitlab.com/",
} }
ProviderMastodon = Provider{ ProviderMastodon = Provider{
AuthURL: "https://mstdn.io/oauth/authorize", AuthURL: "https://mstdn.io/oauth/authorize",
Name: "Mastodon", ClientID: "",
Photo: path.Join("static", "providers", "mastodon.svg"), ClientSecret: "",
RedirectURL: path.Join("callback", "mastodon"), Name: "Mastodon",
Scopes: []string{ Photo: path.Join("static", "providers", "mastodon.svg"),
"read:accounts", RedirectURL: path.Join("callback", "mastodon"),
}, Scopes: []string{"read:accounts"},
TokenURL: "https://mstdn.io/oauth/token", TokenURL: "https://mstdn.io/oauth/token",
UID: "mastodon", UID: "mastodon",
URL: "https://mstdn.io/", URL: "https://mstdn.io/",
} }
) )
// AuthCodeURL returns URL for authorize user in RelMeAuth client. // AuthCodeURL returns URL for authorize user in RelMeAuth client.
func (p Provider) AuthCodeURL(state string) string { func (p Provider) AuthCodeURL(state string) string {
u := http.AcquireURI() uri := http.AcquireURI()
defer http.ReleaseURI(u) defer http.ReleaseURI(uri)
u.Update(p.AuthURL) uri.Update(p.AuthURL)
for k, v := range map[string]string{ for key, val := range map[string]string{
"client_id": p.ClientID, "client_id": p.ClientID,
"redirect_uri": p.RedirectURL, "redirect_uri": p.RedirectURL,
"response_type": "code", "response_type": "code",
"scope": strings.Join(p.Scopes, " "), "scope": strings.Join(p.Scopes, " "),
"state": state, "state": state,
} { } {
u.QueryArgs().Set(k, v) uri.QueryArgs().Set(key, val)
} }
return u.String() return uri.String()
} }

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain package domain
import ( import (
@ -12,20 +13,20 @@ type ResponseType struct {
uid string uid string
} }
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants //nolint: gochecknoglobals // structs cannot be constants
var ( var (
ResponseTypeUndefined ResponseType = ResponseType{uid: ""} ResponseTypeUndefined = ResponseType{uid: ""}
// Deprecated(toby3d): Only accept response_type=code requests, and for // Deprecated(toby3d): Only accept response_type=code requests, and for
// backwards-compatible support, treat response_type=id requests as // backwards-compatible support, treat response_type=id requests as
// response_type=code requests: // response_type=code requests:
// https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type // https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
ResponseTypeID ResponseType = ResponseType{uid: "id"} ResponseTypeID = ResponseType{uid: "id"}
// Indicates to the authorization server that an authorization code // Indicates to the authorization server that an authorization code
// should be returned as the response: // should be returned as the response:
// https://indieauth.net/source/#authorization-request-li-1 // https://indieauth.net/source/#authorization-request-li-1
ResponseTypeCode ResponseType = ResponseType{uid: "code"} ResponseTypeCode = ResponseType{uid: "code"}
) )
var ErrResponseTypeUnknown error = NewError( var ErrResponseTypeUnknown error = NewError(
@ -65,7 +66,7 @@ func (rt *ResponseType) UnmarshalJSON(v []byte) error {
return fmt.Errorf("UnmarshalJSON: %w", err) return fmt.Errorf("UnmarshalJSON: %w", err)
} }
responseType, err := ParseResponseType(string(uid)) responseType, err := ParseResponseType(uid)
if err != nil { if err != nil {
return fmt.Errorf("UnmarshalJSON: %w", err) return fmt.Errorf("UnmarshalJSON: %w", err)
} }

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain_test package domain_test
import ( import (
@ -12,13 +13,10 @@ func TestResponseTypeType(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
in string in string
out domain.ResponseType out domain.ResponseType
}{{ }{
in: "id", {in: "id", out: domain.ResponseTypeID},
out: domain.ResponseTypeID, {in: "code", out: domain.ResponseTypeCode},
}, { } {
in: "code",
out: domain.ResponseTypeCode,
}} {
tc := tc tc := tc
t.Run(tc.in, func(t *testing.T) { t.Run(tc.in, func(t *testing.T) {
@ -73,15 +71,10 @@ func TestResponseType_String(t *testing.T) {
name string name string
in domain.ResponseType in domain.ResponseType
out string out string
}{{ }{
name: "id", {name: "id", in: domain.ResponseTypeID, out: "id"},
in: domain.ResponseTypeID, {name: "code", in: domain.ResponseTypeCode, out: "code"},
out: "id", } {
}, {
name: "code",
in: domain.ResponseTypeCode,
out: "code",
}} {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {

View File

@ -21,7 +21,7 @@ type (
var ErrScopeUnknown error = NewError(ErrorCodeInvalidRequest, "unknown scope", "https://indieweb.org/scope") var ErrScopeUnknown error = NewError(ErrorCodeInvalidRequest, "unknown scope", "https://indieweb.org/scope")
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants //nolint: gochecknoglobals // structs cannot be constants
var ( var (
ScopeUndefined = Scope{uid: ""} ScopeUndefined = Scope{uid: ""}
@ -56,7 +56,7 @@ var (
ScopeEmail = Scope{uid: "email"} ScopeEmail = Scope{uid: "email"}
) )
//nolint: gochecknoglobals // NOTE(toby3d): maps cannot be constants //nolint: gochecknoglobals // maps cannot be constants
var uidsScopes = map[string]Scope{ var uidsScopes = map[string]Scope{
ScopeBlock.uid: ScopeBlock, ScopeBlock.uid: ScopeBlock,
ScopeChannels.uid: ScopeChannels, ScopeChannels.uid: ScopeChannels,
@ -112,17 +112,17 @@ func (s *Scopes) UnmarshalJSON(v []byte) error {
result := make(Scopes, 0) result := make(Scopes, 0)
for _, scope := range strings.Fields(src) { for _, rawScope := range strings.Fields(src) {
s, err := ParseScope(scope) scope, err := ParseScope(rawScope)
if err != nil { if err != nil {
return fmt.Errorf("UnmarshalJSON: %w", err) return fmt.Errorf("UnmarshalJSON: %w", err)
} }
if result.Has(s) { if result.Has(scope) {
continue continue
} }
result = append(result, s) result = append(result, scope)
} }
*s = result *s = result

View File

@ -14,43 +14,20 @@ func TestParseScope(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
in string in string
out domain.Scope out domain.Scope
}{{ }{
in: "create", {in: "create", out: domain.ScopeCreate},
out: domain.ScopeCreate, {in: "delete", out: domain.ScopeDelete},
}, { {in: "draft", out: domain.ScopeDraft},
in: "delete", {in: "media", out: domain.ScopeMedia},
out: domain.ScopeDelete, {in: "update", out: domain.ScopeUpdate},
}, { {in: "block", out: domain.ScopeBlock},
in: "draft", {in: "channels", out: domain.ScopeChannels},
out: domain.ScopeDraft, {in: "follow", out: domain.ScopeFollow},
}, { {in: "mute", out: domain.ScopeMute},
in: "media", {in: "read", out: domain.ScopeRead},
out: domain.ScopeMedia, {in: "profile", out: domain.ScopeProfile},
}, { {in: "email", out: domain.ScopeEmail},
in: "update", } {
out: domain.ScopeUpdate,
}, {
in: "block",
out: domain.ScopeBlock,
}, {
in: "channels",
out: domain.ScopeChannels,
}, {
in: "follow",
out: domain.ScopeFollow,
}, {
in: "mute",
out: domain.ScopeMute,
}, {
in: "read",
out: domain.ScopeRead,
}, {
in: "profile",
out: domain.ScopeProfile,
}, {
in: "email",
out: domain.ScopeEmail,
}} {
tc := tc tc := tc
t.Run(tc.in, func(t *testing.T) { t.Run(tc.in, func(t *testing.T) {
@ -118,47 +95,24 @@ func TestScopes_MarshalJSON(t *testing.T) {
func TestScope_String(t *testing.T) { func TestScope_String(t *testing.T) {
t.Parallel() t.Parallel()
//nolint: paralleltest // NOTE(toby3d): false positive, tc.in is used. //nolint: paralleltest // false positive, in is used
for _, tc := range []struct { for _, tc := range []struct {
in domain.Scope in domain.Scope
out string out string
}{{ }{
in: domain.ScopeCreate, {in: domain.ScopeCreate, out: "create"},
out: "create", {in: domain.ScopeDelete, out: "delete"},
}, { {in: domain.ScopeDraft, out: "draft"},
in: domain.ScopeDelete, {in: domain.ScopeMedia, out: "media"},
out: "delete", {in: domain.ScopeUpdate, out: "update"},
}, { {in: domain.ScopeBlock, out: "block"},
in: domain.ScopeDraft, {in: domain.ScopeChannels, out: "channels"},
out: "draft", {in: domain.ScopeFollow, out: "follow"},
}, { {in: domain.ScopeMute, out: "mute"},
in: domain.ScopeMedia, {in: domain.ScopeRead, out: "read"},
out: "media", {in: domain.ScopeProfile, out: "profile"},
}, { {in: domain.ScopeEmail, out: "email"},
in: domain.ScopeUpdate, } {
out: "update",
}, {
in: domain.ScopeBlock,
out: "block",
}, {
in: domain.ScopeChannels,
out: "channels",
}, {
in: domain.ScopeFollow,
out: "follow",
}, {
in: domain.ScopeMute,
out: "mute",
}, {
in: domain.ScopeRead,
out: "read",
}, {
in: domain.ScopeProfile,
out: "profile",
}, {
in: domain.ScopeEmail,
out: "email",
}} {
tc := tc tc := tc
t.Run(fmt.Sprint(tc.in), func(t *testing.T) { t.Run(fmt.Sprint(tc.in), func(t *testing.T) {

View File

@ -8,22 +8,22 @@ import (
type Session struct { type Session struct {
ClientID *ClientID ClientID *ClientID
Me *Me
RedirectURI *URL RedirectURI *URL
Profile *Profile Me *Me
CodeChallengeMethod CodeChallengeMethod CodeChallengeMethod CodeChallengeMethod
Scope Scopes Scope Scopes
Code string
CodeChallenge string CodeChallenge string
Code string
} }
// TestSession returns valid random generated session for tests. // TestSession returns valid random generated session for tests.
//nolint: gomnd // testing domain can contains non-standart values
func TestSession(tb testing.TB) *Session { func TestSession(tb testing.TB) *Session {
tb.Helper() tb.Helper()
code, err := random.String(24) code, err := random.String(24)
if err != nil { if err != nil {
tb.Fatalf("%+v", err) tb.Fatal(err)
} }
return &Session{ return &Session{

View File

@ -33,13 +33,20 @@ type (
} }
) )
//nolint: gochecknoglobals // DefaultNewTokenOptions describes the default settings for NewToken.
//nolint: gochecknoglobals, gomnd
var DefaultNewTokenOptions = NewTokenOptions{ var DefaultNewTokenOptions = NewTokenOptions{
Algorithm: "HS256", Algorithm: "HS256",
Expiration: 0,
Issuer: nil,
NonceLength: 32, NonceLength: 32,
Scope: nil,
Secret: nil,
Subject: nil,
} }
// NewToken create a new token by provided options. // NewToken create a new token by provided options.
//nolint: cyclop
func NewToken(opts NewTokenOptions) (*Token, error) { func NewToken(opts NewTokenOptions) (*Token, error) {
if opts.NonceLength == 0 { if opts.NonceLength == 0 {
opts.NonceLength = DefaultNewTokenOptions.NonceLength opts.NonceLength = DefaultNewTokenOptions.NonceLength
@ -56,22 +63,33 @@ func NewToken(opts NewTokenOptions) (*Token, error) {
return nil, fmt.Errorf("cannot generate nonce: %w", err) return nil, fmt.Errorf("cannot generate nonce: %w", err)
} }
t := jwt.New() tkn := jwt.New()
t.Set(jwt.SubjectKey, opts.Subject.String())
t.Set(jwt.NotBeforeKey, now) for key, val := range map[string]interface{}{
t.Set(jwt.IssuedAtKey, now) "nonce": nonce,
t.Set("scope", opts.Scope) "scope": opts.Scope,
t.Set("nonce", nonce) jwt.IssuedAtKey: now,
jwt.NotBeforeKey: now,
jwt.SubjectKey: opts.Subject.String(),
} {
if err = tkn.Set(key, val); err != nil {
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
}
}
if opts.Issuer != nil { if opts.Issuer != nil {
t.Set(jwt.IssuerKey, opts.Issuer.String()) if err = tkn.Set(jwt.IssuerKey, opts.Issuer.String()); err != nil {
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
}
} }
if opts.Expiration != 0 { if opts.Expiration != 0 {
t.Set(jwt.ExpirationKey, now.Add(opts.Expiration)) if err = tkn.Set(jwt.ExpirationKey, now.Add(opts.Expiration)); err != nil {
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
}
} }
accessToken, err := jwt.Sign(t, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret) accessToken, err := jwt.Sign(tkn, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot sign a new access token: %w", err) return nil, fmt.Errorf("cannot sign a new access token: %w", err)
} }
@ -81,19 +99,20 @@ func NewToken(opts NewTokenOptions) (*Token, error) {
ClientID: opts.Issuer, ClientID: opts.Issuer,
Me: opts.Subject, Me: opts.Subject,
Scope: opts.Scope, Scope: opts.Scope,
}, err }, nil
} }
// TestToken returns valid random generated token for tests. // TestToken returns valid random generated token for tests.
//nolint: gomnd // testing domain can contains non-standart values
func TestToken(tb testing.TB) *Token { func TestToken(tb testing.TB) *Token {
tb.Helper() tb.Helper()
nonce, err := random.String(22) nonce, err := random.String(22)
if err != nil { if err != nil {
tb.Fatalf("%+v", err) tb.Fatal(err)
} }
t := jwt.New() tkn := jwt.New()
cid := TestClientID(tb) cid := TestClientID(tb)
me := TestMe(tb, "https://user.example.net/") me := TestMe(tb, "https://user.example.net/")
now := time.Now().UTC().Round(time.Second) now := time.Now().UTC().Round(time.Second)
@ -103,22 +122,25 @@ func TestToken(tb testing.TB) *Token {
ScopeUpdate, ScopeUpdate,
} }
// NOTE(toby3d): required for key, val := range map[string]interface{}{
t.Set(jwt.IssuerKey, cid.String()) // NOTE(toby3d): required
t.Set(jwt.SubjectKey, me.String()) jwt.IssuerKey: cid.String(),
// TODO(toby3d): t.Set(jwt.AudienceKey, nil) jwt.SubjectKey: me.String(),
t.Set(jwt.ExpirationKey, now.Add(1*time.Hour)) jwt.ExpirationKey: now.Add(1 * time.Hour),
t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour)) jwt.NotBeforeKey: now.Add(-1 * time.Hour),
t.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour)) jwt.IssuedAtKey: now.Add(-1 * time.Hour),
// TODO(toby3d): t.Set(jwt.JwtIDKey, nil) // TODO(toby3d): jwt.AudienceKey
// TODO(toby3d): jwt.JwtIDKey
// NOTE(toby3d): optional
"scope": scope,
"nonce": nonce,
} {
_ = tkn.Set(key, val)
}
// optional accessToken, err := jwt.Sign(tkn, jwa.HS256, []byte("hackme"))
t.Set("scope", scope)
t.Set("nonce", nonce)
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
if err != nil { if err != nil {
tb.Fatalf("%+v", err) tb.Fatal(err)
} }
return &Token{ return &Token{
@ -144,5 +166,5 @@ func (t Token) String() string {
return "" return ""
} }
return "Bearer " + string(t.AccessToken) return "Bearer " + t.AccessToken
} }

View File

@ -71,7 +71,7 @@ func (u URL) URL() *url.URL {
return nil return nil
} }
result, err := url.ParseRequestURI(u.URI.String()) result, err := url.ParseRequestURI(u.String())
if err != nil { if err != nil {
return nil return nil
} }

View File

@ -19,32 +19,32 @@ func TestParseURL(t *testing.T) {
func TestURL_UnmarshalForm(t *testing.T) { func TestURL_UnmarshalForm(t *testing.T) {
t.Parallel() t.Parallel()
u := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me") url := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
input := []byte(fmt.Sprint(u)) input := []byte(fmt.Sprint(url))
result := new(domain.URL) result := new(domain.URL)
if err := result.UnmarshalForm(input); err != nil { if err := result.UnmarshalForm(input); err != nil {
t.Fatalf("%+v", err) t.Fatalf("%+v", err)
} }
if fmt.Sprint(result) != fmt.Sprint(u) { if fmt.Sprint(result) != fmt.Sprint(url) {
t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, u) t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, url)
} }
} }
func TestURL_UnmarshalJSON(t *testing.T) { func TestURL_UnmarshalJSON(t *testing.T) {
t.Parallel() t.Parallel()
u := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me") url := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
input := []byte(fmt.Sprintf(`"%s"`, u)) input := []byte(fmt.Sprintf(`"%s"`, url))
result := new(domain.URL) result := new(domain.URL)
if err := result.UnmarshalJSON(input); err != nil { if err := result.UnmarshalJSON(input); err != nil {
t.Fatalf("%+v", err) t.Fatalf("%+v", err)
} }
if fmt.Sprint(result) != fmt.Sprint(u) { if fmt.Sprint(result) != fmt.Sprint(url) {
t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, u) t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, url)
} }
} }

View File

@ -11,7 +11,7 @@ import (
) )
type ( type (
//nolint: tagliatelle //nolint: tagliatelle // https://indieauth.net/source/#indieauth-server-metadata
MetadataResponse struct { MetadataResponse struct {
// The server's issuer identifier. The issuer identifier is a // The server's issuer identifier. The issuer identifier is a
// URL that uses the "https" scheme and has no query or fragment // URL that uses the "https" scheme and has no query or fragment
@ -72,14 +72,11 @@ type (
} }
) )
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be contants. // DefaultMetadataResponse contains all supported types by default.
//nolint: gochecknoglobals // structs cannot be constants
var DefaultMetadataResponse = MetadataResponse{ var DefaultMetadataResponse = MetadataResponse{
ServiceDocumentation: "https://indieauth.net/source/", AuthorizationEndpoint: "",
AuthorizationResponseIssParameterSupported: true, AuthorizationResponseIssParameterSupported: true,
ScopesSupported: []string{
domain.ScopeEmail.String(),
domain.ScopeProfile.String(),
},
CodeChallengeMethodsSupported: []string{ CodeChallengeMethodsSupported: []string{
domain.CodeChallengeMethodMD5.String(), domain.CodeChallengeMethodMD5.String(),
domain.CodeChallengeMethodPLAIN.String(), domain.CodeChallengeMethodPLAIN.String(),
@ -87,13 +84,21 @@ var DefaultMetadataResponse = MetadataResponse{
domain.CodeChallengeMethodS256.String(), domain.CodeChallengeMethodS256.String(),
domain.CodeChallengeMethodS512.String(), domain.CodeChallengeMethodS512.String(),
}, },
GrantTypesSupported: []string{
domain.GrantTypeAuthorizationCode.String(),
domain.GrantTypeTicket.String(),
},
Issuer: "",
ResponseTypesSupported: []string{ ResponseTypesSupported: []string{
domain.ResponseTypeCode.String(), domain.ResponseTypeCode.String(),
domain.ResponseTypeID.String(), domain.ResponseTypeID.String(),
}, },
GrantTypesSupported: []string{ ScopesSupported: []string{
domain.GrantTypeAuthorizationCode.String(), domain.ScopeEmail.String(),
domain.ScopeProfile.String(),
}, },
ServiceDocumentation: "https://indieauth.net/source/",
TokenEndpoint: "",
} }
func NewRequestHandler(config *domain.Config) *RequestHandler { func NewRequestHandler(config *domain.Config) *RequestHandler {
@ -118,5 +123,5 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) {
ctx.SetStatusCode(http.StatusOK) ctx.SetStatusCode(http.StatusOK)
ctx.SetContentType(common.MIMEApplicationJSON) ctx.SetContentType(common.MIMEApplicationJSON)
json.NewEncoder(ctx).Encode(&resp) _ = json.NewEncoder(ctx).Encode(&resp)
} }

View File

@ -19,6 +19,7 @@ func NewGithubProfileRepository() profile.Repository {
return &githubProfileRepository{} return &githubProfileRepository{}
} }
//nolint: cyclop
func (repo *githubProfileRepository) Get(ctx context.Context, token *oauth2.Token) (*domain.Profile, error) { func (repo *githubProfileRepository) Get(ctx context.Context, token *oauth2.Token) (*domain.Profile, error) {
user, _, err := github.NewClient(oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))).Users.Get(ctx, "") user, _, err := github.NewClient(oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))).Users.Get(ctx, "")
if err != nil { if err != nil {
@ -41,6 +42,7 @@ func (repo *githubProfileRepository) Get(ctx context.Context, token *oauth2.Toke
// NOTE(toby3d): Profile URLs. // NOTE(toby3d): Profile URLs.
result.URL = make([]*domain.URL, 0) result.URL = make([]*domain.URL, 0)
var twitterURL *string var twitterURL *string
if user.TwitterUsername != nil { if user.TwitterUsername != nil {

View File

@ -19,7 +19,7 @@ func NewGitlabProfileRepository() profile.Repository {
return &gitlabProfileRepository{} return &gitlabProfileRepository{}
} }
//nolint: funlen // NOTE(toby3d): uses hyphenation on new lines for readability. //nolint: funlen, cyclop
func (repo *gitlabProfileRepository) Get(_ context.Context, token *oauth2.Token) (*domain.Profile, error) { func (repo *gitlabProfileRepository) Get(_ context.Context, token *oauth2.Token) (*domain.Profile, error) {
client, err := gitlab.NewClient(token.AccessToken) client, err := gitlab.NewClient(token.AccessToken)
if err != nil { if err != nil {

View File

@ -25,8 +25,10 @@ func NewMastodonProfileRepository(server string) profile.Repository {
func (repo *mastodonProfileRepository) Get(ctx context.Context, token *oauth2.Token) (*domain.Profile, error) { func (repo *mastodonProfileRepository) Get(ctx context.Context, token *oauth2.Token) (*domain.Profile, error) {
account, err := mastodon.NewClient(&mastodon.Config{ account, err := mastodon.NewClient(&mastodon.Config{
Server: repo.server, AccessToken: token.AccessToken,
AccessToken: token.AccessToken, ClientID: "",
ClientSecret: "",
Server: repo.server,
}).GetAccountCurrentUser(ctx) }).GetAccountCurrentUser(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s: cannot get account info: %w", ErrPrefix, err) return nil, fmt.Errorf("%s: cannot get account info: %w", ErrPrefix, err)
@ -64,11 +66,9 @@ func (repo *mastodonProfileRepository) Get(ctx context.Context, token *oauth2.To
} }
u, err := domain.ParseURL(account.Fields[i].Value) u, err := domain.ParseURL(account.Fields[i].Value)
if err != nil { if err == nil {
continue result.URL = append(result.URL, u)
} }
result.URL = append(result.URL, u)
} }
// WARN(toby3d): Mastodon does not provide an email via API. // WARN(toby3d): Mastodon does not provide an email via API.

View File

@ -2,6 +2,7 @@ package random
import ( import (
"crypto/rand" "crypto/rand"
"fmt"
"math/big" "math/big"
"strings" "strings"
) )
@ -17,32 +18,31 @@ const (
) )
func Bytes(length int) ([]byte, error) { func Bytes(length int) ([]byte, error) {
b := make([]byte, length) bytes := make([]byte, length)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(bytes); err != nil {
return nil, err return nil, fmt.Errorf("cannot read bytes: %w", err)
} }
return b, nil return bytes, nil
} }
func String(length int, charsets ...string) (string, error) { func String(length int, charsets ...string) (string, error) {
charset := strings.Join(charsets, "") charset := strings.Join(charsets, "")
if charset == "" { if charset == "" {
charset = Alphabetic charset = Alphabetic
} }
b := make([]byte, length) bytes := make([]byte, length)
for i := range b { for i := range bytes {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("failed to randomize bytes: %w", err)
} }
b[i] = charset[n.Int64()] bytes[i] = charset[n.Int64()]
} }
return string(b), nil return string(bytes), nil
} }

View File

@ -27,8 +27,7 @@ type (
} }
sqlite3SessionRepository struct { sqlite3SessionRepository struct {
config *domain.Config db *sqlx.DB
db *sqlx.DB
} }
) )
@ -57,7 +56,7 @@ const (
WHERE code=$1;` WHERE code=$1;`
) )
func NewSQLite3SessionRepository(config *domain.Config, db *sqlx.DB) session.Repository { func NewSQLite3SessionRepository(db *sqlx.DB) session.Repository {
db.MustExec(QueryTable) db.MustExec(QueryTable)
return &sqlite3SessionRepository{ return &sqlite3SessionRepository{
@ -74,7 +73,7 @@ func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domai
} }
func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*domain.Session, error) { func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*domain.Session, error) {
s := new(Session) s := new(Session) //nolint: varnamelen // cannot redaclare import
if err := repo.db.GetContext(ctx, s, QueryGet, code); err != nil { if err := repo.db.GetContext(ctx, s, QueryGet, code); err != nil {
return nil, fmt.Errorf("cannot find session in db: %w", err) return nil, fmt.Errorf("cannot find session in db: %w", err)
} }
@ -86,16 +85,17 @@ func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*do
} }
func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) { func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) {
s := new(Session) s := new(Session) //nolint: varnamelen // cannot redaclare import
tx, err := repo.db.Beginx() tx, err := repo.db.Beginx()
if err != nil { if err != nil {
tx.Rollback() _ = tx.Rollback()
return nil, fmt.Errorf("failed to begin transaction: %w", err) return nil, fmt.Errorf("failed to begin transaction: %w", err)
} }
if err = tx.GetContext(ctx, s, QueryGet, code); err != nil { if err = tx.GetContext(ctx, s, QueryGet, code); err != nil {
//nolint: errcheck // deffered method
defer tx.Rollback() defer tx.Rollback()
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@ -106,7 +106,7 @@ func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code str
} }
if _, err = tx.ExecContext(ctx, QueryDelete, code); err != nil { if _, err = tx.ExecContext(ctx, QueryDelete, code); err != nil {
tx.Rollback() _ = tx.Rollback()
return nil, fmt.Errorf("cannot remove session from db: %w", err) return nil, fmt.Errorf("cannot remove session from db: %w", err)
} }

View File

@ -12,10 +12,9 @@ import (
"source.toby3d.me/website/indieauth/internal/testing/sqltest" "source.toby3d.me/website/indieauth/internal/testing/sqltest"
) )
//nolint: gochecknoglobals //nolint: gochecknoglobals // slices cannot be contants
var tableColumns = []string{ var tableColumns = []string{
"created_at", "client_id", "me", "redirect_uri", "code_challenge_method", "scope", "code", "created_at", "client_id", "me", "redirect_uri", "code_challenge_method", "scope", "code", "code_challenge",
"code_challenge",
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
@ -40,7 +39,7 @@ func TestCreate(t *testing.T) {
). ).
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnResult(sqlmock.NewResult(1, 1))
if err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). if err := repository.NewSQLite3SessionRepository(db).
Create(context.TODO(), session); err != nil { Create(context.TODO(), session); err != nil {
t.Error(err) t.Error(err)
} }
@ -69,7 +68,7 @@ func TestGet(t *testing.T) {
model.CodeChallenge, model.CodeChallenge,
)) ))
result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). result, err := repository.NewSQLite3SessionRepository(db).
Get(context.TODO(), session.Code) Get(context.TODO(), session.Code)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -108,7 +107,7 @@ func TestGetAndDelete(t *testing.T) {
WillReturnResult(sqlmock.NewResult(1, 1)) WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit() mock.ExpectCommit()
result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db). result, err := repository.NewSQLite3SessionRepository(db).
GetAndDelete(context.TODO(), session.Code) GetAndDelete(context.TODO(), session.Code)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -23,7 +23,6 @@ func New(tb testing.TB) (*bolt.DB, func()) {
db, err := bolt.Open(filePath, os.ModePerm, nil) db, err := bolt.Open(filePath, os.ModePerm, nil)
require.NoError(tb, err) require.NoError(tb, err)
//nolint: errcheck
return db, func() { return db, func() {
db.Close() db.Close()
os.Remove(filePath) os.Remove(filePath)

View File

@ -3,6 +3,8 @@ package httptest
import ( import (
"crypto/tls" "crypto/tls"
// used for running tests.
_ "embed" _ "embed"
"net" "net"
"testing" "testing"
@ -47,9 +49,8 @@ func New(tb testing.TB, handler http.RequestHandler) (*http.Client, *http.Server
}, },
} }
//nolint: errcheck
return client, server, func() { return client, server, func() {
server.Shutdown() _ = server.Shutdown()
} }
} }

View File

@ -7,7 +7,6 @@ import (
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
) )
type Time struct{} type Time struct{}
@ -24,17 +23,17 @@ func Open(tb testing.TB) (*sqlx.DB, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
if err != nil { if err != nil {
tb.Fatalf("%+v", err) tb.Fatal(err)
} }
xdb := sqlx.NewDb(db, "sqlite") xdb := sqlx.NewDb(db, "sqlite")
if err = xdb.Ping(); err != nil { if err = xdb.Ping(); err != nil {
_ = db.Close() _ = db.Close()
tb.Fatalf("%+v", err) tb.Fatal(err)
} }
return xdb, mock, func() { return xdb, mock, func() {
_ = db.Close() //nolint: errcheck _ = db.Close()
} }
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/fasthttp/router" "github.com/fasthttp/router"
"github.com/goccy/go-json" "github.com/goccy/go-json"
"github.com/lestrrat-go/jwx/jwa"
http "github.com/valyala/fasthttp" http "github.com/valyala/fasthttp"
"golang.org/x/text/language" "golang.org/x/text/language"
"golang.org/x/text/message" "golang.org/x/text/message"
@ -21,7 +22,7 @@ import (
) )
type ( type (
GenerateRequest struct { TicketGenerateRequest struct {
// The access token should be used when acting on behalf of this URL. // The access token should be used when acting on behalf of this URL.
Subject *domain.Me `form:"subject"` Subject *domain.Me `form:"subject"`
@ -29,7 +30,7 @@ type (
Resource *domain.URL `form:"resource"` Resource *domain.URL `form:"resource"`
} }
ExchangeRequest struct { TicketExchangeRequest struct {
// A random string that can be redeemed for an access token. // A random string that can be redeemed for an access token.
Ticket string `form:"ticket"` Ticket string `form:"ticket"`
@ -58,21 +59,40 @@ func NewRequestHandler(tickets ticket.UseCase, matcher language.Matcher, config
func (h *RequestHandler) Register(r *router.Router) { func (h *RequestHandler) Register(r *router.Router) {
chain := middleware.Chain{ chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{ middleware.CSRFWithConfig(middleware.CSRFConfig{
CookieSameSite: http.CookieSameSiteLaxMode,
ContextKey: "csrf",
CookieName: "_csrf",
TokenLookup: "form:_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
Skipper: func(ctx *http.RequestCtx) bool { Skipper: func(ctx *http.RequestCtx) bool {
matched, _ := path.Match("/ticket*", string(ctx.Path())) matched, _ := path.Match("/ticket*", string(ctx.Path()))
return ctx.IsPost() && matched return ctx.IsPost() && matched
}, },
CookieMaxAge: 0,
CookieSameSite: http.CookieSameSiteLaxMode,
ContextKey: "csrf",
CookieDomain: "",
CookieName: "_csrf",
CookiePath: "",
TokenLookup: "form:_csrf",
TokenLength: 0,
CookieSecure: true,
CookieHTTPOnly: true,
}),
middleware.JWTWithConfig(middleware.JWTConfig{
AuthScheme: "Bearer",
BeforeFunc: nil,
Claims: nil,
ContextKey: "user",
ErrorHandler: nil,
ErrorHandlerWithContext: nil,
ParseTokenFunc: nil,
SigningKey: []byte(h.config.JWT.Secret),
SigningKeys: nil,
SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm),
Skipper: middleware.DefaultSkipper,
SuccessHandler: nil,
TokenLookup: middleware.SourceHeader + ":" + http.HeaderAuthorization,
}), }),
middleware.LogFmt(), middleware.LogFmt(),
} }
// TODO(toby3d): secure this via JWT middleware
r.GET("/ticket", chain.RequestHandler(h.handleRender)) r.GET("/ticket", chain.RequestHandler(h.handleRender))
r.POST("/api/ticket", chain.RequestHandler(h.handleSend)) r.POST("/api/ticket", chain.RequestHandler(h.handleSend))
r.POST("/ticket", chain.RequestHandler(h.handleRedeem)) r.POST("/ticket", chain.RequestHandler(h.handleRedeem))
@ -102,10 +122,11 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest) req := new(TicketGenerateRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
@ -119,14 +140,16 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
var err error var err error
if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil { if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return return
} }
if err = h.tickets.Generate(ctx, ticket); err != nil { if err = h.tickets.Generate(ctx, ticket); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return return
} }
@ -140,10 +163,11 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest) req := new(TicketExchangeRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
@ -155,7 +179,8 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return return
} }
@ -170,7 +195,7 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
}`, token.AccessToken, token.Scope.String(), token.Me.String())) }`, token.AccessToken, token.Scope.String(), token.Me.String()))
} }
func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) { func (req *TicketGenerateRequest) bind(ctx *http.RequestCtx) (err error) {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil { if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil {
if errors.As(err, indieAuthError) { if errors.As(err, indieAuthError) {
@ -203,7 +228,7 @@ func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) {
return nil return nil
} }
func (req *ExchangeRequest) bind(ctx *http.RequestCtx) (err error) { func (req *TicketExchangeRequest) bind(ctx *http.RequestCtx) (err error) {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil { if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil {
if errors.As(err, indieAuthError) { if errors.As(err, indieAuthError) {

View File

@ -46,13 +46,13 @@ func TestUpdate(t *testing.T) {
userClient, _, userCleanup := httptest.New(t, userRouter.Handler) userClient, _, userCleanup := httptest.New(t, userRouter.Handler)
t.Cleanup(userCleanup) t.Cleanup(userCleanup)
r := router.New() router := router.New()
delivery.NewRequestHandler( delivery.NewRequestHandler(
ucase.NewTicketUseCase(ticketrepo.NewMemoryTicketRepository(new(sync.Map), config), userClient, config), ucase.NewTicketUseCase(ticketrepo.NewMemoryTicketRepository(new(sync.Map), config), userClient, config),
language.NewMatcher(message.DefaultCatalog.Languages()), config, language.NewMatcher(message.DefaultCatalog.Languages()), config,
).Register(r) ).Register(router)
client, _, cleanup := httptest.New(t, r.Handler) client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup) t.Cleanup(cleanup)
req := httptest.NewRequest(http.MethodPost, "https://example.com/ticket", []byte( req := httptest.NewRequest(http.MethodPost, "https://example.com/ticket", []byte(

View File

@ -59,8 +59,8 @@ func (repo *memoryTicketRepository) GC() {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
defer ticker.Stop() defer ticker.Stop()
for ts := range ticker.C { for timeStamp := range ticker.C {
ts := ts.UTC() timeStamp := timeStamp.UTC()
repo.store.Range(func(key, value interface{}) bool { repo.store.Range(func(key, value interface{}) bool {
k, ok := key.(string) k, ok := key.(string)
@ -78,7 +78,7 @@ func (repo *memoryTicketRepository) GC() {
return false return false
} }
if val.CreatedAt.Add(repo.config.Code.Expiry).After(ts) { if val.CreatedAt.Add(repo.config.Code.Expiry).After(timeStamp) {
return false return false
} }

View File

@ -63,16 +63,17 @@ func (repo *sqlite3TicketRepository) Create(ctx context.Context, t *domain.Ticke
return nil return nil
} }
func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, t string) (*domain.Ticket, error) { func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, rawTicket string) (*domain.Ticket, error) {
tx, err := repo.db.Beginx() tx, err := repo.db.Beginx()
if err != nil { if err != nil {
tx.Rollback() _ = tx.Rollback()
return nil, fmt.Errorf("failed to begin transaction: %w", err) return nil, fmt.Errorf("failed to begin transaction: %w", err)
} }
tkt := new(Ticket) tkt := new(Ticket)
if err = tx.GetContext(ctx, tkt, QueryGet, t); err != nil { if err = tx.GetContext(ctx, tkt, QueryGet, rawTicket); err != nil {
//nolint: errcheck // deffered method
defer tx.Rollback() defer tx.Rollback()
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@ -82,8 +83,8 @@ func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, t string)
return nil, fmt.Errorf("cannot find ticket in db: %w", err) return nil, fmt.Errorf("cannot find ticket in db: %w", err)
} }
if _, err = tx.ExecContext(ctx, QueryDelete, t); err != nil { if _, err = tx.ExecContext(ctx, QueryDelete, rawTicket); err != nil {
tx.Rollback() _ = tx.Rollback()
return nil, fmt.Errorf("cannot remove ticket from db: %w", err) return nil, fmt.Errorf("cannot remove ticket from db: %w", err)
} }

View File

@ -12,8 +12,8 @@ import (
repository "source.toby3d.me/website/indieauth/internal/ticket/repository/sqlite3" repository "source.toby3d.me/website/indieauth/internal/ticket/repository/sqlite3"
) )
//nolint: gochecknoglobals //nolint: gochecknoglobals // slices cannot be contants
var tableColumns []string = []string{"created_at", "resource", "subject", "ticket"} var tableColumns = []string{"created_at", "resource", "subject", "ticket"}
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -11,7 +11,6 @@ type UseCase interface {
// Redeem transform received ticket into access token. // Redeem transform received ticket into access token.
Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error) Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error)
Exchange(ctx context.Context, ticket string) (*domain.Token, error) Exchange(ctx context.Context, ticket string) (*domain.Token, error)
} }

View File

@ -14,7 +14,7 @@ import (
) )
type ( type (
//nolint: tagliatelle //nolint: tagliatelle // https://indieauth.net/source/#access-token-response
Response struct { Response struct {
Me *domain.Me `json:"me"` Me *domain.Me `json:"me"`
Scope domain.Scopes `json:"scope"` Scope domain.Scopes `json:"scope"`
@ -37,11 +37,11 @@ func NewTicketUseCase(tickets ticket.Repository, client *http.Client, config *do
} }
} }
func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) error { func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket) error {
req := http.AcquireRequest() req := http.AcquireRequest()
defer http.ReleaseRequest(req) defer http.ReleaseRequest(req)
req.Header.SetMethod(http.MethodGet) req.Header.SetMethod(http.MethodGet)
req.SetRequestURI(t.Subject.String()) req.SetRequestURI(tkt.Subject.String())
resp := http.AcquireResponse() resp := http.AcquireResponse()
defer http.ReleaseResponse(resp) defer http.ReleaseResponse(resp)
@ -65,7 +65,7 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
return ticket.ErrTicketEndpointNotExist return ticket.ErrTicketEndpointNotExist
} }
if err := useCase.tickets.Create(ctx, t); err != nil { if err := useCase.tickets.Create(ctx, tkt); err != nil {
return fmt.Errorf("cannot save ticket in store: %w", err) return fmt.Errorf("cannot save ticket in store: %w", err)
} }
@ -73,9 +73,9 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
req.Header.SetMethod(http.MethodPost) req.Header.SetMethod(http.MethodPost)
req.SetRequestURIBytes(ticketEndpoint.RequestURI()) req.SetRequestURIBytes(ticketEndpoint.RequestURI())
req.Header.SetContentType(common.MIMEApplicationForm) req.Header.SetContentType(common.MIMEApplicationForm)
req.PostArgs().Set("ticket", t.Ticket) req.PostArgs().Set("ticket", tkt.Ticket)
req.PostArgs().Set("subject", t.Subject.String()) req.PostArgs().Set("subject", tkt.Subject.String())
req.PostArgs().Set("resource", t.Resource.String()) req.PostArgs().Set("resource", tkt.Resource.String())
resp.Reset() resp.Reset()
if err := useCase.client.Do(req, resp); err != nil { if err := useCase.client.Do(req, resp); err != nil {
@ -85,10 +85,10 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
return nil return nil
} }
func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*domain.Token, error) { func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (*domain.Token, error) {
req := http.AcquireRequest() req := http.AcquireRequest()
defer http.ReleaseRequest(req) defer http.ReleaseRequest(req)
req.SetRequestURI(t.Resource.String()) req.SetRequestURI(tkt.Resource.String())
req.Header.SetMethod(http.MethodGet) req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse() resp := http.AcquireResponse()
@ -119,7 +119,7 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*do
req.Header.SetContentType(common.MIMEApplicationForm) req.Header.SetContentType(common.MIMEApplicationForm)
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON) req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String()) req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String())
req.PostArgs().Set("ticket", t.Ticket) req.PostArgs().Set("ticket", tkt.Ticket)
resp.Reset() resp.Reset()
if err := useCase.client.Do(req, resp); err != nil { if err := useCase.client.Do(req, resp); err != nil {
@ -142,19 +142,19 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*do
} }
func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*domain.Token, error) { func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*domain.Token, error) {
t, err := useCase.tickets.GetAndDelete(ctx, ticket) tkt, err := useCase.tickets.GetAndDelete(ctx, ticket)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot find provided ticket: %w", err) return nil, fmt.Errorf("cannot find provided ticket: %w", err)
} }
token, err := domain.NewToken(domain.NewTokenOptions{ token, err := domain.NewToken(domain.NewTokenOptions{
Algorithm: useCase.config.JWT.Algorithm, Expiration: useCase.config.JWT.Expiry,
Expiration: useCase.config.JWT.Expiry,
// TODO(toby3d): Issuer: &domain.ClientID{},
NonceLength: useCase.config.JWT.NonceLength,
Scope: domain.Scopes{domain.ScopeRead}, Scope: domain.Scopes{domain.ScopeRead},
Issuer: nil,
Subject: tkt.Subject,
Secret: []byte(useCase.config.JWT.Secret), Secret: []byte(useCase.config.JWT.Secret),
Subject: t.Subject, Algorithm: useCase.config.JWT.Algorithm,
NonceLength: useCase.config.JWT.NonceLength,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot generate a new access token: %w", err) return nil, fmt.Errorf("cannot generate a new access token: %w", err)

View File

@ -22,12 +22,12 @@ func TestRedeem(t *testing.T) {
token := domain.TestToken(t) token := domain.TestToken(t)
ticket := domain.TestTicket(t) ticket := domain.TestTicket(t)
r := router.New() router := router.New()
r.GET(string(ticket.Resource.Path()), func(ctx *http.RequestCtx) { router.GET(string(ticket.Resource.Path()), func(ctx *http.RequestCtx) {
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, `<link rel="token_endpoint" href="`+ ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, `<link rel="token_endpoint" href="`+
ticket.Subject.String()+`token">`) ticket.Subject.String()+`token">`)
}) })
r.POST("/token", func(ctx *http.RequestCtx) { router.POST("/token", func(ctx *http.RequestCtx) {
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, fmt.Sprintf(`{ ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, fmt.Sprintf(`{
"token_type": "Bearer", "token_type": "Bearer",
"access_token": "%s", "access_token": "%s",
@ -36,7 +36,7 @@ func TestRedeem(t *testing.T) {
}`, token.AccessToken, token.Scope.String(), token.Me.String())) }`, token.AccessToken, token.Scope.String(), token.Me.String()))
}) })
client, _, cleanup := httptest.New(t, r.Handler) client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup) t.Cleanup(cleanup)
result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)). result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)).

View File

@ -17,7 +17,7 @@ import (
) )
type ( type (
ExchangeRequest struct { TokenExchangeRequest struct {
ClientID *domain.ClientID `form:"client_id"` ClientID *domain.ClientID `form:"client_id"`
RedirectURI *domain.URL `form:"redirect_uri"` RedirectURI *domain.URL `form:"redirect_uri"`
GrantType domain.GrantType `form:"grant_type"` GrantType domain.GrantType `form:"grant_type"`
@ -25,40 +25,40 @@ type (
CodeVerifier string `form:"code_verifier"` CodeVerifier string `form:"code_verifier"`
} }
RevokeRequest struct { TokenRevokeRequest struct {
Action domain.Action `form:"action"` Action domain.Action `form:"action"`
Token string `form:"token"` Token string `form:"token"`
} }
TicketRequest struct { TokenTicketRequest struct {
Action domain.Action `form:"action"` Action domain.Action `form:"action"`
Ticket string `form:"ticket"` Ticket string `form:"ticket"`
} }
//nolint: tagliatelle //nolint: tagliatelle // https://indieauth.net/source/#access-token-response
ExchangeResponse struct { TokenExchangeResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
Scope string `json:"scope"` Scope string `json:"scope"`
Me string `json:"me"` Me string `json:"me"`
Profile *ProfileResponse `json:"profile,omitempty"` Profile *TokenProfileResponse `json:"profile,omitempty"`
} }
ProfileResponse struct { TokenProfileResponse struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
URL *domain.URL `json:"url,omitempty"` URL *domain.URL `json:"url,omitempty"`
Photo *domain.URL `json:"photo,omitempty"` Photo *domain.URL `json:"photo,omitempty"`
Email *domain.Email `json:"email,omitempty"` Email *domain.Email `json:"email,omitempty"`
} }
//nolint: tagliatelle //nolint: tagliatelle // https://indieauth.net/source/#access-token-verification-response
VerificationResponse struct { TokenVerificationResponse struct {
Me *domain.Me `json:"me"` Me *domain.Me `json:"me"`
ClientID *domain.ClientID `json:"client_id"` ClientID *domain.ClientID `json:"client_id"`
Scope domain.Scopes `json:"scope"` Scope domain.Scopes `json:"scope"`
} }
RevocationResponse struct{} TokenRevocationResponse struct{}
RequestHandler struct { RequestHandler struct {
tokens token.UseCase tokens token.UseCase
@ -88,11 +88,12 @@ func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(ctx)
t, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)), tkt, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)),
"Bearer ")) "Bearer "))
if err != nil || t == nil { if err != nil || tkt == nil {
ctx.SetStatusCode(http.StatusUnauthorized) ctx.SetStatusCode(http.StatusUnauthorized)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeUnauthorizedClient, domain.ErrorCodeUnauthorizedClient,
err.Error(), err.Error(),
"https://indieauth.net/source/#access-token-verification", "https://indieauth.net/source/#access-token-verification",
@ -101,10 +102,10 @@ func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
return return
} }
_ = encoder.Encode(&VerificationResponse{ _ = encoder.Encode(&TokenVerificationResponse{
ClientID: t.ClientID, ClientID: tkt.ClientID,
Me: t.Me, Me: tkt.Me,
Scope: t.Scope, Scope: tkt.Scope,
}) })
} }
@ -120,7 +121,8 @@ func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action"))) action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action")))
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeInvalidRequest, domain.ErrorCodeInvalidRequest,
err.Error(), err.Error(),
"", "",
@ -138,15 +140,17 @@ func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
} }
} }
//nolint: funlen
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) { func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8) ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest) req := new(TokenExchangeRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
@ -159,7 +163,8 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
}) })
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeInvalidRequest, domain.ErrorCodeInvalidRequest,
err.Error(), err.Error(),
"https://indieauth.net/source/#request", "https://indieauth.net/source/#request",
@ -168,20 +173,21 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
return return
} }
resp := &ExchangeResponse{ resp := &TokenExchangeResponse{
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
TokenType: "Bearer", TokenType: "Bearer",
Scope: token.Scope.String(), Scope: token.Scope.String(),
Me: token.Me.String(), Me: token.Me.String(),
Profile: nil,
} }
if profile == nil { if profile == nil {
encoder.Encode(resp) _ = encoder.Encode(resp)
return return
} }
resp.Profile = new(ProfileResponse) resp.Profile = new(TokenProfileResponse)
if len(profile.Name) > 0 { if len(profile.Name) > 0 {
resp.Profile.Name = profile.Name[0] resp.Profile.Name = profile.Name[0]
@ -208,17 +214,19 @@ func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(ctx)
req := new(RevokeRequest) req := new(TokenRevokeRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
if err := h.tokens.Revoke(ctx, req.Token); err != nil { if err := h.tokens.Revoke(ctx, req.Token); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeInvalidRequest, domain.ErrorCodeInvalidRequest,
err.Error(), err.Error(),
"", "",
@ -227,7 +235,7 @@ func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
return return
} }
_ = encoder.Encode(&RevocationResponse{}) _ = encoder.Encode(&TokenRevocationResponse{})
} }
func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) { func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
@ -236,18 +244,20 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx) encoder := json.NewEncoder(ctx)
req := new(TicketRequest) req := new(TokenTicketRequest)
if err := req.bind(ctx); err != nil { if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest) ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return return
} }
t, err := h.tickets.Exchange(ctx, req.Ticket) tkn, err := h.tickets.Exchange(ctx, req.Ticket)
if err != nil { if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError) ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeInvalidRequest, domain.ErrorCodeInvalidRequest,
err.Error(), err.Error(),
"https://indieauth.net/source/#request", "https://indieauth.net/source/#request",
@ -256,15 +266,16 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
return return
} }
encoder.Encode(ExchangeResponse{ _ = encoder.Encode(TokenExchangeResponse{
AccessToken: t.AccessToken, AccessToken: tkn.AccessToken,
TokenType: "Bearer", TokenType: "Bearer",
Scope: t.Scope.String(), Scope: tkn.Scope.String(),
Me: t.Me.String(), Me: tkn.Me.String(),
Profile: nil,
}) })
} }
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error { func (r *TokenExchangeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
if errors.As(err, indieAuthError) { if errors.As(err, indieAuthError) {
@ -281,7 +292,7 @@ func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
return nil return nil
} }
func (r *RevokeRequest) bind(ctx *http.RequestCtx) error { func (r *TokenRevokeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
if errors.As(err, indieAuthError) { if errors.As(err, indieAuthError) {
@ -298,7 +309,7 @@ func (r *RevokeRequest) bind(ctx *http.RequestCtx) error {
return nil return nil
} }
func (r *TicketRequest) bind(ctx *http.RequestCtx) error { func (r *TokenTicketRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error) indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil { if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
if errors.As(err, indieAuthError) { if errors.As(err, indieAuthError) {

View File

@ -36,7 +36,7 @@ func TestVerification(t *testing.T) {
config := domain.TestConfig(t) config := domain.TestConfig(t)
token := domain.TestToken(t) token := domain.TestToken(t)
r := router.New() router := router.New()
// TODO(toby3d): provide tickets // TODO(toby3d): provide tickets
delivery.NewRequestHandler( delivery.NewRequestHandler(
tokenucase.NewTokenUseCase( tokenucase.NewTokenUseCase(
@ -49,9 +49,9 @@ func TestVerification(t *testing.T) {
new(http.Client), new(http.Client),
config, config,
), ),
).Register(r) ).Register(router)
client, _, cleanup := httptest.New(t, r.Handler) client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup) t.Cleanup(cleanup)
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/token", nil) req := httptest.NewRequest(http.MethodGet, "https://app.example.com/token", nil)
@ -65,7 +65,7 @@ func TestVerification(t *testing.T) {
require.NoError(t, client.Do(req, resp)) require.NoError(t, client.Do(req, resp))
assert.Equal(t, http.StatusOK, resp.StatusCode()) assert.Equal(t, http.StatusOK, resp.StatusCode())
result := new(delivery.VerificationResponse) result := new(delivery.TokenVerificationResponse)
require.NoError(t, json.Unmarshal(resp.Body(), result)) require.NoError(t, json.Unmarshal(resp.Body(), result))
assert.Equal(t, token.ClientID.String(), result.ClientID.String()) assert.Equal(t, token.ClientID.String(), result.ClientID.String())
assert.Equal(t, token.Me.String(), result.Me.String()) assert.Equal(t, token.Me.String(), result.Me.String())
@ -80,7 +80,7 @@ func TestRevocation(t *testing.T) {
tokens := tokenrepo.NewMemoryTokenRepository(store) tokens := tokenrepo.NewMemoryTokenRepository(store)
accessToken := domain.TestToken(t) accessToken := domain.TestToken(t)
r := router.New() router := router.New()
delivery.NewRequestHandler( delivery.NewRequestHandler(
tokenucase.NewTokenUseCase( tokenucase.NewTokenUseCase(
tokens, tokens,
@ -92,9 +92,9 @@ func TestRevocation(t *testing.T) {
new(http.Client), new(http.Client),
config, config,
), ),
).Register(r) ).Register(router)
client, _, cleanup := httptest.New(t, r.Handler) client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup) t.Cleanup(cleanup)
req := httptest.NewRequest(http.MethodPost, "https://app.example.com/token", nil) req := httptest.NewRequest(http.MethodPost, "https://app.example.com/token", nil)

View File

@ -62,8 +62,8 @@ func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *dom
} }
func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) { func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
t := new(Token) tkn := new(Token)
if err := repo.db.GetContext(ctx, t, QueryGet, accessToken); err != nil { if err := repo.db.GetContext(ctx, tkn, QueryGet, accessToken); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, token.ErrNotExist return nil, token.ErrNotExist
} }
@ -72,7 +72,7 @@ func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string)
} }
result := new(domain.Token) result := new(domain.Token)
t.Populate(result) tkn.Populate(result)
return result, nil return result, nil
} }

View File

@ -12,14 +12,15 @@ import (
repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3" repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3"
) )
//nolint: gochecknoglobals //nolint: gochecknoglobals // slices cannot be contants
var tableColumns []string = []string{"created_at", "access_token", "client_id", "me", "scope"} var tableColumns = []string{"created_at", "access_token", "client_id", "me", "scope"}
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
t.Parallel() t.Parallel()
token := domain.TestToken(t) token := domain.TestToken(t)
model := repository.NewToken(token) model := repository.NewToken(token)
db, mock, cleanup := sqltest.Open(t) db, mock, cleanup := sqltest.Open(t)
t.Cleanup(cleanup) t.Cleanup(cleanup)
@ -44,6 +45,7 @@ func TestGet(t *testing.T) {
token := domain.TestToken(t) token := domain.TestToken(t)
model := repository.NewToken(token) model := repository.NewToken(token)
db, mock, cleanup := sqltest.Open(t) db, mock, cleanup := sqltest.Open(t)
t.Cleanup(cleanup) t.Cleanup(cleanup)

View File

@ -56,7 +56,7 @@ func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOp
return nil, nil, token.ErrEmptyScope return nil, nil, token.ErrEmptyScope
} }
t, err := domain.NewToken(domain.NewTokenOptions{ tkn, err := domain.NewToken(domain.NewTokenOptions{
Algorithm: useCase.config.JWT.Algorithm, Algorithm: useCase.config.JWT.Algorithm,
Expiration: useCase.config.JWT.Expiry, Expiration: useCase.config.JWT.Expiry,
Issuer: session.ClientID, Issuer: session.ClientID,
@ -70,14 +70,14 @@ func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOp
} }
if !session.Scope.Has(domain.ScopeProfile) { if !session.Scope.Has(domain.ScopeProfile) {
return t, nil, nil return tkn, nil, nil
} }
p := new(domain.Profile) p := new(domain.Profile)
// TODO(toby3d): if session.Scope.Has(domain.ScopeEmail) {} // TODO(toby3d): if session.Scope.Has(domain.ScopeEmail) {}
return t, p, nil return tkn, p, nil
} }
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) { func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
@ -90,23 +90,26 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
return nil, token.ErrRevoke return nil, token.ErrRevoke
} }
t, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm), tkn, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm),
[]byte(useCase.config.JWT.Secret))) []byte(useCase.config.JWT.Secret)))
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot parse JWT token: %w", err) return nil, fmt.Errorf("cannot parse JWT token: %w", err)
} }
if err = jwt.Validate(t); err != nil { if err = jwt.Validate(tkn); err != nil {
return nil, fmt.Errorf("cannot validate JWT token: %w", err) return nil, fmt.Errorf("cannot validate JWT token: %w", err)
} }
result := &domain.Token{ result := &domain.Token{
Scope: nil,
ClientID: nil,
Me: nil,
AccessToken: accessToken, AccessToken: accessToken,
} }
result.ClientID, _ = domain.ParseClientID(t.Issuer()) result.ClientID, _ = domain.ParseClientID(tkn.Issuer())
result.Me, _ = domain.ParseMe(t.Subject()) result.Me, _ = domain.ParseMe(tkn.Subject())
if scope, ok := t.Get("scope"); ok { if scope, ok := tkn.Get("scope"); ok {
result.Scope, _ = scope.(domain.Scopes) result.Scope, _ = scope.(domain.Scopes)
} }
@ -114,12 +117,12 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
} }
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error { func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
t, err := useCase.Verify(ctx, accessToken) tkn, err := useCase.Verify(ctx, accessToken)
if err != nil { if err != nil {
return fmt.Errorf("cannot verify token: %w", err) return fmt.Errorf("cannot verify token: %w", err)
} }
if err = useCase.tokens.Create(ctx, t); err != nil { if err = useCase.tokens.Create(ctx, tkn); err != nil {
return fmt.Errorf("cannot save token in database: %w", err) return fmt.Errorf("cannot save token in database: %w", err)
} }

View File

@ -52,31 +52,38 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain
// TODO(toby3d): handle error here? // TODO(toby3d): handle error here?
resolvedMe, _ := domain.ParseMe(string(resp.Header.Peek(http.HeaderLocation))) resolvedMe, _ := domain.ParseMe(string(resp.Header.Peek(http.HeaderLocation)))
u := &domain.User{
Me: resolvedMe, user := &domain.User{
AuthorizationEndpoint: nil,
IndieAuthMetadata: nil,
Me: resolvedMe,
Micropub: nil,
Microsub: nil,
Profile: &domain.Profile{ Profile: &domain.Profile{
Name: make([]string, 0),
URL: make([]*domain.URL, 0),
Photo: make([]*domain.URL, 0),
Email: make([]*domain.Email, 0), Email: make([]*domain.Email, 0),
Name: make([]string, 0),
Photo: make([]*domain.URL, 0),
URL: make([]*domain.URL, 0),
}, },
TicketEndpoint: nil,
TokenEndpoint: nil,
} }
metadata, err := util.ExtractMetadata(resp, repo.client) if metadata, err := util.ExtractMetadata(resp, repo.client); err == nil {
if err == nil && metadata != nil { user.AuthorizationEndpoint = metadata.AuthorizationEndpoint
u.AuthorizationEndpoint = metadata.AuthorizationEndpoint user.Micropub = metadata.Micropub
u.Micropub = metadata.Micropub user.Microsub = metadata.Microsub
u.Microsub = metadata.Microsub user.TicketEndpoint = metadata.TicketEndpoint
u.TicketEndpoint = metadata.TicketEndpoint user.TokenEndpoint = metadata.TokenEndpoint
u.TokenEndpoint = metadata.TokenEndpoint
} }
extractUser(u, resp) extractUser(user, resp)
extractProfile(u.Profile, resp) extractProfile(user.Profile, resp)
return u, nil return user, nil
} }
//nolint: cyclop
func extractUser(dst *domain.User, src *http.Response) { func extractUser(dst *domain.User, src *http.Response) {
if dst.IndieAuthMetadata != nil { if dst.IndieAuthMetadata != nil {
if endpoints := util.ExtractEndpoints(src, relIndieAuthMetadata); len(endpoints) > 0 { if endpoints := util.ExtractEndpoints(src, relIndieAuthMetadata); len(endpoints) > 0 {
@ -115,14 +122,12 @@ func extractUser(dst *domain.User, src *http.Response) {
} }
} }
//nolint: cyclop
func extractProfile(dst *domain.Profile, src *http.Response) { func extractProfile(dst *domain.Profile, src *http.Response) {
for _, name := range util.ExtractProperty(src, hCard, propertyName) { for _, name := range util.ExtractProperty(src, hCard, propertyName) {
n, ok := name.(string) if n, ok := name.(string); ok {
if !ok { dst.Name = append(dst.Name, n)
continue
} }
dst.Name = append(dst.Name, n)
} }
for _, rawEmail := range util.ExtractProperty(src, hCard, propertyEmail) { for _, rawEmail := range util.ExtractProperty(src, hCard, propertyEmail) {
@ -131,26 +136,20 @@ func extractProfile(dst *domain.Profile, src *http.Response) {
continue continue
} }
e, err := domain.ParseEmail(email) if e, err := domain.ParseEmail(email); err == nil {
if err != nil { dst.Email = append(dst.Email, e)
continue
} }
dst.Email = append(dst.Email, e)
} }
for _, rawUrl := range util.ExtractProperty(src, hCard, propertyURL) { for _, rawURL := range util.ExtractProperty(src, hCard, propertyURL) {
url, ok := rawUrl.(string) url, ok := rawURL.(string)
if !ok { if !ok {
continue continue
} }
u, err := domain.ParseURL(url) if u, err := domain.ParseURL(url); err == nil {
if err != nil { dst.URL = append(dst.URL, u)
continue
} }
dst.URL = append(dst.URL, u)
} }
for _, rawPhoto := range util.ExtractProperty(src, hCard, propertyPhoto) { for _, rawPhoto := range util.ExtractProperty(src, hCard, propertyPhoto) {
@ -159,11 +158,8 @@ func extractProfile(dst *domain.Profile, src *http.Response) {
continue continue
} }
p, err := domain.ParseURL(photo) if p, err := domain.ParseURL(photo); err == nil {
if err != nil { dst.Photo = append(dst.Photo, p)
continue
} }
dst.Photo = append(dst.Photo, p)
} }
} }

View File

@ -69,8 +69,8 @@ func TestGet(t *testing.T) {
func testHandler(tb testing.TB, user *domain.User) http.RequestHandler { func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
tb.Helper() tb.Helper()
r := router.New() router := router.New()
r.GET("/", func(ctx *http.RequestCtx) { router.GET("/", func(ctx *http.RequestCtx) {
ctx.Response.Header.Set(http.HeaderLink, strings.Join([]string{ ctx.Response.Header.Set(http.HeaderLink, strings.Join([]string{
`<` + user.AuthorizationEndpoint.String() + `>; rel="authorization_endpoint"`, `<` + user.AuthorizationEndpoint.String() + `>; rel="authorization_endpoint"`,
`<` + user.IndieAuthMetadata.String() + `>; rel="indieauth-metadata"`, `<` + user.IndieAuthMetadata.String() + `>; rel="indieauth-metadata"`,
@ -83,7 +83,7 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0], testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0],
)) ))
}) })
r.GET(string(user.IndieAuthMetadata.Path()), func(ctx *http.RequestCtx) { router.GET(string(user.IndieAuthMetadata.Path()), func(ctx *http.RequestCtx) {
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{ ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{
"issuer": "`+user.Me.String()+`", "issuer": "`+user.Me.String()+`",
"authorization_endpoint": "`+user.AuthorizationEndpoint.String()+`", "authorization_endpoint": "`+user.AuthorizationEndpoint.String()+`",
@ -91,5 +91,5 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
}`) }`)
}) })
return r.Handler return router.Handler
} }

View File

@ -3,6 +3,7 @@ package util
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"net/url" "net/url"
"strings" "strings"
@ -13,6 +14,12 @@ import (
"source.toby3d.me/website/indieauth/internal/domain" "source.toby3d.me/website/indieauth/internal/domain"
) )
var ErrEndpointNotExist = domain.NewError(
domain.ErrorCodeServerError,
"cannot found any endpoints",
"https://indieauth.net/source/#discovery-0",
)
func ExtractEndpoints(resp *http.Response, rel string) []*domain.URL { func ExtractEndpoints(resp *http.Response, rel string) []*domain.URL {
results := make([]*domain.URL, 0) results := make([]*domain.URL, 0)
@ -39,7 +46,7 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*domain.URL,
u := http.AcquireURI() u := http.AcquireURI()
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil { if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil {
return nil, err return nil, fmt.Errorf("cannot parse header endpoint: %w", err)
} }
results = append(results, &domain.URL{URI: u}) results = append(results, &domain.URL{URI: u})
@ -51,7 +58,7 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*domain.URL,
func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, error) { func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, error) {
endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), nil).Rels[rel] endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), nil).Rels[rel]
if !ok || len(endpoints) == 0 { if !ok || len(endpoints) == 0 {
return nil, nil return nil, ErrEndpointNotExist
} }
results := make([]*domain.URL, 0) results := make([]*domain.URL, 0)
@ -59,7 +66,7 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, e
for i := range endpoints { for i := range endpoints {
u := http.AcquireURI() u := http.AcquireURI()
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(endpoints[i])); err != nil { if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(endpoints[i])); err != nil {
return nil, err return nil, fmt.Errorf("cannot parse body endpoint: %w", err)
} }
results = append(results, &domain.URL{URI: u}) results = append(results, &domain.URL{URI: u})
@ -71,29 +78,30 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, e
func ExtractMetadata(resp *http.Response, client *http.Client) (*domain.Metadata, error) { func ExtractMetadata(resp *http.Response, client *http.Client) (*domain.Metadata, error) {
endpoints := ExtractEndpoints(resp, "indieauth-metadata") endpoints := ExtractEndpoints(resp, "indieauth-metadata")
if len(endpoints) == 0 { if len(endpoints) == 0 {
return nil, nil return nil, ErrEndpointNotExist
} }
_, body, err := client.Get(nil, endpoints[len(endpoints)-1].String()) _, body, err := client.Get(nil, endpoints[len(endpoints)-1].String())
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err)
} }
result := new(domain.Metadata) result := new(domain.Metadata)
if err = json.Unmarshal(body, result); err != nil { if err = json.Unmarshal(body, result); err != nil {
return nil, err return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err)
} }
return result, nil return result, nil
} }
func ExtractProperty(resp *http.Response, t, key string) []interface{} { func ExtractProperty(resp *http.Response, itemType, key string) []interface{} {
//nolint: exhaustivestruct // only Host part in url.URL is needed
data := microformats.Parse(bytes.NewReader(resp.Body()), &url.URL{ data := microformats.Parse(bytes.NewReader(resp.Body()), &url.URL{
Host: string(resp.Header.Peek(http.HeaderHost)), Host: string(resp.Header.Peek(http.HeaderHost)),
}) })
for _, item := range data.Items { for _, item := range data.Items {
if !contains(item.Type, t) { if !contains(item.Type, itemType) {
continue continue
} }

24
main.go
View File

@ -55,7 +55,8 @@ import (
const ( const (
DefaultCacheDuration time.Duration = 8760 * time.Hour // NOTE(toby3d): year DefaultCacheDuration time.Duration = 8760 * time.Hour // NOTE(toby3d): year
DefaultPort int = 3000 DefaultReadTimeout time.Duration = 5 * time.Second
DefaultWriteTimeout time.Duration = 10 * time.Second
) )
//nolint: gochecknoglobals //nolint: gochecknoglobals
@ -126,7 +127,7 @@ func init() {
client.RedirectURI = []*domain.URL{redirectURI} client.RedirectURI = []*domain.URL{redirectURI}
} }
//nolint: funlen //nolint: funlen, cyclop // "god object" and the entry point of all modules
func main() { func main() {
var ( var (
tokens token.Repository tokens token.Repository
@ -146,7 +147,7 @@ func main() {
} }
tokens = tokensqlite3repo.NewSQLite3TokenRepository(store) tokens = tokensqlite3repo.NewSQLite3TokenRepository(store)
sessions = sessionsqlite3repo.NewSQLite3SessionRepository(config, store) sessions = sessionsqlite3repo.NewSQLite3SessionRepository(store)
tickets = ticketsqlite3repo.NewSQLite3TicketRepository(store, config) tickets = ticketsqlite3repo.NewSQLite3TicketRepository(store, config)
case "memory": case "memory":
store := new(sync.Map) store := new(sync.Map)
@ -160,12 +161,11 @@ func main() {
go sessions.GC() go sessions.GC()
matcher := language.NewMatcher(message.DefaultCatalog.Languages()) matcher := language.NewMatcher(message.DefaultCatalog.Languages())
//nolint: exhaustivestruct // too many options
httpClient := &http.Client{ httpClient := &http.Client{
Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()), Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()),
MaxConnDuration: 10 * time.Second, ReadTimeout: DefaultReadTimeout,
ReadTimeout: 10 * time.Second, WriteTimeout: DefaultWriteTimeout,
WriteTimeout: 10 * time.Second,
MaxConnWaitTimeout: 10 * time.Second,
} }
ticketService := ticketucase.NewTicketUseCase(tickets, httpClient, config) ticketService := ticketucase.NewTicketUseCase(tickets, httpClient, config)
tokenService := tokenucase.NewTokenUseCase(tokens, sessions, config) tokenService := tokenucase.NewTokenUseCase(tokens, sessions, config)
@ -187,6 +187,7 @@ func main() {
Matcher: matcher, Matcher: matcher,
Config: config, Config: config,
}).Register(r) }).Register(r)
//nolint: exhaustivestruct// too many options
r.ServeFilesCustom(path.Join(config.Server.StaticURLPrefix, "{filepath:*}"), &http.FS{ r.ServeFilesCustom(path.Join(config.Server.StaticURLPrefix, "{filepath:*}"), &http.FS{
Root: config.Server.StaticRootPath, Root: config.Server.StaticRootPath,
CacheDuration: DefaultCacheDuration, CacheDuration: DefaultCacheDuration,
@ -200,11 +201,12 @@ func main() {
r.GET("/debug/pprof/{filepath:*}", pprofhandler.PprofHandler) r.GET("/debug/pprof/{filepath:*}", pprofhandler.PprofHandler)
} }
//nolint: exhaustivestruct
server := &http.Server{ server := &http.Server{
Name: fmt.Sprintf("IndieAuth/0.1 (+%s)", config.Server.GetAddress()), Name: fmt.Sprintf("IndieAuth/0.1 (+%s)", config.Server.GetAddress()),
Handler: r.Handler, Handler: r.Handler,
ReadTimeout: 10 * time.Second, ReadTimeout: DefaultReadTimeout,
WriteTimeout: 10 * time.Second, WriteTimeout: DefaultWriteTimeout,
DisableKeepalive: true, DisableKeepalive: true,
ReduceMemoryUsage: true, ReduceMemoryUsage: true,
SecureErrorLogMessage: true, SecureErrorLogMessage: true,
@ -212,7 +214,7 @@ func main() {
} }
done := make(chan os.Signal, 1) done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL) signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
if cpuProfilePath != "" { if cpuProfilePath != "" {
cpuProfile, err := os.Create(cpuProfilePath) cpuProfile, err := os.Create(cpuProfilePath)