auth/vendor/github.com/lestrrat-go/jwx/v2/jwk/okp.go

228 lines
5.2 KiB
Go

package jwk
import (
"bytes"
"crypto"
"crypto/ed25519"
"fmt"
"github.com/lestrrat-go/blackmagic"
"github.com/lestrrat-go/jwx/v2/internal/base64"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/x25519"
)
func (k *okpPublicKey) FromRaw(rawKeyIf interface{}) error {
k.mu.Lock()
defer k.mu.Unlock()
var crv jwa.EllipticCurveAlgorithm
switch rawKey := rawKeyIf.(type) {
case ed25519.PublicKey:
k.x = rawKey
crv = jwa.Ed25519
k.crv = &crv
case x25519.PublicKey:
k.x = rawKey
crv = jwa.X25519
k.crv = &crv
default:
return fmt.Errorf(`unknown key type %T`, rawKeyIf)
}
return nil
}
func (k *okpPrivateKey) FromRaw(rawKeyIf interface{}) error {
k.mu.Lock()
defer k.mu.Unlock()
var crv jwa.EllipticCurveAlgorithm
switch rawKey := rawKeyIf.(type) {
case ed25519.PrivateKey:
k.d = rawKey.Seed()
k.x = rawKey.Public().(ed25519.PublicKey) //nolint:forcetypeassert
crv = jwa.Ed25519
k.crv = &crv
case x25519.PrivateKey:
k.d = rawKey.Seed()
k.x = rawKey.Public().(x25519.PublicKey) //nolint:forcetypeassert
crv = jwa.X25519
k.crv = &crv
default:
return fmt.Errorf(`unknown key type %T`, rawKeyIf)
}
return nil
}
func buildOKPPublicKey(alg jwa.EllipticCurveAlgorithm, xbuf []byte) (interface{}, error) {
switch alg {
case jwa.Ed25519:
return ed25519.PublicKey(xbuf), nil
case jwa.X25519:
return x25519.PublicKey(xbuf), nil
default:
return nil, fmt.Errorf(`invalid curve algorithm %s`, alg)
}
}
// Raw returns the EC-DSA public key represented by this JWK
func (k *okpPublicKey) Raw(v interface{}) error {
k.mu.RLock()
defer k.mu.RUnlock()
pubk, err := buildOKPPublicKey(k.Crv(), k.x)
if err != nil {
return fmt.Errorf(`failed to build public key: %w`, err)
}
return blackmagic.AssignIfCompatible(v, pubk)
}
func buildOKPPrivateKey(alg jwa.EllipticCurveAlgorithm, xbuf []byte, dbuf []byte) (interface{}, error) {
if len(dbuf) == 0 {
return nil, fmt.Errorf(`cannot use empty seed`)
}
switch alg {
case jwa.Ed25519:
if len(dbuf) != ed25519.SeedSize {
return nil, fmt.Errorf(`wrong private key size`)
}
ret := ed25519.NewKeyFromSeed(dbuf)
//nolint:forcetypeassert
if !bytes.Equal(xbuf, ret.Public().(ed25519.PublicKey)) {
return nil, fmt.Errorf(`invalid x value given d value`)
}
return ret, nil
case jwa.X25519:
ret, err := x25519.NewKeyFromSeed(dbuf)
if err != nil {
return nil, fmt.Errorf(`unable to construct x25519 private key from seed: %w`, err)
}
//nolint:forcetypeassert
if !bytes.Equal(xbuf, ret.Public().(x25519.PublicKey)) {
return nil, fmt.Errorf(`invalid x value given d value`)
}
return ret, nil
default:
return nil, fmt.Errorf(`invalid curve algorithm %s`, alg)
}
}
func (k *okpPrivateKey) Raw(v interface{}) error {
k.mu.RLock()
defer k.mu.RUnlock()
privk, err := buildOKPPrivateKey(k.Crv(), k.x, k.d)
if err != nil {
return fmt.Errorf(`failed to build public key: %w`, err)
}
return blackmagic.AssignIfCompatible(v, privk)
}
func makeOKPPublicKey(v interface {
makePairs() []*HeaderPair
}) (Key, error) {
newKey := newOKPPublicKey()
// Iterate and copy everything except for the bits that should not be in the public key
for _, pair := range v.makePairs() {
switch pair.Key {
case OKPDKey:
continue
default:
//nolint:forcetypeassert
key := pair.Key.(string)
if err := newKey.Set(key, pair.Value); err != nil {
return nil, fmt.Errorf(`failed to set field %q: %w`, key, err)
}
}
}
return newKey, nil
}
func (k *okpPrivateKey) PublicKey() (Key, error) {
return makeOKPPublicKey(k)
}
func (k *okpPublicKey) PublicKey() (Key, error) {
return makeOKPPublicKey(k)
}
func okpThumbprint(hash crypto.Hash, crv, x string) []byte {
h := hash.New()
fmt.Fprint(h, `{"crv":"`)
fmt.Fprint(h, crv)
fmt.Fprint(h, `","kty":"OKP","x":"`)
fmt.Fprint(h, x)
fmt.Fprint(h, `"}`)
return h.Sum(nil)
}
// Thumbprint returns the JWK thumbprint using the indicated
// hashing algorithm, according to RFC 7638 / 8037
func (k okpPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
k.mu.RLock()
defer k.mu.RUnlock()
return okpThumbprint(
hash,
k.Crv().String(),
base64.EncodeToString(k.x),
), nil
}
// Thumbprint returns the JWK thumbprint using the indicated
// hashing algorithm, according to RFC 7638 / 8037
func (k okpPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
k.mu.RLock()
defer k.mu.RUnlock()
return okpThumbprint(
hash,
k.Crv().String(),
base64.EncodeToString(k.x),
), nil
}
func validateOKPKey(key interface {
Crv() jwa.EllipticCurveAlgorithm
X() []byte
}) error {
if key.Crv() == jwa.InvalidEllipticCurve {
return fmt.Errorf(`invalid curve algorithm`)
}
if len(key.X()) == 0 {
return fmt.Errorf(`missing "x" field`)
}
if priv, ok := key.(interface{ D() []byte }); ok {
if len(priv.D()) == 0 {
return fmt.Errorf(`missing "d" field`)
}
}
return nil
}
func (k *okpPublicKey) Validate() error {
k.mu.RLock()
defer k.mu.RUnlock()
if err := validateOKPKey(k); err != nil {
return NewKeyValidationError(fmt.Errorf(`jwk.OKPPublicKey: %w`, err))
}
return nil
}
func (k *okpPrivateKey) Validate() error {
k.mu.RLock()
defer k.mu.RUnlock()
if err := validateOKPKey(k); err != nil {
return NewKeyValidationError(fmt.Errorf(`jwk.OKPPrivateKey: %w`, err))
}
return nil
}