auth/vendor/github.com/lestrrat-go/jwx/jwe/message.go

649 lines
17 KiB
Go

package jwe
import (
"context"
"crypto/ecdsa"
"fmt"
"github.com/lestrrat-go/jwx/internal/json"
"github.com/lestrrat-go/jwx/internal/pool"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/internal/base64"
"github.com/lestrrat-go/jwx/jwa"
"github.com/pkg/errors"
)
// NewRecipient creates a Recipient object
func NewRecipient() Recipient {
return &stdRecipient{
headers: NewHeaders(),
}
}
func (r *stdRecipient) SetHeaders(h Headers) error {
r.headers = h
return nil
}
func (r *stdRecipient) SetEncryptedKey(v []byte) error {
r.encryptedKey = v
return nil
}
func (r *stdRecipient) Headers() Headers {
return r.headers
}
func (r *stdRecipient) EncryptedKey() []byte {
return r.encryptedKey
}
type recipientMarshalProxy struct {
Headers Headers `json:"header"`
EncryptedKey string `json:"encrypted_key"`
}
func (r *stdRecipient) UnmarshalJSON(buf []byte) error {
var proxy recipientMarshalProxy
proxy.Headers = NewHeaders()
if err := json.Unmarshal(buf, &proxy); err != nil {
return errors.Wrap(err, `failed to unmarshal json into recipient`)
}
r.headers = proxy.Headers
decoded, err := base64.DecodeString(proxy.EncryptedKey)
if err != nil {
return errors.Wrap(err, `failed to decode "encrypted_key"`)
}
r.encryptedKey = decoded
return nil
}
func (r *stdRecipient) MarshalJSON() ([]byte, error) {
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
buf.WriteString(`{"header":`)
hdrbuf, err := r.headers.MarshalJSON()
if err != nil {
return nil, errors.Wrap(err, `failed to marshal recipient header`)
}
buf.Write(hdrbuf)
buf.WriteString(`,"encrypted_key":"`)
buf.WriteString(base64.EncodeToString(r.encryptedKey))
buf.WriteString(`"}`)
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}
// NewMessage creates a new message
func NewMessage() *Message {
return &Message{}
}
func (m *Message) AuthenticatedData() []byte {
return m.authenticatedData
}
func (m *Message) CipherText() []byte {
return m.cipherText
}
func (m *Message) InitializationVector() []byte {
return m.initializationVector
}
func (m *Message) Tag() []byte {
return m.tag
}
func (m *Message) ProtectedHeaders() Headers {
return m.protectedHeaders
}
func (m *Message) Recipients() []Recipient {
return m.recipients
}
func (m *Message) UnprotectedHeaders() Headers {
return m.unprotectedHeaders
}
const (
AuthenticatedDataKey = "aad"
CipherTextKey = "ciphertext"
CountKey = "p2c"
InitializationVectorKey = "iv"
ProtectedHeadersKey = "protected"
RecipientsKey = "recipients"
SaltKey = "p2s"
TagKey = "tag"
UnprotectedHeadersKey = "unprotected"
HeadersKey = "header"
EncryptedKeyKey = "encrypted_key"
)
func (m *Message) Set(k string, v interface{}) error {
switch k {
case AuthenticatedDataKey:
buf, ok := v.([]byte)
if !ok {
return errors.Errorf(`invalid value %T for %s key`, v, AuthenticatedDataKey)
}
m.authenticatedData = buf
case CipherTextKey:
buf, ok := v.([]byte)
if !ok {
return errors.Errorf(`invalid value %T for %s key`, v, CipherTextKey)
}
m.cipherText = buf
case InitializationVectorKey:
buf, ok := v.([]byte)
if !ok {
return errors.Errorf(`invalid value %T for %s key`, v, InitializationVectorKey)
}
m.initializationVector = buf
case ProtectedHeadersKey:
cv, ok := v.(Headers)
if !ok {
return errors.Errorf(`invalid value %T for %s key`, v, ProtectedHeadersKey)
}
m.protectedHeaders = cv
case RecipientsKey:
cv, ok := v.([]Recipient)
if !ok {
return errors.Errorf(`invalid value %T for %s key`, v, RecipientsKey)
}
m.recipients = cv
case TagKey:
buf, ok := v.([]byte)
if !ok {
return errors.Errorf(`invalid value %T for %s key`, v, TagKey)
}
m.tag = buf
case UnprotectedHeadersKey:
cv, ok := v.(Headers)
if !ok {
return errors.Errorf(`invalid value %T for %s key`, v, UnprotectedHeadersKey)
}
m.unprotectedHeaders = cv
default:
if m.unprotectedHeaders == nil {
m.unprotectedHeaders = NewHeaders()
}
return m.unprotectedHeaders.Set(k, v)
}
return nil
}
type messageMarshalProxy struct {
AuthenticatedData string `json:"aad,omitempty"`
CipherText string `json:"ciphertext"`
InitializationVector string `json:"iv,omitempty"`
ProtectedHeaders json.RawMessage `json:"protected"`
Recipients []json.RawMessage `json:"recipients,omitempty"`
Tag string `json:"tag,omitempty"`
UnprotectedHeaders Headers `json:"unprotected,omitempty"`
// For flattened structure. Headers is NOT a Headers type,
// so that we can detect its presence by checking proxy.Headers != nil
Headers json.RawMessage `json:"header,omitempty"`
EncryptedKey string `json:"encrypted_key,omitempty"`
}
func (m *Message) MarshalJSON() ([]byte, error) {
// This is slightly convoluted, but we need to encode the
// protected headers, so we do it by hand
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
enc := json.NewEncoder(buf)
fmt.Fprintf(buf, `{`)
var wrote bool
if aad := m.AuthenticatedData(); len(aad) > 0 {
wrote = true
fmt.Fprintf(buf, `%#v:`, AuthenticatedDataKey)
if err := enc.Encode(base64.EncodeToString(aad)); err != nil {
return nil, errors.Wrapf(err, `failed to encode %s field`, AuthenticatedDataKey)
}
}
if cipherText := m.CipherText(); len(cipherText) > 0 {
if wrote {
fmt.Fprintf(buf, `,`)
}
wrote = true
fmt.Fprintf(buf, `%#v:`, CipherTextKey)
if err := enc.Encode(base64.EncodeToString(cipherText)); err != nil {
return nil, errors.Wrapf(err, `failed to encode %s field`, CipherTextKey)
}
}
if iv := m.InitializationVector(); len(iv) > 0 {
if wrote {
fmt.Fprintf(buf, `,`)
}
wrote = true
fmt.Fprintf(buf, `%#v:`, InitializationVectorKey)
if err := enc.Encode(base64.EncodeToString(iv)); err != nil {
return nil, errors.Wrapf(err, `failed to encode %s field`, InitializationVectorKey)
}
}
if h := m.ProtectedHeaders(); h != nil {
encodedHeaders, err := h.Encode()
if err != nil {
return nil, errors.Wrap(err, `failed to encode protected headers`)
}
if len(encodedHeaders) > 2 {
if wrote {
fmt.Fprintf(buf, `,`)
}
wrote = true
fmt.Fprintf(buf, `%#v:%#v`, ProtectedHeadersKey, string(encodedHeaders))
}
}
if recipients := m.Recipients(); len(recipients) > 0 {
if wrote {
fmt.Fprintf(buf, `,`)
}
if len(recipients) == 1 { // Use flattened format
fmt.Fprintf(buf, `%#v:`, HeadersKey)
if err := enc.Encode(recipients[0].Headers()); err != nil {
return nil, errors.Wrapf(err, `failed to encode %s field`, HeadersKey)
}
if ek := recipients[0].EncryptedKey(); len(ek) > 0 {
fmt.Fprintf(buf, `,%#v:`, EncryptedKeyKey)
if err := enc.Encode(base64.EncodeToString(ek)); err != nil {
return nil, errors.Wrapf(err, `failed to encode %s field`, EncryptedKeyKey)
}
}
} else {
fmt.Fprintf(buf, `%#v:`, RecipientsKey)
if err := enc.Encode(recipients); err != nil {
return nil, errors.Wrapf(err, `failed to encode %s field`, RecipientsKey)
}
}
}
if tag := m.Tag(); len(tag) > 0 {
if wrote {
fmt.Fprintf(buf, `,`)
}
fmt.Fprintf(buf, `%#v:`, TagKey)
if err := enc.Encode(base64.EncodeToString(tag)); err != nil {
return nil, errors.Wrapf(err, `failed to encode %s field`, TagKey)
}
}
if h := m.UnprotectedHeaders(); h != nil {
unprotected, err := json.Marshal(h)
if err != nil {
return nil, errors.Wrap(err, `failed to encode unprotected headers`)
}
if len(unprotected) > 2 {
fmt.Fprintf(buf, `,%#v:%#v`, UnprotectedHeadersKey, string(unprotected))
}
}
fmt.Fprintf(buf, `}`)
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}
func (m *Message) UnmarshalJSON(buf []byte) error {
var proxy messageMarshalProxy
proxy.UnprotectedHeaders = NewHeaders()
if err := json.Unmarshal(buf, &proxy); err != nil {
return errors.Wrap(err, `failed to unmashal JSON into message`)
}
// Get the string value
var protectedHeadersStr string
if err := json.Unmarshal(proxy.ProtectedHeaders, &protectedHeadersStr); err != nil {
return errors.Wrap(err, `failed to decode protected headers (1)`)
}
// It's now in _quoted_ base64 string. Decode it
protectedHeadersRaw, err := base64.DecodeString(protectedHeadersStr)
if err != nil {
return errors.Wrap(err, "failed to base64 decoded protected headers buffer")
}
h := NewHeaders()
if err := json.Unmarshal(protectedHeadersRaw, h); err != nil {
return errors.Wrap(err, `failed to decode protected headers (2)`)
}
// if this were a flattened message, we would see a "header" and "ciphertext"
// field. TODO: do both of these conditions need to meet, or just one?
if proxy.Headers != nil || len(proxy.EncryptedKey) > 0 {
recipient := NewRecipient()
hdrs := NewHeaders()
if err := json.Unmarshal(proxy.Headers, hdrs); err != nil {
return errors.Wrap(err, `failed to decode headers field`)
}
if err := recipient.SetHeaders(hdrs); err != nil {
return errors.Wrap(err, `failed to set new headers`)
}
if v := proxy.EncryptedKey; len(v) > 0 {
buf, err := base64.DecodeString(v)
if err != nil {
return errors.Wrap(err, `failed to decode encrypted key`)
}
if err := recipient.SetEncryptedKey(buf); err != nil {
return errors.Wrap(err, `failed to set encrypted key`)
}
}
m.recipients = append(m.recipients, recipient)
} else {
for i, recipientbuf := range proxy.Recipients {
recipient := NewRecipient()
if err := json.Unmarshal(recipientbuf, recipient); err != nil {
return errors.Wrapf(err, `failed to decode recipient at index %d`, i)
}
m.recipients = append(m.recipients, recipient)
}
}
if src := proxy.AuthenticatedData; len(src) > 0 {
v, err := base64.DecodeString(src)
if err != nil {
return errors.Wrap(err, `failed to decode "aad"`)
}
m.authenticatedData = v
}
if src := proxy.CipherText; len(src) > 0 {
v, err := base64.DecodeString(src)
if err != nil {
return errors.Wrap(err, `failed to decode "ciphertext"`)
}
m.cipherText = v
}
if src := proxy.InitializationVector; len(src) > 0 {
v, err := base64.DecodeString(src)
if err != nil {
return errors.Wrap(err, `failed to decode "iv"`)
}
m.initializationVector = v
}
if src := proxy.Tag; len(src) > 0 {
v, err := base64.DecodeString(src)
if err != nil {
return errors.Wrap(err, `failed to decode "tag"`)
}
m.tag = v
}
m.protectedHeaders = h
if m.storeProtectedHeaders {
// this is later used for decryption
m.rawProtectedHeaders = base64.Encode(protectedHeadersRaw)
}
if iz, ok := proxy.UnprotectedHeaders.(isZeroer); ok {
if !iz.isZero() {
m.unprotectedHeaders = proxy.UnprotectedHeaders
}
}
if len(m.recipients) == 0 {
if err := m.makeDummyRecipient(proxy.EncryptedKey, m.protectedHeaders); err != nil {
return errors.Wrap(err, `failed to setup recipient`)
}
}
return nil
}
func (m *Message) makeDummyRecipient(enckeybuf string, protected Headers) error {
// Recipients in this case should not contain the content encryption key,
// so move that out
hdrs, err := protected.Clone(context.TODO())
if err != nil {
return errors.Wrap(err, `failed to clone headers`)
}
if err := hdrs.Remove(ContentEncryptionKey); err != nil {
return errors.Wrapf(err, "failed to remove %#v from public header", ContentEncryptionKey)
}
enckey, err := base64.DecodeString(enckeybuf)
if err != nil {
return errors.Wrap(err, `failed to decode encrypted key`)
}
if err := m.Set(RecipientsKey, []Recipient{
&stdRecipient{
headers: hdrs,
encryptedKey: enckey,
},
}); err != nil {
return errors.Wrapf(err, `failed to set %s`, RecipientsKey)
}
return nil
}
// Decrypt decrypts the message using the specified algorithm and key.
//
// `key` must be a private key in its "raw" format (i.e. something like
// *rsa.PrivateKey, instead of jwk.Key)
//
// This method is marked for deprecation. It will be removed from the API
// in the next major release. You should not rely on this method
// to work 100% of the time, especially when it was obtained via jwe.Parse
// instead of being constructed from scratch by this library.
func (m *Message) Decrypt(alg jwa.KeyEncryptionAlgorithm, key interface{}) ([]byte, error) {
var ctx decryptCtx
ctx.alg = alg
ctx.key = key
ctx.msg = m
return doDecryptCtx(&ctx)
}
func doDecryptCtx(dctx *decryptCtx) ([]byte, error) {
m := dctx.msg
alg := dctx.alg
key := dctx.key
if jwkKey, ok := key.(jwk.Key); ok {
var raw interface{}
if err := jwkKey.Raw(&raw); err != nil {
return nil, errors.Wrapf(err, `failed to retrieve raw key from %T`, key)
}
key = raw
}
var err error
ctx := context.TODO()
h, err := m.protectedHeaders.Clone(ctx)
if err != nil {
return nil, errors.Wrap(err, `failed to copy protected headers`)
}
h, err = h.Merge(ctx, m.unprotectedHeaders)
if err != nil {
return nil, errors.Wrap(err, "failed to merge headers for message decryption")
}
enc := m.protectedHeaders.ContentEncryption()
var aad []byte
if aadContainer := m.authenticatedData; aadContainer != nil {
aad = base64.Encode(aadContainer)
}
var computedAad []byte
if len(m.rawProtectedHeaders) > 0 {
computedAad = m.rawProtectedHeaders
} else {
// this is probably not required once msg.Decrypt is deprecated
var err error
computedAad, err = m.protectedHeaders.Encode()
if err != nil {
return nil, errors.Wrap(err, "failed to encode protected headers")
}
}
dec := NewDecrypter(alg, enc, key).
AuthenticatedData(aad).
ComputedAuthenticatedData(computedAad).
InitializationVector(m.initializationVector).
Tag(m.tag)
var plaintext []byte
var lastError error
// if we have no recipients, pretend like we only have one
recipients := m.recipients
if len(recipients) == 0 {
r := NewRecipient()
if err := r.SetHeaders(m.protectedHeaders); err != nil {
return nil, errors.Wrap(err, `failed to set headers to recipient`)
}
recipients = append(recipients, r)
}
for _, recipient := range recipients {
// strategy: try each recipient. If we fail in one of the steps,
// keep looping because there might be another key with the same algo
if recipient.Headers().Algorithm() != alg {
// algorithms don't match
continue
}
h2, err := h.Clone(ctx)
if err != nil {
lastError = errors.Wrap(err, `failed to copy headers (1)`)
continue
}
h2, err = h2.Merge(ctx, recipient.Headers())
if err != nil {
lastError = errors.Wrap(err, `failed to copy headers (2)`)
continue
}
switch alg {
case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
epkif, ok := h2.Get(EphemeralPublicKeyKey)
if !ok {
return nil, errors.New("failed to get 'epk' field")
}
switch epk := epkif.(type) {
case jwk.ECDSAPublicKey:
var pubkey ecdsa.PublicKey
if err := epk.Raw(&pubkey); err != nil {
return nil, errors.Wrap(err, "failed to get public key")
}
dec.PublicKey(&pubkey)
case jwk.OKPPublicKey:
var pubkey interface{}
if err := epk.Raw(&pubkey); err != nil {
return nil, errors.Wrap(err, "failed to get public key")
}
dec.PublicKey(pubkey)
default:
return nil, errors.Errorf("unexpected 'epk' type %T for alg %s", epkif, alg)
}
if apu := h2.AgreementPartyUInfo(); len(apu) > 0 {
dec.AgreementPartyUInfo(apu)
}
if apv := h2.AgreementPartyVInfo(); len(apv) > 0 {
dec.AgreementPartyVInfo(apv)
}
case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
ivB64, ok := h2.Get(InitializationVectorKey)
if !ok {
return nil, errors.New("failed to get 'iv' field")
}
ivB64Str, ok := ivB64.(string)
if !ok {
return nil, errors.Errorf("unexpected type for 'iv': %T", ivB64)
}
tagB64, ok := h2.Get(TagKey)
if !ok {
return nil, errors.New("failed to get 'tag' field")
}
tagB64Str, ok := tagB64.(string)
if !ok {
return nil, errors.Errorf("unexpected type for 'tag': %T", tagB64)
}
iv, err := base64.DecodeString(ivB64Str)
if err != nil {
return nil, errors.Wrap(err, "failed to b64-decode 'iv'")
}
tag, err := base64.DecodeString(tagB64Str)
if err != nil {
return nil, errors.Wrap(err, "failed to b64-decode 'tag'")
}
dec.KeyInitializationVector(iv)
dec.KeyTag(tag)
case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
saltB64, ok := h2.Get(SaltKey)
if !ok {
return nil, errors.New("failed to get 'p2s' field")
}
saltB64Str, ok := saltB64.(string)
if !ok {
return nil, errors.Errorf("unexpected type for 'p2s': %T", saltB64)
}
count, ok := h2.Get(CountKey)
if !ok {
return nil, errors.New("failed to get 'p2c' field")
}
countFlt, ok := count.(float64)
if !ok {
return nil, errors.Errorf("unexpected type for 'p2c': %T", count)
}
salt, err := base64.DecodeString(saltB64Str)
if err != nil {
return nil, errors.Wrap(err, "failed to b64-decode 'salt'")
}
dec.KeySalt(salt)
dec.KeyCount(int(countFlt))
}
plaintext, err = dec.Decrypt(recipient.EncryptedKey(), m.cipherText)
if err != nil {
lastError = errors.Wrap(err, `failed to decrypt`)
continue
}
if h2.Compression() == jwa.Deflate {
buf, err := uncompress(plaintext)
if err != nil {
lastError = errors.Wrap(err, `failed to uncompress payload`)
continue
}
plaintext = buf
}
break
}
if plaintext == nil {
if lastError != nil {
return nil, errors.Errorf(`failed to find matching recipient to decrypt key (last error = %s)`, lastError)
}
return nil, errors.New("failed to find matching recipient")
}
return plaintext, nil
}