🚨 Removed linter warnings

This commit is contained in:
Maxim Lebedev 2022-02-01 22:27:48 +05:00
parent 7680845f74
commit 59d4c4988a
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
60 changed files with 877 additions and 886 deletions

View File

@ -1,15 +1,54 @@
---
run:
tests: true
skip-dirs:
- locales
- testdata
- web
skip-dirs-use-default: true
skip-files:
- ".*_gen\\.go$"
output:
sort-results: true
linters-settings:
lll:
tab-width: 8
gci:
local-prefixes: source.toby3d.me
goimports:
local-prefixes: source.toby3d.me
ireturn:
allow:
- "(Repository|UseCase)$"
- error
- stdlib
lll:
tab-width: 8
varnamelen:
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-names:
- ctx # context
- db # dataBase
- err # error
- i # index
- ip
- ln # listener
- me
- ok
- tc # testCase
- ts # timeStamp
- tx # transaction
ignore-decls:
- "cid *domain.ClientID"
- "ctx *fasthttp.RequestCtx"
- "ctx context.Context"
- "i int"
- "me *domain.Me"
- "r *router.Router"
linters:
enable-all: true
disable:
- godox
issues:
exclude-rules:
- source: "^//go:generate "

1
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/hashicorp/go-retryablehttp v0.7.0 // indirect
github.com/jmoiron/sqlx v1.3.4
github.com/klauspost/compress v1.14.2 // indirect
github.com/lestrrat-go/iter v1.0.1 // indirect
github.com/lestrrat-go/jwx v1.2.18
github.com/mattn/go-mastodon v0.0.4
github.com/spf13/afero v1.8.0 // indirect

View File

@ -23,7 +23,7 @@ import (
)
type (
AuthorizeRequest struct {
AuthAuthorizeRequest struct {
// Indicates to the authorization server that an authorization
// code should be returned as the response.
ResponseType domain.ResponseType `form:"response_type"` // code
@ -58,7 +58,7 @@ type (
Me *domain.Me `form:"me"`
}
VerifyRequest struct {
AuthVerifyRequest struct {
ClientID *domain.ClientID `form:"client_id"`
Me *domain.Me `form:"me"`
RedirectURI *domain.URL `form:"redirect_uri"`
@ -71,7 +71,7 @@ type (
Provider string `form:"provider"`
}
ExchangeRequest struct {
AuthExchangeRequest struct {
GrantType domain.GrantType `form:"grant_type"` // authorization_code
// The authorization code received from the authorization
@ -91,50 +91,52 @@ type (
CodeVerifier string `form:"code_verifier"`
}
ExchangeResponse struct {
AuthExchangeResponse struct {
Me *domain.Me `json:"me"`
}
NewRequestHandlerOptions struct {
Auth auth.UseCase
Clients client.UseCase
Config *domain.Config
Matcher language.Matcher
Providers []*domain.Provider
Auth auth.UseCase
Clients client.UseCase
Config *domain.Config
Matcher language.Matcher
}
RequestHandler struct {
clients client.UseCase
config *domain.Config
matcher language.Matcher
useCase auth.UseCase
providers []*domain.Provider
clients client.UseCase
config *domain.Config
matcher language.Matcher
useCase auth.UseCase
}
)
func NewRequestHandler(opts NewRequestHandlerOptions) *RequestHandler {
return &RequestHandler{
clients: opts.Clients,
config: opts.Config,
matcher: opts.Matcher,
useCase: opts.Auth,
providers: opts.Providers,
clients: opts.Clients,
config: opts.Config,
matcher: opts.Matcher,
useCase: opts.Auth,
}
}
func (h *RequestHandler) Register(r *router.Router) {
chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{
CookieSameSite: http.CookieSameSiteStrictMode,
CookieName: "_csrf",
TokenLookup: "form:_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
Skipper: func(ctx *http.RequestCtx) bool {
matched, _ := path.Match("/api/*", string(ctx.Path()))
return ctx.IsPost() && matched
},
CookieMaxAge: 0,
CookieSameSite: http.CookieSameSiteStrictMode,
ContextKey: "",
CookieDomain: "",
CookieName: "_csrf",
CookiePath: "",
TokenLookup: "form:_csrf",
TokenLength: 0,
CookieSecure: true,
CookieHTTPOnly: true,
}),
middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{
Skipper: func(ctx *http.RequestCtx) bool {
@ -145,15 +147,14 @@ func (h *RequestHandler) Register(r *router.Router) {
return !ctx.IsPost() || !matched || providerMatched
},
Validator: func(ctx *http.RequestCtx, login, password string) (bool, error) {
userMatch := subtle.ConstantTimeCompare(
[]byte(login), []byte(h.config.IndieAuth.Username),
)
passMatch := subtle.ConstantTimeCompare(
[]byte(password), []byte(h.config.IndieAuth.Password),
)
userMatch := subtle.ConstantTimeCompare([]byte(login),
[]byte(h.config.IndieAuth.Username))
passMatch := subtle.ConstantTimeCompare([]byte(password),
[]byte(h.config.IndieAuth.Password))
return userMatch == 1 && passMatch == 1, nil
},
Realm: "",
}),
middleware.LogFmt(),
}
@ -164,7 +165,6 @@ func (h *RequestHandler) Register(r *router.Router) {
}
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
req := new(AuthorizeRequest)
ctx.SetContentType(common.MIMETextHTMLCharsetUTF8)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
@ -175,6 +175,7 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
Printer: message.NewPrinter(tag),
}
req := new(AuthAuthorizeRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
web.WriteTemplate(ctx, &web.ErrorPage{
@ -213,15 +214,16 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
csrf, _ := ctx.UserValue(middleware.DefaultCSRFConfig.ContextKey).([]byte)
web.WriteTemplate(ctx, &web.AuthorizePage{
BaseOf: baseOf,
Client: client,
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
CSRF: csrf,
Scope: req.Scope,
Client: client,
Me: req.Me,
RedirectURI: req.RedirectURI,
CodeChallengeMethod: req.CodeChallengeMethod,
ResponseType: req.ResponseType,
Scope: req.Scope,
CodeChallenge: req.CodeChallenge,
State: req.State,
Providers: make([]*domain.Provider, 0), // TODO(toby3d)
})
}
@ -230,22 +232,23 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx)
req := new(VerifyRequest)
req := new(AuthVerifyRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
u := http.AcquireURI()
defer http.ReleaseURI(u)
req.RedirectURI.CopyTo(u)
redirectURL := http.AcquireURI()
defer http.ReleaseURI(redirectURL)
req.RedirectURI.CopyTo(redirectURL)
if strings.EqualFold(req.Authorize, "deny") {
domain.NewError(domain.ErrorCodeAccessDenied, "user deny authorization request", "", req.State).
SetReirectURI(u)
ctx.Redirect(u.String(), http.StatusFound)
SetReirectURI(redirectURL)
ctx.Redirect(redirectURL.String(), http.StatusFound)
return
}
@ -260,7 +263,8 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
})
if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
@ -270,10 +274,10 @@ func (h *RequestHandler) handleVerify(ctx *http.RequestCtx) {
"iss": h.config.Server.GetRootURL(),
"state": req.State,
} {
u.QueryArgs().Set(key, val)
redirectURL.QueryArgs().Set(key, val)
}
ctx.Redirect(u.String(), http.StatusFound)
ctx.Redirect(redirectURL.String(), http.StatusFound)
}
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
@ -281,10 +285,11 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest)
req := new(AuthExchangeRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
@ -297,17 +302,18 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
})
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
encoder.Encode(&ExchangeResponse{
_ = encoder.Encode(&AuthExchangeResponse{
Me: me,
})
}
func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error {
func (r *AuthAuthorizeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs(), r); err != nil {
if errors.As(err, indieAuthError) {
@ -341,7 +347,7 @@ func (r *AuthorizeRequest) bind(ctx *http.RequestCtx) error {
return nil
}
func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
func (r *AuthVerifyRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
@ -369,6 +375,8 @@ func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
)
}
// NOTE(toby3d): backwards-compatible support.
// See: https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
if r.ResponseType == domain.ResponseTypeID {
r.ResponseType = domain.ResponseTypeCode
}
@ -386,7 +394,7 @@ func (r *VerifyRequest) bind(ctx *http.RequestCtx) error {
return nil
}
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
func (r *AuthExchangeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
if errors.As(err, indieAuthError) {

View File

@ -35,15 +35,15 @@ func TestRender(t *testing.T) {
require.NoError(t, s.SetProvider(provider))
me := domain.TestMe(t, "https://user.example.net")
c := domain.TestClient(t)
client := domain.TestClient(t)
config := domain.TestConfig(t)
store := new(sync.Map)
user := domain.TestUser(t)
store.Store(path.Join(userrepo.DefaultPathPrefix, me.String()), user)
store.Store(path.Join(clientrepo.DefaultPathPrefix, c.ID.String()), c)
store.Store(path.Join(clientrepo.DefaultPathPrefix, client.ID.String()), client)
store.Store(path.Join(profilerepo.DefaultPathPrefix, me.String()), user.Profile)
r := router.New()
router := router.New()
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
Clients: clientucase.NewClientUseCase(clientrepo.NewMemoryClientRepository(store)),
Config: config,
@ -52,35 +52,35 @@ func TestRender(t *testing.T) {
sessionrepo.NewMemorySessionRepository(config, store),
config,
),
}).Register(r)
}).Register(router)
client, _, cleanup := httptest.New(t, r.Handler)
httpClient, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup)
u := http.AcquireURI()
defer http.ReleaseURI(u)
u.Update("https://example.com/authorize")
uri := http.AcquireURI()
defer http.ReleaseURI(uri)
uri.Update("https://example.com/authorize")
for k, v := range map[string]string{
"client_id": c.ID.String(),
for key, val := range map[string]string{
"client_id": client.ID.String(),
"code_challenge": "OfYAxt8zU2dAPDWQxTAUIteRzMsoj9QBdMIVEDOErUo",
"code_challenge_method": domain.CodeChallengeMethodS256.String(),
"me": me.String(),
"redirect_uri": c.RedirectURI[0].String(),
"redirect_uri": client.RedirectURI[0].String(),
"response_type": domain.ResponseTypeCode.String(),
"scope": "profile email",
"state": "1234567890",
} {
u.QueryArgs().Set(k, v)
uri.QueryArgs().Set(key, val)
}
req := httptest.NewRequest(http.MethodGet, u.String(), nil)
req := httptest.NewRequest(http.MethodGet, uri.String(), nil)
defer http.ReleaseRequest(req)
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
require.NoError(t, client.Do(req, resp))
require.NoError(t, httpClient.Do(req, resp))
assert.Equal(t, http.StatusOK, resp.StatusCode())
assert.Contains(t, string(resp.Body()), `Authorize application`)

View File

@ -18,7 +18,7 @@ import (
)
type (
CallbackRequest struct {
ClientCallbackRequest struct {
Iss *domain.ClientID `form:"iss"`
Code string `form:"code"`
Error string `form:"error"`
@ -60,13 +60,14 @@ func (h *RequestHandler) Register(r *router.Router) {
}
func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
redirectUri := make([]string, len(h.client.RedirectURI))
redirect := make([]string, len(h.client.RedirectURI))
for i := range h.client.RedirectURI {
redirectUri[i] = h.client.RedirectURI[i].String()
redirect[i] = h.client.RedirectURI[i].String()
}
ctx.Response.Header.Set(
http.HeaderLink, `<`+strings.Join(redirectUri, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
http.HeaderLink, `<`+strings.Join(redirect, `>; rel="redirect_uri", `)+`>; rel="redirect_uri"`,
)
tags, _, _ := language.ParseAcceptLanguage(string(ctx.Request.Header.Peek(http.HeaderAcceptLanguage)))
@ -87,7 +88,7 @@ func (h *RequestHandler) handleRender(ctx *http.RequestCtx) {
}
func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
req := new(CallbackRequest)
req := new(ClientCallbackRequest)
if err := req.bind(ctx); err != nil {
ctx.Error(err.Error(), http.StatusInternalServerError)
@ -134,7 +135,7 @@ func (h *RequestHandler) handleCallback(ctx *http.RequestCtx) {
})
}
func (req *CallbackRequest) bind(ctx *http.RequestCtx) error {
func (req *ClientCallbackRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.QueryArgs(), req); err != nil {

View File

@ -25,16 +25,16 @@ func TestRead(t *testing.T) {
store := new(sync.Map)
config := domain.TestConfig(t)
r := router.New()
router := router.New()
delivery.NewRequestHandler(delivery.NewRequestHandlerOptions{
Client: domain.TestClient(t),
Config: config,
Matcher: language.NewMatcher(message.DefaultCatalog.Languages()),
Tokens: tokenucase.NewTokenUseCase(tokenrepo.NewMemoryTokenRepository(store),
sessionrepo.NewMemorySessionRepository(config, store), config),
}).Register(r)
}).Register(router)
client, _, cleanup := httptest.New(t, r.Handler)
client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup)
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/", nil)

View File

@ -10,4 +10,8 @@ type Repository interface {
Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error)
}
var ErrNotExist error = domain.NewError(domain.ErrorCodeInvalidClient, "client with the specified ID does not exist", "")
var ErrNotExist error = domain.NewError(
domain.ErrorCodeInvalidClient,
"client with the specified ID does not exist",
"",
)

View File

@ -32,10 +32,10 @@ func NewHTTPClientRepository(c *http.Client) client.Repository {
}
}
func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID) (*domain.Client, error) {
func (repo *httpClientRepository) Get(ctx context.Context, cid *domain.ClientID) (*domain.Client, error) {
req := http.AcquireRequest()
defer http.ReleaseRequest(req)
req.SetRequestURI(id.String())
req.SetRequestURI(cid.String())
req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse()
@ -50,7 +50,7 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
}
client := &domain.Client{
ID: id,
ID: cid,
RedirectURI: make([]*domain.URL, 0),
Logo: make([]*domain.URL, 0),
URL: make([]*domain.URL, 0),
@ -62,68 +62,52 @@ func (repo *httpClientRepository) Get(ctx context.Context, id *domain.ClientID)
return client, nil
}
//nolint: gocognit, cyclop
func extract(dst *domain.Client, src *http.Response) {
for _, u := range util.ExtractEndpoints(src, relRedirectURI) {
if containsURL(dst.RedirectURI, u) {
continue
for _, endpoint := range util.ExtractEndpoints(src, relRedirectURI) {
if !containsURL(dst.RedirectURI, endpoint) {
dst.RedirectURI = append(dst.RedirectURI, endpoint)
}
dst.RedirectURI = append(dst.RedirectURI, u)
}
for _, t := range []string{hXApp, hApp} {
for _, name := range util.ExtractProperty(src, t, propertyName) {
n, ok := name.(string)
if !ok || containsString(dst.Name, n) {
continue
for _, itemType := range []string{hXApp, hApp} {
for _, name := range util.ExtractProperty(src, itemType, propertyName) {
if n, ok := name.(string); ok && !containsString(dst.Name, n) {
dst.Name = append(dst.Name, n)
}
dst.Name = append(dst.Name, n)
}
for _, logo := range util.ExtractProperty(src, t, propertyLogo) {
var err error
for _, logo := range util.ExtractProperty(src, itemType, propertyLogo) {
var (
uri *domain.URL
err error
)
var u *domain.URL
switch l := logo.(type) {
case string:
u, err = domain.ParseURL(l)
uri, err = domain.ParseURL(l)
case map[string]string:
value, ok := l["value"]
if !ok {
continue
if value, ok := l["value"]; ok {
uri, err = domain.ParseURL(value)
}
u, err = domain.ParseURL(value)
}
if err != nil {
if err != nil || containsURL(dst.Logo, uri) {
continue
}
if containsURL(dst.Logo, u) {
continue
}
dst.Logo = append(dst.Logo, u)
dst.Logo = append(dst.Logo, uri)
}
for _, url := range util.ExtractProperty(src, t, propertyURL) {
l, ok := url.(string)
for _, property := range util.ExtractProperty(src, itemType, propertyURL) {
prop, ok := property.(string)
if !ok {
continue
}
u, err := domain.ParseURL(l)
if err != nil {
continue
if u, err := domain.ParseURL(prop); err == nil || !containsURL(dst.URL, u) {
dst.URL = append(dst.URL, u)
}
if containsURL(dst.URL, u) {
continue
}
dst.URL = append(dst.URL, u)
}
}
}

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain
import (
@ -14,7 +15,7 @@ type Action struct {
uid string
}
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
//nolint: gochecknoglobals // structs cannot be constants
var (
ActionUndefined = Action{uid: ""}

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain_test
import (
@ -12,13 +13,10 @@ func TestParseAction(t *testing.T) {
for _, tc := range []struct {
in string
out domain.Action
}{{
in: "revoke",
out: domain.ActionRevoke,
}, {
in: "ticket",
out: domain.ActionTicket,
}} {
}{
{in: "revoke", out: domain.ActionRevoke},
{in: "ticket", out: domain.ActionTicket},
} {
tc := tc
t.Run(tc.in, func(t *testing.T) {
@ -73,15 +71,10 @@ func TestAction_String(t *testing.T) {
name string
in domain.Action
out string
}{{
name: "revoke",
in: domain.ActionRevoke,
out: "revoke",
}, {
name: "ticket",
in: domain.ActionTicket,
out: "ticket",
}} {
}{
{name: "revoke", in: domain.ActionRevoke, out: "revoke"},
{name: "ticket", in: domain.ActionTicket, out: "ticket"},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {

View File

@ -8,111 +8,94 @@ import (
"testing"
http "github.com/valyala/fasthttp"
"golang.org/x/xerrors"
"inet.af/netaddr"
)
// ClientID is a URL client identifier.
type ClientID struct {
clientID *http.URI
valid bool
}
//nolint: gochecknoglobals
//nolint: gochecknoglobals // slices cannot be constants
var (
localhostIPv4 = netaddr.MustParseIP("127.0.0.1")
localhostIPv6 = netaddr.MustParseIP("::1")
)
// ParseClientID parse string as client ID URL identifier.
//nolint: funlen
//nolint: funlen, cyclop
func ParseClientID(src string) (*ClientID, error) {
cid := http.AcquireURI()
if err := cid.Parse(nil, []byte(src)); err != nil {
return nil, Error{
Code: ErrorCodeInvalidRequest,
Description: err.Error(),
URI: "https://indieauth.net/source/#client-identifier",
State: "",
frame: xerrors.Caller(1),
}
return nil, NewError(
ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#client-identifier",
)
}
scheme := string(cid.Scheme())
if scheme != "http" && scheme != "https" {
return nil, Error{
Code: ErrorCodeInvalidRequest,
Description: "client identifier URL MUST have either an https or http scheme",
URI: "https://indieauth.net/source/#client-identifier",
State: "",
frame: xerrors.Caller(1),
}
return nil, NewError(
ErrorCodeInvalidRequest,
"client identifier URL MUST have either an https or http scheme",
"https://indieauth.net/source/#client-identifier",
)
}
path := string(cid.PathOriginal())
if path == "" || strings.Contains(path, "/.") || strings.Contains(path, "/..") {
return nil, Error{
Code: ErrorCodeInvalidRequest,
Description: "client identifier URL MUST contain a path component and MUST NOT contain " +
return nil, NewError(
ErrorCodeInvalidRequest,
"client identifier URL MUST contain a path component and MUST NOT contain "+
"single-dot or double-dot path segments",
URI: "https://indieauth.net/source/#client-identifier",
State: "",
frame: xerrors.Caller(1),
}
"https://indieauth.net/source/#client-identifier",
)
}
if cid.Hash() != nil {
return nil, Error{
Code: ErrorCodeInvalidRequest,
Description: "client identifier URL MUST NOT contain a fragment component",
URI: "https://indieauth.net/source/#client-identifier",
State: "",
frame: xerrors.Caller(1),
}
return nil, NewError(
ErrorCodeInvalidRequest,
"client identifier URL MUST NOT contain a fragment component",
"https://indieauth.net/source/#client-identifier",
)
}
if cid.Username() != nil || cid.Password() != nil {
return nil, Error{
Code: ErrorCodeInvalidRequest,
Description: "client identifier URL MUST NOT contain a username or password component",
URI: "https://indieauth.net/source/#client-identifier",
State: "",
frame: xerrors.Caller(1),
}
return nil, NewError(
ErrorCodeInvalidRequest,
"client identifier URL MUST NOT contain a username or password component",
"https://indieauth.net/source/#client-identifier",
)
}
domain := string(cid.Host())
if domain == "" {
return nil, Error{
Code: ErrorCodeInvalidRequest,
Description: "client host name MUST be domain name or a loopback interface",
URI: "https://indieauth.net/source/#client-identifier",
State: "",
frame: xerrors.Caller(1),
}
return nil, NewError(
ErrorCodeInvalidRequest,
"client host name MUST be domain name or a loopback interface",
"https://indieauth.net/source/#client-identifier",
)
}
ip, err := netaddr.ParseIP(domain)
if err != nil {
ipPort, err := netaddr.ParseIPPort(domain)
if err != nil {
return &ClientID{
clientID: cid,
}, nil
//nolint: nilerr // ClientID does not contain an IP address, so it is valid
return &ClientID{clientID: cid}, nil
}
ip = ipPort.IP()
}
if !ip.IsLoopback() && ip.Compare(localhostIPv4) != 0 && ip.Compare(localhostIPv6) != 0 {
return nil, Error{
Code: ErrorCodeInvalidRequest,
Description: "client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 " +
return nil, NewError(
ErrorCodeInvalidRequest,
"client identifier URL MUST NOT be IPv4 or IPv6 addresses except for IPv4 "+
"127.0.0.1 or IPv6 [::1]",
URI: "https://indieauth.net/source/#client-identifier",
State: "",
frame: xerrors.Caller(1),
}
"https://indieauth.net/source/#client-identifier",
)
}
return &ClientID{
@ -126,7 +109,7 @@ func TestClientID(tb testing.TB) *ClientID {
clientID, err := ParseClientID("https://app.example.com/")
if err != nil {
tb.Fatalf("%+v", err)
tb.Fatal(err)
}
return clientID
@ -167,23 +150,28 @@ func (cid ClientID) MarshalJSON() ([]byte, error) {
}
// URI returns copy of parsed *fasthttp.URI.
// This copy MUST be released via fasthttp.ReleaseURI.
//
// WARN(toby3d): This copy MUST be released via fasthttp.ReleaseURI.
func (cid ClientID) URI() *http.URI {
u := http.AcquireURI()
cid.clientID.CopyTo(u)
uri := http.AcquireURI()
cid.clientID.CopyTo(uri)
return u
return uri
}
// URL returns url.URL representation of client ID.
func (cid ClientID) URL() *url.URL {
return &url.URL{
Scheme: string(cid.clientID.Scheme()),
Host: string(cid.clientID.Host()),
Path: string(cid.clientID.Path()),
RawPath: string(cid.clientID.PathOriginal()),
RawQuery: string(cid.clientID.QueryString()),
Fragment: string(cid.clientID.Hash()),
ForceQuery: false,
Fragment: string(cid.clientID.Hash()),
Host: string(cid.clientID.Host()),
Opaque: "",
Path: string(cid.clientID.Path()),
RawFragment: "",
RawPath: string(cid.clientID.PathOriginal()),
RawQuery: string(cid.clientID.QueryString()),
Scheme: string(cid.clientID.Scheme()),
User: nil,
}
}

View File

@ -7,7 +7,6 @@ import (
"source.toby3d.me/website/indieauth/internal/domain"
)
//nolint: funlen
func TestParseClientID(t *testing.T) {
t.Parallel()
@ -15,51 +14,19 @@ func TestParseClientID(t *testing.T) {
name string
in string
expError bool
}{{
name: "valid",
in: "https://example.com/",
expError: false,
}, {
name: "valid path",
in: "https://example.com/username",
expError: false,
}, {
name: "valid query",
in: "https://example.com/users?id=100",
expError: false,
}, {
name: "valid port",
in: "https://example.com:8443/",
expError: false,
}, {
name: "valid loopback",
in: "https://127.0.0.1:8443/",
expError: false,
}, {
name: "missing scheme",
in: "example.com",
expError: true,
}, {
name: "invalid scheme",
in: "mailto:user@example.com",
expError: true,
}, {
name: "invalid double-dot path",
in: "https://example.com/foo/../bar",
expError: true,
}, {
name: "invalid fragment",
in: "https://example.com/#me",
expError: true,
}, {
name: "invalid user",
in: "https://user:pass@example.com/",
expError: true,
}, {
name: "host is an IP address",
in: "https://172.28.92.51/",
expError: true,
}} {
}{
{name: "valid", in: "https://example.com/", expError: false},
{name: "valid path", in: "https://example.com/username", expError: false},
{name: "valid query", in: "https://example.com/users?id=100", expError: false},
{name: "valid port", in: "https://example.com:8443/", expError: false},
{name: "valid loopback", in: "https://127.0.0.1:8443/", expError: false},
{name: "missing scheme", in: "example.com", expError: true},
{name: "invalid scheme", in: "mailto:user@example.com", expError: true},
{name: "invalid double-dot path", in: "https://example.com/foo/../bar", expError: true},
{name: "invalid fragment", in: "https://example.com/#me", expError: true},
{name: "invalid user", in: "https://user:pass@example.com/", expError: true},
{name: "host is an IP address", in: "https://172.28.92.51/", expError: true},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
@ -131,8 +98,7 @@ func TestClientID_MarshalJSON(t *testing.T) {
func TestClientID_String(t *testing.T) {
t.Parallel()
cid := domain.TestClientID(t)
if result := cid.String(); result != fmt.Sprint(cid) {
t.Errorf("Strig() = %s, want %s", result, fmt.Sprint(cid))
if cid := domain.TestClientID(t); cid.String() != fmt.Sprint(cid) {
t.Errorf("String() = %s, want %s", cid.String(), fmt.Sprint(cid))
}
}

View File

@ -15,13 +15,10 @@ func TestClient_ValidateRedirectURI(t *testing.T) {
for _, tc := range []struct {
name string
in *domain.URL
}{{
name: "client_id prefix",
in: domain.TestURL(t, fmt.Sprint(client.ID, "/callback")),
}, {
name: "registered redirect_uri",
in: client.RedirectURI[len(client.RedirectURI)-1],
}} {
}{
{name: "client_id prefix", in: domain.TestURL(t, fmt.Sprint(client.ID, "/callback"))},
{name: "registered redirect_uri", in: client.RedirectURI[len(client.RedirectURI)-1]},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {

View File

@ -1,5 +1,6 @@
package domain
//nolint: gosec // support old clients
import (
"crypto/md5"
"crypto/sha1"
@ -21,7 +22,7 @@ type CodeChallengeMethod struct {
uid string
}
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
//nolint: gochecknoglobals // structs cannot be constants
var (
CodeChallengeMethodUndefined = CodeChallengeMethod{
uid: "",
@ -34,12 +35,14 @@ var (
}
CodeChallengeMethodMD5 = CodeChallengeMethod{
uid: "MD5",
uid: "MD5",
//nolint: gosec // support old clients
hash: md5.New(),
}
CodeChallengeMethodS1 = CodeChallengeMethod{
uid: "S1",
uid: "S1",
//nolint: gosec // support old clients
hash: sha1.New(),
}
@ -60,7 +63,7 @@ var ErrCodeChallengeMethodUnknown error = NewError(
"https://indieauth.net/source/#authorization-request",
)
//nolint: gochecknoglobals // NOTE(toby3d): maps cannot be constants
//nolint: gochecknoglobals // maps cannot be constants
var slugsMethods = map[string]CodeChallengeMethod{
CodeChallengeMethodMD5.uid: CodeChallengeMethodMD5,
CodeChallengeMethodPLAIN.uid: CodeChallengeMethodPLAIN,

View File

@ -1,5 +1,6 @@
package domain_test
//nolint: gosec // support old clients
import (
"crypto/md5"
"crypto/sha1"
@ -23,32 +24,14 @@ func TestParseCodeChallengeMethod(t *testing.T) {
in string
out domain.CodeChallengeMethod
expError bool
}{{
expError: true,
name: "invalid",
in: "und",
out: domain.CodeChallengeMethodUndefined,
}, {
name: "PLAIN",
in: "plain",
out: domain.CodeChallengeMethodPLAIN,
}, {
name: "MD5",
in: "Md5",
out: domain.CodeChallengeMethodMD5,
}, {
name: "S1",
in: "S1",
out: domain.CodeChallengeMethodS1,
}, {
name: "S256",
in: "S256",
out: domain.CodeChallengeMethodS256,
}, {
name: "S512",
in: "S512",
out: domain.CodeChallengeMethodS512,
}} {
}{
{name: "invalid", in: "und", out: domain.CodeChallengeMethodUndefined, expError: true},
{name: "PLAIN", in: "plain", out: domain.CodeChallengeMethodPLAIN, expError: false},
{name: "MD5", in: "Md5", out: domain.CodeChallengeMethodMD5, expError: false},
{name: "S1", in: "S1", out: domain.CodeChallengeMethodS1, expError: false},
{name: "S256", in: "S256", out: domain.CodeChallengeMethodS256, expError: false},
{name: "S512", in: "S512", out: domain.CodeChallengeMethodS512, expError: false},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
@ -107,27 +90,13 @@ func TestCodeChallengeMethod_String(t *testing.T) {
name string
in domain.CodeChallengeMethod
out string
}{{
name: "plain",
in: domain.CodeChallengeMethodPLAIN,
out: "PLAIN",
}, {
name: "md5",
in: domain.CodeChallengeMethodMD5,
out: "MD5",
}, {
name: "s1",
in: domain.CodeChallengeMethodS1,
out: "S1",
}, {
name: "s256",
in: domain.CodeChallengeMethodS256,
out: "S256",
}, {
name: "s512",
in: domain.CodeChallengeMethodS512,
out: "S512",
}} {
}{
{name: "plain", in: domain.CodeChallengeMethodPLAIN, out: "PLAIN"},
{name: "md5", in: domain.CodeChallengeMethodMD5, out: "MD5"},
{name: "s1", in: domain.CodeChallengeMethodS1, out: "S1"},
{name: "s256", in: domain.CodeChallengeMethodS256, out: "S256"},
{name: "s512", in: domain.CodeChallengeMethodS512, out: "S512"},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
@ -141,7 +110,7 @@ func TestCodeChallengeMethod_String(t *testing.T) {
}
}
//nolint: funlen
//nolint: gosec // support old clients
func TestCodeChallengeMethod_Validate(t *testing.T) {
t.Parallel()
@ -155,42 +124,15 @@ func TestCodeChallengeMethod_Validate(t *testing.T) {
in domain.CodeChallengeMethod
name string
expError bool
}{{
name: "invalid",
in: domain.CodeChallengeMethodS256,
hash: md5.New(),
expError: true,
}, {
name: "MD5",
in: domain.CodeChallengeMethodMD5,
hash: md5.New(),
expError: false,
}, {
name: "plain",
in: domain.CodeChallengeMethodPLAIN,
hash: nil,
expError: false,
}, {
name: "S1",
in: domain.CodeChallengeMethodS1,
hash: sha1.New(),
expError: false,
}, {
name: "S256",
in: domain.CodeChallengeMethodS256,
hash: sha256.New(),
expError: false,
}, {
name: "S512",
in: domain.CodeChallengeMethodS512,
hash: sha512.New(),
expError: false,
}, {
name: "undefined",
in: domain.CodeChallengeMethodUndefined,
hash: nil,
expError: true,
}} {
}{
{name: "invalid", in: domain.CodeChallengeMethodS256, hash: md5.New(), expError: true},
{name: "MD5", in: domain.CodeChallengeMethodMD5, hash: md5.New(), expError: false},
{name: "plain", in: domain.CodeChallengeMethodPLAIN, hash: nil, expError: false},
{name: "S1", in: domain.CodeChallengeMethodS1, hash: sha1.New(), expError: false},
{name: "S256", in: domain.CodeChallengeMethodS256, hash: sha256.New(), expError: false},
{name: "S512", in: domain.CodeChallengeMethodS512, hash: sha512.New(), expError: false},
{name: "undefined", in: domain.CodeChallengeMethodUndefined, hash: nil, expError: true},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {

View File

@ -78,6 +78,7 @@ type (
)
// TestConfig returns a valid config for tests.
//nolint: gomnd // testing domain can contains non-standart values
func TestConfig(tb testing.TB) *Config {
tb.Helper()

View File

@ -12,12 +12,14 @@ type Email struct {
subAddress string
}
const DefaultEmailPartsLength int = 2
var ErrEmailInvalid error = NewError(ErrorCodeInvalidRequest, "cannot parse email", "")
// ParseEmail parse strings to email identifier.
func ParseEmail(src string) (*Email, error) {
parts := strings.Split(strings.TrimPrefix(src, "mailto:"), "@")
if len(parts) != 2 { //nolint: gomnd
if len(parts) != DefaultEmailPartsLength {
return nil, ErrEmailInvalid
}
@ -27,7 +29,7 @@ func ParseEmail(src string) (*Email, error) {
subAddress: "",
}
if userParts := strings.SplitN(parts[0], `+`, 2); len(userParts) > 1 {
if userParts := strings.SplitN(parts[0], `+`, DefaultEmailPartsLength); len(userParts) > 1 {
result.user = userParts[0]
result.subAddress = userParts[1]
}

View File

@ -14,19 +14,11 @@ func TestParseEmail(t *testing.T) {
name string
in string
out string
}{{
name: "simple",
in: "user@example.com",
out: "user@example.com",
}, {
name: "subAddress",
in: "user+suffix@example.com",
out: "user+suffix@example.com",
}, {
name: "mailto prefix",
in: "mailto:user@example.com",
out: "user@example.com",
}} {
}{
{name: "simple", in: "user@example.com", out: "user@example.com"},
{name: "subAddress", in: "user+suffix@example.com", out: "user+suffix@example.com"},
{name: "mailto prefix", in: "mailto:user@example.com", out: "user@example.com"},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
@ -47,7 +39,7 @@ func TestParseEmail(t *testing.T) {
func TestEmail_String(t *testing.T) {
t.Parallel()
email := domain.TestEmail(t)
email := domain.TestEmail(t) //nolint: ifshort
if result := email.String(); result != fmt.Sprint(email) {
t.Errorf("String() = %v, want %v", result, email)
}

View File

@ -10,18 +10,19 @@ import (
type (
// Error describes the format of a typical IndieAuth error.
//nolint: tagliatelle // RFC 6749 section 5.2
Error struct {
// A single error code.
Code ErrorCode `json:"error"`
// Human-readable ASCII text providing additional information, used to
// assist the client developer in understanding the error that occurred.
Description string `json:"error_description,omitempty"` //nolint: tagliatelle
Description string `json:"error_description,omitempty"`
// A URI identifying a human-readable web page with information about
// the error, used to provide the client developer with additional
// information about the error.
URI string `json:"error_uri,omitempty"` //nolint: tagliatelle
URI string `json:"error_uri,omitempty"`
// REQUIRED if a "state" parameter was present in the client
// authorization request. The exact value received from the
@ -33,17 +34,29 @@ type (
// ErrorCode represent error code described in RFC 6749.
ErrorCode struct {
uid string
uid string
status int
}
)
var (
ErrorCodeUndefined ErrorCode = ErrorCode{uid: ""}
// ErrorCodeUndefined describes an unrecognized error code.
ErrorCodeUndefined = ErrorCode{
uid: "",
status: 0,
}
// ErrorCodeAccessDenied describes the access_denied error code.
//
// RFC 6749 section 4.1.2.1: The resource owner or authorization server
// denied the request.
ErrorCodeAccessDenied ErrorCode = ErrorCode{uid: "access_denied"}
ErrorCodeAccessDenied = ErrorCode{
uid: "access_denied",
status: 0, // TODO(toby3d)
}
// ErrorCodeInvalidClient describes the invalid_client error code.
//
// RFC 6749 section 5.2: Client authentication failed (e.g., unknown
// client, no client authentication included, or unsupported
// authentication method).
@ -56,14 +69,26 @@ var (
// HTTP 401 (Unauthorized) status code and include the
// "WWW-Authenticate" response header field matching the authentication
// scheme used by the client.
ErrorCodeInvalidClient ErrorCode = ErrorCode{uid: "invalid_client"}
ErrorCodeInvalidClient = ErrorCode{
uid: "invalid_client",
status: 0, // TODO(toby3d)
}
// ErrorCodeInvalidGrant describes the invalid_grant error code.
//
// RFC 6749 section 5.2: The provided authorization grant (e.g.,
// authorization code, resource owner credentials) or refresh token is
// invalid, expired, revoked, does not match the redirection URI used in
// the authorization request, or was issued to another client.
ErrorCodeInvalidGrant ErrorCode = ErrorCode{uid: "invalid_grant"}
ErrorCodeInvalidGrant = ErrorCode{
uid: "invalid_grant",
status: 0, // TODO(toby3d)
}
// ErrorCodeInvalidRequest describes the invalid_request error code.
//
// IndieAuth: The request is not valid.
//
// RFC 6749 section 4.1.2.1: The request is missing a required
// parameter, includes an invalid parameter value, includes a parameter
// more than once, or is otherwise malformed.
@ -73,42 +98,91 @@ var (
// repeats a parameter, includes multiple credentials, utilizes more
// than one mechanism for authenticating the client, or is otherwise
// malformed.
ErrorCodeInvalidRequest ErrorCode = ErrorCode{uid: "invalid_request"}
ErrorCodeInvalidRequest = ErrorCode{
uid: "invalid_request",
status: http.StatusBadRequest,
}
// ErrorCodeInvalidScope describes the invalid_scope error code.
//
// RFC 6749 section 4.1.2.1: The requested scope is invalid, unknown, or
// malformed.
//
// RFC 6749 section 5.2: The requested scope is invalid, unknown,
// malformed, or exceeds the scope granted by the resource owner.
ErrorCodeInvalidScope ErrorCode = ErrorCode{uid: "invalid_scope"}
ErrorCodeInvalidScope = ErrorCode{
uid: "invalid_scope",
status: 0, // TODO(toby3d)
}
// ErrorCodeServerError describes the server_error error code.
//
// RFC 6749 section 4.1.2.1: The authorization server encountered an
// unexpected condition that prevented it from fulfilling the request.
// (This error code is needed because a 500 Internal Server Error HTTP
// status code cannot be returned to the client via an HTTP redirect.)
ErrorCodeServerError ErrorCode = ErrorCode{uid: "server_error"}
ErrorCodeServerError = ErrorCode{
uid: "server_error",
status: 0, // TODO(toby3d)
}
// ErrorCodeTemporarilyUnavailable describes the temporarily_unavailable error code.
//
// RFC 6749 section 4.1.2.1: The authorization server is currently
// unable to handle the request due to a temporary overloading or
// maintenance of the server. (This error code is needed because a 503
// Service Unavailable HTTP status code cannot be returned to the client
// via an HTTP redirect.)
ErrorCodeTemporarilyUnavailable ErrorCode = ErrorCode{uid: "temporarily_unavailable"}
ErrorCodeTemporarilyUnavailable = ErrorCode{
uid: "temporarily_unavailable",
status: 0, // TODO(toby3d)
}
// ErrorCodeUnauthorizedClient describes the unauthorized_client error code.
//
// RFC 6749 section 4.1.2.1: The client is not authorized to request an
// authorization code using this method.
//
// RFC 6749 section 5.2: The authenticated client is not authorized to
// use this authorization grant type.
ErrorCodeUnauthorizedClient ErrorCode = ErrorCode{uid: "unauthorized_client"}
ErrorCodeUnauthorizedClient = ErrorCode{
uid: "unauthorized_client",
status: 0, // TODO(toby3d)
}
// ErrorCodeUnsupportedGrantType describes the unsupported_grant_type error code.
//
// RFC 6749 section 5.2: The authorization grant type is not supported
// by the authorization server.
ErrorCodeUnsupportedGrantType ErrorCode = ErrorCode{uid: "unsupported_grant_type"}
ErrorCodeUnsupportedGrantType = ErrorCode{
uid: "unsupported_grant_type",
status: 0, // TODO(toby3d)
}
// ErrorCodeUnsupportedResponseType describes the unsupported_response_type error code.
//
// RFC 6749 section 4.1.2.1: The authorization server does not support
// obtaining an authorization code using this method.
ErrorCodeUnsupportedResponseType ErrorCode = ErrorCode{uid: "unsupported_response_type"}
ErrorCodeUnsupportedResponseType = ErrorCode{
uid: "unsupported_response_type",
status: 0, // TODO(toby3d)
}
// ErrorCodeInvalidToken describes the invalid_token error code.
//
// IndieAuth: The access token provided is expired, revoked, or invalid.
ErrorCodeInvalidToken = ErrorCode{
uid: "invalid_token",
status: http.StatusUnauthorized,
}
// ErrorCodeInsufficientScope describes the insufficient_scope error code.
//
// IndieAuth: The request requires higher privileges than provided.
ErrorCodeInsufficientScope = ErrorCode{
uid: "insufficient_scope",
status: http.StatusForbidden,
}
)
// String returns a string representation of the error code.
@ -128,45 +202,45 @@ func (e Error) Error() string {
}
// Format prints the stack as error detail.
func (e Error) Format(s fmt.State, r rune) {
xerrors.FormatError(e, s, r)
func (e Error) Format(state fmt.State, r rune) {
xerrors.FormatError(e, state, r)
}
// FormatError prints the receiver's error, if any.
func (e Error) FormatError(p xerrors.Printer) error {
p.Print(e.Code)
func (e Error) FormatError(printer xerrors.Printer) error {
printer.Print(e.Code)
if e.Description != "" {
p.Print(": ", e.Description)
printer.Print(": ", e.Description)
}
if !p.Detail() {
if !printer.Detail() {
return nil
}
e.frame.Format(p)
e.frame.Format(printer)
return nil
}
// SetReirectURI sets fasthttp.QueryArgs with the request state, code,
// description and error URI in the provided fasthttp.URI.
func (e Error) SetReirectURI(u *http.URI) {
if u == nil {
func (e Error) SetReirectURI(uri *http.URI) {
if uri == nil {
return
}
for k, v := range map[string]string{
for key, val := range map[string]string{
"error": e.Code.String(),
"error_description": e.Description,
"error_uri": e.URI,
"state": e.State,
} {
if v == "" {
if val == "" {
continue
}
u.QueryArgs().Set(k, v)
uri.QueryArgs().Set(key, val)
}
}

View File

@ -1,19 +1,21 @@
//nolint: dupl
package domain
import (
"errors"
"fmt"
"strconv"
"strings"
)
// GrantType represent fixed grant_type parameter.
//
// NOTE(toby3d): Encapsulate enums in structs for extra compile-time safety:
// https://threedots.tech/post/safer-enums-in-go/#struct-based-enums
type GrantType struct {
uid string
}
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
//nolint: gochecknoglobals // structs cannot be constants
var (
GrantTypeUndefined = GrantType{uid: ""}
GrantTypeAuthorizationCode = GrantType{uid: "authorization_code"}
@ -22,7 +24,11 @@ var (
GrantTypeTicket = GrantType{uid: "ticket"}
)
var ErrGrantTypeUnknown error = errors.New("unknown grant type")
var ErrGrantTypeUnknown error = NewError(
ErrorCodeInvalidGrant,
"unknown grant type",
"",
)
// ParseGrantType parse grant_type value as GrantType struct enum.
func ParseGrantType(uid string) (GrantType, error) {
@ -40,7 +46,7 @@ func ParseGrantType(uid string) (GrantType, error) {
func (gt *GrantType) UnmarshalForm(src []byte) error {
responseType, err := ParseGrantType(string(src))
if err != nil {
return fmt.Errorf("grant_type: %w", err)
return fmt.Errorf("UnmarshalForm: %w", err)
}
*gt = responseType
@ -52,12 +58,12 @@ func (gt *GrantType) UnmarshalForm(src []byte) error {
func (gt *GrantType) UnmarshalJSON(v []byte) error {
src, err := strconv.Unquote(string(v))
if err != nil {
return err
return fmt.Errorf("UnmarshalJSON: %w", err)
}
responseType, err := ParseGrantType(src)
if err != nil {
return fmt.Errorf("grant_type: %w", err)
return fmt.Errorf("UnmarshalJSON: %w", err)
}
*gt = responseType

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain_test
import (
@ -12,13 +13,10 @@ func TestParseGrantType(t *testing.T) {
for _, tc := range []struct {
in string
out domain.GrantType
}{{
in: "authorization_code",
out: domain.GrantTypeAuthorizationCode,
}, {
in: "ticket",
out: domain.GrantTypeTicket,
}} {
}{
{in: "authorization_code", out: domain.GrantTypeAuthorizationCode},
{in: "ticket", out: domain.GrantTypeTicket},
} {
tc := tc
t.Run(tc.in, func(t *testing.T) {
@ -73,15 +71,10 @@ func TestGrantType_String(t *testing.T) {
name string
in domain.GrantType
out string
}{{
name: "authorization_code",
in: domain.GrantTypeAuthorizationCode,
out: "authorization_code",
}, {
name: "ticket",
in: domain.GrantTypeTicket,
out: "ticket",
}} {
}{
{name: "authorization_code", in: domain.GrantTypeAuthorizationCode, out: "authorization_code"},
{name: "ticket", in: domain.GrantTypeTicket, out: "ticket"},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {

View File

@ -18,7 +18,7 @@ type Me struct {
}
// ParseMe parse string as me URL identifier.
//nolint: funlen
//nolint: funlen, cyclop
func ParseMe(raw string) (*Me, error) {
me := http.AcquireURI()
if err := me.Parse(nil, []byte(raw)); err != nil {
@ -114,7 +114,7 @@ func TestMe(tb testing.TB, src string) *Me {
me, err := ParseMe(src)
if err != nil {
tb.Fatalf("%+v", err)
tb.Fatal(err)
}
return me
@ -174,12 +174,16 @@ func (m Me) URL() *url.URL {
}
return &url.URL{
Scheme: string(m.me.Scheme()),
Host: string(m.me.Host()),
Path: string(m.me.Path()),
RawPath: string(m.me.PathOriginal()),
RawQuery: string(m.me.QueryString()),
Fragment: string(m.me.Hash()),
ForceQuery: false,
Fragment: string(m.me.Hash()),
Host: string(m.me.Host()),
Opaque: "",
Path: string(m.me.Path()),
RawFragment: "",
RawPath: string(m.me.PathOriginal()),
RawQuery: string(m.me.QueryString()),
Scheme: string(m.me.Scheme()),
User: nil,
}
}

View File

@ -7,7 +7,6 @@ import (
"source.toby3d.me/website/indieauth/internal/domain"
)
//nolint: funlen
func TestParseMe(t *testing.T) {
t.Parallel()
@ -15,47 +14,18 @@ func TestParseMe(t *testing.T) {
name string
in string
expError bool
}{{
name: "valid",
in: "https://example.com/",
expError: false,
}, {
name: "valid path",
in: "https://example.com/username",
expError: false,
}, {
name: "valid query",
in: "https://example.com/users?id=100",
expError: false,
}, {
name: "missing scheme",
in: "example.com",
expError: true,
}, {
name: "invalid scheme",
in: "mailto:user@example.com",
expError: true,
}, {
name: "contains double-dot path",
in: "https://example.com/foo/../bar",
expError: true,
}, {
name: "contains fragment",
in: "https://example.com/#me",
expError: true,
}, {
name: "contains user",
in: "https://user:pass@example.com/",
expError: true,
}, {
name: "contains port",
in: "https://example.com:8443/",
expError: true,
}, {
name: "host is an IP address",
in: "https://172.28.92.51/",
expError: true,
}} {
}{
{name: "valid", in: "https://example.com/", expError: false},
{name: "valid path", in: "https://example.com/username", expError: false},
{name: "valid query", in: "https://example.com/users?id=100", expError: false},
{name: "missing scheme", in: "example.com", expError: true},
{name: "invalid scheme", in: "mailto:user@example.com", expError: true},
{name: "contains double-dot path", in: "https://example.com/foo/../bar", expError: true},
{name: "contains fragment", in: "https://example.com/#me", expError: true},
{name: "contains user", in: "https://user:pass@example.com/", expError: true},
{name: "contains port", in: "https://example.com:8443/", expError: true},
{name: "host is an IP address", in: "https://172.28.92.51/", expError: true},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {

View File

@ -1,6 +1,6 @@
package domain
//nolint: tagliatelle
//nolint: tagliatelle // https://indieauth.net/source/#indieauth-server-metadata
type Metadata struct {
// The server's issuer identifier. The issuer identifier is a URL that
// uses the "https" scheme and has no query or fragment components. The

View File

@ -21,89 +21,89 @@ type Provider struct {
URL string
}
//nolint: gochecknoglobals
//nolint: gochecknoglobals // structs cannot be contants
var (
ProviderDirect = Provider{
AuthURL: "/authorize",
Name: "IndieAuth",
Photo: path.Join("static", "icon.svg"),
RedirectURL: path.Join("callback"),
Scopes: []string{},
TokenURL: "/token",
UID: "direct",
URL: "/",
AuthURL: "/authorize",
ClientID: "",
ClientSecret: "",
Name: "IndieAuth",
Photo: path.Join("static", "icon.svg"),
RedirectURL: path.Join("callback"),
Scopes: []string{},
TokenURL: "/token",
UID: "direct",
URL: "/",
}
ProviderTwitter = Provider{
AuthURL: "https://twitter.com/i/oauth2/authorize",
Name: "Twitter",
Photo: path.Join("static", "providers", "twitter.svg"),
RedirectURL: path.Join("callback", "twitter"),
Scopes: []string{
"tweet.read",
"users.read",
},
TokenURL: "https://api.twitter.com/2/oauth2/token",
UID: "twitter",
URL: "https://twitter.com/",
AuthURL: "https://twitter.com/i/oauth2/authorize",
ClientID: "",
ClientSecret: "",
Name: "Twitter",
Photo: path.Join("static", "providers", "twitter.svg"),
RedirectURL: path.Join("callback", "twitter"),
Scopes: []string{"tweet.read", "users.read"},
TokenURL: "https://api.twitter.com/2/oauth2/token",
UID: "twitter",
URL: "https://twitter.com/",
}
ProviderGitHub = Provider{
AuthURL: "https://github.com/login/oauth/authorize",
Name: "GitHub",
Photo: path.Join("static", "providers", "github.svg"),
RedirectURL: path.Join("callback", "github"),
Scopes: []string{
"read:user",
"user:email",
},
TokenURL: "https://github.com/login/oauth/access_token",
UID: "github",
URL: "https://github.com/",
AuthURL: "https://github.com/login/oauth/authorize",
ClientID: "",
ClientSecret: "",
Name: "GitHub",
Photo: path.Join("static", "providers", "github.svg"),
RedirectURL: path.Join("callback", "github"),
Scopes: []string{"read:user", "user:email"},
TokenURL: "https://github.com/login/oauth/access_token",
UID: "github",
URL: "https://github.com/",
}
ProviderGitLab = Provider{
AuthURL: "https://gitlab.com/oauth/authorize",
Name: "GitLab",
Photo: path.Join("static", "providers", "gitlab.svg"),
RedirectURL: path.Join("callback", "gitlab"),
Scopes: []string{
"read_user",
},
TokenURL: "https://gitlab.com/oauth/token",
UID: "gitlab",
URL: "https://gitlab.com/",
AuthURL: "https://gitlab.com/oauth/authorize",
ClientID: "",
ClientSecret: "",
Name: "GitLab",
Photo: path.Join("static", "providers", "gitlab.svg"),
RedirectURL: path.Join("callback", "gitlab"),
Scopes: []string{"read_user"},
TokenURL: "https://gitlab.com/oauth/token",
UID: "gitlab",
URL: "https://gitlab.com/",
}
ProviderMastodon = Provider{
AuthURL: "https://mstdn.io/oauth/authorize",
Name: "Mastodon",
Photo: path.Join("static", "providers", "mastodon.svg"),
RedirectURL: path.Join("callback", "mastodon"),
Scopes: []string{
"read:accounts",
},
TokenURL: "https://mstdn.io/oauth/token",
UID: "mastodon",
URL: "https://mstdn.io/",
AuthURL: "https://mstdn.io/oauth/authorize",
ClientID: "",
ClientSecret: "",
Name: "Mastodon",
Photo: path.Join("static", "providers", "mastodon.svg"),
RedirectURL: path.Join("callback", "mastodon"),
Scopes: []string{"read:accounts"},
TokenURL: "https://mstdn.io/oauth/token",
UID: "mastodon",
URL: "https://mstdn.io/",
}
)
// AuthCodeURL returns URL for authorize user in RelMeAuth client.
func (p Provider) AuthCodeURL(state string) string {
u := http.AcquireURI()
defer http.ReleaseURI(u)
u.Update(p.AuthURL)
uri := http.AcquireURI()
defer http.ReleaseURI(uri)
uri.Update(p.AuthURL)
for k, v := range map[string]string{
for key, val := range map[string]string{
"client_id": p.ClientID,
"redirect_uri": p.RedirectURL,
"response_type": "code",
"scope": strings.Join(p.Scopes, " "),
"state": state,
} {
u.QueryArgs().Set(k, v)
uri.QueryArgs().Set(key, val)
}
return u.String()
return uri.String()
}

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain
import (
@ -12,20 +13,20 @@ type ResponseType struct {
uid string
}
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
//nolint: gochecknoglobals // structs cannot be constants
var (
ResponseTypeUndefined ResponseType = ResponseType{uid: ""}
ResponseTypeUndefined = ResponseType{uid: ""}
// Deprecated(toby3d): Only accept response_type=code requests, and for
// backwards-compatible support, treat response_type=id requests as
// response_type=code requests:
// https://aaronparecki.com/2020/12/03/1/indieauth-2020#response-type
ResponseTypeID ResponseType = ResponseType{uid: "id"}
ResponseTypeID = ResponseType{uid: "id"}
// Indicates to the authorization server that an authorization code
// should be returned as the response:
// https://indieauth.net/source/#authorization-request-li-1
ResponseTypeCode ResponseType = ResponseType{uid: "code"}
ResponseTypeCode = ResponseType{uid: "code"}
)
var ErrResponseTypeUnknown error = NewError(
@ -65,7 +66,7 @@ func (rt *ResponseType) UnmarshalJSON(v []byte) error {
return fmt.Errorf("UnmarshalJSON: %w", err)
}
responseType, err := ParseResponseType(string(uid))
responseType, err := ParseResponseType(uid)
if err != nil {
return fmt.Errorf("UnmarshalJSON: %w", err)
}

View File

@ -1,3 +1,4 @@
//nolint: dupl
package domain_test
import (
@ -12,13 +13,10 @@ func TestResponseTypeType(t *testing.T) {
for _, tc := range []struct {
in string
out domain.ResponseType
}{{
in: "id",
out: domain.ResponseTypeID,
}, {
in: "code",
out: domain.ResponseTypeCode,
}} {
}{
{in: "id", out: domain.ResponseTypeID},
{in: "code", out: domain.ResponseTypeCode},
} {
tc := tc
t.Run(tc.in, func(t *testing.T) {
@ -73,15 +71,10 @@ func TestResponseType_String(t *testing.T) {
name string
in domain.ResponseType
out string
}{{
name: "id",
in: domain.ResponseTypeID,
out: "id",
}, {
name: "code",
in: domain.ResponseTypeCode,
out: "code",
}} {
}{
{name: "id", in: domain.ResponseTypeID, out: "id"},
{name: "code", in: domain.ResponseTypeCode, out: "code"},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {

View File

@ -21,7 +21,7 @@ type (
var ErrScopeUnknown error = NewError(ErrorCodeInvalidRequest, "unknown scope", "https://indieweb.org/scope")
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be constants
//nolint: gochecknoglobals // structs cannot be constants
var (
ScopeUndefined = Scope{uid: ""}
@ -56,7 +56,7 @@ var (
ScopeEmail = Scope{uid: "email"}
)
//nolint: gochecknoglobals // NOTE(toby3d): maps cannot be constants
//nolint: gochecknoglobals // maps cannot be constants
var uidsScopes = map[string]Scope{
ScopeBlock.uid: ScopeBlock,
ScopeChannels.uid: ScopeChannels,
@ -112,17 +112,17 @@ func (s *Scopes) UnmarshalJSON(v []byte) error {
result := make(Scopes, 0)
for _, scope := range strings.Fields(src) {
s, err := ParseScope(scope)
for _, rawScope := range strings.Fields(src) {
scope, err := ParseScope(rawScope)
if err != nil {
return fmt.Errorf("UnmarshalJSON: %w", err)
}
if result.Has(s) {
if result.Has(scope) {
continue
}
result = append(result, s)
result = append(result, scope)
}
*s = result

View File

@ -14,43 +14,20 @@ func TestParseScope(t *testing.T) {
for _, tc := range []struct {
in string
out domain.Scope
}{{
in: "create",
out: domain.ScopeCreate,
}, {
in: "delete",
out: domain.ScopeDelete,
}, {
in: "draft",
out: domain.ScopeDraft,
}, {
in: "media",
out: domain.ScopeMedia,
}, {
in: "update",
out: domain.ScopeUpdate,
}, {
in: "block",
out: domain.ScopeBlock,
}, {
in: "channels",
out: domain.ScopeChannels,
}, {
in: "follow",
out: domain.ScopeFollow,
}, {
in: "mute",
out: domain.ScopeMute,
}, {
in: "read",
out: domain.ScopeRead,
}, {
in: "profile",
out: domain.ScopeProfile,
}, {
in: "email",
out: domain.ScopeEmail,
}} {
}{
{in: "create", out: domain.ScopeCreate},
{in: "delete", out: domain.ScopeDelete},
{in: "draft", out: domain.ScopeDraft},
{in: "media", out: domain.ScopeMedia},
{in: "update", out: domain.ScopeUpdate},
{in: "block", out: domain.ScopeBlock},
{in: "channels", out: domain.ScopeChannels},
{in: "follow", out: domain.ScopeFollow},
{in: "mute", out: domain.ScopeMute},
{in: "read", out: domain.ScopeRead},
{in: "profile", out: domain.ScopeProfile},
{in: "email", out: domain.ScopeEmail},
} {
tc := tc
t.Run(tc.in, func(t *testing.T) {
@ -118,47 +95,24 @@ func TestScopes_MarshalJSON(t *testing.T) {
func TestScope_String(t *testing.T) {
t.Parallel()
//nolint: paralleltest // NOTE(toby3d): false positive, tc.in is used.
//nolint: paralleltest // false positive, in is used
for _, tc := range []struct {
in domain.Scope
out string
}{{
in: domain.ScopeCreate,
out: "create",
}, {
in: domain.ScopeDelete,
out: "delete",
}, {
in: domain.ScopeDraft,
out: "draft",
}, {
in: domain.ScopeMedia,
out: "media",
}, {
in: domain.ScopeUpdate,
out: "update",
}, {
in: domain.ScopeBlock,
out: "block",
}, {
in: domain.ScopeChannels,
out: "channels",
}, {
in: domain.ScopeFollow,
out: "follow",
}, {
in: domain.ScopeMute,
out: "mute",
}, {
in: domain.ScopeRead,
out: "read",
}, {
in: domain.ScopeProfile,
out: "profile",
}, {
in: domain.ScopeEmail,
out: "email",
}} {
}{
{in: domain.ScopeCreate, out: "create"},
{in: domain.ScopeDelete, out: "delete"},
{in: domain.ScopeDraft, out: "draft"},
{in: domain.ScopeMedia, out: "media"},
{in: domain.ScopeUpdate, out: "update"},
{in: domain.ScopeBlock, out: "block"},
{in: domain.ScopeChannels, out: "channels"},
{in: domain.ScopeFollow, out: "follow"},
{in: domain.ScopeMute, out: "mute"},
{in: domain.ScopeRead, out: "read"},
{in: domain.ScopeProfile, out: "profile"},
{in: domain.ScopeEmail, out: "email"},
} {
tc := tc
t.Run(fmt.Sprint(tc.in), func(t *testing.T) {

View File

@ -8,22 +8,22 @@ import (
type Session struct {
ClientID *ClientID
Me *Me
RedirectURI *URL
Profile *Profile
Me *Me
CodeChallengeMethod CodeChallengeMethod
Scope Scopes
Code string
CodeChallenge string
Code string
}
// TestSession returns valid random generated session for tests.
//nolint: gomnd // testing domain can contains non-standart values
func TestSession(tb testing.TB) *Session {
tb.Helper()
code, err := random.String(24)
if err != nil {
tb.Fatalf("%+v", err)
tb.Fatal(err)
}
return &Session{

View File

@ -33,13 +33,20 @@ type (
}
)
//nolint: gochecknoglobals
// DefaultNewTokenOptions describes the default settings for NewToken.
//nolint: gochecknoglobals, gomnd
var DefaultNewTokenOptions = NewTokenOptions{
Algorithm: "HS256",
Expiration: 0,
Issuer: nil,
NonceLength: 32,
Scope: nil,
Secret: nil,
Subject: nil,
}
// NewToken create a new token by provided options.
//nolint: cyclop
func NewToken(opts NewTokenOptions) (*Token, error) {
if opts.NonceLength == 0 {
opts.NonceLength = DefaultNewTokenOptions.NonceLength
@ -56,22 +63,33 @@ func NewToken(opts NewTokenOptions) (*Token, error) {
return nil, fmt.Errorf("cannot generate nonce: %w", err)
}
t := jwt.New()
t.Set(jwt.SubjectKey, opts.Subject.String())
t.Set(jwt.NotBeforeKey, now)
t.Set(jwt.IssuedAtKey, now)
t.Set("scope", opts.Scope)
t.Set("nonce", nonce)
tkn := jwt.New()
for key, val := range map[string]interface{}{
"nonce": nonce,
"scope": opts.Scope,
jwt.IssuedAtKey: now,
jwt.NotBeforeKey: now,
jwt.SubjectKey: opts.Subject.String(),
} {
if err = tkn.Set(key, val); err != nil {
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
}
}
if opts.Issuer != nil {
t.Set(jwt.IssuerKey, opts.Issuer.String())
if err = tkn.Set(jwt.IssuerKey, opts.Issuer.String()); err != nil {
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
}
}
if opts.Expiration != 0 {
t.Set(jwt.ExpirationKey, now.Add(opts.Expiration))
if err = tkn.Set(jwt.ExpirationKey, now.Add(opts.Expiration)); err != nil {
return nil, fmt.Errorf("failed to set JWT token field: %w", err)
}
}
accessToken, err := jwt.Sign(t, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret)
accessToken, err := jwt.Sign(tkn, jwa.SignatureAlgorithm(opts.Algorithm), opts.Secret)
if err != nil {
return nil, fmt.Errorf("cannot sign a new access token: %w", err)
}
@ -81,19 +99,20 @@ func NewToken(opts NewTokenOptions) (*Token, error) {
ClientID: opts.Issuer,
Me: opts.Subject,
Scope: opts.Scope,
}, err
}, nil
}
// TestToken returns valid random generated token for tests.
//nolint: gomnd // testing domain can contains non-standart values
func TestToken(tb testing.TB) *Token {
tb.Helper()
nonce, err := random.String(22)
if err != nil {
tb.Fatalf("%+v", err)
tb.Fatal(err)
}
t := jwt.New()
tkn := jwt.New()
cid := TestClientID(tb)
me := TestMe(tb, "https://user.example.net/")
now := time.Now().UTC().Round(time.Second)
@ -103,22 +122,25 @@ func TestToken(tb testing.TB) *Token {
ScopeUpdate,
}
// NOTE(toby3d): required
t.Set(jwt.IssuerKey, cid.String())
t.Set(jwt.SubjectKey, me.String())
// TODO(toby3d): t.Set(jwt.AudienceKey, nil)
t.Set(jwt.ExpirationKey, now.Add(1*time.Hour))
t.Set(jwt.NotBeforeKey, now.Add(-1*time.Hour))
t.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour))
// TODO(toby3d): t.Set(jwt.JwtIDKey, nil)
for key, val := range map[string]interface{}{
// NOTE(toby3d): required
jwt.IssuerKey: cid.String(),
jwt.SubjectKey: me.String(),
jwt.ExpirationKey: now.Add(1 * time.Hour),
jwt.NotBeforeKey: now.Add(-1 * time.Hour),
jwt.IssuedAtKey: now.Add(-1 * time.Hour),
// TODO(toby3d): jwt.AudienceKey
// TODO(toby3d): jwt.JwtIDKey
// NOTE(toby3d): optional
"scope": scope,
"nonce": nonce,
} {
_ = tkn.Set(key, val)
}
// optional
t.Set("scope", scope)
t.Set("nonce", nonce)
accessToken, err := jwt.Sign(t, jwa.HS256, []byte("hackme"))
accessToken, err := jwt.Sign(tkn, jwa.HS256, []byte("hackme"))
if err != nil {
tb.Fatalf("%+v", err)
tb.Fatal(err)
}
return &Token{
@ -144,5 +166,5 @@ func (t Token) String() string {
return ""
}
return "Bearer " + string(t.AccessToken)
return "Bearer " + t.AccessToken
}

View File

@ -71,7 +71,7 @@ func (u URL) URL() *url.URL {
return nil
}
result, err := url.ParseRequestURI(u.URI.String())
result, err := url.ParseRequestURI(u.String())
if err != nil {
return nil
}

View File

@ -19,32 +19,32 @@ func TestParseURL(t *testing.T) {
func TestURL_UnmarshalForm(t *testing.T) {
t.Parallel()
u := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
input := []byte(fmt.Sprint(u))
url := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
input := []byte(fmt.Sprint(url))
result := new(domain.URL)
if err := result.UnmarshalForm(input); err != nil {
t.Fatalf("%+v", err)
}
if fmt.Sprint(result) != fmt.Sprint(u) {
t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, u)
if fmt.Sprint(result) != fmt.Sprint(url) {
t.Errorf("UnmarshalForm(%s) = %v, want %v", input, result, url)
}
}
func TestURL_UnmarshalJSON(t *testing.T) {
t.Parallel()
u := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
input := []byte(fmt.Sprintf(`"%s"`, u))
url := domain.TestURL(t, "https://user:pass@example.com:8443/users?id=100#me")
input := []byte(fmt.Sprintf(`"%s"`, url))
result := new(domain.URL)
if err := result.UnmarshalJSON(input); err != nil {
t.Fatalf("%+v", err)
}
if fmt.Sprint(result) != fmt.Sprint(u) {
t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, u)
if fmt.Sprint(result) != fmt.Sprint(url) {
t.Errorf("UnmarshalJSON(%s) = %v, want %v", input, result, url)
}
}

View File

@ -11,7 +11,7 @@ import (
)
type (
//nolint: tagliatelle
//nolint: tagliatelle // https://indieauth.net/source/#indieauth-server-metadata
MetadataResponse struct {
// The server's issuer identifier. The issuer identifier is a
// URL that uses the "https" scheme and has no query or fragment
@ -72,14 +72,11 @@ type (
}
)
//nolint: gochecknoglobals // NOTE(toby3d): structs cannot be contants.
// DefaultMetadataResponse contains all supported types by default.
//nolint: gochecknoglobals // structs cannot be constants
var DefaultMetadataResponse = MetadataResponse{
ServiceDocumentation: "https://indieauth.net/source/",
AuthorizationEndpoint: "",
AuthorizationResponseIssParameterSupported: true,
ScopesSupported: []string{
domain.ScopeEmail.String(),
domain.ScopeProfile.String(),
},
CodeChallengeMethodsSupported: []string{
domain.CodeChallengeMethodMD5.String(),
domain.CodeChallengeMethodPLAIN.String(),
@ -87,13 +84,21 @@ var DefaultMetadataResponse = MetadataResponse{
domain.CodeChallengeMethodS256.String(),
domain.CodeChallengeMethodS512.String(),
},
GrantTypesSupported: []string{
domain.GrantTypeAuthorizationCode.String(),
domain.GrantTypeTicket.String(),
},
Issuer: "",
ResponseTypesSupported: []string{
domain.ResponseTypeCode.String(),
domain.ResponseTypeID.String(),
},
GrantTypesSupported: []string{
domain.GrantTypeAuthorizationCode.String(),
ScopesSupported: []string{
domain.ScopeEmail.String(),
domain.ScopeProfile.String(),
},
ServiceDocumentation: "https://indieauth.net/source/",
TokenEndpoint: "",
}
func NewRequestHandler(config *domain.Config) *RequestHandler {
@ -118,5 +123,5 @@ func (h *RequestHandler) read(ctx *http.RequestCtx) {
ctx.SetStatusCode(http.StatusOK)
ctx.SetContentType(common.MIMEApplicationJSON)
json.NewEncoder(ctx).Encode(&resp)
_ = json.NewEncoder(ctx).Encode(&resp)
}

View File

@ -19,6 +19,7 @@ func NewGithubProfileRepository() profile.Repository {
return &githubProfileRepository{}
}
//nolint: cyclop
func (repo *githubProfileRepository) Get(ctx context.Context, token *oauth2.Token) (*domain.Profile, error) {
user, _, err := github.NewClient(oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))).Users.Get(ctx, "")
if err != nil {
@ -41,6 +42,7 @@ func (repo *githubProfileRepository) Get(ctx context.Context, token *oauth2.Toke
// NOTE(toby3d): Profile URLs.
result.URL = make([]*domain.URL, 0)
var twitterURL *string
if user.TwitterUsername != nil {

View File

@ -19,7 +19,7 @@ func NewGitlabProfileRepository() profile.Repository {
return &gitlabProfileRepository{}
}
//nolint: funlen // NOTE(toby3d): uses hyphenation on new lines for readability.
//nolint: funlen, cyclop
func (repo *gitlabProfileRepository) Get(_ context.Context, token *oauth2.Token) (*domain.Profile, error) {
client, err := gitlab.NewClient(token.AccessToken)
if err != nil {

View File

@ -25,8 +25,10 @@ func NewMastodonProfileRepository(server string) profile.Repository {
func (repo *mastodonProfileRepository) Get(ctx context.Context, token *oauth2.Token) (*domain.Profile, error) {
account, err := mastodon.NewClient(&mastodon.Config{
Server: repo.server,
AccessToken: token.AccessToken,
AccessToken: token.AccessToken,
ClientID: "",
ClientSecret: "",
Server: repo.server,
}).GetAccountCurrentUser(ctx)
if err != nil {
return nil, fmt.Errorf("%s: cannot get account info: %w", ErrPrefix, err)
@ -64,11 +66,9 @@ func (repo *mastodonProfileRepository) Get(ctx context.Context, token *oauth2.To
}
u, err := domain.ParseURL(account.Fields[i].Value)
if err != nil {
continue
if err == nil {
result.URL = append(result.URL, u)
}
result.URL = append(result.URL, u)
}
// WARN(toby3d): Mastodon does not provide an email via API.

View File

@ -2,6 +2,7 @@ package random
import (
"crypto/rand"
"fmt"
"math/big"
"strings"
)
@ -17,32 +18,31 @@ const (
)
func Bytes(length int) ([]byte, error) {
b := make([]byte, length)
bytes := make([]byte, length)
if _, err := rand.Read(b); err != nil {
return nil, err
if _, err := rand.Read(bytes); err != nil {
return nil, fmt.Errorf("cannot read bytes: %w", err)
}
return b, nil
return bytes, nil
}
func String(length int, charsets ...string) (string, error) {
charset := strings.Join(charsets, "")
if charset == "" {
charset = Alphabetic
}
b := make([]byte, length)
bytes := make([]byte, length)
for i := range b {
for i := range bytes {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
return "", fmt.Errorf("failed to randomize bytes: %w", err)
}
b[i] = charset[n.Int64()]
bytes[i] = charset[n.Int64()]
}
return string(b), nil
return string(bytes), nil
}

View File

@ -27,8 +27,7 @@ type (
}
sqlite3SessionRepository struct {
config *domain.Config
db *sqlx.DB
db *sqlx.DB
}
)
@ -57,7 +56,7 @@ const (
WHERE code=$1;`
)
func NewSQLite3SessionRepository(config *domain.Config, db *sqlx.DB) session.Repository {
func NewSQLite3SessionRepository(db *sqlx.DB) session.Repository {
db.MustExec(QueryTable)
return &sqlite3SessionRepository{
@ -74,7 +73,7 @@ func (repo *sqlite3SessionRepository) Create(ctx context.Context, session *domai
}
func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*domain.Session, error) {
s := new(Session)
s := new(Session) //nolint: varnamelen // cannot redaclare import
if err := repo.db.GetContext(ctx, s, QueryGet, code); err != nil {
return nil, fmt.Errorf("cannot find session in db: %w", err)
}
@ -86,16 +85,17 @@ func (repo *sqlite3SessionRepository) Get(ctx context.Context, code string) (*do
}
func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code string) (*domain.Session, error) {
s := new(Session)
s := new(Session) //nolint: varnamelen // cannot redaclare import
tx, err := repo.db.Beginx()
if err != nil {
tx.Rollback()
_ = tx.Rollback()
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
if err = tx.GetContext(ctx, s, QueryGet, code); err != nil {
//nolint: errcheck // deffered method
defer tx.Rollback()
if errors.Is(err, sql.ErrNoRows) {
@ -106,7 +106,7 @@ func (repo *sqlite3SessionRepository) GetAndDelete(ctx context.Context, code str
}
if _, err = tx.ExecContext(ctx, QueryDelete, code); err != nil {
tx.Rollback()
_ = tx.Rollback()
return nil, fmt.Errorf("cannot remove session from db: %w", err)
}

View File

@ -12,10 +12,9 @@ import (
"source.toby3d.me/website/indieauth/internal/testing/sqltest"
)
//nolint: gochecknoglobals
//nolint: gochecknoglobals // slices cannot be contants
var tableColumns = []string{
"created_at", "client_id", "me", "redirect_uri", "code_challenge_method", "scope", "code",
"code_challenge",
"created_at", "client_id", "me", "redirect_uri", "code_challenge_method", "scope", "code", "code_challenge",
}
func TestCreate(t *testing.T) {
@ -40,7 +39,7 @@ func TestCreate(t *testing.T) {
).
WillReturnResult(sqlmock.NewResult(1, 1))
if err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db).
if err := repository.NewSQLite3SessionRepository(db).
Create(context.TODO(), session); err != nil {
t.Error(err)
}
@ -69,7 +68,7 @@ func TestGet(t *testing.T) {
model.CodeChallenge,
))
result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db).
result, err := repository.NewSQLite3SessionRepository(db).
Get(context.TODO(), session.Code)
if err != nil {
t.Fatal(err)
@ -108,7 +107,7 @@ func TestGetAndDelete(t *testing.T) {
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
result, err := repository.NewSQLite3SessionRepository(domain.TestConfig(t), db).
result, err := repository.NewSQLite3SessionRepository(db).
GetAndDelete(context.TODO(), session.Code)
if err != nil {
t.Fatal(err)

View File

@ -23,7 +23,6 @@ func New(tb testing.TB) (*bolt.DB, func()) {
db, err := bolt.Open(filePath, os.ModePerm, nil)
require.NoError(tb, err)
//nolint: errcheck
return db, func() {
db.Close()
os.Remove(filePath)

View File

@ -3,6 +3,8 @@ package httptest
import (
"crypto/tls"
// used for running tests.
_ "embed"
"net"
"testing"
@ -47,9 +49,8 @@ func New(tb testing.TB, handler http.RequestHandler) (*http.Client, *http.Server
},
}
//nolint: errcheck
return client, server, func() {
server.Shutdown()
_ = server.Shutdown()
}
}

View File

@ -7,7 +7,6 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
)
type Time struct{}
@ -24,17 +23,17 @@ func Open(tb testing.TB) (*sqlx.DB, sqlmock.Sqlmock, func()) {
db, mock, err := sqlmock.New()
if err != nil {
tb.Fatalf("%+v", err)
tb.Fatal(err)
}
xdb := sqlx.NewDb(db, "sqlite")
if err = xdb.Ping(); err != nil {
_ = db.Close()
tb.Fatalf("%+v", err)
tb.Fatal(err)
}
return xdb, mock, func() {
_ = db.Close() //nolint: errcheck
_ = db.Close()
}
}

View File

@ -7,6 +7,7 @@ import (
"github.com/fasthttp/router"
"github.com/goccy/go-json"
"github.com/lestrrat-go/jwx/jwa"
http "github.com/valyala/fasthttp"
"golang.org/x/text/language"
"golang.org/x/text/message"
@ -21,7 +22,7 @@ import (
)
type (
GenerateRequest struct {
TicketGenerateRequest struct {
// The access token should be used when acting on behalf of this URL.
Subject *domain.Me `form:"subject"`
@ -29,7 +30,7 @@ type (
Resource *domain.URL `form:"resource"`
}
ExchangeRequest struct {
TicketExchangeRequest struct {
// A random string that can be redeemed for an access token.
Ticket string `form:"ticket"`
@ -58,21 +59,40 @@ func NewRequestHandler(tickets ticket.UseCase, matcher language.Matcher, config
func (h *RequestHandler) Register(r *router.Router) {
chain := middleware.Chain{
middleware.CSRFWithConfig(middleware.CSRFConfig{
CookieSameSite: http.CookieSameSiteLaxMode,
ContextKey: "csrf",
CookieName: "_csrf",
TokenLookup: "form:_csrf",
CookieSecure: true,
CookieHTTPOnly: true,
Skipper: func(ctx *http.RequestCtx) bool {
matched, _ := path.Match("/ticket*", string(ctx.Path()))
return ctx.IsPost() && matched
},
CookieMaxAge: 0,
CookieSameSite: http.CookieSameSiteLaxMode,
ContextKey: "csrf",
CookieDomain: "",
CookieName: "_csrf",
CookiePath: "",
TokenLookup: "form:_csrf",
TokenLength: 0,
CookieSecure: true,
CookieHTTPOnly: true,
}),
middleware.JWTWithConfig(middleware.JWTConfig{
AuthScheme: "Bearer",
BeforeFunc: nil,
Claims: nil,
ContextKey: "user",
ErrorHandler: nil,
ErrorHandlerWithContext: nil,
ParseTokenFunc: nil,
SigningKey: []byte(h.config.JWT.Secret),
SigningKeys: nil,
SigningMethod: jwa.SignatureAlgorithm(h.config.JWT.Algorithm),
Skipper: middleware.DefaultSkipper,
SuccessHandler: nil,
TokenLookup: middleware.SourceHeader + ":" + http.HeaderAuthorization,
}),
middleware.LogFmt(),
}
// TODO(toby3d): secure this via JWT middleware
r.GET("/ticket", chain.RequestHandler(h.handleRender))
r.POST("/api/ticket", chain.RequestHandler(h.handleSend))
r.POST("/ticket", chain.RequestHandler(h.handleRedeem))
@ -102,10 +122,11 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest)
req := new(TicketGenerateRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
@ -119,14 +140,16 @@ func (h *RequestHandler) handleSend(ctx *http.RequestCtx) {
var err error
if ticket.Ticket, err = random.String(h.config.TicketAuth.Length); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return
}
if err = h.tickets.Generate(ctx, ticket); err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return
}
@ -140,10 +163,11 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest)
req := new(TicketExchangeRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
@ -155,7 +179,8 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
})
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
_ = encoder.Encode(domain.NewError(domain.ErrorCodeServerError, err.Error(), ""))
return
}
@ -170,7 +195,7 @@ func (h *RequestHandler) handleRedeem(ctx *http.RequestCtx) {
}`, token.AccessToken, token.Scope.String(), token.Me.String()))
}
func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) {
func (req *TicketGenerateRequest) bind(ctx *http.RequestCtx) (err error) {
indieAuthError := new(domain.Error)
if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil {
if errors.As(err, indieAuthError) {
@ -203,7 +228,7 @@ func (req *GenerateRequest) bind(ctx *http.RequestCtx) (err error) {
return nil
}
func (req *ExchangeRequest) bind(ctx *http.RequestCtx) (err error) {
func (req *TicketExchangeRequest) bind(ctx *http.RequestCtx) (err error) {
indieAuthError := new(domain.Error)
if err = form.Unmarshal(ctx.Request.PostArgs(), req); err != nil {
if errors.As(err, indieAuthError) {

View File

@ -46,13 +46,13 @@ func TestUpdate(t *testing.T) {
userClient, _, userCleanup := httptest.New(t, userRouter.Handler)
t.Cleanup(userCleanup)
r := router.New()
router := router.New()
delivery.NewRequestHandler(
ucase.NewTicketUseCase(ticketrepo.NewMemoryTicketRepository(new(sync.Map), config), userClient, config),
language.NewMatcher(message.DefaultCatalog.Languages()), config,
).Register(r)
).Register(router)
client, _, cleanup := httptest.New(t, r.Handler)
client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup)
req := httptest.NewRequest(http.MethodPost, "https://example.com/ticket", []byte(

View File

@ -59,8 +59,8 @@ func (repo *memoryTicketRepository) GC() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for ts := range ticker.C {
ts := ts.UTC()
for timeStamp := range ticker.C {
timeStamp := timeStamp.UTC()
repo.store.Range(func(key, value interface{}) bool {
k, ok := key.(string)
@ -78,7 +78,7 @@ func (repo *memoryTicketRepository) GC() {
return false
}
if val.CreatedAt.Add(repo.config.Code.Expiry).After(ts) {
if val.CreatedAt.Add(repo.config.Code.Expiry).After(timeStamp) {
return false
}

View File

@ -63,16 +63,17 @@ func (repo *sqlite3TicketRepository) Create(ctx context.Context, t *domain.Ticke
return nil
}
func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, t string) (*domain.Ticket, error) {
func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, rawTicket string) (*domain.Ticket, error) {
tx, err := repo.db.Beginx()
if err != nil {
tx.Rollback()
_ = tx.Rollback()
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
tkt := new(Ticket)
if err = tx.GetContext(ctx, tkt, QueryGet, t); err != nil {
if err = tx.GetContext(ctx, tkt, QueryGet, rawTicket); err != nil {
//nolint: errcheck // deffered method
defer tx.Rollback()
if errors.Is(err, sql.ErrNoRows) {
@ -82,8 +83,8 @@ func (repo *sqlite3TicketRepository) GetAndDelete(ctx context.Context, t string)
return nil, fmt.Errorf("cannot find ticket in db: %w", err)
}
if _, err = tx.ExecContext(ctx, QueryDelete, t); err != nil {
tx.Rollback()
if _, err = tx.ExecContext(ctx, QueryDelete, rawTicket); err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("cannot remove ticket from db: %w", err)
}

View File

@ -12,8 +12,8 @@ import (
repository "source.toby3d.me/website/indieauth/internal/ticket/repository/sqlite3"
)
//nolint: gochecknoglobals
var tableColumns []string = []string{"created_at", "resource", "subject", "ticket"}
//nolint: gochecknoglobals // slices cannot be contants
var tableColumns = []string{"created_at", "resource", "subject", "ticket"}
func TestCreate(t *testing.T) {
t.Parallel()

View File

@ -11,7 +11,6 @@ type UseCase interface {
// Redeem transform received ticket into access token.
Redeem(ctx context.Context, ticket *domain.Ticket) (*domain.Token, error)
Exchange(ctx context.Context, ticket string) (*domain.Token, error)
}

View File

@ -14,7 +14,7 @@ import (
)
type (
//nolint: tagliatelle
//nolint: tagliatelle // https://indieauth.net/source/#access-token-response
Response struct {
Me *domain.Me `json:"me"`
Scope domain.Scopes `json:"scope"`
@ -37,11 +37,11 @@ func NewTicketUseCase(tickets ticket.Repository, client *http.Client, config *do
}
}
func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) error {
func (useCase *ticketUseCase) Generate(ctx context.Context, tkt *domain.Ticket) error {
req := http.AcquireRequest()
defer http.ReleaseRequest(req)
req.Header.SetMethod(http.MethodGet)
req.SetRequestURI(t.Subject.String())
req.SetRequestURI(tkt.Subject.String())
resp := http.AcquireResponse()
defer http.ReleaseResponse(resp)
@ -65,7 +65,7 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
return ticket.ErrTicketEndpointNotExist
}
if err := useCase.tickets.Create(ctx, t); err != nil {
if err := useCase.tickets.Create(ctx, tkt); err != nil {
return fmt.Errorf("cannot save ticket in store: %w", err)
}
@ -73,9 +73,9 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
req.Header.SetMethod(http.MethodPost)
req.SetRequestURIBytes(ticketEndpoint.RequestURI())
req.Header.SetContentType(common.MIMEApplicationForm)
req.PostArgs().Set("ticket", t.Ticket)
req.PostArgs().Set("subject", t.Subject.String())
req.PostArgs().Set("resource", t.Resource.String())
req.PostArgs().Set("ticket", tkt.Ticket)
req.PostArgs().Set("subject", tkt.Subject.String())
req.PostArgs().Set("resource", tkt.Resource.String())
resp.Reset()
if err := useCase.client.Do(req, resp); err != nil {
@ -85,10 +85,10 @@ func (useCase *ticketUseCase) Generate(ctx context.Context, t *domain.Ticket) er
return nil
}
func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*domain.Token, error) {
func (useCase *ticketUseCase) Redeem(ctx context.Context, tkt *domain.Ticket) (*domain.Token, error) {
req := http.AcquireRequest()
defer http.ReleaseRequest(req)
req.SetRequestURI(t.Resource.String())
req.SetRequestURI(tkt.Resource.String())
req.Header.SetMethod(http.MethodGet)
resp := http.AcquireResponse()
@ -119,7 +119,7 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*do
req.Header.SetContentType(common.MIMEApplicationForm)
req.Header.Set(http.HeaderAccept, common.MIMEApplicationJSON)
req.PostArgs().Set("grant_type", domain.GrantTypeTicket.String())
req.PostArgs().Set("ticket", t.Ticket)
req.PostArgs().Set("ticket", tkt.Ticket)
resp.Reset()
if err := useCase.client.Do(req, resp); err != nil {
@ -142,19 +142,19 @@ func (useCase *ticketUseCase) Redeem(ctx context.Context, t *domain.Ticket) (*do
}
func (useCase *ticketUseCase) Exchange(ctx context.Context, ticket string) (*domain.Token, error) {
t, err := useCase.tickets.GetAndDelete(ctx, ticket)
tkt, err := useCase.tickets.GetAndDelete(ctx, ticket)
if err != nil {
return nil, fmt.Errorf("cannot find provided ticket: %w", err)
}
token, err := domain.NewToken(domain.NewTokenOptions{
Algorithm: useCase.config.JWT.Algorithm,
Expiration: useCase.config.JWT.Expiry,
// TODO(toby3d): Issuer: &domain.ClientID{},
NonceLength: useCase.config.JWT.NonceLength,
Expiration: useCase.config.JWT.Expiry,
Scope: domain.Scopes{domain.ScopeRead},
Issuer: nil,
Subject: tkt.Subject,
Secret: []byte(useCase.config.JWT.Secret),
Subject: t.Subject,
Algorithm: useCase.config.JWT.Algorithm,
NonceLength: useCase.config.JWT.NonceLength,
})
if err != nil {
return nil, fmt.Errorf("cannot generate a new access token: %w", err)

View File

@ -22,12 +22,12 @@ func TestRedeem(t *testing.T) {
token := domain.TestToken(t)
ticket := domain.TestTicket(t)
r := router.New()
r.GET(string(ticket.Resource.Path()), func(ctx *http.RequestCtx) {
router := router.New()
router.GET(string(ticket.Resource.Path()), func(ctx *http.RequestCtx) {
ctx.SuccessString(common.MIMETextHTMLCharsetUTF8, `<link rel="token_endpoint" href="`+
ticket.Subject.String()+`token">`)
})
r.POST("/token", func(ctx *http.RequestCtx) {
router.POST("/token", func(ctx *http.RequestCtx) {
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, fmt.Sprintf(`{
"token_type": "Bearer",
"access_token": "%s",
@ -36,7 +36,7 @@ func TestRedeem(t *testing.T) {
}`, token.AccessToken, token.Scope.String(), token.Me.String()))
})
client, _, cleanup := httptest.New(t, r.Handler)
client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup)
result, err := ucase.NewTicketUseCase(nil, client, domain.TestConfig(t)).

View File

@ -17,7 +17,7 @@ import (
)
type (
ExchangeRequest struct {
TokenExchangeRequest struct {
ClientID *domain.ClientID `form:"client_id"`
RedirectURI *domain.URL `form:"redirect_uri"`
GrantType domain.GrantType `form:"grant_type"`
@ -25,40 +25,40 @@ type (
CodeVerifier string `form:"code_verifier"`
}
RevokeRequest struct {
TokenRevokeRequest struct {
Action domain.Action `form:"action"`
Token string `form:"token"`
}
TicketRequest struct {
TokenTicketRequest struct {
Action domain.Action `form:"action"`
Ticket string `form:"ticket"`
}
//nolint: tagliatelle
ExchangeResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
Me string `json:"me"`
Profile *ProfileResponse `json:"profile,omitempty"`
//nolint: tagliatelle // https://indieauth.net/source/#access-token-response
TokenExchangeResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
Me string `json:"me"`
Profile *TokenProfileResponse `json:"profile,omitempty"`
}
ProfileResponse struct {
TokenProfileResponse struct {
Name string `json:"name,omitempty"`
URL *domain.URL `json:"url,omitempty"`
Photo *domain.URL `json:"photo,omitempty"`
Email *domain.Email `json:"email,omitempty"`
}
//nolint: tagliatelle
VerificationResponse struct {
//nolint: tagliatelle // https://indieauth.net/source/#access-token-verification-response
TokenVerificationResponse struct {
Me *domain.Me `json:"me"`
ClientID *domain.ClientID `json:"client_id"`
Scope domain.Scopes `json:"scope"`
}
RevocationResponse struct{}
TokenRevocationResponse struct{}
RequestHandler struct {
tokens token.UseCase
@ -88,11 +88,12 @@ func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx)
t, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)),
tkt, err := h.tokens.Verify(ctx, strings.TrimPrefix(string(ctx.Request.Header.Peek(http.HeaderAuthorization)),
"Bearer "))
if err != nil || t == nil {
if err != nil || tkt == nil {
ctx.SetStatusCode(http.StatusUnauthorized)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeUnauthorizedClient,
err.Error(),
"https://indieauth.net/source/#access-token-verification",
@ -101,10 +102,10 @@ func (h *RequestHandler) handleValidate(ctx *http.RequestCtx) {
return
}
_ = encoder.Encode(&VerificationResponse{
ClientID: t.ClientID,
Me: t.Me,
Scope: t.Scope,
_ = encoder.Encode(&TokenVerificationResponse{
ClientID: tkt.ClientID,
Me: tkt.Me,
Scope: tkt.Scope,
})
}
@ -120,7 +121,8 @@ func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
action, err := domain.ParseAction(string(ctx.PostArgs().Peek("action")))
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"",
@ -138,15 +140,17 @@ func (h *RequestHandler) handleAction(ctx *http.RequestCtx) {
}
}
//nolint: funlen
func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
ctx.SetContentType(common.MIMEApplicationJSONCharsetUTF8)
encoder := json.NewEncoder(ctx)
req := new(ExchangeRequest)
req := new(TokenExchangeRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
@ -159,7 +163,8 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
})
if err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#request",
@ -168,20 +173,21 @@ func (h *RequestHandler) handleExchange(ctx *http.RequestCtx) {
return
}
resp := &ExchangeResponse{
resp := &TokenExchangeResponse{
AccessToken: token.AccessToken,
TokenType: "Bearer",
Scope: token.Scope.String(),
Me: token.Me.String(),
Profile: nil,
}
if profile == nil {
encoder.Encode(resp)
_ = encoder.Encode(resp)
return
}
resp.Profile = new(ProfileResponse)
resp.Profile = new(TokenProfileResponse)
if len(profile.Name) > 0 {
resp.Profile.Name = profile.Name[0]
@ -208,17 +214,19 @@ func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx)
req := new(RevokeRequest)
req := new(TokenRevokeRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
if err := h.tokens.Revoke(ctx, req.Token); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"",
@ -227,7 +235,7 @@ func (h *RequestHandler) handleRevoke(ctx *http.RequestCtx) {
return
}
_ = encoder.Encode(&RevocationResponse{})
_ = encoder.Encode(&TokenRevocationResponse{})
}
func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
@ -236,18 +244,20 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
encoder := json.NewEncoder(ctx)
req := new(TicketRequest)
req := new(TokenTicketRequest)
if err := req.bind(ctx); err != nil {
ctx.SetStatusCode(http.StatusBadRequest)
encoder.Encode(err)
_ = encoder.Encode(err)
return
}
t, err := h.tickets.Exchange(ctx, req.Ticket)
tkn, err := h.tickets.Exchange(ctx, req.Ticket)
if err != nil {
ctx.SetStatusCode(http.StatusInternalServerError)
encoder.Encode(domain.NewError(
_ = encoder.Encode(domain.NewError(
domain.ErrorCodeInvalidRequest,
err.Error(),
"https://indieauth.net/source/#request",
@ -256,15 +266,16 @@ func (h *RequestHandler) handleTicket(ctx *http.RequestCtx) {
return
}
encoder.Encode(ExchangeResponse{
AccessToken: t.AccessToken,
_ = encoder.Encode(TokenExchangeResponse{
AccessToken: tkn.AccessToken,
TokenType: "Bearer",
Scope: t.Scope.String(),
Me: t.Me.String(),
Scope: tkn.Scope.String(),
Me: tkn.Me.String(),
Profile: nil,
})
}
func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
func (r *TokenExchangeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
if errors.As(err, indieAuthError) {
@ -281,7 +292,7 @@ func (r *ExchangeRequest) bind(ctx *http.RequestCtx) error {
return nil
}
func (r *RevokeRequest) bind(ctx *http.RequestCtx) error {
func (r *TokenRevokeRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
if errors.As(err, indieAuthError) {
@ -298,7 +309,7 @@ func (r *RevokeRequest) bind(ctx *http.RequestCtx) error {
return nil
}
func (r *TicketRequest) bind(ctx *http.RequestCtx) error {
func (r *TokenTicketRequest) bind(ctx *http.RequestCtx) error {
indieAuthError := new(domain.Error)
if err := form.Unmarshal(ctx.PostArgs(), r); err != nil {
if errors.As(err, indieAuthError) {

View File

@ -36,7 +36,7 @@ func TestVerification(t *testing.T) {
config := domain.TestConfig(t)
token := domain.TestToken(t)
r := router.New()
router := router.New()
// TODO(toby3d): provide tickets
delivery.NewRequestHandler(
tokenucase.NewTokenUseCase(
@ -49,9 +49,9 @@ func TestVerification(t *testing.T) {
new(http.Client),
config,
),
).Register(r)
).Register(router)
client, _, cleanup := httptest.New(t, r.Handler)
client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup)
req := httptest.NewRequest(http.MethodGet, "https://app.example.com/token", nil)
@ -65,7 +65,7 @@ func TestVerification(t *testing.T) {
require.NoError(t, client.Do(req, resp))
assert.Equal(t, http.StatusOK, resp.StatusCode())
result := new(delivery.VerificationResponse)
result := new(delivery.TokenVerificationResponse)
require.NoError(t, json.Unmarshal(resp.Body(), result))
assert.Equal(t, token.ClientID.String(), result.ClientID.String())
assert.Equal(t, token.Me.String(), result.Me.String())
@ -80,7 +80,7 @@ func TestRevocation(t *testing.T) {
tokens := tokenrepo.NewMemoryTokenRepository(store)
accessToken := domain.TestToken(t)
r := router.New()
router := router.New()
delivery.NewRequestHandler(
tokenucase.NewTokenUseCase(
tokens,
@ -92,9 +92,9 @@ func TestRevocation(t *testing.T) {
new(http.Client),
config,
),
).Register(r)
).Register(router)
client, _, cleanup := httptest.New(t, r.Handler)
client, _, cleanup := httptest.New(t, router.Handler)
t.Cleanup(cleanup)
req := httptest.NewRequest(http.MethodPost, "https://app.example.com/token", nil)

View File

@ -62,8 +62,8 @@ func (repo *sqlite3TokenRepository) Create(ctx context.Context, accessToken *dom
}
func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string) (*domain.Token, error) {
t := new(Token)
if err := repo.db.GetContext(ctx, t, QueryGet, accessToken); err != nil {
tkn := new(Token)
if err := repo.db.GetContext(ctx, tkn, QueryGet, accessToken); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, token.ErrNotExist
}
@ -72,7 +72,7 @@ func (repo *sqlite3TokenRepository) Get(ctx context.Context, accessToken string)
}
result := new(domain.Token)
t.Populate(result)
tkn.Populate(result)
return result, nil
}

View File

@ -12,14 +12,15 @@ import (
repository "source.toby3d.me/website/indieauth/internal/token/repository/sqlite3"
)
//nolint: gochecknoglobals
var tableColumns []string = []string{"created_at", "access_token", "client_id", "me", "scope"}
//nolint: gochecknoglobals // slices cannot be contants
var tableColumns = []string{"created_at", "access_token", "client_id", "me", "scope"}
func TestCreate(t *testing.T) {
t.Parallel()
token := domain.TestToken(t)
model := repository.NewToken(token)
db, mock, cleanup := sqltest.Open(t)
t.Cleanup(cleanup)
@ -44,6 +45,7 @@ func TestGet(t *testing.T) {
token := domain.TestToken(t)
model := repository.NewToken(token)
db, mock, cleanup := sqltest.Open(t)
t.Cleanup(cleanup)

View File

@ -56,7 +56,7 @@ func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOp
return nil, nil, token.ErrEmptyScope
}
t, err := domain.NewToken(domain.NewTokenOptions{
tkn, err := domain.NewToken(domain.NewTokenOptions{
Algorithm: useCase.config.JWT.Algorithm,
Expiration: useCase.config.JWT.Expiry,
Issuer: session.ClientID,
@ -70,14 +70,14 @@ func (useCase *tokenUseCase) Exchange(ctx context.Context, opts token.ExchangeOp
}
if !session.Scope.Has(domain.ScopeProfile) {
return t, nil, nil
return tkn, nil, nil
}
p := new(domain.Profile)
// TODO(toby3d): if session.Scope.Has(domain.ScopeEmail) {}
return t, p, nil
return tkn, p, nil
}
func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*domain.Token, error) {
@ -90,23 +90,26 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
return nil, token.ErrRevoke
}
t, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm),
tkn, err := jwt.ParseString(accessToken, jwt.WithVerify(jwa.SignatureAlgorithm(useCase.config.JWT.Algorithm),
[]byte(useCase.config.JWT.Secret)))
if err != nil {
return nil, fmt.Errorf("cannot parse JWT token: %w", err)
}
if err = jwt.Validate(t); err != nil {
if err = jwt.Validate(tkn); err != nil {
return nil, fmt.Errorf("cannot validate JWT token: %w", err)
}
result := &domain.Token{
Scope: nil,
ClientID: nil,
Me: nil,
AccessToken: accessToken,
}
result.ClientID, _ = domain.ParseClientID(t.Issuer())
result.Me, _ = domain.ParseMe(t.Subject())
result.ClientID, _ = domain.ParseClientID(tkn.Issuer())
result.Me, _ = domain.ParseMe(tkn.Subject())
if scope, ok := t.Get("scope"); ok {
if scope, ok := tkn.Get("scope"); ok {
result.Scope, _ = scope.(domain.Scopes)
}
@ -114,12 +117,12 @@ func (useCase *tokenUseCase) Verify(ctx context.Context, accessToken string) (*d
}
func (useCase *tokenUseCase) Revoke(ctx context.Context, accessToken string) error {
t, err := useCase.Verify(ctx, accessToken)
tkn, err := useCase.Verify(ctx, accessToken)
if err != nil {
return fmt.Errorf("cannot verify token: %w", err)
}
if err = useCase.tokens.Create(ctx, t); err != nil {
if err = useCase.tokens.Create(ctx, tkn); err != nil {
return fmt.Errorf("cannot save token in database: %w", err)
}

View File

@ -52,31 +52,38 @@ func (repo *httpUserRepository) Get(ctx context.Context, me *domain.Me) (*domain
// TODO(toby3d): handle error here?
resolvedMe, _ := domain.ParseMe(string(resp.Header.Peek(http.HeaderLocation)))
u := &domain.User{
Me: resolvedMe,
user := &domain.User{
AuthorizationEndpoint: nil,
IndieAuthMetadata: nil,
Me: resolvedMe,
Micropub: nil,
Microsub: nil,
Profile: &domain.Profile{
Name: make([]string, 0),
URL: make([]*domain.URL, 0),
Photo: make([]*domain.URL, 0),
Email: make([]*domain.Email, 0),
Name: make([]string, 0),
Photo: make([]*domain.URL, 0),
URL: make([]*domain.URL, 0),
},
TicketEndpoint: nil,
TokenEndpoint: nil,
}
metadata, err := util.ExtractMetadata(resp, repo.client)
if err == nil && metadata != nil {
u.AuthorizationEndpoint = metadata.AuthorizationEndpoint
u.Micropub = metadata.Micropub
u.Microsub = metadata.Microsub
u.TicketEndpoint = metadata.TicketEndpoint
u.TokenEndpoint = metadata.TokenEndpoint
if metadata, err := util.ExtractMetadata(resp, repo.client); err == nil {
user.AuthorizationEndpoint = metadata.AuthorizationEndpoint
user.Micropub = metadata.Micropub
user.Microsub = metadata.Microsub
user.TicketEndpoint = metadata.TicketEndpoint
user.TokenEndpoint = metadata.TokenEndpoint
}
extractUser(u, resp)
extractProfile(u.Profile, resp)
extractUser(user, resp)
extractProfile(user.Profile, resp)
return u, nil
return user, nil
}
//nolint: cyclop
func extractUser(dst *domain.User, src *http.Response) {
if dst.IndieAuthMetadata != nil {
if endpoints := util.ExtractEndpoints(src, relIndieAuthMetadata); len(endpoints) > 0 {
@ -115,14 +122,12 @@ func extractUser(dst *domain.User, src *http.Response) {
}
}
//nolint: cyclop
func extractProfile(dst *domain.Profile, src *http.Response) {
for _, name := range util.ExtractProperty(src, hCard, propertyName) {
n, ok := name.(string)
if !ok {
continue
if n, ok := name.(string); ok {
dst.Name = append(dst.Name, n)
}
dst.Name = append(dst.Name, n)
}
for _, rawEmail := range util.ExtractProperty(src, hCard, propertyEmail) {
@ -131,26 +136,20 @@ func extractProfile(dst *domain.Profile, src *http.Response) {
continue
}
e, err := domain.ParseEmail(email)
if err != nil {
continue
if e, err := domain.ParseEmail(email); err == nil {
dst.Email = append(dst.Email, e)
}
dst.Email = append(dst.Email, e)
}
for _, rawUrl := range util.ExtractProperty(src, hCard, propertyURL) {
url, ok := rawUrl.(string)
for _, rawURL := range util.ExtractProperty(src, hCard, propertyURL) {
url, ok := rawURL.(string)
if !ok {
continue
}
u, err := domain.ParseURL(url)
if err != nil {
continue
if u, err := domain.ParseURL(url); err == nil {
dst.URL = append(dst.URL, u)
}
dst.URL = append(dst.URL, u)
}
for _, rawPhoto := range util.ExtractProperty(src, hCard, propertyPhoto) {
@ -159,11 +158,8 @@ func extractProfile(dst *domain.Profile, src *http.Response) {
continue
}
p, err := domain.ParseURL(photo)
if err != nil {
continue
if p, err := domain.ParseURL(photo); err == nil {
dst.Photo = append(dst.Photo, p)
}
dst.Photo = append(dst.Photo, p)
}
}

View File

@ -69,8 +69,8 @@ func TestGet(t *testing.T) {
func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
tb.Helper()
r := router.New()
r.GET("/", func(ctx *http.RequestCtx) {
router := router.New()
router.GET("/", func(ctx *http.RequestCtx) {
ctx.Response.Header.Set(http.HeaderLink, strings.Join([]string{
`<` + user.AuthorizationEndpoint.String() + `>; rel="authorization_endpoint"`,
`<` + user.IndieAuthMetadata.String() + `>; rel="indieauth-metadata"`,
@ -83,7 +83,7 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
testBody, user.Name[0], user.URL[0].String(), user.Photo[0].String(), user.Email[0],
))
})
r.GET(string(user.IndieAuthMetadata.Path()), func(ctx *http.RequestCtx) {
router.GET(string(user.IndieAuthMetadata.Path()), func(ctx *http.RequestCtx) {
ctx.SuccessString(common.MIMEApplicationJSONCharsetUTF8, `{
"issuer": "`+user.Me.String()+`",
"authorization_endpoint": "`+user.AuthorizationEndpoint.String()+`",
@ -91,5 +91,5 @@ func testHandler(tb testing.TB, user *domain.User) http.RequestHandler {
}`)
})
return r.Handler
return router.Handler
}

View File

@ -3,6 +3,7 @@ package util
import (
"bytes"
"encoding/json"
"fmt"
"net/url"
"strings"
@ -13,6 +14,12 @@ import (
"source.toby3d.me/website/indieauth/internal/domain"
)
var ErrEndpointNotExist = domain.NewError(
domain.ErrorCodeServerError,
"cannot found any endpoints",
"https://indieauth.net/source/#discovery-0",
)
func ExtractEndpoints(resp *http.Response, rel string) []*domain.URL {
results := make([]*domain.URL, 0)
@ -39,7 +46,7 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*domain.URL,
u := http.AcquireURI()
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(link.URL)); err != nil {
return nil, err
return nil, fmt.Errorf("cannot parse header endpoint: %w", err)
}
results = append(results, &domain.URL{URI: u})
@ -51,7 +58,7 @@ func ExtractEndpointsFromHeader(resp *http.Response, rel string) ([]*domain.URL,
func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, error) {
endpoints, ok := microformats.Parse(bytes.NewReader(resp.Body()), nil).Rels[rel]
if !ok || len(endpoints) == 0 {
return nil, nil
return nil, ErrEndpointNotExist
}
results := make([]*domain.URL, 0)
@ -59,7 +66,7 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, e
for i := range endpoints {
u := http.AcquireURI()
if err := u.Parse(resp.Header.Peek(http.HeaderHost), []byte(endpoints[i])); err != nil {
return nil, err
return nil, fmt.Errorf("cannot parse body endpoint: %w", err)
}
results = append(results, &domain.URL{URI: u})
@ -71,29 +78,30 @@ func ExtractEndpointsFromBody(resp *http.Response, rel string) ([]*domain.URL, e
func ExtractMetadata(resp *http.Response, client *http.Client) (*domain.Metadata, error) {
endpoints := ExtractEndpoints(resp, "indieauth-metadata")
if len(endpoints) == 0 {
return nil, nil
return nil, ErrEndpointNotExist
}
_, body, err := client.Get(nil, endpoints[len(endpoints)-1].String())
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to fetch metadata endpoint configuration: %w", err)
}
result := new(domain.Metadata)
if err = json.Unmarshal(body, result); err != nil {
return nil, err
return nil, fmt.Errorf("cannot unmarshal emtadata configuration: %w", err)
}
return result, nil
}
func ExtractProperty(resp *http.Response, t, key string) []interface{} {
func ExtractProperty(resp *http.Response, itemType, key string) []interface{} {
//nolint: exhaustivestruct // only Host part in url.URL is needed
data := microformats.Parse(bytes.NewReader(resp.Body()), &url.URL{
Host: string(resp.Header.Peek(http.HeaderHost)),
})
for _, item := range data.Items {
if !contains(item.Type, t) {
if !contains(item.Type, itemType) {
continue
}

24
main.go
View File

@ -55,7 +55,8 @@ import (
const (
DefaultCacheDuration time.Duration = 8760 * time.Hour // NOTE(toby3d): year
DefaultPort int = 3000
DefaultReadTimeout time.Duration = 5 * time.Second
DefaultWriteTimeout time.Duration = 10 * time.Second
)
//nolint: gochecknoglobals
@ -126,7 +127,7 @@ func init() {
client.RedirectURI = []*domain.URL{redirectURI}
}
//nolint: funlen
//nolint: funlen, cyclop // "god object" and the entry point of all modules
func main() {
var (
tokens token.Repository
@ -146,7 +147,7 @@ func main() {
}
tokens = tokensqlite3repo.NewSQLite3TokenRepository(store)
sessions = sessionsqlite3repo.NewSQLite3SessionRepository(config, store)
sessions = sessionsqlite3repo.NewSQLite3SessionRepository(store)
tickets = ticketsqlite3repo.NewSQLite3TicketRepository(store, config)
case "memory":
store := new(sync.Map)
@ -160,12 +161,11 @@ func main() {
go sessions.GC()
matcher := language.NewMatcher(message.DefaultCatalog.Languages())
//nolint: exhaustivestruct // too many options
httpClient := &http.Client{
Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()),
MaxConnDuration: 10 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxConnWaitTimeout: 10 * time.Second,
Name: fmt.Sprintf("%s/0.1 (+%s)", config.Name, config.Server.GetAddress()),
ReadTimeout: DefaultReadTimeout,
WriteTimeout: DefaultWriteTimeout,
}
ticketService := ticketucase.NewTicketUseCase(tickets, httpClient, config)
tokenService := tokenucase.NewTokenUseCase(tokens, sessions, config)
@ -187,6 +187,7 @@ func main() {
Matcher: matcher,
Config: config,
}).Register(r)
//nolint: exhaustivestruct// too many options
r.ServeFilesCustom(path.Join(config.Server.StaticURLPrefix, "{filepath:*}"), &http.FS{
Root: config.Server.StaticRootPath,
CacheDuration: DefaultCacheDuration,
@ -200,11 +201,12 @@ func main() {
r.GET("/debug/pprof/{filepath:*}", pprofhandler.PprofHandler)
}
//nolint: exhaustivestruct
server := &http.Server{
Name: fmt.Sprintf("IndieAuth/0.1 (+%s)", config.Server.GetAddress()),
Handler: r.Handler,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ReadTimeout: DefaultReadTimeout,
WriteTimeout: DefaultWriteTimeout,
DisableKeepalive: true,
ReduceMemoryUsage: true,
SecureErrorLogMessage: true,
@ -212,7 +214,7 @@ func main() {
}
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL)
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
if cpuProfilePath != "" {
cpuProfile, err := os.Create(cpuProfilePath)