github.com/hyperledger/aries-framework-go@v0.3.2/pkg/doc/jose/encrypter.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package jose 8 9 import ( 10 "bytes" 11 "crypto/ecdsa" 12 "crypto/elliptic" 13 "crypto/rand" 14 "crypto/sha256" 15 "encoding/base64" 16 "encoding/json" 17 "errors" 18 "fmt" 19 "math/big" 20 "sort" 21 "strings" 22 23 "github.com/go-jose/go-jose/v3" 24 hybrid "github.com/google/tink/go/hybrid/subtle" 25 "github.com/google/tink/go/keyset" 26 "github.com/google/tink/go/subtle/random" 27 "golang.org/x/crypto/curve25519" 28 29 ecdhpb "github.com/hyperledger/aries-framework-go/component/kmscrypto/crypto/tinkcrypto/primitive/proto/ecdh_aead_go_proto" 30 31 "github.com/hyperledger/aries-framework-go/component/kmscrypto/doc/jose/jwk" 32 cryptoapi "github.com/hyperledger/aries-framework-go/pkg/crypto" 33 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto" 34 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/aead/subtle" 35 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite" 36 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite/api" 37 "github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite/ecdh" 38 "github.com/hyperledger/aries-framework-go/pkg/internal/cryptoutil" 39 ) 40 41 // EncAlg represents the JWE content encryption algorithm. 42 type EncAlg string 43 44 const ( 45 // A256GCM for AES256GCM content encryption. 46 A256GCM = EncAlg(A256GCMALG) 47 // XC20P for XChacha20Poly1305 content encryption. 48 XC20P = EncAlg(XC20PALG) 49 // A128CBCHS256 for A128CBC-HS256 (AES128-CBC+HMAC-SHA256) content encryption. 50 A128CBCHS256 = EncAlg(A128CBCHS256ALG) 51 // A192CBCHS384 for A192CBC-HS384 (AES192-CBC+HMAC-SHA384) content encryption. 52 A192CBCHS384 = EncAlg(A192CBCHS384ALG) 53 // A256CBCHS384 for A256CBC-HS384 (AES256-CBC+HMAC-SHA384) content encryption. 54 A256CBCHS384 = EncAlg(A256CBCHS384ALG) 55 // A256CBCHS512 for A256CBC-HS512 (AES256-CBC+HMAC-SHA512) content encryption. 56 A256CBCHS512 = EncAlg(A256CBCHS512ALG) 57 ) 58 59 // Encrypter interface to Encrypt/Decrypt JWE messages. 60 type Encrypter interface { 61 // EncryptWithAuthData encrypt plaintext and aad sent to more than 1 recipients and returns a valid 62 // JSONWebEncryption instance 63 EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) 64 65 // Encrypt plaintext with empty aad sent to 1 or more recipients and returns a valid JSONWebEncryption instance 66 Encrypt(plaintext []byte) (*JSONWebEncryption, error) 67 } 68 69 // JWEEncrypt is responsible for encrypting a plaintext and its AAD into a protected JWE and decrypting it. 70 type JWEEncrypt struct { 71 recipientsKeys []*cryptoapi.PublicKey 72 skid string 73 senderKH *keyset.Handle 74 encAlg EncAlg 75 encTyp string 76 cty string 77 crypto cryptoapi.Crypto 78 } 79 80 // NewJWEEncrypt creates a new JWEEncrypt instance to build JWE with recipientsPubKeys 81 // senderKID and senderKH are used for Authcrypt (to authenticate the sender), if not set JWEEncrypt assumes Anoncrypt. 82 func NewJWEEncrypt(encAlg EncAlg, envelopMediaType, cty, senderKID string, senderKH *keyset.Handle, 83 recipientsPubKeys []*cryptoapi.PublicKey, crypto cryptoapi.Crypto) (*JWEEncrypt, error) { 84 if len(recipientsPubKeys) == 0 { 85 return nil, fmt.Errorf("empty recipientsPubKeys list") 86 } 87 88 switch encAlg { 89 case A256GCM, XC20P, A128CBCHS256, A192CBCHS384, A256CBCHS384, A256CBCHS512: 90 default: 91 return nil, fmt.Errorf("encryption algorithm '%s' not supported", encAlg) 92 } 93 94 if crypto == nil { 95 return nil, errors.New("crypto service is required to create a JWEEncrypt instance") 96 } 97 98 if senderKH != nil { 99 // senderKID is required with non empty senderKH 100 if senderKID == "" { 101 return nil, errors.New("senderKID is required with senderKH") 102 } 103 } 104 105 return &JWEEncrypt{ 106 recipientsKeys: recipientsPubKeys, 107 skid: senderKID, 108 senderKH: senderKH, 109 encAlg: encAlg, 110 encTyp: envelopMediaType, 111 cty: cty, 112 crypto: crypto, 113 }, nil 114 } 115 116 func (je *JWEEncrypt) getECDHEncPrimitive(cek []byte) (api.CompositeEncrypt, error) { 117 nistpKW := je.useNISTPKW() 118 119 encAlg, ok := aeadAlg[je.encAlg] 120 if !ok { 121 return nil, fmt.Errorf("getECDHEncPrimitive: encAlg not supported: '%v'", je.encAlg) 122 } 123 124 kt := ecdh.KeyTemplateForECDHPrimitiveWithCEK(cek, nistpKW, encAlg) 125 126 kh, err := keyset.NewHandle(kt) 127 if err != nil { 128 return nil, err 129 } 130 131 pubKH, err := kh.Public() 132 if err != nil { 133 return nil, err 134 } 135 136 return ecdh.NewECDHEncrypt(pubKH) 137 } 138 139 // Encrypt encrypt plaintext with AAD and returns a JSONWebEncryption instance to serialize a JWE instance. 140 func (je *JWEEncrypt) Encrypt(plaintext []byte) (*JSONWebEncryption, error) { 141 return je.EncryptWithAuthData(plaintext, nil) 142 } 143 144 // EncryptWithAuthData encrypt plaintext with AAD and returns a JSONWebEncryption instance to serialize a JWE instance. 145 func (je *JWEEncrypt) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) { 146 protectedHeaders := map[string]interface{}{ 147 HeaderEncryption: je.encAlg, 148 HeaderType: je.encTyp, 149 } 150 151 je.addExtraProtectedHeaders(protectedHeaders) 152 153 cek := je.newCEK() 154 155 // creating the crypto primitive requires a pre-built cek 156 encPrimitive, err := je.getECDHEncPrimitive(cek) 157 if err != nil { 158 return nil, fmt.Errorf("jweencrypt: failed to get encryption primitive: %w", err) 159 } 160 161 authData, err := computeAuthData(protectedHeaders, "", aad) 162 if err != nil { 163 return nil, fmt.Errorf("jweencrypt: computeAuthData: marshal error %w", err) 164 } 165 166 if je.senderKH != nil && je.skid != "" { 167 // ecdh-1pu encryption requires CBC+HMAC encAlg types. 168 return je.encryptWithSender(encPrimitive, plaintext, authData, cek, aad) 169 } 170 171 return je.encrypt(protectedHeaders, encPrimitive, plaintext, authData, cek, aad) 172 } 173 174 func (je *JWEEncrypt) encrypt(protectedHeaders map[string]interface{}, encPrimitive api.CompositeEncrypt, 175 plaintext, authData, cek, aad []byte) (*JSONWebEncryption, error) { 176 recipients, singleRecipientHeaderADDs, err := je.wrapCEKForRecipients(cek, []byte{}, []byte{}, authData, json.Marshal) 177 if err != nil { 178 return nil, fmt.Errorf("jweencrypt: failed to wrap cek: %w", err) 179 } 180 181 if len(singleRecipientHeaderADDs) > 0 { 182 authData = singleRecipientHeaderADDs 183 } 184 185 recipientsHeaders, singleRecipientHeaders, err := je.buildRecs(recipients, false) 186 if err != nil { 187 return nil, fmt.Errorf("jweencrypt: failed to build recipients: %w", err) 188 } 189 190 serializedEncData, err := encPrimitive.Encrypt(plaintext, authData) 191 if err != nil { 192 return nil, fmt.Errorf("jweencrypt: failed to Encrypt: %w", err) 193 } 194 195 encData := new(composite.EncryptedData) 196 197 err = json.Unmarshal(serializedEncData, encData) 198 if err != nil { 199 return nil, fmt.Errorf("jweencrypt: unmarshal encrypted data failed: %w", err) 200 } 201 202 if singleRecipientHeaders != nil { 203 mergeRecipientHeaders(protectedHeaders, singleRecipientHeaders) 204 } 205 206 return getJSONWebEncryption(encData, recipientsHeaders, protectedHeaders, aad), nil 207 } 208 209 func (je *JWEEncrypt) encryptWithSender(primitive api.CompositeEncrypt, 210 plaintext, authData, cek, aad []byte) (*JSONWebEncryption, error) { 211 // pre-generate an EPK + compute apu and apv to be added to authData 212 apu, apv, err := je.buildAPUAPV() 213 if err != nil { 214 return nil, fmt.Errorf("jweencryptWithSender: %w", err) 215 } 216 217 // protectHeaders must be replaced for 1PU to match authData for encryption/decryption 218 // (including EPK, APU, APV always + alg and kid if single recipient). 219 epk, authData, newProtectedHeaders, err := je.generateEPKAndUpdateAuthDataFor1PU(authData, cek, apu, apv) 220 if err != nil { 221 return nil, fmt.Errorf("jweencryptWithSender: %w", err) 222 } 223 224 serializedEncData, err := primitive.Encrypt(plaintext, authData) 225 if err != nil { 226 return nil, fmt.Errorf("jweencryptWithSender: failed to Encrypt: %w", err) 227 } 228 229 encData := new(composite.EncryptedData) 230 231 err = json.Unmarshal(serializedEncData, encData) 232 if err != nil { 233 return nil, fmt.Errorf("jweencryptWithSender: unmarshal encrypted data failed: %w", err) 234 } 235 236 recipients, _, err := je.wrapCEKForRecipientsWithTagAndEPK(cek, apu, apv, authData, 237 encData.Tag, json.Marshal, epk) 238 if err != nil { 239 return nil, fmt.Errorf("jweencryptWithSender: failed to wrap cek: %w", err) 240 } 241 242 recipientsHeaders, _, err := je.buildRecs(recipients, true) 243 if err != nil { 244 return nil, fmt.Errorf("jweencryptWithSender: failed to build recipients: %w", err) 245 } 246 247 return getJSONWebEncryption(encData, recipientsHeaders, newProtectedHeaders, aad), nil 248 } 249 250 func getJSONWebEncryption(encData *composite.EncryptedData, recipientsHeaders []*Recipient, 251 protectedHeaders map[string]interface{}, aad []byte) *JSONWebEncryption { 252 return &JSONWebEncryption{ 253 IV: string(encData.IV), 254 Tag: string(encData.Tag), 255 Ciphertext: string(encData.Ciphertext), 256 Recipients: recipientsHeaders, 257 ProtectedHeaders: protectedHeaders, 258 AAD: string(aad), 259 } 260 } 261 262 func (je *JWEEncrypt) wrapCEKForRecipients(cek, apu, apv, aad []byte, 263 marshaller marshalFunc) ([]*cryptoapi.RecipientWrappedKey, []byte, error) { 264 return je.wrapCEKForRecipientsWithTagAndEPK(cek, apu, apv, aad, nil, marshaller, nil) 265 } 266 267 func (je *JWEEncrypt) wrapCEKForRecipientsWithTagAndEPK(cek, apu, apv, aad, tag []byte, 268 marshaller marshalFunc, epk *cryptoapi.PrivateKey) ([]*cryptoapi.RecipientWrappedKey, []byte, error) { 269 var ( 270 computedAPU []byte 271 computedAPV []byte 272 err error 273 ) 274 275 if len(tag) > 0 { 276 // build apu/apv prior to key wrapping for 1PU only (tag not empty). 277 computedAPU, computedAPV, err = je.buildAPUAPV() 278 if err != nil { 279 return nil, nil, fmt.Errorf("wrapCEKForRecipientsWithTagAndEPK: %w", err) 280 } 281 } 282 283 if len(apv) == 0 { 284 apv = make([]byte, len(computedAPV)) 285 copy(apv, computedAPV) 286 } 287 288 if len(apu) == 0 && je.skid != "" { 289 apu = make([]byte, len(computedAPU)) 290 copy(apu, computedAPU) 291 } 292 293 wrapOpts := je.getWrapKeyOpts(tag, epk) 294 295 rw, kek, err := je.wrapKey(cek, apu, apv, aad, wrapOpts, marshaller) 296 if err != nil { 297 return nil, nil, fmt.Errorf("wrapCEKForRecipientsWithTagAndEPK: %w", err) 298 } 299 300 return rw, kek, nil 301 } 302 303 func (je *JWEEncrypt) wrapKey(cek, apu, apv, aad []byte, wrapOpts []cryptoapi.WrapKeyOpts, 304 marshaller marshalFunc) ([]*cryptoapi.RecipientWrappedKey, []byte, error) { 305 var ( 306 recipientsWK []*cryptoapi.RecipientWrappedKey 307 singleRecipientAAD []byte 308 ) 309 310 for i, recPubKey := range je.recipientsKeys { 311 var ( 312 kek *cryptoapi.RecipientWrappedKey 313 err error 314 ) 315 316 if len(wrapOpts) > 0 { 317 kek, err = je.crypto.WrapKey(cek, apu, apv, recPubKey, wrapOpts...) 318 } else { 319 kek, err = je.crypto.WrapKey(cek, apu, apv, recPubKey) 320 } 321 322 if err != nil { 323 return nil, nil, fmt.Errorf("wrapKey: %d failed: %w", i+1, err) 324 } 325 326 je.encodeAPUAPV(kek) 327 328 recipientsWK = append(recipientsWK, kek) 329 330 if len(je.recipientsKeys) == 1 { 331 singleRecipientAAD, err = mergeSingleRecipientHeaders(kek, aad, marshaller) 332 if err != nil { 333 return nil, nil, fmt.Errorf("wrapKey: merge recipent headers failed for %d: %w", i+1, err) 334 } 335 } 336 } 337 338 return recipientsWK, singleRecipientAAD, nil 339 } 340 341 func (je *JWEEncrypt) encodeAPUAPV(kek *cryptoapi.RecipientWrappedKey) { 342 // APU and APV must be base64URL encoded. 343 if len(kek.APU) > 0 { 344 apuBytes := make([]byte, len(kek.APU)) 345 copy(apuBytes, kek.APU) 346 kek.APU = make([]byte, base64.RawURLEncoding.EncodedLen(len(apuBytes))) 347 base64.RawURLEncoding.Encode(kek.APU, apuBytes) 348 } 349 350 if len(kek.APV) > 0 { 351 apvBytes := make([]byte, len(kek.APV)) 352 copy(apvBytes, kek.APV) 353 kek.APV = make([]byte, base64.RawURLEncoding.EncodedLen(len(apvBytes))) 354 base64.RawURLEncoding.Encode(kek.APV, apvBytes) 355 } 356 } 357 358 func (je *JWEEncrypt) getWrapKeyOpts(tag []byte, epk *cryptoapi.PrivateKey) []cryptoapi.WrapKeyOpts { 359 var wrapOpts []cryptoapi.WrapKeyOpts 360 361 if je.recipientsKeys[0].Type == "OKP" { 362 wrapOpts = append(wrapOpts, cryptoapi.WithXC20PKW()) 363 } 364 365 if je.skid != "" && je.senderKH != nil { 366 wrapOpts = append(wrapOpts, cryptoapi.WithSender(je.senderKH)) 367 } 368 369 if len(tag) > 0 { 370 wrapOpts = append(wrapOpts, cryptoapi.WithTag(tag)) 371 } 372 373 if epk != nil { 374 wrapOpts = append(wrapOpts, cryptoapi.WithEPK(epk)) 375 } 376 377 return wrapOpts 378 } 379 380 // mergeSingleRecipientHeaders for single recipient encryption, recipient header info is available in the key, update 381 // AAD with this info and return the marshalled merged result. 382 func mergeSingleRecipientHeaders(recipientWK *cryptoapi.RecipientWrappedKey, 383 aad []byte, marshaller marshalFunc) ([]byte, error) { 384 var externalAAD []byte 385 386 aadIdx := len(aad) 387 388 if i := bytes.Index(aad, []byte(".")); i > 0 { 389 aadIdx = i 390 externalAAD = append(externalAAD, aad[aadIdx+1:]...) 391 } 392 393 newAAD, err := base64.RawURLEncoding.DecodeString(string(aad[:aadIdx])) 394 if err != nil { 395 return nil, err 396 } 397 398 rawHeaders := map[string]json.RawMessage{} 399 400 err = json.Unmarshal(newAAD, &rawHeaders) 401 if err != nil { 402 return nil, err 403 } 404 405 if recipientWK.KID != "" { 406 var kid []byte 407 408 kid, err = marshaller(recipientWK.KID) 409 if err != nil { 410 return nil, err 411 } 412 413 rawHeaders["kid"] = kid 414 } 415 416 alg, err := marshaller(recipientWK.Alg) 417 if err != nil { 418 return nil, err 419 } 420 421 rawHeaders["alg"] = alg 422 423 err = addKDFHeaders(rawHeaders, recipientWK, marshaller) 424 if err != nil { 425 return nil, err 426 } 427 428 mAAD, err := marshaller(rawHeaders) 429 if err != nil { 430 return nil, err 431 } 432 433 mAADStr := []byte(base64.RawURLEncoding.EncodeToString(mAAD)) 434 435 if len(externalAAD) > 0 { 436 mAADStr = append(mAADStr, byte('.')) 437 mAADStr = append(mAADStr, externalAAD...) 438 } 439 440 return mAADStr, nil 441 } 442 443 func addKDFHeaders(rawHeaders map[string]json.RawMessage, recipientWK *cryptoapi.RecipientWrappedKey, 444 marshaller marshalFunc) error { 445 var err error 446 447 mEPK, err := convertRecEPKToMarshalledJWK(&recipientWK.EPK) 448 if err != nil { 449 return err 450 } 451 452 rawHeaders["epk"] = mEPK 453 454 if len(recipientWK.APU) != 0 { 455 rawHeaders["apu"], err = marshaller(fmt.Sprintf("%s", recipientWK.APU)) 456 if err != nil { 457 return err 458 } 459 } 460 461 if len(recipientWK.APV) != 0 { 462 rawHeaders["apv"], err = marshaller(fmt.Sprintf("%s", recipientWK.APV)) 463 if err != nil { 464 return err 465 } 466 } 467 468 return nil 469 } 470 471 func mergeRecipientHeaders(headers map[string]interface{}, recHeaders *RecipientHeaders) { 472 headers[HeaderAlgorithm] = recHeaders.Alg 473 if recHeaders.KID != "" { 474 headers[HeaderKeyID] = recHeaders.KID 475 } 476 477 // EPK, APU, APV will be marshalled by Serialize 478 headers[HeaderEPK] = recHeaders.EPK 479 if recHeaders.APU != "" { 480 headers["apu"] = base64.RawURLEncoding.EncodeToString([]byte(recHeaders.APU)) 481 } 482 483 if recHeaders.APV != "" { 484 headers["apv"] = base64.RawURLEncoding.EncodeToString([]byte(recHeaders.APV)) 485 } 486 } 487 488 func (je *JWEEncrypt) buildRecs(recWKs []*cryptoapi.RecipientWrappedKey, 489 forAuthcrypt bool) ([]*Recipient, *RecipientHeaders, error) { 490 var ( 491 recipients []*Recipient 492 singleRecipientHeaders *RecipientHeaders 493 ) 494 495 for _, rec := range recWKs { 496 recHeaders, err := buildRecipientHeaders(rec, forAuthcrypt) 497 if err != nil { 498 return nil, nil, err 499 } 500 501 recipients = append(recipients, &Recipient{ 502 EncryptedKey: string(rec.EncryptedCEK), 503 Header: recHeaders, 504 }) 505 } 506 507 // if we have only 1 recipient, then assume compact JWE serialization format. This means recipient header should 508 // be merged with the JWE envelope's protected headers and not added to the recipients 509 if len(recWKs) == 1 { 510 var ( 511 decodedAPU []byte 512 decodedAPV []byte 513 err error 514 ) 515 516 decodedAPU, decodedAPV, err = decodeAPUAPV(recipients[0].Header) 517 if err != nil { 518 return nil, nil, err 519 } 520 521 singleRecipientHeaders = &RecipientHeaders{ 522 Alg: recipients[0].Header.Alg, 523 KID: recipients[0].Header.KID, 524 EPK: recipients[0].Header.EPK, 525 APU: string(decodedAPU), 526 APV: string(decodedAPV), 527 } 528 529 recipients[0].Header = nil 530 } 531 532 return recipients, singleRecipientHeaders, nil 533 } 534 535 func (je *JWEEncrypt) addExtraProtectedHeaders(protectedHeaders map[string]interface{}) { 536 // set cty if it's not empty 537 if je.cty != "" { 538 protectedHeaders[HeaderContentType] = je.cty 539 } 540 541 // set skid if it's not empty 542 if je.skid != "" { 543 protectedHeaders[HeaderSenderKeyID] = je.skid 544 } 545 } 546 547 func (je *JWEEncrypt) useNISTPKW() bool { 548 if je.senderKH == nil { 549 return true 550 } 551 552 for _, ki := range je.senderKH.KeysetInfo().KeyInfo { 553 switch ki.TypeUrl { 554 case "type.hyperledger.org/hyperledger.aries.crypto.tink.NistPEcdhKwPublicKey", 555 "type.hyperledger.org/hyperledger.aries.crypto.tink.NistPEcdhKwPrivateKey": 556 return true 557 case "type.hyperledger.org/hyperledger.aries.crypto.tink.X25519EcdhKwPublicKey", 558 "type.hyperledger.org/hyperledger.aries.crypto.tink.X25519EcdhKwPrivateKey": 559 return false 560 } 561 } 562 563 return true 564 } 565 566 func (je *JWEEncrypt) newCEK() []byte { 567 twoKeys := 2 568 defKeySize := 32 569 570 switch je.encAlg { 571 case A256GCM, XC20P: 572 return random.GetRandomBytes(uint32(defKeySize)) 573 case A128CBCHS256: 574 return random.GetRandomBytes(uint32(subtle.AES128Size * twoKeys)) // cek: 32 bytes. 575 case A192CBCHS384: 576 return random.GetRandomBytes(uint32(subtle.AES192Size * twoKeys)) // cek: 48 bytes. 577 case A256CBCHS384: 578 return random.GetRandomBytes(uint32(subtle.AES256Size + subtle.AES192Size)) // cek: 56 bytes. 579 case A256CBCHS512: 580 return random.GetRandomBytes(uint32(subtle.AES256Size * twoKeys)) // cek: 64 bytes. 581 default: 582 return random.GetRandomBytes(uint32(defKeySize)) // default cek: 32 bytes. 583 } 584 } 585 586 func (je *JWEEncrypt) buildAPUAPV() ([]byte, []byte, error) { 587 if je.skid == "" { 588 return nil, nil, fmt.Errorf("cannot create APU/APV with empty sender skid") 589 } 590 591 if len(je.recipientsKeys) == 0 { 592 return nil, nil, fmt.Errorf("cannot create APU/APV with empty recipient keys") 593 } 594 595 var recKIDs []string 596 597 apu := make([]byte, len(je.skid)) 598 copy(apu, je.skid) 599 600 for _, r := range je.recipientsKeys { 601 recKIDs = append(recKIDs, r.KID) 602 } 603 604 // set recipients' sorted kids list then SHA256 hashed in apv. 605 sort.Strings(recKIDs) 606 607 apvList := []byte(strings.Join(recKIDs, ".")) 608 apv32 := sha256.Sum256(apvList) 609 apv := make([]byte, 32) 610 copy(apv, apv32[:]) 611 612 return apu, apv, nil 613 } 614 615 func (je *JWEEncrypt) generateEPKAndUpdateAuthDataFor1PU(auth, 616 cek, apu, apv []byte) (*cryptoapi.PrivateKey, []byte, map[string]interface{}, error) { 617 var epk *cryptoapi.PrivateKey 618 619 // generate an EPK based on the first recipient. 620 epk, kwAlg, err := je.newEPK(cek) 621 if err != nil { 622 return nil, nil, nil, fmt.Errorf("generateEPKAndUpdateAuthDataFor1PU: %w", err) 623 } 624 625 aadIndex := bytes.Index(auth, []byte(".")) 626 lastIndex := aadIndex 627 628 if lastIndex < 0 { 629 lastIndex = len(auth) 630 } 631 632 return je.buildCommonAuthData(aadIndex, kwAlg, string(auth[:lastIndex]), auth, apu, apv, epk) 633 } 634 635 func (je *JWEEncrypt) buildCommonAuthData(aadIndex int, kwAlg, authData string, auth, apu, apv []byte, 636 epk *cryptoapi.PrivateKey) (*cryptoapi.PrivateKey, []byte, map[string]interface{}, error) { 637 authDataBytes, err := base64.RawURLEncoding.DecodeString(authData) 638 if err != nil { 639 return nil, nil, nil, fmt.Errorf("buildCommonAuthData: authdata decode: %w", err) 640 } 641 642 authDataJSON := map[string]interface{}{} 643 644 err = json.Unmarshal(authDataBytes, &authDataJSON) 645 if err != nil { 646 return nil, nil, nil, fmt.Errorf("buildCommonAuthData: authData unmarshall: %w", err) 647 } 648 649 if len(je.recipientsKeys) == 1 { 650 // kid is part of the protected headers for single recipient JWEs. 651 authDataJSON["kid"] = je.recipientsKeys[0].KID 652 } 653 654 authDataJSON["alg"] = kwAlg 655 656 marshalledEPK, err := convertRecEPKToMarshalledJWK(&epk.PublicKey) 657 if err != nil { 658 return nil, nil, nil, fmt.Errorf("buildCommonAuthData: epk marshall: %w", err) 659 } 660 661 authDataJSON["epk"] = json.RawMessage(marshalledEPK) 662 663 encodedAPU := []byte(base64.RawURLEncoding.EncodeToString(apu)) 664 authDataJSON["apu"] = string(encodedAPU) 665 666 encodedAPV := []byte(base64.RawURLEncoding.EncodeToString(apv)) 667 authDataJSON["apv"] = string(encodedAPV) 668 669 newAuth, err := json.Marshal(authDataJSON) 670 if err != nil { 671 return nil, nil, nil, fmt.Errorf("buildCommonAuthData: authData marshall: %w", err) 672 } 673 674 authData = base64.RawURLEncoding.EncodeToString(newAuth) 675 676 if aadIndex > 0 { 677 authData += string(auth[aadIndex:]) 678 } 679 680 return epk, []byte(authData), authDataJSON, nil 681 } 682 683 func (je *JWEEncrypt) newEPK(cek []byte) (*cryptoapi.PrivateKey, string, error) { 684 var ( 685 kwAlg string 686 epk *cryptoapi.PrivateKey 687 err error 688 ) 689 690 switch je.recipientsKeys[0].Type { 691 case "EC": 692 epk, kwAlg, err = je.ecEPKAndAlg(cek) 693 if err != nil { 694 return nil, "", fmt.Errorf("newEPK: %w", err) 695 } 696 case "OKP": 697 epk, kwAlg, err = je.okpEPKAndAlg() 698 if err != nil { 699 return nil, "", fmt.Errorf("newEPK: %w", err) 700 } 701 default: 702 return nil, "", fmt.Errorf("newEPK: invalid key type '%v'", je.recipientsKeys[0].Type) 703 } 704 705 return epk, kwAlg, nil 706 } 707 708 func (je *JWEEncrypt) ecEPKAndAlg(cek []byte) (*cryptoapi.PrivateKey, string, error) { 709 var kwAlg string 710 711 curve, err := hybrid.GetCurve(je.recipientsKeys[0].Curve) 712 if err != nil { 713 return nil, "", fmt.Errorf("ecEPKAndAlg: getCurve: %w", err) 714 } 715 716 pk, err := ecdsa.GenerateKey(curve, rand.Reader) 717 if err != nil { 718 return nil, "", fmt.Errorf("ecEPKAndAlg: generate ec key: %w", err) 719 } 720 721 epk := &cryptoapi.PrivateKey{ 722 PublicKey: cryptoapi.PublicKey{ 723 Type: "EC", 724 Curve: pk.Curve.Params().Name, 725 X: pk.X.Bytes(), 726 Y: pk.Y.Bytes(), 727 }, 728 D: pk.D.Bytes(), 729 } 730 731 two := 2 732 733 switch len(cek) { 734 case subtle.AES128Size * two: 735 kwAlg = tinkcrypto.ECDH1PUA128KWAlg 736 case subtle.AES192Size * two: 737 kwAlg = tinkcrypto.ECDH1PUA192KWAlg 738 case subtle.AES256Size * two: 739 kwAlg = tinkcrypto.ECDH1PUA256KWAlg 740 } 741 742 return epk, kwAlg, nil 743 } 744 745 func (je *JWEEncrypt) okpEPKAndAlg() (*cryptoapi.PrivateKey, string, error) { 746 ephemeralPrivKey := make([]byte, cryptoutil.Curve25519KeySize) 747 748 _, err := rand.Read(ephemeralPrivKey) 749 if err != nil { 750 return nil, "", fmt.Errorf("okpEPKAndAlg: generate random key for OKP: %w", err) 751 } 752 753 ephemeralPubKey, err := curve25519.X25519(ephemeralPrivKey, curve25519.Basepoint) 754 if err != nil { 755 return nil, "", fmt.Errorf("okpEPKAndAlg: get public epk for OKP: %w", err) 756 } 757 758 kwAlg := tinkcrypto.ECDH1PUXC20PKWAlg 759 760 epk := &cryptoapi.PrivateKey{ 761 PublicKey: cryptoapi.PublicKey{ 762 Type: "OKP", 763 Curve: "X25519", 764 X: ephemeralPubKey, 765 }, 766 D: ephemeralPrivKey, 767 } 768 769 return epk, kwAlg, nil 770 } 771 772 func decodeAPUAPV(headers *RecipientHeaders) ([]byte, []byte, error) { 773 var ( 774 decodedAPU []byte 775 decodedAPV []byte 776 err error 777 ) 778 779 if len(headers.APU) > 0 { 780 decodedAPU, err = base64.RawURLEncoding.DecodeString(headers.APU) 781 if err != nil { 782 return nil, nil, err 783 } 784 } 785 786 if len(headers.APV) > 0 { 787 decodedAPV, err = base64.RawURLEncoding.DecodeString(headers.APV) 788 if err != nil { 789 return nil, nil, err 790 } 791 } 792 793 return decodedAPU, decodedAPV, nil 794 } 795 796 func buildRecipientHeaders(rec *cryptoapi.RecipientWrappedKey, forAuthcrypt bool) (*RecipientHeaders, error) { 797 mRecJWK, err := convertRecEPKToMarshalledJWK(&rec.EPK) 798 if err != nil { 799 return nil, fmt.Errorf("failed to convert recipient key to marshalled JWK: %w", err) 800 } 801 802 rh := &RecipientHeaders{ 803 KID: rec.KID, 804 } 805 806 // authcrypt envelopes have these headers shared for all recipients in the protected headers. 807 if !forAuthcrypt { 808 rh.Alg = rec.Alg 809 rh.EPK = mRecJWK 810 rh.APU = string(rec.APU) 811 rh.APV = string(rec.APV) 812 } 813 814 return rh, nil 815 } 816 817 func convertRecEPKToMarshalledJWK(recEPK *cryptoapi.PublicKey) ([]byte, error) { 818 var ( 819 c elliptic.Curve 820 err error 821 key interface{} 822 ) 823 824 switch recEPK.Type { 825 case ecdhpb.KeyType_EC.String(): 826 c, err = hybrid.GetCurve(recEPK.Curve) 827 if err != nil { 828 return nil, err 829 } 830 831 key = &ecdsa.PublicKey{ 832 Curve: c, 833 X: new(big.Int).SetBytes(recEPK.X), 834 Y: new(big.Int).SetBytes(recEPK.Y), 835 } 836 case ecdhpb.KeyType_OKP.String(): 837 key = recEPK.X 838 default: 839 return nil, errors.New("invalid key type") 840 } 841 842 recJWK := jwk.JWK{ 843 JSONWebKey: jose.JSONWebKey{ 844 Key: key, 845 }, 846 Kty: recEPK.Type, 847 Crv: recEPK.Curve, 848 } 849 850 return recJWK.MarshalJSON() 851 } 852 853 // Get the additional authenticated data from a JWE object. 854 func computeAuthData(protectedHeaders map[string]interface{}, origProtectedHeader string, aad []byte) ([]byte, error) { 855 var protected string 856 857 if len(origProtectedHeader) > 0 { 858 // use origProtectedheader if set instead of marshal/unmarshal existing headers. This is important especially 859 // for ECDH-1PU because protectHeaders are used in the sender authentication mechanism. JSON keys order must 860 // remain untouched. This is critical for successful verification. 861 protected = origProtectedHeader 862 } else if protectedHeaders != nil { 863 protectedHeadersJSON := map[string]json.RawMessage{} 864 865 for k, v := range protectedHeaders { 866 mV, err := json.Marshal(v) 867 if err != nil { 868 return nil, fmt.Errorf("computeAuthData: %w", err) 869 } 870 871 rawMsg := json.RawMessage(mV) // need to explicitly convert []byte to RawMessage (same as go-jose) 872 protectedHeadersJSON[k] = rawMsg 873 } 874 875 err := jwkMarshalEPK(protectedHeadersJSON) 876 if err != nil { 877 return nil, fmt.Errorf("computeAuthData: %w", err) 878 } 879 880 mProtected, err := json.Marshal(protectedHeadersJSON) 881 if err != nil { 882 return nil, fmt.Errorf("computeAuthData: %w", err) 883 } 884 885 protected = base64.RawURLEncoding.EncodeToString(mProtected) 886 } 887 888 output := []byte(protected) 889 if len(aad) > 0 { 890 output = append(output, '.') 891 892 encLen := base64.RawURLEncoding.EncodedLen(len(aad)) 893 aadEncoded := make([]byte, encLen) 894 895 base64.RawURLEncoding.Encode(aadEncoded, aad) 896 output = append(output, aadEncoded...) 897 } 898 899 return output, nil 900 } 901 902 func jwkMarshalEPK(protectedHeadersJSON map[string]json.RawMessage) error { 903 // must use jwk.MarshalJSON() to marshal EPK to maintain headers order. 904 if protectedHeadersJSON[HeaderEPK] != nil { 905 epk := &jwk.JWK{} 906 907 err := epk.UnmarshalJSON(protectedHeadersJSON[HeaderEPK]) 908 if err != nil { 909 return err 910 } 911 912 mEPK, err := epk.MarshalJSON() 913 if err != nil { 914 return fmt.Errorf("jwkMarshalEPK: %w", err) 915 } 916 917 protectedHeadersJSON[HeaderEPK] = mEPK 918 } 919 920 return nil 921 }