🚨 Removed linter warnings
This commit is contained in:
parent
7680845f74
commit
59d4c4988a
|
@ -1,15 +1,54 @@
|
|||
---
|
||||
run:
|
||||
tests: true
|
||||
skip-dirs:
|
||||
- locales
|
||||
- testdata
|
||||
- web
|
||||
skip-dirs-use-default: true
|
||||
skip-files:
|
||||
- ".*_gen\\.go$"
|
||||
output:
|
||||
sort-results: true
|
||||
linters-settings:
|
||||
lll:
|
||||
tab-width: 8
|
||||
gci:
|
||||
local-prefixes: source.toby3d.me
|
||||
goimports:
|
||||
local-prefixes: source.toby3d.me
|
||||
ireturn:
|
||||
allow:
|
||||
- "(Repository|UseCase)$"
|
||||
- error
|
||||
- stdlib
|
||||
lll:
|
||||
tab-width: 8
|
||||
varnamelen:
|
||||
ignore-type-assert-ok: true
|
||||
ignore-map-index-ok: true
|
||||
ignore-chan-recv-ok: true
|
||||
ignore-names:
|
||||
- ctx # context
|
||||
- db # dataBase
|
||||
- err # error
|
||||
- i # index
|
||||
- ip
|
||||
- ln # listener
|
||||
- me
|
||||
- ok
|
||||
- tc # testCase
|
||||
- ts # timeStamp
|
||||
- tx # transaction
|
||||
ignore-decls:
|
||||
- "cid *domain.ClientID"
|
||||
- "ctx *fasthttp.RequestCtx"
|
||||
- "ctx context.Context"
|
||||
- "i int"
|
||||
- "me *domain.Me"
|
||||
- "r *router.Router"
|
||||
linters:
|
||||
enable-all: true
|
||||
disable:
|
||||
- godox
|
||||
issues:
|
||||
exclude-rules:
|
||||
- source: "^//go:generate "
|
||||
|
|
1
go.mod
1
go.mod
|
@ -14,6 +14,7 @@ require (
|
|||
github.com/hashicorp/go-retryablehttp v0.7.0 // indirect
|
||||
github.com/jmoiron/sqlx v1.3.4
|
||||
github.com/klauspost/compress v1.14.2 // indirect
|
||||
github.com/lestrrat-go/iter v1.0.1 // indirect
|
||||
github.com/lestrrat-go/jwx v1.2.18
|
||||
github.com/mattn/go-mastodon v0.0.4
|
||||
github.com/spf13/afero v1.8.0 // indirect
|
||||
|
|
|
@ -23,7 +23,7 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
AuthorizeRequest struct {
|
||||
AuthAuthorizeRequest struct {
|
||||
// Indicates to the authorization server that an authorization
|
||||
// code should be returned as the response.
|
||||
ResponseType domain.ResponseType `form:"response_type"` // code
|
||||
|
@ -58,7 +58,7 @@ type (
|
|||
Me *domain.Me `form:"me"`
|
||||
}
|
||||
|
||||
VerifyRequest struct {
|
||||
AuthVerifyRequest struct {
|
||||
ClientID *domain.ClientID `form:"client_id"`
|
||||
Me *domain.Me `form:"me"`
|
||||
RedirectURI *domain.URL `form:"redirect_uri"`
|
||||
|
@ -71,7 +71,7 @@ type (
|
|||
Provider string `form:"provider"`
|
||||
}
|
||||
|
||||
ExchangeRequest struct {
|
||||
AuthExchangeRequest struct {
|
||||
GrantType domain.GrantType `form:"grant_type"` // authorization_code
|
||||
|
||||
// The authorization code received from the authorization
|
||||
|
@ -91,50 +91,52 @@ type (
|
|||
CodeVerifier string `form:"code_verifier"`
|
||||
}
|
||||
|
||||
ExchangeResponse struct {
|
||||
AuthExchangeResponse struct {
|
||||
Me *domain.Me `json:"me"`
|
||||
}
|
||||
|
||||
NewRequestHandlerOptions struct {
|
||||
Auth auth.UseCase
|
||||
Clients client.UseCase
|
||||
Config *domain.Config
|
||||
Matcher language.Matcher
|
||||
Providers []*domain.Provider
|
||||
Auth auth.UseCase
|
||||
Clients client.UseCase
|
||||
Config *domain.Config
|
||||
Matcher language.Matcher
|
||||
}
|
||||
|
||||
RequestHandler struct {
|
||||
clients client.UseCase
|
||||
config *domain.Config
|
||||
matcher language.Matcher
|
||||
useCase auth.UseCase
|
||||
providers []*domain.Provider
|
||||
clients client.UseCase
|
||||
config *domain.Config
|
||||
matcher language.Matcher
|
||||
useCase auth.UseCase
|
||||
}
|
||||
)
|
||||
|
||||
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
|
||||
return &RequestHandler{
|
||||
clients: opts.Clients,
|
||||
config: opts.Config,
|
||||
matcher: opts.Matcher,
|
||||
useCase: opts.Auth,
|
||||
providers: opts.Providers,
|
||||
clients: opts.Clients,
|
||||
config: opts.Config,
|
||||
matcher: opts.Matcher,
|
||||
useCase: opts.Auth,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RequestHandler) Register(r *router.Router) {
|
||||
chain := middleware.Chain{
|
||||
middleware.CSRFWithConfig(middleware.CSRFConfig{
|
||||
CookieSameSite: http.CookieSameSiteStrictMode,
|
||||
CookieName: "_csrf",
|
||||
TokenLookup: "form:_csrf",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
Skipper: func(ctx *http.RequestCtx) bool {
|
||||
matched, _ := path.Match("/api/*", string(ctx.Path()))
|
||||
|
||||
return ctx.IsPost() && matched
|
||||
},
|
||||
CookieMaxAge: 0,
|
||||
CookieSameSite: http.CookieSameSiteStrictMode,
|
||||
ContextKey: "",
|
||||
CookieDomain: "",
|
||||
CookieName: "_csrf",
|
||||
CookiePath: "",
|
||||
TokenLookup: "form:_csrf",
|
||||
TokenLength: 0,
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
}),
|
||||
middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{
|
||||
Skipper: func(ctx *http.RequestCtx) bool {
|
||||
|
@ -145,15 +147,14 @@ func (h *RequestHandler) Register(r *router.Router) {
|
|||
return !ctx.IsPost() || !matched || providerMatched
|
||||
},
|
||||
Validator: func(ctx *http.RequestCtx, login, password string) (bool, error) {
|
||||
userMatch := subtle.ConstantTimeCompare(
|
||||
[]byte(login), []byte(h.config.IndieAuth.Username),
|
||||
)
|
||||
passMatch := subtle.ConstantTimeCompare(
|
||||
[]byte(password), []byte(h.config.IndieAuth.Password),
|
||||
)
|
||||
userMatch := subtle.ConstantTimeCompare([]byte(login),
|
||||
[]byte(h.config.IndieAuth.Username))
|
||||
passMatch := subtle.ConstantTimeCompare([]byte(password),
|
||||
[]byte(h.config.IndieAuth.Password))
|
||||
|
||||
return userMatch == 1 && passMatch == 1, nil
|
||||
},
|
||||
Realm: "",
|
||||
}),
|
||||
middleware.LogFmt(),
|
||||
}
|
||||
|
@ -164,7 +165,6 @@ func (h *RequestHandler) Register(r *router.Router) {
|
|||
}
|
||||
|
||||
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
|
||||
req := new(AuthorizeRequest)
|
||||
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
|
||||
|
||||
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
|
||||
|
@ -175,6 +175,7 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
|
|||
Printer: message.NewPrinter(tag),
|
||||
}
|
||||
|
||||
req := new(AuthAuthorizeRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
web.WriteTemplate(ctx, &web.ErrorPage{
|
||||
|
@ -213,15 +214,16 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
|
|||
csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte)
|
||||
web.WriteTemplate(ctx, &web.AuthorizePage{
|
||||
BaseOf: baseOf,
|
||||
Client: client,
|
||||
CodeChallenge: req.CodeChallenge,
|
||||
CodeChallengeMethod: req.CodeChallengeMethod,
|
||||
CSRF: csrf,
|
||||
Scope: req.Scope,
|
||||
Client: client,
|
||||
Me: req.Me,
|
||||
RedirectURI: req.RedirectURI,
|
||||
CodeChallengeMethod: req.CodeChallengeMethod,
|
||||
ResponseType: req.ResponseType,
|
||||
Scope: req.Scope,
|
||||
CodeChallenge: req.CodeChallenge,
|
||||
State: req.State,
|
||||
Providers: make([]*domain.Provider, 0), // TODO(toby3d)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -230,22 +232,23 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
|
|||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(VerifyRequest)
|
||||
req := new(AuthVerifyRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
u := http.AcquireURI()
|
||||
defer http.ReleaseURI(u)
|
||||
req.RedirectURI.CopyTo(u)
|
||||
redirectURL := http.AcquireURI()
|
||||
defer http.ReleaseURI(redirectURL)
|
||||
req.RedirectURI.CopyTo(redirectURL)
|
||||
|
||||
if strings.EqualFold(req.Authorize, "deny") {
|
||||
domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State).
|
||||
SetReirectURI(u)
|
||||
ctx.Redirect(u.String(), http.StatusFound)
|
||||
SetReirectURI(redirectURL)
|
||||
ctx.Redirect(redirectURL.String(), http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -260,7 +263,8 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
|
|||
})
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusInternalServerError)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -270,10 +274,10 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
|
|||
"iss": h.config.Server.GetRootURL(),
|
||||
"state": req.State,
|
||||
} {
|
||||
u.QueryArgs().Set(key, val)
|
||||
redirectURL.QueryArgs().Set(key, val)
|
||||
}
|
||||
|
||||
ctx.Redirect(u.String(), http.StatusFound)
|
||||
ctx.Redirect(redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
|
||||
|
@ -281,10 +285,11 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
|
|||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(ExchangeRequest)
|
||||
req := new(AuthExchangeRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -297,17 +302,18 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
|
|||
})
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
encoder.Encode(&ExchangeResponse{
|
||||
_ = encoder.Encode(&AuthExchangeResponse{
|
||||
Me: me,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error {
|
||||
func (r *AuthAuthorizeRequest) bind(ctx *http.RequestCtx) error {
|
||||
indieAuthError := new(domain.Error)
|
||||
if err := form.Unmarshal(ctx.QueryArgs(), r); err != nil {
|
||||
if errors.As(err, indieAuthError) {
|
||||
|
@ -341,7 +347,7 @@ func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
|
||||
func (r *AuthVerifyRequest) bind(ctx *http.RequestCtx) error {
|
||||
indieAuthError := new(domain.Error)
|
||||
|
||||
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||
|
@ -369,6 +375,8 @@ func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
|
|||
)
|
||||
}
|
||||
|
||||
// NOTE(toby3d): backwards-compatible support.
|
||||
// See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
|
||||
if r.ResponseType == domain.ResponseTypeID {
|
||||
r.ResponseType = domain.ResponseTypeCode
|
||||
}
|
||||
|
@ -386,7 +394,7 @@ func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
|
||||
func (r *AuthExchangeRequest) bind(ctx *http.RequestCtx) error {
|
||||
indieAuthError := new(domain.Error)
|
||||
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||
if errors.As(err, indieAuthError) {
|
||||
|
|
|
@ -35,15 +35,15 @@ func TestRender(t *testing.T) {
|
|||
require.NoError(t, s.SetProvider(provider))
|
||||
|
||||
me := domain.TestMe(t, "https://user.example.net")
|
||||
c := domain.TestClient(t)
|
||||
client := domain.TestClient(t)
|
||||
config := domain.TestConfig(t)
|
||||
store := new(sync.Map)
|
||||
user := domain.TestUser(t)
|
||||
store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), user)
|
||||
store.Store(path.Join(clientrepo.DefaultPathPrefix, c.ID.String()), c)
|
||||
store.Store(path.Join(clientrepo.DefaultPathPrefix, client.ID.String()), client)
|
||||
store.Store(path.Join(profilerepo.DefaultPathPrefix, me.String()), user.Profile)
|
||||
|
||||
r := router.New()
|
||||
router := router.New()
|
||||
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
|
||||
Clients: clientucase.NewClientUseCase(clientrepo.NewMemoryClientRepository(store)),
|
||||
Config: config,
|
||||
|
@ -52,35 +52,35 @@ func TestRender(t *testing.T) {
|
|||
sessionrepo.NewMemorySessionRepository(config, store),
|
||||
config,
|
||||
),
|
||||
}).Register(r)
|
||||
}).Register(router)
|
||||
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
httpClient, _, cleanup := httptest.New(t, router.Handler)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
u := http.AcquireURI()
|
||||
defer http.ReleaseURI(u)
|
||||
u.Update("https://example.com/authorize")
|
||||
uri := http.AcquireURI()
|
||||
defer http.ReleaseURI(uri)
|
||||
uri.Update("https://example.com/authorize")
|
||||
|
||||
for k, v := range map[string]string{
|
||||
"client_id": c.ID.String(),
|
||||
for key, val := range map[string]string{
|
||||
"client_id": client.ID.String(),
|
||||
"code_challenge": "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo",
|
||||
"code_challenge_method": domain.CodeChallengeMethodS256.String(),
|
||||
"me": me.String(),
|
||||
"redirect_uri": c.RedirectURI[0].String(),
|
||||
"redirect_uri": client.RedirectURI[0].String(),
|
||||
"response_type": domain.ResponseTypeCode.String(),
|
||||
"scope": "profile email",
|
||||
"state": "1234567890",
|
||||
} {
|
||||
u.QueryArgs().Set(k, v)
|
||||
uri.QueryArgs().Set(key, val)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, u.String(), nil)
|
||||
req := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
||||
defer http.ReleaseRequest(req)
|
||||
|
||||
resp := http.AcquireResponse()
|
||||
defer http.ReleaseResponse(resp)
|
||||
|
||||
require.NoError(t, client.Do(req, resp))
|
||||
require.NoError(t, httpClient.Do(req, resp))
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode())
|
||||
assert.Contains(t, string(resp.Body()), `Authorize application`)
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
CallbackRequest struct {
|
||||
ClientCallbackRequest struct {
|
||||
Iss *domain.ClientID `form:"iss"`
|
||||
Code string `form:"code"`
|
||||
Error string `form:"error"`
|
||||
|
@ -60,13 +60,14 @@ func (h *RequestHandler) Register(r *router.Router) {
|
|||
}
|
||||
|
||||
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
|
||||
redirectUri := make([]string, len(h.client.RedirectURI))
|
||||
redirect := make([]string, len(h.client.RedirectURI))
|
||||
|
||||
for i := range h.client.RedirectURI {
|
||||
redirectUri[i] = h.client.RedirectURI[i].String()
|
||||
redirect[i] = h.client.RedirectURI[i].String()
|
||||
}
|
||||
|
||||
ctx.Response.Header.Set(
|
||||
http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
|
||||
http.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
|
||||
)
|
||||
|
||||
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
|
||||
|
@ -87,7 +88,7 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
|
|||
}
|
||||
|
||||
func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
|
||||
req := new(CallbackRequest)
|
||||
req := new(ClientCallbackRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
|
||||
|
@ -134,7 +135,7 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
|
|||
})
|
||||
}
|
||||
|
||||
func (req *CallbackRequest) bind(ctx *http.RequestCtx) error {
|
||||
func (req *ClientCallbackRequest) bind(ctx *http.RequestCtx) error {
|
||||
indieAuthError := new(domain.Error)
|
||||
|
||||
if err := form.Unmarshal(ctx.QueryArgs(), req); err != nil {
|
||||
|
|
|
@ -25,16 +25,16 @@ func TestRead(t *testing.T) {
|
|||
store := new(sync.Map)
|
||||
config := domain.TestConfig(t)
|
||||
|
||||
r := router.New()
|
||||
router := router.New()
|
||||
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
|
||||
Client: domain.TestClient(t),
|
||||
Config: config,
|
||||
Matcher: language.NewMatcher(message.DefaultCatalog.Languages()),
|
||||
Tokens: tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store),
|
||||
sessionrepo.NewMemorySessionRepository(config, store), config),
|
||||
}).Register(r)
|
||||
}).Register(router)
|
||||
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
client, _, cleanup := httptest.New(t, router.Handler)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/", nil)
|
||||
|
|
|
@ -10,4 +10,8 @@ type Repository interface {
|
|||
Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
|
||||
}
|
||||
|
||||
var ErrNotExist error = domain.NewError(domain.ErrorCodeInvalidClient, "client with the specified ID does not exist", "")
|
||||
var ErrNotExist error = domain.NewError(
|
||||
domain.ErrorCodeInvalidClient,
|
||||
"client with the specified ID does not exist",
|
||||
"",
|
||||
)
|
||||
|
|
|
@ -32,10 +32,10 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
|
|||
}
|
||||
}
|
||||
|
||||
func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
|
||||
func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) {
|
||||
req := http.AcquireRequest()
|
||||
defer http.ReleaseRequest(req)
|
||||
req.SetRequestURI(id.String())
|
||||
req.SetRequestURI(cid.String())
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
resp := http.AcquireResponse()
|
||||
|
@ -50,7 +50,7 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
|
|||
}
|
||||
|
||||
client := &domain.Client{
|
||||
ID: id,
|
||||
ID: cid,
|
||||
RedirectURI: make([]*domain.URL, 0),
|
||||
Logo: make([]*domain.URL, 0),
|
||||
URL: make([]*domain.URL, 0),
|
||||
|
@ -62,68 +62,52 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
|
|||
return client, nil
|
||||
}
|
||||
|
||||
//nolint: gocognit, cyclop
|
||||
func extract(dst *domain.Client, src *http.Response) {
|
||||
for _, u := range util.ExtractEndpoints(src, relRedirectURI) {
|
||||
if containsURL(dst.RedirectURI, u) {
|
||||
continue
|
||||
for _, endpoint := range util.ExtractEndpoints(src, relRedirectURI) {
|
||||
if !containsURL(dst.RedirectURI, endpoint) {
|
||||
dst.RedirectURI = append(dst.RedirectURI, endpoint)
|
||||
}
|
||||
|
||||
dst.RedirectURI = append(dst.RedirectURI, u)
|
||||
}
|
||||
|
||||
for _, t := range []string{hXApp, hApp} {
|
||||
for _, name := range util.ExtractProperty(src, t, propertyName) {
|
||||
n, ok := name.(string)
|
||||
if !ok || containsString(dst.Name, n) {
|
||||
continue
|
||||
for _, itemType := range []string{hXApp, hApp} {
|
||||
for _, name := range util.ExtractProperty(src, itemType, propertyName) {
|
||||
if n, ok := name.(string); ok && !containsString(dst.Name, n) {
|
||||
dst.Name = append(dst.Name, n)
|
||||
}
|
||||
|
||||
dst.Name = append(dst.Name, n)
|
||||
}
|
||||
|
||||
for _, logo := range util.ExtractProperty(src, t, propertyLogo) {
|
||||
var err error
|
||||
for _, logo := range util.ExtractProperty(src, itemType, propertyLogo) {
|
||||
var (
|
||||
uri *domain.URL
|
||||
err error
|
||||
)
|
||||
|
||||
var u *domain.URL
|
||||
switch l := logo.(type) {
|
||||
case string:
|
||||
u, err = domain.ParseURL(l)
|
||||
uri, err = domain.ParseURL(l)
|
||||
case map[string]string:
|
||||
value, ok := l["value"]
|
||||
if !ok {
|
||||
continue
|
||||
if value, ok := l["value"]; ok {
|
||||
uri, err = domain.ParseURL(value)
|
||||
}
|
||||
|
||||
u, err = domain.ParseURL(value)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err != nil || containsURL(dst.Logo, uri) {
|
||||
continue
|
||||
}
|
||||
|
||||
if containsURL(dst.Logo, u) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Logo = append(dst.Logo, u)
|
||||
dst.Logo = append(dst.Logo, uri)
|
||||
}
|
||||
|
||||
for _, url := range util.ExtractProperty(src, t, propertyURL) {
|
||||
l, ok := url.(string)
|
||||
for _, property := range util.ExtractProperty(src, itemType, propertyURL) {
|
||||
prop, ok := property.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
u, err := domain.ParseURL(l)
|
||||
if err != nil {
|
||||
continue
|
||||
if u, err := domain.ParseURL(prop); err == nil || !containsURL(dst.URL, u) {
|
||||
dst.URL = append(dst.URL, u)
|
||||
}
|
||||
|
||||
if containsURL(dst.URL, u) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.URL = append(dst.URL, u)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint: dupl
|
||||
package domain
|
||||
|
||||
import (
|
||||
|
@ -14,7 +15,7 @@ type Action struct {
|
|||
uid string
|
||||
}
|
||||
|
||||
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
|
||||
//nolint: gochecknoglobals // structs cannot be constants
|
||||
var (
|
||||
ActionUndefined = Action{uid: ""}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint: dupl
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
|
@ -12,13 +13,10 @@ func TestParseAction(t *testing.T) {
|
|||
for _, tc := range []struct {
|
||||
in string
|
||||
out domain.Action
|
||||
}{{
|
||||
in: "revoke",
|
||||
out: domain.ActionRevoke,
|
||||
}, {
|
||||
in: "ticket",
|
||||
out: domain.ActionTicket,
|
||||
}} {
|
||||
}{
|
||||
{in: "revoke", out: domain.ActionRevoke},
|
||||
{in: "ticket", out: domain.ActionTicket},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.in, func(t *testing.T) {
|
||||
|
@ -73,15 +71,10 @@ func TestAction_String(t *testing.T) {
|
|||
name string
|
||||
in domain.Action
|
||||
out string
|
||||
}{{
|
||||
name: "revoke",
|
||||
in: domain.ActionRevoke,
|
||||
out: "revoke",
|
||||
}, {
|
||||
name: "ticket",
|
||||
in: domain.ActionTicket,
|
||||
out: "ticket",
|
||||
}} {
|
||||
}{
|
||||
{name: "revoke", in: domain.ActionRevoke, out: "revoke"},
|
||||
{name: "ticket", in: domain.ActionTicket, out: "ticket"},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
|
|
@ -8,111 +8,94 @@ import (
|
|||
"testing"
|
||||
|
||||
http "github.com/valyala/fasthttp"
|
||||
"golang.org/x/xerrors"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// ClientID is a URL client identifier.
|
||||
type ClientID struct {
|
||||
clientID *http.URI
|
||||
valid bool
|
||||
}
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
//nolint: gochecknoglobals // slices cannot be constants
|
||||
var (
|
||||
localhostIPv4 = netaddr.MustParseIP("127.0.0.1")
|
||||
localhostIPv6 = netaddr.MustParseIP("::1")
|
||||
)
|
||||
|
||||
// ParseClientID parse string as client ID URL identifier.
|
||||
//nolint: funlen
|
||||
//nolint: funlen, cyclop
|
||||
func ParseClientID(src string) (*ClientID, error) {
|
||||
cid := http.AcquireURI()
|
||||
if err := cid.Parse(nil, []byte(src)); err != nil {
|
||||
return nil, Error{
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Description: err.Error(),
|
||||
URI: "https://indieauth.net/source/#client-identifier",
|
||||
State: "",
|
||||
frame: xerrors.Caller(1),
|
||||
}
|
||||
return nil, NewError(
|
||||
ErrorCodeInvalidRequest,
|
||||
err.Error(),
|
||||
"https://indieauth.net/source/#client-identifier",
|
||||
)
|
||||
}
|
||||
|
||||
scheme := string(cid.Scheme())
|
||||
if scheme != "http" && scheme != "https" {
|
||||
return nil, Error{
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Description: "client identifier URL MUST have either an https or http scheme",
|
||||
URI: "https://indieauth.net/source/#client-identifier",
|
||||
State: "",
|
||||
frame: xerrors.Caller(1),
|
||||
}
|
||||
return nil, NewError(
|
||||
ErrorCodeInvalidRequest,
|
||||
"client identifier URL MUST have either an https or http scheme",
|
||||
"https://indieauth.net/source/#client-identifier",
|
||||
)
|
||||
}
|
||||
|
||||
path := string(cid.PathOriginal())
|
||||
if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") {
|
||||
return nil, Error{
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Description: "client identifier URL MUST contain a path component and MUST NOT contain " +
|
||||
return nil, NewError(
|
||||
ErrorCodeInvalidRequest,
|
||||
"client identifier URL MUST contain a path component and MUST NOT contain "+
|
||||
"single-dot or double-dot path segments",
|
||||
URI: "https://indieauth.net/source/#client-identifier",
|
||||
State: "",
|
||||
frame: xerrors.Caller(1),
|
||||
}
|
||||
"https://indieauth.net/source/#client-identifier",
|
||||
)
|
||||
}
|
||||
|
||||
if cid.Hash() != nil {
|
||||
return nil, Error{
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Description: "client identifier URL MUST NOT contain a fragment component",
|
||||
URI: "https://indieauth.net/source/#client-identifier",
|
||||
State: "",
|
||||
frame: xerrors.Caller(1),
|
||||
}
|
||||
return nil, NewError(
|
||||
ErrorCodeInvalidRequest,
|
||||
"client identifier URL MUST NOT contain a fragment component",
|
||||
"https://indieauth.net/source/#client-identifier",
|
||||
)
|
||||
}
|
||||
|
||||
if cid.Username() != nil || cid.Password() != nil {
|
||||
return nil, Error{
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Description: "client identifier URL MUST NOT contain a username or password component",
|
||||
URI: "https://indieauth.net/source/#client-identifier",
|
||||
State: "",
|
||||
frame: xerrors.Caller(1),
|
||||
}
|
||||
return nil, NewError(
|
||||
ErrorCodeInvalidRequest,
|
||||
"client identifier URL MUST NOT contain a username or password component",
|
||||
"https://indieauth.net/source/#client-identifier",
|
||||
)
|
||||
}
|
||||
|
||||
domain := string(cid.Host())
|
||||
if domain == "" {
|
||||
return nil, Error{
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Description: "client host name MUST be domain name or a loopback interface",
|
||||
URI: "https://indieauth.net/source/#client-identifier",
|
||||
State: "",
|
||||
frame: xerrors.Caller(1),
|
||||
}
|
||||
return nil, NewError(
|
||||
ErrorCodeInvalidRequest,
|
||||
"client host name MUST be domain name or a loopback interface",
|
||||
"https://indieauth.net/source/#client-identifier",
|
||||
)
|
||||
}
|
||||
|
||||
ip, err := netaddr.ParseIP(domain)
|
||||
if err != nil {
|
||||
ipPort, err := netaddr.ParseIPPort(domain)
|
||||
if err != nil {
|
||||
return &ClientID{
|
||||
clientID: cid,
|
||||
}, nil
|
||||
//nolint: nilerr // ClientID does not contain an IP address, so it is valid
|
||||
return &ClientID{clientID: cid}, nil
|
||||
}
|
||||
|
||||
ip = ipPort.IP()
|
||||
}
|
||||
|
||||
if !ip.IsLoopback() && ip.Compare(localhostIPv4) != 0 && ip.Compare(localhostIPv6) != 0 {
|
||||
return nil, Error{
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Description: "client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 " +
|
||||
return nil, NewError(
|
||||
ErrorCodeInvalidRequest,
|
||||
"client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 "+
|
||||
"127.0.0.1 or IPv6 [::1]",
|
||||
URI: "https://indieauth.net/source/#client-identifier",
|
||||
State: "",
|
||||
frame: xerrors.Caller(1),
|
||||
}
|
||||
"https://indieauth.net/source/#client-identifier",
|
||||
)
|
||||
}
|
||||
|
||||
return &ClientID{
|
||||
|
@ -126,7 +109,7 @@ func TestClientID(tb testing.TB) *ClientID {
|
|||
|
||||
clientID, err := ParseClientID("https://app.example.com/")
|
||||
if err != nil {
|
||||
tb.Fatalf("%+v", err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return clientID
|
||||
|
@ -167,23 +150,28 @@ func (cid ClientID) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// URI returns copy of parsed *fasthttp.URI.
|
||||
// This copy MUST be released via fasthttp.ReleaseURI.
|
||||
//
|
||||
// WARN(toby3d): This copy MUST be released via fasthttp.ReleaseURI.
|
||||
func (cid ClientID) URI() *http.URI {
|
||||
u := http.AcquireURI()
|
||||
cid.clientID.CopyTo(u)
|
||||
uri := http.AcquireURI()
|
||||
cid.clientID.CopyTo(uri)
|
||||
|
||||
return u
|
||||
return uri
|
||||
}
|
||||
|
||||
// URL returns url.URL representation of client ID.
|
||||
func (cid ClientID) URL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: string(cid.clientID.Scheme()),
|
||||
Host: string(cid.clientID.Host()),
|
||||
Path: string(cid.clientID.Path()),
|
||||
RawPath: string(cid.clientID.PathOriginal()),
|
||||
RawQuery: string(cid.clientID.QueryString()),
|
||||
Fragment: string(cid.clientID.Hash()),
|
||||
ForceQuery: false,
|
||||
Fragment: string(cid.clientID.Hash()),
|
||||
Host: string(cid.clientID.Host()),
|
||||
Opaque: "",
|
||||
Path: string(cid.clientID.Path()),
|
||||
RawFragment: "",
|
||||
RawPath: string(cid.clientID.PathOriginal()),
|
||||
RawQuery: string(cid.clientID.QueryString()),
|
||||
Scheme: string(cid.clientID.Scheme()),
|
||||
User: nil,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
)
|
||||
|
||||
//nolint: funlen
|
||||
func TestParseClientID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -15,51 +14,19 @@ func TestParseClientID(t *testing.T) {
|
|||
name string
|
||||
in string
|
||||
expError bool
|
||||
}{{
|
||||
name: "valid",
|
||||
in: "https://example.com/",
|
||||
expError: false,
|
||||
}, {
|
||||
name: "valid path",
|
||||
in: "https://example.com/username",
|
||||
expError: false,
|
||||
}, {
|
||||
name: "valid query",
|
||||
in: "https://example.com/users?id=100",
|
||||
expError: false,
|
||||
}, {
|
||||
name: "valid port",
|
||||
in: "https://example.com:8443/",
|
||||
expError: false,
|
||||
}, {
|
||||
name: "valid loopback",
|
||||
in: "https://127.0.0.1:8443/",
|
||||
expError: false,
|
||||
}, {
|
||||
name: "missing scheme",
|
||||
in: "example.com",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "invalid scheme",
|
||||
in: "mailto:user@example.com",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "invalid double-dot path",
|
||||
in: "https://example.com/foo/../bar",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "invalid fragment",
|
||||
in: "https://example.com/#me",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "invalid user",
|
||||
in: "https://user:pass@example.com/",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "host is an IP address",
|
||||
in: "https://172.28.92.51/",
|
||||
expError: true,
|
||||
}} {
|
||||
}{
|
||||
{name: "valid", in: "https://example.com/", expError: false},
|
||||
{name: "valid path", in: "https://example.com/username", expError: false},
|
||||
{name: "valid query", in: "https://example.com/users?id=100", expError: false},
|
||||
{name: "valid port", in: "https://example.com:8443/", expError: false},
|
||||
{name: "valid loopback", in: "https://127.0.0.1:8443/", expError: false},
|
||||
{name: "missing scheme", in: "example.com", expError: true},
|
||||
{name: "invalid scheme", in: "mailto:user@example.com", expError: true},
|
||||
{name: "invalid double-dot path", in: "https://example.com/foo/../bar", expError: true},
|
||||
{name: "invalid fragment", in: "https://example.com/#me", expError: true},
|
||||
{name: "invalid user", in: "https://user:pass@example.com/", expError: true},
|
||||
{name: "host is an IP address", in: "https://172.28.92.51/", expError: true},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -131,8 +98,7 @@ func TestClientID_MarshalJSON(t *testing.T) {
|
|||
func TestClientID_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cid := domain.TestClientID(t)
|
||||
if result := cid.String(); result != fmt.Sprint(cid) {
|
||||
t.Errorf("Strig() = %s, want %s", result, fmt.Sprint(cid))
|
||||
if cid := domain.TestClientID(t); cid.String() != fmt.Sprint(cid) {
|
||||
t.Errorf("String() = %s, want %s", cid.String(), fmt.Sprint(cid))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,13 +15,10 @@ func TestClient_ValidateRedirectURI(t *testing.T) {
|
|||
for _, tc := range []struct {
|
||||
name string
|
||||
in *domain.URL
|
||||
}{{
|
||||
name: "client_id prefix",
|
||||
in: domain.TestURL(t, fmt.Sprint(client.ID, "/callback")),
|
||||
}, {
|
||||
name: "registered redirect_uri",
|
||||
in: client.RedirectURI[len(client.RedirectURI)-1],
|
||||
}} {
|
||||
}{
|
||||
{name: "client_id prefix", in: domain.TestURL(t, fmt.Sprint(client.ID, "/callback"))},
|
||||
{name: "registered redirect_uri", in: client.RedirectURI[len(client.RedirectURI)-1]},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package domain
|
||||
|
||||
//nolint: gosec // support old clients
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
|
@ -21,7 +22,7 @@ type CodeChallengeMethod struct {
|
|||
uid string
|
||||
}
|
||||
|
||||
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
|
||||
//nolint: gochecknoglobals // structs cannot be constants
|
||||
var (
|
||||
CodeChallengeMethodUndefined = CodeChallengeMethod{
|
||||
uid: "",
|
||||
|
@ -34,12 +35,14 @@ var (
|
|||
}
|
||||
|
||||
CodeChallengeMethodMD5 = CodeChallengeMethod{
|
||||
uid: "MD5",
|
||||
uid: "MD5",
|
||||
//nolint: gosec // support old clients
|
||||
hash: md5.New(),
|
||||
}
|
||||
|
||||
CodeChallengeMethodS1 = CodeChallengeMethod{
|
||||
uid: "S1",
|
||||
uid: "S1",
|
||||
//nolint: gosec // support old clients
|
||||
hash: sha1.New(),
|
||||
}
|
||||
|
||||
|
@ -60,7 +63,7 @@ var ErrCodeChallengeMethodUnknown error = NewError(
|
|||
"https://indieauth.net/source/#authorization-request",
|
||||
)
|
||||
|
||||
//nolint: gochecknoglobals // NOTE(toby3d): maps cannot be constants
|
||||
//nolint: gochecknoglobals // maps cannot be constants
|
||||
var slugsMethods = map[string]CodeChallengeMethod{
|
||||
CodeChallengeMethodMD5.uid: CodeChallengeMethodMD5,
|
||||
CodeChallengeMethodPLAIN.uid: CodeChallengeMethodPLAIN,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package domain_test
|
||||
|
||||
//nolint: gosec // support old clients
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
|
@ -23,32 +24,14 @@ func TestParseCodeChallengeMethod(t *testing.T) {
|
|||
in string
|
||||
out domain.CodeChallengeMethod
|
||||
expError bool
|
||||
}{{
|
||||
expError: true,
|
||||
name: "invalid",
|
||||
in: "und",
|
||||
out: domain.CodeChallengeMethodUndefined,
|
||||
}, {
|
||||
name: "PLAIN",
|
||||
in: "plain",
|
||||
out: domain.CodeChallengeMethodPLAIN,
|
||||
}, {
|
||||
name: "MD5",
|
||||
in: "Md5",
|
||||
out: domain.CodeChallengeMethodMD5,
|
||||
}, {
|
||||
name: "S1",
|
||||
in: "S1",
|
||||
out: domain.CodeChallengeMethodS1,
|
||||
}, {
|
||||
name: "S256",
|
||||
in: "S256",
|
||||
out: domain.CodeChallengeMethodS256,
|
||||
}, {
|
||||
name: "S512",
|
||||
in: "S512",
|
||||
out: domain.CodeChallengeMethodS512,
|
||||
}} {
|
||||
}{
|
||||
{name: "invalid", in: "und", out: domain.CodeChallengeMethodUndefined, expError: true},
|
||||
{name: "PLAIN", in: "plain", out: domain.CodeChallengeMethodPLAIN, expError: false},
|
||||
{name: "MD5", in: "Md5", out: domain.CodeChallengeMethodMD5, expError: false},
|
||||
{name: "S1", in: "S1", out: domain.CodeChallengeMethodS1, expError: false},
|
||||
{name: "S256", in: "S256", out: domain.CodeChallengeMethodS256, expError: false},
|
||||
{name: "S512", in: "S512", out: domain.CodeChallengeMethodS512, expError: false},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -107,27 +90,13 @@ func TestCodeChallengeMethod_String(t *testing.T) {
|
|||
name string
|
||||
in domain.CodeChallengeMethod
|
||||
out string
|
||||
}{{
|
||||
name: "plain",
|
||||
in: domain.CodeChallengeMethodPLAIN,
|
||||
out: "PLAIN",
|
||||
}, {
|
||||
name: "md5",
|
||||
in: domain.CodeChallengeMethodMD5,
|
||||
out: "MD5",
|
||||
}, {
|
||||
name: "s1",
|
||||
in: domain.CodeChallengeMethodS1,
|
||||
out: "S1",
|
||||
}, {
|
||||
name: "s256",
|
||||
in: domain.CodeChallengeMethodS256,
|
||||
out: "S256",
|
||||
}, {
|
||||
name: "s512",
|
||||
in: domain.CodeChallengeMethodS512,
|
||||
out: "S512",
|
||||
}} {
|
||||
}{
|
||||
{name: "plain", in: domain.CodeChallengeMethodPLAIN, out: "PLAIN"},
|
||||
{name: "md5", in: domain.CodeChallengeMethodMD5, out: "MD5"},
|
||||
{name: "s1", in: domain.CodeChallengeMethodS1, out: "S1"},
|
||||
{name: "s256", in: domain.CodeChallengeMethodS256, out: "S256"},
|
||||
{name: "s512", in: domain.CodeChallengeMethodS512, out: "S512"},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -141,7 +110,7 @@ func TestCodeChallengeMethod_String(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
//nolint: funlen
|
||||
//nolint: gosec // support old clients
|
||||
func TestCodeChallengeMethod_Validate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -155,42 +124,15 @@ func TestCodeChallengeMethod_Validate(t *testing.T) {
|
|||
in domain.CodeChallengeMethod
|
||||
name string
|
||||
expError bool
|
||||
}{{
|
||||
name: "invalid",
|
||||
in: domain.CodeChallengeMethodS256,
|
||||
hash: md5.New(),
|
||||
expError: true,
|
||||
}, {
|
||||
name: "MD5",
|
||||
in: domain.CodeChallengeMethodMD5,
|
||||
hash: md5.New(),
|
||||
expError: false,
|
||||
}, {
|
||||
name: "plain",
|
||||
in: domain.CodeChallengeMethodPLAIN,
|
||||
hash: nil,
|
||||
expError: false,
|
||||
}, {
|
||||
name: "S1",
|
||||
in: domain.CodeChallengeMethodS1,
|
||||
hash: sha1.New(),
|
||||
expError: false,
|
||||
}, {
|
||||
name: "S256",
|
||||
in: domain.CodeChallengeMethodS256,
|
||||
hash: sha256.New(),
|
||||
expError: false,
|
||||
}, {
|
||||
name: "S512",
|
||||
in: domain.CodeChallengeMethodS512,
|
||||
hash: sha512.New(),
|
||||
expError: false,
|
||||
}, {
|
||||
name: "undefined",
|
||||
in: domain.CodeChallengeMethodUndefined,
|
||||
hash: nil,
|
||||
expError: true,
|
||||
}} {
|
||||
}{
|
||||
{name: "invalid", in: domain.CodeChallengeMethodS256, hash: md5.New(), expError: true},
|
||||
{name: "MD5", in: domain.CodeChallengeMethodMD5, hash: md5.New(), expError: false},
|
||||
{name: "plain", in: domain.CodeChallengeMethodPLAIN, hash: nil, expError: false},
|
||||
{name: "S1", in: domain.CodeChallengeMethodS1, hash: sha1.New(), expError: false},
|
||||
{name: "S256", in: domain.CodeChallengeMethodS256, hash: sha256.New(), expError: false},
|
||||
{name: "S512", in: domain.CodeChallengeMethodS512, hash: sha512.New(), expError: false},
|
||||
{name: "undefined", in: domain.CodeChallengeMethodUndefined, hash: nil, expError: true},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
|
|
@ -78,6 +78,7 @@ type (
|
|||
)
|
||||
|
||||
// TestConfig returns a valid config for tests.
|
||||
//nolint: gomnd // testing domain can contains non-standart values
|
||||
func TestConfig(tb testing.TB) *Config {
|
||||
tb.Helper()
|
||||
|
||||
|
|
|
@ -12,12 +12,14 @@ type Email struct {
|
|||
subAddress string
|
||||
}
|
||||
|
||||
const DefaultEmailPartsLength int = 2
|
||||
|
||||
var ErrEmailInvalid error = NewError(ErrorCodeInvalidRequest, "cannot parse email", "")
|
||||
|
||||
// ParseEmail parse strings to email identifier.
|
||||
func ParseEmail(src string) (*Email, error) {
|
||||
parts := strings.Split(strings.TrimPrefix(src, "mailto:"), "@")
|
||||
if len(parts) != 2 { //nolint: gomnd
|
||||
if len(parts) != DefaultEmailPartsLength {
|
||||
return nil, ErrEmailInvalid
|
||||
}
|
||||
|
||||
|
@ -27,7 +29,7 @@ func ParseEmail(src string) (*Email, error) {
|
|||
subAddress: "",
|
||||
}
|
||||
|
||||
if userParts := strings.SplitN(parts[0], `+`, 2); len(userParts) > 1 {
|
||||
if userParts := strings.SplitN(parts[0], `+`, DefaultEmailPartsLength); len(userParts) > 1 {
|
||||
result.user = userParts[0]
|
||||
result.subAddress = userParts[1]
|
||||
}
|
||||
|
|
|
@ -14,19 +14,11 @@ func TestParseEmail(t *testing.T) {
|
|||
name string
|
||||
in string
|
||||
out string
|
||||
}{{
|
||||
name: "simple",
|
||||
in: "user@example.com",
|
||||
out: "user@example.com",
|
||||
}, {
|
||||
name: "subAddress",
|
||||
in: "user+suffix@example.com",
|
||||
out: "user+suffix@example.com",
|
||||
}, {
|
||||
name: "mailto prefix",
|
||||
in: "mailto:user@example.com",
|
||||
out: "user@example.com",
|
||||
}} {
|
||||
}{
|
||||
{name: "simple", in: "user@example.com", out: "user@example.com"},
|
||||
{name: "subAddress", in: "user+suffix@example.com", out: "user+suffix@example.com"},
|
||||
{name: "mailto prefix", in: "mailto:user@example.com", out: "user@example.com"},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
@ -47,7 +39,7 @@ func TestParseEmail(t *testing.T) {
|
|||
func TestEmail_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
email := domain.TestEmail(t)
|
||||
email := domain.TestEmail(t) //nolint: ifshort
|
||||
if result := email.String(); result != fmt.Sprint(email) {
|
||||
t.Errorf("String() = %v, want %v", result, email)
|
||||
}
|
||||
|
|
|
@ -10,18 +10,19 @@ import (
|
|||
|
||||
type (
|
||||
// Error describes the format of a typical IndieAuth error.
|
||||
//nolint: tagliatelle // RFC 6749 section 5.2
|
||||
Error struct {
|
||||
// A single error code.
|
||||
Code ErrorCode `json:"error"`
|
||||
|
||||
// Human-readable ASCII text providing additional information, used to
|
||||
// assist the client developer in understanding the error that occurred.
|
||||
Description string `json:"error_description,omitempty"` //nolint: tagliatelle
|
||||
Description string `json:"error_description,omitempty"`
|
||||
|
||||
// A URI identifying a human-readable web page with information about
|
||||
// the error, used to provide the client developer with additional
|
||||
// information about the error.
|
||||
URI string `json:"error_uri,omitempty"` //nolint: tagliatelle
|
||||
URI string `json:"error_uri,omitempty"`
|
||||
|
||||
// REQUIRED if a "state" parameter was present in the client
|
||||
// authorization request. The exact value received from the
|
||||
|
@ -33,17 +34,29 @@ type (
|
|||
|
||||
// ErrorCode represent error code described in RFC 6749.
|
||||
ErrorCode struct {
|
||||
uid string
|
||||
uid string
|
||||
status int
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorCodeUndefined ErrorCode = ErrorCode{uid: ""}
|
||||
// ErrorCodeUndefined describes an unrecognized error code.
|
||||
ErrorCodeUndefined = ErrorCode{
|
||||
uid: "",
|
||||
status: 0,
|
||||
}
|
||||
|
||||
// ErrorCodeAccessDenied describes the access_denied error code.
|
||||
//
|
||||
// RFC 6749 section 4.1.2.1: The resource owner or authorization server
|
||||
// denied the request.
|
||||
ErrorCodeAccessDenied ErrorCode = ErrorCode{uid: "access_denied"}
|
||||
ErrorCodeAccessDenied = ErrorCode{
|
||||
uid: "access_denied",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeInvalidClient describes the invalid_client error code.
|
||||
//
|
||||
// RFC 6749 section 5.2: Client authentication failed (e.g., unknown
|
||||
// client, no client authentication included, or unsupported
|
||||
// authentication method).
|
||||
|
@ -56,14 +69,26 @@ var (
|
|||
// HTTP 401 (Unauthorized) status code and include the
|
||||
// "WWW-Authenticate" response header field matching the authentication
|
||||
// scheme used by the client.
|
||||
ErrorCodeInvalidClient ErrorCode = ErrorCode{uid: "invalid_client"}
|
||||
ErrorCodeInvalidClient = ErrorCode{
|
||||
uid: "invalid_client",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeInvalidGrant describes the invalid_grant error code.
|
||||
//
|
||||
// RFC 6749 section 5.2: The provided authorization grant (e.g.,
|
||||
// authorization code, resource owner credentials) or refresh token is
|
||||
// invalid, expired, revoked, does not match the redirection URI used in
|
||||
// the authorization request, or was issued to another client.
|
||||
ErrorCodeInvalidGrant ErrorCode = ErrorCode{uid: "invalid_grant"}
|
||||
ErrorCodeInvalidGrant = ErrorCode{
|
||||
uid: "invalid_grant",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeInvalidRequest describes the invalid_request error code.
|
||||
//
|
||||
// IndieAuth: The request is not valid.
|
||||
//
|
||||
// RFC 6749 section 4.1.2.1: The request is missing a required
|
||||
// parameter, includes an invalid parameter value, includes a parameter
|
||||
// more than once, or is otherwise malformed.
|
||||
|
@ -73,42 +98,91 @@ var (
|
|||
// repeats a parameter, includes multiple credentials, utilizes more
|
||||
// than one mechanism for authenticating the client, or is otherwise
|
||||
// malformed.
|
||||
ErrorCodeInvalidRequest ErrorCode = ErrorCode{uid: "invalid_request"}
|
||||
ErrorCodeInvalidRequest = ErrorCode{
|
||||
uid: "invalid_request",
|
||||
status: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
// ErrorCodeInvalidScope describes the invalid_scope error code.
|
||||
//
|
||||
// RFC 6749 section 4.1.2.1: The requested scope is invalid, unknown, or
|
||||
// malformed.
|
||||
//
|
||||
// RFC 6749 section 5.2: The requested scope is invalid, unknown,
|
||||
// malformed, or exceeds the scope granted by the resource owner.
|
||||
ErrorCodeInvalidScope ErrorCode = ErrorCode{uid: "invalid_scope"}
|
||||
ErrorCodeInvalidScope = ErrorCode{
|
||||
uid: "invalid_scope",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeServerError describes the server_error error code.
|
||||
//
|
||||
// RFC 6749 section 4.1.2.1: The authorization server encountered an
|
||||
// unexpected condition that prevented it from fulfilling the request.
|
||||
// (This error code is needed because a 500 Internal Server Error HTTP
|
||||
// status code cannot be returned to the client via an HTTP redirect.)
|
||||
ErrorCodeServerError ErrorCode = ErrorCode{uid: "server_error"}
|
||||
ErrorCodeServerError = ErrorCode{
|
||||
uid: "server_error",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeTemporarilyUnavailable describes the temporarily_unavailable error code.
|
||||
//
|
||||
// RFC 6749 section 4.1.2.1: The authorization server is currently
|
||||
// unable to handle the request due to a temporary overloading or
|
||||
// maintenance of the server. (This error code is needed because a 503
|
||||
// Service Unavailable HTTP status code cannot be returned to the client
|
||||
// via an HTTP redirect.)
|
||||
ErrorCodeTemporarilyUnavailable ErrorCode = ErrorCode{uid: "temporarily_unavailable"}
|
||||
ErrorCodeTemporarilyUnavailable = ErrorCode{
|
||||
uid: "temporarily_unavailable",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeUnauthorizedClient describes the unauthorized_client error code.
|
||||
//
|
||||
// RFC 6749 section 4.1.2.1: The client is not authorized to request an
|
||||
// authorization code using this method.
|
||||
//
|
||||
// RFC 6749 section 5.2: The authenticated client is not authorized to
|
||||
// use this authorization grant type.
|
||||
ErrorCodeUnauthorizedClient ErrorCode = ErrorCode{uid: "unauthorized_client"}
|
||||
ErrorCodeUnauthorizedClient = ErrorCode{
|
||||
uid: "unauthorized_client",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeUnsupportedGrantType describes the unsupported_grant_type error code.
|
||||
//
|
||||
// RFC 6749 section 5.2: The authorization grant type is not supported
|
||||
// by the authorization server.
|
||||
ErrorCodeUnsupportedGrantType ErrorCode = ErrorCode{uid: "unsupported_grant_type"}
|
||||
ErrorCodeUnsupportedGrantType = ErrorCode{
|
||||
uid: "unsupported_grant_type",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeUnsupportedResponseType describes the unsupported_response_type error code.
|
||||
//
|
||||
// RFC 6749 section 4.1.2.1: The authorization server does not support
|
||||
// obtaining an authorization code using this method.
|
||||
ErrorCodeUnsupportedResponseType ErrorCode = ErrorCode{uid: "unsupported_response_type"}
|
||||
ErrorCodeUnsupportedResponseType = ErrorCode{
|
||||
uid: "unsupported_response_type",
|
||||
status: 0, // TODO(toby3d)
|
||||
}
|
||||
|
||||
// ErrorCodeInvalidToken describes the invalid_token error code.
|
||||
//
|
||||
// IndieAuth: The access token provided is expired, revoked, or invalid.
|
||||
ErrorCodeInvalidToken = ErrorCode{
|
||||
uid: "invalid_token",
|
||||
status: http.StatusUnauthorized,
|
||||
}
|
||||
|
||||
// ErrorCodeInsufficientScope describes the insufficient_scope error code.
|
||||
//
|
||||
// IndieAuth: The request requires higher privileges than provided.
|
||||
ErrorCodeInsufficientScope = ErrorCode{
|
||||
uid: "insufficient_scope",
|
||||
status: http.StatusForbidden,
|
||||
}
|
||||
)
|
||||
|
||||
// String returns a string representation of the error code.
|
||||
|
@ -128,45 +202,45 @@ func (e Error) Error() string {
|
|||
}
|
||||
|
||||
// Format prints the stack as error detail.
|
||||
func (e Error) Format(s fmt.State, r rune) {
|
||||
xerrors.FormatError(e, s, r)
|
||||
func (e Error) Format(state fmt.State, r rune) {
|
||||
xerrors.FormatError(e, state, r)
|
||||
}
|
||||
|
||||
// FormatError prints the receiver's error, if any.
|
||||
func (e Error) FormatError(p xerrors.Printer) error {
|
||||
p.Print(e.Code)
|
||||
func (e Error) FormatError(printer xerrors.Printer) error {
|
||||
printer.Print(e.Code)
|
||||
|
||||
if e.Description != "" {
|
||||
p.Print(": ", e.Description)
|
||||
printer.Print(": ", e.Description)
|
||||
}
|
||||
|
||||
if !p.Detail() {
|
||||
if !printer.Detail() {
|
||||
return nil
|
||||
}
|
||||
|
||||
e.frame.Format(p)
|
||||
e.frame.Format(printer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReirectURI sets fasthttp.QueryArgs with the request state, code,
|
||||
// description and error URI in the provided fasthttp.URI.
|
||||
func (e Error) SetReirectURI(u *http.URI) {
|
||||
if u == nil {
|
||||
func (e Error) SetReirectURI(uri *http.URI) {
|
||||
if uri == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range map[string]string{
|
||||
for key, val := range map[string]string{
|
||||
"error": e.Code.String(),
|
||||
"error_description": e.Description,
|
||||
"error_uri": e.URI,
|
||||
"state": e.State,
|
||||
} {
|
||||
if v == "" {
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
u.QueryArgs().Set(k, v)
|
||||
uri.QueryArgs().Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,19 +1,21 @@
|
|||
//nolint: dupl
|
||||
package domain
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GrantType represent fixed grant_type parameter.
|
||||
//
|
||||
// NOTE(toby3d): Encapsulate enums in structs for extra compile-time safety:
|
||||
// https://threedots.tech/post/safer-enums-in-go/#struct-based-enums
|
||||
type GrantType struct {
|
||||
uid string
|
||||
}
|
||||
|
||||
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
|
||||
//nolint: gochecknoglobals // structs cannot be constants
|
||||
var (
|
||||
GrantTypeUndefined = GrantType{uid: ""}
|
||||
GrantTypeAuthorizationCode = GrantType{uid: "authorization_code"}
|
||||
|
@ -22,7 +24,11 @@ var (
|
|||
GrantTypeTicket = GrantType{uid: "ticket"}
|
||||
)
|
||||
|
||||
var ErrGrantTypeUnknown error = errors.New("unknown grant type")
|
||||
var ErrGrantTypeUnknown error = NewError(
|
||||
ErrorCodeInvalidGrant,
|
||||
"unknown grant type",
|
||||
"",
|
||||
)
|
||||
|
||||
// ParseGrantType parse grant_type value as GrantType struct enum.
|
||||
func ParseGrantType(uid string) (GrantType, error) {
|
||||
|
@ -40,7 +46,7 @@ func ParseGrantType(uid string) (GrantType, error) {
|
|||
func (gt *GrantType) UnmarshalForm(src []byte) error {
|
||||
responseType, err := ParseGrantType(string(src))
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant_type: %w", err)
|
||||
return fmt.Errorf("UnmarshalForm: %w", err)
|
||||
}
|
||||
|
||||
*gt = responseType
|
||||
|
@ -52,12 +58,12 @@ func (gt *GrantType) UnmarshalForm(src []byte) error {
|
|||
func (gt *GrantType) UnmarshalJSON(v []byte) error {
|
||||
src, err := strconv.Unquote(string(v))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("UnmarshalJSON: %w", err)
|
||||
}
|
||||
|
||||
responseType, err := ParseGrantType(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("grant_type: %w", err)
|
||||
return fmt.Errorf("UnmarshalJSON: %w", err)
|
||||
}
|
||||
|
||||
*gt = responseType
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint: dupl
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
|
@ -12,13 +13,10 @@ func TestParseGrantType(t *testing.T) {
|
|||
for _, tc := range []struct {
|
||||
in string
|
||||
out domain.GrantType
|
||||
}{{
|
||||
in: "authorization_code",
|
||||
out: domain.GrantTypeAuthorizationCode,
|
||||
}, {
|
||||
in: "ticket",
|
||||
out: domain.GrantTypeTicket,
|
||||
}} {
|
||||
}{
|
||||
{in: "authorization_code", out: domain.GrantTypeAuthorizationCode},
|
||||
{in: "ticket", out: domain.GrantTypeTicket},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.in, func(t *testing.T) {
|
||||
|
@ -73,15 +71,10 @@ func TestGrantType_String(t *testing.T) {
|
|||
name string
|
||||
in domain.GrantType
|
||||
out string
|
||||
}{{
|
||||
name: "authorization_code",
|
||||
in: domain.GrantTypeAuthorizationCode,
|
||||
out: "authorization_code",
|
||||
}, {
|
||||
name: "ticket",
|
||||
in: domain.GrantTypeTicket,
|
||||
out: "ticket",
|
||||
}} {
|
||||
}{
|
||||
{name: "authorization_code", in: domain.GrantTypeAuthorizationCode, out: "authorization_code"},
|
||||
{name: "ticket", in: domain.GrantTypeTicket, out: "ticket"},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
|
|
@ -18,7 +18,7 @@ type Me struct {
|
|||
}
|
||||
|
||||
// ParseMe parse string as me URL identifier.
|
||||
//nolint: funlen
|
||||
//nolint: funlen, cyclop
|
||||
func ParseMe(raw string) (*Me, error) {
|
||||
me := http.AcquireURI()
|
||||
if err := me.Parse(nil, []byte(raw)); err != nil {
|
||||
|
@ -114,7 +114,7 @@ func TestMe(tb testing.TB, src string) *Me {
|
|||
|
||||
me, err := ParseMe(src)
|
||||
if err != nil {
|
||||
tb.Fatalf("%+v", err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return me
|
||||
|
@ -174,12 +174,16 @@ func (m Me) URL() *url.URL {
|
|||
}
|
||||
|
||||
return &url.URL{
|
||||
Scheme: string(m.me.Scheme()),
|
||||
Host: string(m.me.Host()),
|
||||
Path: string(m.me.Path()),
|
||||
RawPath: string(m.me.PathOriginal()),
|
||||
RawQuery: string(m.me.QueryString()),
|
||||
Fragment: string(m.me.Hash()),
|
||||
ForceQuery: false,
|
||||
Fragment: string(m.me.Hash()),
|
||||
Host: string(m.me.Host()),
|
||||
Opaque: "",
|
||||
Path: string(m.me.Path()),
|
||||
RawFragment: "",
|
||||
RawPath: string(m.me.PathOriginal()),
|
||||
RawQuery: string(m.me.QueryString()),
|
||||
Scheme: string(m.me.Scheme()),
|
||||
User: nil,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
)
|
||||
|
||||
//nolint: funlen
|
||||
func TestParseMe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -15,47 +14,18 @@ func TestParseMe(t *testing.T) {
|
|||
name string
|
||||
in string
|
||||
expError bool
|
||||
}{{
|
||||
name: "valid",
|
||||
in: "https://example.com/",
|
||||
expError: false,
|
||||
}, {
|
||||
name: "valid path",
|
||||
in: "https://example.com/username",
|
||||
expError: false,
|
||||
}, {
|
||||
name: "valid query",
|
||||
in: "https://example.com/users?id=100",
|
||||
expError: false,
|
||||
}, {
|
||||
name: "missing scheme",
|
||||
in: "example.com",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "invalid scheme",
|
||||
in: "mailto:user@example.com",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "contains double-dot path",
|
||||
in: "https://example.com/foo/../bar",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "contains fragment",
|
||||
in: "https://example.com/#me",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "contains user",
|
||||
in: "https://user:pass@example.com/",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "contains port",
|
||||
in: "https://example.com:8443/",
|
||||
expError: true,
|
||||
}, {
|
||||
name: "host is an IP address",
|
||||
in: "https://172.28.92.51/",
|
||||
expError: true,
|
||||
}} {
|
||||
}{
|
||||
{name: "valid", in: "https://example.com/", expError: false},
|
||||
{name: "valid path", in: "https://example.com/username", expError: false},
|
||||
{name: "valid query", in: "https://example.com/users?id=100", expError: false},
|
||||
{name: "missing scheme", in: "example.com", expError: true},
|
||||
{name: "invalid scheme", in: "mailto:user@example.com", expError: true},
|
||||
{name: "contains double-dot path", in: "https://example.com/foo/../bar", expError: true},
|
||||
{name: "contains fragment", in: "https://example.com/#me", expError: true},
|
||||
{name: "contains user", in: "https://user:pass@example.com/", expError: true},
|
||||
{name: "contains port", in: "https://example.com:8443/", expError: true},
|
||||
{name: "host is an IP address", in: "https://172.28.92.51/", expError: true},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package domain
|
||||
|
||||
//nolint: tagliatelle
|
||||
//nolint: tagliatelle // https://indieauth.net/source/#indieauth-server-metadata
|
||||
type Metadata struct {
|
||||
// The server's issuer identifier. The issuer identifier is a URL that
|
||||
// uses the "https" scheme and has no query or fragment components. The
|
||||
|
|
|
@ -21,89 +21,89 @@ type Provider struct {
|
|||
URL string
|
||||
}
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
//nolint: gochecknoglobals // structs cannot be contants
|
||||
var (
|
||||
ProviderDirect = Provider{
|
||||
AuthURL: "/authorize",
|
||||
Name: "IndieAuth",
|
||||
Photo: path.Join("static", "icon.svg"),
|
||||
RedirectURL: path.Join("callback"),
|
||||
Scopes: []string{},
|
||||
TokenURL: "/token",
|
||||
UID: "direct",
|
||||
URL: "/",
|
||||
AuthURL: "/authorize",
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
Name: "IndieAuth",
|
||||
Photo: path.Join("static", "icon.svg"),
|
||||
RedirectURL: path.Join("callback"),
|
||||
Scopes: []string{},
|
||||
TokenURL: "/token",
|
||||
UID: "direct",
|
||||
URL: "/",
|
||||
}
|
||||
|
||||
ProviderTwitter = Provider{
|
||||
AuthURL: "https://twitter.com/i/oauth2/authorize",
|
||||
Name: "Twitter",
|
||||
Photo: path.Join("static", "providers", "twitter.svg"),
|
||||
RedirectURL: path.Join("callback", "twitter"),
|
||||
Scopes: []string{
|
||||
"tweet.read",
|
||||
"users.read",
|
||||
},
|
||||
TokenURL: "https://api.twitter.com/2/oauth2/token",
|
||||
UID: "twitter",
|
||||
URL: "https://twitter.com/",
|
||||
AuthURL: "https://twitter.com/i/oauth2/authorize",
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
Name: "Twitter",
|
||||
Photo: path.Join("static", "providers", "twitter.svg"),
|
||||
RedirectURL: path.Join("callback", "twitter"),
|
||||
Scopes: []string{"tweet.read", "users.read"},
|
||||
TokenURL: "https://api.twitter.com/2/oauth2/token",
|
||||
UID: "twitter",
|
||||
URL: "https://twitter.com/",
|
||||
}
|
||||
|
||||
ProviderGitHub = Provider{
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
Name: "GitHub",
|
||||
Photo: path.Join("static", "providers", "github.svg"),
|
||||
RedirectURL: path.Join("callback", "github"),
|
||||
Scopes: []string{
|
||||
"read:user",
|
||||
"user:email",
|
||||
},
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UID: "github",
|
||||
URL: "https://github.com/",
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
Name: "GitHub",
|
||||
Photo: path.Join("static", "providers", "github.svg"),
|
||||
RedirectURL: path.Join("callback", "github"),
|
||||
Scopes: []string{"read:user", "user:email"},
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UID: "github",
|
||||
URL: "https://github.com/",
|
||||
}
|
||||
|
||||
ProviderGitLab = Provider{
|
||||
AuthURL: "https://gitlab.com/oauth/authorize",
|
||||
Name: "GitLab",
|
||||
Photo: path.Join("static", "providers", "gitlab.svg"),
|
||||
RedirectURL: path.Join("callback", "gitlab"),
|
||||
Scopes: []string{
|
||||
"read_user",
|
||||
},
|
||||
TokenURL: "https://gitlab.com/oauth/token",
|
||||
UID: "gitlab",
|
||||
URL: "https://gitlab.com/",
|
||||
AuthURL: "https://gitlab.com/oauth/authorize",
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
Name: "GitLab",
|
||||
Photo: path.Join("static", "providers", "gitlab.svg"),
|
||||
RedirectURL: path.Join("callback", "gitlab"),
|
||||
Scopes: []string{"read_user"},
|
||||
TokenURL: "https://gitlab.com/oauth/token",
|
||||
UID: "gitlab",
|
||||
URL: "https://gitlab.com/",
|
||||
}
|
||||
|
||||
ProviderMastodon = Provider{
|
||||
AuthURL: "https://mstdn.io/oauth/authorize",
|
||||
Name: "Mastodon",
|
||||
Photo: path.Join("static", "providers", "mastodon.svg"),
|
||||
RedirectURL: path.Join("callback", "mastodon"),
|
||||
Scopes: []string{
|
||||
"read:accounts",
|
||||
},
|
||||
TokenURL: "https://mstdn.io/oauth/token",
|
||||
UID: "mastodon",
|
||||
URL: "https://mstdn.io/",
|
||||
AuthURL: "https://mstdn.io/oauth/authorize",
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
Name: "Mastodon",
|
||||
Photo: path.Join("static", "providers", "mastodon.svg"),
|
||||
RedirectURL: path.Join("callback", "mastodon"),
|
||||
Scopes: []string{"read:accounts"},
|
||||
TokenURL: "https://mstdn.io/oauth/token",
|
||||
UID: "mastodon",
|
||||
URL: "https://mstdn.io/",
|
||||
}
|
||||
)
|
||||
|
||||
// AuthCodeURL returns URL for authorize user in RelMeAuth client.
|
||||
func (p Provider) AuthCodeURL(state string) string {
|
||||
u := http.AcquireURI()
|
||||
defer http.ReleaseURI(u)
|
||||
u.Update(p.AuthURL)
|
||||
uri := http.AcquireURI()
|
||||
defer http.ReleaseURI(uri)
|
||||
uri.Update(p.AuthURL)
|
||||
|
||||
for k, v := range map[string]string{
|
||||
for key, val := range map[string]string{
|
||||
"client_id": p.ClientID,
|
||||
"redirect_uri": p.RedirectURL,
|
||||
"response_type": "code",
|
||||
"scope": strings.Join(p.Scopes, " "),
|
||||
"state": state,
|
||||
} {
|
||||
u.QueryArgs().Set(k, v)
|
||||
uri.QueryArgs().Set(key, val)
|
||||
}
|
||||
|
||||
return u.String()
|
||||
return uri.String()
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint: dupl
|
||||
package domain
|
||||
|
||||
import (
|
||||
|
@ -12,20 +13,20 @@ type ResponseType struct {
|
|||
uid string
|
||||
}
|
||||
|
||||
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
|
||||
//nolint: gochecknoglobals // structs cannot be constants
|
||||
var (
|
||||
ResponseTypeUndefined ResponseType = ResponseType{uid: ""}
|
||||
ResponseTypeUndefined = ResponseType{uid: ""}
|
||||
|
||||
// Deprecated(toby3d): Only accept response_type=code requests, and for
|
||||
// backwards-compatible support, treat response_type=id requests as
|
||||
// response_type=code requests:
|
||||
// https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
|
||||
ResponseTypeID ResponseType = ResponseType{uid: "id"}
|
||||
ResponseTypeID = ResponseType{uid: "id"}
|
||||
|
||||
// Indicates to the authorization server that an authorization code
|
||||
// should be returned as the response:
|
||||
// https://indieauth.net/source/#authorization-request-li-1
|
||||
ResponseTypeCode ResponseType = ResponseType{uid: "code"}
|
||||
ResponseTypeCode = ResponseType{uid: "code"}
|
||||
)
|
||||
|
||||
var ErrResponseTypeUnknown error = NewError(
|
||||
|
@ -65,7 +66,7 @@ func (rt *ResponseType) UnmarshalJSON(v []byte) error {
|
|||
return fmt.Errorf("UnmarshalJSON: %w", err)
|
||||
}
|
||||
|
||||
responseType, err := ParseResponseType(string(uid))
|
||||
responseType, err := ParseResponseType(uid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("UnmarshalJSON: %w", err)
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
//nolint: dupl
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
|
@ -12,13 +13,10 @@ func TestResponseTypeType(t *testing.T) {
|
|||
for _, tc := range []struct {
|
||||
in string
|
||||
out domain.ResponseType
|
||||
}{{
|
||||
in: "id",
|
||||
out: domain.ResponseTypeID,
|
||||
}, {
|
||||
in: "code",
|
||||
out: domain.ResponseTypeCode,
|
||||
}} {
|
||||
}{
|
||||
{in: "id", out: domain.ResponseTypeID},
|
||||
{in: "code", out: domain.ResponseTypeCode},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.in, func(t *testing.T) {
|
||||
|
@ -73,15 +71,10 @@ func TestResponseType_String(t *testing.T) {
|
|||
name string
|
||||
in domain.ResponseType
|
||||
out string
|
||||
}{{
|
||||
name: "id",
|
||||
in: domain.ResponseTypeID,
|
||||
out: "id",
|
||||
}, {
|
||||
name: "code",
|
||||
in: domain.ResponseTypeCode,
|
||||
out: "code",
|
||||
}} {
|
||||
}{
|
||||
{name: "id", in: domain.ResponseTypeID, out: "id"},
|
||||
{name: "code", in: domain.ResponseTypeCode, out: "code"},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
|
|
@ -21,7 +21,7 @@ type (
|
|||
|
||||
var ErrScopeUnknown error = NewError(ErrorCodeInvalidRequest, "unknown scope", "https://indieweb.org/scope")
|
||||
|
||||
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
|
||||
//nolint: gochecknoglobals // structs cannot be constants
|
||||
var (
|
||||
ScopeUndefined = Scope{uid: ""}
|
||||
|
||||
|
@ -56,7 +56,7 @@ var (
|
|||
ScopeEmail = Scope{uid: "email"}
|
||||
)
|
||||
|
||||
//nolint: gochecknoglobals // NOTE(toby3d): maps cannot be constants
|
||||
//nolint: gochecknoglobals // maps cannot be constants
|
||||
var uidsScopes = map[string]Scope{
|
||||
ScopeBlock.uid: ScopeBlock,
|
||||
ScopeChannels.uid: ScopeChannels,
|
||||
|
@ -112,17 +112,17 @@ func (s *Scopes) UnmarshalJSON(v []byte) error {
|
|||
|
||||
result := make(Scopes, 0)
|
||||
|
||||
for _, scope := range strings.Fields(src) {
|
||||
s, err := ParseScope(scope)
|
||||
for _, rawScope := range strings.Fields(src) {
|
||||
scope, err := ParseScope(rawScope)
|
||||
if err != nil {
|
||||
return fmt.Errorf("UnmarshalJSON: %w", err)
|
||||
}
|
||||
|
||||
if result.Has(s) {
|
||||
if result.Has(scope) {
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, s)
|
||||
result = append(result, scope)
|
||||
}
|
||||
|
||||
*s = result
|
||||
|
|
|
@ -14,43 +14,20 @@ func TestParseScope(t *testing.T) {
|
|||
for _, tc := range []struct {
|
||||
in string
|
||||
out domain.Scope
|
||||
}{{
|
||||
in: "create",
|
||||
out: domain.ScopeCreate,
|
||||
}, {
|
||||
in: "delete",
|
||||
out: domain.ScopeDelete,
|
||||
}, {
|
||||
in: "draft",
|
||||
out: domain.ScopeDraft,
|
||||
}, {
|
||||
in: "media",
|
||||
out: domain.ScopeMedia,
|
||||
}, {
|
||||
in: "update",
|
||||
out: domain.ScopeUpdate,
|
||||
}, {
|
||||
in: "block",
|
||||
out: domain.ScopeBlock,
|
||||
}, {
|
||||
in: "channels",
|
||||
out: domain.ScopeChannels,
|
||||
}, {
|
||||
in: "follow",
|
||||
out: domain.ScopeFollow,
|
||||
}, {
|
||||
in: "mute",
|
||||
out: domain.ScopeMute,
|
||||
}, {
|
||||
in: "read",
|
||||
out: domain.ScopeRead,
|
||||
}, {
|
||||
in: "profile",
|
||||
out: domain.ScopeProfile,
|
||||
}, {
|
||||
in: "email",
|
||||
out: domain.ScopeEmail,
|
||||
}} {
|
||||
}{
|
||||
{in: "create", out: domain.ScopeCreate},
|
||||
{in: "delete", out: domain.ScopeDelete},
|
||||
{in: "draft", out: domain.ScopeDraft},
|
||||
{in: "media", out: domain.ScopeMedia},
|
||||
{in: "update", out: domain.ScopeUpdate},
|
||||
{in: "block", out: domain.ScopeBlock},
|
||||
{in: "channels", out: domain.ScopeChannels},
|
||||
{in: "follow", out: domain.ScopeFollow},
|
||||
{in: "mute", out: domain.ScopeMute},
|
||||
{in: "read", out: domain.ScopeRead},
|
||||
{in: "profile", out: domain.ScopeProfile},
|
||||
{in: "email", out: domain.ScopeEmail},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.in, func(t *testing.T) {
|
||||
|
@ -118,47 +95,24 @@ func TestScopes_MarshalJSON(t *testing.T) {
|
|||
func TestScope_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
//nolint: paralleltest // NOTE(toby3d): false positive, tc.in is used.
|
||||
//nolint: paralleltest // false positive, in is used
|
||||
for _, tc := range []struct {
|
||||
in domain.Scope
|
||||
out string
|
||||
}{{
|
||||
in: domain.ScopeCreate,
|
||||
out: "create",
|
||||
}, {
|
||||
in: domain.ScopeDelete,
|
||||
out: "delete",
|
||||
}, {
|
||||
in: domain.ScopeDraft,
|
||||
out: "draft",
|
||||
}, {
|
||||
in: domain.ScopeMedia,
|
||||
out: "media",
|
||||
}, {
|
||||
in: domain.ScopeUpdate,
|
||||
out: "update",
|
||||
}, {
|
||||
in: domain.ScopeBlock,
|
||||
out: "block",
|
||||
}, {
|
||||
in: domain.ScopeChannels,
|
||||
out: "channels",
|
||||
}, {
|
||||
in: domain.ScopeFollow,
|
||||
out: "follow",
|
||||
}, {
|
||||
in: domain.ScopeMute,
|
||||
out: "mute",
|
||||
}, {
|
||||
in: domain.ScopeRead,
|
||||
out: "read",
|
||||
}, {
|
||||
in: domain.ScopeProfile,
|
||||
out: "profile",
|
||||
}, {
|
||||
in: domain.ScopeEmail,
|
||||
out: "email",
|
||||
}} {
|
||||
}{
|
||||
{in: domain.ScopeCreate, out: "create"},
|
||||
{in: domain.ScopeDelete, out: "delete"},
|
||||
{in: domain.ScopeDraft, out: "draft"},
|
||||
{in: domain.ScopeMedia, out: "media"},
|
||||
{in: domain.ScopeUpdate, out: "update"},
|
||||
{in: domain.ScopeBlock, out: "block"},
|
||||
{in: domain.ScopeChannels, out: "channels"},
|
||||
{in: domain.ScopeFollow, out: "follow"},
|
||||
{in: domain.ScopeMute, out: "mute"},
|
||||
{in: domain.ScopeRead, out: "read"},
|
||||
{in: domain.ScopeProfile, out: "profile"},
|
||||
{in: domain.ScopeEmail, out: "email"},
|
||||
} {
|
||||
tc := tc
|
||||
|
||||
t.Run(fmt.Sprint(tc.in), func(t *testing.T) {
|
||||
|
|
|
@ -8,22 +8,22 @@ import (
|
|||
|
||||
type Session struct {
|
||||
ClientID *ClientID
|
||||
Me *Me
|
||||
RedirectURI *URL
|
||||
Profile *Profile
|
||||
Me *Me
|
||||
CodeChallengeMethod CodeChallengeMethod
|
||||
Scope Scopes
|
||||
Code string
|
||||
CodeChallenge string
|
||||
Code string
|
||||
}
|
||||
|
||||
// TestSession returns valid random generated session for tests.
|
||||
//nolint: gomnd // testing domain can contains non-standart values
|
||||
func TestSession(tb testing.TB) *Session {
|
||||
tb.Helper()
|
||||
|
||||
code, err := random.String(24)
|
||||
if err != nil {
|
||||
tb.Fatalf("%+v", err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return &Session{
|
||||
|
|
|
@ -33,13 +33,20 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
// DefaultNewTokenOptions describes the default settings for NewToken.
|
||||
//nolint: gochecknoglobals, gomnd
|
||||
var DefaultNewTokenOptions = NewTokenOptions{
|
||||
Algorithm: "HS256",
|
||||
Expiration: 0,
|
||||
Issuer: nil,
|
||||
NonceLength: 32,
|
||||
Scope: nil,
|
||||
Secret: nil,
|
||||
Subject: nil,
|
||||
}
|
||||
|
||||
// NewToken create a new token by provided options.
|
||||
//nolint: cyclop
|
||||
func NewToken(opts NewTokenOptions) (*Token, error) {
|
||||
if opts.NonceLength == 0 {
|
||||
opts.NonceLength = DefaultNewTokenOptions.NonceLength
|
||||
|
@ -56,22 +63,33 @@ func NewToken(opts NewTokenOptions) (*Token, error) {
|
|||
return nil, fmt.Errorf("cannot generate nonce: %w", err)
|
||||
}
|
||||
|
||||
t := jwt.New()
|
||||
t.Set(jwt.SubjectKey, opts.Subject.String())
|
||||
t.Set(jwt.NotBeforeKey, now)
|
||||
t.Set(jwt.IssuedAtKey, now)
|
||||
t.Set("scope", opts.Scope)
|
||||
t.Set("nonce", nonce)
|
||||
tkn := jwt.New()
|
||||
|
||||
for key, val := range map[string]interface{}{
|
||||
"nonce": nonce,
|
||||
"scope": opts.Scope,
|
||||
jwt.IssuedAtKey: now,
|
||||
jwt.NotBeforeKey: now,
|
||||
jwt.SubjectKey: opts.Subject.String(),
|
||||
} {
|
||||
if err = tkn.Set(key, val); err != nil {
|
||||
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.Issuer != nil {
|
||||
t.Set(jwt.IssuerKey, opts.Issuer.String())
|
||||
if err = tkn.Set(jwt.IssuerKey, opts.Issuer.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.Expiration != 0 {
|
||||
t.Set(jwt.ExpirationKey, now.Add(opts.Expiration))
|
||||
if err = tkn.Set(jwt.ExpirationKey, now.Add(opts.Expiration)); err != nil {
|
||||
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := jwt.Sign(t, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret)
|
||||
accessToken, err := jwt.Sign(tkn, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot sign a new access token: %w", err)
|
||||
}
|
||||
|
@ -81,19 +99,20 @@ func NewToken(opts NewTokenOptions) (*Token, error) {
|
|||
ClientID: opts.Issuer,
|
||||
Me: opts.Subject,
|
||||
Scope: opts.Scope,
|
||||
}, err
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestToken returns valid random generated token for tests.
|
||||
//nolint: gomnd // testing domain can contains non-standart values
|
||||
func TestToken(tb testing.TB) *Token {
|
||||
tb.Helper()
|
||||
|
||||
nonce, err := random.String(22)
|
||||
if err != nil {
|
||||
tb.Fatalf("%+v", err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
t := jwt.New()
|
||||
tkn := jwt.New()
|
||||
cid := TestClientID(tb)
|
||||
me := TestMe(tb, "https://user.example.net/")
|
||||
now := time.Now().UTC().Round(time.Second)
|
||||
|
@ -103,22 +122,25 @@ func TestToken(tb testing.TB) *Token {
|
|||
ScopeUpdate,
|
||||
}
|
||||
|
||||
// NOTE(toby3d): required
|
||||
t.Set(jwt.IssuerKey, cid.String())
|
||||
t.Set(jwt.SubjectKey, me.String())
|
||||
// TODO(toby3d): t.Set(jwt.AudienceKey, nil)
|
||||
t.Set(jwt.ExpirationKey, now.Add(1*time.Hour))
|
||||
t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour))
|
||||
t.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour))
|
||||
// TODO(toby3d): t.Set(jwt.JwtIDKey, nil)
|
||||
for key, val := range map[string]interface{}{
|
||||
// NOTE(toby3d): required
|
||||
jwt.IssuerKey: cid.String(),
|
||||
jwt.SubjectKey: me.String(),
|
||||
jwt.ExpirationKey: now.Add(1 * time.Hour),
|
||||
jwt.NotBeforeKey: now.Add(-1 * time.Hour),
|
||||
jwt.IssuedAtKey: now.Add(-1 * time.Hour),
|
||||
// TODO(toby3d): jwt.AudienceKey
|
||||
// TODO(toby3d): jwt.JwtIDKey
|
||||
// NOTE(toby3d): optional
|
||||
"scope": scope,
|
||||
"nonce": nonce,
|
||||
} {
|
||||
_ = tkn.Set(key, val)
|
||||
}
|
||||
|
||||
// optional
|
||||
t.Set("scope", scope)
|
||||
t.Set("nonce", nonce)
|
||||
|
||||
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
|
||||
accessToken, err := jwt.Sign(tkn, jwa.HS256, []byte("hackme"))
|
||||
if err != nil {
|
||||
tb.Fatalf("%+v", err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return &Token{
|
||||
|
@ -144,5 +166,5 @@ func (t Token) String() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
return "Bearer " + string(t.AccessToken)
|
||||
return "Bearer " + t.AccessToken
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ func (u URL) URL() *url.URL {
|
|||
return nil
|
||||
}
|
||||
|
||||
result, err := url.ParseRequestURI(u.URI.String())
|
||||
result, err := url.ParseRequestURI(u.String())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -19,32 +19,32 @@ func TestParseURL(t *testing.T) {
|
|||
func TestURL_UnmarshalForm(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
u := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
|
||||
input := []byte(fmt.Sprint(u))
|
||||
url := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
|
||||
input := []byte(fmt.Sprint(url))
|
||||
result := new(domain.URL)
|
||||
|
||||
if err := result.UnmarshalForm(input); err != nil {
|
||||
t.Fatalf("%+v", err)
|
||||
}
|
||||
|
||||
if fmt.Sprint(result) != fmt.Sprint(u) {
|
||||
t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, u)
|
||||
if fmt.Sprint(result) != fmt.Sprint(url) {
|
||||
t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestURL_UnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
u := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
|
||||
input := []byte(fmt.Sprintf(`"%s"`, u))
|
||||
url := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
|
||||
input := []byte(fmt.Sprintf(`"%s"`, url))
|
||||
result := new(domain.URL)
|
||||
|
||||
if err := result.UnmarshalJSON(input); err != nil {
|
||||
t.Fatalf("%+v", err)
|
||||
}
|
||||
|
||||
if fmt.Sprint(result) != fmt.Sprint(u) {
|
||||
t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, u)
|
||||
if fmt.Sprint(result) != fmt.Sprint(url) {
|
||||
t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, url)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
//nolint: tagliatelle
|
||||
//nolint: tagliatelle // https://indieauth.net/source/#indieauth-server-metadata
|
||||
MetadataResponse struct {
|
||||
// The server's issuer identifier. The issuer identifier is a
|
||||
// URL that uses the "https" scheme and has no query or fragment
|
||||
|
@ -72,14 +72,11 @@ type (
|
|||
}
|
||||
)
|
||||
|
||||
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be contants.
|
||||
// DefaultMetadataResponse contains all supported types by default.
|
||||
//nolint: gochecknoglobals // structs cannot be constants
|
||||
var DefaultMetadataResponse = MetadataResponse{
|
||||
ServiceDocumentation: "https://indieauth.net/source/",
|
||||
AuthorizationEndpoint: "",
|
||||
AuthorizationResponseIssParameterSupported: true,
|
||||
ScopesSupported: []string{
|
||||
domain.ScopeEmail.String(),
|
||||
domain.ScopeProfile.String(),
|
||||
},
|
||||
CodeChallengeMethodsSupported: []string{
|
||||
domain.CodeChallengeMethodMD5.String(),
|
||||
domain.CodeChallengeMethodPLAIN.String(),
|
||||
|
@ -87,13 +84,21 @@ var DefaultMetadataResponse = MetadataResponse{
|
|||
domain.CodeChallengeMethodS256.String(),
|
||||
domain.CodeChallengeMethodS512.String(),
|
||||
},
|
||||
GrantTypesSupported: []string{
|
||||
domain.GrantTypeAuthorizationCode.String(),
|
||||
domain.GrantTypeTicket.String(),
|
||||
},
|
||||
Issuer: "",
|
||||
ResponseTypesSupported: []string{
|
||||
domain.ResponseTypeCode.String(),
|
||||
domain.ResponseTypeID.String(),
|
||||
},
|
||||
GrantTypesSupported: []string{
|
||||
domain.GrantTypeAuthorizationCode.String(),
|
||||
ScopesSupported: []string{
|
||||
domain.ScopeEmail.String(),
|
||||
domain.ScopeProfile.String(),
|
||||
},
|
||||
ServiceDocumentation: "https://indieauth.net/source/",
|
||||
TokenEndpoint: "",
|
||||
}
|
||||
|
||||
func NewRequestHandler(config *domain.Config) *RequestHandler {
|
||||
|
@ -118,5 +123,5 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) {
|
|||
|
||||
ctx.SetStatusCode(http.StatusOK)
|
||||
ctx.SetContentType(common.MIMEApplicationJSON)
|
||||
json.NewEncoder(ctx).Encode(&resp)
|
||||
_ = json.NewEncoder(ctx).Encode(&resp)
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ func NewGithubProfileRepository() profile.Repository {
|
|||
return &githubProfileRepository{}
|
||||
}
|
||||
|
||||
//nolint: cyclop
|
||||
func (repo *githubProfileRepository) Get(ctx context.Context, token *oauth2.Token) (*domain.Profile, error) {
|
||||
user, _, err := github.NewClient(oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))).Users.Get(ctx, "")
|
||||
if err != nil {
|
||||
|
@ -41,6 +42,7 @@ func (repo *githubProfileRepository) Get(ctx context.Context, token *oauth2.Toke
|
|||
|
||||
// NOTE(toby3d): Profile URLs.
|
||||
result.URL = make([]*domain.URL, 0)
|
||||
|
||||
var twitterURL *string
|
||||
|
||||
if user.TwitterUsername != nil {
|
||||
|
|
|
@ -19,7 +19,7 @@ func NewGitlabProfileRepository() profile.Repository {
|
|||
return &gitlabProfileRepository{}
|
||||
}
|
||||
|
||||
//nolint: funlen // NOTE(toby3d): uses hyphenation on new lines for readability.
|
||||
//nolint: funlen, cyclop
|
||||
func (repo *gitlabProfileRepository) Get(_ context.Context, token *oauth2.Token) (*domain.Profile, error) {
|
||||
client, err := gitlab.NewClient(token.AccessToken)
|
||||
if err != nil {
|
||||
|
|
|
@ -25,8 +25,10 @@ func NewMastodonProfileRepository(server string) profile.Repository {
|
|||
|
||||
func (repo *mastodonProfileRepository) Get(ctx context.Context, token *oauth2.Token) (*domain.Profile, error) {
|
||||
account, err := mastodon.NewClient(&mastodon.Config{
|
||||
Server: repo.server,
|
||||
AccessToken: token.AccessToken,
|
||||
AccessToken: token.AccessToken,
|
||||
ClientID: "",
|
||||
ClientSecret: "",
|
||||
Server: repo.server,
|
||||
}).GetAccountCurrentUser(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: cannot get account info: %w", ErrPrefix, err)
|
||||
|
@ -64,11 +66,9 @@ func (repo *mastodonProfileRepository) Get(ctx context.Context, token *oauth2.To
|
|||
}
|
||||
|
||||
u, err := domain.ParseURL(account.Fields[i].Value)
|
||||
if err != nil {
|
||||
continue
|
||||
if err == nil {
|
||||
result.URL = append(result.URL, u)
|
||||
}
|
||||
|
||||
result.URL = append(result.URL, u)
|
||||
}
|
||||
|
||||
// WARN(toby3d): Mastodon does not provide an email via API.
|
||||
|
|
|
@ -2,6 +2,7 @@ package random
|
|||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
)
|
||||
|
@ -17,32 +18,31 @@ const (
|
|||
)
|
||||
|
||||
func Bytes(length int) ([]byte, error) {
|
||||
b := make([]byte, length)
|
||||
bytes := make([]byte, length)
|
||||
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return nil, err
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return nil, fmt.Errorf("cannot read bytes: %w", err)
|
||||
}
|
||||
|
||||
return b, nil
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
func String(length int, charsets ...string) (string, error) {
|
||||
charset := strings.Join(charsets, "")
|
||||
|
||||
if charset == "" {
|
||||
charset = Alphabetic
|
||||
}
|
||||
|
||||
b := make([]byte, length)
|
||||
bytes := make([]byte, length)
|
||||
|
||||
for i := range b {
|
||||
for i := range bytes {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("failed to randomize bytes: %w", err)
|
||||
}
|
||||
|
||||
b[i] = charset[n.Int64()]
|
||||
bytes[i] = charset[n.Int64()]
|
||||
}
|
||||
|
||||
return string(b), nil
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
|
|
@ -27,8 +27,7 @@ type (
|
|||
}
|
||||
|
||||
sqlite3SessionRepository struct {
|
||||
config *domain.Config
|
||||
db *sqlx.DB
|
||||
db *sqlx.DB
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -57,7 +56,7 @@ const (
|
|||
WHERE code=$1;`
|
||||
)
|
||||
|
||||
func NewSQLite3SessionRepository(config *domain.Config, db *sqlx.DB) session.Repository {
|
||||
func NewSQLite3SessionRepository(db *sqlx.DB) session.Repository {
|
||||
db.MustExec(QueryTable)
|
||||
|
||||
return &sqlite3SessionRepository{
|
||||
|
@ -74,7 +73,7 @@ func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domai
|
|||
}
|
||||
|
||||
func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*domain.Session, error) {
|
||||
s := new(Session)
|
||||
s := new(Session) //nolint: varnamelen // cannot redaclare import
|
||||
if err := repo.db.GetContext(ctx, s, QueryGet, code); err != nil {
|
||||
return nil, fmt.Errorf("cannot find session in db: %w", err)
|
||||
}
|
||||
|
@ -86,16 +85,17 @@ func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*do
|
|||
}
|
||||
|
||||
func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) {
|
||||
s := new(Session)
|
||||
s := new(Session) //nolint: varnamelen // cannot redaclare import
|
||||
|
||||
tx, err := repo.db.Beginx()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
_ = tx.Rollback()
|
||||
|
||||
return nil, fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.GetContext(ctx, s, QueryGet, code); err != nil {
|
||||
//nolint: errcheck // deffered method
|
||||
defer tx.Rollback()
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -106,7 +106,7 @@ func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code str
|
|||
}
|
||||
|
||||
if _, err = tx.ExecContext(ctx, QueryDelete, code); err != nil {
|
||||
tx.Rollback()
|
||||
_ = tx.Rollback()
|
||||
|
||||
return nil, fmt.Errorf("cannot remove session from db: %w", err)
|
||||
}
|
||||
|
|
|
@ -12,10 +12,9 @@ import (
|
|||
"source.toby3d.me/website/indieauth/internal/testing/sqltest"
|
||||
)
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
//nolint: gochecknoglobals // slices cannot be contants
|
||||
var tableColumns = []string{
|
||||
"created_at", "client_id", "me", "redirect_uri", "code_challenge_method", "scope", "code",
|
||||
"code_challenge",
|
||||
"created_at", "client_id", "me", "redirect_uri", "code_challenge_method", "scope", "code", "code_challenge",
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
|
@ -40,7 +39,7 @@ func TestCreate(t *testing.T) {
|
|||
).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
if err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db).
|
||||
if err := repository.NewSQLite3SessionRepository(db).
|
||||
Create(context.TODO(), session); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
@ -69,7 +68,7 @@ func TestGet(t *testing.T) {
|
|||
model.CodeChallenge,
|
||||
))
|
||||
|
||||
result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db).
|
||||
result, err := repository.NewSQLite3SessionRepository(db).
|
||||
Get(context.TODO(), session.Code)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -108,7 +107,7 @@ func TestGetAndDelete(t *testing.T) {
|
|||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db).
|
||||
result, err := repository.NewSQLite3SessionRepository(db).
|
||||
GetAndDelete(context.TODO(), session.Code)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -23,7 +23,6 @@ func New(tb testing.TB) (*bolt.DB, func()) {
|
|||
db, err := bolt.Open(filePath, os.ModePerm, nil)
|
||||
require.NoError(tb, err)
|
||||
|
||||
//nolint: errcheck
|
||||
return db, func() {
|
||||
db.Close()
|
||||
os.Remove(filePath)
|
||||
|
|
|
@ -3,6 +3,8 @@ package httptest
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
// used for running tests.
|
||||
_ "embed"
|
||||
"net"
|
||||
"testing"
|
||||
|
@ -47,9 +49,8 @@ func New(tb testing.TB, handler http.RequestHandler) (*http.Client, *http.Server
|
|||
},
|
||||
}
|
||||
|
||||
//nolint: errcheck
|
||||
return client, server, func() {
|
||||
server.Shutdown()
|
||||
_ = server.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type Time struct{}
|
||||
|
@ -24,17 +23,17 @@ func Open(tb testing.TB) (*sqlx.DB, sqlmock.Sqlmock, func()) {
|
|||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
tb.Fatalf("%+v", err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
xdb := sqlx.NewDb(db, "sqlite")
|
||||
if err = xdb.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
|
||||
tb.Fatalf("%+v", err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return xdb, mock, func() {
|
||||
_ = db.Close() //nolint: errcheck
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
http "github.com/valyala/fasthttp"
|
||||
"golang.org/x/text/language"
|
||||
"golang.org/x/text/message"
|
||||
|
@ -21,7 +22,7 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
GenerateRequest struct {
|
||||
TicketGenerateRequest struct {
|
||||
// The access token should be used when acting on behalf of this URL.
|
||||
Subject *domain.Me `form:"subject"`
|
||||
|
||||
|
@ -29,7 +30,7 @@ type (
|
|||
Resource *domain.URL `form:"resource"`
|
||||
}
|
||||
|
||||
ExchangeRequest struct {
|
||||
TicketExchangeRequest struct {
|
||||
// A random string that can be redeemed for an access token.
|
||||
Ticket string `form:"ticket"`
|
||||
|
||||
|
@ -58,21 +59,40 @@ func NewRequestHandler(tickets ticket.UseCase, matcher language.Matcher, config
|
|||
func (h *RequestHandler) Register(r *router.Router) {
|
||||
chain := middleware.Chain{
|
||||
middleware.CSRFWithConfig(middleware.CSRFConfig{
|
||||
CookieSameSite: http.CookieSameSiteLaxMode,
|
||||
ContextKey: "csrf",
|
||||
CookieName: "_csrf",
|
||||
TokenLookup: "form:_csrf",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
Skipper: func(ctx *http.RequestCtx) bool {
|
||||
matched, _ := path.Match("/ticket*", string(ctx.Path()))
|
||||
|
||||
return ctx.IsPost() && matched
|
||||
},
|
||||
CookieMaxAge: 0,
|
||||
CookieSameSite: http.CookieSameSiteLaxMode,
|
||||
ContextKey: "csrf",
|
||||
CookieDomain: "",
|
||||
CookieName: "_csrf",
|
||||
CookiePath: "",
|
||||
TokenLookup: "form:_csrf",
|
||||
TokenLength: 0,
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
}),
|
||||
middleware.JWTWithConfig(middleware.JWTConfig{
|
||||
AuthScheme: "Bearer",
|
||||
BeforeFunc: nil,
|
||||
Claims: nil,
|
||||
ContextKey: "user",
|
||||
ErrorHandler: nil,
|
||||
ErrorHandlerWithContext: nil,
|
||||
ParseTokenFunc: nil,
|
||||
SigningKey: []byte(h.config.JWT.Secret),
|
||||
SigningKeys: nil,
|
||||
SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm),
|
||||
Skipper: middleware.DefaultSkipper,
|
||||
SuccessHandler: nil,
|
||||
TokenLookup: middleware.SourceHeader + ":" + http.HeaderAuthorization,
|
||||
}),
|
||||
middleware.LogFmt(),
|
||||
}
|
||||
// TODO(toby3d): secure this via JWT middleware
|
||||
|
||||
r.GET("/ticket", chain.RequestHandler(h.handleRender))
|
||||
r.POST("/api/ticket", chain.RequestHandler(h.handleSend))
|
||||
r.POST("/ticket", chain.RequestHandler(h.handleRedeem))
|
||||
|
@ -102,10 +122,11 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
|
|||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(ExchangeRequest)
|
||||
req := new(TicketGenerateRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -119,14 +140,16 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
|
|||
var err error
|
||||
if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil {
|
||||
ctx.SetStatusCode(http.StatusInternalServerError)
|
||||
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
|
||||
|
||||
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.tickets.Generate(ctx, ticket); err != nil {
|
||||
ctx.SetStatusCode(http.StatusInternalServerError)
|
||||
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
|
||||
|
||||
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -140,10 +163,11 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
|
|||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(ExchangeRequest)
|
||||
req := new(TicketExchangeRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -155,7 +179,8 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
|
|||
})
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
|
||||
|
||||
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -170,7 +195,7 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
|
|||
}`, token.AccessToken, token.Scope.String(), token.Me.String()))
|
||||
}
|
||||
|
||||
func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) {
|
||||
func (req *TicketGenerateRequest) bind(ctx *http.RequestCtx) (err error) {
|
||||
indieAuthError := new(domain.Error)
|
||||
if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil {
|
||||
if errors.As(err, indieAuthError) {
|
||||
|
@ -203,7 +228,7 @@ func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (req *ExchangeRequest) bind(ctx *http.RequestCtx) (err error) {
|
||||
func (req *TicketExchangeRequest) bind(ctx *http.RequestCtx) (err error) {
|
||||
indieAuthError := new(domain.Error)
|
||||
if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil {
|
||||
if errors.As(err, indieAuthError) {
|
||||
|
|
|
@ -46,13 +46,13 @@ func TestUpdate(t *testing.T) {
|
|||
userClient, _, userCleanup := httptest.New(t, userRouter.Handler)
|
||||
t.Cleanup(userCleanup)
|
||||
|
||||
r := router.New()
|
||||
router := router.New()
|
||||
delivery.NewRequestHandler(
|
||||
ucase.NewTicketUseCase(ticketrepo.NewMemoryTicketRepository(new(sync.Map), config), userClient, config),
|
||||
language.NewMatcher(message.DefaultCatalog.Languages()), config,
|
||||
).Register(r)
|
||||
).Register(router)
|
||||
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
client, _, cleanup := httptest.New(t, router.Handler)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "https://example.com/ticket", []byte(
|
||||
|
|
|
@ -59,8 +59,8 @@ func (repo *memoryTicketRepository) GC() {
|
|||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for ts := range ticker.C {
|
||||
ts := ts.UTC()
|
||||
for timeStamp := range ticker.C {
|
||||
timeStamp := timeStamp.UTC()
|
||||
|
||||
repo.store.Range(func(key, value interface{}) bool {
|
||||
k, ok := key.(string)
|
||||
|
@ -78,7 +78,7 @@ func (repo *memoryTicketRepository) GC() {
|
|||
return false
|
||||
}
|
||||
|
||||
if val.CreatedAt.Add(repo.config.Code.Expiry).After(ts) {
|
||||
if val.CreatedAt.Add(repo.config.Code.Expiry).After(timeStamp) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -63,16 +63,17 @@ func (repo *sqlite3TicketRepository) Create(ctx context.Context, t *domain.Ticke
|
|||
return nil
|
||||
}
|
||||
|
||||
func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, t string) (*domain.Ticket, error) {
|
||||
func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, rawTicket string) (*domain.Ticket, error) {
|
||||
tx, err := repo.db.Beginx()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
_ = tx.Rollback()
|
||||
|
||||
return nil, fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
tkt := new(Ticket)
|
||||
if err = tx.GetContext(ctx, tkt, QueryGet, t); err != nil {
|
||||
if err = tx.GetContext(ctx, tkt, QueryGet, rawTicket); err != nil {
|
||||
//nolint: errcheck // deffered method
|
||||
defer tx.Rollback()
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -82,8 +83,8 @@ func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, t string)
|
|||
return nil, fmt.Errorf("cannot find ticket in db: %w", err)
|
||||
}
|
||||
|
||||
if _, err = tx.ExecContext(ctx, QueryDelete, t); err != nil {
|
||||
tx.Rollback()
|
||||
if _, err = tx.ExecContext(ctx, QueryDelete, rawTicket); err != nil {
|
||||
_ = tx.Rollback()
|
||||
|
||||
return nil, fmt.Errorf("cannot remove ticket from db: %w", err)
|
||||
}
|
||||
|
|
|
@ -12,8 +12,8 @@ import (
|
|||
repository "source.toby3d.me/website/indieauth/internal/ticket/repository/sqlite3"
|
||||
)
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
var tableColumns []string = []string{"created_at", "resource", "subject", "ticket"}
|
||||
//nolint: gochecknoglobals // slices cannot be contants
|
||||
var tableColumns = []string{"created_at", "resource", "subject", "ticket"}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -11,7 +11,6 @@ type UseCase interface {
|
|||
|
||||
// Redeem transform received ticket into access token.
|
||||
Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error)
|
||||
|
||||
Exchange(ctx context.Context, ticket string) (*domain.Token, error)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
//nolint: tagliatelle
|
||||
//nolint: tagliatelle // https://indieauth.net/source/#access-token-response
|
||||
Response struct {
|
||||
Me *domain.Me `json:"me"`
|
||||
Scope domain.Scopes `json:"scope"`
|
||||
|
@ -37,11 +37,11 @@ func NewTicketUseCase(tickets ticket.Repository, client *http.Client, config *do
|
|||
}
|
||||
}
|
||||
|
||||
func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) error {
|
||||
func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket) error {
|
||||
req := http.AcquireRequest()
|
||||
defer http.ReleaseRequest(req)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
req.SetRequestURI(t.Subject.String())
|
||||
req.SetRequestURI(tkt.Subject.String())
|
||||
|
||||
resp := http.AcquireResponse()
|
||||
defer http.ReleaseResponse(resp)
|
||||
|
@ -65,7 +65,7 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
|
|||
return ticket.ErrTicketEndpointNotExist
|
||||
}
|
||||
|
||||
if err := useCase.tickets.Create(ctx, t); err != nil {
|
||||
if err := useCase.tickets.Create(ctx, tkt); err != nil {
|
||||
return fmt.Errorf("cannot save ticket in store: %w", err)
|
||||
}
|
||||
|
||||
|
@ -73,9 +73,9 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
|
|||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetRequestURIBytes(ticketEndpoint.RequestURI())
|
||||
req.Header.SetContentType(common.MIMEApplicationForm)
|
||||
req.PostArgs().Set("ticket", t.Ticket)
|
||||
req.PostArgs().Set("subject", t.Subject.String())
|
||||
req.PostArgs().Set("resource", t.Resource.String())
|
||||
req.PostArgs().Set("ticket", tkt.Ticket)
|
||||
req.PostArgs().Set("subject", tkt.Subject.String())
|
||||
req.PostArgs().Set("resource", tkt.Resource.String())
|
||||
resp.Reset()
|
||||
|
||||
if err := useCase.client.Do(req, resp); err != nil {
|
||||
|
@ -85,10 +85,10 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*domain.Token, error) {
|
||||
func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (*domain.Token, error) {
|
||||
req := http.AcquireRequest()
|
||||
defer http.ReleaseRequest(req)
|
||||
req.SetRequestURI(t.Resource.String())
|
||||
req.SetRequestURI(tkt.Resource.String())
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
resp := http.AcquireResponse()
|
||||
|
@ -119,7 +119,7 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*do
|
|||
req.Header.SetContentType(common.MIMEApplicationForm)
|
||||
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
|
||||
req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String())
|
||||
req.PostArgs().Set("ticket", t.Ticket)
|
||||
req.PostArgs().Set("ticket", tkt.Ticket)
|
||||
resp.Reset()
|
||||
|
||||
if err := useCase.client.Do(req, resp); err != nil {
|
||||
|
@ -142,19 +142,19 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*do
|
|||
}
|
||||
|
||||
func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*domain.Token, error) {
|
||||
t, err := useCase.tickets.GetAndDelete(ctx, ticket)
|
||||
tkt, err := useCase.tickets.GetAndDelete(ctx, ticket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find provided ticket: %w", err)
|
||||
}
|
||||
|
||||
token, err := domain.NewToken(domain.NewTokenOptions{
|
||||
Algorithm: useCase.config.JWT.Algorithm,
|
||||
Expiration: useCase.config.JWT.Expiry,
|
||||
// TODO(toby3d): Issuer: &domain.ClientID{},
|
||||
NonceLength: useCase.config.JWT.NonceLength,
|
||||
Expiration: useCase.config.JWT.Expiry,
|
||||
Scope: domain.Scopes{domain.ScopeRead},
|
||||
Issuer: nil,
|
||||
Subject: tkt.Subject,
|
||||
Secret: []byte(useCase.config.JWT.Secret),
|
||||
Subject: t.Subject,
|
||||
Algorithm: useCase.config.JWT.Algorithm,
|
||||
NonceLength: useCase.config.JWT.NonceLength,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot generate a new access token: %w", err)
|
||||
|
|
|
@ -22,12 +22,12 @@ func TestRedeem(t *testing.T) {
|
|||
token := domain.TestToken(t)
|
||||
ticket := domain.TestTicket(t)
|
||||
|
||||
r := router.New()
|
||||
r.GET(string(ticket.Resource.Path()), func(ctx *http.RequestCtx) {
|
||||
router := router.New()
|
||||
router.GET(string(ticket.Resource.Path()), func(ctx *http.RequestCtx) {
|
||||
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, `<link rel="token_endpoint" href="`+
|
||||
ticket.Subject.String()+`token">`)
|
||||
})
|
||||
r.POST("/token", func(ctx *http.RequestCtx) {
|
||||
router.POST("/token", func(ctx *http.RequestCtx) {
|
||||
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, fmt.Sprintf(`{
|
||||
"token_type": "Bearer",
|
||||
"access_token": "%s",
|
||||
|
@ -36,7 +36,7 @@ func TestRedeem(t *testing.T) {
|
|||
}`, token.AccessToken, token.Scope.String(), token.Me.String()))
|
||||
})
|
||||
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
client, _, cleanup := httptest.New(t, router.Handler)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)).
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
ExchangeRequest struct {
|
||||
TokenExchangeRequest struct {
|
||||
ClientID *domain.ClientID `form:"client_id"`
|
||||
RedirectURI *domain.URL `form:"redirect_uri"`
|
||||
GrantType domain.GrantType `form:"grant_type"`
|
||||
|
@ -25,40 +25,40 @@ type (
|
|||
CodeVerifier string `form:"code_verifier"`
|
||||
}
|
||||
|
||||
RevokeRequest struct {
|
||||
TokenRevokeRequest struct {
|
||||
Action domain.Action `form:"action"`
|
||||
Token string `form:"token"`
|
||||
}
|
||||
|
||||
TicketRequest struct {
|
||||
TokenTicketRequest struct {
|
||||
Action domain.Action `form:"action"`
|
||||
Ticket string `form:"ticket"`
|
||||
}
|
||||
|
||||
//nolint: tagliatelle
|
||||
ExchangeResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
Me string `json:"me"`
|
||||
Profile *ProfileResponse `json:"profile,omitempty"`
|
||||
//nolint: tagliatelle // https://indieauth.net/source/#access-token-response
|
||||
TokenExchangeResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
Me string `json:"me"`
|
||||
Profile *TokenProfileResponse `json:"profile,omitempty"`
|
||||
}
|
||||
|
||||
ProfileResponse struct {
|
||||
TokenProfileResponse struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
URL *domain.URL `json:"url,omitempty"`
|
||||
Photo *domain.URL `json:"photo,omitempty"`
|
||||
Email *domain.Email `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
//nolint: tagliatelle
|
||||
VerificationResponse struct {
|
||||
//nolint: tagliatelle // https://indieauth.net/source/#access-token-verification-response
|
||||
TokenVerificationResponse struct {
|
||||
Me *domain.Me `json:"me"`
|
||||
ClientID *domain.ClientID `json:"client_id"`
|
||||
Scope domain.Scopes `json:"scope"`
|
||||
}
|
||||
|
||||
RevocationResponse struct{}
|
||||
TokenRevocationResponse struct{}
|
||||
|
||||
RequestHandler struct {
|
||||
tokens token.UseCase
|
||||
|
@ -88,11 +88,12 @@ func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
|
|||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
t, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)),
|
||||
tkt, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)),
|
||||
"Bearer "))
|
||||
if err != nil || t == nil {
|
||||
if err != nil || tkt == nil {
|
||||
ctx.SetStatusCode(http.StatusUnauthorized)
|
||||
encoder.Encode(domain.NewError(
|
||||
|
||||
_ = encoder.Encode(domain.NewError(
|
||||
domain.ErrorCodeUnauthorizedClient,
|
||||
err.Error(),
|
||||
"https://indieauth.net/source/#access-token-verification",
|
||||
|
@ -101,10 +102,10 @@ func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
|
|||
return
|
||||
}
|
||||
|
||||
_ = encoder.Encode(&VerificationResponse{
|
||||
ClientID: t.ClientID,
|
||||
Me: t.Me,
|
||||
Scope: t.Scope,
|
||||
_ = encoder.Encode(&TokenVerificationResponse{
|
||||
ClientID: tkt.ClientID,
|
||||
Me: tkt.Me,
|
||||
Scope: tkt.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -120,7 +121,8 @@ func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
|
|||
action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action")))
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(domain.NewError(
|
||||
|
||||
_ = encoder.Encode(domain.NewError(
|
||||
domain.ErrorCodeInvalidRequest,
|
||||
err.Error(),
|
||||
"",
|
||||
|
@ -138,15 +140,17 @@ func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
|
|||
}
|
||||
}
|
||||
|
||||
//nolint: funlen
|
||||
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
|
||||
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
|
||||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(ExchangeRequest)
|
||||
req := new(TokenExchangeRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -159,7 +163,8 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
|
|||
})
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(domain.NewError(
|
||||
|
||||
_ = encoder.Encode(domain.NewError(
|
||||
domain.ErrorCodeInvalidRequest,
|
||||
err.Error(),
|
||||
"https://indieauth.net/source/#request",
|
||||
|
@ -168,20 +173,21 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
|
|||
return
|
||||
}
|
||||
|
||||
resp := &ExchangeResponse{
|
||||
resp := &TokenExchangeResponse{
|
||||
AccessToken: token.AccessToken,
|
||||
TokenType: "Bearer",
|
||||
Scope: token.Scope.String(),
|
||||
Me: token.Me.String(),
|
||||
Profile: nil,
|
||||
}
|
||||
|
||||
if profile == nil {
|
||||
encoder.Encode(resp)
|
||||
_ = encoder.Encode(resp)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
resp.Profile = new(ProfileResponse)
|
||||
resp.Profile = new(TokenProfileResponse)
|
||||
|
||||
if len(profile.Name) > 0 {
|
||||
resp.Profile.Name = profile.Name[0]
|
||||
|
@ -208,17 +214,19 @@ func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
|
|||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(RevokeRequest)
|
||||
req := new(TokenRevokeRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.tokens.Revoke(ctx, req.Token); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(domain.NewError(
|
||||
|
||||
_ = encoder.Encode(domain.NewError(
|
||||
domain.ErrorCodeInvalidRequest,
|
||||
err.Error(),
|
||||
"",
|
||||
|
@ -227,7 +235,7 @@ func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
|
|||
return
|
||||
}
|
||||
|
||||
_ = encoder.Encode(&RevocationResponse{})
|
||||
_ = encoder.Encode(&TokenRevocationResponse{})
|
||||
}
|
||||
|
||||
func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
|
||||
|
@ -236,18 +244,20 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
|
|||
|
||||
encoder := json.NewEncoder(ctx)
|
||||
|
||||
req := new(TicketRequest)
|
||||
req := new(TokenTicketRequest)
|
||||
if err := req.bind(ctx); err != nil {
|
||||
ctx.SetStatusCode(http.StatusBadRequest)
|
||||
encoder.Encode(err)
|
||||
|
||||
_ = encoder.Encode(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
t, err := h.tickets.Exchange(ctx, req.Ticket)
|
||||
tkn, err := h.tickets.Exchange(ctx, req.Ticket)
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(http.StatusInternalServerError)
|
||||
encoder.Encode(domain.NewError(
|
||||
|
||||
_ = encoder.Encode(domain.NewError(
|
||||
domain.ErrorCodeInvalidRequest,
|
||||
err.Error(),
|
||||
"https://indieauth.net/source/#request",
|
||||
|
@ -256,15 +266,16 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
|
|||
return
|
||||
}
|
||||
|
||||
encoder.Encode(ExchangeResponse{
|
||||
AccessToken: t.AccessToken,
|
||||
_ = encoder.Encode(TokenExchangeResponse{
|
||||
AccessToken: tkn.AccessToken,
|
||||
TokenType: "Bearer",
|
||||
Scope: t.Scope.String(),
|
||||
Me: t.Me.String(),
|
||||
Scope: tkn.Scope.String(),
|
||||
Me: tkn.Me.String(),
|
||||
Profile: nil,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
|
||||
func (r *TokenExchangeRequest) bind(ctx *http.RequestCtx) error {
|
||||
indieAuthError := new(domain.Error)
|
||||
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||
if errors.As(err, indieAuthError) {
|
||||
|
@ -281,7 +292,7 @@ func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *RevokeRequest) bind(ctx *http.RequestCtx) error {
|
||||
func (r *TokenRevokeRequest) bind(ctx *http.RequestCtx) error {
|
||||
indieAuthError := new(domain.Error)
|
||||
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||
if errors.As(err, indieAuthError) {
|
||||
|
@ -298,7 +309,7 @@ func (r *RevokeRequest) bind(ctx *http.RequestCtx) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *TicketRequest) bind(ctx *http.RequestCtx) error {
|
||||
func (r *TokenTicketRequest) bind(ctx *http.RequestCtx) error {
|
||||
indieAuthError := new(domain.Error)
|
||||
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
|
||||
if errors.As(err, indieAuthError) {
|
||||
|
|
|
@ -36,7 +36,7 @@ func TestVerification(t *testing.T) {
|
|||
config := domain.TestConfig(t)
|
||||
token := domain.TestToken(t)
|
||||
|
||||
r := router.New()
|
||||
router := router.New()
|
||||
// TODO(toby3d): provide tickets
|
||||
delivery.NewRequestHandler(
|
||||
tokenucase.NewTokenUseCase(
|
||||
|
@ -49,9 +49,9 @@ func TestVerification(t *testing.T) {
|
|||
new(http.Client),
|
||||
config,
|
||||
),
|
||||
).Register(r)
|
||||
).Register(router)
|
||||
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
client, _, cleanup := httptest.New(t, router.Handler)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/token", nil)
|
||||
|
@ -65,7 +65,7 @@ func TestVerification(t *testing.T) {
|
|||
require.NoError(t, client.Do(req, resp))
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode())
|
||||
|
||||
result := new(delivery.VerificationResponse)
|
||||
result := new(delivery.TokenVerificationResponse)
|
||||
require.NoError(t, json.Unmarshal(resp.Body(), result))
|
||||
assert.Equal(t, token.ClientID.String(), result.ClientID.String())
|
||||
assert.Equal(t, token.Me.String(), result.Me.String())
|
||||
|
@ -80,7 +80,7 @@ func TestRevocation(t *testing.T) {
|
|||
tokens := tokenrepo.NewMemoryTokenRepository(store)
|
||||
accessToken := domain.TestToken(t)
|
||||
|
||||
r := router.New()
|
||||
router := router.New()
|
||||
delivery.NewRequestHandler(
|
||||
tokenucase.NewTokenUseCase(
|
||||
tokens,
|
||||
|
@ -92,9 +92,9 @@ func TestRevocation(t *testing.T) {
|
|||
new(http.Client),
|
||||
config,
|
||||
),
|
||||
).Register(r)
|
||||
).Register(router)
|
||||
|
||||
client, _, cleanup := httptest.New(t, r.Handler)
|
||||
client, _, cleanup := httptest.New(t, router.Handler)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "https://app.example.com/token", nil)
|
||||
|
|
|
@ -62,8 +62,8 @@ func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *dom
|
|||
}
|
||||
|
||||
func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
t := new(Token)
|
||||
if err := repo.db.GetContext(ctx, t, QueryGet, accessToken); err != nil {
|
||||
tkn := new(Token)
|
||||
if err := repo.db.GetContext(ctx, tkn, QueryGet, accessToken); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, token.ErrNotExist
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string)
|
|||
}
|
||||
|
||||
result := new(domain.Token)
|
||||
t.Populate(result)
|
||||
tkn.Populate(result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
@ -12,14 +12,15 @@ import (
|
|||
repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3"
|
||||
)
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
var tableColumns []string = []string{"created_at", "access_token", "client_id", "me", "scope"}
|
||||
//nolint: gochecknoglobals // slices cannot be contants
|
||||
var tableColumns = []string{"created_at", "access_token", "client_id", "me", "scope"}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token := domain.TestToken(t)
|
||||
model := repository.NewToken(token)
|
||||
|
||||
db, mock, cleanup := sqltest.Open(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
|
@ -44,6 +45,7 @@ func TestGet(t *testing.T) {
|
|||
|
||||
token := domain.TestToken(t)
|
||||
model := repository.NewToken(token)
|
||||
|
||||
db, mock, cleanup := sqltest.Open(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOp
|
|||
return nil, nil, token.ErrEmptyScope
|
||||
}
|
||||
|
||||
t, err := domain.NewToken(domain.NewTokenOptions{
|
||||
tkn, err := domain.NewToken(domain.NewTokenOptions{
|
||||
Algorithm: useCase.config.JWT.Algorithm,
|
||||
Expiration: useCase.config.JWT.Expiry,
|
||||
Issuer: session.ClientID,
|
||||
|
@ -70,14 +70,14 @@ func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOp
|
|||
}
|
||||
|
||||
if !session.Scope.Has(domain.ScopeProfile) {
|
||||
return t, nil, nil
|
||||
return tkn, nil, nil
|
||||
}
|
||||
|
||||
p := new(domain.Profile)
|
||||
|
||||
// TODO(toby3d): if session.Scope.Has(domain.ScopeEmail) {}
|
||||
|
||||
return t, p, nil
|
||||
return tkn, p, nil
|
||||
}
|
||||
|
||||
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
|
||||
|
@ -90,23 +90,26 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
|
|||
return nil, token.ErrRevoke
|
||||
}
|
||||
|
||||
t, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm),
|
||||
tkn, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm),
|
||||
[]byte(useCase.config.JWT.Secret)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse JWT token: %w", err)
|
||||
}
|
||||
|
||||
if err = jwt.Validate(t); err != nil {
|
||||
if err = jwt.Validate(tkn); err != nil {
|
||||
return nil, fmt.Errorf("cannot validate JWT token: %w", err)
|
||||
}
|
||||
|
||||
result := &domain.Token{
|
||||
Scope: nil,
|
||||
ClientID: nil,
|
||||
Me: nil,
|
||||
AccessToken: accessToken,
|
||||
}
|
||||
result.ClientID, _ = domain.ParseClientID(t.Issuer())
|
||||
result.Me, _ = domain.ParseMe(t.Subject())
|
||||
result.ClientID, _ = domain.ParseClientID(tkn.Issuer())
|
||||
result.Me, _ = domain.ParseMe(tkn.Subject())
|
||||
|
||||
if scope, ok := t.Get("scope"); ok {
|
||||
if scope, ok := tkn.Get("scope"); ok {
|
||||
result.Scope, _ = scope.(domain.Scopes)
|
||||
}
|
||||
|
||||
|
@ -114,12 +117,12 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
|
|||
}
|
||||
|
||||
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
|
||||
t, err := useCase.Verify(ctx, accessToken)
|
||||
tkn, err := useCase.Verify(ctx, accessToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot verify token: %w", err)
|
||||
}
|
||||
|
||||
if err = useCase.tokens.Create(ctx, t); err != nil {
|
||||
if err = useCase.tokens.Create(ctx, tkn); err != nil {
|
||||
return fmt.Errorf("cannot save token in database: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -52,31 +52,38 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain
|
|||
|
||||
// TODO(toby3d): handle error here?
|
||||
resolvedMe, _ := domain.ParseMe(string(resp.Header.Peek(http.HeaderLocation)))
|
||||
u := &domain.User{
|
||||
Me: resolvedMe,
|
||||
|
||||
user := &domain.User{
|
||||
AuthorizationEndpoint: nil,
|
||||
IndieAuthMetadata: nil,
|
||||
Me: resolvedMe,
|
||||
Micropub: nil,
|
||||
Microsub: nil,
|
||||
Profile: &domain.Profile{
|
||||
Name: make([]string, 0),
|
||||
URL: make([]*domain.URL, 0),
|
||||
Photo: make([]*domain.URL, 0),
|
||||
Email: make([]*domain.Email, 0),
|
||||
Name: make([]string, 0),
|
||||
Photo: make([]*domain.URL, 0),
|
||||
URL: make([]*domain.URL, 0),
|
||||
},
|
||||
TicketEndpoint: nil,
|
||||
TokenEndpoint: nil,
|
||||
}
|
||||
|
||||
metadata, err := util.ExtractMetadata(resp, repo.client)
|
||||
if err == nil && metadata != nil {
|
||||
u.AuthorizationEndpoint = metadata.AuthorizationEndpoint
|
||||
u.Micropub = metadata.Micropub
|
||||
u.Microsub = metadata.Microsub
|
||||
u.TicketEndpoint = metadata.TicketEndpoint
|
||||
u.TokenEndpoint = metadata.TokenEndpoint
|
||||
if metadata, err := util.ExtractMetadata(resp, repo.client); err == nil {
|
||||
user.AuthorizationEndpoint = metadata.AuthorizationEndpoint
|
||||
user.Micropub = metadata.Micropub
|
||||
user.Microsub = metadata.Microsub
|
||||
user.TicketEndpoint = metadata.TicketEndpoint
|
||||
user.TokenEndpoint = metadata.TokenEndpoint
|
||||
}
|
||||
|
||||
extractUser(u, resp)
|
||||
extractProfile(u.Profile, resp)
|
||||
extractUser(user, resp)
|
||||
extractProfile(user.Profile, resp)
|
||||
|
||||
return u, nil
|
||||
return user, nil
|
||||
}
|
||||
|
||||
//nolint: cyclop
|
||||
func extractUser(dst *domain.User, src *http.Response) {
|
||||
if dst.IndieAuthMetadata != nil {
|
||||
if endpoints := util.ExtractEndpoints(src, relIndieAuthMetadata); len(endpoints) > 0 {
|
||||
|
@ -115,14 +122,12 @@ func extractUser(dst *domain.User, src *http.Response) {
|
|||
}
|
||||
}
|
||||
|
||||
//nolint: cyclop
|
||||
func extractProfile(dst *domain.Profile, src *http.Response) {
|
||||
for _, name := range util.ExtractProperty(src, hCard, propertyName) {
|
||||
n, ok := name.(string)
|
||||
if !ok {
|
||||
continue
|
||||
if n, ok := name.(string); ok {
|
||||
dst.Name = append(dst.Name, n)
|
||||
}
|
||||
|
||||
dst.Name = append(dst.Name, n)
|
||||
}
|
||||
|
||||
for _, rawEmail := range util.ExtractProperty(src, hCard, propertyEmail) {
|
||||
|
@ -131,26 +136,20 @@ func extractProfile(dst *domain.Profile, src *http.Response) {
|
|||
continue
|
||||
}
|
||||
|
||||
e, err := domain.ParseEmail(email)
|
||||
if err != nil {
|
||||
continue
|
||||
if e, err := domain.ParseEmail(email); err == nil {
|
||||
dst.Email = append(dst.Email, e)
|
||||
}
|
||||
|
||||
dst.Email = append(dst.Email, e)
|
||||
}
|
||||
|
||||
for _, rawUrl := range util.ExtractProperty(src, hCard, propertyURL) {
|
||||
url, ok := rawUrl.(string)
|
||||
for _, rawURL := range util.ExtractProperty(src, hCard, propertyURL) {
|
||||
url, ok := rawURL.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
u, err := domain.ParseURL(url)
|
||||
if err != nil {
|
||||
continue
|
||||
if u, err := domain.ParseURL(url); err == nil {
|
||||
dst.URL = append(dst.URL, u)
|
||||
}
|
||||
|
||||
dst.URL = append(dst.URL, u)
|
||||
}
|
||||
|
||||
for _, rawPhoto := range util.ExtractProperty(src, hCard, propertyPhoto) {
|
||||
|
@ -159,11 +158,8 @@ func extractProfile(dst *domain.Profile, src *http.Response) {
|
|||
continue
|
||||
}
|
||||
|
||||
p, err := domain.ParseURL(photo)
|
||||
if err != nil {
|
||||
continue
|
||||
if p, err := domain.ParseURL(photo); err == nil {
|
||||
dst.Photo = append(dst.Photo, p)
|
||||
}
|
||||
|
||||
dst.Photo = append(dst.Photo, p)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,8 +69,8 @@ func TestGet(t *testing.T) {
|
|||
func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
|
||||
tb.Helper()
|
||||
|
||||
r := router.New()
|
||||
r.GET("/", func(ctx *http.RequestCtx) {
|
||||
router := router.New()
|
||||
router.GET("/", func(ctx *http.RequestCtx) {
|
||||
ctx.Response.Header.Set(http.HeaderLink, strings.Join([]string{
|
||||
`<` + user.AuthorizationEndpoint.String() + `>; rel="authorization_endpoint"`,
|
||||
`<` + user.IndieAuthMetadata.String() + `>; rel="indieauth-metadata"`,
|
||||
|
@ -83,7 +83,7 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
|
|||
testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0],
|
||||
))
|
||||
})
|
||||
r.GET(string(user.IndieAuthMetadata.Path()), func(ctx *http.RequestCtx) {
|
||||
router.GET(string(user.IndieAuthMetadata.Path()), func(ctx *http.RequestCtx) {
|
||||
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{
|
||||
"issuer": "`+user.Me.String()+`",
|
||||
"authorization_endpoint": "`+user.AuthorizationEndpoint.String()+`",
|
||||
|
@ -91,5 +91,5 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
|
|||
}`)
|
||||
})
|
||||
|
||||
return r.Handler
|
||||
return router.Handler
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package util
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
|
@ -13,6 +14,12 @@ import (
|
|||
"source.toby3d.me/website/indieauth/internal/domain"
|
||||
)
|
||||
|
||||
var ErrEndpointNotExist = domain.NewError(
|
||||
domain.ErrorCodeServerError,
|
||||
"cannot found any endpoints",
|
||||
"https://indieauth.net/source/#discovery-0",
|
||||
)
|
||||
|
||||
func ExtractEndpoints(resp *http.Response, rel string) []*domain.URL {
|
||||
results := make([]*domain.URL, 0)
|
||||
|
||||
|
@ -39,7 +46,7 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*domain.URL,
|
|||
|
||||
u := http.AcquireURI()
|
||||
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("cannot parse header endpoint: %w", err)
|
||||
}
|
||||
|
||||
results = append(results, &domain.URL{URI: u})
|
||||
|
@ -51,7 +58,7 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*domain.URL,
|
|||
func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, error) {
|
||||
endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), nil).Rels[rel]
|
||||
if !ok || len(endpoints) == 0 {
|
||||
return nil, nil
|
||||
return nil, ErrEndpointNotExist
|
||||
}
|
||||
|
||||
results := make([]*domain.URL, 0)
|
||||
|
@ -59,7 +66,7 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, e
|
|||
for i := range endpoints {
|
||||
u := http.AcquireURI()
|
||||
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(endpoints[i])); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("cannot parse body endpoint: %w", err)
|
||||
}
|
||||
|
||||
results = append(results, &domain.URL{URI: u})
|
||||
|
@ -71,29 +78,30 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, e
|
|||
func ExtractMetadata(resp *http.Response, client *http.Client) (*domain.Metadata, error) {
|
||||
endpoints := ExtractEndpoints(resp, "indieauth-metadata")
|
||||
if len(endpoints) == 0 {
|
||||
return nil, nil
|
||||
return nil, ErrEndpointNotExist
|
||||
}
|
||||
|
||||
_, body, err := client.Get(nil, endpoints[len(endpoints)-1].String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err)
|
||||
}
|
||||
|
||||
result := new(domain.Metadata)
|
||||
if err = json.Unmarshal(body, result); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func ExtractProperty(resp *http.Response, t, key string) []interface{} {
|
||||
func ExtractProperty(resp *http.Response, itemType, key string) []interface{} {
|
||||
//nolint: exhaustivestruct // only Host part in url.URL is needed
|
||||
data := microformats.Parse(bytes.NewReader(resp.Body()), &url.URL{
|
||||
Host: string(resp.Header.Peek(http.HeaderHost)),
|
||||
})
|
||||
|
||||
for _, item := range data.Items {
|
||||
if !contains(item.Type, t) {
|
||||
if !contains(item.Type, itemType) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
24
main.go
24
main.go
|
@ -55,7 +55,8 @@ import (
|
|||
|
||||
const (
|
||||
DefaultCacheDuration time.Duration = 8760 * time.Hour // NOTE(toby3d): year
|
||||
DefaultPort int = 3000
|
||||
DefaultReadTimeout time.Duration = 5 * time.Second
|
||||
DefaultWriteTimeout time.Duration = 10 * time.Second
|
||||
)
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
|
@ -126,7 +127,7 @@ func init() {
|
|||
client.RedirectURI = []*domain.URL{redirectURI}
|
||||
}
|
||||
|
||||
//nolint: funlen
|
||||
//nolint: funlen, cyclop // "god object" and the entry point of all modules
|
||||
func main() {
|
||||
var (
|
||||
tokens token.Repository
|
||||
|
@ -146,7 +147,7 @@ func main() {
|
|||
}
|
||||
|
||||
tokens = tokensqlite3repo.NewSQLite3TokenRepository(store)
|
||||
sessions = sessionsqlite3repo.NewSQLite3SessionRepository(config, store)
|
||||
sessions = sessionsqlite3repo.NewSQLite3SessionRepository(store)
|
||||
tickets = ticketsqlite3repo.NewSQLite3TicketRepository(store, config)
|
||||
case "memory":
|
||||
store := new(sync.Map)
|
||||
|
@ -160,12 +161,11 @@ func main() {
|
|||
go sessions.GC()
|
||||
|
||||
matcher := language.NewMatcher(message.DefaultCatalog.Languages())
|
||||
//nolint: exhaustivestruct // too many options
|
||||
httpClient := &http.Client{
|
||||
Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()),
|
||||
MaxConnDuration: 10 * time.Second,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
MaxConnWaitTimeout: 10 * time.Second,
|
||||
Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()),
|
||||
ReadTimeout: DefaultReadTimeout,
|
||||
WriteTimeout: DefaultWriteTimeout,
|
||||
}
|
||||
ticketService := ticketucase.NewTicketUseCase(tickets, httpClient, config)
|
||||
tokenService := tokenucase.NewTokenUseCase(tokens, sessions, config)
|
||||
|
@ -187,6 +187,7 @@ func main() {
|
|||
Matcher: matcher,
|
||||
Config: config,
|
||||
}).Register(r)
|
||||
//nolint: exhaustivestruct// too many options
|
||||
r.ServeFilesCustom(path.Join(config.Server.StaticURLPrefix, "{filepath:*}"), &http.FS{
|
||||
Root: config.Server.StaticRootPath,
|
||||
CacheDuration: DefaultCacheDuration,
|
||||
|
@ -200,11 +201,12 @@ func main() {
|
|||
r.GET("/debug/pprof/{filepath:*}", pprofhandler.PprofHandler)
|
||||
}
|
||||
|
||||
//nolint: exhaustivestruct
|
||||
server := &http.Server{
|
||||
Name: fmt.Sprintf("IndieAuth/0.1 (+%s)", config.Server.GetAddress()),
|
||||
Handler: r.Handler,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ReadTimeout: DefaultReadTimeout,
|
||||
WriteTimeout: DefaultWriteTimeout,
|
||||
DisableKeepalive: true,
|
||||
ReduceMemoryUsage: true,
|
||||
SecureErrorLogMessage: true,
|
||||
|
@ -212,7 +214,7 @@ func main() {
|
|||
}
|
||||
|
||||
done := make(chan os.Signal, 1)
|
||||
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL)
|
||||
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
if cpuProfilePath != "" {
|
||||
cpuProfile, err := os.Create(cpuProfilePath)
|
||||
|
|
Loading…
Reference in New Issue