github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/jwe.go (about) 1 //go:generate ../tools/cmd/genjwe.sh 2 3 // Package jwe implements JWE as described in https://tools.ietf.org/html/rfc7516 4 package jwe 5 6 import ( 7 "bytes" 8 "context" 9 "crypto/ecdsa" 10 "crypto/rsa" 11 "fmt" 12 "io" 13 "sync" 14 15 "github.com/lestrrat-go/blackmagic" 16 "github.com/lestrrat-go/jwx/v2/internal/base64" 17 "github.com/lestrrat-go/jwx/v2/internal/json" 18 "github.com/lestrrat-go/jwx/v2/internal/keyconv" 19 "github.com/lestrrat-go/jwx/v2/jwk" 20 21 "github.com/lestrrat-go/jwx/v2/jwa" 22 "github.com/lestrrat-go/jwx/v2/jwe/internal/aescbc" 23 "github.com/lestrrat-go/jwx/v2/jwe/internal/content_crypt" 24 "github.com/lestrrat-go/jwx/v2/jwe/internal/keyenc" 25 "github.com/lestrrat-go/jwx/v2/jwe/internal/keygen" 26 "github.com/lestrrat-go/jwx/v2/x25519" 27 ) 28 29 var muSettings sync.RWMutex 30 var maxPBES2Count = 10000 31 var maxDecompressBufferSize int64 = 10 * 1024 * 1024 // 10MB 32 33 func Settings(options ...GlobalOption) { 34 muSettings.Lock() 35 defer muSettings.Unlock() 36 //nolint:forcetypeassert 37 for _, option := range options { 38 switch option.Ident() { 39 case identMaxPBES2Count{}: 40 maxPBES2Count = option.Value().(int) 41 case identMaxDecompressBufferSize{}: 42 maxDecompressBufferSize = option.Value().(int64) 43 case identMaxBufferSize{}: 44 aescbc.SetMaxBufferSize(option.Value().(int64)) 45 } 46 } 47 } 48 49 const ( 50 fmtInvalid = iota 51 fmtCompact 52 fmtJSON 53 fmtJSONPretty 54 fmtMax 55 ) 56 57 var _ = fmtInvalid 58 var _ = fmtMax 59 60 var registry = json.NewRegistry() 61 62 type keyEncrypterWrapper struct { 63 encrypter KeyEncrypter 64 } 65 66 func (w *keyEncrypterWrapper) Algorithm() jwa.KeyEncryptionAlgorithm { 67 return w.encrypter.Algorithm() 68 } 69 70 func (w *keyEncrypterWrapper) EncryptKey(cek []byte) (keygen.ByteSource, error) { 71 encrypted, err := w.encrypter.EncryptKey(cek) 72 if err != nil { 73 return nil, err 74 } 75 return keygen.ByteKey(encrypted), nil 76 } 77 78 type recipientBuilder struct { 79 alg jwa.KeyEncryptionAlgorithm 80 key interface{} 81 headers Headers 82 } 83 84 func (b *recipientBuilder) Build(cek []byte, calg jwa.ContentEncryptionAlgorithm, cc *content_crypt.Generic) (Recipient, []byte, error) { 85 var enc keyenc.Encrypter 86 87 // we need the raw key for later use 88 rawKey := b.key 89 90 var keyID string 91 if ke, ok := b.key.(KeyEncrypter); ok { 92 enc = &keyEncrypterWrapper{encrypter: ke} 93 if kider, ok := enc.(KeyIDer); ok { 94 keyID = kider.KeyID() 95 } 96 } else if jwkKey, ok := b.key.(jwk.Key); ok { 97 // Meanwhile, grab the kid as well 98 keyID = jwkKey.KeyID() 99 100 var raw interface{} 101 if err := jwkKey.Raw(&raw); err != nil { 102 return nil, nil, fmt.Errorf(`failed to retrieve raw key out of %T: %w`, b.key, err) 103 } 104 105 rawKey = raw 106 } 107 108 if enc == nil { 109 switch b.alg { 110 case jwa.RSA1_5: 111 var pubkey rsa.PublicKey 112 if err := keyconv.RSAPublicKey(&pubkey, rawKey); err != nil { 113 return nil, nil, fmt.Errorf(`failed to generate public key from key (%T): %w`, rawKey, err) 114 } 115 116 v, err := keyenc.NewRSAPKCSEncrypt(b.alg, &pubkey) 117 if err != nil { 118 return nil, nil, fmt.Errorf(`failed to create RSA PKCS encrypter: %w`, err) 119 } 120 enc = v 121 case jwa.RSA_OAEP, jwa.RSA_OAEP_256: 122 var pubkey rsa.PublicKey 123 if err := keyconv.RSAPublicKey(&pubkey, rawKey); err != nil { 124 return nil, nil, fmt.Errorf(`failed to generate public key from key (%T): %w`, rawKey, err) 125 } 126 127 v, err := keyenc.NewRSAOAEPEncrypt(b.alg, &pubkey) 128 if err != nil { 129 return nil, nil, fmt.Errorf(`failed to create RSA OAEP encrypter: %w`, err) 130 } 131 enc = v 132 case jwa.A128KW, jwa.A192KW, jwa.A256KW, 133 jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW, 134 jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW: 135 sharedkey, ok := rawKey.([]byte) 136 if !ok { 137 return nil, nil, fmt.Errorf(`invalid key: []byte required (%T)`, rawKey) 138 } 139 140 var err error 141 switch b.alg { 142 case jwa.A128KW, jwa.A192KW, jwa.A256KW: 143 enc, err = keyenc.NewAES(b.alg, sharedkey) 144 case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW: 145 enc, err = keyenc.NewPBES2Encrypt(b.alg, sharedkey) 146 default: 147 enc, err = keyenc.NewAESGCMEncrypt(b.alg, sharedkey) 148 } 149 if err != nil { 150 return nil, nil, fmt.Errorf(`failed to create key wrap encrypter: %w`, err) 151 } 152 // NOTE: there was formerly a restriction, introduced 153 // in PR #26, which disallowed certain key/content 154 // algorithm combinations. This seemed bogus, and 155 // interop with the jose tool demonstrates it. 156 case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: 157 var keysize int 158 switch b.alg { 159 case jwa.ECDH_ES: 160 // https://tools.ietf.org/html/rfc7518#page-15 161 // In Direct Key Agreement mode, the output of the Concat KDF MUST be a 162 // key of the same length as that used by the "enc" algorithm. 163 keysize = cc.KeySize() 164 case jwa.ECDH_ES_A128KW: 165 keysize = 16 166 case jwa.ECDH_ES_A192KW: 167 keysize = 24 168 case jwa.ECDH_ES_A256KW: 169 keysize = 32 170 } 171 172 switch key := rawKey.(type) { 173 case x25519.PublicKey: 174 var apu, apv []byte 175 if hdrs := b.headers; hdrs != nil { 176 apu = hdrs.AgreementPartyUInfo() 177 apv = hdrs.AgreementPartyVInfo() 178 } 179 180 v, err := keyenc.NewECDHESEncrypt(b.alg, calg, keysize, rawKey, apu, apv) 181 if err != nil { 182 return nil, nil, fmt.Errorf(`failed to create ECDHS key wrap encrypter: %w`, err) 183 } 184 enc = v 185 default: 186 var pubkey ecdsa.PublicKey 187 if err := keyconv.ECDSAPublicKey(&pubkey, rawKey); err != nil { 188 return nil, nil, fmt.Errorf(`failed to generate public key from key (%T): %w`, key, err) 189 } 190 191 var apu, apv []byte 192 if hdrs := b.headers; hdrs != nil { 193 apu = hdrs.AgreementPartyUInfo() 194 apv = hdrs.AgreementPartyVInfo() 195 } 196 197 v, err := keyenc.NewECDHESEncrypt(b.alg, calg, keysize, &pubkey, apu, apv) 198 if err != nil { 199 return nil, nil, fmt.Errorf(`failed to create ECDHS key wrap encrypter: %w`, err) 200 } 201 enc = v 202 } 203 case jwa.DIRECT: 204 sharedkey, ok := rawKey.([]byte) 205 if !ok { 206 return nil, nil, fmt.Errorf("invalid key: []byte required") 207 } 208 enc, _ = keyenc.NewNoop(b.alg, sharedkey) 209 default: 210 return nil, nil, fmt.Errorf(`invalid key encryption algorithm (%s)`, b.alg) 211 } 212 } 213 214 r := NewRecipient() 215 if hdrs := b.headers; hdrs != nil { 216 _ = r.SetHeaders(hdrs) 217 } 218 219 if err := r.Headers().Set(AlgorithmKey, b.alg); err != nil { 220 return nil, nil, fmt.Errorf(`failed to set header: %w`, err) 221 } 222 223 if keyID != "" { 224 if err := r.Headers().Set(KeyIDKey, keyID); err != nil { 225 return nil, nil, fmt.Errorf(`failed to set header: %w`, err) 226 } 227 } 228 229 var rawCEK []byte 230 enckey, err := enc.EncryptKey(cek) 231 if err != nil { 232 return nil, nil, fmt.Errorf(`failed to encrypt key: %w`, err) 233 } 234 if enc.Algorithm() == jwa.ECDH_ES || enc.Algorithm() == jwa.DIRECT { 235 rawCEK = enckey.Bytes() 236 } else { 237 if err := r.SetEncryptedKey(enckey.Bytes()); err != nil { 238 return nil, nil, fmt.Errorf(`failed to set encrypted key: %w`, err) 239 } 240 } 241 242 if hp, ok := enckey.(populater); ok { 243 if err := hp.Populate(r.Headers()); err != nil { 244 return nil, nil, fmt.Errorf(`failed to populate: %w`, err) 245 } 246 } 247 248 return r, rawCEK, nil 249 } 250 251 // Encrypt generates a JWE message for the given payload and returns 252 // it in serialized form, which can be in either compact or 253 // JSON format. Default is compact. 254 // 255 // You must pass at least one key to `jwe.Encrypt()` by using `jwe.WithKey()` 256 // option. 257 // 258 // jwe.Encrypt(payload, jwe.WithKey(alg, key)) 259 // jwe.Encrypt(payload, jws.WithJSON(), jws.WithKey(alg1, key1), jws.WithKey(alg2, key2)) 260 // 261 // Note that in the second example the `jws.WithJSON()` option is 262 // specified as well. This is because the compact serialization 263 // format does not support multiple recipients, and users must 264 // specifically ask for the JSON serialization format. 265 // 266 // Read the documentation for `jwe.WithKey()` to learn more about the 267 // possible values that can be used for `alg` and `key`. 268 // 269 // Look for options that return `jwe.EncryptOption` or `jws.EncryptDecryptOption` 270 // for a complete list of options that can be passed to this function. 271 func Encrypt(payload []byte, options ...EncryptOption) ([]byte, error) { 272 return encrypt(payload, nil, options...) 273 } 274 275 // Encryptstatic is exactly like Encrypt, except it accepts a static 276 // content encryption key (CEK). It is separated out from the main 277 // Encrypt function such that the latter does not accidentally use a static 278 // CEK. 279 // 280 // DO NOT attempt to use this function unless you completely understand the 281 // security implications to using static CEKs. You have been warned. 282 // 283 // This function is currently considered EXPERIMENTAL, and is subject to 284 // future changes across minor/micro versions. 285 func EncryptStatic(payload, cek []byte, options ...EncryptOption) ([]byte, error) { 286 if len(cek) <= 0 { 287 return nil, fmt.Errorf(`jwe.EncryptStatic: empty CEK`) 288 } 289 return encrypt(payload, cek, options...) 290 } 291 292 // encrypt is separate so it can receive cek from outside. 293 // (but we don't want to receive it in the options slice) 294 func encrypt(payload, cek []byte, options ...EncryptOption) ([]byte, error) { 295 // default content encryption algorithm 296 calg := jwa.A256GCM 297 298 // default compression is "none" 299 compression := jwa.NoCompress 300 301 // default format is compact serialization 302 format := fmtCompact 303 304 // builds each "recipient" with encrypted_key and headers 305 var builders []*recipientBuilder 306 307 var protected Headers 308 var mergeProtected bool 309 var useRawCEK bool 310 for _, option := range options { 311 //nolint:forcetypeassert 312 switch option.Ident() { 313 case identKey{}: 314 data := option.Value().(*withKey) 315 v, ok := data.alg.(jwa.KeyEncryptionAlgorithm) 316 if !ok { 317 return nil, fmt.Errorf(`jwe.Encrypt: expected alg to be jwa.KeyEncryptionAlgorithm, but got %T`, data.alg) 318 } 319 320 switch v { 321 case jwa.DIRECT, jwa.ECDH_ES: 322 useRawCEK = true 323 } 324 325 builders = append(builders, &recipientBuilder{ 326 alg: v, 327 key: data.key, 328 headers: data.headers, 329 }) 330 case identContentEncryptionAlgorithm{}: 331 calg = option.Value().(jwa.ContentEncryptionAlgorithm) 332 case identCompress{}: 333 compression = option.Value().(jwa.CompressionAlgorithm) 334 case identMergeProtectedHeaders{}: 335 mergeProtected = option.Value().(bool) 336 case identProtectedHeaders{}: 337 v := option.Value().(Headers) 338 if !mergeProtected || protected == nil { 339 protected = v 340 } else { 341 ctx := context.TODO() 342 merged, err := protected.Merge(ctx, v) 343 if err != nil { 344 return nil, fmt.Errorf(`jwe.Encrypt: failed to merge headers: %w`, err) 345 } 346 protected = merged 347 } 348 case identSerialization{}: 349 format = option.Value().(int) 350 } 351 } 352 353 // We need to have at least one builder 354 switch l := len(builders); { 355 case l == 0: 356 return nil, fmt.Errorf(`jwe.Encrypt: missing key encryption builders: use jwe.WithKey() to specify one`) 357 case l > 1: 358 if format == fmtCompact { 359 return nil, fmt.Errorf(`jwe.Encrypt: cannot use compact serialization when multiple recipients exist (check the number of WithKey() argument, or use WithJSON())`) 360 } 361 } 362 363 if useRawCEK { 364 if len(builders) != 1 { 365 return nil, fmt.Errorf(`jwe.Encrypt: multiple recipients for ECDH-ES/DIRECT mode supported`) 366 } 367 } 368 369 // There is exactly one content encrypter. 370 contentcrypt, err := content_crypt.NewGeneric(calg) 371 if err != nil { 372 return nil, fmt.Errorf(`jwe.Encrypt: failed to create AES encrypter: %w`, err) 373 } 374 375 if len(cek) <= 0 { 376 generator := keygen.NewRandom(contentcrypt.KeySize()) 377 bk, err := generator.Generate() 378 if err != nil { 379 return nil, fmt.Errorf(`jwe.Encrypt: failed to generate key: %w`, err) 380 } 381 cek = bk.Bytes() 382 } 383 384 recipients := make([]Recipient, len(builders)) 385 for i, builder := range builders { 386 // some builders require hint from the contentcrypt object 387 r, rawCEK, err := builder.Build(cek, calg, contentcrypt) 388 if err != nil { 389 return nil, fmt.Errorf(`jwe.Encrypt: failed to create recipient #%d: %w`, i, err) 390 } 391 recipients[i] = r 392 393 // Kinda feels weird, but if useRawCEK == true, we asserted earlier 394 // that len(builders) == 1, so this is OK 395 if useRawCEK { 396 cek = rawCEK 397 } 398 } 399 400 if protected == nil { 401 protected = NewHeaders() 402 } 403 404 if err := protected.Set(ContentEncryptionKey, calg); err != nil { 405 return nil, fmt.Errorf(`jwe.Encrypt: failed to set "enc" in protected header: %w`, err) 406 } 407 408 if compression != jwa.NoCompress { 409 payload, err = compress(payload) 410 if err != nil { 411 return nil, fmt.Errorf(`jwe.Encrypt: failed to compress payload before encryption: %w`, err) 412 } 413 if err := protected.Set(CompressionKey, compression); err != nil { 414 return nil, fmt.Errorf(`jwe.Encrypt: failed to set "zip" in protected header: %w`, err) 415 } 416 } 417 418 // If there's only one recipient, you want to include that in the 419 // protected header 420 if len(recipients) == 1 { 421 h, err := protected.Merge(context.TODO(), recipients[0].Headers()) 422 if err != nil { 423 return nil, fmt.Errorf(`jwe.Encrypt: failed to merge protected headers: %w`, err) 424 } 425 protected = h 426 } 427 428 aad, err := protected.Encode() 429 if err != nil { 430 return nil, fmt.Errorf(`failed to base64 encode protected headers: %w`, err) 431 } 432 433 iv, ciphertext, tag, err := contentcrypt.Encrypt(cek, payload, aad) 434 if err != nil { 435 return nil, fmt.Errorf(`failed to encrypt payload: %w`, err) 436 } 437 438 msg := NewMessage() 439 440 if err := msg.Set(CipherTextKey, ciphertext); err != nil { 441 return nil, fmt.Errorf(`failed to set %s: %w`, CipherTextKey, err) 442 } 443 if err := msg.Set(InitializationVectorKey, iv); err != nil { 444 return nil, fmt.Errorf(`failed to set %s: %w`, InitializationVectorKey, err) 445 } 446 if err := msg.Set(ProtectedHeadersKey, protected); err != nil { 447 return nil, fmt.Errorf(`failed to set %s: %w`, ProtectedHeadersKey, err) 448 } 449 if err := msg.Set(RecipientsKey, recipients); err != nil { 450 return nil, fmt.Errorf(`failed to set %s: %w`, RecipientsKey, err) 451 } 452 if err := msg.Set(TagKey, tag); err != nil { 453 return nil, fmt.Errorf(`failed to set %s: %w`, TagKey, err) 454 } 455 456 switch format { 457 case fmtCompact: 458 return Compact(msg) 459 case fmtJSON: 460 return json.Marshal(msg) 461 case fmtJSONPretty: 462 return json.MarshalIndent(msg, "", " ") 463 default: 464 return nil, fmt.Errorf(`jwe.Encrypt: invalid serialization`) 465 } 466 } 467 468 type decryptCtx struct { 469 msg *Message 470 aad []byte 471 cek *[]byte 472 computedAad []byte 473 keyProviders []KeyProvider 474 protectedHeaders Headers 475 maxDecompressBufferSize int64 476 } 477 478 // Decrypt takes encrypted payload, and information required to decrypt the 479 // payload (e.g. the key encryption algorithm and the corresponding 480 // key to decrypt the JWE message) in its optional arguments. See 481 // the examples and list of options that return a DecryptOption for possible 482 // values. Upon successful decryptiond returns the decrypted payload. 483 // 484 // The JWE message can be either compact or full JSON format. 485 // 486 // When using `jwe.WithKeyEncryptionAlgorithm()`, you can pass a `jwa.KeyAlgorithm` 487 // for convenience: this is mainly to allow you to directly pass the result of `(jwk.Key).Algorithm()`. 488 // However, do note that while `(jwk.Key).Algorithm()` could very well contain key encryption 489 // algorithms, it could also contain other types of values, such as _signature algorithms_. 490 // In order for `jwe.Decrypt` to work properly, the `alg` parameter must be of type 491 // `jwa.KeyEncryptionAlgorithm` or otherwise it will cause an error. 492 // 493 // When using `jwe.WithKey()`, the value must be a private key. 494 // It can be either in its raw format (e.g. *rsa.PrivateKey) or a jwk.Key 495 // 496 // When the encrypted message is also compressed, the decompressed payload must be 497 // smaller than the size specified by the `jwe.WithMaxDecompressBufferSize` setting, 498 // which defaults to 10MB. If the decompressed payload is larger than this size, 499 // an error is returned. 500 // 501 // You can opt to change the MaxDecompressBufferSize setting globally, or on a 502 // per-call basis by passing the `jwe.WithMaxDecompressBufferSize` option to 503 // either `jwe.Settings()` or `jwe.Decrypt()`: 504 // 505 // jwe.Settings(jwe.WithMaxDecompressBufferSize(10*1024*1024)) // changes value globally 506 // jwe.Decrypt(..., jwe.WithMaxDecompressBufferSize(250*1024)) // changes just for this call 507 func Decrypt(buf []byte, options ...DecryptOption) ([]byte, error) { 508 var keyProviders []KeyProvider 509 var keyUsed interface{} 510 var cek *[]byte 511 var dst *Message 512 perCallMaxDecompressBufferSize := maxDecompressBufferSize 513 //nolint:forcetypeassert 514 for _, option := range options { 515 switch option.Ident() { 516 case identMessage{}: 517 dst = option.Value().(*Message) 518 case identKeyProvider{}: 519 keyProviders = append(keyProviders, option.Value().(KeyProvider)) 520 case identKeyUsed{}: 521 keyUsed = option.Value() 522 case identKey{}: 523 pair := option.Value().(*withKey) 524 alg, ok := pair.alg.(jwa.KeyEncryptionAlgorithm) 525 if !ok { 526 return nil, fmt.Errorf(`WithKey() option must be specified using jwa.KeyEncryptionAlgorithm (got %T)`, pair.alg) 527 } 528 keyProviders = append(keyProviders, &staticKeyProvider{ 529 alg: alg, 530 key: pair.key, 531 }) 532 case identCEK{}: 533 cek = option.Value().(*[]byte) 534 case identMaxDecompressBufferSize{}: 535 perCallMaxDecompressBufferSize = option.Value().(int64) 536 } 537 } 538 539 if len(keyProviders) < 1 { 540 return nil, fmt.Errorf(`jwe.Decrypt: no key providers have been provided (see jwe.WithKey(), jwe.WithKeySet(), and jwe.WithKeyProvider()`) 541 } 542 543 msg, err := parseJSONOrCompact(buf, true) 544 if err != nil { 545 return nil, fmt.Errorf(`failed to parse buffer for Decrypt: %w`, err) 546 } 547 548 // Process things that are common to the message 549 ctx := context.TODO() 550 h, err := msg.protectedHeaders.Clone(ctx) 551 if err != nil { 552 return nil, fmt.Errorf(`failed to copy protected headers: %w`, err) 553 } 554 h, err = h.Merge(ctx, msg.unprotectedHeaders) 555 if err != nil { 556 return nil, fmt.Errorf(`failed to merge headers for message decryption: %w`, err) 557 } 558 559 var aad []byte 560 if aadContainer := msg.authenticatedData; aadContainer != nil { 561 aad = base64.Encode(aadContainer) 562 } 563 564 var computedAad []byte 565 if len(msg.rawProtectedHeaders) > 0 { 566 computedAad = msg.rawProtectedHeaders 567 } else { 568 // this is probably not required once msg.Decrypt is deprecated 569 var err error 570 computedAad, err = msg.protectedHeaders.Encode() 571 if err != nil { 572 return nil, fmt.Errorf(`failed to encode protected headers: %w`, err) 573 } 574 } 575 576 // for each recipient, attempt to match the key providers 577 // if we have no recipients, pretend like we only have one 578 recipients := msg.recipients 579 if len(recipients) == 0 { 580 r := NewRecipient() 581 if err := r.SetHeaders(msg.protectedHeaders); err != nil { 582 return nil, fmt.Errorf(`failed to set headers to recipient: %w`, err) 583 } 584 recipients = append(recipients, r) 585 } 586 587 var dctx decryptCtx 588 589 dctx.aad = aad 590 dctx.computedAad = computedAad 591 dctx.msg = msg 592 dctx.keyProviders = keyProviders 593 dctx.protectedHeaders = h 594 dctx.cek = cek 595 dctx.maxDecompressBufferSize = perCallMaxDecompressBufferSize 596 597 var lastError error 598 for _, recipient := range recipients { 599 decrypted, err := dctx.try(ctx, recipient, keyUsed) 600 if err != nil { 601 lastError = err 602 continue 603 } 604 if dst != nil { 605 *dst = *msg 606 dst.rawProtectedHeaders = nil 607 dst.storeProtectedHeaders = false 608 } 609 return decrypted, nil 610 } 611 return nil, fmt.Errorf(`jwe.Decrypt: failed to decrypt any of the recipients (last error = %w)`, lastError) 612 } 613 614 func (dctx *decryptCtx) try(ctx context.Context, recipient Recipient, keyUsed interface{}) ([]byte, error) { 615 var tried int 616 var lastError error 617 for i, kp := range dctx.keyProviders { 618 var sink algKeySink 619 if err := kp.FetchKeys(ctx, &sink, recipient, dctx.msg); err != nil { 620 return nil, fmt.Errorf(`key provider %d failed: %w`, i, err) 621 } 622 623 for _, pair := range sink.list { 624 tried++ 625 // alg is converted here because pair.alg is of type jwa.KeyAlgorithm. 626 // this may seem ugly, but we're trying to avoid declaring separate 627 // structs for `alg jwa.KeyAlgorithm` and `alg jwa.SignatureAlgorithm` 628 //nolint:forcetypeassert 629 alg := pair.alg.(jwa.KeyEncryptionAlgorithm) 630 key := pair.key 631 632 decrypted, err := dctx.decryptContent(ctx, alg, key, recipient) 633 if err != nil { 634 lastError = err 635 continue 636 } 637 638 if keyUsed != nil { 639 if err := blackmagic.AssignIfCompatible(keyUsed, key); err != nil { 640 return nil, fmt.Errorf(`failed to assign used key (%T) to %T: %w`, key, keyUsed, err) 641 } 642 } 643 return decrypted, nil 644 } 645 } 646 return nil, fmt.Errorf(`jwe.Decrypt: tried %d keys, but failed to match any of the keys with recipient (last error = %s)`, tried, lastError) 647 } 648 649 func (dctx *decryptCtx) decryptContent(ctx context.Context, alg jwa.KeyEncryptionAlgorithm, key interface{}, recipient Recipient) ([]byte, error) { 650 if jwkKey, ok := key.(jwk.Key); ok { 651 var raw interface{} 652 if err := jwkKey.Raw(&raw); err != nil { 653 return nil, fmt.Errorf(`failed to retrieve raw key from %T: %w`, key, err) 654 } 655 key = raw 656 } 657 658 dec := newDecrypter(alg, dctx.msg.protectedHeaders.ContentEncryption(), key). 659 AuthenticatedData(dctx.aad). 660 ComputedAuthenticatedData(dctx.computedAad). 661 InitializationVector(dctx.msg.initializationVector). 662 Tag(dctx.msg.tag). 663 CEK(dctx.cek) 664 665 if recipient.Headers().Algorithm() != alg { 666 // algorithms don't match 667 return nil, fmt.Errorf(`jwe.Decrypt: key and recipient algorithms do not match`) 668 } 669 670 h2, err := dctx.protectedHeaders.Clone(ctx) 671 if err != nil { 672 return nil, fmt.Errorf(`jwe.Decrypt: failed to copy headers (1): %w`, err) 673 } 674 675 h2, err = h2.Merge(ctx, recipient.Headers()) 676 if err != nil { 677 return nil, fmt.Errorf(`failed to copy headers (2): %w`, err) 678 } 679 680 switch alg { 681 case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW: 682 epkif, ok := h2.Get(EphemeralPublicKeyKey) 683 if !ok { 684 return nil, fmt.Errorf(`failed to get 'epk' field`) 685 } 686 switch epk := epkif.(type) { 687 case jwk.ECDSAPublicKey: 688 var pubkey ecdsa.PublicKey 689 if err := epk.Raw(&pubkey); err != nil { 690 return nil, fmt.Errorf(`failed to get public key: %w`, err) 691 } 692 dec.PublicKey(&pubkey) 693 case jwk.OKPPublicKey: 694 var pubkey interface{} 695 if err := epk.Raw(&pubkey); err != nil { 696 return nil, fmt.Errorf(`failed to get public key: %w`, err) 697 } 698 dec.PublicKey(pubkey) 699 default: 700 return nil, fmt.Errorf("unexpected 'epk' type %T for alg %s", epkif, alg) 701 } 702 703 if apu := h2.AgreementPartyUInfo(); len(apu) > 0 { 704 dec.AgreementPartyUInfo(apu) 705 } 706 if apv := h2.AgreementPartyVInfo(); len(apv) > 0 { 707 dec.AgreementPartyVInfo(apv) 708 } 709 case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW: 710 ivB64, ok := h2.Get(InitializationVectorKey) 711 if ok { 712 ivB64Str, ok := ivB64.(string) 713 if !ok { 714 return nil, fmt.Errorf("unexpected type for 'iv': %T", ivB64) 715 } 716 iv, err := base64.DecodeString(ivB64Str) 717 if err != nil { 718 return nil, fmt.Errorf(`failed to b64-decode 'iv': %w`, err) 719 } 720 dec.KeyInitializationVector(iv) 721 } 722 tagB64, ok := h2.Get(TagKey) 723 if ok { 724 tagB64Str, ok := tagB64.(string) 725 if !ok { 726 return nil, fmt.Errorf("unexpected type for 'tag': %T", tagB64) 727 } 728 tag, err := base64.DecodeString(tagB64Str) 729 if err != nil { 730 return nil, fmt.Errorf(`failed to b64-decode 'tag': %w`, err) 731 } 732 dec.KeyTag(tag) 733 } 734 case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW: 735 saltB64, ok := h2.Get(SaltKey) 736 if !ok { 737 return nil, fmt.Errorf(`failed to get 'p2s' field`) 738 } 739 saltB64Str, ok := saltB64.(string) 740 if !ok { 741 return nil, fmt.Errorf("unexpected type for 'p2s': %T", saltB64) 742 } 743 744 count, ok := h2.Get(CountKey) 745 if !ok { 746 return nil, fmt.Errorf(`failed to get 'p2c' field`) 747 } 748 countFlt, ok := count.(float64) 749 if !ok { 750 return nil, fmt.Errorf("unexpected type for 'p2c': %T", count) 751 } 752 muSettings.RLock() 753 maxCount := maxPBES2Count 754 muSettings.RUnlock() 755 if countFlt > float64(maxCount) { 756 return nil, fmt.Errorf("invalid 'p2c' value") 757 } 758 salt, err := base64.DecodeString(saltB64Str) 759 if err != nil { 760 return nil, fmt.Errorf(`failed to b64-decode 'salt': %w`, err) 761 } 762 dec.KeySalt(salt) 763 dec.KeyCount(int(countFlt)) 764 } 765 766 plaintext, err := dec.Decrypt(recipient, dctx.msg.cipherText, dctx.msg) 767 if err != nil { 768 return nil, fmt.Errorf(`jwe.Decrypt: decryption failed: %w`, err) 769 } 770 771 if h2.Compression() == jwa.Deflate { 772 buf, err := uncompress(plaintext, dctx.maxDecompressBufferSize) 773 if err != nil { 774 return nil, fmt.Errorf(`jwe.Derypt: failed to uncompress payload: %w`, err) 775 } 776 plaintext = buf 777 } 778 779 if plaintext == nil { 780 return nil, fmt.Errorf(`failed to find matching recipient`) 781 } 782 783 return plaintext, nil 784 } 785 786 // Parse parses the JWE message into a Message object. The JWE message 787 // can be either compact or full JSON format. 788 // 789 // Parse() currently does not take any options, but the API accepts it 790 // in anticipation of future addition. 791 func Parse(buf []byte, _ ...ParseOption) (*Message, error) { 792 return parseJSONOrCompact(buf, false) 793 } 794 795 func parseJSONOrCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) { 796 buf = bytes.TrimSpace(buf) 797 if len(buf) == 0 { 798 return nil, fmt.Errorf(`empty buffer`) 799 } 800 801 if buf[0] == '{' { 802 return parseJSON(buf, storeProtectedHeaders) 803 } 804 return parseCompact(buf, storeProtectedHeaders) 805 } 806 807 // ParseString is the same as Parse, but takes a string. 808 func ParseString(s string) (*Message, error) { 809 return Parse([]byte(s)) 810 } 811 812 // ParseReader is the same as Parse, but takes an io.Reader. 813 func ParseReader(src io.Reader) (*Message, error) { 814 buf, err := io.ReadAll(src) 815 if err != nil { 816 return nil, fmt.Errorf(`failed to read from io.Reader: %w`, err) 817 } 818 return Parse(buf) 819 } 820 821 func parseJSON(buf []byte, storeProtectedHeaders bool) (*Message, error) { 822 m := NewMessage() 823 m.storeProtectedHeaders = storeProtectedHeaders 824 if err := json.Unmarshal(buf, &m); err != nil { 825 return nil, fmt.Errorf(`failed to parse JSON: %w`, err) 826 } 827 return m, nil 828 } 829 830 func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) { 831 parts := bytes.Split(buf, []byte{'.'}) 832 if len(parts) != 5 { 833 return nil, fmt.Errorf(`compact JWE format must have five parts (%d)`, len(parts)) 834 } 835 836 hdrbuf, err := base64.Decode(parts[0]) 837 if err != nil { 838 return nil, fmt.Errorf(`failed to parse first part of compact form: %w`, err) 839 } 840 841 protected := NewHeaders() 842 if err := json.Unmarshal(hdrbuf, protected); err != nil { 843 return nil, fmt.Errorf(`failed to parse header JSON: %w`, err) 844 } 845 846 ivbuf, err := base64.Decode(parts[2]) 847 if err != nil { 848 return nil, fmt.Errorf(`failed to base64 decode iv: %w`, err) 849 } 850 851 ctbuf, err := base64.Decode(parts[3]) 852 if err != nil { 853 return nil, fmt.Errorf(`failed to base64 decode content: %w`, err) 854 } 855 856 tagbuf, err := base64.Decode(parts[4]) 857 if err != nil { 858 return nil, fmt.Errorf(`failed to base64 decode tag: %w`, err) 859 } 860 861 m := NewMessage() 862 if err := m.Set(CipherTextKey, ctbuf); err != nil { 863 return nil, fmt.Errorf(`failed to set %s: %w`, CipherTextKey, err) 864 } 865 if err := m.Set(InitializationVectorKey, ivbuf); err != nil { 866 return nil, fmt.Errorf(`failed to set %s: %w`, InitializationVectorKey, err) 867 } 868 if err := m.Set(ProtectedHeadersKey, protected); err != nil { 869 return nil, fmt.Errorf(`failed to set %s: %w`, ProtectedHeadersKey, err) 870 } 871 872 if err := m.makeDummyRecipient(string(parts[1]), protected); err != nil { 873 return nil, fmt.Errorf(`failed to setup recipient: %w`, err) 874 } 875 876 if err := m.Set(TagKey, tagbuf); err != nil { 877 return nil, fmt.Errorf(`failed to set %s: %w`, TagKey, err) 878 } 879 880 if storeProtectedHeaders { 881 // This is later used for decryption. 882 m.rawProtectedHeaders = parts[0] 883 } 884 885 return m, nil 886 } 887 888 // RegisterCustomField allows users to specify that a private field 889 // be decoded as an instance of the specified type. This option has 890 // a global effect. 891 // 892 // For example, suppose you have a custom field `x-birthday`, which 893 // you want to represent as a string formatted in RFC3339 in JSON, 894 // but want it back as `time.Time`. 895 // 896 // In that case you would register a custom field as follows 897 // 898 // jwe.RegisterCustomField(`x-birthday`, timeT) 899 // 900 // Then `hdr.Get("x-birthday")` will still return an `interface{}`, 901 // but you can convert its type to `time.Time` 902 // 903 // bdayif, _ := hdr.Get(`x-birthday`) 904 // bday := bdayif.(time.Time) 905 func RegisterCustomField(name string, object interface{}) { 906 registry.Register(name, object) 907 }