diff --git a/internal/common/common.go b/internal/common/common.go index a14817a..dfd2445 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -13,7 +13,8 @@ const ( ) const ( - HeaderContentType string = "Content-Type" + HeaderContentType string = "Content-Type" + HeaderAuthorization string = "Authorization" ) const Und string = "und" diff --git a/internal/middleware/extractor.go b/internal/middleware/extractor.go new file mode 100644 index 0000000..a5e35b7 --- /dev/null +++ b/internal/middleware/extractor.go @@ -0,0 +1,240 @@ +package middleware + +import ( + "errors" + "fmt" + "net/http" + "net/textproto" + "strings" + + "source.toby3d.me/toby3d/auth/internal/common" +) + +// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. +type ValuesExtractor func(w http.ResponseWriter, r *http.Request) ([]string, error) + +const ( + // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource + // exhaustion attack vector. + extractorLimit = 20 +) + +var ( + errHeaderExtractorValueMissing = errors.New("missing value in request header") + errHeaderExtractorValueInvalid = errors.New("invalid value in request header") + errQueryExtractorValueMissing = errors.New("missing value in the query string") + errParamExtractorValueMissing = errors.New("missing value in path params") + errCookieExtractorValueMissing = errors.New("missing value in cookies") + errFormExtractorValueMissing = errors.New("missing value in the form") +) + +// CreateExtractors creates ValuesExtractors from given lookups. +// Lookups is a string in the form of ":" or ":,:" that is used +// to extract key from the request. +// Possible values: +// - "header:" or "header::" +// `` is argument value to cut/trim prefix of the extracted value. This is useful if header +// value has static prefix like `Authorization: ` where part that we +// want to cut is ` ` note the space at the end. +// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. +// - "query:" +// - "param:" +// - "form:" +// - "cookie:" +// +// Multiple sources example: +// - "header:Authorization,header:X-Api-Key". +func CreateExtractors(lookups string) ([]ValuesExtractor, error) { + return createExtractors(lookups, "") +} + +func createExtractors(lookups, authScheme string) ([]ValuesExtractor, error) { + if lookups == "" { + return nil, nil + } + + sources := strings.Split(lookups, ",") + extractors := make([]ValuesExtractor, 0) + + for _, source := range sources { + parts := strings.Split(source, ":") + + if len(parts) < 2 { + return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", + source) + } + + switch parts[0] { + case "query": + extractors = append(extractors, valuesFromQuery(parts[1])) + // case "param": + // extractors = append(extractors, valuesFromParam(parts[1])) + case "cookie": + extractors = append(extractors, valuesFromCookie(parts[1])) + case "form": + extractors = append(extractors, valuesFromForm(parts[1])) + case "header": + prefix := "" + if len(parts) > 2 { + prefix = parts[2] + } else if authScheme != "" && parts[1] == common.HeaderAuthorization { + // backwards compatibility for JWT and KeyAuth: + // * we only apply this fix to Authorization as header we use and uses prefixes like + // "Bearer " etc + // * previously header extractor assumed that auth-scheme/prefix had a space as suffix + // we need to retain that behaviour for default values and Authorization header. + prefix = authScheme + if !strings.HasSuffix(prefix, " ") { + prefix += " " + } + } + + extractors = append(extractors, valuesFromHeader(parts[1], prefix)) + } + } + + return extractors, nil +} + +// valuesFromHeader returns a functions that extracts values from the request header. +// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has +// static prefix like `Authorization: ` where part that we want to remove is +// ` ` note the space at the end. In case of basic authentication `Authorization: Basic ` +// prefix we want to remove is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. +// If prefix is left empty the whole value is returned. +func valuesFromHeader(header, valuePrefix string) ValuesExtractor { + prefixLen := len(valuePrefix) + // standard library parses http.Request header keys in canonical form but we may provide something else so fix + // this + header = textproto.CanonicalMIMEHeaderKey(header) + + return func(w http.ResponseWriter, r *http.Request) ([]string, error) { + values := r.Header.Values(header) + if len(values) == 0 { + return nil, errHeaderExtractorValueMissing + } + + result := make([]string, 0) + + for i, value := range values { + if prefixLen == 0 { + result = append(result, value) + + if i >= extractorLimit-1 { + break + } + + continue + } + + if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { + result = append(result, value[prefixLen:]) + + if i >= extractorLimit-1 { + break + } + } + } + + if len(result) == 0 { + if prefixLen > 0 { + return nil, errHeaderExtractorValueInvalid + } + + return nil, errHeaderExtractorValueMissing + } + + return result, nil + } +} + +// valuesFromQuery returns a function that extracts values from the query string. +func valuesFromQuery(param string) ValuesExtractor { + return func(w http.ResponseWriter, r *http.Request) ([]string, error) { + result := r.URL.Query()[param] + + if len(result) == 0 { + return nil, errQueryExtractorValueMissing + } else if len(result) > extractorLimit-1 { + result = result[:extractorLimit] + } + + return result, nil + } +} + +// valuesFromParam returns a function that extracts values from the url param string. +/* +func valuesFromParam(param string) ValuesExtractor { + return func(w http.ResponseWriter, r *http.Request) ([]string, error) { + result := make([]string, 0) + paramVales := r.ParamValues() + + for i, p := range r.ParamNames() { + if param == p { + result = append(result, paramVales[i]) + if i >= extractorLimit-1 { + break + } + } + } + + if len(result) == 0 { + return nil, errParamExtractorValueMissing + } + + return result, nil + } +} +*/ + +// valuesFromCookie returns a function that extracts values from the named cookie. +func valuesFromCookie(name string) ValuesExtractor { + return func(w http.ResponseWriter, r *http.Request) ([]string, error) { + cookies := r.Cookies() + if len(cookies) == 0 { + return nil, errCookieExtractorValueMissing + } + + result := make([]string, 0) + + for i, cookie := range cookies { + if name == cookie.Name { + result = append(result, cookie.Value) + + if i >= extractorLimit-1 { + break + } + } + } + + if len(result) == 0 { + return nil, errCookieExtractorValueMissing + } + + return result, nil + } +} + +// valuesFromForm returns a function that extracts values from the form field. +func valuesFromForm(name string) ValuesExtractor { + return func(w http.ResponseWriter, r *http.Request) ([]string, error) { + if r.Form == nil { + _ = r.ParseMultipartForm(32 << 20) // same what `r.FormValue(name)` does + } + + values := r.Form[name] + + if len(values) == 0 { + return nil, errFormExtractorValueMissing + } + + if len(values) > extractorLimit-1 { + values = values[:extractorLimit] + } + + result := append([]string{}, values...) + + return result, nil + } +} diff --git a/internal/middleware/jwt.go b/internal/middleware/jwt.go new file mode 100644 index 0000000..386ab8f --- /dev/null +++ b/internal/middleware/jwt.go @@ -0,0 +1,304 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwt" + + "source.toby3d.me/toby3d/auth/internal/common" +) + +type ( + // JWTConfig defines the config for JWT middleware. + JWTConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // BeforeFunc defines a function which is executed just before + // the middleware. + BeforeFunc BeforeFunc + + // SuccessHandler defines a function which is executed for a + // valid token. + SuccessHandler JWTSuccessHandler + + // ErrorHandler defines a function which is executed for an + // invalid token. It may be used to define a custom JWT error. + ErrorHandler JWTErrorHandler + + // ErrorHandlerWithContext is almost identical to ErrorHandler, + // but it's passed the current context. + ErrorHandlerWithContext JWTErrorHandlerWithContext + + // Signing key to validate token. + // This is one of the three options to provide a token + // validation key. The order of precedence is a user-defined + // KeyFunc, SigningKeys and SigningKey. + // + // Required if neither user-defined KeyFunc nor SigningKeys is + // provided. + SigningKey any + + // Map of signing keys to validate token with kid field usage. + // This is one of the three options to provide a token + // validation key. The order of precedence is a user-defined + // KeyFunc, SigningKeys and SigningKey. + // + // Required if neither user-defined KeyFunc nor SigningKey is + // provided. + SigningKeys map[string]any + + // Signing method used to check the token's signing algorithm. + // + // Optional. Default value HS256. + SigningMethod jwa.SignatureAlgorithm + + // Context key to store user information from the token into + // context. + // + // Optional. Default value "user". + ContextKey string + + // Claims are extendable claims data defining token content. + // Used by default ParseTokenFunc implementation. Not used if + // custom ParseTokenFunc is set. + // + // Optional. Default value []jwt.ClaimPair + Claims []jwt.ClaimPair + + // TokenLookup is a string in the form of ":" or + // ":,:" that is used to extract + // token from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" + // - "query:" + // - "param:" + // - "cookie:" + // - "form:" + // Multiply sources example: + // - "header: Authorization,cookie: myowncookie" + TokenLookup string + + // TokenLookupFuncs defines a list of user-defined functions + // that extract JWT token from the given context. + // This is one of the two options to provide a token extractor. + // The order of precedence is user-defined TokenLookupFuncs, and + // TokenLookup. + // You can also provide both if you want. + TokenLookupFuncs []ValuesExtractor + + // AuthScheme to be used in the Authorization header. + // + // Optional. Default value "Bearer". + AuthScheme string + + // KeyFunc defines a user-defined function that supplies the + // public key for a token validation. The function shall take + // care of verifying the signing algorithm and selecting the + // proper key. A user-defined KeyFunc can be useful if tokens + // are issued by an external party. Used by default + // ParseTokenFunc implementation. + // + // When a user-defined KeyFunc is provided, SigningKey, + // SigningKeys, and SigningMethod are ignored. This is one of + // the three options to provide a token validation key. The + // order of precedence is a user-defined KeyFunc, SigningKeys + // and SigningKey. Required if neither SigningKeys nor + // SigningKey is provided. Not used if custom ParseTokenFunc is + // set. Default to an internal implementation verifying the + // signing algorithm and selecting the proper key. + // KeyFunc jwt.Keyfunc + + // ParseTokenFunc defines a user-defined function that parses + // token from given auth. Returns an error when token parsing + // fails or parsed token is invalid. Defaults to implementation + // using `github.com/golang-jwt/jwt` as JWT implementation library + ParseTokenFunc func(auth []byte, w http.ResponseWriter, r *http.Request) (any, error) + + // ContinueOnIgnoredError allows the next middleware/handler to + // be called when ErrorHandlerWithContext decides to ignore the + // error (by returning `nil`). This is useful when parts of your + // site/api allow public access and some authorized routes + // provide extra functionality. In that case you can use + // ErrorHandlerWithContext to set a default public JWT token + // value in the request context and continue. Some logic down + // the remaining execution chain needs to check that (public) + // token value then. + ContinueOnIgnoredError bool + } + + // JWTSuccessHandler defines a function which is executed for a valid + // token. + JWTSuccessHandler func(w http.ResponseWriter, r *http.Request) + + // JWTErrorHandler defines a function which is executed for an invalid + // token. + JWTErrorHandler func(err error) + + // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, + // but it's passed the current context. + JWTErrorHandlerWithContext func(err error, w http.ResponseWriter, r *http.Request) error + + jwtExtractor func(w http.ResponseWriter, r *http.Request) ([]byte, error) +) + +// DefaultJWTConfig is the default JWT auth middleware config. +//nolint: gochecknoglobals +var DefaultJWTConfig = JWTConfig{ + Skipper: DefaultSkipper, + SigningMethod: "HS256", + ContextKey: "user", + TokenLookup: "header:" + common.HeaderAuthorization, + AuthScheme: "Bearer", + Claims: []jwt.ClaimPair{}, +} + +// Possible token errors. +var ( + ErrJWTMissing = errors.New("jwt: missing or malformed jwt") + ErrJWTInvalid = errors.New("jwt: invalid or expired jwt") +) + +// JWT returns a JSON Web Token (JWT) auth middleware. +// +// * For valid token, it sets the user in context and calls next handler. +// * For invalid token, it returns "401 - Unauthorized" error. +// * For missing token, it returns "400 - Bad Request" error. +func JWT(key any) Interceptor { + config := DefaultJWTConfig + config.SigningKey = key + + return JWTWithConfig(config) +} + +// JWTWithConfig returns a JWT auth middleware with config. +//nolint: funlen, gocognit +func JWTWithConfig(config JWTConfig) Interceptor { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultJWTConfig.Skipper + } + + if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.ParseTokenFunc == nil { + panic("jwt: middleware requires signing key") + } + + if config.SigningMethod == "" { + config.SigningMethod = DefaultJWTConfig.SigningMethod + } + + if config.ContextKey == "" { + config.ContextKey = DefaultJWTConfig.ContextKey + } + + if config.Claims == nil { + config.Claims = DefaultJWTConfig.Claims + } + + if config.TokenLookup == "" { + config.TokenLookup = DefaultJWTConfig.TokenLookup + } + + if config.AuthScheme == "" { + config.AuthScheme = DefaultJWTConfig.AuthScheme + } + + if config.ParseTokenFunc == nil { + config.ParseTokenFunc = config.defaultParseToken + } + + extractors, err := createExtractors(config.TokenLookup, config.AuthScheme) + if err != nil { + panic("middleware: jwt: " + err.Error()) + } + + if len(config.TokenLookupFuncs) > 0 { + extractors = append(config.TokenLookupFuncs, extractors...) + } + + return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if config.Skipper(w, r) { + next(w, r) + + return + } + + if config.BeforeFunc != nil { + config.BeforeFunc(w, r) + } + + var lastExtractorErr, lastTokenErr error + + for _, extractor := range extractors { + auths, err := extractor(w, r) + if err != nil { + lastExtractorErr = ErrJWTMissing + + continue + } + + for _, auth := range auths { + token, err := config.ParseTokenFunc([]byte(auth), w, r) + if err != nil { + lastTokenErr = err + + continue + } + + r = r.WithContext(context.WithValue(r.Context(), config.ContextKey, token)) + + if config.SuccessHandler != nil { + config.SuccessHandler(w, r) + } + + next(w, r) + + return + } + } + + err := lastTokenErr + if err == nil { + err = lastExtractorErr + } + + if config.ErrorHandler != nil { + config.ErrorHandler(err) + + return + } + + if config.ErrorHandlerWithContext != nil { + tmpErr := config.ErrorHandlerWithContext(err, w, r) + if config.ContinueOnIgnoredError && tmpErr == nil { + next(w, r) + } else { + http.Error(w, tmpErr.Error(), http.StatusUnauthorized) + } + + return + } + + if lastTokenErr != nil { + http.Error(w, lastTokenErr.Error(), http.StatusUnauthorized) + + return + } + + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (config *JWTConfig) defaultParseToken(auth []byte, w http.ResponseWriter, r *http.Request) (any, error) { + token, err := jwt.Parse(auth, jwt.WithKey(config.SigningMethod, config.SigningKey), jwt.WithVerify(true)) + if err != nil { + return nil, fmt.Errorf("cannot parse JWT token: %w", err) + } + + return token, nil +}