github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/decrypt.go (about) 1 package jwe 2 3 import ( 4 "crypto/aes" 5 cryptocipher "crypto/cipher" 6 "crypto/ecdsa" 7 "crypto/rsa" 8 "crypto/sha256" 9 "crypto/sha512" 10 "fmt" 11 "hash" 12 13 "golang.org/x/crypto/pbkdf2" 14 15 "github.com/lestrrat-go/jwx/v2/internal/keyconv" 16 "github.com/lestrrat-go/jwx/v2/jwa" 17 "github.com/lestrrat-go/jwx/v2/jwe/internal/cipher" 18 "github.com/lestrrat-go/jwx/v2/jwe/internal/content_crypt" 19 "github.com/lestrrat-go/jwx/v2/jwe/internal/keyenc" 20 "github.com/lestrrat-go/jwx/v2/x25519" 21 ) 22 23 // decrypter is responsible for taking various components to decrypt a message. 24 // its operation is not concurrency safe. You must provide locking yourself 25 // 26 //nolint:govet 27 type decrypter struct { 28 aad []byte 29 apu []byte 30 apv []byte 31 cek *[]byte 32 computedAad []byte 33 iv []byte 34 keyiv []byte 35 keysalt []byte 36 keytag []byte 37 tag []byte 38 privkey interface{} 39 pubkey interface{} 40 ctalg jwa.ContentEncryptionAlgorithm 41 keyalg jwa.KeyEncryptionAlgorithm 42 cipher content_crypt.Cipher 43 keycount int 44 } 45 46 // newDecrypter Creates a new Decrypter instance. You must supply the 47 // rest of parameters via their respective setter methods before 48 // calling Decrypt(). 49 // 50 // privkey must be a private key in its "raw" format (i.e. something like 51 // *rsa.PrivateKey, instead of jwk.Key) 52 // 53 // You should consider this object immutable once you assign values to it. 54 func newDecrypter(keyalg jwa.KeyEncryptionAlgorithm, ctalg jwa.ContentEncryptionAlgorithm, privkey interface{}) *decrypter { 55 return &decrypter{ 56 ctalg: ctalg, 57 keyalg: keyalg, 58 privkey: privkey, 59 } 60 } 61 62 func (d *decrypter) AgreementPartyUInfo(apu []byte) *decrypter { 63 d.apu = apu 64 return d 65 } 66 67 func (d *decrypter) AgreementPartyVInfo(apv []byte) *decrypter { 68 d.apv = apv 69 return d 70 } 71 72 func (d *decrypter) AuthenticatedData(aad []byte) *decrypter { 73 d.aad = aad 74 return d 75 } 76 77 func (d *decrypter) ComputedAuthenticatedData(aad []byte) *decrypter { 78 d.computedAad = aad 79 return d 80 } 81 82 func (d *decrypter) ContentEncryptionAlgorithm(ctalg jwa.ContentEncryptionAlgorithm) *decrypter { 83 d.ctalg = ctalg 84 return d 85 } 86 87 func (d *decrypter) InitializationVector(iv []byte) *decrypter { 88 d.iv = iv 89 return d 90 } 91 92 func (d *decrypter) KeyCount(keycount int) *decrypter { 93 d.keycount = keycount 94 return d 95 } 96 97 func (d *decrypter) KeyInitializationVector(keyiv []byte) *decrypter { 98 d.keyiv = keyiv 99 return d 100 } 101 102 func (d *decrypter) KeySalt(keysalt []byte) *decrypter { 103 d.keysalt = keysalt 104 return d 105 } 106 107 func (d *decrypter) KeyTag(keytag []byte) *decrypter { 108 d.keytag = keytag 109 return d 110 } 111 112 // PublicKey sets the public key to be used in decoding EC based encryptions. 113 // The key must be in its "raw" format (i.e. *ecdsa.PublicKey, instead of jwk.Key) 114 func (d *decrypter) PublicKey(pubkey interface{}) *decrypter { 115 d.pubkey = pubkey 116 return d 117 } 118 119 func (d *decrypter) Tag(tag []byte) *decrypter { 120 d.tag = tag 121 return d 122 } 123 124 func (d *decrypter) CEK(ptr *[]byte) *decrypter { 125 d.cek = ptr 126 return d 127 } 128 129 func (d *decrypter) ContentCipher() (content_crypt.Cipher, error) { 130 if d.cipher == nil { 131 switch d.ctalg { 132 case jwa.A128GCM, jwa.A192GCM, jwa.A256GCM, jwa.A128CBC_HS256, jwa.A192CBC_HS384, jwa.A256CBC_HS512: 133 cipher, err := cipher.NewAES(d.ctalg) 134 if err != nil { 135 return nil, fmt.Errorf(`failed to build content cipher for %s: %w`, d.ctalg, err) 136 } 137 d.cipher = cipher 138 default: 139 return nil, fmt.Errorf(`invalid content cipher algorithm (%s)`, d.ctalg) 140 } 141 } 142 143 return d.cipher, nil 144 } 145 146 func (d *decrypter) Decrypt(recipient Recipient, ciphertext []byte, msg *Message) (plaintext []byte, err error) { 147 cek, keyerr := d.DecryptKey(recipient, msg) 148 if keyerr != nil { 149 err = fmt.Errorf(`failed to decrypt key: %w`, keyerr) 150 return 151 } 152 153 cipher, ciphererr := d.ContentCipher() 154 if ciphererr != nil { 155 err = fmt.Errorf(`failed to fetch content crypt cipher: %w`, ciphererr) 156 return 157 } 158 159 computedAad := d.computedAad 160 if d.aad != nil { 161 computedAad = append(append(computedAad, '.'), d.aad...) 162 } 163 164 plaintext, err = cipher.Decrypt(cek, d.iv, ciphertext, d.tag, computedAad) 165 if err != nil { 166 err = fmt.Errorf(`failed to decrypt payload: %w`, err) 167 return 168 } 169 170 if d.cek != nil { 171 *d.cek = cek 172 } 173 return plaintext, nil 174 } 175 176 func (d *decrypter) decryptSymmetricKey(recipientKey, cek []byte) ([]byte, error) { 177 switch d.keyalg { 178 case jwa.DIRECT: 179 return cek, nil 180 case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW: 181 var hashFunc func() hash.Hash 182 var keylen int 183 switch d.keyalg { 184 case jwa.PBES2_HS256_A128KW: 185 hashFunc = sha256.New 186 keylen = 16 187 case jwa.PBES2_HS384_A192KW: 188 hashFunc = sha512.New384 189 keylen = 24 190 case jwa.PBES2_HS512_A256KW: 191 hashFunc = sha512.New 192 keylen = 32 193 } 194 salt := []byte(d.keyalg) 195 salt = append(salt, byte(0)) 196 salt = append(salt, d.keysalt...) 197 cek = pbkdf2.Key(cek, salt, d.keycount, keylen, hashFunc) 198 fallthrough 199 case jwa.A128KW, jwa.A192KW, jwa.A256KW: 200 block, err := aes.NewCipher(cek) 201 if err != nil { 202 return nil, fmt.Errorf(`failed to create new AES cipher: %w`, err) 203 } 204 205 jek, err := keyenc.Unwrap(block, recipientKey) 206 if err != nil { 207 return nil, fmt.Errorf(`failed to unwrap key: %w`, err) 208 } 209 210 return jek, nil 211 case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW: 212 if len(d.keyiv) != 12 { 213 return nil, fmt.Errorf("GCM requires 96-bit iv, got %d", len(d.keyiv)*8) 214 } 215 if len(d.keytag) != 16 { 216 return nil, fmt.Errorf("GCM requires 128-bit tag, got %d", len(d.keytag)*8) 217 } 218 block, err := aes.NewCipher(cek) 219 if err != nil { 220 return nil, fmt.Errorf(`failed to create new AES cipher: %w`, err) 221 } 222 aesgcm, err := cryptocipher.NewGCM(block) 223 if err != nil { 224 return nil, fmt.Errorf(`failed to create new GCM wrap: %w`, err) 225 } 226 ciphertext := recipientKey[:] 227 ciphertext = append(ciphertext, d.keytag...) 228 jek, err := aesgcm.Open(nil, d.keyiv, ciphertext, nil) 229 if err != nil { 230 return nil, fmt.Errorf(`failed to decode key: %w`, err) 231 } 232 return jek, nil 233 default: 234 return nil, fmt.Errorf("decrypt key: unsupported algorithm %s", d.keyalg) 235 } 236 } 237 238 func (d *decrypter) DecryptKey(recipient Recipient, msg *Message) (cek []byte, err error) { 239 recipientKey := recipient.EncryptedKey() 240 if kd, ok := d.privkey.(KeyDecrypter); ok { 241 return kd.DecryptKey(d.keyalg, recipientKey, recipient, msg) 242 } 243 244 if d.keyalg.IsSymmetric() { 245 var ok bool 246 cek, ok = d.privkey.([]byte) 247 if !ok { 248 return nil, fmt.Errorf("decrypt key: []byte is required as the key to build %s key decrypter (got %T)", d.keyalg, d.privkey) 249 } 250 251 return d.decryptSymmetricKey(recipientKey, cek) 252 } 253 254 k, err := d.BuildKeyDecrypter() 255 if err != nil { 256 return nil, fmt.Errorf(`failed to build key decrypter: %w`, err) 257 } 258 259 cek, err = k.Decrypt(recipientKey) 260 if err != nil { 261 return nil, fmt.Errorf(`failed to decrypt key: %w`, err) 262 } 263 264 return cek, nil 265 } 266 267 func (d *decrypter) BuildKeyDecrypter() (keyenc.Decrypter, error) { 268 cipher, err := d.ContentCipher() 269 if err != nil { 270 return nil, fmt.Errorf(`failed to fetch content crypt cipher: %w`, err) 271 } 272 273 switch alg := d.keyalg; alg { 274 case jwa.RSA1_5: 275 var privkey rsa.PrivateKey 276 if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil { 277 return nil, fmt.Errorf(`*rsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err) 278 } 279 280 return keyenc.NewRSAPKCS15Decrypt(alg, &privkey, cipher.KeySize()/2), nil 281 case jwa.RSA_OAEP, jwa.RSA_OAEP_256: 282 var privkey rsa.PrivateKey 283 if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil { 284 return nil, fmt.Errorf(`*rsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err) 285 } 286 287 return keyenc.NewRSAOAEPDecrypt(alg, &privkey) 288 case jwa.A128KW, jwa.A192KW, jwa.A256KW: 289 sharedkey, ok := d.privkey.([]byte) 290 if !ok { 291 return nil, fmt.Errorf("[]byte is required as the key to build %s key decrypter", alg) 292 } 293 294 return keyenc.NewAES(alg, sharedkey) 295 case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: 296 switch d.pubkey.(type) { 297 case x25519.PublicKey: 298 return keyenc.NewECDHESDecrypt(alg, d.ctalg, d.pubkey, d.apu, d.apv, d.privkey), nil 299 default: 300 var pubkey ecdsa.PublicKey 301 if err := keyconv.ECDSAPublicKey(&pubkey, d.pubkey); err != nil { 302 return nil, fmt.Errorf(`*ecdsa.PublicKey is required as the key to build %s key decrypter: %w`, alg, err) 303 } 304 305 var privkey ecdsa.PrivateKey 306 if err := keyconv.ECDSAPrivateKey(&privkey, d.privkey); err != nil { 307 return nil, fmt.Errorf(`*ecdsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err) 308 } 309 310 return keyenc.NewECDHESDecrypt(alg, d.ctalg, &pubkey, d.apu, d.apv, &privkey), nil 311 } 312 default: 313 return nil, fmt.Errorf(`unsupported algorithm for key decryption (%s)`, alg) 314 } 315 }