♻️ Refactored value extractors for middlewares
This commit is contained in:
parent
c5f0bbc687
commit
a214a7e9e8
234
csrf.go
234
csrf.go
|
@ -5,105 +5,104 @@ import (
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
http "github.com/valyala/fasthttp"
|
http "github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
// CSRFConfig defines the config for CSRF middleware.
|
||||||
// CSRFConfig defines the config for CSRF middleware.
|
type CSRFConfig struct {
|
||||||
CSRFConfig struct {
|
// Skipper defines a function to skip middleware.
|
||||||
// Skipper defines a function to skip middleware.
|
Skipper Skipper
|
||||||
Skipper Skipper
|
|
||||||
|
|
||||||
// Max age (in seconds) of the CSRF cookie.
|
// Indicates SameSite mode of the CSRF cookie.
|
||||||
//
|
//
|
||||||
// Optional. Default value 86400 (24hr).
|
// Optional. Default value SameSiteDefaultMode.
|
||||||
CookieMaxAge time.Duration
|
CookieSameSite http.CookieSameSite
|
||||||
|
|
||||||
// Indicates SameSite mode of the CSRF cookie.
|
// Context key to store generated CSRF token into context.
|
||||||
//
|
//
|
||||||
// Optional. Default value SameSiteDefaultMode.
|
// Optional. Default value "csrf".
|
||||||
CookieSameSite http.CookieSameSite
|
ContextKey string
|
||||||
|
|
||||||
// Context key to store generated CSRF token into context.
|
// Domain of the CSRF cookie.
|
||||||
//
|
//
|
||||||
// Optional. Default value "csrf".
|
// Optional. Default value none.
|
||||||
ContextKey string
|
CookieDomain string
|
||||||
|
|
||||||
// Domain of the CSRF cookie.
|
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||||||
//
|
//
|
||||||
// Optional. Default value none.
|
// Optional. Default value "csrf".
|
||||||
CookieDomain string
|
CookieName string
|
||||||
|
|
||||||
// Name of the CSRF cookie. This cookie will store CSRF token.
|
// Path of the CSRF cookie.
|
||||||
//
|
//
|
||||||
// Optional. Default value "csrf".
|
// Optional. Default value none.
|
||||||
CookieName string
|
CookiePath string
|
||||||
|
|
||||||
// Path of the CSRF cookie.
|
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
||||||
//
|
// to extract token from the request.
|
||||||
// Optional. Default value none.
|
// Possible values:
|
||||||
CookiePath string
|
// - "header:<name>"
|
||||||
|
// - "form:<name>"
|
||||||
|
// - "query:<name>"
|
||||||
|
//
|
||||||
|
// Optional. Default value "header:X-CSRF-Token".
|
||||||
|
TokenLookup string
|
||||||
|
|
||||||
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
// Max age (in seconds) of the CSRF cookie.
|
||||||
// to extract token from the request.
|
//
|
||||||
// Possible values:
|
// Optional. Default value 86400 (24hr).
|
||||||
// - "header:<name>"
|
CookieMaxAge int
|
||||||
// - "form:<name>"
|
|
||||||
// - "query:<name>"
|
|
||||||
//
|
|
||||||
// Optional. Default value "header:X-CSRF-Token".
|
|
||||||
TokenLookup string
|
|
||||||
|
|
||||||
// TokenLength is the length of the generated token.
|
// TokenLength is the length of the generated token.
|
||||||
//
|
//
|
||||||
// Optional. Default value 32.
|
// Optional. Default value 32.
|
||||||
TokenLength int
|
TokenLength int
|
||||||
|
|
||||||
// Indicates if CSRF cookie is secure.
|
// Indicates if CSRF cookie is secure.
|
||||||
//
|
//
|
||||||
// Optional. Default value false.
|
// Optional. Default value false.
|
||||||
CookieSecure bool
|
CookieSecure bool
|
||||||
|
|
||||||
// Indicates if CSRF cookie is HTTP only.
|
// Indicates if CSRF cookie is HTTP only.
|
||||||
//
|
//
|
||||||
// Optional. Default value false.
|
// Optional. Default value false.
|
||||||
CookieHTTPOnly bool
|
CookieHTTPOnly bool
|
||||||
}
|
|
||||||
|
|
||||||
csrfTokenExtractor func(*http.RequestCtx) ([]byte, error)
|
|
||||||
)
|
|
||||||
|
|
||||||
// HeaderXCSRFToken describes the name of the header with the CSRF token.
|
|
||||||
const HeaderXCSRFToken string = "X-CSRF-Token"
|
|
||||||
|
|
||||||
// Possible CSRF errors.
|
|
||||||
var (
|
|
||||||
ErrMissingFormToken = errors.New("missing csrf token in the form parameter")
|
|
||||||
ErrMissingQueryToken = errors.New("missing csrf token in the query string")
|
|
||||||
)
|
|
||||||
|
|
||||||
// DefaultCSRFConfig contains the default CSRF middleware configuration.
|
|
||||||
//nolint: gochecknoglobals, gomnd
|
|
||||||
var DefaultCSRFConfig = CSRFConfig{
|
|
||||||
Skipper: DefaultSkipper,
|
|
||||||
CookieMaxAge: 24 * time.Hour,
|
|
||||||
CookieSameSite: http.CookieSameSiteDefaultMode,
|
|
||||||
ContextKey: "csrf",
|
|
||||||
CookieDomain: "",
|
|
||||||
CookieName: "_csrf",
|
|
||||||
CookiePath: "",
|
|
||||||
TokenLookup: "header:" + HeaderXCSRFToken,
|
|
||||||
TokenLength: 32,
|
|
||||||
CookieSecure: false,
|
|
||||||
CookieHTTPOnly: false,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ExtractorLimit int = 20
|
||||||
|
|
||||||
|
// HeaderXCSRFToken describes the name of the header with the CSRF token.
|
||||||
|
HeaderXCSRFToken string = "X-CSRF-Token"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrCSRFInvalid = errors.New("invalid csrf token")
|
||||||
|
|
||||||
|
// DefaultCSRFConfig contains the default CSRF middleware configuration.
|
||||||
|
//nolint: gochecknoglobals, gomnd
|
||||||
|
DefaultCSRFConfig = CSRFConfig{
|
||||||
|
Skipper: DefaultSkipper,
|
||||||
|
CookieMaxAge: 86400,
|
||||||
|
CookieSameSite: http.CookieSameSiteDefaultMode,
|
||||||
|
ContextKey: "csrf",
|
||||||
|
CookieDomain: "",
|
||||||
|
CookieName: "_csrf",
|
||||||
|
CookiePath: "",
|
||||||
|
TokenLookup: "header:" + HeaderXCSRFToken,
|
||||||
|
TokenLength: 32,
|
||||||
|
CookieSecure: true,
|
||||||
|
CookieHTTPOnly: false,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
||||||
func CSRF() Interceptor {
|
func CSRF() Interceptor {
|
||||||
return CSRFWithConfig(DefaultCSRFConfig)
|
config := DefaultCSRFConfig
|
||||||
|
|
||||||
|
return CSRFWithConfig(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRFWithConfig returns a CSRF middleware with config.
|
// CSRFWithConfig returns a CSRF middleware with config.
|
||||||
|
@ -134,17 +133,12 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.CookieSameSite == http.CookieSameSiteNoneMode {
|
if config.CookieSameSite == http.CookieSameSiteNoneMode {
|
||||||
config.CookieSecure = true
|
config.CookieSecure = DefaultCSRFConfig.CookieSecure
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(config.TokenLookup, ":")
|
extractors, err := createExtractors(config.TokenLookup, "")
|
||||||
extractor := csrfTokenFromHeader(parts[1])
|
if err != nil {
|
||||||
|
panic("middleware: csrf: " + err.Error())
|
||||||
switch parts[0] {
|
|
||||||
case "form":
|
|
||||||
extractor = csrfTokenFromForm(parts[1])
|
|
||||||
case "query":
|
|
||||||
extractor = csrfTokenFromQuery(parts[1])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(ctx *http.RequestCtx, next http.RequestHandler) {
|
return func(ctx *http.RequestCtx, next http.RequestHandler) {
|
||||||
|
@ -171,16 +165,36 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
||||||
switch {
|
switch {
|
||||||
case ctx.IsGet(), ctx.IsHead(), ctx.IsOptions(), ctx.IsTrace():
|
case ctx.IsGet(), ctx.IsHead(), ctx.IsOptions(), ctx.IsTrace():
|
||||||
default:
|
default:
|
||||||
// NOTE(toby3d): validate token only for requests which are not defined as 'safe' by RFC7231
|
var lastExtractorErr, lastTokenErr error
|
||||||
clientToken, err := extractor(ctx)
|
|
||||||
if err != nil {
|
|
||||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
|
||||||
|
|
||||||
return
|
outer:
|
||||||
|
for _, extractor := range extractors {
|
||||||
|
clientTokens, err := extractor(ctx)
|
||||||
|
if err != nil {
|
||||||
|
lastExtractorErr = err
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, clientToken := range clientTokens {
|
||||||
|
if !validateCSRFToken(token, clientToken) {
|
||||||
|
lastTokenErr = ErrCSRFInvalid
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lastTokenErr, lastExtractorErr = nil, nil
|
||||||
|
|
||||||
|
break outer
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validateCSRFToken(token, clientToken) {
|
if lastTokenErr != nil {
|
||||||
ctx.Error("invalid csrf token", http.StatusForbidden)
|
ctx.Error(lastTokenErr.Error(), http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
} else if lastExtractorErr != nil {
|
||||||
|
ctx.Error(lastExtractorErr.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -205,7 +219,7 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
||||||
cookie.SetSameSite(config.CookieSameSite)
|
cookie.SetSameSite(config.CookieSameSite)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie.SetExpire(time.Now().Add(config.CookieMaxAge))
|
cookie.SetExpire(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
|
||||||
cookie.SetSecure(config.CookieSecure)
|
cookie.SetSecure(config.CookieSecure)
|
||||||
cookie.SetHTTPOnly(config.CookieHTTPOnly)
|
cookie.SetHTTPOnly(config.CookieHTTPOnly)
|
||||||
ctx.Response.Header.SetCookie(cookie)
|
ctx.Response.Header.SetCookie(cookie)
|
||||||
|
@ -220,32 +234,6 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func csrfTokenFromHeader(header string) csrfTokenExtractor {
|
|
||||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
|
||||||
return ctx.Request.Header.Peek(header), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func csrfTokenFromForm(param string) csrfTokenExtractor {
|
|
||||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
|
||||||
if token := ctx.FormValue(param); token != nil {
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrMissingFormToken
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
|
||||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
|
||||||
if !ctx.QueryArgs().Has(param) {
|
|
||||||
return nil, ErrMissingQueryToken
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.QueryArgs().Peek(param), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateCSRFToken(token, clientToken []byte) bool {
|
func validateCSRFToken(token, clientToken []byte) bool {
|
||||||
return subtle.ConstantTimeCompare(token, clientToken) == 1
|
return subtle.ConstantTimeCompare(token, clientToken) == 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/textproto"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
http "github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValuesExtractor defines a function for extracting values (keys/tokens) from
|
||||||
|
// the given context.
|
||||||
|
type ValuesExtractor func(ctx *http.RequestCtx) ([][]byte, error)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// extractorLimit is arbitrary number to limit values extractor can
|
||||||
|
// return. this limits possible resource exhaustion attack vector.
|
||||||
|
extractorLimit = 20
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errHeaderExtractorValueMissing = errors.New("missing value in request header")
|
||||||
|
errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
|
||||||
|
errQueryExtractorValueMissing = errors.New("missing value in the query string")
|
||||||
|
errParamExtractorValueMissing = errors.New("missing value in path params")
|
||||||
|
errCookieExtractorValueMissing = errors.New("missing value in cookies")
|
||||||
|
errFormExtractorValueMissing = errors.New("missing value in the form")
|
||||||
|
)
|
||||||
|
|
||||||
|
func createExtractors(lookups, authScheme string) ([]ValuesExtractor, error) {
|
||||||
|
if lookups == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sources := strings.Split(lookups, ",")
|
||||||
|
extractors := make([]ValuesExtractor, 0)
|
||||||
|
|
||||||
|
for _, source := range sources {
|
||||||
|
parts := strings.Split(source, ":")
|
||||||
|
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %s",
|
||||||
|
source)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch parts[0] {
|
||||||
|
case "query":
|
||||||
|
extractors = append(extractors, valuesFromQuery(parts[1]))
|
||||||
|
case "cookie":
|
||||||
|
extractors = append(extractors, valuesFromCookie(parts[1]))
|
||||||
|
case "form":
|
||||||
|
extractors = append(extractors, valuesFromForm(parts[1]))
|
||||||
|
case "header":
|
||||||
|
prefix := ""
|
||||||
|
|
||||||
|
if len(parts) > 2 {
|
||||||
|
prefix = parts[2]
|
||||||
|
} else if authScheme != "" && parts[1] == http.HeaderAuthorization {
|
||||||
|
// backwards compatibility for JWT and KeyAuth:
|
||||||
|
// * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer <token-value>" etc
|
||||||
|
// * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
|
||||||
|
// behaviour for default values and Authorization header.
|
||||||
|
prefix = authScheme
|
||||||
|
|
||||||
|
if !strings.HasSuffix(prefix, " ") {
|
||||||
|
prefix += " "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return extractors, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromHeader returns a functions that extracts values from the request header.
|
||||||
|
// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
|
||||||
|
// prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we want to remove is `<auth-scheme> `
|
||||||
|
// note the space at the end. In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove
|
||||||
|
// is `Basic `. In case of JWT tokens `Authorization: Bearer <token>` prefix is `Bearer `.
|
||||||
|
// If prefix is left empty the whole value is returned.
|
||||||
|
func valuesFromHeader(header, valuePrefix string) ValuesExtractor {
|
||||||
|
prefixLen := len(valuePrefix)
|
||||||
|
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
|
||||||
|
header = textproto.CanonicalMIMEHeaderKey(header)
|
||||||
|
|
||||||
|
return func(ctx *http.RequestCtx) ([][]byte, error) {
|
||||||
|
value := ctx.Request.Header.Peek(header)
|
||||||
|
if len(value) == 0 {
|
||||||
|
return nil, errHeaderExtractorValueMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefixLen == 0 {
|
||||||
|
return append([][]byte{}, value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(value) > prefixLen && bytes.EqualFold(value[:prefixLen], []byte(valuePrefix)) {
|
||||||
|
return append([][]byte{}, value[:prefixLen]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errHeaderExtractorValueInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromQuery returns a function that extracts values from the query string.
|
||||||
|
func valuesFromQuery(param string) ValuesExtractor {
|
||||||
|
return func(ctx *http.RequestCtx) ([][]byte, error) {
|
||||||
|
if !ctx.QueryArgs().Has(param) {
|
||||||
|
return nil, errQueryExtractorValueMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ctx.QueryArgs().PeekMulti(param)
|
||||||
|
if len(result) > extractorLimit-1 {
|
||||||
|
result = result[:extractorLimit]
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromCookie returns a function that extracts values from the named cookie.
|
||||||
|
func valuesFromCookie(name string) ValuesExtractor {
|
||||||
|
return func(ctx *http.RequestCtx) ([][]byte, error) {
|
||||||
|
if value := ctx.Request.Header.Cookie(name); len(value) != 0 {
|
||||||
|
return append([][]byte{}, value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errCookieExtractorValueMissing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromForm returns a function that extracts values from the form field.
|
||||||
|
func valuesFromForm(name string) ValuesExtractor {
|
||||||
|
return func(ctx *http.RequestCtx) ([][]byte, error) {
|
||||||
|
form, err := ctx.MultipartForm()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := form.Value[name]
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(values) == 0:
|
||||||
|
return nil, errFormExtractorValueMissing
|
||||||
|
case len(values) > extractorLimit-1:
|
||||||
|
values = values[:extractorLimit]
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([][]byte, 0)
|
||||||
|
for i := range values {
|
||||||
|
result = append(result, []byte(values[i]))
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
186
jwt.go
186
jwt.go
|
@ -1,10 +1,8 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/jwa"
|
"github.com/lestrrat-go/jwx/jwa"
|
||||||
"github.com/lestrrat-go/jwx/jwt"
|
"github.com/lestrrat-go/jwx/jwt"
|
||||||
|
@ -83,6 +81,14 @@ type (
|
||||||
// - "header: Authorization,cookie: myowncookie"
|
// - "header: Authorization,cookie: myowncookie"
|
||||||
TokenLookup string
|
TokenLookup string
|
||||||
|
|
||||||
|
// TokenLookupFuncs defines a list of user-defined functions
|
||||||
|
// that extract JWT token from the given context.
|
||||||
|
// This is one of the two options to provide a token extractor.
|
||||||
|
// The order of precedence is user-defined TokenLookupFuncs, and
|
||||||
|
// TokenLookup.
|
||||||
|
// You can also provide both if you want.
|
||||||
|
TokenLookupFuncs []ValuesExtractor
|
||||||
|
|
||||||
// AuthScheme to be used in the Authorization header.
|
// AuthScheme to be used in the Authorization header.
|
||||||
//
|
//
|
||||||
// Optional. Default value "Bearer".
|
// Optional. Default value "Bearer".
|
||||||
|
@ -110,6 +116,17 @@ type (
|
||||||
// fails or parsed token is invalid. Defaults to implementation
|
// fails or parsed token is invalid. Defaults to implementation
|
||||||
// using `github.com/golang-jwt/jwt` as JWT implementation library
|
// using `github.com/golang-jwt/jwt` as JWT implementation library
|
||||||
ParseTokenFunc func(auth []byte, ctx *http.RequestCtx) (interface{}, error)
|
ParseTokenFunc func(auth []byte, ctx *http.RequestCtx) (interface{}, error)
|
||||||
|
|
||||||
|
// ContinueOnIgnoredError allows the next middleware/handler to
|
||||||
|
// be called when ErrorHandlerWithContext decides to ignore the
|
||||||
|
// error (by returning `nil`). This is useful when parts of your
|
||||||
|
// site/api allow public access and some authorized routes
|
||||||
|
// provide extra functionality. In that case you can use
|
||||||
|
// ErrorHandlerWithContext to set a default public JWT token
|
||||||
|
// value in the request context and continue. Some logic down
|
||||||
|
// the remaining execution chain needs to check that (public)
|
||||||
|
// token value then.
|
||||||
|
ContinueOnIgnoredError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTSuccessHandler defines a function which is executed for a valid
|
// JWTSuccessHandler defines a function which is executed for a valid
|
||||||
|
@ -122,27 +139,18 @@ type (
|
||||||
|
|
||||||
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler,
|
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler,
|
||||||
// but it's passed the current context.
|
// but it's passed the current context.
|
||||||
JWTErrorHandlerWithContext func(err error, ctx *http.RequestCtx)
|
JWTErrorHandlerWithContext func(err error, ctx *http.RequestCtx) error
|
||||||
|
|
||||||
jwtExtractor func(ctx *http.RequestCtx) ([]byte, error)
|
jwtExtractor func(ctx *http.RequestCtx) ([]byte, error)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Variants of token sources.
|
|
||||||
const (
|
|
||||||
SourceCookie = "cookie"
|
|
||||||
SourceForm = "form"
|
|
||||||
SourceHeader = "header"
|
|
||||||
SourceParam = "param"
|
|
||||||
SourceQuery = "query"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DefaultJWTConfig is the default JWT auth middleware config.
|
// DefaultJWTConfig is the default JWT auth middleware config.
|
||||||
//nolint: gochecknoglobals
|
//nolint: gochecknoglobals
|
||||||
var DefaultJWTConfig = JWTConfig{
|
var DefaultJWTConfig = JWTConfig{
|
||||||
Skipper: DefaultSkipper,
|
Skipper: DefaultSkipper,
|
||||||
SigningMethod: "HS256",
|
SigningMethod: "HS256",
|
||||||
ContextKey: "user",
|
ContextKey: "user",
|
||||||
TokenLookup: SourceHeader + ":" + http.HeaderAuthorization,
|
TokenLookup: "header:" + http.HeaderAuthorization,
|
||||||
AuthScheme: "Bearer",
|
AuthScheme: "Bearer",
|
||||||
Claims: []jwt.ClaimPair{},
|
Claims: []jwt.ClaimPair{},
|
||||||
}
|
}
|
||||||
|
@ -159,7 +167,10 @@ var (
|
||||||
// * For invalid token, it returns "401 - Unauthorized" error.
|
// * For invalid token, it returns "401 - Unauthorized" error.
|
||||||
// * For missing token, it returns "400 - Bad Request" error.
|
// * For missing token, it returns "400 - Bad Request" error.
|
||||||
func JWT(key interface{}) Interceptor {
|
func JWT(key interface{}) Interceptor {
|
||||||
return JWTWithConfig(DefaultJWTConfig)
|
config := DefaultJWTConfig
|
||||||
|
config.SigningKey = key
|
||||||
|
|
||||||
|
return JWTWithConfig(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTWithConfig returns a JWT auth middleware with config.
|
// JWTWithConfig returns a JWT auth middleware with config.
|
||||||
|
@ -198,24 +209,13 @@ func JWTWithConfig(config JWTConfig) Interceptor {
|
||||||
config.ParseTokenFunc = config.defaultParseToken
|
config.ParseTokenFunc = config.defaultParseToken
|
||||||
}
|
}
|
||||||
|
|
||||||
extractors := make([]jwtExtractor, 0)
|
extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
|
||||||
sources := strings.Split(config.TokenLookup, ",")
|
if err != nil {
|
||||||
|
panic("middleware: jwt: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
for _, source := range sources {
|
if len(config.TokenLookupFuncs) > 0 {
|
||||||
parts := strings.Split(source, ":")
|
extractors = append(config.TokenLookupFuncs, extractors...)
|
||||||
|
|
||||||
switch parts[0] {
|
|
||||||
case SourceQuery:
|
|
||||||
extractors = append(extractors, jwtFromQuery(parts[1]))
|
|
||||||
case SourceParam:
|
|
||||||
extractors = append(extractors, jwtFromParam(parts[1]))
|
|
||||||
case SourceCookie:
|
|
||||||
extractors = append(extractors, jwtFromCookie(parts[1]))
|
|
||||||
case SourceForm:
|
|
||||||
extractors = append(extractors, jwtFromForm(parts[1]))
|
|
||||||
case SourceHeader:
|
|
||||||
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(ctx *http.RequestCtx, next http.RequestHandler) {
|
return func(ctx *http.RequestCtx, next http.RequestHandler) {
|
||||||
|
@ -229,45 +229,39 @@ func JWTWithConfig(config JWTConfig) Interceptor {
|
||||||
config.BeforeFunc(ctx)
|
config.BeforeFunc(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var lastExtractorErr, lastTokenErr error
|
||||||
auth []byte
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, extractor := range extractors {
|
for _, extractor := range extractors {
|
||||||
if auth, err = extractor(ctx); err == nil {
|
auths, err := extractor(ctx)
|
||||||
break
|
if err != nil {
|
||||||
}
|
lastExtractorErr = ErrJWTMissing
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
continue
|
||||||
if config.ErrorHandler != nil {
|
}
|
||||||
config.ErrorHandler(err)
|
|
||||||
|
for _, auth := range auths {
|
||||||
|
token, err := config.ParseTokenFunc(auth, ctx)
|
||||||
|
if err != nil {
|
||||||
|
lastTokenErr = err
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.SetUserValue(config.ContextKey, token)
|
||||||
|
|
||||||
|
if config.SuccessHandler != nil {
|
||||||
|
config.SuccessHandler(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.ErrorHandlerWithContext != nil {
|
|
||||||
config.ErrorHandlerWithContext(err, ctx)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if token, err := config.ParseTokenFunc(auth, ctx); err == nil {
|
err := lastTokenErr
|
||||||
ctx.SetUserValue(config.ContextKey, token)
|
if err == nil {
|
||||||
|
err = lastExtractorErr
|
||||||
if config.SuccessHandler != nil {
|
|
||||||
config.SuccessHandler(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
next(ctx)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.ErrorHandler != nil {
|
if config.ErrorHandler != nil {
|
||||||
|
@ -277,12 +271,23 @@ func JWTWithConfig(config JWTConfig) Interceptor {
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.ErrorHandlerWithContext != nil {
|
if config.ErrorHandlerWithContext != nil {
|
||||||
config.ErrorHandlerWithContext(err, ctx)
|
tmpErr := config.ErrorHandlerWithContext(err, ctx)
|
||||||
|
if config.ContinueOnIgnoredError && tmpErr == nil {
|
||||||
|
next(ctx)
|
||||||
|
} else {
|
||||||
|
ctx.Error(tmpErr.Error(), http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Error(ErrJWTInvalid.Error(), http.StatusUnauthorized)
|
if lastTokenErr != nil {
|
||||||
|
ctx.Error(lastTokenErr.Error(), http.StatusUnauthorized)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,56 +301,3 @@ func (config *JWTConfig) defaultParseToken(auth []byte, ctx *http.RequestCtx) (i
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func jwtFromHeader(header, authScheme string) jwtExtractor {
|
|
||||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
|
||||||
auth := ctx.Request.Header.Peek(header)
|
|
||||||
l := len(authScheme)
|
|
||||||
|
|
||||||
if len(auth) > l+1 && bytes.EqualFold(auth[:l], []byte(authScheme)) {
|
|
||||||
return auth[l+1:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrJWTMissing
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func jwtFromQuery(param string) jwtExtractor {
|
|
||||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
|
||||||
if ctx.QueryArgs().Has(param) {
|
|
||||||
return nil, ErrJWTMissing
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.QueryArgs().Peek(param), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func jwtFromParam(param string) jwtExtractor {
|
|
||||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
|
||||||
if !ctx.PostArgs().Has(param) {
|
|
||||||
return nil, ErrJWTMissing
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.PostArgs().Peek(param), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func jwtFromCookie(name string) jwtExtractor {
|
|
||||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
|
||||||
if cookie := ctx.Request.Header.Cookie(name); cookie != nil {
|
|
||||||
return cookie, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrJWTMissing
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func jwtFromForm(name string) jwtExtractor {
|
|
||||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
|
||||||
if field := ctx.FormValue(name); field != nil {
|
|
||||||
return field, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, ErrJWTMissing
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue