diff --git a/go.mod b/go.mod index 5972d7d..81ddd9b 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,22 @@ module source.toby3d.me/toby3d/middleware go 1.17 -require github.com/valyala/fasthttp v1.31.0 +require ( + github.com/lestrrat-go/jwx v1.2.13 + github.com/valyala/fasthttp v1.31.0 +) require ( github.com/andybalholm/brotli v1.0.4 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect + github.com/goccy/go-json v0.8.1 // indirect github.com/klauspost/compress v1.13.6 // indirect + github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect + github.com/lestrrat-go/blackmagic v1.0.0 // indirect + github.com/lestrrat-go/httpcc v1.0.0 // indirect + github.com/lestrrat-go/iter v1.0.1 // indirect + github.com/lestrrat-go/option v1.0.0 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect ) diff --git a/go.sum b/go.sum index 8769850..56726c5 100644 --- a/go.sum +++ b/go.sum @@ -1,22 +1,66 @@ github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.0-20210816181553-5444fa50b93d h1:1iy2qD6JEhHKKhUOA9IWs7mjco7lnw2qx8FsRI2wirE= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.0-20210816181553-5444fa50b93d/go.mod h1:tmAIfUFEirG/Y8jhZ9M+h36obRZAk/1fcSpXwAVlfqE= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= +github.com/goccy/go-json v0.8.1 h1:4/Wjm0JIJaTDm8K1KcGrLHJoa8EsJ13YWeX+6Kfq6uI= +github.com/goccy/go-json v0.8.1/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A= +github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y= +github.com/lestrrat-go/blackmagic v1.0.0 h1:XzdxDbuQTz0RZZEmdU7cnQxUtFUzgCSPq8RCz4BxIi4= +github.com/lestrrat-go/blackmagic v1.0.0/go.mod h1:TNgH//0vYSs8VXDCfkZLgIrVTTXQELZffUV0tz3MtdQ= +github.com/lestrrat-go/httpcc v1.0.0 h1:FszVC6cKfDvBKcJv646+lkh4GydQg2Z29scgUfkOpYc= +github.com/lestrrat-go/httpcc v1.0.0/go.mod h1:tGS/u00Vh5N6FHNkExqGGNId8e0Big+++0Gf8MBnAvE= +github.com/lestrrat-go/iter v1.0.1 h1:q8faalr2dY6o8bV45uwrxq12bRa1ezKrB6oM9FUgN4A= +github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc= +github.com/lestrrat-go/jwx v1.2.13 h1:GxuOPfAz4+nzL98WaKxBxEEZ9b7qmyDetMGfBm9yVvE= +github.com/lestrrat-go/jwx v1.2.13/go.mod h1:3Q3Re8TaOcVTdpx4Tvz++OWmryDklihTDqrrwQiyS2A= +github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4= +github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.31.0 h1:lrauRLII19afgCs2fnWRJ4M5IkV0lo2FqA61uGkNBfE= github.com/valyala/fasthttp v1.31.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20201217014255-9d1352758620/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jwt.go b/jwt.go new file mode 100644 index 0000000..b978bb8 --- /dev/null +++ b/jwt.go @@ -0,0 +1,351 @@ +package middleware + +import ( + "bytes" + "errors" + "fmt" + "strings" + + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" + http "github.com/valyala/fasthttp" +) + +type ( + // JWTConfig defines the config for JWT middleware. + JWTConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // BeforeFunc defines a function which is executed just before + // the middleware. + BeforeFunc BeforeFunc + + // SuccessHandler defines a function which is executed for a + // valid token. + SuccessHandler JWTSuccessHandler + + // ErrorHandler defines a function which is executed for an + // invalid token. It may be used to define a custom JWT error. + ErrorHandler JWTErrorHandler + + // ErrorHandlerWithContext is almost identical to ErrorHandler, + // but it's passed the current context. + ErrorHandlerWithContext JWTErrorHandlerWithContext + + // Signing key to validate token. + // This is one of the three options to provide a token + // validation key. The order of precedence is a user-defined + // KeyFunc, SigningKeys and SigningKey. + // + // Required if neither user-defined KeyFunc nor SigningKeys is + // provided. + SigningKey interface{} + + // Map of signing keys to validate token with kid field usage. + // This is one of the three options to provide a token + // validation key. The order of precedence is a user-defined + // KeyFunc, SigningKeys and SigningKey. + // + // Required if neither user-defined KeyFunc nor SigningKey is + // provided. + SigningKeys map[string]interface{} + + // Signing method used to check the token's signing algorithm. + // + // Optional. Default value HS256. + SigningMethod jwa.SignatureAlgorithm + + // Context key to store user information from the token into + // context. + // + // Optional. Default value "user". + ContextKey string + + // Claims are extendable claims data defining token content. + // Used by default ParseTokenFunc implementation. Not used if + // custom ParseTokenFunc is set. + // + // Optional. Default value []jwt.ClaimPair + Claims []jwt.ClaimPair + + // TokenLookup is a string in the form of ":" or + // ":,:" that is used to extract + // token from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" + // - "query:" + // - "param:" + // - "cookie:" + // - "form:" + // Multiply sources example: + // - "header: Authorization,cookie: myowncookie" + TokenLookup string + + // AuthScheme to be used in the Authorization header. + // + // Optional. Default value "Bearer". + AuthScheme string + + // KeyFunc defines a user-defined function that supplies the + // public key for a token validation. The function shall take + // care of verifying the signing algorithm and selecting the + // proper key. A user-defined KeyFunc can be useful if tokens + // are issued by an external party. Used by default + // ParseTokenFunc implementation. + // + // When a user-defined KeyFunc is provided, SigningKey, + // SigningKeys, and SigningMethod are ignored. This is one of + // the three options to provide a token validation key. The + // order of precedence is a user-defined KeyFunc, SigningKeys + // and SigningKey. Required if neither SigningKeys nor + // SigningKey is provided. Not used if custom ParseTokenFunc is + // set. Default to an internal implementation verifying the + // signing algorithm and selecting the proper key. + // KeyFunc jwt.Keyfunc + + // ParseTokenFunc defines a user-defined function that parses + // token from given auth. Returns an error when token parsing + // fails or parsed token is invalid. Defaults to implementation + // using `github.com/golang-jwt/jwt` as JWT implementation library + ParseTokenFunc func(auth []byte, ctx *http.RequestCtx) (interface{}, error) + } + + // JWTSuccessHandler defines a function which is executed for a valid + // token. + JWTSuccessHandler func(ctx *http.RequestCtx) + + // JWTErrorHandler defines a function which is executed for an invalid + // token. + JWTErrorHandler func(err error) + + // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, + // but it's passed the current context. + JWTErrorHandlerWithContext func(err error, ctx *http.RequestCtx) + + jwtExtractor func(ctx *http.RequestCtx) ([]byte, error) +) + +// Variants of token sources. +const ( + SourceCookie = "cookie" + SourceForm = "form" + SourceHeader = "header" + SourceParam = "param" + SourceQuery = "query" +) + +// DefaultJWTConfig is the default JWT auth middleware config. +//nolint: gochecknoglobals +var DefaultJWTConfig = JWTConfig{ + Skipper: DefaultSkipper, + SigningMethod: "HS256", + ContextKey: "user", + TokenLookup: SourceHeader + ":" + http.HeaderAuthorization, + AuthScheme: "Bearer", + Claims: []jwt.ClaimPair{}, +} + +// Possible token errors. +var ( + ErrJWTMissing = errors.New("jwt: missing or malformed jwt") + ErrJWTInvalid = errors.New("jwt: invalid or expired jwt") +) + +// JWT returns a JSON Web Token (JWT) auth middleware. +// +// * For valid token, it sets the user in context and calls next handler. +// * For invalid token, it returns "401 - Unauthorized" error. +// * For missing token, it returns "400 - Bad Request" error. +func JWT(key interface{}) Interceptor { + return JWTWithConfig(DefaultJWTConfig) +} + +// JWTWithConfig returns a JWT auth middleware with config. +//nolint: funlen, gocognit +func JWTWithConfig(config JWTConfig) Interceptor { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultJWTConfig.Skipper + } + + if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.ParseTokenFunc == nil { + panic("jwt: middleware requires signing key") + } + + if config.SigningMethod == "" { + config.SigningMethod = DefaultJWTConfig.SigningMethod + } + + if config.ContextKey == "" { + config.ContextKey = DefaultJWTConfig.ContextKey + } + + if config.Claims == nil { + config.Claims = DefaultJWTConfig.Claims + } + + if config.TokenLookup == "" { + config.TokenLookup = DefaultJWTConfig.TokenLookup + } + + if config.AuthScheme == "" { + config.AuthScheme = DefaultJWTConfig.AuthScheme + } + + if config.ParseTokenFunc == nil { + config.ParseTokenFunc = config.defaultParseToken + } + + extractors := make([]jwtExtractor, 0) + sources := strings.Split(config.TokenLookup, ",") + + for _, source := range sources { + parts := strings.Split(source, ":") + + switch parts[0] { + case SourceQuery: + extractors = append(extractors, jwtFromQuery(parts[1])) + case SourceParam: + extractors = append(extractors, jwtFromParam(parts[1])) + case SourceCookie: + extractors = append(extractors, jwtFromCookie(parts[1])) + case SourceForm: + extractors = append(extractors, jwtFromForm(parts[1])) + case SourceHeader: + extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme)) + } + } + + return func(ctx *http.RequestCtx, next http.RequestHandler) { + if config.Skipper(ctx) { + next(ctx) + + return + } + + if config.BeforeFunc != nil { + config.BeforeFunc(ctx) + } + + var ( + auth []byte + err error + ) + + for _, extractor := range extractors { + if auth, err = extractor(ctx); err == nil { + break + } + } + + if err != nil { + if config.ErrorHandler != nil { + config.ErrorHandler(err) + + return + } + + if config.ErrorHandlerWithContext != nil { + config.ErrorHandlerWithContext(err, ctx) + + return + } + + ctx.Error(err.Error(), http.StatusInternalServerError) + + return + } + + if token, err := config.ParseTokenFunc(auth, ctx); err == nil { + ctx.SetUserValue(config.ContextKey, token) + + if config.SuccessHandler != nil { + config.SuccessHandler(ctx) + } + + next(ctx) + + return + } + + if config.ErrorHandler != nil { + config.ErrorHandler(err) + + return + } + + if config.ErrorHandlerWithContext != nil { + config.ErrorHandlerWithContext(err, ctx) + + return + } + + ctx.Error(ErrJWTInvalid.Error(), http.StatusUnauthorized) + } +} + +func (config *JWTConfig) defaultParseToken(auth []byte, ctx *http.RequestCtx) (interface{}, error) { + token, err := jwt.Parse( + auth, jwt.WithVerify(config.SigningMethod, config.SigningKey), jwt.WithValidate(true), + ) + if err != nil { + return nil, fmt.Errorf("cannot parse JWT token: %w", err) + } + + return token, nil +} + +func jwtFromHeader(header, authScheme string) jwtExtractor { + return func(ctx *http.RequestCtx) ([]byte, error) { + auth := ctx.Request.Header.Peek(header) + l := len(authScheme) + + if len(auth) > l+1 && bytes.EqualFold(auth[:l], []byte(authScheme)) { + return auth[l+1:], nil + } + + return nil, ErrJWTMissing + } +} + +func jwtFromQuery(param string) jwtExtractor { + return func(ctx *http.RequestCtx) ([]byte, error) { + if ctx.QueryArgs().Has(param) { + return nil, ErrJWTMissing + } + + return ctx.QueryArgs().Peek(param), nil + } +} + +func jwtFromParam(param string) jwtExtractor { + return func(ctx *http.RequestCtx) ([]byte, error) { + if !ctx.PostArgs().Has(param) { + return nil, ErrJWTMissing + } + + return ctx.PostArgs().Peek(param), nil + } +} + +func jwtFromCookie(name string) jwtExtractor { + return func(ctx *http.RequestCtx) ([]byte, error) { + if cookie := ctx.Request.Header.Cookie(name); cookie != nil { + return cookie, nil + } + + return nil, ErrJWTMissing + } +} + +func jwtFromForm(name string) jwtExtractor { + return func(ctx *http.RequestCtx) ([]byte, error) { + if field := ctx.FormValue(name); field != nil { + return field, nil + } + + return nil, ErrJWTMissing + } +}