449 lines
11 KiB
Go
449 lines
11 KiB
Go
package jws
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
|
|
"github.com/lestrrat-go/jwx/internal/base64"
|
|
"github.com/lestrrat-go/jwx/internal/json"
|
|
"github.com/lestrrat-go/jwx/internal/pool"
|
|
"github.com/lestrrat-go/jwx/jwk"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type collectRawCtx struct{}
|
|
|
|
func (collectRawCtx) CollectRaw() bool {
|
|
return true
|
|
}
|
|
|
|
func NewSignature() *Signature {
|
|
return &Signature{}
|
|
}
|
|
|
|
func (s *Signature) DecodeCtx() DecodeCtx {
|
|
return s.dc
|
|
}
|
|
|
|
func (s *Signature) SetDecodeCtx(dc DecodeCtx) {
|
|
s.dc = dc
|
|
}
|
|
|
|
func (s Signature) PublicHeaders() Headers {
|
|
return s.headers
|
|
}
|
|
|
|
func (s *Signature) SetPublicHeaders(v Headers) *Signature {
|
|
s.headers = v
|
|
return s
|
|
}
|
|
|
|
func (s Signature) ProtectedHeaders() Headers {
|
|
return s.protected
|
|
}
|
|
|
|
func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
|
|
s.protected = v
|
|
return s
|
|
}
|
|
|
|
func (s Signature) Signature() []byte {
|
|
return s.signature
|
|
}
|
|
|
|
func (s *Signature) SetSignature(v []byte) *Signature {
|
|
s.signature = v
|
|
return s
|
|
}
|
|
|
|
type signatureUnmarshalProbe struct {
|
|
Header Headers `json:"header,omitempty"`
|
|
Protected *string `json:"protected,omitempty"`
|
|
Signature *string `json:"signature,omitempty"`
|
|
}
|
|
|
|
func (s *Signature) UnmarshalJSON(data []byte) error {
|
|
var sup signatureUnmarshalProbe
|
|
sup.Header = NewHeaders()
|
|
if err := json.Unmarshal(data, &sup); err != nil {
|
|
return errors.Wrap(err, `failed to unmarshal signature into temporary struct`)
|
|
}
|
|
|
|
s.headers = sup.Header
|
|
if buf := sup.Protected; buf != nil {
|
|
src := []byte(*buf)
|
|
if !bytes.HasPrefix(src, []byte{'{'}) {
|
|
decoded, err := base64.Decode(src)
|
|
if err != nil {
|
|
return errors.Wrap(err, `failed to base64 decode protected headers`)
|
|
}
|
|
src = decoded
|
|
}
|
|
|
|
prt := NewHeaders()
|
|
//nolint:forcetypeassert
|
|
prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx())
|
|
if err := json.Unmarshal(src, prt); err != nil {
|
|
return errors.Wrap(err, `failed to unmarshal protected headers`)
|
|
}
|
|
//nolint:forcetypeassert
|
|
prt.(*stdHeaders).SetDecodeCtx(nil)
|
|
s.protected = prt
|
|
}
|
|
|
|
decoded, err := base64.DecodeString(*sup.Signature)
|
|
if err != nil {
|
|
return errors.Wrap(err, `failed to base decode signature`)
|
|
}
|
|
s.signature = decoded
|
|
return nil
|
|
}
|
|
|
|
// Sign populates the signature field, with a signature generated by
|
|
// given the signer object and payload.
|
|
//
|
|
// The first return value is the raw signature in binary format.
|
|
// The second return value s the full three-segment signature
|
|
// (e.g. "eyXXXX.XXXXX.XXXX")
|
|
func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, `failed to merge headers`)
|
|
}
|
|
|
|
if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
|
|
return nil, nil, errors.Wrap(err, `failed to set "alg"`)
|
|
}
|
|
|
|
// If the key is a jwk.Key instance, obtain the raw key
|
|
if jwkKey, ok := key.(jwk.Key); ok {
|
|
// If we have a key ID specified by this jwk.Key, use that in the header
|
|
if kid := jwkKey.KeyID(); kid != "" {
|
|
if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
|
|
return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`)
|
|
}
|
|
}
|
|
}
|
|
hdrbuf, err := json.Marshal(hdrs)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, `failed to marshal headers`)
|
|
}
|
|
|
|
buf := pool.GetBytesBuffer()
|
|
defer pool.ReleaseBytesBuffer(buf)
|
|
|
|
buf.WriteString(base64.EncodeToString(hdrbuf))
|
|
buf.WriteByte('.')
|
|
|
|
var plen int
|
|
b64 := getB64Value(hdrs)
|
|
if b64 {
|
|
encoded := base64.EncodeToString(payload)
|
|
plen = len(encoded)
|
|
buf.WriteString(encoded)
|
|
} else {
|
|
if !s.detached {
|
|
if bytes.Contains(payload, []byte{'.'}) {
|
|
return nil, nil, errors.New(`payload must not contain a "."`)
|
|
}
|
|
}
|
|
plen = len(payload)
|
|
buf.Write(payload)
|
|
}
|
|
|
|
signature, err := signer.Sign(buf.Bytes(), key)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, `failed to sign payload`)
|
|
}
|
|
s.signature = signature
|
|
|
|
// Detached payload, this should be removed from the end result
|
|
if s.detached {
|
|
buf.Truncate(buf.Len() - plen)
|
|
}
|
|
|
|
buf.WriteByte('.')
|
|
buf.WriteString(base64.EncodeToString(signature))
|
|
ret := make([]byte, buf.Len())
|
|
copy(ret, buf.Bytes())
|
|
|
|
return signature, ret, nil
|
|
}
|
|
|
|
func NewMessage() *Message {
|
|
return &Message{}
|
|
}
|
|
|
|
// Clears the internal raw buffer that was accumulated during
|
|
// the verify phase
|
|
func (m *Message) clearRaw() {
|
|
for _, sig := range m.signatures {
|
|
if protected := sig.protected; protected != nil {
|
|
if cr, ok := protected.(*stdHeaders); ok {
|
|
cr.raw = nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Message) SetDecodeCtx(dc DecodeCtx) {
|
|
m.dc = dc
|
|
}
|
|
|
|
func (m *Message) DecodeCtx() DecodeCtx {
|
|
return m.dc
|
|
}
|
|
|
|
// Payload returns the decoded payload
|
|
func (m Message) Payload() []byte {
|
|
return m.payload
|
|
}
|
|
|
|
func (m *Message) SetPayload(v []byte) *Message {
|
|
m.payload = v
|
|
return m
|
|
}
|
|
|
|
func (m Message) Signatures() []*Signature {
|
|
return m.signatures
|
|
}
|
|
|
|
func (m *Message) AppendSignature(v *Signature) *Message {
|
|
m.signatures = append(m.signatures, v)
|
|
return m
|
|
}
|
|
|
|
func (m *Message) ClearSignatures() *Message {
|
|
m.signatures = nil
|
|
return m
|
|
}
|
|
|
|
// LookupSignature looks up a particular signature entry using
|
|
// the `kid` value
|
|
func (m Message) LookupSignature(kid string) []*Signature {
|
|
var sigs []*Signature
|
|
for _, sig := range m.signatures {
|
|
if hdr := sig.PublicHeaders(); hdr != nil {
|
|
hdrKeyID := hdr.KeyID()
|
|
if hdrKeyID == kid {
|
|
sigs = append(sigs, sig)
|
|
continue
|
|
}
|
|
}
|
|
|
|
if hdr := sig.ProtectedHeaders(); hdr != nil {
|
|
hdrKeyID := hdr.KeyID()
|
|
if hdrKeyID == kid {
|
|
sigs = append(sigs, sig)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
return sigs
|
|
}
|
|
|
|
// This struct is used to first probe for the structure of the
|
|
// incoming JSON object. We then decide how to parse it
|
|
// from the fields that are populated.
|
|
type messageUnmarshalProbe struct {
|
|
Payload *string `json:"payload"`
|
|
Signatures []json.RawMessage `json:"signatures,omitempty"`
|
|
Header Headers `json:"header,omitempty"`
|
|
Protected *string `json:"protected,omitempty"`
|
|
Signature *string `json:"signature,omitempty"`
|
|
}
|
|
|
|
func (m *Message) UnmarshalJSON(buf []byte) error {
|
|
m.payload = nil
|
|
m.signatures = nil
|
|
m.b64 = true
|
|
|
|
var mup messageUnmarshalProbe
|
|
mup.Header = NewHeaders()
|
|
if err := json.Unmarshal(buf, &mup); err != nil {
|
|
return errors.Wrap(err, `failed to unmarshal into temporary structure`)
|
|
}
|
|
|
|
b64 := true
|
|
if mup.Signature == nil { // flattened signature is NOT present
|
|
if len(mup.Signatures) == 0 {
|
|
return errors.New(`required field "signatures" not present`)
|
|
}
|
|
|
|
m.signatures = make([]*Signature, 0, len(mup.Signatures))
|
|
for i, rawsig := range mup.Signatures {
|
|
var sig Signature
|
|
sig.SetDecodeCtx(m.DecodeCtx())
|
|
if err := json.Unmarshal(rawsig, &sig); err != nil {
|
|
return errors.Wrapf(err, `failed to unmarshal signature #%d`, i+1)
|
|
}
|
|
sig.SetDecodeCtx(nil)
|
|
|
|
if i == 0 {
|
|
if !getB64Value(sig.protected) {
|
|
b64 = false
|
|
}
|
|
} else {
|
|
if b64 != getB64Value(sig.protected) {
|
|
return errors.Errorf(`b64 value must be the same for all signatures`)
|
|
}
|
|
}
|
|
|
|
m.signatures = append(m.signatures, &sig)
|
|
}
|
|
} else { // .signature is present, it's a flattened structure
|
|
if len(mup.Signatures) != 0 {
|
|
return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`)
|
|
}
|
|
|
|
var sig Signature
|
|
sig.headers = mup.Header
|
|
if src := mup.Protected; src != nil {
|
|
decoded, err := base64.DecodeString(*src)
|
|
if err != nil {
|
|
return errors.Wrap(err, `failed to base64 decode flattened protected headers`)
|
|
}
|
|
prt := NewHeaders()
|
|
//nolint:forcetypeassert
|
|
prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx())
|
|
if err := json.Unmarshal(decoded, prt); err != nil {
|
|
return errors.Wrap(err, `failed to unmarshal flattened protected headers`)
|
|
}
|
|
//nolint:forcetypeassert
|
|
prt.(*stdHeaders).SetDecodeCtx(nil)
|
|
sig.protected = prt
|
|
}
|
|
|
|
decoded, err := base64.DecodeString(*mup.Signature)
|
|
if err != nil {
|
|
return errors.Wrap(err, `failed to base64 decode flattened signature`)
|
|
}
|
|
sig.signature = decoded
|
|
|
|
m.signatures = []*Signature{&sig}
|
|
b64 = getB64Value(sig.protected)
|
|
}
|
|
|
|
if mup.Payload != nil {
|
|
if !b64 { // NOT base64 encoded
|
|
m.payload = []byte(*mup.Payload)
|
|
} else {
|
|
decoded, err := base64.DecodeString(*mup.Payload)
|
|
if err != nil {
|
|
return errors.Wrap(err, `failed to base64 decode payload`)
|
|
}
|
|
m.payload = decoded
|
|
}
|
|
}
|
|
m.b64 = b64
|
|
return nil
|
|
}
|
|
|
|
func (m Message) MarshalJSON() ([]byte, error) {
|
|
if len(m.signatures) == 1 {
|
|
return m.marshalFlattened()
|
|
}
|
|
return m.marshalFull()
|
|
}
|
|
|
|
func (m Message) marshalFlattened() ([]byte, error) {
|
|
buf := pool.GetBytesBuffer()
|
|
defer pool.ReleaseBytesBuffer(buf)
|
|
|
|
sig := m.signatures[0]
|
|
|
|
buf.WriteRune('{')
|
|
var wrote bool
|
|
|
|
if hdr := sig.headers; hdr != nil {
|
|
hdrjs, err := hdr.MarshalJSON()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`)
|
|
}
|
|
buf.WriteString(`"header":`)
|
|
buf.Write(hdrjs)
|
|
wrote = true
|
|
}
|
|
|
|
if wrote {
|
|
buf.WriteRune(',')
|
|
}
|
|
buf.WriteString(`"payload":"`)
|
|
buf.WriteString(base64.EncodeToString(m.payload))
|
|
buf.WriteRune('"')
|
|
|
|
if protected := sig.protected; protected != nil {
|
|
protectedbuf, err := protected.MarshalJSON()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`)
|
|
}
|
|
buf.WriteString(`,"protected":"`)
|
|
buf.WriteString(base64.EncodeToString(protectedbuf))
|
|
buf.WriteRune('"')
|
|
}
|
|
|
|
buf.WriteString(`,"signature":"`)
|
|
buf.WriteString(base64.EncodeToString(sig.signature))
|
|
buf.WriteRune('"')
|
|
buf.WriteRune('}')
|
|
|
|
ret := make([]byte, buf.Len())
|
|
copy(ret, buf.Bytes())
|
|
return ret, nil
|
|
}
|
|
|
|
func (m Message) marshalFull() ([]byte, error) {
|
|
buf := pool.GetBytesBuffer()
|
|
defer pool.ReleaseBytesBuffer(buf)
|
|
|
|
buf.WriteString(`{"payload":"`)
|
|
buf.WriteString(base64.EncodeToString(m.payload))
|
|
buf.WriteString(`","signatures":[`)
|
|
for i, sig := range m.signatures {
|
|
if i > 0 {
|
|
buf.WriteRune(',')
|
|
}
|
|
|
|
buf.WriteRune('{')
|
|
var wrote bool
|
|
if hdr := sig.headers; hdr != nil {
|
|
hdrbuf, err := hdr.MarshalJSON()
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1)
|
|
}
|
|
buf.WriteString(`"header":`)
|
|
buf.Write(hdrbuf)
|
|
wrote = true
|
|
}
|
|
|
|
if protected := sig.protected; protected != nil {
|
|
protectedbuf, err := protected.MarshalJSON()
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
|
|
}
|
|
if wrote {
|
|
buf.WriteRune(',')
|
|
}
|
|
buf.WriteString(`"protected":"`)
|
|
buf.WriteString(base64.EncodeToString(protectedbuf))
|
|
buf.WriteRune('"')
|
|
wrote = true
|
|
}
|
|
|
|
if wrote {
|
|
buf.WriteRune(',')
|
|
}
|
|
buf.WriteString(`"signature":"`)
|
|
buf.WriteString(base64.EncodeToString(sig.signature))
|
|
buf.WriteString(`"}`)
|
|
}
|
|
buf.WriteString(`]}`)
|
|
|
|
ret := make([]byte, buf.Len())
|
|
copy(ret, buf.Bytes())
|
|
return ret, nil
|
|
}
|