♻️ Refactored value extractors for middlewares

This commit is contained in:
Maxim Lebedev 2022-02-26 03:12:02 +05:00
parent c5f0bbc687
commit a214a7e9e8
Signed by: toby3d
GPG Key ID: 1F14E25B7C119FC5
3 changed files with 339 additions and 240 deletions

234
csrf.go
View File

@ -5,105 +5,104 @@ import (
"crypto/subtle"
"encoding/base64"
"errors"
"strings"
"time"
http "github.com/valyala/fasthttp"
)
type (
// CSRFConfig defines the config for CSRF middleware.
CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// CSRFConfig defines the config for CSRF middleware.
type CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Max age (in seconds) of the CSRF cookie.
//
// Optional. Default value 86400 (24hr).
CookieMaxAge time.Duration
// Indicates SameSite mode of the CSRF cookie.
//
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.CookieSameSite
// Indicates SameSite mode of the CSRF cookie.
//
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.CookieSameSite
// Context key to store generated CSRF token into context.
//
// Optional. Default value "csrf".
ContextKey string
// Context key to store generated CSRF token into context.
//
// Optional. Default value "csrf".
ContextKey string
// Domain of the CSRF cookie.
//
// Optional. Default value none.
CookieDomain string
// Domain of the CSRF cookie.
//
// Optional. Default value none.
CookieDomain string
// Name of the CSRF cookie. This cookie will store CSRF token.
//
// Optional. Default value "csrf".
CookieName string
// Name of the CSRF cookie. This cookie will store CSRF token.
//
// Optional. Default value "csrf".
CookieName string
// Path of the CSRF cookie.
//
// Optional. Default value none.
CookiePath string
// Path of the CSRF cookie.
//
// Optional. Default value none.
CookiePath string
// TokenLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "query:<name>"
//
// Optional. Default value "header:X-CSRF-Token".
TokenLookup string
// TokenLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "query:<name>"
//
// Optional. Default value "header:X-CSRF-Token".
TokenLookup string
// Max age (in seconds) of the CSRF cookie.
//
// Optional. Default value 86400 (24hr).
CookieMaxAge int
// TokenLength is the length of the generated token.
//
// Optional. Default value 32.
TokenLength int
// TokenLength is the length of the generated token.
//
// Optional. Default value 32.
TokenLength int
// Indicates if CSRF cookie is secure.
//
// Optional. Default value false.
CookieSecure bool
// Indicates if CSRF cookie is secure.
//
// Optional. Default value false.
CookieSecure bool
// Indicates if CSRF cookie is HTTP only.
//
// Optional. Default value false.
CookieHTTPOnly bool
}
csrfTokenExtractor func(*http.RequestCtx) ([]byte, error)
)
// HeaderXCSRFToken describes the name of the header with the CSRF token.
const HeaderXCSRFToken string = "X-CSRF-Token"
// Possible CSRF errors.
var (
ErrMissingFormToken = errors.New("missing csrf token in the form parameter")
ErrMissingQueryToken = errors.New("missing csrf token in the query string")
)
// DefaultCSRFConfig contains the default CSRF middleware configuration.
//nolint: gochecknoglobals, gomnd
var DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper,
CookieMaxAge: 24 * time.Hour,
CookieSameSite: http.CookieSameSiteDefaultMode,
ContextKey: "csrf",
CookieDomain: "",
CookieName: "_csrf",
CookiePath: "",
TokenLookup: "header:" + HeaderXCSRFToken,
TokenLength: 32,
CookieSecure: false,
CookieHTTPOnly: false,
// Indicates if CSRF cookie is HTTP only.
//
// Optional. Default value false.
CookieHTTPOnly bool
}
const (
ExtractorLimit int = 20
// HeaderXCSRFToken describes the name of the header with the CSRF token.
HeaderXCSRFToken string = "X-CSRF-Token"
)
var (
ErrCSRFInvalid = errors.New("invalid csrf token")
// DefaultCSRFConfig contains the default CSRF middleware configuration.
//nolint: gochecknoglobals, gomnd
DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper,
CookieMaxAge: 86400,
CookieSameSite: http.CookieSameSiteDefaultMode,
ContextKey: "csrf",
CookieDomain: "",
CookieName: "_csrf",
CookiePath: "",
TokenLookup: "header:" + HeaderXCSRFToken,
TokenLength: 32,
CookieSecure: true,
CookieHTTPOnly: false,
}
)
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
func CSRF() Interceptor {
return CSRFWithConfig(DefaultCSRFConfig)
config := DefaultCSRFConfig
return CSRFWithConfig(config)
}
// CSRFWithConfig returns a CSRF middleware with config.
@ -134,17 +133,12 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
}
if config.CookieSameSite == http.CookieSameSiteNoneMode {
config.CookieSecure = true
config.CookieSecure = DefaultCSRFConfig.CookieSecure
}
parts := strings.Split(config.TokenLookup, ":")
extractor := csrfTokenFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = csrfTokenFromForm(parts[1])
case "query":
extractor = csrfTokenFromQuery(parts[1])
extractors, err := createExtractors(config.TokenLookup, "")
if err != nil {
panic("middleware: csrf: " + err.Error())
}
return func(ctx *http.RequestCtx, next http.RequestHandler) {
@ -171,16 +165,36 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
switch {
case ctx.IsGet(), ctx.IsHead(), ctx.IsOptions(), ctx.IsTrace():
default:
// NOTE(toby3d): validate token only for requests which are not defined as 'safe' by RFC7231
clientToken, err := extractor(ctx)
if err != nil {
ctx.Error(err.Error(), http.StatusBadRequest)
var lastExtractorErr, lastTokenErr error
return
outer:
for _, extractor := range extractors {
clientTokens, err := extractor(ctx)
if err != nil {
lastExtractorErr = err
continue
}
for _, clientToken := range clientTokens {
if !validateCSRFToken(token, clientToken) {
lastTokenErr = ErrCSRFInvalid
continue
}
lastTokenErr, lastExtractorErr = nil, nil
break outer
}
}
if !validateCSRFToken(token, clientToken) {
ctx.Error("invalid csrf token", http.StatusForbidden)
if lastTokenErr != nil {
ctx.Error(lastTokenErr.Error(), http.StatusInternalServerError)
return
} else if lastExtractorErr != nil {
ctx.Error(lastExtractorErr.Error(), http.StatusBadRequest)
return
}
@ -205,7 +219,7 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
cookie.SetSameSite(config.CookieSameSite)
}
cookie.SetExpire(time.Now().Add(config.CookieMaxAge))
cookie.SetExpire(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
cookie.SetSecure(config.CookieSecure)
cookie.SetHTTPOnly(config.CookieHTTPOnly)
ctx.Response.Header.SetCookie(cookie)
@ -220,32 +234,6 @@ func CSRFWithConfig(config CSRFConfig) Interceptor {
}
}
func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(ctx *http.RequestCtx) ([]byte, error) {
return ctx.Request.Header.Peek(header), nil
}
}
func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(ctx *http.RequestCtx) ([]byte, error) {
if token := ctx.FormValue(param); token != nil {
return token, nil
}
return nil, ErrMissingFormToken
}
}
func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(ctx *http.RequestCtx) ([]byte, error) {
if !ctx.QueryArgs().Has(param) {
return nil, ErrMissingQueryToken
}
return ctx.QueryArgs().Peek(param), nil
}
}
func validateCSRFToken(token, clientToken []byte) bool {
return subtle.ConstantTimeCompare(token, clientToken) == 1
}

159
extractor.go Normal file
View File

@ -0,0 +1,159 @@
package middleware
import (
"bytes"
"errors"
"fmt"
"net/textproto"
"strings"
http "github.com/valyala/fasthttp"
)
// ValuesExtractor defines a function for extracting values (keys/tokens) from
// the given context.
type ValuesExtractor func(ctx *http.RequestCtx) ([][]byte, error)
const (
// extractorLimit is arbitrary number to limit values extractor can
// return. this limits possible resource exhaustion attack vector.
extractorLimit = 20
)
var (
errHeaderExtractorValueMissing = errors.New("missing value in request header")
errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
errQueryExtractorValueMissing = errors.New("missing value in the query string")
errParamExtractorValueMissing = errors.New("missing value in path params")
errCookieExtractorValueMissing = errors.New("missing value in cookies")
errFormExtractorValueMissing = errors.New("missing value in the form")
)
func createExtractors(lookups, authScheme string) ([]ValuesExtractor, error) {
if lookups == "" {
return nil, nil
}
sources := strings.Split(lookups, ",")
extractors := make([]ValuesExtractor, 0)
for _, source := range sources {
parts := strings.Split(source, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %s",
source)
}
switch parts[0] {
case "query":
extractors = append(extractors, valuesFromQuery(parts[1]))
case "cookie":
extractors = append(extractors, valuesFromCookie(parts[1]))
case "form":
extractors = append(extractors, valuesFromForm(parts[1]))
case "header":
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
} else if authScheme != "" && parts[1] == http.HeaderAuthorization {
// backwards compatibility for JWT and KeyAuth:
// * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer <token-value>" etc
// * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
// behaviour for default values and Authorization header.
prefix = authScheme
if !strings.HasSuffix(prefix, " ") {
prefix += " "
}
}
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
}
}
return extractors, nil
}
// valuesFromHeader returns a functions that extracts values from the request header.
// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
// prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we want to remove is `<auth-scheme> `
// note the space at the end. In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove
// is `Basic `. In case of JWT tokens `Authorization: Bearer <token>` prefix is `Bearer `.
// If prefix is left empty the whole value is returned.
func valuesFromHeader(header, valuePrefix string) ValuesExtractor {
prefixLen := len(valuePrefix)
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
header = textproto.CanonicalMIMEHeaderKey(header)
return func(ctx *http.RequestCtx) ([][]byte, error) {
value := ctx.Request.Header.Peek(header)
if len(value) == 0 {
return nil, errHeaderExtractorValueMissing
}
if prefixLen == 0 {
return append([][]byte{}, value), nil
}
if len(value) > prefixLen && bytes.EqualFold(value[:prefixLen], []byte(valuePrefix)) {
return append([][]byte{}, value[:prefixLen]), nil
}
return nil, errHeaderExtractorValueInvalid
}
}
// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string) ValuesExtractor {
return func(ctx *http.RequestCtx) ([][]byte, error) {
if !ctx.QueryArgs().Has(param) {
return nil, errQueryExtractorValueMissing
}
result := ctx.QueryArgs().PeekMulti(param)
if len(result) > extractorLimit-1 {
result = result[:extractorLimit]
}
return result, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) ValuesExtractor {
return func(ctx *http.RequestCtx) ([][]byte, error) {
if value := ctx.Request.Header.Cookie(name); len(value) != 0 {
return append([][]byte{}, value), nil
}
return nil, errCookieExtractorValueMissing
}
}
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) ValuesExtractor {
return func(ctx *http.RequestCtx) ([][]byte, error) {
form, err := ctx.MultipartForm()
if err != nil {
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", err)
}
values := form.Value[name]
switch {
case len(values) == 0:
return nil, errFormExtractorValueMissing
case len(values) > extractorLimit-1:
values = values[:extractorLimit]
}
result := make([][]byte, 0)
for i := range values {
result = append(result, []byte(values[i]))
}
return result, nil
}
}

186
jwt.go
View File

@ -1,10 +1,8 @@
package middleware
import (
"bytes"
"errors"
"fmt"
"strings"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
@ -83,6 +81,14 @@ type (
// - "header: Authorization,cookie: myowncookie"
TokenLookup string
// TokenLookupFuncs defines a list of user-defined functions
// that extract JWT token from the given context.
// This is one of the two options to provide a token extractor.
// The order of precedence is user-defined TokenLookupFuncs, and
// TokenLookup.
// You can also provide both if you want.
TokenLookupFuncs []ValuesExtractor
// AuthScheme to be used in the Authorization header.
//
// Optional. Default value "Bearer".
@ -110,6 +116,17 @@ type (
// fails or parsed token is invalid. Defaults to implementation
// using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(auth []byte, ctx *http.RequestCtx) (interface{}, error)
// ContinueOnIgnoredError allows the next middleware/handler to
// be called when ErrorHandlerWithContext decides to ignore the
// error (by returning `nil`). This is useful when parts of your
// site/api allow public access and some authorized routes
// provide extra functionality. In that case you can use
// ErrorHandlerWithContext to set a default public JWT token
// value in the request context and continue. Some logic down
// the remaining execution chain needs to check that (public)
// token value then.
ContinueOnIgnoredError bool
}
// JWTSuccessHandler defines a function which is executed for a valid
@ -122,27 +139,18 @@ type (
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler,
// but it's passed the current context.
JWTErrorHandlerWithContext func(err error, ctx *http.RequestCtx)
JWTErrorHandlerWithContext func(err error, ctx *http.RequestCtx) error
jwtExtractor func(ctx *http.RequestCtx) ([]byte, error)
)
// Variants of token sources.
const (
SourceCookie = "cookie"
SourceForm = "form"
SourceHeader = "header"
SourceParam = "param"
SourceQuery = "query"
)
// DefaultJWTConfig is the default JWT auth middleware config.
//nolint: gochecknoglobals
var DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper,
SigningMethod: "HS256",
ContextKey: "user",
TokenLookup: SourceHeader + ":" + http.HeaderAuthorization,
TokenLookup: "header:" + http.HeaderAuthorization,
AuthScheme: "Bearer",
Claims: []jwt.ClaimPair{},
}
@ -159,7 +167,10 @@ var (
// * For invalid token, it returns "401 - Unauthorized" error.
// * For missing token, it returns "400 - Bad Request" error.
func JWT(key interface{}) Interceptor {
return JWTWithConfig(DefaultJWTConfig)
config := DefaultJWTConfig
config.SigningKey = key
return JWTWithConfig(config)
}
// JWTWithConfig returns a JWT auth middleware with config.
@ -198,24 +209,13 @@ func JWTWithConfig(config JWTConfig) Interceptor {
config.ParseTokenFunc = config.defaultParseToken
}
extractors := make([]jwtExtractor, 0)
sources := strings.Split(config.TokenLookup, ",")
extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
if err != nil {
panic("middleware: jwt: " + err.Error())
}
for _, source := range sources {
parts := strings.Split(source, ":")
switch parts[0] {
case SourceQuery:
extractors = append(extractors, jwtFromQuery(parts[1]))
case SourceParam:
extractors = append(extractors, jwtFromParam(parts[1]))
case SourceCookie:
extractors = append(extractors, jwtFromCookie(parts[1]))
case SourceForm:
extractors = append(extractors, jwtFromForm(parts[1]))
case SourceHeader:
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
}
if len(config.TokenLookupFuncs) > 0 {
extractors = append(config.TokenLookupFuncs, extractors...)
}
return func(ctx *http.RequestCtx, next http.RequestHandler) {
@ -229,45 +229,39 @@ func JWTWithConfig(config JWTConfig) Interceptor {
config.BeforeFunc(ctx)
}
var (
auth []byte
err error
)
var lastExtractorErr, lastTokenErr error
for _, extractor := range extractors {
if auth, err = extractor(ctx); err == nil {
break
}
}
auths, err := extractor(ctx)
if err != nil {
lastExtractorErr = ErrJWTMissing
if err != nil {
if config.ErrorHandler != nil {
config.ErrorHandler(err)
continue
}
for _, auth := range auths {
token, err := config.ParseTokenFunc(auth, ctx)
if err != nil {
lastTokenErr = err
continue
}
ctx.SetUserValue(config.ContextKey, token)
if config.SuccessHandler != nil {
config.SuccessHandler(ctx)
}
next(ctx)
return
}
if config.ErrorHandlerWithContext != nil {
config.ErrorHandlerWithContext(err, ctx)
return
}
ctx.Error(err.Error(), http.StatusInternalServerError)
return
}
if token, err := config.ParseTokenFunc(auth, ctx); err == nil {
ctx.SetUserValue(config.ContextKey, token)
if config.SuccessHandler != nil {
config.SuccessHandler(ctx)
}
next(ctx)
return
err := lastTokenErr
if err == nil {
err = lastExtractorErr
}
if config.ErrorHandler != nil {
@ -277,12 +271,23 @@ func JWTWithConfig(config JWTConfig) Interceptor {
}
if config.ErrorHandlerWithContext != nil {
config.ErrorHandlerWithContext(err, ctx)
tmpErr := config.ErrorHandlerWithContext(err, ctx)
if config.ContinueOnIgnoredError && tmpErr == nil {
next(ctx)
} else {
ctx.Error(tmpErr.Error(), http.StatusUnauthorized)
}
return
}
ctx.Error(ErrJWTInvalid.Error(), http.StatusUnauthorized)
if lastTokenErr != nil {
ctx.Error(lastTokenErr.Error(), http.StatusUnauthorized)
return
}
ctx.Error(err.Error(), http.StatusInternalServerError)
}
}
@ -296,56 +301,3 @@ func (config *JWTConfig) defaultParseToken(auth []byte, ctx *http.RequestCtx) (i
return token, nil
}
func jwtFromHeader(header, authScheme string) jwtExtractor {
return func(ctx *http.RequestCtx) ([]byte, error) {
auth := ctx.Request.Header.Peek(header)
l := len(authScheme)
if len(auth) > l+1 && bytes.EqualFold(auth[:l], []byte(authScheme)) {
return auth[l+1:], nil
}
return nil, ErrJWTMissing
}
}
func jwtFromQuery(param string) jwtExtractor {
return func(ctx *http.RequestCtx) ([]byte, error) {
if ctx.QueryArgs().Has(param) {
return nil, ErrJWTMissing
}
return ctx.QueryArgs().Peek(param), nil
}
}
func jwtFromParam(param string) jwtExtractor {
return func(ctx *http.RequestCtx) ([]byte, error) {
if !ctx.PostArgs().Has(param) {
return nil, ErrJWTMissing
}
return ctx.PostArgs().Peek(param), nil
}
}
func jwtFromCookie(name string) jwtExtractor {
return func(ctx *http.RequestCtx) ([]byte, error) {
if cookie := ctx.Request.Header.Cookie(name); cookie != nil {
return cookie, nil
}
return nil, ErrJWTMissing
}
}
func jwtFromForm(name string) jwtExtractor {
return func(ctx *http.RequestCtx) ([]byte, error) {
if field := ctx.FormValue(name); field != nil {
return field, nil
}
return nil, ErrJWTMissing
}
}