diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go new file mode 100644 index 0000000..c957d4d --- /dev/null +++ b/internal/middleware/csrf.go @@ -0,0 +1,238 @@ +package middleware + +import ( + "context" + "crypto/subtle" + "errors" + "net/http" + "time" + + "source.toby3d.me/toby3d/auth/internal/common" + "source.toby3d.me/toby3d/auth/internal/random" +) + +type ( + // CSRFConfig defines the config for CSRF middleware. + CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // TokenLength is the length of the generated token. + TokenLength uint8 + // Optional. Default value 32. + + // TokenLookup is a string in the form of ":" or ":,:" that + // is used to extract token from the request. + // Optional. Default value "header:X-CSRF-Token". + // Possible values: + // - "header:" or "header::" + // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" + TokenLookup string + + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string + + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value none. + CookiePath string + + // Max age (in seconds) of the CSRF cookie. + // Optional. Default value 86400 (24hr). + CookieMaxAge int + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite + + // ErrorHandler defines a function which is executed for returning custom errors. + ErrorHandler CSRFErrorHandler + } + + // CSRFErrorHandler is a function which is executed for creating custom errors. + CSRFErrorHandler func(err error, w http.ResponseWriter, r *http.Request) error +) + +// ErrCSRFInvalid is returned when CSRF check fails. +var ErrCSRFInvalid = errors.New("") // echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + +// DefaultCSRFConfig is the default CSRF middleware config. +var DefaultCSRFConfig = CSRFConfig{ + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + common.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, +} + +// CSRF returns a Cross-Site Request Forgery (CSRF) middleware. +// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery +func CSRF() Interceptor { + c := DefaultCSRFConfig + + return CSRFWithConfig(c) +} + +// CSRFWithConfig returns a CSRF middleware with config. +// See `CSRF()`. +func CSRFWithConfig(config CSRFConfig) Interceptor { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultCSRFConfig.Skipper + } + + if config.TokenLength == 0 { + config.TokenLength = DefaultCSRFConfig.TokenLength + } + + if config.TokenLookup == "" { + config.TokenLookup = DefaultCSRFConfig.TokenLookup + } + + if config.ContextKey == "" { + config.ContextKey = DefaultCSRFConfig.ContextKey + } + + if config.CookieName == "" { + config.CookieName = DefaultCSRFConfig.CookieName + } + + if config.CookieMaxAge == 0 { + config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge + } + + if config.CookieSameSite == http.SameSiteNoneMode { + config.CookieSecure = true + } + + extractors, err := CreateExtractors(config.TokenLookup) + if err != nil { + panic(err) + } + + return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if config.Skipper(w, r) { + next(w, r) + + return + } + + token := "" + if k, err := r.Cookie(config.CookieName); err != nil { + token, _ = random.String(config.TokenLength) // Generate token + } else { + token = k.Value // Reuse token + } + + switch r.Method { + case "", http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: + default: + var lastExtractorErr, lastTokenErr error + outer: + for _, extractor := range extractors { + clientTokens, err := extractor(w, r) + if err != nil { + lastExtractorErr = err + + continue + } + + for _, clientToken := range clientTokens { + if validateCSRFToken(token, clientToken) { + lastTokenErr = nil + lastExtractorErr = nil + + break outer + } + + lastTokenErr = ErrCSRFInvalid + } + } + + var finalErr error + + if lastTokenErr != nil { + finalErr = lastTokenErr + } else if lastExtractorErr != nil { + switch { + case errors.Is(lastExtractorErr, errQueryExtractorValueMissing): + lastExtractorErr = errors.New("missing csrf token in the query string") + case errors.Is(lastExtractorErr, errFormExtractorValueMissing): + lastExtractorErr = errors.New("missing csrf token in the form parameter") + case errors.Is(lastExtractorErr, errHeaderExtractorValueMissing): + lastExtractorErr = errors.New("missing csrf token in request header") + } + + finalErr = lastExtractorErr + } + + if finalErr != nil { + if config.ErrorHandler != nil { + config.ErrorHandler(finalErr, w, r) + + return + } + + http.Error(w, finalErr.Error(), http.StatusBadRequest) + + return + } + } + + // Set CSRF cookie + cookie := new(http.Cookie) + cookie.Name = config.CookieName + cookie.Value = token + + if config.CookiePath != "" { + cookie.Path = config.CookiePath + } + + if config.CookieDomain != "" { + cookie.Domain = config.CookieDomain + } + + if config.CookieSameSite != http.SameSiteDefaultMode { + cookie.SameSite = config.CookieSameSite + } + + cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) + cookie.Secure = config.CookieSecure + cookie.HttpOnly = config.CookieHTTPOnly + + http.SetCookie(w, cookie) + + // Store token in the context + r = r.WithContext(context.WithValue(r.Context(), config.ContextKey, token)) + + // Protect clients from caching the response + w.Header().Add(common.HeaderVary, common.HeaderCookie) + + next(w, r) + } +} + +func validateCSRFToken(token, clientToken string) bool { + return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 +}