github.com/hyperledger/aries-framework-go@v0.3.2/pkg/doc/sdjwt/common/common.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package common 8 9 import ( 10 "crypto" 11 "encoding/base64" 12 "encoding/json" 13 "fmt" 14 "reflect" 15 "strings" 16 17 "github.com/hyperledger/aries-framework-go/pkg/common/utils" 18 afgjwt "github.com/hyperledger/aries-framework-go/pkg/doc/jwt" 19 ) 20 21 // CombinedFormatSeparator is disclosure separator. 22 const ( 23 CombinedFormatSeparator = "~" 24 25 SDAlgorithmKey = "_sd_alg" 26 SDKey = "_sd" 27 CNFKey = "cnf" 28 29 disclosureParts = 3 30 saltIndex = 0 31 nameIndex = 1 32 valueIndex = 2 33 ) 34 35 // CombinedFormatForIssuance holds SD-JWT and disclosures. 36 type CombinedFormatForIssuance struct { 37 SDJWT string 38 Disclosures []string 39 } 40 41 // Serialize will assemble combined format for issuance. 42 func (cf *CombinedFormatForIssuance) Serialize() string { 43 presentation := cf.SDJWT 44 for _, disclosure := range cf.Disclosures { 45 presentation += CombinedFormatSeparator + disclosure 46 } 47 48 return presentation 49 } 50 51 // CombinedFormatForPresentation holds SD-JWT, disclosures and optional holder binding info. 52 type CombinedFormatForPresentation struct { 53 SDJWT string 54 Disclosures []string 55 HolderBinding string 56 } 57 58 // Serialize will assemble combined format for presentation. 59 func (cf *CombinedFormatForPresentation) Serialize() string { 60 presentation := cf.SDJWT 61 for _, disclosure := range cf.Disclosures { 62 presentation += CombinedFormatSeparator + disclosure 63 } 64 65 if len(cf.Disclosures) > 0 || cf.HolderBinding != "" { 66 presentation += CombinedFormatSeparator 67 } 68 69 presentation += cf.HolderBinding 70 71 return presentation 72 } 73 74 // DisclosureClaim defines claim. 75 type DisclosureClaim struct { 76 Disclosure string 77 Salt string 78 Name string 79 Value interface{} 80 } 81 82 // GetDisclosureClaims de-codes disclosures. 83 func GetDisclosureClaims(disclosures []string) ([]*DisclosureClaim, error) { 84 var claims []*DisclosureClaim 85 86 for _, disclosure := range disclosures { 87 claim, err := getDisclosureClaim(disclosure) 88 if err != nil { 89 return nil, err 90 } 91 92 claims = append(claims, claim) 93 } 94 95 return claims, nil 96 } 97 98 func getDisclosureClaim(disclosure string) (*DisclosureClaim, error) { 99 decoded, err := base64.RawURLEncoding.DecodeString(disclosure) 100 if err != nil { 101 return nil, fmt.Errorf("failed to decode disclosure: %w", err) 102 } 103 104 var disclosureArr []interface{} 105 106 err = json.Unmarshal(decoded, &disclosureArr) 107 if err != nil { 108 return nil, fmt.Errorf("failed to unmarshal disclosure array: %w", err) 109 } 110 111 if len(disclosureArr) != disclosureParts { 112 return nil, fmt.Errorf("disclosure array size[%d] must be %d", len(disclosureArr), disclosureParts) 113 } 114 115 salt, ok := disclosureArr[saltIndex].(string) 116 if !ok { 117 return nil, fmt.Errorf("disclosure salt type[%T] must be string", disclosureArr[saltIndex]) 118 } 119 120 name, ok := disclosureArr[nameIndex].(string) 121 if !ok { 122 return nil, fmt.Errorf("disclosure name type[%T] must be string", disclosureArr[nameIndex]) 123 } 124 125 claim := &DisclosureClaim{Disclosure: disclosure, Salt: salt, Name: name, Value: disclosureArr[valueIndex]} 126 127 return claim, nil 128 } 129 130 // ParseCombinedFormatForIssuance parses combined format for issuance into CombinedFormatForIssuance parts. 131 func ParseCombinedFormatForIssuance(combinedFormatForIssuance string) *CombinedFormatForIssuance { 132 parts := strings.Split(combinedFormatForIssuance, CombinedFormatSeparator) 133 134 var disclosures []string 135 if len(parts) > 1 { 136 disclosures = parts[1:] 137 } 138 139 sdJWT := parts[0] 140 141 return &CombinedFormatForIssuance{SDJWT: sdJWT, Disclosures: disclosures} 142 } 143 144 // ParseCombinedFormatForPresentation parses combined format for presentation into CombinedFormatForPresentation parts. 145 func ParseCombinedFormatForPresentation(combinedFormatForPresentation string) *CombinedFormatForPresentation { 146 parts := strings.Split(combinedFormatForPresentation, CombinedFormatSeparator) 147 148 var disclosures []string 149 if len(parts) > 2 { 150 disclosures = parts[1 : len(parts)-1] 151 } 152 153 var holderBinding string 154 if len(parts) > 1 { 155 holderBinding = parts[len(parts)-1] 156 } 157 158 sdJWT := parts[0] 159 160 return &CombinedFormatForPresentation{SDJWT: sdJWT, Disclosures: disclosures, HolderBinding: holderBinding} 161 } 162 163 // GetHash calculates hash of data using hash function identified by hash. 164 func GetHash(hash crypto.Hash, value string) (string, error) { 165 if !hash.Available() { 166 return "", fmt.Errorf("hash function not available for: %d", hash) 167 } 168 169 h := hash.New() 170 171 if _, hashErr := h.Write([]byte(value)); hashErr != nil { 172 return "", hashErr 173 } 174 175 result := h.Sum(nil) 176 177 return base64.RawURLEncoding.EncodeToString(result), nil 178 } 179 180 // VerifyDisclosuresInSDJWT checks for disclosure inclusion in SD-JWT. 181 func VerifyDisclosuresInSDJWT(disclosures []string, signedJWT *afgjwt.JSONWebToken) error { 182 claims := utils.CopyMap(signedJWT.Payload) 183 184 // check that the _sd_alg claim is present 185 // check that _sd_alg value is understood and the hash algorithm is deemed secure. 186 cryptoHash, err := GetCryptoHashFromClaims(claims) 187 if err != nil { 188 return err 189 } 190 191 for _, disclosure := range disclosures { 192 digest, err := GetHash(cryptoHash, disclosure) 193 if err != nil { 194 return err 195 } 196 197 found, err := isDigestInClaims(digest, claims) 198 if err != nil { 199 return err 200 } 201 202 if !found { 203 return fmt.Errorf("disclosure digest '%s' not found in SD-JWT disclosure digests", digest) 204 } 205 } 206 207 return nil 208 } 209 210 func isDigestInClaims(digest string, claims map[string]interface{}) (bool, error) { 211 var found bool 212 213 digests, err := GetDisclosureDigests(claims) 214 if err != nil { 215 return false, err 216 } 217 218 for _, value := range claims { 219 if obj, ok := value.(map[string]interface{}); ok { 220 found, err = isDigestInClaims(digest, obj) 221 if err != nil { 222 return false, err 223 } 224 225 if found { 226 return found, nil 227 } 228 } 229 } 230 231 _, ok := digests[digest] 232 233 return ok, nil 234 } 235 236 // GetCryptoHashFromClaims returns crypto hash from claims. 237 func GetCryptoHashFromClaims(claims map[string]interface{}) (crypto.Hash, error) { 238 var cryptoHash crypto.Hash 239 240 // check that the _sd_alg claim is present 241 sdAlg, err := GetSDAlg(claims) 242 if err != nil { 243 return cryptoHash, err 244 } 245 246 // check that _sd_alg value is understood and the hash algorithm is deemed secure. 247 return GetCryptoHash(sdAlg) 248 } 249 250 // GetCryptoHash returns crypto hash from SD algorithm. 251 func GetCryptoHash(sdAlg string) (crypto.Hash, error) { 252 var err error 253 254 var cryptoHash crypto.Hash 255 256 // From spec: the hash algorithms MD2, MD4, MD5, RIPEMD-160, and SHA-1 revealed fundamental weaknesses 257 // and they MUST NOT be used. 258 259 switch strings.ToUpper(sdAlg) { 260 case crypto.SHA256.String(): 261 cryptoHash = crypto.SHA256 262 case crypto.SHA384.String(): 263 cryptoHash = crypto.SHA384 264 case crypto.SHA512.String(): 265 cryptoHash = crypto.SHA512 266 default: 267 err = fmt.Errorf("%s '%s' not supported", SDAlgorithmKey, sdAlg) 268 } 269 270 return cryptoHash, err 271 } 272 273 // GetSDAlg returns SD algorithm from claims. 274 func GetSDAlg(claims map[string]interface{}) (string, error) { 275 var alg string 276 277 obj, ok := claims[SDAlgorithmKey] 278 if !ok { 279 // if claims contain 'vc' claim it may be present in vc 280 obj, ok = GetKeyFromVC(SDAlgorithmKey, claims) 281 if !ok { 282 return "", fmt.Errorf("%s must be present in SD-JWT", SDAlgorithmKey) 283 } 284 } 285 286 alg, ok = obj.(string) 287 if !ok { 288 return "", fmt.Errorf("%s must be a string", SDAlgorithmKey) 289 } 290 291 return alg, nil 292 } 293 294 // GetKeyFromVC returns key value from VC. 295 func GetKeyFromVC(key string, claims map[string]interface{}) (interface{}, bool) { 296 vcObj, ok := claims["vc"] 297 if !ok { 298 return nil, false 299 } 300 301 vc, ok := vcObj.(map[string]interface{}) 302 if !ok { 303 return nil, false 304 } 305 306 obj, ok := vc[key] 307 if !ok { 308 return nil, false 309 } 310 311 return obj, true 312 } 313 314 // GetCNF returns confirmation claim 'cnf'. 315 func GetCNF(claims map[string]interface{}) (map[string]interface{}, error) { 316 obj, ok := claims[CNFKey] 317 if !ok { 318 obj, ok = GetKeyFromVC(CNFKey, claims) 319 if !ok { 320 return nil, fmt.Errorf("%s must be present in SD-JWT", CNFKey) 321 } 322 } 323 324 cnf, ok := obj.(map[string]interface{}) 325 if !ok { 326 return nil, fmt.Errorf("%s must be an object", CNFKey) 327 } 328 329 return cnf, nil 330 } 331 332 // GetDisclosureDigests returns digests from claims map. 333 func GetDisclosureDigests(claims map[string]interface{}) (map[string]bool, error) { 334 disclosuresObj, ok := claims[SDKey] 335 if !ok { 336 return nil, nil 337 } 338 339 disclosures, err := stringArray(disclosuresObj) 340 if err != nil { 341 return nil, fmt.Errorf("get disclosure digests: %w", err) 342 } 343 344 return SliceToMap(disclosures), nil 345 } 346 347 // GetDisclosedClaims returns disclosed claims only. 348 func GetDisclosedClaims(disclosureClaims []*DisclosureClaim, claims map[string]interface{}) (map[string]interface{}, error) { // nolint:lll 349 hash, err := GetCryptoHashFromClaims(claims) 350 if err != nil { 351 return nil, fmt.Errorf("failed to get crypto hash from claims: %w", err) 352 } 353 354 output := utils.CopyMap(claims) 355 includedDigests := make(map[string]bool) 356 357 err = processDisclosedClaims(disclosureClaims, output, includedDigests, hash) 358 if err != nil { 359 return nil, fmt.Errorf("failed to process disclosed claims: %w", err) 360 } 361 362 return output, nil 363 } 364 365 func processDisclosedClaims(disclosureClaims []*DisclosureClaim, claims map[string]interface{}, includedDigests map[string]bool, hash crypto.Hash) error { // nolint:lll 366 digests, err := GetDisclosureDigests(claims) 367 if err != nil { 368 return err 369 } 370 371 for key, value := range claims { 372 if obj, ok := value.(map[string]interface{}); ok { 373 err := processDisclosedClaims(disclosureClaims, obj, includedDigests, hash) 374 if err != nil { 375 return err 376 } 377 378 claims[key] = obj 379 } 380 } 381 382 for _, dc := range disclosureClaims { 383 digest, err := GetHash(hash, dc.Disclosure) 384 if err != nil { 385 return err 386 } 387 388 if _, ok := digests[digest]; !ok { 389 continue 390 } 391 392 _, digestAlreadyIncluded := includedDigests[digest] 393 if digestAlreadyIncluded { 394 // If there is more than one place where the digest is included, 395 // the Verifier MUST reject the Presentation. 396 return fmt.Errorf("digest '%s' has been included in more than one place", digest) 397 } 398 399 err = validateClaim(dc, claims) 400 if err != nil { 401 return err 402 } 403 404 claims[dc.Name] = dc.Value 405 406 includedDigests[digest] = true 407 } 408 409 delete(claims, SDKey) 410 delete(claims, SDAlgorithmKey) 411 412 return nil 413 } 414 415 func validateClaim(dc *DisclosureClaim, claims map[string]interface{}) error { 416 _, claimNameExists := claims[dc.Name] 417 if claimNameExists { 418 // If the claim name already exists at the same level, the Verifier MUST reject the Presentation. 419 return fmt.Errorf("claim name '%s' already exists at the same level", dc.Name) 420 } 421 422 m, ok := getMap(dc.Value) 423 if ok { 424 if KeyExistsInMap(SDKey, m) { 425 // If the claim value contains an object with an _sd key (at the top level or nested deeper), 426 // the Verifier MUST reject the Presentation. 427 return fmt.Errorf("claim value contains an object with an '%s' key", SDKey) 428 } 429 } 430 431 return nil 432 } 433 434 func getMap(value interface{}) (map[string]interface{}, bool) { 435 val, ok := value.(map[string]interface{}) 436 437 return val, ok 438 } 439 440 func stringArray(entry interface{}) ([]string, error) { 441 if entry == nil { 442 return nil, nil 443 } 444 445 sliceValue := reflect.ValueOf(entry) 446 if sliceValue.Kind() != reflect.Slice { 447 return nil, fmt.Errorf("entry type[%T] is not an array", entry) 448 } 449 450 // Iterate over the slice and convert each element to a string 451 stringSlice := make([]string, sliceValue.Len()) 452 453 for i := 0; i < sliceValue.Len(); i++ { 454 sliceVal := sliceValue.Index(i).Interface() 455 val, ok := sliceVal.(string) 456 457 if !ok { 458 return nil, fmt.Errorf("entry item type[%T] is not a string", sliceVal) 459 } 460 461 stringSlice[i] = val 462 } 463 464 return stringSlice, nil 465 } 466 467 // SliceToMap converts slice to map. 468 func SliceToMap(ids []string) map[string]bool { 469 // convert slice to map 470 values := make(map[string]bool) 471 for _, id := range ids { 472 values[id] = true 473 } 474 475 return values 476 } 477 478 // KeyExistsInMap checks if key exists in map. 479 func KeyExistsInMap(key string, m map[string]interface{}) bool { 480 for k, v := range m { 481 if k == key { 482 return true 483 } 484 485 if obj, ok := v.(map[string]interface{}); ok { 486 exists := KeyExistsInMap(key, obj) 487 if exists { 488 return true 489 } 490 } 491 } 492 493 return false 494 }