github.com/trustbloc/kms-go@v1.1.2/doc/jose/decrypter.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package jose 8 9 import ( 10 "crypto/ecdsa" 11 "encoding/base64" 12 "encoding/json" 13 "fmt" 14 "strings" 15 16 "github.com/google/tink/go/keyset" 17 18 "github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite" 19 "github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/api" 20 "github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/ecdh" 21 "github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/keyio" 22 ecdhpb "github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/proto/ecdh_aead_go_proto" 23 "github.com/trustbloc/kms-go/doc/jose/jwk" 24 resolver "github.com/trustbloc/kms-go/doc/jose/kidresolver" 25 26 cryptoapi "github.com/trustbloc/kms-go/spi/crypto" 27 "github.com/trustbloc/kms-go/spi/kms" 28 ) 29 30 // Decrypter interface to Decrypt JWE messages. 31 type Decrypter interface { 32 // Decrypt a deserialized JWE, extracts the corresponding recipient key to decrypt plaintext and returns it 33 Decrypt(jwe *JSONWebEncryption) ([]byte, error) 34 } 35 36 // JWEDecrypt is responsible for decrypting a JWE message and returns its protected plaintext. 37 type JWEDecrypt struct { 38 kidResolvers []resolver.KIDResolver 39 crypto cryptoapi.Crypto 40 kms kms.KeyManager 41 } 42 43 // NewJWEDecrypt creates a new JWEDecrypt instance to parse and decrypt a JWE message for a given recipient 44 // store is needed for Authcrypt only (to fetch sender's pre agreed upon public key), it is not needed for Anoncrypt. 45 func NewJWEDecrypt(kidResolvers []resolver.KIDResolver, c cryptoapi.Crypto, k kms.KeyManager) *JWEDecrypt { 46 return &JWEDecrypt{ 47 kidResolvers: kidResolvers, 48 crypto: c, 49 kms: k, 50 } 51 } 52 53 func getECDHDecPrimitive(cek []byte, encAlg EncAlg, nistpKW bool) (api.CompositeDecrypt, error) { 54 ceAlg := aeadAlg[encAlg] 55 56 if ceAlg <= 0 { 57 return nil, fmt.Errorf("invalid content encAlg: '%s'", encAlg) 58 } 59 60 kt := ecdh.KeyTemplateForECDHPrimitiveWithCEK(cek, nistpKW, ceAlg) 61 62 kh, err := keyset.NewHandle(kt) 63 if err != nil { 64 return nil, err 65 } 66 67 return ecdh.NewECDHDecrypt(kh) 68 } 69 70 // Decrypt a deserialized JWE, decrypts its protected content and returns plaintext. 71 func (jd *JWEDecrypt) Decrypt(jwe *JSONWebEncryption) ([]byte, error) { 72 encAlg, err := jd.validateAndExtractProtectedHeaders(jwe) 73 if err != nil { 74 return nil, fmt.Errorf("jwedecrypt: %w", err) 75 } 76 77 var wkOpts []cryptoapi.WrapKeyOpts 78 79 skid, ok := jwe.ProtectedHeaders.SenderKeyID() 80 if !ok { 81 skid, ok = fetchSKIDFromAPU(jwe) 82 } 83 84 if ok && skid != "" { 85 senderKH, e := jd.fetchSenderPubKey(skid, EncAlg(encAlg)) 86 if e != nil { 87 return nil, fmt.Errorf("jwedecrypt: failed to add sender public key for skid: %w", e) 88 } 89 90 wkOpts = append(wkOpts, cryptoapi.WithSender(senderKH), cryptoapi.WithTag([]byte(jwe.Tag))) 91 } 92 93 recWK, err := buildRecipientsWrappedKey(jwe) 94 if err != nil { 95 return nil, fmt.Errorf("jwedecrypt: failed to build recipients WK: %w", err) 96 } 97 98 cek, err := jd.unwrapCEK(recWK, wkOpts...) 99 if err != nil { 100 return nil, fmt.Errorf("jwedecrypt: %w", err) 101 } 102 103 if len(recWK) == 1 { 104 // ensure EPK is marshalled the same way as during encryption since it is merged into ProtectHeaders. 105 marshalledEPK, err := convertRecEPKToMarshalledJWK(&recWK[0].EPK) 106 if err != nil { 107 return nil, fmt.Errorf("jwedecrypt: %w", err) 108 } 109 110 jwe.ProtectedHeaders["epk"] = json.RawMessage(marshalledEPK) 111 } 112 113 return jd.decryptJWE(jwe, cek) 114 } 115 116 func fetchSKIDFromAPU(jwe *JSONWebEncryption) (string, bool) { 117 // for multi-recipients only: check apu in protectedHeaders if it's found for ECDH-1PU, if skid header is empty then 118 // use apu as skid instead. 119 if len(jwe.Recipients) > 1 { 120 if a, apuOK := jwe.ProtectedHeaders["apu"]; apuOK { 121 skidBytes, err := base64.RawURLEncoding.DecodeString(a.(string)) 122 if err != nil { 123 return "", false 124 } 125 126 return string(skidBytes), true 127 } 128 } 129 130 return "", false 131 } 132 133 //nolint:gocyclo 134 func (jd *JWEDecrypt) unwrapCEK(recWK []*cryptoapi.RecipientWrappedKey, 135 senderOpt ...cryptoapi.WrapKeyOpts) ([]byte, error) { 136 var ( 137 cek []byte 138 errs []error 139 ) 140 141 for _, rec := range recWK { 142 var unwrapOpts []cryptoapi.WrapKeyOpts 143 144 if strings.HasPrefix(rec.KID, "did:key") || strings.Index(rec.KID, "#") > 0 { 145 // resolve and use kms KID if did:key or KeyAgreement.ID. 146 resolvedRec, err := jd.resolveKID(rec.KID) 147 if err != nil { 148 errs = append(errs, err) 149 continue 150 } 151 152 // Need to get the kms KID in order to do kms.Get() since original rec.KID is a did:key/KeyAgreement.ID. 153 // This is necessary to ensure recipient is the owner of the key. 154 rec.KID = resolvedRec.KID 155 } 156 157 recKH, err := jd.kms.Get(rec.KID) 158 if err != nil { 159 continue 160 } 161 162 if rec.EPK.Type == ecdhpb.KeyType_OKP.String() { 163 unwrapOpts = append(unwrapOpts, cryptoapi.WithXC20PKW()) 164 } 165 166 if senderOpt != nil { 167 unwrapOpts = append(unwrapOpts, senderOpt...) 168 } 169 170 if len(unwrapOpts) > 0 { 171 cek, err = jd.crypto.UnwrapKey(rec, recKH, unwrapOpts...) 172 } else { 173 cek, err = jd.crypto.UnwrapKey(rec, recKH) 174 } 175 176 if err == nil { 177 break 178 } 179 180 errs = append(errs, err) 181 } 182 183 if len(cek) == 0 { 184 return nil, fmt.Errorf("failed to unwrap cek: %v", errs) 185 } 186 187 return cek, nil 188 } 189 190 func (jd *JWEDecrypt) resolveKID(kid string) (*cryptoapi.PublicKey, error) { 191 var errs []error 192 193 for _, resolver := range jd.kidResolvers { 194 rKID, err := resolver.Resolve(kid) 195 if err == nil { 196 return rKID, nil 197 } 198 199 errs = append(errs, err) 200 } 201 202 return nil, fmt.Errorf("resolveKID: %v", errs) 203 } 204 205 func (jd *JWEDecrypt) decryptJWE(jwe *JSONWebEncryption, cek []byte) ([]byte, error) { 206 encAlg, ok := jwe.ProtectedHeaders.Encryption() 207 if !ok { 208 return nil, fmt.Errorf("jwedecrypt: JWE 'enc' protected header is missing") 209 } 210 211 decPrimitive, err := getECDHDecPrimitive(cek, EncAlg(encAlg), true) 212 if err != nil { 213 return nil, fmt.Errorf("jwedecrypt: failed to get decryption primitive: %w", err) 214 } 215 216 encryptedData, err := buildEncryptedData(jwe) 217 if err != nil { 218 return nil, fmt.Errorf("jwedecrypt: failed to build encryptedData for Decrypt(): %w", err) 219 } 220 221 aadBytes := []byte(jwe.AAD) 222 223 authData, err := computeAuthData(jwe.ProtectedHeaders, jwe.OrigProtectedHders, aadBytes) 224 if err != nil { 225 return nil, err 226 } 227 228 return decPrimitive.Decrypt(encryptedData, authData) 229 } 230 231 func (jd *JWEDecrypt) fetchSenderPubKey(skid string, encAlg EncAlg) (*keyset.Handle, error) { 232 senderKey, err := jd.resolveKID(skid) 233 if err != nil { 234 return nil, fmt.Errorf("fetchSenderPubKey: %w", err) 235 } 236 237 ceAlg := aeadAlg[encAlg] 238 239 if ceAlg <= 0 { 240 return nil, fmt.Errorf("fetchSenderPubKey: invalid content encAlg: '%s'", encAlg) 241 } 242 243 return keyio.PublicKeyToKeysetHandle(senderKey, ceAlg) 244 } 245 246 func (jd *JWEDecrypt) validateAndExtractProtectedHeaders(jwe *JSONWebEncryption) (string, error) { 247 if jwe == nil { 248 return "", fmt.Errorf("jwe is nil") 249 } 250 251 if len(jwe.ProtectedHeaders) == 0 { 252 return "", fmt.Errorf("jwe is missing protected headers") 253 } 254 255 protectedHeaders := jwe.ProtectedHeaders 256 257 encAlg, ok := protectedHeaders.Encryption() 258 if !ok { 259 return "", fmt.Errorf("jwe is missing encryption algorithm 'enc' header") 260 } 261 262 switch encAlg { 263 case string(A256GCM), string(XC20P), string(A128CBCHS256), 264 string(A192CBCHS384), string(A256CBCHS384), string(A256CBCHS512): 265 default: 266 return "", fmt.Errorf("encryption algorithm '%s' not supported", encAlg) 267 } 268 269 return encAlg, nil 270 } 271 272 func buildRecipientsWrappedKey(jwe *JSONWebEncryption) ([]*cryptoapi.RecipientWrappedKey, error) { 273 var ( 274 recipients []*cryptoapi.RecipientWrappedKey 275 err error 276 ) 277 278 for _, recJWE := range jwe.Recipients { 279 headers := recJWE.Header 280 alg, ok := jwe.ProtectedHeaders.Algorithm() 281 is1PU := ok && strings.Contains(strings.ToUpper(alg), "1PU") 282 283 if len(jwe.Recipients) == 1 || is1PU { 284 // compact serialization: it has only 1 recipient with no headers or 1pu, extract from protectedHeaders. 285 headers, err = extractRecipientHeaders(jwe.ProtectedHeaders) 286 if err != nil { 287 return nil, err 288 } 289 } 290 291 var recWK *cryptoapi.RecipientWrappedKey 292 // set kid if 1PU (authcrypt) with multi recipients since common protected headers don't have the recipient kid. 293 if is1PU && len(jwe.Recipients) > 1 { 294 headers.KID = recJWE.Header.KID 295 } 296 297 recWK, err = createRecWK(headers, []byte(recJWE.EncryptedKey)) 298 if err != nil { 299 return nil, err 300 } 301 302 recipients = append(recipients, recWK) 303 } 304 305 return recipients, nil 306 } 307 308 func createRecWK(headers *RecipientHeaders, encryptedKey []byte) (*cryptoapi.RecipientWrappedKey, error) { 309 recWK, err := convertMarshalledJWKToRecKey(headers.EPK) 310 if err != nil { 311 return nil, err 312 } 313 314 recWK.KID = headers.KID 315 recWK.Alg = headers.Alg 316 317 err = updateAPUAPVInRecWK(recWK, headers) 318 if err != nil { 319 return nil, err 320 } 321 322 recWK.EncryptedCEK = encryptedKey 323 324 return recWK, nil 325 } 326 327 func updateAPUAPVInRecWK(recWK *cryptoapi.RecipientWrappedKey, headers *RecipientHeaders) error { 328 decodedAPU, decodedAPV, err := decodeAPUAPV(headers) 329 if err != nil { 330 return fmt.Errorf("updateAPUAPVInRecWK: %w", err) 331 } 332 333 recWK.APU = decodedAPU 334 recWK.APV = decodedAPV 335 336 return nil 337 } 338 339 func buildEncryptedData(jwe *JSONWebEncryption) ([]byte, error) { 340 encData := new(composite.EncryptedData) 341 encData.Tag = []byte(jwe.Tag) 342 encData.IV = []byte(jwe.IV) 343 encData.Ciphertext = []byte(jwe.Ciphertext) 344 345 return json.Marshal(encData) 346 } 347 348 // extractRecipientHeaders will extract RecipientHeaders from headers argument. 349 func extractRecipientHeaders(headers map[string]interface{}) (*RecipientHeaders, error) { 350 // Since headers is a generic map, epk value is converted to a generic map by Serialize(), ie we lose RawMessage 351 // type of epk. We need to convert epk value (generic map) to marshaled json so we can call RawMessage.Unmarshal() 352 // to get the original epk value (RawMessage type). 353 mapData, ok := headers[HeaderEPK].(map[string]interface{}) 354 if !ok { 355 return nil, fmt.Errorf("JSON value is not a map (%#v)", headers[HeaderEPK]) 356 } 357 358 epkBytes, err := json.Marshal(mapData) 359 if err != nil { 360 return nil, err 361 } 362 363 epk := json.RawMessage{} 364 365 err = epk.UnmarshalJSON(epkBytes) 366 if err != nil { 367 return nil, err 368 } 369 370 alg := "" 371 if headers[HeaderAlgorithm] != nil { 372 alg = fmt.Sprintf("%v", headers[HeaderAlgorithm]) 373 } 374 375 kid := "" 376 if headers[HeaderKeyID] != nil { 377 kid = fmt.Sprintf("%v", headers[HeaderKeyID]) 378 } 379 380 apu := "" 381 if headers["apu"] != nil { 382 apu = fmt.Sprintf("%v", headers["apu"]) 383 } 384 385 apv := "" 386 if headers["apv"] != nil { 387 apv = fmt.Sprintf("%v", headers["apv"]) 388 } 389 390 recHeaders := &RecipientHeaders{ 391 Alg: alg, 392 KID: kid, 393 EPK: epk, 394 APU: apu, 395 APV: apv, 396 } 397 398 // original headers should remain untouched to avoid modifying the original JWE content. 399 return recHeaders, nil 400 } 401 402 func convertMarshalledJWKToRecKey(marshalledJWK []byte) (*cryptoapi.RecipientWrappedKey, error) { 403 j := &jwk.JWK{} 404 405 err := j.UnmarshalJSON(marshalledJWK) 406 if err != nil { 407 return nil, err 408 } 409 410 epk := cryptoapi.PublicKey{ 411 Curve: j.Crv, 412 Type: j.Kty, 413 } 414 415 switch key := j.Key.(type) { 416 case *ecdsa.PublicKey: 417 epk.X = key.X.Bytes() 418 epk.Y = key.Y.Bytes() 419 case []byte: 420 epk.X = key 421 default: 422 return nil, fmt.Errorf("unsupported recipient key type") 423 } 424 425 return &cryptoapi.RecipientWrappedKey{ 426 KID: j.KeyID, 427 EPK: epk, 428 }, nil 429 }