auth/vendor/github.com/lestrrat-go/jwx/v2/jwt/jwt.go

475 lines
14 KiB
Go

//go:generate ../tools/cmd/genjwt.sh
//go:generate stringer -type=TokenOption -output=token_options_gen.go
// Package jwt implements JSON Web Tokens as described in https://tools.ietf.org/html/rfc7519
package jwt
import (
"bytes"
"errors"
"fmt"
"io"
"sync/atomic"
"github.com/lestrrat-go/jwx/v2"
"github.com/lestrrat-go/jwx/v2/internal/json"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt/internal/types"
)
var errInvalidJWT = errors.New(`invalid JWT`)
// ErrInvalidJWT returns the opaque error value that is returned when
// `jwt.Parse` fails due to not being able to deduce the format of
// the incoming buffer
func ErrInvalidJWT() error {
return errInvalidJWT
}
// Settings controls global settings that are specific to JWTs.
func Settings(options ...GlobalOption) {
var flattenAudienceBool bool
var parsePedantic bool
var parsePrecision = types.MaxPrecision + 1 // illegal value, so we can detect nothing was set
var formatPrecision = types.MaxPrecision + 1 // illegal value, so we can detect nothing was set
//nolint:forcetypeassert
for _, option := range options {
switch option.Ident() {
case identFlattenAudience{}:
flattenAudienceBool = option.Value().(bool)
case identNumericDateParsePedantic{}:
parsePedantic = option.Value().(bool)
case identNumericDateParsePrecision{}:
v := option.Value().(int)
// only accept this value if it's in our desired range
if v >= 0 && v <= int(types.MaxPrecision) {
parsePrecision = uint32(v)
}
case identNumericDateFormatPrecision{}:
v := option.Value().(int)
// only accept this value if it's in our desired range
if v >= 0 && v <= int(types.MaxPrecision) {
formatPrecision = uint32(v)
}
}
}
if parsePrecision <= types.MaxPrecision { // remember we set default to max + 1
v := atomic.LoadUint32(&types.ParsePrecision)
if v != parsePrecision {
atomic.CompareAndSwapUint32(&types.ParsePrecision, v, parsePrecision)
}
}
if formatPrecision <= types.MaxPrecision { // remember we set default to max + 1
v := atomic.LoadUint32(&types.FormatPrecision)
if v != formatPrecision {
atomic.CompareAndSwapUint32(&types.FormatPrecision, v, formatPrecision)
}
}
{
v := atomic.LoadUint32(&types.Pedantic)
if (v == 1) != parsePedantic {
var newVal uint32
if parsePedantic {
newVal = 1
}
atomic.CompareAndSwapUint32(&types.Pedantic, v, newVal)
}
}
{
defaultOptionsMu.Lock()
if flattenAudienceBool {
defaultOptions.Enable(FlattenAudience)
} else {
defaultOptions.Disable(FlattenAudience)
}
defaultOptionsMu.Unlock()
}
}
var registry = json.NewRegistry()
// ParseString calls Parse against a string
func ParseString(s string, options ...ParseOption) (Token, error) {
return parseBytes([]byte(s), options...)
}
// Parse parses the JWT token payload and creates a new `jwt.Token` object.
// The token must be encoded in either JSON format or compact format.
//
// This function can only work with either raw JWT (JSON) and JWS (Compact or JSON).
// If you need JWE support on top of it, you will need to rollout your
// own workaround.
//
// If the token is signed and you want to verify the payload matches the signature,
// you must pass the jwt.WithKey(alg, key) or jwt.WithKeySet(jwk.Set) option.
// If you do not specify these parameters, no verification will be performed.
//
// During verification, if the JWS headers specify a key ID (`kid`), the
// key used for verification must match the specified ID. If you are somehow
// using a key without a `kid` (which is highly unlikely if you are working
// with a JWT from a well know provider), you can workaround this by modifying
// the `jwk.Key` and setting the `kid` header.
//
// If you also want to assert the validity of the JWT itself (i.e. expiration
// and such), use the `Validate()` function on the returned token, or pass the
// `WithValidate(true)` option. Validate options can also be passed to
// `Parse`
//
// This function takes both ParseOption and ValidateOption types:
// ParseOptions control the parsing behavior, and ValidateOptions are
// passed to `Validate()` when `jwt.WithValidate` is specified.
func Parse(s []byte, options ...ParseOption) (Token, error) {
return parseBytes(s, options...)
}
// ParseInsecure is exactly the same as Parse(), but it disables
// signature verification and token validation.
//
// You cannot override `jwt.WithVerify()` or `jwt.WithValidate()`
// using this function. Providing these options would result in
// an error
func ParseInsecure(s []byte, options ...ParseOption) (Token, error) {
for _, option := range options {
switch option.Ident() {
case identVerify{}, identValidate{}:
return nil, fmt.Errorf(`jwt.ParseInsecure: jwt.WithVerify() and jwt.WithValidate() may not be specified`)
}
}
options = append(options, WithVerify(false), WithValidate(false))
return Parse(s, options...)
}
// ParseReader calls Parse against an io.Reader
func ParseReader(src io.Reader, options ...ParseOption) (Token, error) {
// We're going to need the raw bytes regardless. Read it.
data, err := io.ReadAll(src)
if err != nil {
return nil, fmt.Errorf(`failed to read from token data source: %w`, err)
}
return parseBytes(data, options...)
}
type parseCtx struct {
token Token
validateOpts []ValidateOption
verifyOpts []jws.VerifyOption
localReg *json.Registry
pedantic bool
skipVerification bool
validate bool
}
func parseBytes(data []byte, options ...ParseOption) (Token, error) {
var ctx parseCtx
// Validation is turned on by default. You need to specify
// jwt.WithValidate(false) if you want to disable it
ctx.validate = true
// Verification is required (i.e., it is assumed that the incoming
// data is in JWS format) unless the user explicitly asks for
// it to be skipped.
verification := true
var verifyOpts []Option
for _, o := range options {
if v, ok := o.(ValidateOption); ok {
ctx.validateOpts = append(ctx.validateOpts, v)
continue
}
//nolint:forcetypeassert
switch o.Ident() {
case identKey{}, identKeySet{}, identVerifyAuto{}, identKeyProvider{}:
verifyOpts = append(verifyOpts, o)
case identToken{}:
token, ok := o.Value().(Token)
if !ok {
return nil, fmt.Errorf(`invalid token passed via WithToken() option (%T)`, o.Value())
}
ctx.token = token
case identPedantic{}:
ctx.pedantic = o.Value().(bool)
case identValidate{}:
ctx.validate = o.Value().(bool)
case identVerify{}:
verification = o.Value().(bool)
case identTypedClaim{}:
pair := o.Value().(claimPair)
if ctx.localReg == nil {
ctx.localReg = json.NewRegistry()
}
ctx.localReg.Register(pair.Name, pair.Value)
}
}
if !verification {
ctx.skipVerification = true
}
lvo := len(verifyOpts)
if lvo == 0 && verification {
return nil, fmt.Errorf(`jwt.Parse: no keys for verification are provided (use jwt.WithVerify(false) to explicitly skip)`)
}
if lvo > 0 {
converted, err := toVerifyOptions(verifyOpts...)
if err != nil {
return nil, fmt.Errorf(`jwt.Parse: failed to convert options into jws.VerifyOption: %w`, err)
}
ctx.verifyOpts = converted
}
data = bytes.TrimSpace(data)
return parse(&ctx, data)
}
const (
_JwsVerifyInvalid = iota
_JwsVerifyDone
_JwsVerifyExpectNested
_JwsVerifySkipped
)
var _ = _JwsVerifyInvalid
func verifyJWS(ctx *parseCtx, payload []byte) ([]byte, int, error) {
if len(ctx.verifyOpts) == 0 {
return nil, _JwsVerifySkipped, nil
}
verified, err := jws.Verify(payload, ctx.verifyOpts...)
return verified, _JwsVerifyDone, err
}
// verify parameter exists to make sure that we don't accidentally skip
// over verification just because alg == "" or key == nil or something.
func parse(ctx *parseCtx, data []byte) (Token, error) {
payload := data
const maxDecodeLevels = 2
// If cty = `JWT`, we expect this to be a nested structure
var expectNested bool
OUTER:
for i := 0; i < maxDecodeLevels; i++ {
switch kind := jwx.GuessFormat(payload); kind {
case jwx.JWT:
if ctx.pedantic {
if expectNested {
return nil, fmt.Errorf(`expected nested encrypted/signed payload, got raw JWT`)
}
}
if i == 0 {
// We were NOT enveloped in other formats
if !ctx.skipVerification {
if _, _, err := verifyJWS(ctx, payload); err != nil {
return nil, err
}
}
}
break OUTER
case jwx.InvalidFormat:
return nil, ErrInvalidJWT()
case jwx.UnknownFormat:
// "Unknown" may include invalid JWTs, for example, those who lack "aud"
// claim. We could be pedantic and reject these
if ctx.pedantic {
return nil, fmt.Errorf(`unknown JWT format (pedantic)`)
}
if i == 0 {
// We were NOT enveloped in other formats
if !ctx.skipVerification {
if _, _, err := verifyJWS(ctx, payload); err != nil {
return nil, err
}
}
}
break OUTER
case jwx.JWS:
// Food for thought: This is going to break if you have multiple layers of
// JWS enveloping using different keys. It is highly unlikely use case,
// but it might happen.
// skipVerification should only be set to true by us. It's used
// when we just want to parse the JWT out of a payload
if !ctx.skipVerification {
// nested return value means:
// false (next envelope _may_ need to be processed)
// true (next envelope MUST be processed)
v, state, err := verifyJWS(ctx, payload)
if err != nil {
return nil, err
}
if state != _JwsVerifySkipped {
payload = v
// We only check for cty and typ if the pedantic flag is enabled
if !ctx.pedantic {
continue
}
if state == _JwsVerifyExpectNested {
expectNested = true
continue OUTER
}
// if we're not nested, we found our target. bail out of this loop
break OUTER
}
}
// No verification.
m, err := jws.Parse(data)
if err != nil {
return nil, fmt.Errorf(`invalid jws message: %w`, err)
}
payload = m.Payload()
default:
return nil, fmt.Errorf(`unsupported format (layer: #%d)`, i+1)
}
expectNested = false
}
if ctx.token == nil {
ctx.token = New()
}
if ctx.localReg != nil {
dcToken, ok := ctx.token.(TokenWithDecodeCtx)
if !ok {
return nil, fmt.Errorf(`typed claim was requested, but the token (%T) does not support DecodeCtx`, ctx.token)
}
dc := json.NewDecodeCtx(ctx.localReg)
dcToken.SetDecodeCtx(dc)
defer func() { dcToken.SetDecodeCtx(nil) }()
}
if err := json.Unmarshal(payload, ctx.token); err != nil {
return nil, fmt.Errorf(`failed to parse token: %w`, err)
}
if ctx.validate {
if err := Validate(ctx.token, ctx.validateOpts...); err != nil {
return nil, err
}
}
return ctx.token, nil
}
// Sign is a convenience function to create a signed JWT token serialized in
// compact form.
//
// It accepts either a raw key (e.g. rsa.PrivateKey, ecdsa.PrivateKey, etc)
// or a jwk.Key, and the name of the algorithm that should be used to sign
// the token.
//
// If the key is a jwk.Key and the key contains a key ID (`kid` field),
// then it is added to the protected header generated by the signature
//
// The algorithm specified in the `alg` parameter must be able to support
// the type of key you provided, otherwise an error is returned.
// For convenience `alg` is of type jwa.KeyAlgorithm so you can pass
// the return value of `(jwk.Key).Algorithm()` directly, but in practice
// it must be an instance of jwa.SignatureAlgorithm, otherwise an error
// is returned.
//
// The protected header will also automatically have the `typ` field set
// to the literal value `JWT`, unless you provide a custom value for it
// by jwt.WithHeaders option.
func Sign(t Token, options ...SignOption) ([]byte, error) {
var soptions []jws.SignOption
if l := len(options); l > 0 {
// we need to from SignOption to Option because ... reasons
// (todo: when go1.18 prevails, use type parameters
rawoptions := make([]Option, l)
for i, option := range options {
rawoptions[i] = option
}
converted, err := toSignOptions(rawoptions...)
if err != nil {
return nil, fmt.Errorf(`jwt.Sign: failed to convert options into jws.SignOption: %w`, err)
}
soptions = converted
}
return NewSerializer().sign(soptions...).Serialize(t)
}
// Equal compares two JWT tokens. Do not use `reflect.Equal` or the like
// to compare tokens as they will also compare extra detail such as
// sync.Mutex objects used to control concurrent access.
//
// The comparison for values is currently done using a simple equality ("=="),
// except for time.Time, which uses time.Equal after dropping the monotonic
// clock and truncating the values to 1 second accuracy.
//
// if both t1 and t2 are nil, returns true
func Equal(t1, t2 Token) bool {
if t1 == nil && t2 == nil {
return true
}
// we already checked for t1 == t2 == nil, so safe to do this
if t1 == nil || t2 == nil {
return false
}
j1, err := json.Marshal(t1)
if err != nil {
return false
}
j2, err := json.Marshal(t2)
if err != nil {
return false
}
return bytes.Equal(j1, j2)
}
func (t *stdToken) Clone() (Token, error) {
dst := New()
dst.Options().Set(*(t.Options()))
for _, pair := range t.makePairs() {
//nolint:forcetypeassert
key := pair.Key.(string)
if err := dst.Set(key, pair.Value); err != nil {
return nil, fmt.Errorf(`failed to set %s: %w`, key, err)
}
}
return dst, nil
}
// RegisterCustomField allows users to specify that a private field
// be decoded as an instance of the specified type. This option has
// a global effect.
//
// For example, suppose you have a custom field `x-birthday`, which
// you want to represent as a string formatted in RFC3339 in JSON,
// but want it back as `time.Time`.
//
// In that case you would register a custom field as follows
//
// jwt.RegisterCustomField(`x-birthday`, timeT)
//
// Then `token.Get("x-birthday")` will still return an `interface{}`,
// but you can convert its type to `time.Time`
//
// bdayif, _ := token.Get(`x-birthday`)
// bday := bdayif.(time.Time)
func RegisterCustomField(name string, object interface{}) {
registry.Register(name, object)
}