auth/vendor/github.com/lestrrat-go/jwx/jwe/encrypt.go

146 lines
4.1 KiB
Go

package jwe
import (
"context"
"sync"
"github.com/lestrrat-go/jwx/internal/base64"
"github.com/lestrrat-go/jwx/jwa"
"github.com/pkg/errors"
)
var encryptCtxPool = sync.Pool{
New: func() interface{} {
return &encryptCtx{}
},
}
func getEncryptCtx() *encryptCtx {
//nolint:forcetypeassert
return encryptCtxPool.Get().(*encryptCtx)
}
func releaseEncryptCtx(ctx *encryptCtx) {
ctx.protected = nil
ctx.contentEncrypter = nil
ctx.generator = nil
ctx.keyEncrypters = nil
ctx.compress = jwa.NoCompress
encryptCtxPool.Put(ctx)
}
// Encrypt takes the plaintext and encrypts into a JWE message.
func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) {
bk, err := e.generator.Generate()
if err != nil {
return nil, errors.Wrap(err, "failed to generate key")
}
cek := bk.Bytes()
if e.protected == nil {
// shouldn't happen, but...
e.protected = NewHeaders()
}
if err := e.protected.Set(ContentEncryptionKey, e.contentEncrypter.Algorithm()); err != nil {
return nil, errors.Wrap(err, `failed to set "enc" in protected header`)
}
compression := e.compress
if compression != jwa.NoCompress {
if err := e.protected.Set(CompressionKey, compression); err != nil {
return nil, errors.Wrap(err, `failed to set "zip" in protected header`)
}
}
// In JWE, multiple recipients may exist -- they receive an
// encrypted version of the CEK, using their key encryption
// algorithm of choice.
recipients := make([]Recipient, len(e.keyEncrypters))
for i, enc := range e.keyEncrypters {
r := NewRecipient()
if err := r.Headers().Set(AlgorithmKey, enc.Algorithm()); err != nil {
return nil, errors.Wrap(err, "failed to set header")
}
if v := enc.KeyID(); v != "" {
if err := r.Headers().Set(KeyIDKey, v); err != nil {
return nil, errors.Wrap(err, "failed to set header")
}
}
enckey, err := enc.Encrypt(cek)
if err != nil {
return nil, errors.Wrap(err, `failed to encrypt key`)
}
if enc.Algorithm() == jwa.ECDH_ES || enc.Algorithm() == jwa.DIRECT {
if len(e.keyEncrypters) > 1 {
return nil, errors.Errorf("unable to support multiple recipients for ECDH-ES")
}
cek = enckey.Bytes()
} else {
if err := r.SetEncryptedKey(enckey.Bytes()); err != nil {
return nil, errors.Wrap(err, "failed to set encrypted key")
}
}
if hp, ok := enckey.(populater); ok {
if err := hp.Populate(r.Headers()); err != nil {
return nil, errors.Wrap(err, "failed to populate")
}
}
recipients[i] = r
}
// If there's only one recipient, you want to include that in the
// protected header
if len(recipients) == 1 {
h, err := e.protected.Merge(context.TODO(), recipients[0].Headers())
if err != nil {
return nil, errors.Wrap(err, "failed to merge protected headers")
}
e.protected = h
}
aad, err := e.protected.Encode()
if err != nil {
return nil, errors.Wrap(err, "failed to base64 encode protected headers")
}
plaintext, err = compress(plaintext, compression)
if err != nil {
return nil, errors.Wrap(err, `failed to compress payload before encryption`)
}
// ...on the other hand, there's only one content cipher.
iv, ciphertext, tag, err := e.contentEncrypter.Encrypt(cek, plaintext, aad)
if err != nil {
return nil, errors.Wrap(err, "failed to encrypt payload")
}
msg := NewMessage()
decodedAad, err := base64.Decode(aad)
if err != nil {
return nil, errors.Wrap(err, "failed to decode base64")
}
if err := msg.Set(AuthenticatedDataKey, decodedAad); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, AuthenticatedDataKey)
}
if err := msg.Set(CipherTextKey, ciphertext); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey)
}
if err := msg.Set(InitializationVectorKey, iv); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey)
}
if err := msg.Set(ProtectedHeadersKey, e.protected); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey)
}
if err := msg.Set(RecipientsKey, recipients); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, RecipientsKey)
}
if err := msg.Set(TagKey, tag); err != nil {
return nil, errors.Wrapf(err, `failed to set %s`, TagKey)
}
return msg, nil
}