auth/vendor/github.com/lestrrat-go/jwx/v2/jwe/message.go

548 lines
14 KiB
Go

package jwe
import (
"context"
"fmt"
"sort"
"strings"
"github.com/lestrrat-go/jwx/v2/internal/base64"
"github.com/lestrrat-go/jwx/v2/internal/json"
"github.com/lestrrat-go/jwx/v2/internal/pool"
)
// NewRecipient creates a Recipient object
func NewRecipient() Recipient {
return &stdRecipient{
headers: NewHeaders(),
}
}
func (r *stdRecipient) SetHeaders(h Headers) error {
r.headers = h
return nil
}
func (r *stdRecipient) SetEncryptedKey(v []byte) error {
r.encryptedKey = v
return nil
}
func (r *stdRecipient) Headers() Headers {
return r.headers
}
func (r *stdRecipient) EncryptedKey() []byte {
return r.encryptedKey
}
type recipientMarshalProxy struct {
Headers Headers `json:"header"`
EncryptedKey string `json:"encrypted_key"`
}
func (r *stdRecipient) UnmarshalJSON(buf []byte) error {
var proxy recipientMarshalProxy
proxy.Headers = NewHeaders()
if err := json.Unmarshal(buf, &proxy); err != nil {
return fmt.Errorf(`failed to unmarshal json into recipient: %w`, err)
}
r.headers = proxy.Headers
decoded, err := base64.DecodeString(proxy.EncryptedKey)
if err != nil {
return fmt.Errorf(`failed to decode "encrypted_key": %w`, err)
}
r.encryptedKey = decoded
return nil
}
func (r *stdRecipient) MarshalJSON() ([]byte, error) {
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
buf.WriteString(`{"header":`)
hdrbuf, err := r.headers.MarshalJSON()
if err != nil {
return nil, fmt.Errorf(`failed to marshal recipient header: %w`, err)
}
buf.Write(hdrbuf)
buf.WriteString(`,"encrypted_key":"`)
buf.WriteString(base64.EncodeToString(r.encryptedKey))
buf.WriteString(`"}`)
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}
// NewMessage creates a new message
func NewMessage() *Message {
return &Message{}
}
func (m *Message) AuthenticatedData() []byte {
return m.authenticatedData
}
func (m *Message) CipherText() []byte {
return m.cipherText
}
func (m *Message) InitializationVector() []byte {
return m.initializationVector
}
func (m *Message) Tag() []byte {
return m.tag
}
func (m *Message) ProtectedHeaders() Headers {
return m.protectedHeaders
}
func (m *Message) Recipients() []Recipient {
return m.recipients
}
func (m *Message) UnprotectedHeaders() Headers {
return m.unprotectedHeaders
}
const (
AuthenticatedDataKey = "aad"
CipherTextKey = "ciphertext"
CountKey = "p2c"
InitializationVectorKey = "iv"
ProtectedHeadersKey = "protected"
RecipientsKey = "recipients"
SaltKey = "p2s"
TagKey = "tag"
UnprotectedHeadersKey = "unprotected"
HeadersKey = "header"
EncryptedKeyKey = "encrypted_key"
)
func (m *Message) Set(k string, v interface{}) error {
switch k {
case AuthenticatedDataKey:
buf, ok := v.([]byte)
if !ok {
return fmt.Errorf(`invalid value %T for %s key`, v, AuthenticatedDataKey)
}
m.authenticatedData = buf
case CipherTextKey:
buf, ok := v.([]byte)
if !ok {
return fmt.Errorf(`invalid value %T for %s key`, v, CipherTextKey)
}
m.cipherText = buf
case InitializationVectorKey:
buf, ok := v.([]byte)
if !ok {
return fmt.Errorf(`invalid value %T for %s key`, v, InitializationVectorKey)
}
m.initializationVector = buf
case ProtectedHeadersKey:
cv, ok := v.(Headers)
if !ok {
return fmt.Errorf(`invalid value %T for %s key`, v, ProtectedHeadersKey)
}
m.protectedHeaders = cv
case RecipientsKey:
cv, ok := v.([]Recipient)
if !ok {
return fmt.Errorf(`invalid value %T for %s key`, v, RecipientsKey)
}
m.recipients = cv
case TagKey:
buf, ok := v.([]byte)
if !ok {
return fmt.Errorf(`invalid value %T for %s key`, v, TagKey)
}
m.tag = buf
case UnprotectedHeadersKey:
cv, ok := v.(Headers)
if !ok {
return fmt.Errorf(`invalid value %T for %s key`, v, UnprotectedHeadersKey)
}
m.unprotectedHeaders = cv
default:
if m.unprotectedHeaders == nil {
m.unprotectedHeaders = NewHeaders()
}
return m.unprotectedHeaders.Set(k, v)
}
return nil
}
type messageMarshalProxy struct {
AuthenticatedData string `json:"aad,omitempty"`
CipherText string `json:"ciphertext"`
InitializationVector string `json:"iv,omitempty"`
ProtectedHeaders json.RawMessage `json:"protected"`
Recipients []json.RawMessage `json:"recipients,omitempty"`
Tag string `json:"tag,omitempty"`
UnprotectedHeaders Headers `json:"unprotected,omitempty"`
// For flattened structure. Headers is NOT a Headers type,
// so that we can detect its presence by checking proxy.Headers != nil
Headers json.RawMessage `json:"header,omitempty"`
EncryptedKey string `json:"encrypted_key,omitempty"`
}
type jsonKV struct {
Key string
Value string
}
func (m *Message) MarshalJSON() ([]byte, error) {
// This is slightly convoluted, but we need to encode the
// protected headers, so we do it by hand
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
enc := json.NewEncoder(buf)
var fields []jsonKV
if cipherText := m.CipherText(); len(cipherText) > 0 {
buf.Reset()
if err := enc.Encode(base64.EncodeToString(cipherText)); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, CipherTextKey, err)
}
fields = append(fields, jsonKV{
Key: CipherTextKey,
Value: strings.TrimSpace(buf.String()),
})
}
if iv := m.InitializationVector(); len(iv) > 0 {
buf.Reset()
if err := enc.Encode(base64.EncodeToString(iv)); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, InitializationVectorKey, err)
}
fields = append(fields, jsonKV{
Key: InitializationVectorKey,
Value: strings.TrimSpace(buf.String()),
})
}
var encodedProtectedHeaders []byte
if h := m.ProtectedHeaders(); h != nil {
v, err := h.Encode()
if err != nil {
return nil, fmt.Errorf(`failed to encode protected headers: %w`, err)
}
encodedProtectedHeaders = v
if len(encodedProtectedHeaders) <= 2 { // '{}'
encodedProtectedHeaders = nil
} else {
fields = append(fields, jsonKV{
Key: ProtectedHeadersKey,
Value: fmt.Sprintf("%q", encodedProtectedHeaders),
})
}
}
if aad := m.AuthenticatedData(); len(aad) > 0 {
aad = base64.Encode(aad)
if encodedProtectedHeaders != nil {
tmp := append(encodedProtectedHeaders, '.')
aad = append(tmp, aad...)
}
buf.Reset()
if err := enc.Encode(aad); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, AuthenticatedDataKey, err)
}
fields = append(fields, jsonKV{
Key: AuthenticatedDataKey,
Value: strings.TrimSpace(buf.String()),
})
}
if recipients := m.Recipients(); len(recipients) > 0 {
if len(recipients) == 1 { // Use flattened format
if hdrs := recipients[0].Headers(); hdrs != nil {
buf.Reset()
if err := enc.Encode(hdrs); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, HeadersKey, err)
}
fields = append(fields, jsonKV{
Key: HeadersKey,
Value: strings.TrimSpace(buf.String()),
})
}
if ek := recipients[0].EncryptedKey(); len(ek) > 0 {
buf.Reset()
if err := enc.Encode(base64.EncodeToString(ek)); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, EncryptedKeyKey, err)
}
fields = append(fields, jsonKV{
Key: EncryptedKeyKey,
Value: strings.TrimSpace(buf.String()),
})
}
} else {
buf.Reset()
if err := enc.Encode(recipients); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, RecipientsKey, err)
}
fields = append(fields, jsonKV{
Key: RecipientsKey,
Value: strings.TrimSpace(buf.String()),
})
}
}
if tag := m.Tag(); len(tag) > 0 {
buf.Reset()
if err := enc.Encode(base64.EncodeToString(tag)); err != nil {
return nil, fmt.Errorf(`failed to encode %s field: %w`, TagKey, err)
}
fields = append(fields, jsonKV{
Key: TagKey,
Value: strings.TrimSpace(buf.String()),
})
}
if h := m.UnprotectedHeaders(); h != nil {
unprotected, err := json.Marshal(h)
if err != nil {
return nil, fmt.Errorf(`failed to encode unprotected headers: %w`, err)
}
if len(unprotected) > 2 {
fields = append(fields, jsonKV{
Key: UnprotectedHeadersKey,
Value: fmt.Sprintf("%q", unprotected),
})
}
}
sort.Slice(fields, func(i, j int) bool {
return fields[i].Key < fields[j].Key
})
buf.Reset()
fmt.Fprintf(buf, `{`)
for i, kv := range fields {
if i > 0 {
fmt.Fprintf(buf, `,`)
}
fmt.Fprintf(buf, `%q:%s`, kv.Key, kv.Value)
}
fmt.Fprintf(buf, `}`)
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}
func (m *Message) UnmarshalJSON(buf []byte) error {
var proxy messageMarshalProxy
proxy.UnprotectedHeaders = NewHeaders()
if err := json.Unmarshal(buf, &proxy); err != nil {
return fmt.Errorf(`failed to unmashal JSON into message: %w`, err)
}
// Get the string value
var protectedHeadersStr string
if err := json.Unmarshal(proxy.ProtectedHeaders, &protectedHeadersStr); err != nil {
return fmt.Errorf(`failed to decode protected headers (1): %w`, err)
}
// It's now in _quoted_ base64 string. Decode it
protectedHeadersRaw, err := base64.DecodeString(protectedHeadersStr)
if err != nil {
return fmt.Errorf(`failed to base64 decoded protected headers buffer: %w`, err)
}
h := NewHeaders()
if err := json.Unmarshal(protectedHeadersRaw, h); err != nil {
return fmt.Errorf(`failed to decode protected headers (2): %w`, err)
}
// if this were a flattened message, we would see a "header" and "ciphertext"
// field. TODO: do both of these conditions need to meet, or just one?
if proxy.Headers != nil || len(proxy.EncryptedKey) > 0 {
recipient := NewRecipient()
hdrs := NewHeaders()
if err := json.Unmarshal(proxy.Headers, hdrs); err != nil {
return fmt.Errorf(`failed to decode headers field: %w`, err)
}
if err := recipient.SetHeaders(hdrs); err != nil {
return fmt.Errorf(`failed to set new headers: %w`, err)
}
if v := proxy.EncryptedKey; len(v) > 0 {
buf, err := base64.DecodeString(v)
if err != nil {
return fmt.Errorf(`failed to decode encrypted key: %w`, err)
}
if err := recipient.SetEncryptedKey(buf); err != nil {
return fmt.Errorf(`failed to set encrypted key: %w`, err)
}
}
m.recipients = append(m.recipients, recipient)
} else {
for i, recipientbuf := range proxy.Recipients {
recipient := NewRecipient()
if err := json.Unmarshal(recipientbuf, recipient); err != nil {
return fmt.Errorf(`failed to decode recipient at index %d: %w`, i, err)
}
m.recipients = append(m.recipients, recipient)
}
}
if src := proxy.AuthenticatedData; len(src) > 0 {
v, err := base64.DecodeString(src)
if err != nil {
return fmt.Errorf(`failed to decode "aad": %w`, err)
}
m.authenticatedData = v
}
if src := proxy.CipherText; len(src) > 0 {
v, err := base64.DecodeString(src)
if err != nil {
return fmt.Errorf(`failed to decode "ciphertext": %w`, err)
}
m.cipherText = v
}
if src := proxy.InitializationVector; len(src) > 0 {
v, err := base64.DecodeString(src)
if err != nil {
return fmt.Errorf(`failed to decode "iv": %w`, err)
}
m.initializationVector = v
}
if src := proxy.Tag; len(src) > 0 {
v, err := base64.DecodeString(src)
if err != nil {
return fmt.Errorf(`failed to decode "tag": %w`, err)
}
m.tag = v
}
m.protectedHeaders = h
if m.storeProtectedHeaders {
// this is later used for decryption
m.rawProtectedHeaders = base64.Encode(protectedHeadersRaw)
}
if iz, ok := proxy.UnprotectedHeaders.(isZeroer); ok {
if !iz.isZero() {
m.unprotectedHeaders = proxy.UnprotectedHeaders
}
}
if len(m.recipients) == 0 {
if err := m.makeDummyRecipient(proxy.EncryptedKey, m.protectedHeaders); err != nil {
return fmt.Errorf(`failed to setup recipient: %w`, err)
}
}
return nil
}
func (m *Message) makeDummyRecipient(enckeybuf string, protected Headers) error {
// Recipients in this case should not contain the content encryption key,
// so move that out
hdrs, err := protected.Clone(context.TODO())
if err != nil {
return fmt.Errorf(`failed to clone headers: %w`, err)
}
if err := hdrs.Remove(ContentEncryptionKey); err != nil {
return fmt.Errorf(`failed to remove %#v from public header: %w`, ContentEncryptionKey, err)
}
enckey, err := base64.DecodeString(enckeybuf)
if err != nil {
return fmt.Errorf(`failed to decode encrypted key: %w`, err)
}
if err := m.Set(RecipientsKey, []Recipient{
&stdRecipient{
headers: hdrs,
encryptedKey: enckey,
},
}); err != nil {
return fmt.Errorf(`failed to set %s: %w`, RecipientsKey, err)
}
return nil
}
// Compact generates a JWE message in compact serialization format from a
// `*jwe.Message` object. The object contain exactly one recipient, or
// an error is returned.
//
// This function currently does not take any options, but the function
// signature contains `options` for possible future expansion of the API
func Compact(m *Message, _ ...CompactOption) ([]byte, error) {
if len(m.recipients) != 1 {
return nil, fmt.Errorf(`wrong number of recipients for compact serialization`)
}
recipient := m.recipients[0]
// The protected header must be a merge between the message-wide
// protected header AND the recipient header
// There's something wrong if m.protectedHeaders is nil, but
// it could happen
if m.protectedHeaders == nil {
return nil, fmt.Errorf(`invalid protected header`)
}
ctx := context.TODO()
hcopy, err := m.protectedHeaders.Clone(ctx)
if err != nil {
return nil, fmt.Errorf(`failed to copy protected header: %w`, err)
}
hcopy, err = hcopy.Merge(ctx, m.unprotectedHeaders)
if err != nil {
return nil, fmt.Errorf(`failed to merge unprotected header: %w`, err)
}
hcopy, err = hcopy.Merge(ctx, recipient.Headers())
if err != nil {
return nil, fmt.Errorf(`failed to merge recipient header: %w`, err)
}
protected, err := hcopy.Encode()
if err != nil {
return nil, fmt.Errorf(`failed to encode header: %w`, err)
}
encryptedKey := base64.Encode(recipient.EncryptedKey())
iv := base64.Encode(m.initializationVector)
cipher := base64.Encode(m.cipherText)
tag := base64.Encode(m.tag)
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
buf.Grow(len(protected) + len(encryptedKey) + len(iv) + len(cipher) + len(tag) + 4)
buf.Write(protected)
buf.WriteByte('.')
buf.Write(encryptedKey)
buf.WriteByte('.')
buf.Write(iv)
buf.WriteByte('.')
buf.Write(cipher)
buf.WriteByte('.')
buf.Write(tag)
result := make([]byte, buf.Len())
copy(result, buf.Bytes())
return result, nil
}