♻️ Refactored value extractors for middlewares
This commit is contained in:
parent
c5f0bbc687
commit
a214a7e9e8
3 changed files with 339 additions and 240 deletions
234
csrf.go
234
csrf.go
|
@ -5,105 +5,104 @@ import (
|
|||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
http "github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type (
|
||||
// CSRFConfig defines the config for CSRF middleware.
|
||||
CSRFConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// CSRFConfig defines the config for CSRF middleware.
|
||||
type CSRFConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Max age (in seconds) of the CSRF cookie.
|
||||
//
|
||||
// Optional. Default value 86400 (24hr).
|
||||
CookieMaxAge time.Duration
|
||||
// Indicates SameSite mode of the CSRF cookie.
|
||||
//
|
||||
// Optional. Default value SameSiteDefaultMode.
|
||||
CookieSameSite http.CookieSameSite
|
||||
|
||||
// Indicates SameSite mode of the CSRF cookie.
|
||||
//
|
||||
// Optional. Default value SameSiteDefaultMode.
|
||||
CookieSameSite http.CookieSameSite
|
||||
// Context key to store generated CSRF token into context.
|
||||
//
|
||||
// Optional. Default value "csrf".
|
||||
ContextKey string
|
||||
|
||||
// Context key to store generated CSRF token into context.
|
||||
//
|
||||
// Optional. Default value "csrf".
|
||||
ContextKey string
|
||||
// Domain of the CSRF cookie.
|
||||
//
|
||||
// Optional. Default value none.
|
||||
CookieDomain string
|
||||
|
||||
// Domain of the CSRF cookie.
|
||||
//
|
||||
// Optional. Default value none.
|
||||
CookieDomain string
|
||||
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||||
//
|
||||
// Optional. Default value "csrf".
|
||||
CookieName string
|
||||
|
||||
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||||
//
|
||||
// Optional. Default value "csrf".
|
||||
CookieName string
|
||||
// Path of the CSRF cookie.
|
||||
//
|
||||
// Optional. Default value none.
|
||||
CookiePath string
|
||||
|
||||
// Path of the CSRF cookie.
|
||||
//
|
||||
// Optional. Default value none.
|
||||
CookiePath string
|
||||
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
||||
// to extract token from the request.
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "form:<name>"
|
||||
// - "query:<name>"
|
||||
//
|
||||
// Optional. Default value "header:X-CSRF-Token".
|
||||
TokenLookup string
|
||||
|
||||
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
||||
// to extract token from the request.
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "form:<name>"
|
||||
// - "query:<name>"
|
||||
//
|
||||
// Optional. Default value "header:X-CSRF-Token".
|
||||
TokenLookup string
|
||||
// Max age (in seconds) of the CSRF cookie.
|
||||
//
|
||||
// Optional. Default value 86400 (24hr).
|
||||
CookieMaxAge int
|
||||
|
||||
// TokenLength is the length of the generated token.
|
||||
//
|
||||
// Optional. Default value 32.
|
||||
TokenLength int
|
||||
// TokenLength is the length of the generated token.
|
||||
//
|
||||
// Optional. Default value 32.
|
||||
TokenLength int
|
||||
|
||||
// Indicates if CSRF cookie is secure.
|
||||
//
|
||||
// Optional. Default value false.
|
||||
CookieSecure bool
|
||||
// Indicates if CSRF cookie is secure.
|
||||
//
|
||||
// Optional. Default value false.
|
||||
CookieSecure bool
|
||||
|
||||
// Indicates if CSRF cookie is HTTP only.
|
||||
//
|
||||
// Optional. Default value false.
|
||||
CookieHTTPOnly bool
|
||||
}
|
||||
|
||||
csrfTokenExtractor func(*http.RequestCtx) ([]byte, error)
|
||||
)
|
||||
|
||||
// HeaderXCSRFToken describes the name of the header with the CSRF token.
|
||||
const HeaderXCSRFToken string = "X-CSRF-Token"
|
||||
|
||||
// Possible CSRF errors.
|
||||
var (
|
||||
ErrMissingFormToken = errors.New("missing csrf token in the form parameter")
|
||||
ErrMissingQueryToken = errors.New("missing csrf token in the query string")
|
||||
)
|
||||
|
||||
// DefaultCSRFConfig contains the default CSRF middleware configuration.
|
||||
//nolint: gochecknoglobals, gomnd
|
||||
var DefaultCSRFConfig = CSRFConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
CookieMaxAge: 24 * time.Hour,
|
||||
CookieSameSite: http.CookieSameSiteDefaultMode,
|
||||
ContextKey: "csrf",
|
||||
CookieDomain: "",
|
||||
CookieName: "_csrf",
|
||||
CookiePath: "",
|
||||
TokenLookup: "header:" + HeaderXCSRFToken,
|
||||
TokenLength: 32,
|
||||
CookieSecure: false,
|
||||
CookieHTTPOnly: false,
|
||||
// Indicates if CSRF cookie is HTTP only.
|
||||
//
|
||||
// Optional. Default value false.
|
||||
CookieHTTPOnly bool
|
||||
}
|
||||
|
||||
const (
|
||||
ExtractorLimit int = 20
|
||||
|
||||
// HeaderXCSRFToken describes the name of the header with the CSRF token.
|
||||
HeaderXCSRFToken string = "X-CSRF-Token"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCSRFInvalid = errors.New("invalid csrf token")
|
||||
|
||||
// DefaultCSRFConfig contains the default CSRF middleware configuration.
|
||||
//nolint: gochecknoglobals, gomnd
|
||||
DefaultCSRFConfig = CSRFConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
CookieMaxAge: 86400,
|
||||
CookieSameSite: http.CookieSameSiteDefaultMode,
|
||||
ContextKey: "csrf",
|
||||
CookieDomain: "",
|
||||
CookieName: "_csrf",
|
||||
CookiePath: "",
|
||||
TokenLookup: "header:" + HeaderXCSRFToken,
|
||||
TokenLength: 32,
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: false,
|
||||
}
|
||||
)
|
||||
|
||||
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
||||
func CSRF() Interceptor {
|
||||
return CSRFWithConfig(DefaultCSRFConfig)
|
||||
config := DefaultCSRFConfig
|
||||
|
||||
return CSRFWithConfig(config)
|
||||
}
|
||||
|
||||
// CSRFWithConfig returns a CSRF middleware with config.
|
||||
|
@ -134,17 +133,12 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
|||
}
|
||||
|
||||
if config.CookieSameSite == http.CookieSameSiteNoneMode {
|
||||
config.CookieSecure = true
|
||||
config.CookieSecure = DefaultCSRFConfig.CookieSecure
|
||||
}
|
||||
|
||||
parts := strings.Split(config.TokenLookup, ":")
|
||||
extractor := csrfTokenFromHeader(parts[1])
|
||||
|
||||
switch parts[0] {
|
||||
case "form":
|
||||
extractor = csrfTokenFromForm(parts[1])
|
||||
case "query":
|
||||
extractor = csrfTokenFromQuery(parts[1])
|
||||
extractors, err := createExtractors(config.TokenLookup, "")
|
||||
if err != nil {
|
||||
panic("middleware: csrf: " + err.Error())
|
||||
}
|
||||
|
||||
return func(ctx *http.RequestCtx, next http.RequestHandler) {
|
||||
|
@ -171,16 +165,36 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
|||
switch {
|
||||
case ctx.IsGet(), ctx.IsHead(), ctx.IsOptions(), ctx.IsTrace():
|
||||
default:
|
||||
// NOTE(toby3d): validate token only for requests which are not defined as 'safe' by RFC7231
|
||||
clientToken, err := extractor(ctx)
|
||||
if err != nil {
|
||||
ctx.Error(err.Error(), http.StatusBadRequest)
|
||||
var lastExtractorErr, lastTokenErr error
|
||||
|
||||
return
|
||||
outer:
|
||||
for _, extractor := range extractors {
|
||||
clientTokens, err := extractor(ctx)
|
||||
if err != nil {
|
||||
lastExtractorErr = err
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
for _, clientToken := range clientTokens {
|
||||
if !validateCSRFToken(token, clientToken) {
|
||||
lastTokenErr = ErrCSRFInvalid
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
lastTokenErr, lastExtractorErr = nil, nil
|
||||
|
||||
break outer
|
||||
}
|
||||
}
|
||||
|
||||
if !validateCSRFToken(token, clientToken) {
|
||||
ctx.Error("invalid csrf token", http.StatusForbidden)
|
||||
if lastTokenErr != nil {
|
||||
ctx.Error(lastTokenErr.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
} else if lastExtractorErr != nil {
|
||||
ctx.Error(lastExtractorErr.Error(), http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -205,7 +219,7 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
|||
cookie.SetSameSite(config.CookieSameSite)
|
||||
}
|
||||
|
||||
cookie.SetExpire(time.Now().Add(config.CookieMaxAge))
|
||||
cookie.SetExpire(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
|
||||
cookie.SetSecure(config.CookieSecure)
|
||||
cookie.SetHTTPOnly(config.CookieHTTPOnly)
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
|
@ -220,32 +234,6 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
|
|||
}
|
||||
}
|
||||
|
||||
func csrfTokenFromHeader(header string) csrfTokenExtractor {
|
||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
||||
return ctx.Request.Header.Peek(header), nil
|
||||
}
|
||||
}
|
||||
|
||||
func csrfTokenFromForm(param string) csrfTokenExtractor {
|
||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
||||
if token := ctx.FormValue(param); token != nil {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
return nil, ErrMissingFormToken
|
||||
}
|
||||
}
|
||||
|
||||
func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
||||
if !ctx.QueryArgs().Has(param) {
|
||||
return nil, ErrMissingQueryToken
|
||||
}
|
||||
|
||||
return ctx.QueryArgs().Peek(param), nil
|
||||
}
|
||||
}
|
||||
|
||||
func validateCSRFToken(token, clientToken []byte) bool {
|
||||
return subtle.ConstantTimeCompare(token, clientToken) == 1
|
||||
}
|
||||
|
|
159
extractor.go
Normal file
159
extractor.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
|
||||
http "github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ValuesExtractor defines a function for extracting values (keys/tokens) from
|
||||
// the given context.
|
||||
type ValuesExtractor func(ctx *http.RequestCtx) ([][]byte, error)
|
||||
|
||||
const (
|
||||
// extractorLimit is arbitrary number to limit values extractor can
|
||||
// return. this limits possible resource exhaustion attack vector.
|
||||
extractorLimit = 20
|
||||
)
|
||||
|
||||
var (
|
||||
errHeaderExtractorValueMissing = errors.New("missing value in request header")
|
||||
errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
|
||||
errQueryExtractorValueMissing = errors.New("missing value in the query string")
|
||||
errParamExtractorValueMissing = errors.New("missing value in path params")
|
||||
errCookieExtractorValueMissing = errors.New("missing value in cookies")
|
||||
errFormExtractorValueMissing = errors.New("missing value in the form")
|
||||
)
|
||||
|
||||
func createExtractors(lookups, authScheme string) ([]ValuesExtractor, error) {
|
||||
if lookups == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sources := strings.Split(lookups, ",")
|
||||
extractors := make([]ValuesExtractor, 0)
|
||||
|
||||
for _, source := range sources {
|
||||
parts := strings.Split(source, ":")
|
||||
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %s",
|
||||
source)
|
||||
}
|
||||
|
||||
switch parts[0] {
|
||||
case "query":
|
||||
extractors = append(extractors, valuesFromQuery(parts[1]))
|
||||
case "cookie":
|
||||
extractors = append(extractors, valuesFromCookie(parts[1]))
|
||||
case "form":
|
||||
extractors = append(extractors, valuesFromForm(parts[1]))
|
||||
case "header":
|
||||
prefix := ""
|
||||
|
||||
if len(parts) > 2 {
|
||||
prefix = parts[2]
|
||||
} else if authScheme != "" && parts[1] == http.HeaderAuthorization {
|
||||
// backwards compatibility for JWT and KeyAuth:
|
||||
// * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer <token-value>" etc
|
||||
// * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
|
||||
// behaviour for default values and Authorization header.
|
||||
prefix = authScheme
|
||||
|
||||
if !strings.HasSuffix(prefix, " ") {
|
||||
prefix += " "
|
||||
}
|
||||
}
|
||||
|
||||
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
|
||||
}
|
||||
}
|
||||
|
||||
return extractors, nil
|
||||
}
|
||||
|
||||
// valuesFromHeader returns a functions that extracts values from the request header.
|
||||
// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
|
||||
// prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we want to remove is `<auth-scheme> `
|
||||
// note the space at the end. In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove
|
||||
// is `Basic `. In case of JWT tokens `Authorization: Bearer <token>` prefix is `Bearer `.
|
||||
// If prefix is left empty the whole value is returned.
|
||||
func valuesFromHeader(header, valuePrefix string) ValuesExtractor {
|
||||
prefixLen := len(valuePrefix)
|
||||
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
|
||||
header = textproto.CanonicalMIMEHeaderKey(header)
|
||||
|
||||
return func(ctx *http.RequestCtx) ([][]byte, error) {
|
||||
value := ctx.Request.Header.Peek(header)
|
||||
if len(value) == 0 {
|
||||
return nil, errHeaderExtractorValueMissing
|
||||
}
|
||||
|
||||
if prefixLen == 0 {
|
||||
return append([][]byte{}, value), nil
|
||||
}
|
||||
|
||||
if len(value) > prefixLen && bytes.EqualFold(value[:prefixLen], []byte(valuePrefix)) {
|
||||
return append([][]byte{}, value[:prefixLen]), nil
|
||||
}
|
||||
|
||||
return nil, errHeaderExtractorValueInvalid
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromQuery returns a function that extracts values from the query string.
|
||||
func valuesFromQuery(param string) ValuesExtractor {
|
||||
return func(ctx *http.RequestCtx) ([][]byte, error) {
|
||||
if !ctx.QueryArgs().Has(param) {
|
||||
return nil, errQueryExtractorValueMissing
|
||||
}
|
||||
|
||||
result := ctx.QueryArgs().PeekMulti(param)
|
||||
if len(result) > extractorLimit-1 {
|
||||
result = result[:extractorLimit]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromCookie returns a function that extracts values from the named cookie.
|
||||
func valuesFromCookie(name string) ValuesExtractor {
|
||||
return func(ctx *http.RequestCtx) ([][]byte, error) {
|
||||
if value := ctx.Request.Header.Cookie(name); len(value) != 0 {
|
||||
return append([][]byte{}, value), nil
|
||||
}
|
||||
|
||||
return nil, errCookieExtractorValueMissing
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromForm returns a function that extracts values from the form field.
|
||||
func valuesFromForm(name string) ValuesExtractor {
|
||||
return func(ctx *http.RequestCtx) ([][]byte, error) {
|
||||
form, err := ctx.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", err)
|
||||
}
|
||||
|
||||
values := form.Value[name]
|
||||
|
||||
switch {
|
||||
case len(values) == 0:
|
||||
return nil, errFormExtractorValueMissing
|
||||
case len(values) > extractorLimit-1:
|
||||
values = values[:extractorLimit]
|
||||
}
|
||||
|
||||
result := make([][]byte, 0)
|
||||
for i := range values {
|
||||
result = append(result, []byte(values[i]))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
186
jwt.go
186
jwt.go
|
@ -1,10 +1,8 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
|
@ -83,6 +81,14 @@ type (
|
|||
// - "header: Authorization,cookie: myowncookie"
|
||||
TokenLookup string
|
||||
|
||||
// TokenLookupFuncs defines a list of user-defined functions
|
||||
// that extract JWT token from the given context.
|
||||
// This is one of the two options to provide a token extractor.
|
||||
// The order of precedence is user-defined TokenLookupFuncs, and
|
||||
// TokenLookup.
|
||||
// You can also provide both if you want.
|
||||
TokenLookupFuncs []ValuesExtractor
|
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
//
|
||||
// Optional. Default value "Bearer".
|
||||
|
@ -110,6 +116,17 @@ type (
|
|||
// fails or parsed token is invalid. Defaults to implementation
|
||||
// using `github.com/golang-jwt/jwt` as JWT implementation library
|
||||
ParseTokenFunc func(auth []byte, ctx *http.RequestCtx) (interface{}, error)
|
||||
|
||||
// ContinueOnIgnoredError allows the next middleware/handler to
|
||||
// be called when ErrorHandlerWithContext decides to ignore the
|
||||
// error (by returning `nil`). This is useful when parts of your
|
||||
// site/api allow public access and some authorized routes
|
||||
// provide extra functionality. In that case you can use
|
||||
// ErrorHandlerWithContext to set a default public JWT token
|
||||
// value in the request context and continue. Some logic down
|
||||
// the remaining execution chain needs to check that (public)
|
||||
// token value then.
|
||||
ContinueOnIgnoredError bool
|
||||
}
|
||||
|
||||
// JWTSuccessHandler defines a function which is executed for a valid
|
||||
|
@ -122,27 +139,18 @@ type (
|
|||
|
||||
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler,
|
||||
// but it's passed the current context.
|
||||
JWTErrorHandlerWithContext func(err error, ctx *http.RequestCtx)
|
||||
JWTErrorHandlerWithContext func(err error, ctx *http.RequestCtx) error
|
||||
|
||||
jwtExtractor func(ctx *http.RequestCtx) ([]byte, error)
|
||||
)
|
||||
|
||||
// Variants of token sources.
|
||||
const (
|
||||
SourceCookie = "cookie"
|
||||
SourceForm = "form"
|
||||
SourceHeader = "header"
|
||||
SourceParam = "param"
|
||||
SourceQuery = "query"
|
||||
)
|
||||
|
||||
// DefaultJWTConfig is the default JWT auth middleware config.
|
||||
//nolint: gochecknoglobals
|
||||
var DefaultJWTConfig = JWTConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
SigningMethod: "HS256",
|
||||
ContextKey: "user",
|
||||
TokenLookup: SourceHeader + ":" + http.HeaderAuthorization,
|
||||
TokenLookup: "header:" + http.HeaderAuthorization,
|
||||
AuthScheme: "Bearer",
|
||||
Claims: []jwt.ClaimPair{},
|
||||
}
|
||||
|
@ -159,7 +167,10 @@ var (
|
|||
// * For invalid token, it returns "401 - Unauthorized" error.
|
||||
// * For missing token, it returns "400 - Bad Request" error.
|
||||
func JWT(key interface{}) Interceptor {
|
||||
return JWTWithConfig(DefaultJWTConfig)
|
||||
config := DefaultJWTConfig
|
||||
config.SigningKey = key
|
||||
|
||||
return JWTWithConfig(config)
|
||||
}
|
||||
|
||||
// JWTWithConfig returns a JWT auth middleware with config.
|
||||
|
@ -198,24 +209,13 @@ func JWTWithConfig(config JWTConfig) Interceptor {
|
|||
config.ParseTokenFunc = config.defaultParseToken
|
||||
}
|
||||
|
||||
extractors := make([]jwtExtractor, 0)
|
||||
sources := strings.Split(config.TokenLookup, ",")
|
||||
extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
|
||||
if err != nil {
|
||||
panic("middleware: jwt: " + err.Error())
|
||||
}
|
||||
|
||||
for _, source := range sources {
|
||||
parts := strings.Split(source, ":")
|
||||
|
||||
switch parts[0] {
|
||||
case SourceQuery:
|
||||
extractors = append(extractors, jwtFromQuery(parts[1]))
|
||||
case SourceParam:
|
||||
extractors = append(extractors, jwtFromParam(parts[1]))
|
||||
case SourceCookie:
|
||||
extractors = append(extractors, jwtFromCookie(parts[1]))
|
||||
case SourceForm:
|
||||
extractors = append(extractors, jwtFromForm(parts[1]))
|
||||
case SourceHeader:
|
||||
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
|
||||
}
|
||||
if len(config.TokenLookupFuncs) > 0 {
|
||||
extractors = append(config.TokenLookupFuncs, extractors...)
|
||||
}
|
||||
|
||||
return func(ctx *http.RequestCtx, next http.RequestHandler) {
|
||||
|
@ -229,45 +229,39 @@ func JWTWithConfig(config JWTConfig) Interceptor {
|
|||
config.BeforeFunc(ctx)
|
||||
}
|
||||
|
||||
var (
|
||||
auth []byte
|
||||
err error
|
||||
)
|
||||
var lastExtractorErr, lastTokenErr error
|
||||
|
||||
for _, extractor := range extractors {
|
||||
if auth, err = extractor(ctx); err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
auths, err := extractor(ctx)
|
||||
if err != nil {
|
||||
lastExtractorErr = ErrJWTMissing
|
||||
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
config.ErrorHandler(err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, auth := range auths {
|
||||
token, err := config.ParseTokenFunc(auth, ctx)
|
||||
if err != nil {
|
||||
lastTokenErr = err
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
ctx.SetUserValue(config.ContextKey, token)
|
||||
|
||||
if config.SuccessHandler != nil {
|
||||
config.SuccessHandler(ctx)
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if config.ErrorHandlerWithContext != nil {
|
||||
config.ErrorHandlerWithContext(err, ctx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if token, err := config.ParseTokenFunc(auth, ctx); err == nil {
|
||||
ctx.SetUserValue(config.ContextKey, token)
|
||||
|
||||
if config.SuccessHandler != nil {
|
||||
config.SuccessHandler(ctx)
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
|
||||
return
|
||||
err := lastTokenErr
|
||||
if err == nil {
|
||||
err = lastExtractorErr
|
||||
}
|
||||
|
||||
if config.ErrorHandler != nil {
|
||||
|
@ -277,12 +271,23 @@ func JWTWithConfig(config JWTConfig) Interceptor {
|
|||
}
|
||||
|
||||
if config.ErrorHandlerWithContext != nil {
|
||||
config.ErrorHandlerWithContext(err, ctx)
|
||||
tmpErr := config.ErrorHandlerWithContext(err, ctx)
|
||||
if config.ContinueOnIgnoredError && tmpErr == nil {
|
||||
next(ctx)
|
||||
} else {
|
||||
ctx.Error(tmpErr.Error(), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Error(ErrJWTInvalid.Error(), http.StatusUnauthorized)
|
||||
if lastTokenErr != nil {
|
||||
ctx.Error(lastTokenErr.Error(), http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Error(err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -296,56 +301,3 @@ func (config *JWTConfig) defaultParseToken(auth []byte, ctx *http.RequestCtx) (i
|
|||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func jwtFromHeader(header, authScheme string) jwtExtractor {
|
||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
||||
auth := ctx.Request.Header.Peek(header)
|
||||
l := len(authScheme)
|
||||
|
||||
if len(auth) > l+1 && bytes.EqualFold(auth[:l], []byte(authScheme)) {
|
||||
return auth[l+1:], nil
|
||||
}
|
||||
|
||||
return nil, ErrJWTMissing
|
||||
}
|
||||
}
|
||||
|
||||
func jwtFromQuery(param string) jwtExtractor {
|
||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
||||
if ctx.QueryArgs().Has(param) {
|
||||
return nil, ErrJWTMissing
|
||||
}
|
||||
|
||||
return ctx.QueryArgs().Peek(param), nil
|
||||
}
|
||||
}
|
||||
|
||||
func jwtFromParam(param string) jwtExtractor {
|
||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
||||
if !ctx.PostArgs().Has(param) {
|
||||
return nil, ErrJWTMissing
|
||||
}
|
||||
|
||||
return ctx.PostArgs().Peek(param), nil
|
||||
}
|
||||
}
|
||||
|
||||
func jwtFromCookie(name string) jwtExtractor {
|
||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
||||
if cookie := ctx.Request.Header.Cookie(name); cookie != nil {
|
||||
return cookie, nil
|
||||
}
|
||||
|
||||
return nil, ErrJWTMissing
|
||||
}
|
||||
}
|
||||
|
||||
func jwtFromForm(name string) jwtExtractor {
|
||||
return func(ctx *http.RequestCtx) ([]byte, error) {
|
||||
if field := ctx.FormValue(name); field != nil {
|
||||
return field, nil
|
||||
}
|
||||
|
||||
return nil, ErrJWTMissing
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue