github.com/hyperledger/aries-framework-go@v0.3.2/pkg/doc/jose/decrypter.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package jose 8 9 import ( 10 "crypto/ecdsa" 11 "encoding/base64" 12 "encoding/json" 13 "fmt" 14 "strings" 15 16 "github.com/google/tink/go/keyset" 17 18 ecdhpb "github.com/hyperledger/aries-framework-go/component/kmscrypto/crypto/tinkcrypto/primitive/proto/ecdh_aead_go_proto" 19 "github.com/hyperledger/aries-framework-go/component/kmscrypto/doc/jose/jwk" 20 cryptoapi "github.com/hyperledger/aries-framework-go/pkg/crypto" 21 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite" 22 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite/api" 23 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite/ecdh" 24 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite/keyio" 25 "github.com/hyperledger/aries-framework-go/pkg/doc/jose/kid/resolver" 26 "github.com/hyperledger/aries-framework-go/pkg/kms" 27 ) 28 29 // Decrypter interface to Decrypt JWE messages. 30 type Decrypter interface { 31 // Decrypt a deserialized JWE, extracts the corresponding recipient key to decrypt plaintext and returns it 32 Decrypt(jwe *JSONWebEncryption) ([]byte, error) 33 } 34 35 // JWEDecrypt is responsible for decrypting a JWE message and returns its protected plaintext. 36 type JWEDecrypt struct { 37 kidResolvers []resolver.KIDResolver 38 crypto cryptoapi.Crypto 39 kms kms.KeyManager 40 } 41 42 // NewJWEDecrypt creates a new JWEDecrypt instance to parse and decrypt a JWE message for a given recipient 43 // store is needed for Authcrypt only (to fetch sender's pre agreed upon public key), it is not needed for Anoncrypt. 44 func NewJWEDecrypt(kidResolvers []resolver.KIDResolver, c cryptoapi.Crypto, k kms.KeyManager) *JWEDecrypt { 45 return &JWEDecrypt{ 46 kidResolvers: kidResolvers, 47 crypto: c, 48 kms: k, 49 } 50 } 51 52 func getECDHDecPrimitive(cek []byte, encAlg EncAlg, nistpKW bool) (api.CompositeDecrypt, error) { 53 ceAlg := aeadAlg[encAlg] 54 55 if ceAlg <= 0 { 56 return nil, fmt.Errorf("invalid content encAlg: '%s'", encAlg) 57 } 58 59 kt := ecdh.KeyTemplateForECDHPrimitiveWithCEK(cek, nistpKW, ceAlg) 60 61 kh, err := keyset.NewHandle(kt) 62 if err != nil { 63 return nil, err 64 } 65 66 return ecdh.NewECDHDecrypt(kh) 67 } 68 69 // Decrypt a deserialized JWE, decrypts its protected content and returns plaintext. 70 func (jd *JWEDecrypt) Decrypt(jwe *JSONWebEncryption) ([]byte, error) { 71 encAlg, err := jd.validateAndExtractProtectedHeaders(jwe) 72 if err != nil { 73 return nil, fmt.Errorf("jwedecrypt: %w", err) 74 } 75 76 var wkOpts []cryptoapi.WrapKeyOpts 77 78 skid, ok := jwe.ProtectedHeaders.SenderKeyID() 79 if !ok { 80 skid, ok = fetchSKIDFromAPU(jwe) 81 } 82 83 if ok && skid != "" { 84 senderKH, e := jd.fetchSenderPubKey(skid, EncAlg(encAlg)) 85 if e != nil { 86 return nil, fmt.Errorf("jwedecrypt: failed to add sender public key for skid: %w", e) 87 } 88 89 wkOpts = append(wkOpts, cryptoapi.WithSender(senderKH), cryptoapi.WithTag([]byte(jwe.Tag))) 90 } 91 92 recWK, err := buildRecipientsWrappedKey(jwe) 93 if err != nil { 94 return nil, fmt.Errorf("jwedecrypt: failed to build recipients WK: %w", err) 95 } 96 97 cek, err := jd.unwrapCEK(recWK, wkOpts...) 98 if err != nil { 99 return nil, fmt.Errorf("jwedecrypt: %w", err) 100 } 101 102 if len(recWK) == 1 { 103 // ensure EPK is marshalled the same way as during encryption since it is merged into ProtectHeaders. 104 marshalledEPK, err := convertRecEPKToMarshalledJWK(&recWK[0].EPK) 105 if err != nil { 106 return nil, fmt.Errorf("jwedecrypt: %w", err) 107 } 108 109 jwe.ProtectedHeaders["epk"] = json.RawMessage(marshalledEPK) 110 } 111 112 return jd.decryptJWE(jwe, cek) 113 } 114 115 func fetchSKIDFromAPU(jwe *JSONWebEncryption) (string, bool) { 116 // for multi-recipients only: check apu in protectedHeaders if it's found for ECDH-1PU, if skid header is empty then 117 // use apu as skid instead. 118 if len(jwe.Recipients) > 1 { 119 if a, apuOK := jwe.ProtectedHeaders["apu"]; apuOK { 120 skidBytes, err := base64.RawURLEncoding.DecodeString(a.(string)) 121 if err != nil { 122 return "", false 123 } 124 125 return string(skidBytes), true 126 } 127 } 128 129 return "", false 130 } 131 132 //nolint:gocyclo 133 func (jd *JWEDecrypt) unwrapCEK(recWK []*cryptoapi.RecipientWrappedKey, 134 senderOpt ...cryptoapi.WrapKeyOpts) ([]byte, error) { 135 var ( 136 cek []byte 137 errs []error 138 ) 139 140 for _, rec := range recWK { 141 var unwrapOpts []cryptoapi.WrapKeyOpts 142 143 if strings.HasPrefix(rec.KID, "did:key") || strings.Index(rec.KID, "#") > 0 { 144 // resolve and use kms KID if did:key or KeyAgreement.ID. 145 resolvedRec, err := jd.resolveKID(rec.KID) 146 if err != nil { 147 errs = append(errs, err) 148 continue 149 } 150 151 // Need to get the kms KID in order to do kms.Get() since original rec.KID is a did:key/KeyAgreement.ID. 152 // This is necessary to ensure recipient is the owner of the key. 153 rec.KID = resolvedRec.KID 154 } 155 156 recKH, err := jd.kms.Get(rec.KID) 157 if err != nil { 158 continue 159 } 160 161 if rec.EPK.Type == ecdhpb.KeyType_OKP.String() { 162 unwrapOpts = append(unwrapOpts, cryptoapi.WithXC20PKW()) 163 } 164 165 if senderOpt != nil { 166 unwrapOpts = append(unwrapOpts, senderOpt...) 167 } 168 169 if len(unwrapOpts) > 0 { 170 cek, err = jd.crypto.UnwrapKey(rec, recKH, unwrapOpts...) 171 } else { 172 cek, err = jd.crypto.UnwrapKey(rec, recKH) 173 } 174 175 if err == nil { 176 break 177 } 178 179 errs = append(errs, err) 180 } 181 182 if len(cek) == 0 { 183 return nil, fmt.Errorf("failed to unwrap cek: %v", errs) 184 } 185 186 return cek, nil 187 } 188 189 func (jd *JWEDecrypt) resolveKID(kid string) (*cryptoapi.PublicKey, error) { 190 var errs []error 191 192 for _, resolver := range jd.kidResolvers { 193 rKID, err := resolver.Resolve(kid) 194 if err == nil { 195 return rKID, nil 196 } 197 198 errs = append(errs, err) 199 } 200 201 return nil, fmt.Errorf("resolveKID: %v", errs) 202 } 203 204 func (jd *JWEDecrypt) decryptJWE(jwe *JSONWebEncryption, cek []byte) ([]byte, error) { 205 encAlg, ok := jwe.ProtectedHeaders.Encryption() 206 if !ok { 207 return nil, fmt.Errorf("jwedecrypt: JWE 'enc' protected header is missing") 208 } 209 210 decPrimitive, err := getECDHDecPrimitive(cek, EncAlg(encAlg), true) 211 if err != nil { 212 return nil, fmt.Errorf("jwedecrypt: failed to get decryption primitive: %w", err) 213 } 214 215 encryptedData, err := buildEncryptedData(jwe) 216 if err != nil { 217 return nil, fmt.Errorf("jwedecrypt: failed to build encryptedData for Decrypt(): %w", err) 218 } 219 220 aadBytes := []byte(jwe.AAD) 221 222 authData, err := computeAuthData(jwe.ProtectedHeaders, jwe.OrigProtectedHders, aadBytes) 223 if err != nil { 224 return nil, err 225 } 226 227 return decPrimitive.Decrypt(encryptedData, authData) 228 } 229 230 func (jd *JWEDecrypt) fetchSenderPubKey(skid string, encAlg EncAlg) (*keyset.Handle, error) { 231 senderKey, err := jd.resolveKID(skid) 232 if err != nil { 233 return nil, fmt.Errorf("fetchSenderPubKey: %w", err) 234 } 235 236 ceAlg := aeadAlg[encAlg] 237 238 if ceAlg <= 0 { 239 return nil, fmt.Errorf("fetchSenderPubKey: invalid content encAlg: '%s'", encAlg) 240 } 241 242 return keyio.PublicKeyToKeysetHandle(senderKey, ceAlg) 243 } 244 245 func (jd *JWEDecrypt) validateAndExtractProtectedHeaders(jwe *JSONWebEncryption) (string, error) { 246 if jwe == nil { 247 return "", fmt.Errorf("jwe is nil") 248 } 249 250 if len(jwe.ProtectedHeaders) == 0 { 251 return "", fmt.Errorf("jwe is missing protected headers") 252 } 253 254 protectedHeaders := jwe.ProtectedHeaders 255 256 encAlg, ok := protectedHeaders.Encryption() 257 if !ok { 258 return "", fmt.Errorf("jwe is missing encryption algorithm 'enc' header") 259 } 260 261 switch encAlg { 262 case string(A256GCM), string(XC20P), string(A128CBCHS256), 263 string(A192CBCHS384), string(A256CBCHS384), string(A256CBCHS512): 264 default: 265 return "", fmt.Errorf("encryption algorithm '%s' not supported", encAlg) 266 } 267 268 return encAlg, nil 269 } 270 271 func buildRecipientsWrappedKey(jwe *JSONWebEncryption) ([]*cryptoapi.RecipientWrappedKey, error) { 272 var ( 273 recipients []*cryptoapi.RecipientWrappedKey 274 err error 275 ) 276 277 for _, recJWE := range jwe.Recipients { 278 headers := recJWE.Header 279 alg, ok := jwe.ProtectedHeaders.Algorithm() 280 is1PU := ok && strings.Contains(strings.ToUpper(alg), "1PU") 281 282 if len(jwe.Recipients) == 1 || is1PU { 283 // compact serialization: it has only 1 recipient with no headers or 1pu, extract from protectedHeaders. 284 headers, err = extractRecipientHeaders(jwe.ProtectedHeaders) 285 if err != nil { 286 return nil, err 287 } 288 } 289 290 var recWK *cryptoapi.RecipientWrappedKey 291 // set kid if 1PU (authcrypt) with multi recipients since common protected headers don't have the recipient kid. 292 if is1PU && len(jwe.Recipients) > 1 { 293 headers.KID = recJWE.Header.KID 294 } 295 296 recWK, err = createRecWK(headers, []byte(recJWE.EncryptedKey)) 297 if err != nil { 298 return nil, err 299 } 300 301 recipients = append(recipients, recWK) 302 } 303 304 return recipients, nil 305 } 306 307 func createRecWK(headers *RecipientHeaders, encryptedKey []byte) (*cryptoapi.RecipientWrappedKey, error) { 308 recWK, err := convertMarshalledJWKToRecKey(headers.EPK) 309 if err != nil { 310 return nil, err 311 } 312 313 recWK.KID = headers.KID 314 recWK.Alg = headers.Alg 315 316 err = updateAPUAPVInRecWK(recWK, headers) 317 if err != nil { 318 return nil, err 319 } 320 321 recWK.EncryptedCEK = encryptedKey 322 323 return recWK, nil 324 } 325 326 func updateAPUAPVInRecWK(recWK *cryptoapi.RecipientWrappedKey, headers *RecipientHeaders) error { 327 decodedAPU, decodedAPV, err := decodeAPUAPV(headers) 328 if err != nil { 329 return fmt.Errorf("updateAPUAPVInRecWK: %w", err) 330 } 331 332 recWK.APU = decodedAPU 333 recWK.APV = decodedAPV 334 335 return nil 336 } 337 338 func buildEncryptedData(jwe *JSONWebEncryption) ([]byte, error) { 339 encData := new(composite.EncryptedData) 340 encData.Tag = []byte(jwe.Tag) 341 encData.IV = []byte(jwe.IV) 342 encData.Ciphertext = []byte(jwe.Ciphertext) 343 344 return json.Marshal(encData) 345 } 346 347 // extractRecipientHeaders will extract RecipientHeaders from headers argument. 348 func extractRecipientHeaders(headers map[string]interface{}) (*RecipientHeaders, error) { 349 // Since headers is a generic map, epk value is converted to a generic map by Serialize(), ie we lose RawMessage 350 // type of epk. We need to convert epk value (generic map) to marshaled json so we can call RawMessage.Unmarshal() 351 // to get the original epk value (RawMessage type). 352 mapData, ok := headers[HeaderEPK].(map[string]interface{}) 353 if !ok { 354 return nil, fmt.Errorf("JSON value is not a map (%#v)", headers[HeaderEPK]) 355 } 356 357 epkBytes, err := json.Marshal(mapData) 358 if err != nil { 359 return nil, err 360 } 361 362 epk := json.RawMessage{} 363 364 err = epk.UnmarshalJSON(epkBytes) 365 if err != nil { 366 return nil, err 367 } 368 369 alg := "" 370 if headers[HeaderAlgorithm] != nil { 371 alg = fmt.Sprintf("%v", headers[HeaderAlgorithm]) 372 } 373 374 kid := "" 375 if headers[HeaderKeyID] != nil { 376 kid = fmt.Sprintf("%v", headers[HeaderKeyID]) 377 } 378 379 apu := "" 380 if headers["apu"] != nil { 381 apu = fmt.Sprintf("%v", headers["apu"]) 382 } 383 384 apv := "" 385 if headers["apv"] != nil { 386 apv = fmt.Sprintf("%v", headers["apv"]) 387 } 388 389 recHeaders := &RecipientHeaders{ 390 Alg: alg, 391 KID: kid, 392 EPK: epk, 393 APU: apu, 394 APV: apv, 395 } 396 397 // original headers should remain untouched to avoid modifying the original JWE content. 398 return recHeaders, nil 399 } 400 401 func convertMarshalledJWKToRecKey(marshalledJWK []byte) (*cryptoapi.RecipientWrappedKey, error) { 402 j := &jwk.JWK{} 403 404 err := j.UnmarshalJSON(marshalledJWK) 405 if err != nil { 406 return nil, err 407 } 408 409 epk := cryptoapi.PublicKey{ 410 Curve: j.Crv, 411 Type: j.Kty, 412 } 413 414 switch key := j.Key.(type) { 415 case *ecdsa.PublicKey: 416 epk.X = key.X.Bytes() 417 epk.Y = key.Y.Bytes() 418 case []byte: 419 epk.X = key 420 default: 421 return nil, fmt.Errorf("unsupported recipient key type") 422 } 423 424 return &cryptoapi.RecipientWrappedKey{ 425 KID: j.KeyID, 426 EPK: epk, 427 }, nil 428 }