github.com/hyperledger/aries-framework-go@v0.3.2/pkg/doc/sdjwt/common/common.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package common
     8  
     9  import (
    10  	"crypto"
    11  	"encoding/base64"
    12  	"encoding/json"
    13  	"fmt"
    14  	"reflect"
    15  	"strings"
    16  
    17  	"github.com/hyperledger/aries-framework-go/pkg/common/utils"
    18  	afgjwt "github.com/hyperledger/aries-framework-go/pkg/doc/jwt"
    19  )
    20  
    21  // CombinedFormatSeparator is disclosure separator.
    22  const (
    23  	CombinedFormatSeparator = "~"
    24  
    25  	SDAlgorithmKey = "_sd_alg"
    26  	SDKey          = "_sd"
    27  	CNFKey         = "cnf"
    28  
    29  	disclosureParts = 3
    30  	saltIndex       = 0
    31  	nameIndex       = 1
    32  	valueIndex      = 2
    33  )
    34  
    35  // CombinedFormatForIssuance holds SD-JWT and disclosures.
    36  type CombinedFormatForIssuance struct {
    37  	SDJWT       string
    38  	Disclosures []string
    39  }
    40  
    41  // Serialize will assemble combined format for issuance.
    42  func (cf *CombinedFormatForIssuance) Serialize() string {
    43  	presentation := cf.SDJWT
    44  	for _, disclosure := range cf.Disclosures {
    45  		presentation += CombinedFormatSeparator + disclosure
    46  	}
    47  
    48  	return presentation
    49  }
    50  
    51  // CombinedFormatForPresentation holds SD-JWT, disclosures and optional holder binding info.
    52  type CombinedFormatForPresentation struct {
    53  	SDJWT         string
    54  	Disclosures   []string
    55  	HolderBinding string
    56  }
    57  
    58  // Serialize will assemble combined format for presentation.
    59  func (cf *CombinedFormatForPresentation) Serialize() string {
    60  	presentation := cf.SDJWT
    61  	for _, disclosure := range cf.Disclosures {
    62  		presentation += CombinedFormatSeparator + disclosure
    63  	}
    64  
    65  	if len(cf.Disclosures) > 0 || cf.HolderBinding != "" {
    66  		presentation += CombinedFormatSeparator
    67  	}
    68  
    69  	presentation += cf.HolderBinding
    70  
    71  	return presentation
    72  }
    73  
    74  // DisclosureClaim defines claim.
    75  type DisclosureClaim struct {
    76  	Disclosure string
    77  	Salt       string
    78  	Name       string
    79  	Value      interface{}
    80  }
    81  
    82  // GetDisclosureClaims de-codes disclosures.
    83  func GetDisclosureClaims(disclosures []string) ([]*DisclosureClaim, error) {
    84  	var claims []*DisclosureClaim
    85  
    86  	for _, disclosure := range disclosures {
    87  		claim, err := getDisclosureClaim(disclosure)
    88  		if err != nil {
    89  			return nil, err
    90  		}
    91  
    92  		claims = append(claims, claim)
    93  	}
    94  
    95  	return claims, nil
    96  }
    97  
    98  func getDisclosureClaim(disclosure string) (*DisclosureClaim, error) {
    99  	decoded, err := base64.RawURLEncoding.DecodeString(disclosure)
   100  	if err != nil {
   101  		return nil, fmt.Errorf("failed to decode disclosure: %w", err)
   102  	}
   103  
   104  	var disclosureArr []interface{}
   105  
   106  	err = json.Unmarshal(decoded, &disclosureArr)
   107  	if err != nil {
   108  		return nil, fmt.Errorf("failed to unmarshal disclosure array: %w", err)
   109  	}
   110  
   111  	if len(disclosureArr) != disclosureParts {
   112  		return nil, fmt.Errorf("disclosure array size[%d] must be %d", len(disclosureArr), disclosureParts)
   113  	}
   114  
   115  	salt, ok := disclosureArr[saltIndex].(string)
   116  	if !ok {
   117  		return nil, fmt.Errorf("disclosure salt type[%T] must be string", disclosureArr[saltIndex])
   118  	}
   119  
   120  	name, ok := disclosureArr[nameIndex].(string)
   121  	if !ok {
   122  		return nil, fmt.Errorf("disclosure name type[%T] must be string", disclosureArr[nameIndex])
   123  	}
   124  
   125  	claim := &DisclosureClaim{Disclosure: disclosure, Salt: salt, Name: name, Value: disclosureArr[valueIndex]}
   126  
   127  	return claim, nil
   128  }
   129  
   130  // ParseCombinedFormatForIssuance parses combined format for issuance into CombinedFormatForIssuance parts.
   131  func ParseCombinedFormatForIssuance(combinedFormatForIssuance string) *CombinedFormatForIssuance {
   132  	parts := strings.Split(combinedFormatForIssuance, CombinedFormatSeparator)
   133  
   134  	var disclosures []string
   135  	if len(parts) > 1 {
   136  		disclosures = parts[1:]
   137  	}
   138  
   139  	sdJWT := parts[0]
   140  
   141  	return &CombinedFormatForIssuance{SDJWT: sdJWT, Disclosures: disclosures}
   142  }
   143  
   144  // ParseCombinedFormatForPresentation parses combined format for presentation into CombinedFormatForPresentation parts.
   145  func ParseCombinedFormatForPresentation(combinedFormatForPresentation string) *CombinedFormatForPresentation {
   146  	parts := strings.Split(combinedFormatForPresentation, CombinedFormatSeparator)
   147  
   148  	var disclosures []string
   149  	if len(parts) > 2 {
   150  		disclosures = parts[1 : len(parts)-1]
   151  	}
   152  
   153  	var holderBinding string
   154  	if len(parts) > 1 {
   155  		holderBinding = parts[len(parts)-1]
   156  	}
   157  
   158  	sdJWT := parts[0]
   159  
   160  	return &CombinedFormatForPresentation{SDJWT: sdJWT, Disclosures: disclosures, HolderBinding: holderBinding}
   161  }
   162  
   163  // GetHash calculates hash of data using hash function identified by hash.
   164  func GetHash(hash crypto.Hash, value string) (string, error) {
   165  	if !hash.Available() {
   166  		return "", fmt.Errorf("hash function not available for: %d", hash)
   167  	}
   168  
   169  	h := hash.New()
   170  
   171  	if _, hashErr := h.Write([]byte(value)); hashErr != nil {
   172  		return "", hashErr
   173  	}
   174  
   175  	result := h.Sum(nil)
   176  
   177  	return base64.RawURLEncoding.EncodeToString(result), nil
   178  }
   179  
   180  // VerifyDisclosuresInSDJWT checks for disclosure inclusion in SD-JWT.
   181  func VerifyDisclosuresInSDJWT(disclosures []string, signedJWT *afgjwt.JSONWebToken) error {
   182  	claims := utils.CopyMap(signedJWT.Payload)
   183  
   184  	// check that the _sd_alg claim is present
   185  	// check that _sd_alg value is understood and the hash algorithm is deemed secure.
   186  	cryptoHash, err := GetCryptoHashFromClaims(claims)
   187  	if err != nil {
   188  		return err
   189  	}
   190  
   191  	for _, disclosure := range disclosures {
   192  		digest, err := GetHash(cryptoHash, disclosure)
   193  		if err != nil {
   194  			return err
   195  		}
   196  
   197  		found, err := isDigestInClaims(digest, claims)
   198  		if err != nil {
   199  			return err
   200  		}
   201  
   202  		if !found {
   203  			return fmt.Errorf("disclosure digest '%s' not found in SD-JWT disclosure digests", digest)
   204  		}
   205  	}
   206  
   207  	return nil
   208  }
   209  
   210  func isDigestInClaims(digest string, claims map[string]interface{}) (bool, error) {
   211  	var found bool
   212  
   213  	digests, err := GetDisclosureDigests(claims)
   214  	if err != nil {
   215  		return false, err
   216  	}
   217  
   218  	for _, value := range claims {
   219  		if obj, ok := value.(map[string]interface{}); ok {
   220  			found, err = isDigestInClaims(digest, obj)
   221  			if err != nil {
   222  				return false, err
   223  			}
   224  
   225  			if found {
   226  				return found, nil
   227  			}
   228  		}
   229  	}
   230  
   231  	_, ok := digests[digest]
   232  
   233  	return ok, nil
   234  }
   235  
   236  // GetCryptoHashFromClaims returns crypto hash from claims.
   237  func GetCryptoHashFromClaims(claims map[string]interface{}) (crypto.Hash, error) {
   238  	var cryptoHash crypto.Hash
   239  
   240  	// check that the _sd_alg claim is present
   241  	sdAlg, err := GetSDAlg(claims)
   242  	if err != nil {
   243  		return cryptoHash, err
   244  	}
   245  
   246  	// check that _sd_alg value is understood and the hash algorithm is deemed secure.
   247  	return GetCryptoHash(sdAlg)
   248  }
   249  
   250  // GetCryptoHash returns crypto hash from SD algorithm.
   251  func GetCryptoHash(sdAlg string) (crypto.Hash, error) {
   252  	var err error
   253  
   254  	var cryptoHash crypto.Hash
   255  
   256  	// From spec: the hash algorithms MD2, MD4, MD5, RIPEMD-160, and SHA-1 revealed fundamental weaknesses
   257  	// and they MUST NOT be used.
   258  
   259  	switch strings.ToUpper(sdAlg) {
   260  	case crypto.SHA256.String():
   261  		cryptoHash = crypto.SHA256
   262  	case crypto.SHA384.String():
   263  		cryptoHash = crypto.SHA384
   264  	case crypto.SHA512.String():
   265  		cryptoHash = crypto.SHA512
   266  	default:
   267  		err = fmt.Errorf("%s '%s' not supported", SDAlgorithmKey, sdAlg)
   268  	}
   269  
   270  	return cryptoHash, err
   271  }
   272  
   273  // GetSDAlg returns SD algorithm from claims.
   274  func GetSDAlg(claims map[string]interface{}) (string, error) {
   275  	var alg string
   276  
   277  	obj, ok := claims[SDAlgorithmKey]
   278  	if !ok {
   279  		// if claims contain 'vc' claim it may be present in vc
   280  		obj, ok = GetKeyFromVC(SDAlgorithmKey, claims)
   281  		if !ok {
   282  			return "", fmt.Errorf("%s must be present in SD-JWT", SDAlgorithmKey)
   283  		}
   284  	}
   285  
   286  	alg, ok = obj.(string)
   287  	if !ok {
   288  		return "", fmt.Errorf("%s must be a string", SDAlgorithmKey)
   289  	}
   290  
   291  	return alg, nil
   292  }
   293  
   294  // GetKeyFromVC returns key value from VC.
   295  func GetKeyFromVC(key string, claims map[string]interface{}) (interface{}, bool) {
   296  	vcObj, ok := claims["vc"]
   297  	if !ok {
   298  		return nil, false
   299  	}
   300  
   301  	vc, ok := vcObj.(map[string]interface{})
   302  	if !ok {
   303  		return nil, false
   304  	}
   305  
   306  	obj, ok := vc[key]
   307  	if !ok {
   308  		return nil, false
   309  	}
   310  
   311  	return obj, true
   312  }
   313  
   314  // GetCNF returns confirmation claim 'cnf'.
   315  func GetCNF(claims map[string]interface{}) (map[string]interface{}, error) {
   316  	obj, ok := claims[CNFKey]
   317  	if !ok {
   318  		obj, ok = GetKeyFromVC(CNFKey, claims)
   319  		if !ok {
   320  			return nil, fmt.Errorf("%s must be present in SD-JWT", CNFKey)
   321  		}
   322  	}
   323  
   324  	cnf, ok := obj.(map[string]interface{})
   325  	if !ok {
   326  		return nil, fmt.Errorf("%s must be an object", CNFKey)
   327  	}
   328  
   329  	return cnf, nil
   330  }
   331  
   332  // GetDisclosureDigests returns digests from claims map.
   333  func GetDisclosureDigests(claims map[string]interface{}) (map[string]bool, error) {
   334  	disclosuresObj, ok := claims[SDKey]
   335  	if !ok {
   336  		return nil, nil
   337  	}
   338  
   339  	disclosures, err := stringArray(disclosuresObj)
   340  	if err != nil {
   341  		return nil, fmt.Errorf("get disclosure digests: %w", err)
   342  	}
   343  
   344  	return SliceToMap(disclosures), nil
   345  }
   346  
   347  // GetDisclosedClaims returns disclosed claims only.
   348  func GetDisclosedClaims(disclosureClaims []*DisclosureClaim, claims map[string]interface{}) (map[string]interface{}, error) { // nolint:lll
   349  	hash, err := GetCryptoHashFromClaims(claims)
   350  	if err != nil {
   351  		return nil, fmt.Errorf("failed to get crypto hash from claims: %w", err)
   352  	}
   353  
   354  	output := utils.CopyMap(claims)
   355  	includedDigests := make(map[string]bool)
   356  
   357  	err = processDisclosedClaims(disclosureClaims, output, includedDigests, hash)
   358  	if err != nil {
   359  		return nil, fmt.Errorf("failed to process disclosed claims: %w", err)
   360  	}
   361  
   362  	return output, nil
   363  }
   364  
   365  func processDisclosedClaims(disclosureClaims []*DisclosureClaim, claims map[string]interface{}, includedDigests map[string]bool, hash crypto.Hash) error { // nolint:lll
   366  	digests, err := GetDisclosureDigests(claims)
   367  	if err != nil {
   368  		return err
   369  	}
   370  
   371  	for key, value := range claims {
   372  		if obj, ok := value.(map[string]interface{}); ok {
   373  			err := processDisclosedClaims(disclosureClaims, obj, includedDigests, hash)
   374  			if err != nil {
   375  				return err
   376  			}
   377  
   378  			claims[key] = obj
   379  		}
   380  	}
   381  
   382  	for _, dc := range disclosureClaims {
   383  		digest, err := GetHash(hash, dc.Disclosure)
   384  		if err != nil {
   385  			return err
   386  		}
   387  
   388  		if _, ok := digests[digest]; !ok {
   389  			continue
   390  		}
   391  
   392  		_, digestAlreadyIncluded := includedDigests[digest]
   393  		if digestAlreadyIncluded {
   394  			// If there is more than one place where the digest is included,
   395  			// the Verifier MUST reject the Presentation.
   396  			return fmt.Errorf("digest '%s' has been included in more than one place", digest)
   397  		}
   398  
   399  		err = validateClaim(dc, claims)
   400  		if err != nil {
   401  			return err
   402  		}
   403  
   404  		claims[dc.Name] = dc.Value
   405  
   406  		includedDigests[digest] = true
   407  	}
   408  
   409  	delete(claims, SDKey)
   410  	delete(claims, SDAlgorithmKey)
   411  
   412  	return nil
   413  }
   414  
   415  func validateClaim(dc *DisclosureClaim, claims map[string]interface{}) error {
   416  	_, claimNameExists := claims[dc.Name]
   417  	if claimNameExists {
   418  		// If the claim name already exists at the same level, the Verifier MUST reject the Presentation.
   419  		return fmt.Errorf("claim name '%s' already exists at the same level", dc.Name)
   420  	}
   421  
   422  	m, ok := getMap(dc.Value)
   423  	if ok {
   424  		if KeyExistsInMap(SDKey, m) {
   425  			// If the claim value contains an object with an _sd key (at the top level or nested deeper),
   426  			// the Verifier MUST reject the Presentation.
   427  			return fmt.Errorf("claim value contains an object with an '%s' key", SDKey)
   428  		}
   429  	}
   430  
   431  	return nil
   432  }
   433  
   434  func getMap(value interface{}) (map[string]interface{}, bool) {
   435  	val, ok := value.(map[string]interface{})
   436  
   437  	return val, ok
   438  }
   439  
   440  func stringArray(entry interface{}) ([]string, error) {
   441  	if entry == nil {
   442  		return nil, nil
   443  	}
   444  
   445  	sliceValue := reflect.ValueOf(entry)
   446  	if sliceValue.Kind() != reflect.Slice {
   447  		return nil, fmt.Errorf("entry type[%T] is not an array", entry)
   448  	}
   449  
   450  	// Iterate over the slice and convert each element to a string
   451  	stringSlice := make([]string, sliceValue.Len())
   452  
   453  	for i := 0; i < sliceValue.Len(); i++ {
   454  		sliceVal := sliceValue.Index(i).Interface()
   455  		val, ok := sliceVal.(string)
   456  
   457  		if !ok {
   458  			return nil, fmt.Errorf("entry item type[%T] is not a string", sliceVal)
   459  		}
   460  
   461  		stringSlice[i] = val
   462  	}
   463  
   464  	return stringSlice, nil
   465  }
   466  
   467  // SliceToMap converts slice to map.
   468  func SliceToMap(ids []string) map[string]bool {
   469  	// convert slice to map
   470  	values := make(map[string]bool)
   471  	for _, id := range ids {
   472  		values[id] = true
   473  	}
   474  
   475  	return values
   476  }
   477  
   478  // KeyExistsInMap checks if key exists in map.
   479  func KeyExistsInMap(key string, m map[string]interface{}) bool {
   480  	for k, v := range m {
   481  		if k == key {
   482  			return true
   483  		}
   484  
   485  		if obj, ok := v.(map[string]interface{}); ok {
   486  			exists := KeyExistsInMap(key, obj)
   487  			if exists {
   488  				return true
   489  			}
   490  		}
   491  	}
   492  
   493  	return false
   494  }