284 lines
6.2 KiB
Go
284 lines
6.2 KiB
Go
package jwk
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/rsa"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"math/big"
|
|
|
|
"github.com/lestrrat-go/blackmagic"
|
|
"github.com/lestrrat-go/jwx/v2/internal/base64"
|
|
"github.com/lestrrat-go/jwx/v2/internal/pool"
|
|
)
|
|
|
|
func (k *rsaPrivateKey) FromRaw(rawKey *rsa.PrivateKey) error {
|
|
k.mu.Lock()
|
|
defer k.mu.Unlock()
|
|
|
|
d, err := bigIntToBytes(rawKey.D)
|
|
if err != nil {
|
|
return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
|
|
}
|
|
k.d = d
|
|
|
|
l := len(rawKey.Primes)
|
|
|
|
if l < 0 /* I know, I'm being paranoid */ || l > 2 {
|
|
return fmt.Errorf(`invalid number of primes in rsa.PrivateKey: need 0 to 2, but got %d`, len(rawKey.Primes))
|
|
}
|
|
|
|
if l > 0 {
|
|
p, err := bigIntToBytes(rawKey.Primes[0])
|
|
if err != nil {
|
|
return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
|
|
}
|
|
k.p = p
|
|
}
|
|
|
|
if l > 1 {
|
|
q, err := bigIntToBytes(rawKey.Primes[1])
|
|
if err != nil {
|
|
return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
|
|
}
|
|
k.q = q
|
|
}
|
|
|
|
// dp, dq, qi are optional values
|
|
if v, err := bigIntToBytes(rawKey.Precomputed.Dp); err == nil {
|
|
k.dp = v
|
|
}
|
|
if v, err := bigIntToBytes(rawKey.Precomputed.Dq); err == nil {
|
|
k.dq = v
|
|
}
|
|
if v, err := bigIntToBytes(rawKey.Precomputed.Qinv); err == nil {
|
|
k.qi = v
|
|
}
|
|
|
|
// public key part
|
|
n, e, err := rsaPublicKeyByteValuesFromRaw(&rawKey.PublicKey)
|
|
if err != nil {
|
|
return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
|
|
}
|
|
k.n = n
|
|
k.e = e
|
|
|
|
return nil
|
|
}
|
|
|
|
func rsaPublicKeyByteValuesFromRaw(rawKey *rsa.PublicKey) ([]byte, []byte, error) {
|
|
n, err := bigIntToBytes(rawKey.N)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf(`invalid rsa.PublicKey: %w`, err)
|
|
}
|
|
|
|
data := make([]byte, 8)
|
|
binary.BigEndian.PutUint64(data, uint64(rawKey.E))
|
|
i := 0
|
|
for ; i < len(data); i++ {
|
|
if data[i] != 0x0 {
|
|
break
|
|
}
|
|
}
|
|
return n, data[i:], nil
|
|
}
|
|
|
|
func (k *rsaPublicKey) FromRaw(rawKey *rsa.PublicKey) error {
|
|
k.mu.Lock()
|
|
defer k.mu.Unlock()
|
|
|
|
n, e, err := rsaPublicKeyByteValuesFromRaw(rawKey)
|
|
if err != nil {
|
|
return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
|
|
}
|
|
k.n = n
|
|
k.e = e
|
|
|
|
return nil
|
|
}
|
|
|
|
func (k *rsaPrivateKey) Raw(v interface{}) error {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
|
|
var d, q, p big.Int // note: do not use from sync.Pool
|
|
|
|
d.SetBytes(k.d)
|
|
q.SetBytes(k.q)
|
|
p.SetBytes(k.p)
|
|
|
|
// optional fields
|
|
var dp, dq, qi *big.Int
|
|
if len(k.dp) > 0 {
|
|
dp = &big.Int{} // note: do not use from sync.Pool
|
|
dp.SetBytes(k.dp)
|
|
}
|
|
|
|
if len(k.dq) > 0 {
|
|
dq = &big.Int{} // note: do not use from sync.Pool
|
|
dq.SetBytes(k.dq)
|
|
}
|
|
|
|
if len(k.qi) > 0 {
|
|
qi = &big.Int{} // note: do not use from sync.Pool
|
|
qi.SetBytes(k.qi)
|
|
}
|
|
|
|
var key rsa.PrivateKey
|
|
|
|
pubk := newRSAPublicKey()
|
|
pubk.n = k.n
|
|
pubk.e = k.e
|
|
if err := pubk.Raw(&key.PublicKey); err != nil {
|
|
return fmt.Errorf(`failed to materialize RSA public key: %w`, err)
|
|
}
|
|
|
|
key.D = &d
|
|
key.Primes = []*big.Int{&p, &q}
|
|
|
|
if dp != nil {
|
|
key.Precomputed.Dp = dp
|
|
}
|
|
if dq != nil {
|
|
key.Precomputed.Dq = dq
|
|
}
|
|
if qi != nil {
|
|
key.Precomputed.Qinv = qi
|
|
}
|
|
key.Precomputed.CRTValues = []rsa.CRTValue{}
|
|
|
|
return blackmagic.AssignIfCompatible(v, &key)
|
|
}
|
|
|
|
// Raw takes the values stored in the Key object, and creates the
|
|
// corresponding *rsa.PublicKey object.
|
|
func (k *rsaPublicKey) Raw(v interface{}) error {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
|
|
var key rsa.PublicKey
|
|
|
|
n := pool.GetBigInt()
|
|
e := pool.GetBigInt()
|
|
defer pool.ReleaseBigInt(e)
|
|
|
|
n.SetBytes(k.n)
|
|
e.SetBytes(k.e)
|
|
|
|
key.N = n
|
|
key.E = int(e.Int64())
|
|
|
|
return blackmagic.AssignIfCompatible(v, &key)
|
|
}
|
|
|
|
func makeRSAPublicKey(v interface {
|
|
makePairs() []*HeaderPair
|
|
}) (Key, error) {
|
|
newKey := newRSAPublicKey()
|
|
|
|
// Iterate and copy everything except for the bits that should not be in the public key
|
|
for _, pair := range v.makePairs() {
|
|
switch pair.Key {
|
|
case RSADKey, RSADPKey, RSADQKey, RSAPKey, RSAQKey, RSAQIKey:
|
|
continue
|
|
default:
|
|
//nolint:forcetypeassert
|
|
key := pair.Key.(string)
|
|
if err := newKey.Set(key, pair.Value); err != nil {
|
|
return nil, fmt.Errorf(`failed to set field %q: %w`, key, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return newKey, nil
|
|
}
|
|
|
|
func (k *rsaPrivateKey) PublicKey() (Key, error) {
|
|
return makeRSAPublicKey(k)
|
|
}
|
|
|
|
func (k *rsaPublicKey) PublicKey() (Key, error) {
|
|
return makeRSAPublicKey(k)
|
|
}
|
|
|
|
// Thumbprint returns the JWK thumbprint using the indicated
|
|
// hashing algorithm, according to RFC 7638
|
|
func (k rsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
|
|
var key rsa.PrivateKey
|
|
if err := k.Raw(&key); err != nil {
|
|
return nil, fmt.Errorf(`failed to materialize RSA private key: %w`, err)
|
|
}
|
|
return rsaThumbprint(hash, &key.PublicKey)
|
|
}
|
|
|
|
func (k rsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
|
|
var key rsa.PublicKey
|
|
if err := k.Raw(&key); err != nil {
|
|
return nil, fmt.Errorf(`failed to materialize RSA public key: %w`, err)
|
|
}
|
|
return rsaThumbprint(hash, &key)
|
|
}
|
|
|
|
func rsaThumbprint(hash crypto.Hash, key *rsa.PublicKey) ([]byte, error) {
|
|
buf := pool.GetBytesBuffer()
|
|
defer pool.ReleaseBytesBuffer(buf)
|
|
|
|
buf.WriteString(`{"e":"`)
|
|
buf.WriteString(base64.EncodeUint64ToString(uint64(key.E)))
|
|
buf.WriteString(`","kty":"RSA","n":"`)
|
|
buf.WriteString(base64.EncodeToString(key.N.Bytes()))
|
|
buf.WriteString(`"}`)
|
|
|
|
h := hash.New()
|
|
if _, err := buf.WriteTo(h); err != nil {
|
|
return nil, fmt.Errorf(`failed to write rsaThumbprint: %w`, err)
|
|
}
|
|
return h.Sum(nil), nil
|
|
}
|
|
|
|
func validateRSAKey(key interface {
|
|
N() []byte
|
|
E() []byte
|
|
}, checkPrivate bool) error {
|
|
if len(key.N()) == 0 {
|
|
// Ideally we would like to check for the actual length, but unlike
|
|
// EC keys, we have nothing in the key itself that will tell us
|
|
// how many bits this key should have.
|
|
return fmt.Errorf(`missing "n" value`)
|
|
}
|
|
if len(key.E()) == 0 {
|
|
return fmt.Errorf(`missing "e" value`)
|
|
}
|
|
if checkPrivate {
|
|
if priv, ok := key.(interface{ D() []byte }); ok {
|
|
if len(priv.D()) == 0 {
|
|
return fmt.Errorf(`missing "d" value`)
|
|
}
|
|
} else {
|
|
return fmt.Errorf(`missing "d" value`)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (k *rsaPrivateKey) Validate() error {
|
|
if err := validateRSAKey(k, true); err != nil {
|
|
return NewKeyValidationError(fmt.Errorf(`jwk.RSAPrivateKey: %w`, err))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (k *rsaPublicKey) Validate() error {
|
|
if err := validateRSAKey(k, false); err != nil {
|
|
return NewKeyValidationError(fmt.Errorf(`jwk.RSAPublicKey: %w`, err))
|
|
}
|
|
return nil
|
|
}
|