diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go deleted file mode 100644 index acc0a1c..0000000 --- a/internal/middleware/csrf.go +++ /dev/null @@ -1,199 +0,0 @@ -package middleware - -import ( - "crypto/subtle" - "errors" - "strings" - "time" - - http "github.com/valyala/fasthttp" - - "source.toby3d.me/website/oauth/internal/random" -) - -type ( - CSRFConfig struct { - Skipper Skipper - CookieMaxAge time.Duration - CookieSameSite http.CookieSameSite - ContextKey string - CookieDomain string - CookieName string - CookiePath string - TokenLookup string - TokenLength int - CookieSecure bool - CookieHTTPOnly bool - } - - csrfTokenExtractor func(*http.RequestCtx) ([]byte, error) -) - -// HeaderXCSRFToken describes the name of the header with the CSRF token. -const HeaderXCSRFToken string = "X-CSRF-Token" //nolint: gosec - -var ( - ErrMissingFormToken = errors.New("missing csrf token in the form parameter") - ErrMissingQueryToken = errors.New("missing csrf token in the query string") -) - -// DefaultCSRFConfig contains the default CSRF middleware configuration. -//nolint: gochecknoglobals, gomnd -var DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - CookieMaxAge: 24 * time.Hour, - CookieSameSite: http.CookieSameSiteDefaultMode, - ContextKey: "csrf", - CookieDomain: "", - CookieName: "_csrf", - CookiePath: "", - TokenLookup: "header:" + HeaderXCSRFToken, - TokenLength: 32, - CookieSecure: false, - CookieHTTPOnly: false, -} - -func CSRF() Interceptor { - return CSRFWithConfig(DefaultCSRFConfig) -} - -//nolint: funlen, cyclop -func CSRFWithConfig(config CSRFConfig) Interceptor { - if config.Skipper == nil { - config.Skipper = DefaultCSRFConfig.Skipper - } - - if config.TokenLength == 0 { - config.TokenLength = DefaultCSRFConfig.TokenLength - } - - if config.TokenLookup == "" { - config.TokenLookup = DefaultCSRFConfig.TokenLookup - } - - if config.ContextKey == "" { - config.ContextKey = DefaultCSRFConfig.ContextKey - } - - if config.CookieName == "" { - config.CookieName = DefaultCSRFConfig.CookieName - } - - if config.CookieMaxAge == 0 { - config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge - } - - if config.CookieSameSite == http.CookieSameSiteNoneMode { - config.CookieSecure = true - } - - parts := strings.Split(config.TokenLookup, ":") - extractor := csrfTokenFromHeader(parts[1]) - - switch parts[0] { - case "form": - extractor = csrfTokenFromForm(parts[1]) - case "query": - extractor = csrfTokenFromQuery(parts[1]) - } - - return func(ctx *http.RequestCtx, next http.RequestHandler) { - if config.Skipper(ctx) { - next(ctx) - - return - } - - var token []byte - if k := ctx.Request.Header.Cookie(config.CookieName); k != nil { - token = k - } else { - var err error - if token, err = random.Bytes(config.TokenLength); err != nil { - ctx.Error(err.Error(), http.StatusInternalServerError) - - return - } - } - - switch { - case ctx.IsGet(), ctx.IsHead(), ctx.IsOptions(), ctx.IsTrace(): - default: - // NOTE(toby3d): validate token only for requests which are not defined as 'safe' by RFC7231 - clientToken, err := extractor(ctx) - if err != nil { - ctx.Error(err.Error(), http.StatusBadRequest) - - return - } - - if !validateCSRFToken(token, clientToken) { - ctx.Error("invalid csrf token", http.StatusForbidden) - - return - } - } - - // NOTE(toby3d): set CSRF cookie - cookie := http.AcquireCookie() - defer http.ReleaseCookie(cookie) - - cookie.SetKey(config.CookieName) - cookie.SetValueBytes(token) - - if config.CookiePath != "" { - cookie.SetPath(config.CookiePath) - } - - if config.CookieDomain != "" { - cookie.SetDomain(config.CookieDomain) - } - - if config.CookieSameSite != http.CookieSameSiteDefaultMode { - cookie.SetSameSite(config.CookieSameSite) - } - - cookie.SetExpire(time.Now().Add(config.CookieMaxAge)) - cookie.SetSecure(config.CookieSecure) - cookie.SetHTTPOnly(config.CookieHTTPOnly) - ctx.Response.Header.SetCookie(cookie) - - // NOTE(toby3d): store token in the context - ctx.SetUserValue(config.ContextKey, token) - - // NOTE(toby3d): protect clients from caching the response - ctx.Response.Header.Add(http.HeaderVary, http.HeaderCookie) - - next(ctx) - } -} - -func csrfTokenFromHeader(header string) csrfTokenExtractor { - return func(ctx *http.RequestCtx) ([]byte, error) { - return ctx.Request.Header.Peek(header), nil - } -} - -func csrfTokenFromForm(param string) csrfTokenExtractor { - return func(ctx *http.RequestCtx) ([]byte, error) { - if token := ctx.FormValue(param); token != nil { - return token, nil - } - - return nil, ErrMissingFormToken - } -} - -func csrfTokenFromQuery(param string) csrfTokenExtractor { - return func(ctx *http.RequestCtx) ([]byte, error) { - if !ctx.QueryArgs().Has(param) { - return nil, ErrMissingQueryToken - } - - return ctx.QueryArgs().Peek(param), nil - } -} - -func validateCSRFToken(token, clientToken []byte) bool { - return subtle.ConstantTimeCompare(token, clientToken) == 1 -} diff --git a/internal/middleware/jwt.go b/internal/middleware/jwt.go deleted file mode 100644 index ba2d319..0000000 --- a/internal/middleware/jwt.go +++ /dev/null @@ -1,346 +0,0 @@ -package middleware - -import ( - "bytes" - "strings" - - "github.com/lestrrat-go/jwx/jwa" - "github.com/lestrrat-go/jwx/jwt" - "github.com/pkg/errors" - http "github.com/valyala/fasthttp" - "golang.org/x/xerrors" - - "source.toby3d.me/website/oauth/internal/domain" -) - -type ( - JWTConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // BeforeFunc defines a function which is executed just before - // the middleware. - BeforeFunc BeforeFunc - - // SuccessHandler defines a function which is executed for a - // valid token. - SuccessHandler JWTSuccessHandler - - // ErrorHandler defines a function which is executed for an - // invalid token. It may be used to define a custom JWT error. - ErrorHandler JWTErrorHandler - - // ErrorHandlerWithContext is almost identical to ErrorHandler, - // but it's passed the current context. - ErrorHandlerWithContext JWTErrorHandlerWithContext - - // Signing key to validate token. - // This is one of the three options to provide a token - // validation key. The order of precedence is a user-defined - // KeyFunc, SigningKeys and SigningKey. Required if neither - // user-defined KeyFunc nor SigningKeys is provided. - SigningKey interface{} - - // Map of signing keys to validate token with kid field usage. - // This is one of the three options to provide a token - // validation key. The order of precedence is a user-defined - // KeyFunc, SigningKeys and SigningKey. Required if neither - // user-defined KeyFunc nor SigningKey is provided. - SigningKeys map[string]interface{} - - // Signing method used to check the token's signing algorithm. - // Optional. Default value HS256. - SigningMethod jwa.SignatureAlgorithm - - // Context key to store user information from the token into - // context. Optional. Default value "user". - ContextKey string - - // Claims are extendable claims data defining token content. - // Used by default ParseTokenFunc implementation. Not used if - // custom ParseTokenFunc is set. Optional. Default value - // []jwt.ClaimPair - Claims []jwt.ClaimPair - - // TokenLookup is a string in the form of ":" or - // ":,:" that is used to extract - // token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" - // - "query:" - // - "param:" - // - "cookie:" - // - "form:" - // Multiply sources example: - // - "header: Authorization,cookie: myowncookie" - TokenLookup string - - // AuthScheme to be used in the Authorization header. Optional. - // Default value "Bearer". - AuthScheme string - - // KeyFunc defines a user-defined function that supplies the - // public key for a token validation. The function shall take - // care of verifying the signing algorithm and selecting the - // proper key. A user-defined KeyFunc can be useful if tokens - // are issued by an external party. Used by default - // ParseTokenFunc implementation. - // - // When a user-defined KeyFunc is provided, SigningKey, - // SigningKeys, and SigningMethod are ignored. This is one of - // the three options to provide a token validation key. The - // order of precedence is a user-defined KeyFunc, SigningKeys - // and SigningKey. Required if neither SigningKeys nor - // SigningKey is provided. Not used if custom ParseTokenFunc is - // set. Default to an internal implementation verifying the - // signing algorithm and selecting the proper key. - // KeyFunc jwt.Keyfunc - - // ParseTokenFunc defines a user-defined function that parses - // token from given auth. Returns an error when token parsing - // fails or parsed token is invalid. Defaults to implementation - // using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(auth []byte, ctx *http.RequestCtx) (interface{}, error) - } - - // JWTSuccessHandler defines a function which is executed for a valid - // token. - JWTSuccessHandler func(ctx *http.RequestCtx) - - // JWTErrorHandler defines a function which is executed for an invalid - // token. - JWTErrorHandler func(err error) - - // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, - // but it's passed the current context. - JWTErrorHandlerWithContext func(err error, ctx *http.RequestCtx) - - jwtExtractor func(ctx *http.RequestCtx) ([]byte, error) -) - -const ( - SourceCookie = "cookie" - SourceForm = "form" - SourceHeader = "header" - SourceParam = "param" - SourceQuery = "query" -) - -// DefaultJWTConfig is the default JWT auth middleware config. -//nolint: gochecknoglobals -var DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: "HS256", - ContextKey: "user", - TokenLookup: SourceHeader + ":" + http.HeaderAuthorization, - AuthScheme: "Bearer", - Claims: []jwt.ClaimPair{}, -} - -var ( - ErrJWTMissing = domain.Error{ - Code: "invalid_request", - Description: "missing or malformed jwt", - URI: "", - Frame: xerrors.Caller(1), - } - - ErrJWTInvalid = domain.Error{ - Code: "unauthorized", - Description: "invalid or expired jwt", - URI: "", - Frame: xerrors.Caller(1), - } -) - -func JWT(key interface{}) Interceptor { - return JWTWithConfig(DefaultJWTConfig) -} - -//nolint: funlen, gocognit -func JWTWithConfig(config JWTConfig) Interceptor { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultJWTConfig.Skipper - } - - if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.ParseTokenFunc == nil { - panic("jwt middleware requires signing key") - } - - if config.SigningMethod == "" { - config.SigningMethod = DefaultJWTConfig.SigningMethod - } - - if config.ContextKey == "" { - config.ContextKey = DefaultJWTConfig.ContextKey - } - - if config.Claims == nil { - config.Claims = DefaultJWTConfig.Claims - } - - if config.TokenLookup == "" { - config.TokenLookup = DefaultJWTConfig.TokenLookup - } - - if config.AuthScheme == "" { - config.AuthScheme = DefaultJWTConfig.AuthScheme - } - - if config.ParseTokenFunc == nil { - config.ParseTokenFunc = config.defaultParseToken - } - - extractors := make([]jwtExtractor, 0) - sources := strings.Split(config.TokenLookup, ",") - - for _, source := range sources { - parts := strings.Split(source, ":") - - switch parts[0] { - case SourceQuery: - extractors = append(extractors, jwtFromQuery(parts[1])) - case SourceParam: - extractors = append(extractors, jwtFromParam(parts[1])) - case SourceCookie: - extractors = append(extractors, jwtFromCookie(parts[1])) - case SourceForm: - extractors = append(extractors, jwtFromForm(parts[1])) - case SourceHeader: - extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme)) - } - } - - return func(ctx *http.RequestCtx, next http.RequestHandler) { - if config.Skipper(ctx) { - next(ctx) - - return - } - - if config.BeforeFunc != nil { - config.BeforeFunc(ctx) - } - - var ( - auth []byte - err error - ) - - for _, extractor := range extractors { - if auth, err = extractor(ctx); err == nil { - break - } - } - - if err != nil { - if config.ErrorHandler != nil { - config.ErrorHandler(err) - - return - } - - if config.ErrorHandlerWithContext != nil { - config.ErrorHandlerWithContext(err, ctx) - - return - } - - ctx.Error(err.Error(), http.StatusInternalServerError) - - return - } - - if token, err := config.ParseTokenFunc(auth, ctx); err == nil { - ctx.SetUserValue(config.ContextKey, token) - - if config.SuccessHandler != nil { - config.SuccessHandler(ctx) - } - - next(ctx) - - return - } - - if config.ErrorHandler != nil { - config.ErrorHandler(err) - - return - } - - if config.ErrorHandlerWithContext != nil { - config.ErrorHandlerWithContext(err, ctx) - - return - } - - ctx.Error(ErrJWTInvalid.Description, http.StatusUnauthorized) - } -} - -func (config *JWTConfig) defaultParseToken(auth []byte, ctx *http.RequestCtx) (interface{}, error) { - token, err := jwt.Parse( - auth, jwt.WithVerify(config.SigningMethod, config.SigningKey), jwt.WithValidate(true), - ) - if err != nil { - return nil, errors.Wrap(err, "cannot parse JWT token") - } - - return token, nil -} - -func jwtFromHeader(header, authScheme string) jwtExtractor { - return func(ctx *http.RequestCtx) ([]byte, error) { - auth := ctx.Request.Header.Peek(header) - l := len(authScheme) - - if len(auth) > l+1 && bytes.EqualFold(auth[:l], []byte(authScheme)) { - return auth[l+1:], nil - } - - return nil, ErrJWTMissing - } -} - -func jwtFromQuery(param string) jwtExtractor { - return func(ctx *http.RequestCtx) ([]byte, error) { - if ctx.QueryArgs().Has(param) { - return nil, ErrJWTMissing - } - - return ctx.QueryArgs().Peek(param), nil - } -} - -func jwtFromParam(param string) jwtExtractor { - return func(ctx *http.RequestCtx) ([]byte, error) { - if !ctx.PostArgs().Has(param) { - return nil, ErrJWTMissing - } - - return ctx.PostArgs().Peek(param), nil - } -} - -func jwtFromCookie(name string) jwtExtractor { - return func(ctx *http.RequestCtx) ([]byte, error) { - if cookie := ctx.Request.Header.Cookie(name); cookie != nil { - return cookie, nil - } - - return nil, ErrJWTMissing - } -} - -func jwtFromForm(name string) jwtExtractor { - return func(ctx *http.RequestCtx) ([]byte, error) { - if field := ctx.FormValue(name); field != nil { - return field, nil - } - - return nil, ErrJWTMissing - } -} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go deleted file mode 100644 index 2bf5993..0000000 --- a/internal/middleware/middleware.go +++ /dev/null @@ -1,32 +0,0 @@ -package middleware - -import http "github.com/valyala/fasthttp" - -type ( - BeforeFunc http.RequestHandler - Chain []Interceptor - Interceptor func(*http.RequestCtx, http.RequestHandler) - RequestHandler http.RequestHandler - Skipper func(*http.RequestCtx) bool -) - -// DefaultSkipper is the default skipper, which always returns false. -//nolint: gochecknoglobals -var DefaultSkipper Skipper = func(*http.RequestCtx) bool { return false } - -func (count RequestHandler) Intercept(middleware Interceptor) RequestHandler { - return func(ctx *http.RequestCtx) { - middleware(ctx, http.RequestHandler(count)) - } -} - -func (chain Chain) RequestHandler(handler http.RequestHandler) http.RequestHandler { - current := RequestHandler(handler) - - for i := len(chain) - 1; i >= 0; i-- { - m := chain[i] - current = current.Intercept(m) - } - - return http.RequestHandler(current) -}