auth/vendor/github.com/lestrrat-go/jwx/jws/message.go

449 lines
11 KiB
Go

package jws
import (
"bytes"
"context"
"github.com/lestrrat-go/jwx/internal/base64"
"github.com/lestrrat-go/jwx/internal/json"
"github.com/lestrrat-go/jwx/internal/pool"
"github.com/lestrrat-go/jwx/jwk"
"github.com/pkg/errors"
)
type collectRawCtx struct{}
func (collectRawCtx) CollectRaw() bool {
return true
}
func NewSignature() *Signature {
return &Signature{}
}
func (s *Signature) DecodeCtx() DecodeCtx {
return s.dc
}
func (s *Signature) SetDecodeCtx(dc DecodeCtx) {
s.dc = dc
}
func (s Signature) PublicHeaders() Headers {
return s.headers
}
func (s *Signature) SetPublicHeaders(v Headers) *Signature {
s.headers = v
return s
}
func (s Signature) ProtectedHeaders() Headers {
return s.protected
}
func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
s.protected = v
return s
}
func (s Signature) Signature() []byte {
return s.signature
}
func (s *Signature) SetSignature(v []byte) *Signature {
s.signature = v
return s
}
type signatureUnmarshalProbe struct {
Header Headers `json:"header,omitempty"`
Protected *string `json:"protected,omitempty"`
Signature *string `json:"signature,omitempty"`
}
func (s *Signature) UnmarshalJSON(data []byte) error {
var sup signatureUnmarshalProbe
sup.Header = NewHeaders()
if err := json.Unmarshal(data, &sup); err != nil {
return errors.Wrap(err, `failed to unmarshal signature into temporary struct`)
}
s.headers = sup.Header
if buf := sup.Protected; buf != nil {
src := []byte(*buf)
if !bytes.HasPrefix(src, []byte{'{'}) {
decoded, err := base64.Decode(src)
if err != nil {
return errors.Wrap(err, `failed to base64 decode protected headers`)
}
src = decoded
}
prt := NewHeaders()
//nolint:forcetypeassert
prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx())
if err := json.Unmarshal(src, prt); err != nil {
return errors.Wrap(err, `failed to unmarshal protected headers`)
}
//nolint:forcetypeassert
prt.(*stdHeaders).SetDecodeCtx(nil)
s.protected = prt
}
decoded, err := base64.DecodeString(*sup.Signature)
if err != nil {
return errors.Wrap(err, `failed to base decode signature`)
}
s.signature = decoded
return nil
}
// Sign populates the signature field, with a signature generated by
// given the signer object and payload.
//
// The first return value is the raw signature in binary format.
// The second return value s the full three-segment signature
// (e.g. "eyXXXX.XXXXX.XXXX")
func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
if err != nil {
return nil, nil, errors.Wrap(err, `failed to merge headers`)
}
if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
return nil, nil, errors.Wrap(err, `failed to set "alg"`)
}
// If the key is a jwk.Key instance, obtain the raw key
if jwkKey, ok := key.(jwk.Key); ok {
// If we have a key ID specified by this jwk.Key, use that in the header
if kid := jwkKey.KeyID(); kid != "" {
if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`)
}
}
}
hdrbuf, err := json.Marshal(hdrs)
if err != nil {
return nil, nil, errors.Wrap(err, `failed to marshal headers`)
}
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
buf.WriteString(base64.EncodeToString(hdrbuf))
buf.WriteByte('.')
var plen int
b64 := getB64Value(hdrs)
if b64 {
encoded := base64.EncodeToString(payload)
plen = len(encoded)
buf.WriteString(encoded)
} else {
if !s.detached {
if bytes.Contains(payload, []byte{'.'}) {
return nil, nil, errors.New(`payload must not contain a "."`)
}
}
plen = len(payload)
buf.Write(payload)
}
signature, err := signer.Sign(buf.Bytes(), key)
if err != nil {
return nil, nil, errors.Wrap(err, `failed to sign payload`)
}
s.signature = signature
// Detached payload, this should be removed from the end result
if s.detached {
buf.Truncate(buf.Len() - plen)
}
buf.WriteByte('.')
buf.WriteString(base64.EncodeToString(signature))
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return signature, ret, nil
}
func NewMessage() *Message {
return &Message{}
}
// Clears the internal raw buffer that was accumulated during
// the verify phase
func (m *Message) clearRaw() {
for _, sig := range m.signatures {
if protected := sig.protected; protected != nil {
if cr, ok := protected.(*stdHeaders); ok {
cr.raw = nil
}
}
}
}
func (m *Message) SetDecodeCtx(dc DecodeCtx) {
m.dc = dc
}
func (m *Message) DecodeCtx() DecodeCtx {
return m.dc
}
// Payload returns the decoded payload
func (m Message) Payload() []byte {
return m.payload
}
func (m *Message) SetPayload(v []byte) *Message {
m.payload = v
return m
}
func (m Message) Signatures() []*Signature {
return m.signatures
}
func (m *Message) AppendSignature(v *Signature) *Message {
m.signatures = append(m.signatures, v)
return m
}
func (m *Message) ClearSignatures() *Message {
m.signatures = nil
return m
}
// LookupSignature looks up a particular signature entry using
// the `kid` value
func (m Message) LookupSignature(kid string) []*Signature {
var sigs []*Signature
for _, sig := range m.signatures {
if hdr := sig.PublicHeaders(); hdr != nil {
hdrKeyID := hdr.KeyID()
if hdrKeyID == kid {
sigs = append(sigs, sig)
continue
}
}
if hdr := sig.ProtectedHeaders(); hdr != nil {
hdrKeyID := hdr.KeyID()
if hdrKeyID == kid {
sigs = append(sigs, sig)
continue
}
}
}
return sigs
}
// This struct is used to first probe for the structure of the
// incoming JSON object. We then decide how to parse it
// from the fields that are populated.
type messageUnmarshalProbe struct {
Payload *string `json:"payload"`
Signatures []json.RawMessage `json:"signatures,omitempty"`
Header Headers `json:"header,omitempty"`
Protected *string `json:"protected,omitempty"`
Signature *string `json:"signature,omitempty"`
}
func (m *Message) UnmarshalJSON(buf []byte) error {
m.payload = nil
m.signatures = nil
m.b64 = true
var mup messageUnmarshalProbe
mup.Header = NewHeaders()
if err := json.Unmarshal(buf, &mup); err != nil {
return errors.Wrap(err, `failed to unmarshal into temporary structure`)
}
b64 := true
if mup.Signature == nil { // flattened signature is NOT present
if len(mup.Signatures) == 0 {
return errors.New(`required field "signatures" not present`)
}
m.signatures = make([]*Signature, 0, len(mup.Signatures))
for i, rawsig := range mup.Signatures {
var sig Signature
sig.SetDecodeCtx(m.DecodeCtx())
if err := json.Unmarshal(rawsig, &sig); err != nil {
return errors.Wrapf(err, `failed to unmarshal signature #%d`, i+1)
}
sig.SetDecodeCtx(nil)
if i == 0 {
if !getB64Value(sig.protected) {
b64 = false
}
} else {
if b64 != getB64Value(sig.protected) {
return errors.Errorf(`b64 value must be the same for all signatures`)
}
}
m.signatures = append(m.signatures, &sig)
}
} else { // .signature is present, it's a flattened structure
if len(mup.Signatures) != 0 {
return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`)
}
var sig Signature
sig.headers = mup.Header
if src := mup.Protected; src != nil {
decoded, err := base64.DecodeString(*src)
if err != nil {
return errors.Wrap(err, `failed to base64 decode flattened protected headers`)
}
prt := NewHeaders()
//nolint:forcetypeassert
prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx())
if err := json.Unmarshal(decoded, prt); err != nil {
return errors.Wrap(err, `failed to unmarshal flattened protected headers`)
}
//nolint:forcetypeassert
prt.(*stdHeaders).SetDecodeCtx(nil)
sig.protected = prt
}
decoded, err := base64.DecodeString(*mup.Signature)
if err != nil {
return errors.Wrap(err, `failed to base64 decode flattened signature`)
}
sig.signature = decoded
m.signatures = []*Signature{&sig}
b64 = getB64Value(sig.protected)
}
if mup.Payload != nil {
if !b64 { // NOT base64 encoded
m.payload = []byte(*mup.Payload)
} else {
decoded, err := base64.DecodeString(*mup.Payload)
if err != nil {
return errors.Wrap(err, `failed to base64 decode payload`)
}
m.payload = decoded
}
}
m.b64 = b64
return nil
}
func (m Message) MarshalJSON() ([]byte, error) {
if len(m.signatures) == 1 {
return m.marshalFlattened()
}
return m.marshalFull()
}
func (m Message) marshalFlattened() ([]byte, error) {
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
sig := m.signatures[0]
buf.WriteRune('{')
var wrote bool
if hdr := sig.headers; hdr != nil {
hdrjs, err := hdr.MarshalJSON()
if err != nil {
return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`)
}
buf.WriteString(`"header":`)
buf.Write(hdrjs)
wrote = true
}
if wrote {
buf.WriteRune(',')
}
buf.WriteString(`"payload":"`)
buf.WriteString(base64.EncodeToString(m.payload))
buf.WriteRune('"')
if protected := sig.protected; protected != nil {
protectedbuf, err := protected.MarshalJSON()
if err != nil {
return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`)
}
buf.WriteString(`,"protected":"`)
buf.WriteString(base64.EncodeToString(protectedbuf))
buf.WriteRune('"')
}
buf.WriteString(`,"signature":"`)
buf.WriteString(base64.EncodeToString(sig.signature))
buf.WriteRune('"')
buf.WriteRune('}')
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}
func (m Message) marshalFull() ([]byte, error) {
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
buf.WriteString(`{"payload":"`)
buf.WriteString(base64.EncodeToString(m.payload))
buf.WriteString(`","signatures":[`)
for i, sig := range m.signatures {
if i > 0 {
buf.WriteRune(',')
}
buf.WriteRune('{')
var wrote bool
if hdr := sig.headers; hdr != nil {
hdrbuf, err := hdr.MarshalJSON()
if err != nil {
return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1)
}
buf.WriteString(`"header":`)
buf.Write(hdrbuf)
wrote = true
}
if protected := sig.protected; protected != nil {
protectedbuf, err := protected.MarshalJSON()
if err != nil {
return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
}
if wrote {
buf.WriteRune(',')
}
buf.WriteString(`"protected":"`)
buf.WriteString(base64.EncodeToString(protectedbuf))
buf.WriteRune('"')
wrote = true
}
if wrote {
buf.WriteRune(',')
}
buf.WriteString(`"signature":"`)
buf.WriteString(base64.EncodeToString(sig.signature))
buf.WriteString(`"}`)
}
buf.WriteString(`]}`)
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}