package jwe import ( "context" "crypto/ecdsa" "fmt" "github.com/lestrrat-go/jwx/internal/json" "github.com/lestrrat-go/jwx/internal/pool" "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/internal/base64" "github.com/lestrrat-go/jwx/jwa" "github.com/pkg/errors" ) // NewRecipient creates a Recipient object func NewRecipient() Recipient { return &stdRecipient{ headers: NewHeaders(), } } func (r *stdRecipient) SetHeaders(h Headers) error { r.headers = h return nil } func (r *stdRecipient) SetEncryptedKey(v []byte) error { r.encryptedKey = v return nil } func (r *stdRecipient) Headers() Headers { return r.headers } func (r *stdRecipient) EncryptedKey() []byte { return r.encryptedKey } type recipientMarshalProxy struct { Headers Headers `json:"header"` EncryptedKey string `json:"encrypted_key"` } func (r *stdRecipient) UnmarshalJSON(buf []byte) error { var proxy recipientMarshalProxy proxy.Headers = NewHeaders() if err := json.Unmarshal(buf, &proxy); err != nil { return errors.Wrap(err, `failed to unmarshal json into recipient`) } r.headers = proxy.Headers decoded, err := base64.DecodeString(proxy.EncryptedKey) if err != nil { return errors.Wrap(err, `failed to decode "encrypted_key"`) } r.encryptedKey = decoded return nil } func (r *stdRecipient) MarshalJSON() ([]byte, error) { buf := pool.GetBytesBuffer() defer pool.ReleaseBytesBuffer(buf) buf.WriteString(`{"header":`) hdrbuf, err := r.headers.MarshalJSON() if err != nil { return nil, errors.Wrap(err, `failed to marshal recipient header`) } buf.Write(hdrbuf) buf.WriteString(`,"encrypted_key":"`) buf.WriteString(base64.EncodeToString(r.encryptedKey)) buf.WriteString(`"}`) ret := make([]byte, buf.Len()) copy(ret, buf.Bytes()) return ret, nil } // NewMessage creates a new message func NewMessage() *Message { return &Message{} } func (m *Message) AuthenticatedData() []byte { return m.authenticatedData } func (m *Message) CipherText() []byte { return m.cipherText } func (m *Message) InitializationVector() []byte { return m.initializationVector } func (m *Message) Tag() []byte { return m.tag } func (m *Message) ProtectedHeaders() Headers { return m.protectedHeaders } func (m *Message) Recipients() []Recipient { return m.recipients } func (m *Message) UnprotectedHeaders() Headers { return m.unprotectedHeaders } const ( AuthenticatedDataKey = "aad" CipherTextKey = "ciphertext" CountKey = "p2c" InitializationVectorKey = "iv" ProtectedHeadersKey = "protected" RecipientsKey = "recipients" SaltKey = "p2s" TagKey = "tag" UnprotectedHeadersKey = "unprotected" HeadersKey = "header" EncryptedKeyKey = "encrypted_key" ) func (m *Message) Set(k string, v interface{}) error { switch k { case AuthenticatedDataKey: buf, ok := v.([]byte) if !ok { return errors.Errorf(`invalid value %T for %s key`, v, AuthenticatedDataKey) } m.authenticatedData = buf case CipherTextKey: buf, ok := v.([]byte) if !ok { return errors.Errorf(`invalid value %T for %s key`, v, CipherTextKey) } m.cipherText = buf case InitializationVectorKey: buf, ok := v.([]byte) if !ok { return errors.Errorf(`invalid value %T for %s key`, v, InitializationVectorKey) } m.initializationVector = buf case ProtectedHeadersKey: cv, ok := v.(Headers) if !ok { return errors.Errorf(`invalid value %T for %s key`, v, ProtectedHeadersKey) } m.protectedHeaders = cv case RecipientsKey: cv, ok := v.([]Recipient) if !ok { return errors.Errorf(`invalid value %T for %s key`, v, RecipientsKey) } m.recipients = cv case TagKey: buf, ok := v.([]byte) if !ok { return errors.Errorf(`invalid value %T for %s key`, v, TagKey) } m.tag = buf case UnprotectedHeadersKey: cv, ok := v.(Headers) if !ok { return errors.Errorf(`invalid value %T for %s key`, v, UnprotectedHeadersKey) } m.unprotectedHeaders = cv default: if m.unprotectedHeaders == nil { m.unprotectedHeaders = NewHeaders() } return m.unprotectedHeaders.Set(k, v) } return nil } type messageMarshalProxy struct { AuthenticatedData string `json:"aad,omitempty"` CipherText string `json:"ciphertext"` InitializationVector string `json:"iv,omitempty"` ProtectedHeaders json.RawMessage `json:"protected"` Recipients []json.RawMessage `json:"recipients,omitempty"` Tag string `json:"tag,omitempty"` UnprotectedHeaders Headers `json:"unprotected,omitempty"` // For flattened structure. Headers is NOT a Headers type, // so that we can detect its presence by checking proxy.Headers != nil Headers json.RawMessage `json:"header,omitempty"` EncryptedKey string `json:"encrypted_key,omitempty"` } func (m *Message) MarshalJSON() ([]byte, error) { // This is slightly convoluted, but we need to encode the // protected headers, so we do it by hand buf := pool.GetBytesBuffer() defer pool.ReleaseBytesBuffer(buf) enc := json.NewEncoder(buf) fmt.Fprintf(buf, `{`) var wrote bool if aad := m.AuthenticatedData(); len(aad) > 0 { wrote = true fmt.Fprintf(buf, `%#v:`, AuthenticatedDataKey) if err := enc.Encode(base64.EncodeToString(aad)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, AuthenticatedDataKey) } } if cipherText := m.CipherText(); len(cipherText) > 0 { if wrote { fmt.Fprintf(buf, `,`) } wrote = true fmt.Fprintf(buf, `%#v:`, CipherTextKey) if err := enc.Encode(base64.EncodeToString(cipherText)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, CipherTextKey) } } if iv := m.InitializationVector(); len(iv) > 0 { if wrote { fmt.Fprintf(buf, `,`) } wrote = true fmt.Fprintf(buf, `%#v:`, InitializationVectorKey) if err := enc.Encode(base64.EncodeToString(iv)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, InitializationVectorKey) } } if h := m.ProtectedHeaders(); h != nil { encodedHeaders, err := h.Encode() if err != nil { return nil, errors.Wrap(err, `failed to encode protected headers`) } if len(encodedHeaders) > 2 { if wrote { fmt.Fprintf(buf, `,`) } wrote = true fmt.Fprintf(buf, `%#v:%#v`, ProtectedHeadersKey, string(encodedHeaders)) } } if recipients := m.Recipients(); len(recipients) > 0 { if wrote { fmt.Fprintf(buf, `,`) } if len(recipients) == 1 { // Use flattened format fmt.Fprintf(buf, `%#v:`, HeadersKey) if err := enc.Encode(recipients[0].Headers()); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, HeadersKey) } if ek := recipients[0].EncryptedKey(); len(ek) > 0 { fmt.Fprintf(buf, `,%#v:`, EncryptedKeyKey) if err := enc.Encode(base64.EncodeToString(ek)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, EncryptedKeyKey) } } } else { fmt.Fprintf(buf, `%#v:`, RecipientsKey) if err := enc.Encode(recipients); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, RecipientsKey) } } } if tag := m.Tag(); len(tag) > 0 { if wrote { fmt.Fprintf(buf, `,`) } fmt.Fprintf(buf, `%#v:`, TagKey) if err := enc.Encode(base64.EncodeToString(tag)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, TagKey) } } if h := m.UnprotectedHeaders(); h != nil { unprotected, err := json.Marshal(h) if err != nil { return nil, errors.Wrap(err, `failed to encode unprotected headers`) } if len(unprotected) > 2 { fmt.Fprintf(buf, `,%#v:%#v`, UnprotectedHeadersKey, string(unprotected)) } } fmt.Fprintf(buf, `}`) ret := make([]byte, buf.Len()) copy(ret, buf.Bytes()) return ret, nil } func (m *Message) UnmarshalJSON(buf []byte) error { var proxy messageMarshalProxy proxy.UnprotectedHeaders = NewHeaders() if err := json.Unmarshal(buf, &proxy); err != nil { return errors.Wrap(err, `failed to unmashal JSON into message`) } // Get the string value var protectedHeadersStr string if err := json.Unmarshal(proxy.ProtectedHeaders, &protectedHeadersStr); err != nil { return errors.Wrap(err, `failed to decode protected headers (1)`) } // It's now in _quoted_ base64 string. Decode it protectedHeadersRaw, err := base64.DecodeString(protectedHeadersStr) if err != nil { return errors.Wrap(err, "failed to base64 decoded protected headers buffer") } h := NewHeaders() if err := json.Unmarshal(protectedHeadersRaw, h); err != nil { return errors.Wrap(err, `failed to decode protected headers (2)`) } // if this were a flattened message, we would see a "header" and "ciphertext" // field. TODO: do both of these conditions need to meet, or just one? if proxy.Headers != nil || len(proxy.EncryptedKey) > 0 { recipient := NewRecipient() hdrs := NewHeaders() if err := json.Unmarshal(proxy.Headers, hdrs); err != nil { return errors.Wrap(err, `failed to decode headers field`) } if err := recipient.SetHeaders(hdrs); err != nil { return errors.Wrap(err, `failed to set new headers`) } if v := proxy.EncryptedKey; len(v) > 0 { buf, err := base64.DecodeString(v) if err != nil { return errors.Wrap(err, `failed to decode encrypted key`) } if err := recipient.SetEncryptedKey(buf); err != nil { return errors.Wrap(err, `failed to set encrypted key`) } } m.recipients = append(m.recipients, recipient) } else { for i, recipientbuf := range proxy.Recipients { recipient := NewRecipient() if err := json.Unmarshal(recipientbuf, recipient); err != nil { return errors.Wrapf(err, `failed to decode recipient at index %d`, i) } m.recipients = append(m.recipients, recipient) } } if src := proxy.AuthenticatedData; len(src) > 0 { v, err := base64.DecodeString(src) if err != nil { return errors.Wrap(err, `failed to decode "aad"`) } m.authenticatedData = v } if src := proxy.CipherText; len(src) > 0 { v, err := base64.DecodeString(src) if err != nil { return errors.Wrap(err, `failed to decode "ciphertext"`) } m.cipherText = v } if src := proxy.InitializationVector; len(src) > 0 { v, err := base64.DecodeString(src) if err != nil { return errors.Wrap(err, `failed to decode "iv"`) } m.initializationVector = v } if src := proxy.Tag; len(src) > 0 { v, err := base64.DecodeString(src) if err != nil { return errors.Wrap(err, `failed to decode "tag"`) } m.tag = v } m.protectedHeaders = h if m.storeProtectedHeaders { // this is later used for decryption m.rawProtectedHeaders = base64.Encode(protectedHeadersRaw) } if iz, ok := proxy.UnprotectedHeaders.(isZeroer); ok { if !iz.isZero() { m.unprotectedHeaders = proxy.UnprotectedHeaders } } if len(m.recipients) == 0 { if err := m.makeDummyRecipient(proxy.EncryptedKey, m.protectedHeaders); err != nil { return errors.Wrap(err, `failed to setup recipient`) } } return nil } func (m *Message) makeDummyRecipient(enckeybuf string, protected Headers) error { // Recipients in this case should not contain the content encryption key, // so move that out hdrs, err := protected.Clone(context.TODO()) if err != nil { return errors.Wrap(err, `failed to clone headers`) } if err := hdrs.Remove(ContentEncryptionKey); err != nil { return errors.Wrapf(err, "failed to remove %#v from public header", ContentEncryptionKey) } enckey, err := base64.DecodeString(enckeybuf) if err != nil { return errors.Wrap(err, `failed to decode encrypted key`) } if err := m.Set(RecipientsKey, []Recipient{ &stdRecipient{ headers: hdrs, encryptedKey: enckey, }, }); err != nil { return errors.Wrapf(err, `failed to set %s`, RecipientsKey) } return nil } // Decrypt decrypts the message using the specified algorithm and key. // // `key` must be a private key in its "raw" format (i.e. something like // *rsa.PrivateKey, instead of jwk.Key) // // This method is marked for deprecation. It will be removed from the API // in the next major release. You should not rely on this method // to work 100% of the time, especially when it was obtained via jwe.Parse // instead of being constructed from scratch by this library. func (m *Message) Decrypt(alg jwa.KeyEncryptionAlgorithm, key interface{}) ([]byte, error) { var ctx decryptCtx ctx.alg = alg ctx.key = key ctx.msg = m return doDecryptCtx(&ctx) } func doDecryptCtx(dctx *decryptCtx) ([]byte, error) { m := dctx.msg alg := dctx.alg key := dctx.key if jwkKey, ok := key.(jwk.Key); ok { var raw interface{} if err := jwkKey.Raw(&raw); err != nil { return nil, errors.Wrapf(err, `failed to retrieve raw key from %T`, key) } key = raw } var err error ctx := context.TODO() h, err := m.protectedHeaders.Clone(ctx) if err != nil { return nil, errors.Wrap(err, `failed to copy protected headers`) } h, err = h.Merge(ctx, m.unprotectedHeaders) if err != nil { return nil, errors.Wrap(err, "failed to merge headers for message decryption") } enc := m.protectedHeaders.ContentEncryption() var aad []byte if aadContainer := m.authenticatedData; aadContainer != nil { aad = base64.Encode(aadContainer) } var computedAad []byte if len(m.rawProtectedHeaders) > 0 { computedAad = m.rawProtectedHeaders } else { // this is probably not required once msg.Decrypt is deprecated var err error computedAad, err = m.protectedHeaders.Encode() if err != nil { return nil, errors.Wrap(err, "failed to encode protected headers") } } dec := NewDecrypter(alg, enc, key). AuthenticatedData(aad). ComputedAuthenticatedData(computedAad). InitializationVector(m.initializationVector). Tag(m.tag) var plaintext []byte var lastError error // if we have no recipients, pretend like we only have one recipients := m.recipients if len(recipients) == 0 { r := NewRecipient() if err := r.SetHeaders(m.protectedHeaders); err != nil { return nil, errors.Wrap(err, `failed to set headers to recipient`) } recipients = append(recipients, r) } for _, recipient := range recipients { // strategy: try each recipient. If we fail in one of the steps, // keep looping because there might be another key with the same algo if recipient.Headers().Algorithm() != alg { // algorithms don't match continue } h2, err := h.Clone(ctx) if err != nil { lastError = errors.Wrap(err, `failed to copy headers (1)`) continue } h2, err = h2.Merge(ctx, recipient.Headers()) if err != nil { lastError = errors.Wrap(err, `failed to copy headers (2)`) continue } switch alg { case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: epkif, ok := h2.Get(EphemeralPublicKeyKey) if !ok { return nil, errors.New("failed to get 'epk' field") } switch epk := epkif.(type) { case jwk.ECDSAPublicKey: var pubkey ecdsa.PublicKey if err := epk.Raw(&pubkey); err != nil { return nil, errors.Wrap(err, "failed to get public key") } dec.PublicKey(&pubkey) case jwk.OKPPublicKey: var pubkey interface{} if err := epk.Raw(&pubkey); err != nil { return nil, errors.Wrap(err, "failed to get public key") } dec.PublicKey(pubkey) default: return nil, errors.Errorf("unexpected 'epk' type %T for alg %s", epkif, alg) } if apu := h2.AgreementPartyUInfo(); len(apu) > 0 { dec.AgreementPartyUInfo(apu) } if apv := h2.AgreementPartyVInfo(); len(apv) > 0 { dec.AgreementPartyVInfo(apv) } case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW: ivB64, ok := h2.Get(InitializationVectorKey) if !ok { return nil, errors.New("failed to get 'iv' field") } ivB64Str, ok := ivB64.(string) if !ok { return nil, errors.Errorf("unexpected type for 'iv': %T", ivB64) } tagB64, ok := h2.Get(TagKey) if !ok { return nil, errors.New("failed to get 'tag' field") } tagB64Str, ok := tagB64.(string) if !ok { return nil, errors.Errorf("unexpected type for 'tag': %T", tagB64) } iv, err := base64.DecodeString(ivB64Str) if err != nil { return nil, errors.Wrap(err, "failed to b64-decode 'iv'") } tag, err := base64.DecodeString(tagB64Str) if err != nil { return nil, errors.Wrap(err, "failed to b64-decode 'tag'") } dec.KeyInitializationVector(iv) dec.KeyTag(tag) case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW: saltB64, ok := h2.Get(SaltKey) if !ok { return nil, errors.New("failed to get 'p2s' field") } saltB64Str, ok := saltB64.(string) if !ok { return nil, errors.Errorf("unexpected type for 'p2s': %T", saltB64) } count, ok := h2.Get(CountKey) if !ok { return nil, errors.New("failed to get 'p2c' field") } countFlt, ok := count.(float64) if !ok { return nil, errors.Errorf("unexpected type for 'p2c': %T", count) } salt, err := base64.DecodeString(saltB64Str) if err != nil { return nil, errors.Wrap(err, "failed to b64-decode 'salt'") } dec.KeySalt(salt) dec.KeyCount(int(countFlt)) } plaintext, err = dec.Decrypt(recipient.EncryptedKey(), m.cipherText) if err != nil { lastError = errors.Wrap(err, `failed to decrypt`) continue } if h2.Compression() == jwa.Deflate { buf, err := uncompress(plaintext) if err != nil { lastError = errors.Wrap(err, `failed to uncompress payload`) continue } plaintext = buf } break } if plaintext == nil { if lastError != nil { return nil, errors.Errorf(`failed to find matching recipient to decrypt key (last error = %s)`, lastError) } return nil, errors.New("failed to find matching recipient") } return plaintext, nil }