package middleware import ( "crypto/rand" "crypto/subtle" "encoding/base64" "errors" "strings" "time" http "github.com/valyala/fasthttp" ) type ( // CSRFConfig defines the config for CSRF middleware. CSRFConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Max age (in seconds) of the CSRF cookie. // // Optional. Default value 86400 (24hr). CookieMaxAge time.Duration // Indicates SameSite mode of the CSRF cookie. // // Optional. Default value SameSiteDefaultMode. CookieSameSite http.CookieSameSite // Context key to store generated CSRF token into context. // // Optional. Default value "csrf". ContextKey string // Domain of the CSRF cookie. // // Optional. Default value none. CookieDomain string // Name of the CSRF cookie. This cookie will store CSRF token. // // Optional. Default value "csrf". CookieName string // Path of the CSRF cookie. // // Optional. Default value none. CookiePath string // TokenLookup is a string in the form of ":" that is used // to extract token from the request. // Possible values: // - "header:" // - "form:" // - "query:" // // Optional. Default value "header:X-CSRF-Token". TokenLookup string // TokenLength is the length of the generated token. // // Optional. Default value 32. TokenLength int // Indicates if CSRF cookie is secure. // // Optional. Default value false. CookieSecure bool // Indicates if CSRF cookie is HTTP only. // // Optional. Default value false. CookieHTTPOnly bool } csrfTokenExtractor func(*http.RequestCtx) ([]byte, error) ) // HeaderXCSRFToken describes the name of the header with the CSRF token. const HeaderXCSRFToken string = "X-CSRF-Token" // Possible CSRF errors. var ( ErrMissingFormToken = errors.New("missing csrf token in the form parameter") ErrMissingQueryToken = errors.New("missing csrf token in the query string") ) // DefaultCSRFConfig contains the default CSRF middleware configuration. //nolint: gochecknoglobals, gomnd var DefaultCSRFConfig = CSRFConfig{ Skipper: DefaultSkipper, CookieMaxAge: 24 * time.Hour, CookieSameSite: http.CookieSameSiteDefaultMode, ContextKey: "csrf", CookieDomain: "", CookieName: "_csrf", CookiePath: "", TokenLookup: "header:" + HeaderXCSRFToken, TokenLength: 32, CookieSecure: false, CookieHTTPOnly: false, } // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. func CSRF() Interceptor { return CSRFWithConfig(DefaultCSRFConfig) } // CSRFWithConfig returns a CSRF middleware with config. //nolint: funlen, cyclop func CSRFWithConfig(config CSRFConfig) Interceptor { if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper } if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } if config.ContextKey == "" { config.ContextKey = DefaultCSRFConfig.ContextKey } if config.CookieName == "" { config.CookieName = DefaultCSRFConfig.CookieName } if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } if config.CookieSameSite == http.CookieSameSiteNoneMode { config.CookieSecure = true } parts := strings.Split(config.TokenLookup, ":") extractor := csrfTokenFromHeader(parts[1]) switch parts[0] { case "form": extractor = csrfTokenFromForm(parts[1]) case "query": extractor = csrfTokenFromQuery(parts[1]) } return func(ctx *http.RequestCtx, next http.RequestHandler) { if config.Skipper(ctx) { next(ctx) return } var token []byte if k := ctx.Request.Header.Cookie(config.CookieName); k != nil { token = k } else { token = make([]byte, config.TokenLength) if _, err := rand.Read(token); err != nil { ctx.Error(err.Error(), http.StatusInternalServerError) return } token = []byte(base64.RawURLEncoding.EncodeToString(token)) } switch { case ctx.IsGet(), ctx.IsHead(), ctx.IsOptions(), ctx.IsTrace(): default: // NOTE(toby3d): validate token only for requests which are not defined as 'safe' by RFC7231 clientToken, err := extractor(ctx) if err != nil { ctx.Error(err.Error(), http.StatusBadRequest) return } if !validateCSRFToken(token, clientToken) { ctx.Error("invalid csrf token", http.StatusForbidden) return } } // NOTE(toby3d): set CSRF cookie cookie := http.AcquireCookie() defer http.ReleaseCookie(cookie) cookie.SetKey(config.CookieName) cookie.SetValueBytes(token) if config.CookiePath != "" { cookie.SetPath(config.CookiePath) } if config.CookieDomain != "" { cookie.SetDomain(config.CookieDomain) } if config.CookieSameSite != http.CookieSameSiteDefaultMode { cookie.SetSameSite(config.CookieSameSite) } cookie.SetExpire(time.Now().Add(config.CookieMaxAge)) cookie.SetSecure(config.CookieSecure) cookie.SetHTTPOnly(config.CookieHTTPOnly) ctx.Response.Header.SetCookie(cookie) // NOTE(toby3d): store token in the context ctx.SetUserValue(config.ContextKey, token) // NOTE(toby3d): protect clients from caching the response ctx.Response.Header.Add(http.HeaderVary, http.HeaderCookie) next(ctx) } } func csrfTokenFromHeader(header string) csrfTokenExtractor { return func(ctx *http.RequestCtx) ([]byte, error) { return ctx.Request.Header.Peek(header), nil } } func csrfTokenFromForm(param string) csrfTokenExtractor { return func(ctx *http.RequestCtx) ([]byte, error) { if token := ctx.FormValue(param); token != nil { return token, nil } return nil, ErrMissingFormToken } } func csrfTokenFromQuery(param string) csrfTokenExtractor { return func(ctx *http.RequestCtx) ([]byte, error) { if !ctx.QueryArgs().Has(param) { return nil, ErrMissingQueryToken } return ctx.QueryArgs().Peek(param), nil } } func validateCSRFToken(token, clientToken []byte) bool { return subtle.ConstantTimeCompare(token, clientToken) == 1 }