548 lines
14 KiB
Go
548 lines
14 KiB
Go
package jwe
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/lestrrat-go/jwx/v2/internal/base64"
|
|
"github.com/lestrrat-go/jwx/v2/internal/json"
|
|
"github.com/lestrrat-go/jwx/v2/internal/pool"
|
|
)
|
|
|
|
// NewRecipient creates a Recipient object
|
|
func NewRecipient() Recipient {
|
|
return &stdRecipient{
|
|
headers: NewHeaders(),
|
|
}
|
|
}
|
|
|
|
func (r *stdRecipient) SetHeaders(h Headers) error {
|
|
r.headers = h
|
|
return nil
|
|
}
|
|
|
|
func (r *stdRecipient) SetEncryptedKey(v []byte) error {
|
|
r.encryptedKey = v
|
|
return nil
|
|
}
|
|
|
|
func (r *stdRecipient) Headers() Headers {
|
|
return r.headers
|
|
}
|
|
|
|
func (r *stdRecipient) EncryptedKey() []byte {
|
|
return r.encryptedKey
|
|
}
|
|
|
|
type recipientMarshalProxy struct {
|
|
Headers Headers `json:"header"`
|
|
EncryptedKey string `json:"encrypted_key"`
|
|
}
|
|
|
|
func (r *stdRecipient) UnmarshalJSON(buf []byte) error {
|
|
var proxy recipientMarshalProxy
|
|
proxy.Headers = NewHeaders()
|
|
if err := json.Unmarshal(buf, &proxy); err != nil {
|
|
return fmt.Errorf(`failed to unmarshal json into recipient: %w`, err)
|
|
}
|
|
|
|
r.headers = proxy.Headers
|
|
decoded, err := base64.DecodeString(proxy.EncryptedKey)
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to decode "encrypted_key": %w`, err)
|
|
}
|
|
r.encryptedKey = decoded
|
|
return nil
|
|
}
|
|
|
|
func (r *stdRecipient) MarshalJSON() ([]byte, error) {
|
|
buf := pool.GetBytesBuffer()
|
|
defer pool.ReleaseBytesBuffer(buf)
|
|
|
|
buf.WriteString(`{"header":`)
|
|
hdrbuf, err := r.headers.MarshalJSON()
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`failed to marshal recipient header: %w`, err)
|
|
}
|
|
buf.Write(hdrbuf)
|
|
buf.WriteString(`,"encrypted_key":"`)
|
|
buf.WriteString(base64.EncodeToString(r.encryptedKey))
|
|
buf.WriteString(`"}`)
|
|
|
|
ret := make([]byte, buf.Len())
|
|
copy(ret, buf.Bytes())
|
|
return ret, nil
|
|
}
|
|
|
|
// NewMessage creates a new message
|
|
func NewMessage() *Message {
|
|
return &Message{}
|
|
}
|
|
|
|
func (m *Message) AuthenticatedData() []byte {
|
|
return m.authenticatedData
|
|
}
|
|
|
|
func (m *Message) CipherText() []byte {
|
|
return m.cipherText
|
|
}
|
|
|
|
func (m *Message) InitializationVector() []byte {
|
|
return m.initializationVector
|
|
}
|
|
|
|
func (m *Message) Tag() []byte {
|
|
return m.tag
|
|
}
|
|
|
|
func (m *Message) ProtectedHeaders() Headers {
|
|
return m.protectedHeaders
|
|
}
|
|
|
|
func (m *Message) Recipients() []Recipient {
|
|
return m.recipients
|
|
}
|
|
|
|
func (m *Message) UnprotectedHeaders() Headers {
|
|
return m.unprotectedHeaders
|
|
}
|
|
|
|
const (
|
|
AuthenticatedDataKey = "aad"
|
|
CipherTextKey = "ciphertext"
|
|
CountKey = "p2c"
|
|
InitializationVectorKey = "iv"
|
|
ProtectedHeadersKey = "protected"
|
|
RecipientsKey = "recipients"
|
|
SaltKey = "p2s"
|
|
TagKey = "tag"
|
|
UnprotectedHeadersKey = "unprotected"
|
|
HeadersKey = "header"
|
|
EncryptedKeyKey = "encrypted_key"
|
|
)
|
|
|
|
func (m *Message) Set(k string, v interface{}) error {
|
|
switch k {
|
|
case AuthenticatedDataKey:
|
|
buf, ok := v.([]byte)
|
|
if !ok {
|
|
return fmt.Errorf(`invalid value %T for %s key`, v, AuthenticatedDataKey)
|
|
}
|
|
m.authenticatedData = buf
|
|
case CipherTextKey:
|
|
buf, ok := v.([]byte)
|
|
if !ok {
|
|
return fmt.Errorf(`invalid value %T for %s key`, v, CipherTextKey)
|
|
}
|
|
m.cipherText = buf
|
|
case InitializationVectorKey:
|
|
buf, ok := v.([]byte)
|
|
if !ok {
|
|
return fmt.Errorf(`invalid value %T for %s key`, v, InitializationVectorKey)
|
|
}
|
|
m.initializationVector = buf
|
|
case ProtectedHeadersKey:
|
|
cv, ok := v.(Headers)
|
|
if !ok {
|
|
return fmt.Errorf(`invalid value %T for %s key`, v, ProtectedHeadersKey)
|
|
}
|
|
m.protectedHeaders = cv
|
|
case RecipientsKey:
|
|
cv, ok := v.([]Recipient)
|
|
if !ok {
|
|
return fmt.Errorf(`invalid value %T for %s key`, v, RecipientsKey)
|
|
}
|
|
m.recipients = cv
|
|
case TagKey:
|
|
buf, ok := v.([]byte)
|
|
if !ok {
|
|
return fmt.Errorf(`invalid value %T for %s key`, v, TagKey)
|
|
}
|
|
m.tag = buf
|
|
case UnprotectedHeadersKey:
|
|
cv, ok := v.(Headers)
|
|
if !ok {
|
|
return fmt.Errorf(`invalid value %T for %s key`, v, UnprotectedHeadersKey)
|
|
}
|
|
m.unprotectedHeaders = cv
|
|
default:
|
|
if m.unprotectedHeaders == nil {
|
|
m.unprotectedHeaders = NewHeaders()
|
|
}
|
|
return m.unprotectedHeaders.Set(k, v)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type messageMarshalProxy struct {
|
|
AuthenticatedData string `json:"aad,omitempty"`
|
|
CipherText string `json:"ciphertext"`
|
|
InitializationVector string `json:"iv,omitempty"`
|
|
ProtectedHeaders json.RawMessage `json:"protected"`
|
|
Recipients []json.RawMessage `json:"recipients,omitempty"`
|
|
Tag string `json:"tag,omitempty"`
|
|
UnprotectedHeaders Headers `json:"unprotected,omitempty"`
|
|
|
|
// For flattened structure. Headers is NOT a Headers type,
|
|
// so that we can detect its presence by checking proxy.Headers != nil
|
|
Headers json.RawMessage `json:"header,omitempty"`
|
|
EncryptedKey string `json:"encrypted_key,omitempty"`
|
|
}
|
|
|
|
type jsonKV struct {
|
|
Key string
|
|
Value string
|
|
}
|
|
|
|
func (m *Message) MarshalJSON() ([]byte, error) {
|
|
// This is slightly convoluted, but we need to encode the
|
|
// protected headers, so we do it by hand
|
|
buf := pool.GetBytesBuffer()
|
|
defer pool.ReleaseBytesBuffer(buf)
|
|
enc := json.NewEncoder(buf)
|
|
|
|
var fields []jsonKV
|
|
|
|
if cipherText := m.CipherText(); len(cipherText) > 0 {
|
|
buf.Reset()
|
|
if err := enc.Encode(base64.EncodeToString(cipherText)); err != nil {
|
|
return nil, fmt.Errorf(`failed to encode %s field: %w`, CipherTextKey, err)
|
|
}
|
|
fields = append(fields, jsonKV{
|
|
Key: CipherTextKey,
|
|
Value: strings.TrimSpace(buf.String()),
|
|
})
|
|
}
|
|
|
|
if iv := m.InitializationVector(); len(iv) > 0 {
|
|
buf.Reset()
|
|
if err := enc.Encode(base64.EncodeToString(iv)); err != nil {
|
|
return nil, fmt.Errorf(`failed to encode %s field: %w`, InitializationVectorKey, err)
|
|
}
|
|
fields = append(fields, jsonKV{
|
|
Key: InitializationVectorKey,
|
|
Value: strings.TrimSpace(buf.String()),
|
|
})
|
|
}
|
|
|
|
var encodedProtectedHeaders []byte
|
|
if h := m.ProtectedHeaders(); h != nil {
|
|
v, err := h.Encode()
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`failed to encode protected headers: %w`, err)
|
|
}
|
|
|
|
encodedProtectedHeaders = v
|
|
if len(encodedProtectedHeaders) <= 2 { // '{}'
|
|
encodedProtectedHeaders = nil
|
|
} else {
|
|
fields = append(fields, jsonKV{
|
|
Key: ProtectedHeadersKey,
|
|
Value: fmt.Sprintf("%q", encodedProtectedHeaders),
|
|
})
|
|
}
|
|
}
|
|
|
|
if aad := m.AuthenticatedData(); len(aad) > 0 {
|
|
aad = base64.Encode(aad)
|
|
if encodedProtectedHeaders != nil {
|
|
tmp := append(encodedProtectedHeaders, '.')
|
|
aad = append(tmp, aad...)
|
|
}
|
|
|
|
buf.Reset()
|
|
if err := enc.Encode(aad); err != nil {
|
|
return nil, fmt.Errorf(`failed to encode %s field: %w`, AuthenticatedDataKey, err)
|
|
}
|
|
fields = append(fields, jsonKV{
|
|
Key: AuthenticatedDataKey,
|
|
Value: strings.TrimSpace(buf.String()),
|
|
})
|
|
}
|
|
|
|
if recipients := m.Recipients(); len(recipients) > 0 {
|
|
if len(recipients) == 1 { // Use flattened format
|
|
if hdrs := recipients[0].Headers(); hdrs != nil {
|
|
buf.Reset()
|
|
if err := enc.Encode(hdrs); err != nil {
|
|
return nil, fmt.Errorf(`failed to encode %s field: %w`, HeadersKey, err)
|
|
}
|
|
fields = append(fields, jsonKV{
|
|
Key: HeadersKey,
|
|
Value: strings.TrimSpace(buf.String()),
|
|
})
|
|
}
|
|
|
|
if ek := recipients[0].EncryptedKey(); len(ek) > 0 {
|
|
buf.Reset()
|
|
if err := enc.Encode(base64.EncodeToString(ek)); err != nil {
|
|
return nil, fmt.Errorf(`failed to encode %s field: %w`, EncryptedKeyKey, err)
|
|
}
|
|
fields = append(fields, jsonKV{
|
|
Key: EncryptedKeyKey,
|
|
Value: strings.TrimSpace(buf.String()),
|
|
})
|
|
}
|
|
} else {
|
|
buf.Reset()
|
|
if err := enc.Encode(recipients); err != nil {
|
|
return nil, fmt.Errorf(`failed to encode %s field: %w`, RecipientsKey, err)
|
|
}
|
|
fields = append(fields, jsonKV{
|
|
Key: RecipientsKey,
|
|
Value: strings.TrimSpace(buf.String()),
|
|
})
|
|
}
|
|
}
|
|
|
|
if tag := m.Tag(); len(tag) > 0 {
|
|
buf.Reset()
|
|
if err := enc.Encode(base64.EncodeToString(tag)); err != nil {
|
|
return nil, fmt.Errorf(`failed to encode %s field: %w`, TagKey, err)
|
|
}
|
|
fields = append(fields, jsonKV{
|
|
Key: TagKey,
|
|
Value: strings.TrimSpace(buf.String()),
|
|
})
|
|
}
|
|
|
|
if h := m.UnprotectedHeaders(); h != nil {
|
|
unprotected, err := json.Marshal(h)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`failed to encode unprotected headers: %w`, err)
|
|
}
|
|
|
|
if len(unprotected) > 2 {
|
|
fields = append(fields, jsonKV{
|
|
Key: UnprotectedHeadersKey,
|
|
Value: fmt.Sprintf("%q", unprotected),
|
|
})
|
|
}
|
|
}
|
|
|
|
sort.Slice(fields, func(i, j int) bool {
|
|
return fields[i].Key < fields[j].Key
|
|
})
|
|
buf.Reset()
|
|
fmt.Fprintf(buf, `{`)
|
|
for i, kv := range fields {
|
|
if i > 0 {
|
|
fmt.Fprintf(buf, `,`)
|
|
}
|
|
fmt.Fprintf(buf, `%q:%s`, kv.Key, kv.Value)
|
|
}
|
|
fmt.Fprintf(buf, `}`)
|
|
|
|
ret := make([]byte, buf.Len())
|
|
copy(ret, buf.Bytes())
|
|
return ret, nil
|
|
}
|
|
|
|
func (m *Message) UnmarshalJSON(buf []byte) error {
|
|
var proxy messageMarshalProxy
|
|
proxy.UnprotectedHeaders = NewHeaders()
|
|
|
|
if err := json.Unmarshal(buf, &proxy); err != nil {
|
|
return fmt.Errorf(`failed to unmashal JSON into message: %w`, err)
|
|
}
|
|
|
|
// Get the string value
|
|
var protectedHeadersStr string
|
|
if err := json.Unmarshal(proxy.ProtectedHeaders, &protectedHeadersStr); err != nil {
|
|
return fmt.Errorf(`failed to decode protected headers (1): %w`, err)
|
|
}
|
|
|
|
// It's now in _quoted_ base64 string. Decode it
|
|
protectedHeadersRaw, err := base64.DecodeString(protectedHeadersStr)
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to base64 decoded protected headers buffer: %w`, err)
|
|
}
|
|
|
|
h := NewHeaders()
|
|
if err := json.Unmarshal(protectedHeadersRaw, h); err != nil {
|
|
return fmt.Errorf(`failed to decode protected headers (2): %w`, err)
|
|
}
|
|
|
|
// if this were a flattened message, we would see a "header" and "ciphertext"
|
|
// field. TODO: do both of these conditions need to meet, or just one?
|
|
if proxy.Headers != nil || len(proxy.EncryptedKey) > 0 {
|
|
recipient := NewRecipient()
|
|
hdrs := NewHeaders()
|
|
if err := json.Unmarshal(proxy.Headers, hdrs); err != nil {
|
|
return fmt.Errorf(`failed to decode headers field: %w`, err)
|
|
}
|
|
|
|
if err := recipient.SetHeaders(hdrs); err != nil {
|
|
return fmt.Errorf(`failed to set new headers: %w`, err)
|
|
}
|
|
|
|
if v := proxy.EncryptedKey; len(v) > 0 {
|
|
buf, err := base64.DecodeString(v)
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to decode encrypted key: %w`, err)
|
|
}
|
|
if err := recipient.SetEncryptedKey(buf); err != nil {
|
|
return fmt.Errorf(`failed to set encrypted key: %w`, err)
|
|
}
|
|
}
|
|
|
|
m.recipients = append(m.recipients, recipient)
|
|
} else {
|
|
for i, recipientbuf := range proxy.Recipients {
|
|
recipient := NewRecipient()
|
|
if err := json.Unmarshal(recipientbuf, recipient); err != nil {
|
|
return fmt.Errorf(`failed to decode recipient at index %d: %w`, i, err)
|
|
}
|
|
|
|
m.recipients = append(m.recipients, recipient)
|
|
}
|
|
}
|
|
|
|
if src := proxy.AuthenticatedData; len(src) > 0 {
|
|
v, err := base64.DecodeString(src)
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to decode "aad": %w`, err)
|
|
}
|
|
m.authenticatedData = v
|
|
}
|
|
|
|
if src := proxy.CipherText; len(src) > 0 {
|
|
v, err := base64.DecodeString(src)
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to decode "ciphertext": %w`, err)
|
|
}
|
|
m.cipherText = v
|
|
}
|
|
|
|
if src := proxy.InitializationVector; len(src) > 0 {
|
|
v, err := base64.DecodeString(src)
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to decode "iv": %w`, err)
|
|
}
|
|
m.initializationVector = v
|
|
}
|
|
|
|
if src := proxy.Tag; len(src) > 0 {
|
|
v, err := base64.DecodeString(src)
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to decode "tag": %w`, err)
|
|
}
|
|
m.tag = v
|
|
}
|
|
|
|
m.protectedHeaders = h
|
|
if m.storeProtectedHeaders {
|
|
// this is later used for decryption
|
|
m.rawProtectedHeaders = base64.Encode(protectedHeadersRaw)
|
|
}
|
|
|
|
if iz, ok := proxy.UnprotectedHeaders.(isZeroer); ok {
|
|
if !iz.isZero() {
|
|
m.unprotectedHeaders = proxy.UnprotectedHeaders
|
|
}
|
|
}
|
|
|
|
if len(m.recipients) == 0 {
|
|
if err := m.makeDummyRecipient(proxy.EncryptedKey, m.protectedHeaders); err != nil {
|
|
return fmt.Errorf(`failed to setup recipient: %w`, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Message) makeDummyRecipient(enckeybuf string, protected Headers) error {
|
|
// Recipients in this case should not contain the content encryption key,
|
|
// so move that out
|
|
hdrs, err := protected.Clone(context.TODO())
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to clone headers: %w`, err)
|
|
}
|
|
|
|
if err := hdrs.Remove(ContentEncryptionKey); err != nil {
|
|
return fmt.Errorf(`failed to remove %#v from public header: %w`, ContentEncryptionKey, err)
|
|
}
|
|
|
|
enckey, err := base64.DecodeString(enckeybuf)
|
|
if err != nil {
|
|
return fmt.Errorf(`failed to decode encrypted key: %w`, err)
|
|
}
|
|
|
|
if err := m.Set(RecipientsKey, []Recipient{
|
|
&stdRecipient{
|
|
headers: hdrs,
|
|
encryptedKey: enckey,
|
|
},
|
|
}); err != nil {
|
|
return fmt.Errorf(`failed to set %s: %w`, RecipientsKey, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Compact generates a JWE message in compact serialization format from a
|
|
// `*jwe.Message` object. The object contain exactly one recipient, or
|
|
// an error is returned.
|
|
//
|
|
// This function currently does not take any options, but the function
|
|
// signature contains `options` for possible future expansion of the API
|
|
func Compact(m *Message, _ ...CompactOption) ([]byte, error) {
|
|
if len(m.recipients) != 1 {
|
|
return nil, fmt.Errorf(`wrong number of recipients for compact serialization`)
|
|
}
|
|
|
|
recipient := m.recipients[0]
|
|
|
|
// The protected header must be a merge between the message-wide
|
|
// protected header AND the recipient header
|
|
|
|
// There's something wrong if m.protectedHeaders is nil, but
|
|
// it could happen
|
|
if m.protectedHeaders == nil {
|
|
return nil, fmt.Errorf(`invalid protected header`)
|
|
}
|
|
|
|
ctx := context.TODO()
|
|
hcopy, err := m.protectedHeaders.Clone(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`failed to copy protected header: %w`, err)
|
|
}
|
|
hcopy, err = hcopy.Merge(ctx, m.unprotectedHeaders)
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`failed to merge unprotected header: %w`, err)
|
|
}
|
|
hcopy, err = hcopy.Merge(ctx, recipient.Headers())
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`failed to merge recipient header: %w`, err)
|
|
}
|
|
|
|
protected, err := hcopy.Encode()
|
|
if err != nil {
|
|
return nil, fmt.Errorf(`failed to encode header: %w`, err)
|
|
}
|
|
|
|
encryptedKey := base64.Encode(recipient.EncryptedKey())
|
|
iv := base64.Encode(m.initializationVector)
|
|
cipher := base64.Encode(m.cipherText)
|
|
tag := base64.Encode(m.tag)
|
|
|
|
buf := pool.GetBytesBuffer()
|
|
defer pool.ReleaseBytesBuffer(buf)
|
|
|
|
buf.Grow(len(protected) + len(encryptedKey) + len(iv) + len(cipher) + len(tag) + 4)
|
|
buf.Write(protected)
|
|
buf.WriteByte('.')
|
|
buf.Write(encryptedKey)
|
|
buf.WriteByte('.')
|
|
buf.Write(iv)
|
|
buf.WriteByte('.')
|
|
buf.Write(cipher)
|
|
buf.WriteByte('.')
|
|
buf.Write(tag)
|
|
|
|
result := make([]byte, buf.Len())
|
|
copy(result, buf.Bytes())
|
|
return result, nil
|
|
}
|