github.com/lestrrat-go/jwx/v2@v2.0.21/jwk/rsa.go (about)

     1  package jwk
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/rsa"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"math/big"
     9  
    10  	"github.com/lestrrat-go/blackmagic"
    11  	"github.com/lestrrat-go/jwx/v2/internal/base64"
    12  	"github.com/lestrrat-go/jwx/v2/internal/pool"
    13  )
    14  
    15  func (k *rsaPrivateKey) FromRaw(rawKey *rsa.PrivateKey) error {
    16  	k.mu.Lock()
    17  	defer k.mu.Unlock()
    18  
    19  	d, err := bigIntToBytes(rawKey.D)
    20  	if err != nil {
    21  		return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
    22  	}
    23  	k.d = d
    24  
    25  	l := len(rawKey.Primes)
    26  
    27  	if l < 0 /* I know, I'm being paranoid */ || l > 2 {
    28  		return fmt.Errorf(`invalid number of primes in rsa.PrivateKey: need 0 to 2, but got %d`, len(rawKey.Primes))
    29  	}
    30  
    31  	if l > 0 {
    32  		p, err := bigIntToBytes(rawKey.Primes[0])
    33  		if err != nil {
    34  			return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
    35  		}
    36  		k.p = p
    37  	}
    38  
    39  	if l > 1 {
    40  		q, err := bigIntToBytes(rawKey.Primes[1])
    41  		if err != nil {
    42  			return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
    43  		}
    44  		k.q = q
    45  	}
    46  
    47  	// dp, dq, qi are optional values
    48  	if v, err := bigIntToBytes(rawKey.Precomputed.Dp); err == nil {
    49  		k.dp = v
    50  	}
    51  	if v, err := bigIntToBytes(rawKey.Precomputed.Dq); err == nil {
    52  		k.dq = v
    53  	}
    54  	if v, err := bigIntToBytes(rawKey.Precomputed.Qinv); err == nil {
    55  		k.qi = v
    56  	}
    57  
    58  	// public key part
    59  	n, e, err := rsaPublicKeyByteValuesFromRaw(&rawKey.PublicKey)
    60  	if err != nil {
    61  		return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
    62  	}
    63  	k.n = n
    64  	k.e = e
    65  
    66  	return nil
    67  }
    68  
    69  func rsaPublicKeyByteValuesFromRaw(rawKey *rsa.PublicKey) ([]byte, []byte, error) {
    70  	n, err := bigIntToBytes(rawKey.N)
    71  	if err != nil {
    72  		return nil, nil, fmt.Errorf(`invalid rsa.PublicKey: %w`, err)
    73  	}
    74  
    75  	data := make([]byte, 8)
    76  	binary.BigEndian.PutUint64(data, uint64(rawKey.E))
    77  	i := 0
    78  	for ; i < len(data); i++ {
    79  		if data[i] != 0x0 {
    80  			break
    81  		}
    82  	}
    83  	return n, data[i:], nil
    84  }
    85  
    86  func (k *rsaPublicKey) FromRaw(rawKey *rsa.PublicKey) error {
    87  	k.mu.Lock()
    88  	defer k.mu.Unlock()
    89  
    90  	n, e, err := rsaPublicKeyByteValuesFromRaw(rawKey)
    91  	if err != nil {
    92  		return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err)
    93  	}
    94  	k.n = n
    95  	k.e = e
    96  
    97  	return nil
    98  }
    99  
   100  func (k *rsaPrivateKey) Raw(v interface{}) error {
   101  	k.mu.RLock()
   102  	defer k.mu.RUnlock()
   103  
   104  	var d, q, p big.Int // note: do not use from sync.Pool
   105  
   106  	d.SetBytes(k.d)
   107  	q.SetBytes(k.q)
   108  	p.SetBytes(k.p)
   109  
   110  	// optional fields
   111  	var dp, dq, qi *big.Int
   112  	if len(k.dp) > 0 {
   113  		dp = &big.Int{} // note: do not use from sync.Pool
   114  		dp.SetBytes(k.dp)
   115  	}
   116  
   117  	if len(k.dq) > 0 {
   118  		dq = &big.Int{} // note: do not use from sync.Pool
   119  		dq.SetBytes(k.dq)
   120  	}
   121  
   122  	if len(k.qi) > 0 {
   123  		qi = &big.Int{} // note: do not use from sync.Pool
   124  		qi.SetBytes(k.qi)
   125  	}
   126  
   127  	var key rsa.PrivateKey
   128  
   129  	pubk := newRSAPublicKey()
   130  	pubk.n = k.n
   131  	pubk.e = k.e
   132  	if err := pubk.Raw(&key.PublicKey); err != nil {
   133  		return fmt.Errorf(`failed to materialize RSA public key: %w`, err)
   134  	}
   135  
   136  	key.D = &d
   137  	key.Primes = []*big.Int{&p, &q}
   138  
   139  	if dp != nil {
   140  		key.Precomputed.Dp = dp
   141  	}
   142  	if dq != nil {
   143  		key.Precomputed.Dq = dq
   144  	}
   145  	if qi != nil {
   146  		key.Precomputed.Qinv = qi
   147  	}
   148  	key.Precomputed.CRTValues = []rsa.CRTValue{}
   149  
   150  	return blackmagic.AssignIfCompatible(v, &key)
   151  }
   152  
   153  // Raw takes the values stored in the Key object, and creates the
   154  // corresponding *rsa.PublicKey object.
   155  func (k *rsaPublicKey) Raw(v interface{}) error {
   156  	k.mu.RLock()
   157  	defer k.mu.RUnlock()
   158  
   159  	var key rsa.PublicKey
   160  
   161  	n := pool.GetBigInt()
   162  	e := pool.GetBigInt()
   163  	defer pool.ReleaseBigInt(e)
   164  
   165  	n.SetBytes(k.n)
   166  	e.SetBytes(k.e)
   167  
   168  	key.N = n
   169  	key.E = int(e.Int64())
   170  
   171  	return blackmagic.AssignIfCompatible(v, &key)
   172  }
   173  
   174  func makeRSAPublicKey(v interface {
   175  	makePairs() []*HeaderPair
   176  }) (Key, error) {
   177  	newKey := newRSAPublicKey()
   178  
   179  	// Iterate and copy everything except for the bits that should not be in the public key
   180  	for _, pair := range v.makePairs() {
   181  		switch pair.Key {
   182  		case RSADKey, RSADPKey, RSADQKey, RSAPKey, RSAQKey, RSAQIKey:
   183  			continue
   184  		default:
   185  			//nolint:forcetypeassert
   186  			key := pair.Key.(string)
   187  			if err := newKey.Set(key, pair.Value); err != nil {
   188  				return nil, fmt.Errorf(`failed to set field %q: %w`, key, err)
   189  			}
   190  		}
   191  	}
   192  
   193  	return newKey, nil
   194  }
   195  
   196  func (k *rsaPrivateKey) PublicKey() (Key, error) {
   197  	return makeRSAPublicKey(k)
   198  }
   199  
   200  func (k *rsaPublicKey) PublicKey() (Key, error) {
   201  	return makeRSAPublicKey(k)
   202  }
   203  
   204  // Thumbprint returns the JWK thumbprint using the indicated
   205  // hashing algorithm, according to RFC 7638
   206  func (k rsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
   207  	k.mu.RLock()
   208  	defer k.mu.RUnlock()
   209  
   210  	var key rsa.PrivateKey
   211  	if err := k.Raw(&key); err != nil {
   212  		return nil, fmt.Errorf(`failed to materialize RSA private key: %w`, err)
   213  	}
   214  	return rsaThumbprint(hash, &key.PublicKey)
   215  }
   216  
   217  func (k rsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
   218  	k.mu.RLock()
   219  	defer k.mu.RUnlock()
   220  
   221  	var key rsa.PublicKey
   222  	if err := k.Raw(&key); err != nil {
   223  		return nil, fmt.Errorf(`failed to materialize RSA public key: %w`, err)
   224  	}
   225  	return rsaThumbprint(hash, &key)
   226  }
   227  
   228  func rsaThumbprint(hash crypto.Hash, key *rsa.PublicKey) ([]byte, error) {
   229  	buf := pool.GetBytesBuffer()
   230  	defer pool.ReleaseBytesBuffer(buf)
   231  
   232  	buf.WriteString(`{"e":"`)
   233  	buf.WriteString(base64.EncodeUint64ToString(uint64(key.E)))
   234  	buf.WriteString(`","kty":"RSA","n":"`)
   235  	buf.WriteString(base64.EncodeToString(key.N.Bytes()))
   236  	buf.WriteString(`"}`)
   237  
   238  	h := hash.New()
   239  	if _, err := buf.WriteTo(h); err != nil {
   240  		return nil, fmt.Errorf(`failed to write rsaThumbprint: %w`, err)
   241  	}
   242  	return h.Sum(nil), nil
   243  }
   244  
   245  func validateRSAKey(key interface {
   246  	N() []byte
   247  	E() []byte
   248  }, checkPrivate bool) error {
   249  	if len(key.N()) == 0 {
   250  		// Ideally we would like to check for the actual length, but unlike
   251  		// EC keys, we have nothing in the key itself that will tell us
   252  		// how many bits this key should have.
   253  		return fmt.Errorf(`missing "n" value`)
   254  	}
   255  	if len(key.E()) == 0 {
   256  		return fmt.Errorf(`missing "e" value`)
   257  	}
   258  	if checkPrivate {
   259  		if priv, ok := key.(interface{ D() []byte }); ok {
   260  			if len(priv.D()) == 0 {
   261  				return fmt.Errorf(`missing "d" value`)
   262  			}
   263  		} else {
   264  			return fmt.Errorf(`missing "d" value`)
   265  		}
   266  	}
   267  
   268  	return nil
   269  }
   270  
   271  func (k *rsaPrivateKey) Validate() error {
   272  	if err := validateRSAKey(k, true); err != nil {
   273  		return NewKeyValidationError(fmt.Errorf(`jwk.RSAPrivateKey: %w`, err))
   274  	}
   275  	return nil
   276  }
   277  
   278  func (k *rsaPublicKey) Validate() error {
   279  	if err := validateRSAKey(k, false); err != nil {
   280  		return NewKeyValidationError(fmt.Errorf(`jwk.RSAPublicKey: %w`, err))
   281  	}
   282  	return nil
   283  }