auth/vendor/github.com/lestrrat-go/jwx/v2/jwt/validate.go

560 lines
15 KiB
Go

package jwt
import (
"context"
"fmt"
"log"
"strconv"
"time"
)
type Clock interface {
Now() time.Time
}
type ClockFunc func() time.Time
func (f ClockFunc) Now() time.Time {
return f()
}
func isSupportedTimeClaim(c string) error {
switch c {
case ExpirationKey, IssuedAtKey, NotBeforeKey:
return nil
}
return NewValidationError(fmt.Errorf(`unsupported time claim %s`, strconv.Quote(c)))
}
func timeClaim(t Token, clock Clock, c string) time.Time {
switch c {
case ExpirationKey:
return t.Expiration()
case IssuedAtKey:
return t.IssuedAt()
case NotBeforeKey:
return t.NotBefore()
case "":
return clock.Now()
}
return time.Time{} // should *NEVER* reach here, but...
}
// Validate makes sure that the essential claims stand.
//
// See the various `WithXXX` functions for optional parameters
// that can control the behavior of this method.
func Validate(t Token, options ...ValidateOption) error {
ctx := context.Background()
trunc := time.Second
var clock Clock = ClockFunc(time.Now)
var skew time.Duration
var validators = []Validator{
IsIssuedAtValid(),
IsExpirationValid(),
IsNbfValid(),
}
for _, o := range options {
//nolint:forcetypeassert
switch o.Ident() {
case identClock{}:
clock = o.Value().(Clock)
case identAcceptableSkew{}:
skew = o.Value().(time.Duration)
case identTruncation{}:
trunc = o.Value().(time.Duration)
case identContext{}:
ctx = o.Value().(context.Context)
case identValidator{}:
v := o.Value().(Validator)
switch v := v.(type) {
case *isInTimeRange:
if v.c1 != "" {
if err := isSupportedTimeClaim(v.c1); err != nil {
return err
}
validators = append(validators, IsRequired(v.c1))
}
if v.c2 != "" {
if err := isSupportedTimeClaim(v.c2); err != nil {
return err
}
validators = append(validators, IsRequired(v.c2))
}
}
validators = append(validators, v)
}
}
ctx = SetValidationCtxSkew(ctx, skew)
ctx = SetValidationCtxClock(ctx, clock)
ctx = SetValidationCtxTruncation(ctx, trunc)
for _, v := range validators {
if err := v.Validate(ctx, t); err != nil {
return err
}
}
return nil
}
type isInTimeRange struct {
c1 string
c2 string
dur time.Duration
less bool // if true, d =< c1 - c2. otherwise d >= c1 - c2
}
// MaxDeltaIs implements the logic behind `WithMaxDelta()` option
func MaxDeltaIs(c1, c2 string, dur time.Duration) Validator {
return &isInTimeRange{
c1: c1,
c2: c2,
dur: dur,
less: true,
}
}
// MinDeltaIs implements the logic behind `WithMinDelta()` option
func MinDeltaIs(c1, c2 string, dur time.Duration) Validator {
return &isInTimeRange{
c1: c1,
c2: c2,
dur: dur,
less: false,
}
}
func (iitr *isInTimeRange) Validate(ctx context.Context, t Token) ValidationError {
clock := ValidationCtxClock(ctx) // MUST be populated
skew := ValidationCtxSkew(ctx) // MUST be populated
// We don't check if the claims already exist, because we already did that
// by piggybacking on `required` check.
t1 := timeClaim(t, clock, iitr.c1)
t2 := timeClaim(t, clock, iitr.c2)
if iitr.less { // t1 - t2 <= iitr.dur
// t1 - t2 < iitr.dur + skew
if t1.Sub(t2) > iitr.dur+skew {
return NewValidationError(fmt.Errorf(`iitr between %s and %s exceeds %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
}
} else {
if t1.Sub(t2) < iitr.dur-skew {
return NewValidationError(fmt.Errorf(`iitr between %s and %s is less than %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
}
}
return nil
}
type ValidationError interface {
error
isValidationError()
Unwrap() error
}
func NewValidationError(err error) ValidationError {
return &validationError{error: err}
}
// This is a generic validation error.
type validationError struct {
error
}
func (validationError) isValidationError() {}
func (err *validationError) Unwrap() error {
return err.error
}
type missingRequiredClaimError struct {
claim string
}
func (err *missingRequiredClaimError) Error() string {
return fmt.Sprintf("%q not satisfied: required claim not found", err.claim)
}
func (err *missingRequiredClaimError) Is(target error) bool {
_, ok := target.(*missingRequiredClaimError)
return ok
}
func (err *missingRequiredClaimError) isValidationError() {}
func (*missingRequiredClaimError) Unwrap() error { return nil }
type invalidAudienceError struct {
error
}
func (err *invalidAudienceError) Is(target error) bool {
_, ok := target.(*invalidAudienceError)
return ok
}
func (err *invalidAudienceError) isValidationError() {}
func (err *invalidAudienceError) Unwrap() error {
return err.error
}
func (err *invalidAudienceError) Error() string {
if err.error == nil {
return `"aud" not satisfied`
}
return err.error.Error()
}
type invalidIssuerError struct {
error
}
func (err *invalidIssuerError) Is(target error) bool {
_, ok := target.(*invalidIssuerError)
return ok
}
func (err *invalidIssuerError) isValidationError() {}
func (err *invalidIssuerError) Unwrap() error {
return err.error
}
func (err *invalidIssuerError) Error() string {
if err.error == nil {
return `"iss" not satisfied`
}
return err.error.Error()
}
var errTokenExpired = NewValidationError(fmt.Errorf(`"exp" not satisfied`))
var errInvalidIssuedAt = NewValidationError(fmt.Errorf(`"iat" not satisfied`))
var errTokenNotYetValid = NewValidationError(fmt.Errorf(`"nbf" not satisfied`))
var errInvalidAudience = &invalidAudienceError{}
var errInvalidIssuer = &invalidIssuerError{}
var errRequiredClaim = &missingRequiredClaimError{}
// ErrTokenExpired returns the immutable error used when `exp` claim
// is not satisfied.
//
// The return value should only be used for comparison using `errors.Is()`
func ErrTokenExpired() ValidationError {
return errTokenExpired
}
// ErrInvalidIssuedAt returns the immutable error used when `iat` claim
// is not satisfied
//
// The return value should only be used for comparison using `errors.Is()`
func ErrInvalidIssuedAt() ValidationError {
return errInvalidIssuedAt
}
// ErrTokenNotYetValid returns the immutable error used when `nbf` claim
// is not satisfied
//
// The return value should only be used for comparison using `errors.Is()`
func ErrTokenNotYetValid() ValidationError {
return errTokenNotYetValid
}
// ErrInvalidAudience returns the immutable error used when `aud` claim
// is not satisfied
//
// The return value should only be used for comparison using `errors.Is()`
func ErrInvalidAudience() ValidationError {
return errInvalidAudience
}
// ErrInvalidIssuer returns the immutable error used when `iss` claim
// is not satisfied
//
// The return value should only be used for comparison using `errors.Is()`
func ErrInvalidIssuer() ValidationError {
return errInvalidIssuer
}
// ErrMissingRequiredClaim should not have been exported, and will be
// removed in a future release. Use `ErrRequiredClaim()` instead to get
// an error to be used in `errors.Is()`
//
// This function should not have been implemented as a constructor.
// but rather a means to retrieve an opaque and immutable error value
// that could be passed to `errors.Is()`.
func ErrMissingRequiredClaim(name string) ValidationError {
return &missingRequiredClaimError{claim: name}
}
// ErrRequiredClaim returns the immutable error used when the claim
// specified by `jwt.IsRequired()` is not present.
//
// The return value should only be used for comparison using `errors.Is()`
func ErrRequiredClaim() ValidationError {
return errRequiredClaim
}
// Validator describes interface to validate a Token.
type Validator interface {
// Validate should return an error if a required conditions is not met.
Validate(context.Context, Token) ValidationError
}
// ValidatorFunc is a type of Validator that does not have any
// state, that is implemented as a function
type ValidatorFunc func(context.Context, Token) ValidationError
func (vf ValidatorFunc) Validate(ctx context.Context, tok Token) ValidationError {
return vf(ctx, tok)
}
type identValidationCtxClock struct{}
type identValidationCtxSkew struct{}
type identValidationCtxTruncation struct{}
func SetValidationCtxClock(ctx context.Context, cl Clock) context.Context {
return context.WithValue(ctx, identValidationCtxClock{}, cl)
}
func SetValidationCtxTruncation(ctx context.Context, dur time.Duration) context.Context {
return context.WithValue(ctx, identValidationCtxTruncation{}, dur)
}
func SetValidationCtxSkew(ctx context.Context, dur time.Duration) context.Context {
return context.WithValue(ctx, identValidationCtxSkew{}, dur)
}
// ValidationCtxClock returns the Clock object associated with
// the current validation context. This value will always be available
// during validation of tokens.
func ValidationCtxClock(ctx context.Context) Clock {
//nolint:forcetypeassert
return ctx.Value(identValidationCtxClock{}).(Clock)
}
func ValidationCtxSkew(ctx context.Context) time.Duration {
//nolint:forcetypeassert
return ctx.Value(identValidationCtxSkew{}).(time.Duration)
}
func ValidationCtxTruncation(ctx context.Context) time.Duration {
//nolint:forcetypeassert
return ctx.Value(identValidationCtxTruncation{}).(time.Duration)
}
// IsExpirationValid is one of the default validators that will be executed.
// It does not need to be specified by users, but it exists as an
// exported field so that you can check what it does.
//
// The supplied context.Context object must have the "clock" and "skew"
// populated with appropriate values using SetValidationCtxClock() and
// SetValidationCtxSkew()
func IsExpirationValid() Validator {
return ValidatorFunc(isExpirationValid)
}
func isExpirationValid(ctx context.Context, t Token) ValidationError {
tv := t.Expiration()
if tv.IsZero() || tv.Unix() == 0 {
return nil
}
clock := ValidationCtxClock(ctx) // MUST be populated
skew := ValidationCtxSkew(ctx) // MUST be populated
trunc := ValidationCtxTruncation(ctx) // MUST be populated
now := clock.Now().Truncate(trunc)
ttv := tv.Truncate(trunc)
// expiration date must be after NOW
if !now.Before(ttv.Add(skew)) {
return ErrTokenExpired()
}
return nil
}
// IsIssuedAtValid is one of the default validators that will be executed.
// It does not need to be specified by users, but it exists as an
// exported field so that you can check what it does.
//
// The supplied context.Context object must have the "clock" and "skew"
// populated with appropriate values using SetValidationCtxClock() and
// SetValidationCtxSkew()
func IsIssuedAtValid() Validator {
return ValidatorFunc(isIssuedAtValid)
}
func isIssuedAtValid(ctx context.Context, t Token) ValidationError {
tv := t.IssuedAt()
if tv.IsZero() || tv.Unix() == 0 {
return nil
}
clock := ValidationCtxClock(ctx) // MUST be populated
skew := ValidationCtxSkew(ctx) // MUST be populated
trunc := ValidationCtxTruncation(ctx) // MUST be populated
now := clock.Now().Truncate(trunc)
ttv := tv.Truncate(trunc)
log.Printf("now = %s, ttv = %s, skew = %s, trunc = %s", now.UTC(), ttv, skew, trunc)
if now.Before(ttv.Add(-1 * skew)) {
return ErrInvalidIssuedAt()
}
return nil
}
// IsNbfValid is one of the default validators that will be executed.
// It does not need to be specified by users, but it exists as an
// exported field so that you can check what it does.
//
// The supplied context.Context object must have the "clock" and "skew"
// populated with appropriate values using SetValidationCtxClock() and
// SetValidationCtxSkew()
func IsNbfValid() Validator {
return ValidatorFunc(isNbfValid)
}
func isNbfValid(ctx context.Context, t Token) ValidationError {
tv := t.NotBefore()
if tv.IsZero() || tv.Unix() == 0 {
return nil
}
clock := ValidationCtxClock(ctx) // MUST be populated
skew := ValidationCtxSkew(ctx) // MUST be populated
trunc := ValidationCtxTruncation(ctx) // MUST be populated
// Truncation always happens even for trunc = 0 because
// we also use this to strip monotonic clocks
now := clock.Now().Truncate(trunc)
ttv := tv.Truncate(trunc)
// "now" cannot be before t - skew, so we check for now > t - skew
ttv = ttv.Add(-1 * skew)
if now.Before(ttv) {
return ErrTokenNotYetValid()
}
return nil
}
type claimContainsString struct {
name string
value string
makeErr func(error) ValidationError
}
// ClaimContainsString can be used to check if the claim called `name`, which is
// expected to be a list of strings, contains `value`. Currently because of the
// implementation this will probably only work for `aud` fields.
func ClaimContainsString(name, value string) Validator {
return claimContainsString{
name: name,
value: value,
makeErr: NewValidationError,
}
}
// IsValidationError returns true if the error is a validation error
func IsValidationError(err error) bool {
switch err {
case errTokenExpired, errTokenNotYetValid, errInvalidIssuedAt:
return true
default:
switch err.(type) {
case *validationError, *invalidAudienceError, *invalidIssuerError, *missingRequiredClaimError:
return true
default:
return false
}
}
}
func (ccs claimContainsString) Validate(_ context.Context, t Token) ValidationError {
v, ok := t.Get(ccs.name)
if !ok {
return ccs.makeErr(fmt.Errorf(`claim %q not found`, ccs.name))
}
list, ok := v.([]string)
if !ok {
return ccs.makeErr(fmt.Errorf(`claim %q must be a []string (got %T)`, ccs.name, v))
}
for _, v := range list {
if v == ccs.value {
return nil
}
}
return ccs.makeErr(fmt.Errorf(`%q not satisfied`, ccs.name))
}
func makeInvalidAudienceError(err error) ValidationError {
return &invalidAudienceError{error: err}
}
// audienceClaimContainsString can be used to check if the audience claim, which is
// expected to be a list of strings, contains `value`.
func audienceClaimContainsString(value string) Validator {
return claimContainsString{
name: AudienceKey,
value: value,
makeErr: makeInvalidAudienceError,
}
}
type claimValueIs struct {
name string
value interface{}
makeErr func(error) ValidationError
}
// ClaimValueIs creates a Validator that checks if the value of claim `name`
// matches `value`. The comparison is done using a simple `==` comparison,
// and therefore complex comparisons may fail using this code. If you
// need to do more, use a custom Validator.
func ClaimValueIs(name string, value interface{}) Validator {
return &claimValueIs{
name: name,
value: value,
makeErr: NewValidationError,
}
}
func (cv *claimValueIs) Validate(_ context.Context, t Token) ValidationError {
v, ok := t.Get(cv.name)
if !ok {
return cv.makeErr(fmt.Errorf(`%q not satisfied: claim %q does not exist`, cv.name, cv.name))
}
if v != cv.value {
return cv.makeErr(fmt.Errorf(`%q not satisfied: values do not match`, cv.name))
}
return nil
}
func makeIssuerClaimError(err error) ValidationError {
return &invalidIssuerError{error: err}
}
// issuerClaimValueIs creates a Validator that checks if the issuer claim
// matches `value`.
func issuerClaimValueIs(value string) Validator {
return &claimValueIs{
name: IssuerKey,
value: value,
makeErr: makeIssuerClaimError,
}
}
// IsRequired creates a Validator that checks if the required claim `name`
// exists in the token
func IsRequired(name string) Validator {
return isRequired(name)
}
type isRequired string
func (ir isRequired) Validate(_ context.Context, t Token) ValidationError {
name := string(ir)
_, ok := t.Get(name)
if !ok {
return &missingRequiredClaimError{claim: name}
}
return nil
}