auth/vendor/github.com/lestrrat-go/jwx/v2/jws/rsa.go

147 lines
3.2 KiB
Go

package jws
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"fmt"
"github.com/lestrrat-go/jwx/v2/internal/keyconv"
"github.com/lestrrat-go/jwx/v2/jwa"
)
var rsaSigners map[jwa.SignatureAlgorithm]*rsaSigner
var rsaVerifiers map[jwa.SignatureAlgorithm]*rsaVerifier
func init() {
algs := map[jwa.SignatureAlgorithm]struct {
Hash crypto.Hash
PSS bool
}{
jwa.RS256: {
Hash: crypto.SHA256,
},
jwa.RS384: {
Hash: crypto.SHA384,
},
jwa.RS512: {
Hash: crypto.SHA512,
},
jwa.PS256: {
Hash: crypto.SHA256,
PSS: true,
},
jwa.PS384: {
Hash: crypto.SHA384,
PSS: true,
},
jwa.PS512: {
Hash: crypto.SHA512,
PSS: true,
},
}
rsaSigners = make(map[jwa.SignatureAlgorithm]*rsaSigner)
rsaVerifiers = make(map[jwa.SignatureAlgorithm]*rsaVerifier)
for alg, item := range algs {
rsaSigners[alg] = &rsaSigner{
alg: alg,
hash: item.Hash,
pss: item.PSS,
}
rsaVerifiers[alg] = &rsaVerifier{
alg: alg,
hash: item.Hash,
pss: item.PSS,
}
}
}
type rsaSigner struct {
alg jwa.SignatureAlgorithm
hash crypto.Hash
pss bool
}
func newRSASigner(alg jwa.SignatureAlgorithm) Signer {
return rsaSigners[alg]
}
func (rs *rsaSigner) Algorithm() jwa.SignatureAlgorithm {
return rs.alg
}
func (rs *rsaSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
if key == nil {
return nil, fmt.Errorf(`missing private key while signing payload`)
}
signer, ok := key.(crypto.Signer)
if ok {
if !isValidRSAKey(key) {
return nil, fmt.Errorf(`cannot use key of type %T to generate RSA based signatures`, key)
}
} else {
var privkey rsa.PrivateKey
if err := keyconv.RSAPrivateKey(&privkey, key); err != nil {
return nil, fmt.Errorf(`failed to retrieve rsa.PrivateKey out of %T: %w`, key, err)
}
signer = &privkey
}
h := rs.hash.New()
if _, err := h.Write(payload); err != nil {
return nil, fmt.Errorf(`failed to write payload to hash: %w`, err)
}
if rs.pss {
return signer.Sign(rand.Reader, h.Sum(nil), &rsa.PSSOptions{
Hash: rs.hash,
SaltLength: rsa.PSSSaltLengthEqualsHash,
})
}
return signer.Sign(rand.Reader, h.Sum(nil), rs.hash)
}
type rsaVerifier struct {
alg jwa.SignatureAlgorithm
hash crypto.Hash
pss bool
}
func newRSAVerifier(alg jwa.SignatureAlgorithm) Verifier {
return rsaVerifiers[alg]
}
func (rv *rsaVerifier) Verify(payload, signature []byte, key interface{}) error {
if key == nil {
return fmt.Errorf(`missing public key while verifying payload`)
}
var pubkey rsa.PublicKey
if cs, ok := key.(crypto.Signer); ok {
cpub := cs.Public()
switch cpub := cpub.(type) {
case rsa.PublicKey:
pubkey = cpub
case *rsa.PublicKey:
pubkey = *cpub
default:
return fmt.Errorf(`failed to retrieve rsa.PublicKey out of crypto.Signer %T`, key)
}
} else {
if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
return fmt.Errorf(`failed to retrieve rsa.PublicKey out of %T: %w`, key, err)
}
}
h := rv.hash.New()
if _, err := h.Write(payload); err != nil {
return fmt.Errorf(`failed to write payload to hash: %w`, err)
}
if rv.pss {
return rsa.VerifyPSS(&pubkey, rv.hash, h.Sum(nil), signature, nil)
}
return rsa.VerifyPKCS1v15(&pubkey, rv.hash, h.Sum(nil), signature)
}