github.com/lestrrat-go/jwx/v2@v2.0.21/jwk/rsa.go (about) 1 package jwk 2 3 import ( 4 "crypto" 5 "crypto/rsa" 6 "encoding/binary" 7 "fmt" 8 "math/big" 9 10 "github.com/lestrrat-go/blackmagic" 11 "github.com/lestrrat-go/jwx/v2/internal/base64" 12 "github.com/lestrrat-go/jwx/v2/internal/pool" 13 ) 14 15 func (k *rsaPrivateKey) FromRaw(rawKey *rsa.PrivateKey) error { 16 k.mu.Lock() 17 defer k.mu.Unlock() 18 19 d, err := bigIntToBytes(rawKey.D) 20 if err != nil { 21 return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err) 22 } 23 k.d = d 24 25 l := len(rawKey.Primes) 26 27 if l < 0 /* I know, I'm being paranoid */ || l > 2 { 28 return fmt.Errorf(`invalid number of primes in rsa.PrivateKey: need 0 to 2, but got %d`, len(rawKey.Primes)) 29 } 30 31 if l > 0 { 32 p, err := bigIntToBytes(rawKey.Primes[0]) 33 if err != nil { 34 return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err) 35 } 36 k.p = p 37 } 38 39 if l > 1 { 40 q, err := bigIntToBytes(rawKey.Primes[1]) 41 if err != nil { 42 return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err) 43 } 44 k.q = q 45 } 46 47 // dp, dq, qi are optional values 48 if v, err := bigIntToBytes(rawKey.Precomputed.Dp); err == nil { 49 k.dp = v 50 } 51 if v, err := bigIntToBytes(rawKey.Precomputed.Dq); err == nil { 52 k.dq = v 53 } 54 if v, err := bigIntToBytes(rawKey.Precomputed.Qinv); err == nil { 55 k.qi = v 56 } 57 58 // public key part 59 n, e, err := rsaPublicKeyByteValuesFromRaw(&rawKey.PublicKey) 60 if err != nil { 61 return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err) 62 } 63 k.n = n 64 k.e = e 65 66 return nil 67 } 68 69 func rsaPublicKeyByteValuesFromRaw(rawKey *rsa.PublicKey) ([]byte, []byte, error) { 70 n, err := bigIntToBytes(rawKey.N) 71 if err != nil { 72 return nil, nil, fmt.Errorf(`invalid rsa.PublicKey: %w`, err) 73 } 74 75 data := make([]byte, 8) 76 binary.BigEndian.PutUint64(data, uint64(rawKey.E)) 77 i := 0 78 for ; i < len(data); i++ { 79 if data[i] != 0x0 { 80 break 81 } 82 } 83 return n, data[i:], nil 84 } 85 86 func (k *rsaPublicKey) FromRaw(rawKey *rsa.PublicKey) error { 87 k.mu.Lock() 88 defer k.mu.Unlock() 89 90 n, e, err := rsaPublicKeyByteValuesFromRaw(rawKey) 91 if err != nil { 92 return fmt.Errorf(`invalid rsa.PrivateKey: %w`, err) 93 } 94 k.n = n 95 k.e = e 96 97 return nil 98 } 99 100 func (k *rsaPrivateKey) Raw(v interface{}) error { 101 k.mu.RLock() 102 defer k.mu.RUnlock() 103 104 var d, q, p big.Int // note: do not use from sync.Pool 105 106 d.SetBytes(k.d) 107 q.SetBytes(k.q) 108 p.SetBytes(k.p) 109 110 // optional fields 111 var dp, dq, qi *big.Int 112 if len(k.dp) > 0 { 113 dp = &big.Int{} // note: do not use from sync.Pool 114 dp.SetBytes(k.dp) 115 } 116 117 if len(k.dq) > 0 { 118 dq = &big.Int{} // note: do not use from sync.Pool 119 dq.SetBytes(k.dq) 120 } 121 122 if len(k.qi) > 0 { 123 qi = &big.Int{} // note: do not use from sync.Pool 124 qi.SetBytes(k.qi) 125 } 126 127 var key rsa.PrivateKey 128 129 pubk := newRSAPublicKey() 130 pubk.n = k.n 131 pubk.e = k.e 132 if err := pubk.Raw(&key.PublicKey); err != nil { 133 return fmt.Errorf(`failed to materialize RSA public key: %w`, err) 134 } 135 136 key.D = &d 137 key.Primes = []*big.Int{&p, &q} 138 139 if dp != nil { 140 key.Precomputed.Dp = dp 141 } 142 if dq != nil { 143 key.Precomputed.Dq = dq 144 } 145 if qi != nil { 146 key.Precomputed.Qinv = qi 147 } 148 key.Precomputed.CRTValues = []rsa.CRTValue{} 149 150 return blackmagic.AssignIfCompatible(v, &key) 151 } 152 153 // Raw takes the values stored in the Key object, and creates the 154 // corresponding *rsa.PublicKey object. 155 func (k *rsaPublicKey) Raw(v interface{}) error { 156 k.mu.RLock() 157 defer k.mu.RUnlock() 158 159 var key rsa.PublicKey 160 161 n := pool.GetBigInt() 162 e := pool.GetBigInt() 163 defer pool.ReleaseBigInt(e) 164 165 n.SetBytes(k.n) 166 e.SetBytes(k.e) 167 168 key.N = n 169 key.E = int(e.Int64()) 170 171 return blackmagic.AssignIfCompatible(v, &key) 172 } 173 174 func makeRSAPublicKey(v interface { 175 makePairs() []*HeaderPair 176 }) (Key, error) { 177 newKey := newRSAPublicKey() 178 179 // Iterate and copy everything except for the bits that should not be in the public key 180 for _, pair := range v.makePairs() { 181 switch pair.Key { 182 case RSADKey, RSADPKey, RSADQKey, RSAPKey, RSAQKey, RSAQIKey: 183 continue 184 default: 185 //nolint:forcetypeassert 186 key := pair.Key.(string) 187 if err := newKey.Set(key, pair.Value); err != nil { 188 return nil, fmt.Errorf(`failed to set field %q: %w`, key, err) 189 } 190 } 191 } 192 193 return newKey, nil 194 } 195 196 func (k *rsaPrivateKey) PublicKey() (Key, error) { 197 return makeRSAPublicKey(k) 198 } 199 200 func (k *rsaPublicKey) PublicKey() (Key, error) { 201 return makeRSAPublicKey(k) 202 } 203 204 // Thumbprint returns the JWK thumbprint using the indicated 205 // hashing algorithm, according to RFC 7638 206 func (k rsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { 207 k.mu.RLock() 208 defer k.mu.RUnlock() 209 210 var key rsa.PrivateKey 211 if err := k.Raw(&key); err != nil { 212 return nil, fmt.Errorf(`failed to materialize RSA private key: %w`, err) 213 } 214 return rsaThumbprint(hash, &key.PublicKey) 215 } 216 217 func (k rsaPublicKey) Thumbprint(hash crypto.Hash) ([]byte, error) { 218 k.mu.RLock() 219 defer k.mu.RUnlock() 220 221 var key rsa.PublicKey 222 if err := k.Raw(&key); err != nil { 223 return nil, fmt.Errorf(`failed to materialize RSA public key: %w`, err) 224 } 225 return rsaThumbprint(hash, &key) 226 } 227 228 func rsaThumbprint(hash crypto.Hash, key *rsa.PublicKey) ([]byte, error) { 229 buf := pool.GetBytesBuffer() 230 defer pool.ReleaseBytesBuffer(buf) 231 232 buf.WriteString(`{"e":"`) 233 buf.WriteString(base64.EncodeUint64ToString(uint64(key.E))) 234 buf.WriteString(`","kty":"RSA","n":"`) 235 buf.WriteString(base64.EncodeToString(key.N.Bytes())) 236 buf.WriteString(`"}`) 237 238 h := hash.New() 239 if _, err := buf.WriteTo(h); err != nil { 240 return nil, fmt.Errorf(`failed to write rsaThumbprint: %w`, err) 241 } 242 return h.Sum(nil), nil 243 } 244 245 func validateRSAKey(key interface { 246 N() []byte 247 E() []byte 248 }, checkPrivate bool) error { 249 if len(key.N()) == 0 { 250 // Ideally we would like to check for the actual length, but unlike 251 // EC keys, we have nothing in the key itself that will tell us 252 // how many bits this key should have. 253 return fmt.Errorf(`missing "n" value`) 254 } 255 if len(key.E()) == 0 { 256 return fmt.Errorf(`missing "e" value`) 257 } 258 if checkPrivate { 259 if priv, ok := key.(interface{ D() []byte }); ok { 260 if len(priv.D()) == 0 { 261 return fmt.Errorf(`missing "d" value`) 262 } 263 } else { 264 return fmt.Errorf(`missing "d" value`) 265 } 266 } 267 268 return nil 269 } 270 271 func (k *rsaPrivateKey) Validate() error { 272 if err := validateRSAKey(k, true); err != nil { 273 return NewKeyValidationError(fmt.Errorf(`jwk.RSAPrivateKey: %w`, err)) 274 } 275 return nil 276 } 277 278 func (k *rsaPublicKey) Validate() error { 279 if err := validateRSAKey(k, false); err != nil { 280 return NewKeyValidationError(fmt.Errorf(`jwk.RSAPublicKey: %w`, err)) 281 } 282 return nil 283 }