auth/vendor/github.com/lestrrat-go/jwx/v2/jwe/decrypt.go

301 lines
8.2 KiB
Go

package jwe
import (
"crypto/aes"
cryptocipher "crypto/cipher"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"fmt"
"hash"
"golang.org/x/crypto/pbkdf2"
"github.com/lestrrat-go/jwx/v2/internal/keyconv"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwe/internal/cipher"
"github.com/lestrrat-go/jwx/v2/jwe/internal/content_crypt"
"github.com/lestrrat-go/jwx/v2/jwe/internal/keyenc"
"github.com/lestrrat-go/jwx/v2/x25519"
)
// decrypter is responsible for taking various components to decrypt a message.
// its operation is not concurrency safe. You must provide locking yourself
//nolint:govet
type decrypter struct {
aad []byte
apu []byte
apv []byte
computedAad []byte
iv []byte
keyiv []byte
keysalt []byte
keytag []byte
tag []byte
privkey interface{}
pubkey interface{}
ctalg jwa.ContentEncryptionAlgorithm
keyalg jwa.KeyEncryptionAlgorithm
cipher content_crypt.Cipher
keycount int
}
// newDecrypter Creates a new Decrypter instance. You must supply the
// rest of parameters via their respective setter methods before
// calling Decrypt().
//
// privkey must be a private key in its "raw" format (i.e. something like
// *rsa.PrivateKey, instead of jwk.Key)
//
// You should consider this object immutable once you assign values to it.
func newDecrypter(keyalg jwa.KeyEncryptionAlgorithm, ctalg jwa.ContentEncryptionAlgorithm, privkey interface{}) *decrypter {
return &decrypter{
ctalg: ctalg,
keyalg: keyalg,
privkey: privkey,
}
}
func (d *decrypter) AgreementPartyUInfo(apu []byte) *decrypter {
d.apu = apu
return d
}
func (d *decrypter) AgreementPartyVInfo(apv []byte) *decrypter {
d.apv = apv
return d
}
func (d *decrypter) AuthenticatedData(aad []byte) *decrypter {
d.aad = aad
return d
}
func (d *decrypter) ComputedAuthenticatedData(aad []byte) *decrypter {
d.computedAad = aad
return d
}
func (d *decrypter) ContentEncryptionAlgorithm(ctalg jwa.ContentEncryptionAlgorithm) *decrypter {
d.ctalg = ctalg
return d
}
func (d *decrypter) InitializationVector(iv []byte) *decrypter {
d.iv = iv
return d
}
func (d *decrypter) KeyCount(keycount int) *decrypter {
d.keycount = keycount
return d
}
func (d *decrypter) KeyInitializationVector(keyiv []byte) *decrypter {
d.keyiv = keyiv
return d
}
func (d *decrypter) KeySalt(keysalt []byte) *decrypter {
d.keysalt = keysalt
return d
}
func (d *decrypter) KeyTag(keytag []byte) *decrypter {
d.keytag = keytag
return d
}
// PublicKey sets the public key to be used in decoding EC based encryptions.
// The key must be in its "raw" format (i.e. *ecdsa.PublicKey, instead of jwk.Key)
func (d *decrypter) PublicKey(pubkey interface{}) *decrypter {
d.pubkey = pubkey
return d
}
func (d *decrypter) Tag(tag []byte) *decrypter {
d.tag = tag
return d
}
func (d *decrypter) ContentCipher() (content_crypt.Cipher, error) {
if d.cipher == nil {
switch d.ctalg {
case jwa.A128GCM, jwa.A192GCM, jwa.A256GCM, jwa.A128CBC_HS256, jwa.A192CBC_HS384, jwa.A256CBC_HS512:
cipher, err := cipher.NewAES(d.ctalg)
if err != nil {
return nil, fmt.Errorf(`failed to build content cipher for %s: %w`, d.ctalg, err)
}
d.cipher = cipher
default:
return nil, fmt.Errorf(`invalid content cipher algorithm (%s)`, d.ctalg)
}
}
return d.cipher, nil
}
func (d *decrypter) Decrypt(recipientKey, ciphertext []byte) (plaintext []byte, err error) {
cek, keyerr := d.DecryptKey(recipientKey)
if keyerr != nil {
err = fmt.Errorf(`failed to decrypt key: %w`, keyerr)
return
}
cipher, ciphererr := d.ContentCipher()
if ciphererr != nil {
err = fmt.Errorf(`failed to fetch content crypt cipher: %w`, ciphererr)
return
}
computedAad := d.computedAad
if d.aad != nil {
computedAad = append(append(computedAad, '.'), d.aad...)
}
plaintext, err = cipher.Decrypt(cek, d.iv, ciphertext, d.tag, computedAad)
if err != nil {
err = fmt.Errorf(`failed to decrypt payload: %w`, err)
return
}
return plaintext, nil
}
func (d *decrypter) decryptSymmetricKey(recipientKey, cek []byte) ([]byte, error) {
switch d.keyalg {
case jwa.DIRECT:
return cek, nil
case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
var hashFunc func() hash.Hash
var keylen int
switch d.keyalg {
case jwa.PBES2_HS256_A128KW:
hashFunc = sha256.New
keylen = 16
case jwa.PBES2_HS384_A192KW:
hashFunc = sha512.New384
keylen = 24
case jwa.PBES2_HS512_A256KW:
hashFunc = sha512.New
keylen = 32
}
salt := []byte(d.keyalg)
salt = append(salt, byte(0))
salt = append(salt, d.keysalt...)
cek = pbkdf2.Key(cek, salt, d.keycount, keylen, hashFunc)
fallthrough
case jwa.A128KW, jwa.A192KW, jwa.A256KW:
block, err := aes.NewCipher(cek)
if err != nil {
return nil, fmt.Errorf(`failed to create new AES cipher: %w`, err)
}
jek, err := keyenc.Unwrap(block, recipientKey)
if err != nil {
return nil, fmt.Errorf(`failed to unwrap key: %w`, err)
}
return jek, nil
case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
if len(d.keyiv) != 12 {
return nil, fmt.Errorf("GCM requires 96-bit iv, got %d", len(d.keyiv)*8)
}
if len(d.keytag) != 16 {
return nil, fmt.Errorf("GCM requires 128-bit tag, got %d", len(d.keytag)*8)
}
block, err := aes.NewCipher(cek)
if err != nil {
return nil, fmt.Errorf(`failed to create new AES cipher: %w`, err)
}
aesgcm, err := cryptocipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf(`failed to create new GCM wrap: %w`, err)
}
ciphertext := recipientKey[:]
ciphertext = append(ciphertext, d.keytag...)
jek, err := aesgcm.Open(nil, d.keyiv, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf(`failed to decode key: %w`, err)
}
return jek, nil
default:
return nil, fmt.Errorf("decrypt key: unsupported algorithm %s", d.keyalg)
}
}
func (d *decrypter) DecryptKey(recipientKey []byte) (cek []byte, err error) {
if d.keyalg.IsSymmetric() {
var ok bool
cek, ok = d.privkey.([]byte)
if !ok {
return nil, fmt.Errorf("decrypt key: []byte is required as the key to build %s key decrypter (got %T)", d.keyalg, d.privkey)
}
return d.decryptSymmetricKey(recipientKey, cek)
}
k, err := d.BuildKeyDecrypter()
if err != nil {
return nil, fmt.Errorf(`failed to build key decrypter: %w`, err)
}
cek, err = k.Decrypt(recipientKey)
if err != nil {
return nil, fmt.Errorf(`failed to decrypt key: %w`, err)
}
return cek, nil
}
func (d *decrypter) BuildKeyDecrypter() (keyenc.Decrypter, error) {
cipher, err := d.ContentCipher()
if err != nil {
return nil, fmt.Errorf(`failed to fetch content crypt cipher: %w`, err)
}
switch alg := d.keyalg; alg {
case jwa.RSA1_5:
var privkey rsa.PrivateKey
if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
return nil, fmt.Errorf(`*rsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err)
}
return keyenc.NewRSAPKCS15Decrypt(alg, &privkey, cipher.KeySize()/2), nil
case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
var privkey rsa.PrivateKey
if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
return nil, fmt.Errorf(`*rsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err)
}
return keyenc.NewRSAOAEPDecrypt(alg, &privkey)
case jwa.A128KW, jwa.A192KW, jwa.A256KW:
sharedkey, ok := d.privkey.([]byte)
if !ok {
return nil, fmt.Errorf("[]byte is required as the key to build %s key decrypter", alg)
}
return keyenc.NewAES(alg, sharedkey)
case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
switch d.pubkey.(type) {
case x25519.PublicKey:
return keyenc.NewECDHESDecrypt(alg, d.ctalg, d.pubkey, d.apu, d.apv, d.privkey), nil
default:
var pubkey ecdsa.PublicKey
if err := keyconv.ECDSAPublicKey(&pubkey, d.pubkey); err != nil {
return nil, fmt.Errorf(`*ecdsa.PublicKey is required as the key to build %s key decrypter: %w`, alg, err)
}
var privkey ecdsa.PrivateKey
if err := keyconv.ECDSAPrivateKey(&privkey, d.privkey); err != nil {
return nil, fmt.Errorf(`*ecdsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err)
}
return keyenc.NewECDHESDecrypt(alg, d.ctalg, &pubkey, d.apu, d.apv, &privkey), nil
}
default:
return nil, fmt.Errorf(`unsupported algorithm for key decryption (%s)`, alg)
}
}