auth/vendor/github.com/lestrrat-go/jwx/jwt/serialize.go

240 lines
5.8 KiB
Go

package jwt
import (
"fmt"
"github.com/lestrrat-go/jwx/internal/json"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwe"
"github.com/lestrrat-go/jwx/jws"
"github.com/pkg/errors"
)
type SerializeCtx interface {
Step() int
Nested() bool
}
type serializeCtx struct {
step int
nested bool
}
func (ctx *serializeCtx) Step() int {
return ctx.step
}
func (ctx *serializeCtx) Nested() bool {
return ctx.nested
}
type SerializeStep interface {
Serialize(SerializeCtx, interface{}) (interface{}, error)
}
// Serializer is a generic serializer for JWTs. Whereas other conveinience
// functions can only do one thing (such as generate a JWS signed JWT),
// Using this construct you can serialize the token however you want.
//
// By default the serializer only marshals the token into a JSON payload.
// You must set up the rest of the steps that should be taken by the
// serializer.
//
// For example, to marshal the token into JSON, then apply JWS and JWE
// in that order, you would do:
//
// serialized, err := jwt.NewSerialer().
// Sign(jwa.RS256, key).
// Encrypt(jwa.RSA_OAEP, key.PublicKey).
// Serialize(token)
//
// The `jwt.Sign()` function is equivalent to
//
// serialized, err := jwt.NewSerializer().
// Sign(...args...).
// Serialize(token)
type Serializer struct {
steps []SerializeStep
}
// NewSerializer creates a new empty serializer.
func NewSerializer() *Serializer {
return &Serializer{}
}
// Reset clears all of the registered steps.
func (s *Serializer) Reset() *Serializer {
s.steps = nil
return s
}
// Step adds a new Step to the serialization process
func (s *Serializer) Step(step SerializeStep) *Serializer {
s.steps = append(s.steps, step)
return s
}
type jsonSerializer struct{}
func (jsonSerializer) Serialize(_ SerializeCtx, v interface{}) (interface{}, error) {
token, ok := v.(Token)
if !ok {
return nil, errors.Errorf(`invalid input: expected jwt.Token`)
}
buf, err := json.Marshal(token)
if err != nil {
return nil, errors.Errorf(`failed to serialize as JSON`)
}
return buf, nil
}
type genericHeader interface {
Get(string) (interface{}, bool)
Set(string, interface{}) error
}
func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error {
// cty and typ are common between JWE/JWS, so we don't use
// the constants in jws/jwe package here
const typKey = `typ`
const ctyKey = `cty`
if ctx.Step() == 1 {
// We are executed immediately after json marshaling
if _, ok := hdrs.Get(typKey); !ok {
if err := hdrs.Set(typKey, `JWT`); err != nil {
return errors.Wrapf(err, `failed to set %s key to "JWT"`, typKey)
}
}
} else {
if ctx.Nested() {
// If this is part of a nested sequence, we should set cty = 'JWT'
// https://datatracker.ietf.org/doc/html/rfc7519#section-5.2
if err := hdrs.Set(ctyKey, `JWT`); err != nil {
return errors.Wrapf(err, `failed to set %s key to "JWT"`, ctyKey)
}
}
}
return nil
}
type jwsSerializer struct {
alg jwa.SignatureAlgorithm
key interface{}
options []SignOption
}
func (s *jwsSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
payload, ok := v.([]byte)
if !ok {
return nil, errors.New(`expected []byte as input`)
}
var hdrs jws.Headers
//nolint:forcetypeassert
for _, option := range s.options {
switch option.Ident() {
case identJwsHeaders{}:
hdrs = option.Value().(jws.Headers)
}
}
if hdrs == nil {
hdrs = jws.NewHeaders()
}
if err := setTypeOrCty(ctx, hdrs); err != nil {
return nil, err // this is already wrapped
}
// JWTs MUST NOT use b64 = false
// https://datatracker.ietf.org/doc/html/rfc7797#section-7
if v, ok := hdrs.Get("b64"); ok {
if bval, bok := v.(bool); bok {
if !bval { // b64 = false
return nil, errors.New(`b64 cannot be false for JWTs`)
}
}
}
return jws.Sign(payload, s.alg, s.key, jws.WithHeaders(hdrs))
}
func (s *Serializer) Sign(alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) *Serializer {
return s.Step(&jwsSerializer{
alg: alg,
key: key,
options: options,
})
}
type jweSerializer struct {
keyalg jwa.KeyEncryptionAlgorithm
key interface{}
contentalg jwa.ContentEncryptionAlgorithm
compressalg jwa.CompressionAlgorithm
options []EncryptOption
}
func (s *jweSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
payload, ok := v.([]byte)
if !ok {
return nil, fmt.Errorf(`expected []byte as input`)
}
var hdrs jwe.Headers
//nolint:forcetypeassert
for _, option := range s.options {
switch option.Ident() {
case identJweHeaders{}:
hdrs = option.Value().(jwe.Headers)
}
}
if hdrs == nil {
hdrs = jwe.NewHeaders()
}
if err := setTypeOrCty(ctx, hdrs); err != nil {
return nil, err // this is already wrapped
}
return jwe.Encrypt(payload, s.keyalg, s.key, s.contentalg, s.compressalg, jwe.WithProtectedHeaders(hdrs))
}
func (s *Serializer) Encrypt(keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) *Serializer {
return s.Step(&jweSerializer{
keyalg: keyalg,
key: key,
contentalg: contentalg,
compressalg: compressalg,
options: options,
})
}
func (s *Serializer) Serialize(t Token) ([]byte, error) {
steps := make([]SerializeStep, len(s.steps)+1)
steps[0] = jsonSerializer{}
for i, step := range s.steps {
steps[i+1] = step
}
var ctx serializeCtx
ctx.nested = len(s.steps) > 1
var payload interface{} = t
for i, step := range steps {
ctx.step = i
v, err := step.Serialize(&ctx, payload)
if err != nil {
return nil, errors.Wrapf(err, `failed to serialize token at step #%d`, i+1)
}
payload = v
}
res, ok := payload.([]byte)
if !ok {
return nil, errors.New(`invalid serialization produced`)
}
return res, nil
}