github.com/koko1123/flow-go-1@v0.29.6/fvm/crypto/crypto_test.go (about)

     1  package crypto_test
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"testing"
     7  	"unicode/utf8"
     8  
     9  	"github.com/fxamacker/cbor/v2"
    10  	"github.com/onflow/cadence/runtime"
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/koko1123/flow-go-1/fvm/crypto"
    15  	"github.com/koko1123/flow-go-1/fvm/errors"
    16  	"github.com/koko1123/flow-go-1/model/flow"
    17  	msig "github.com/koko1123/flow-go-1/module/signature"
    18  	gocrypto "github.com/onflow/flow-go/crypto"
    19  	"github.com/onflow/flow-go/crypto/hash"
    20  )
    21  
    22  func TestHashWithTag(t *testing.T) {
    23  	t.Run("tag too long", func(t *testing.T) {
    24  		algorithms := []hash.HashingAlgorithm{
    25  			hash.SHA2_256,
    26  			hash.SHA2_384,
    27  			hash.SHA3_256,
    28  			hash.SHA3_384,
    29  			hash.Keccak_256,
    30  		}
    31  
    32  		okTag := [flow.DomainTagLength / 2]byte{}   // tag does not exceed 32 bytes
    33  		longTag := [flow.DomainTagLength + 1]byte{} // tag larger that 32 bytes
    34  
    35  		for i, algorithm := range algorithms {
    36  			t.Run(fmt.Sprintf("algo %d: %v", i, algorithm), func(t *testing.T) {
    37  				_, err := crypto.HashWithTag(algorithm, string(longTag[:]), []byte("some data"))
    38  				require.Error(t, err)
    39  			})
    40  
    41  			t.Run(fmt.Sprintf("algo %d: %v - control (tag ok)", i, algorithm), func(t *testing.T) {
    42  				_, err := crypto.HashWithTag(algorithm, string(okTag[:]), []byte("some data"))
    43  				require.NoError(t, err)
    44  			})
    45  		}
    46  	})
    47  }
    48  
    49  func TestVerifySignatureFromRuntime(t *testing.T) {
    50  
    51  	// make sure the seed length is larger than miniumum seed lengths of all signature schemes
    52  	seedLength := 64
    53  
    54  	correctCombinations := map[runtime.SignatureAlgorithm]map[runtime.HashAlgorithm]struct{}{
    55  
    56  		runtime.SignatureAlgorithmBLS_BLS12_381: {
    57  			runtime.HashAlgorithmKMAC128_BLS_BLS12_381: {},
    58  		},
    59  		runtime.SignatureAlgorithmECDSA_P256: {
    60  			runtime.HashAlgorithmSHA2_256:   {},
    61  			runtime.HashAlgorithmSHA3_256:   {},
    62  			runtime.HashAlgorithmKECCAK_256: {},
    63  		},
    64  		runtime.SignatureAlgorithmECDSA_secp256k1: {
    65  			runtime.HashAlgorithmSHA2_256:   {},
    66  			runtime.HashAlgorithmSHA3_256:   {},
    67  			runtime.HashAlgorithmKECCAK_256: {},
    68  		},
    69  	}
    70  
    71  	t.Run("verify should fail on incorrect combinations", func(t *testing.T) {
    72  
    73  		signatureAlgos := []runtime.SignatureAlgorithm{
    74  			runtime.SignatureAlgorithmECDSA_P256,
    75  			runtime.SignatureAlgorithmECDSA_secp256k1,
    76  			runtime.SignatureAlgorithmBLS_BLS12_381,
    77  		}
    78  		hashAlgos := []runtime.HashAlgorithm{
    79  			runtime.HashAlgorithmSHA2_256,
    80  			runtime.HashAlgorithmSHA2_384,
    81  			runtime.HashAlgorithmSHA3_256,
    82  			runtime.HashAlgorithmSHA3_384,
    83  			runtime.HashAlgorithmKMAC128_BLS_BLS12_381,
    84  			runtime.HashAlgorithmKECCAK_256,
    85  		}
    86  
    87  		for _, s := range signatureAlgos {
    88  			for _, h := range hashAlgos {
    89  				t.Run(fmt.Sprintf("combination: %v, %v", s, h), func(t *testing.T) {
    90  					seed := make([]byte, seedLength)
    91  					rand.Read(seed)
    92  					pk, err := gocrypto.GeneratePrivateKey(crypto.RuntimeToCryptoSigningAlgorithm(s), seed)
    93  					require.NoError(t, err)
    94  
    95  					tag := "random_tag"
    96  					var hasher hash.Hasher
    97  					if h != runtime.HashAlgorithmKMAC128_BLS_BLS12_381 {
    98  						hasher, err = crypto.NewPrefixedHashing(crypto.RuntimeToCryptoHashingAlgorithm(h), tag)
    99  						require.NoError(t, err)
   100  					} else {
   101  						hasher = msig.NewBLSHasher(tag)
   102  					}
   103  
   104  					signature := make([]byte, 0)
   105  					sig, err := pk.Sign([]byte("some data"), hasher)
   106  					if _, shouldBeOk := correctCombinations[s][h]; shouldBeOk {
   107  						require.NoError(t, err)
   108  					}
   109  
   110  					if sig != nil {
   111  						signature = sig.Bytes()
   112  					}
   113  
   114  					ok, err := crypto.VerifySignatureFromRuntime(
   115  						signature,
   116  						tag,
   117  						[]byte("some data"),
   118  						pk.PublicKey().Encode(),
   119  						s,
   120  						h,
   121  					)
   122  
   123  					if _, shouldBeOk := correctCombinations[s][h]; shouldBeOk {
   124  						require.NoError(t, err)
   125  						require.True(t, ok)
   126  					} else {
   127  						require.Error(t, err)
   128  						require.False(t, ok)
   129  					}
   130  				})
   131  			}
   132  		}
   133  	})
   134  
   135  	t.Run("BLS tag combinations", func(t *testing.T) {
   136  		cases := []struct {
   137  			signTag   string
   138  			verifyTag string
   139  			require   func(t *testing.T, sigOk bool, err error)
   140  		}{
   141  			{
   142  				signTag:   "random_tag",
   143  				verifyTag: "random_tag",
   144  				require: func(t *testing.T, sigOk bool, err error) {
   145  					require.NoError(t, err)
   146  					require.True(t, sigOk)
   147  				},
   148  			},
   149  			{
   150  				signTag:   "",
   151  				verifyTag: "",
   152  				require: func(t *testing.T, sigOk bool, err error) {
   153  					require.NoError(t, err)
   154  					require.True(t, sigOk)
   155  				},
   156  			}, {
   157  				signTag:   "padding test",
   158  				verifyTag: "padding test" + string([]byte{0, 0, 0, 0, 0}),
   159  				require: func(t *testing.T, sigOk bool, err error) {
   160  					require.NoError(t, err)
   161  					require.False(t, sigOk)
   162  				},
   163  			}, {
   164  				signTag:   "valid tag",
   165  				verifyTag: "different valid tag",
   166  				require: func(t *testing.T, sigOk bool, err error) {
   167  					require.NoError(t, err)
   168  					require.False(t, sigOk)
   169  				},
   170  			}, {
   171  				signTag:   "a very large tag with more than thirty two bytes",
   172  				verifyTag: "a very large tag with more than thirty two bytes",
   173  				require: func(t *testing.T, sigOk bool, err error) {
   174  					require.NoError(t, err)
   175  					require.True(t, sigOk)
   176  				},
   177  			},
   178  		}
   179  
   180  		for _, c := range cases {
   181  			seed := make([]byte, seedLength)
   182  			rand.Read(seed)
   183  			pk, err := gocrypto.GeneratePrivateKey(gocrypto.BLSBLS12381, seed)
   184  			require.NoError(t, err)
   185  
   186  			hasher := msig.NewBLSHasher(string(c.signTag))
   187  			signature := make([]byte, 0)
   188  			sig, err := pk.Sign([]byte("some data"), hasher)
   189  			require.NoError(t, err)
   190  
   191  			if sig != nil {
   192  				signature = sig.Bytes()
   193  			}
   194  
   195  			ok, err := crypto.VerifySignatureFromRuntime(
   196  				signature,
   197  				string(c.verifyTag),
   198  				[]byte("some data"),
   199  				pk.PublicKey().Encode(),
   200  				runtime.SignatureAlgorithmBLS_BLS12_381,
   201  				runtime.HashAlgorithmKMAC128_BLS_BLS12_381,
   202  			)
   203  
   204  			c.require(t, ok, err)
   205  		}
   206  	})
   207  
   208  	t.Run("ECDSA tag combinations", func(t *testing.T) {
   209  
   210  		cases := []struct {
   211  			signTag   string
   212  			verifyTag string
   213  			require   func(t *testing.T, sigOk bool, err error)
   214  		}{
   215  			{
   216  				signTag:   "random_tag",
   217  				verifyTag: "random_tag",
   218  				require: func(t *testing.T, sigOk bool, err error) {
   219  					require.NoError(t, err)
   220  					require.True(t, sigOk)
   221  				},
   222  			},
   223  			{
   224  				signTag:   "",
   225  				verifyTag: "",
   226  				require: func(t *testing.T, sigOk bool, err error) {
   227  					require.NoError(t, err)
   228  					require.True(t, sigOk)
   229  				},
   230  			}, {
   231  				signTag:   "padding test",
   232  				verifyTag: "padding test" + string([]byte{0, 0, 0, 0, 0}),
   233  				require: func(t *testing.T, sigOk bool, err error) {
   234  					require.NoError(t, err)
   235  					require.True(t, sigOk)
   236  				},
   237  			}, {
   238  				signTag:   "valid tag",
   239  				verifyTag: "different valid tag",
   240  				require: func(t *testing.T, sigOk bool, err error) {
   241  					require.NoError(t, err)
   242  					require.False(t, sigOk)
   243  				},
   244  			}, {
   245  				signTag:   "valid tag",
   246  				verifyTag: "a very large tag with more than thirty two bytes",
   247  				require: func(t *testing.T, sigOk bool, err error) {
   248  					require.Error(t, err)
   249  					require.False(t, sigOk)
   250  				},
   251  			},
   252  		}
   253  
   254  		for _, c := range cases {
   255  			for s, hMaps := range correctCombinations {
   256  				if s == runtime.SignatureAlgorithmBLS_BLS12_381 {
   257  					// skip BLS to only cover ECDSA in this test
   258  					continue
   259  				}
   260  				for h := range hMaps {
   261  					t.Run(fmt.Sprintf("hash tag: %v, verify tag: %v [%v, %v]", c.signTag, c.verifyTag, s, h), func(t *testing.T) {
   262  
   263  						seed := make([]byte, seedLength)
   264  						rand.Read(seed)
   265  						pk, err := gocrypto.GeneratePrivateKey(crypto.RuntimeToCryptoSigningAlgorithm(s), seed)
   266  						require.NoError(t, err)
   267  
   268  						hasher, err := crypto.NewPrefixedHashing(crypto.RuntimeToCryptoHashingAlgorithm(h), c.signTag)
   269  						require.NoError(t, err)
   270  
   271  						data := []byte("some data")
   272  						sig, err := pk.Sign(data, hasher)
   273  						require.NoError(t, err)
   274  						signature := sig.Bytes()
   275  
   276  						ok, err := crypto.VerifySignatureFromRuntime(
   277  							signature,
   278  							c.verifyTag,
   279  							data,
   280  							pk.PublicKey().Encode(),
   281  							s,
   282  							h,
   283  						)
   284  
   285  						c.require(t, ok, err)
   286  					})
   287  				}
   288  			}
   289  		}
   290  	})
   291  }
   292  
   293  func TestVerifySignatureFromTransaction(t *testing.T) {
   294  
   295  	// make sure the seed length is larger than miniumum seed lengths of all signature schemes
   296  	seedLength := 64
   297  
   298  	correctCombinations := map[gocrypto.SigningAlgorithm]map[hash.HashingAlgorithm]struct{}{
   299  		gocrypto.ECDSAP256: {
   300  			hash.SHA2_256: {},
   301  			hash.SHA3_256: {},
   302  		},
   303  		gocrypto.ECDSASecp256k1: {
   304  			hash.SHA2_256: {},
   305  			hash.SHA3_256: {},
   306  		},
   307  	}
   308  
   309  	t.Run("verify should fail on incorrect combinations", func(t *testing.T) {
   310  
   311  		signatureAlgos := []gocrypto.SigningAlgorithm{
   312  			gocrypto.ECDSAP256,
   313  			gocrypto.ECDSASecp256k1,
   314  			gocrypto.BLSBLS12381,
   315  		}
   316  		hashAlgos := []hash.HashingAlgorithm{
   317  			hash.SHA2_256,
   318  			hash.SHA2_384,
   319  			hash.SHA3_256,
   320  			hash.SHA3_384,
   321  			hash.KMAC128,
   322  			hash.Keccak_256,
   323  		}
   324  
   325  		for _, s := range signatureAlgos {
   326  			for _, h := range hashAlgos {
   327  				t.Run(fmt.Sprintf("combination: %v, %v", s, h), func(t *testing.T) {
   328  					seed := make([]byte, seedLength)
   329  					rand.Read(seed)
   330  					sk, err := gocrypto.GeneratePrivateKey(s, seed)
   331  					require.NoError(t, err)
   332  
   333  					tag := string(flow.TransactionDomainTag[:])
   334  					var hasher hash.Hasher
   335  					if h != hash.KMAC128 {
   336  						hasher, err = crypto.NewPrefixedHashing(h, tag)
   337  						require.NoError(t, err)
   338  					} else {
   339  						hasher = msig.NewBLSHasher(tag)
   340  					}
   341  
   342  					signature := make([]byte, 0)
   343  					data := []byte("some_data")
   344  					sig, err := sk.Sign(data, hasher)
   345  					if _, shouldBeOk := correctCombinations[s][h]; shouldBeOk {
   346  						require.NoError(t, err)
   347  					}
   348  
   349  					if sig != nil {
   350  						signature = sig.Bytes()
   351  					}
   352  
   353  					ok, err := crypto.VerifySignatureFromTransaction(signature, data, sk.PublicKey(), h)
   354  
   355  					if _, shouldBeOk := correctCombinations[s][h]; shouldBeOk {
   356  						require.NoError(t, err)
   357  						require.True(t, ok)
   358  					} else {
   359  						require.Error(t, err)
   360  						require.False(t, ok)
   361  					}
   362  				})
   363  			}
   364  		}
   365  	})
   366  
   367  	t.Run("tag combinations", func(t *testing.T) {
   368  
   369  		cases := []struct {
   370  			signTag string
   371  			require func(t *testing.T, sigOk bool, err error)
   372  		}{
   373  			{
   374  				signTag: string(flow.TransactionDomainTag[:]),
   375  				require: func(t *testing.T, sigOk bool, err error) {
   376  					require.NoError(t, err)
   377  					require.True(t, sigOk)
   378  				},
   379  			},
   380  			{
   381  				signTag: "",
   382  				require: func(t *testing.T, sigOk bool, err error) {
   383  					require.NoError(t, err)
   384  					require.False(t, sigOk)
   385  				},
   386  			}, {
   387  				signTag: "random_tag",
   388  				require: func(t *testing.T, sigOk bool, err error) {
   389  					require.NoError(t, err)
   390  					require.False(t, sigOk)
   391  				},
   392  			},
   393  		}
   394  
   395  		for _, c := range cases {
   396  			for s, hMaps := range correctCombinations {
   397  				for h := range hMaps {
   398  					t.Run(fmt.Sprintf("sign tag: %v [%v, %v]", c.signTag, s, h), func(t *testing.T) {
   399  						seed := make([]byte, seedLength)
   400  						rand.Read(seed)
   401  						sk, err := gocrypto.GeneratePrivateKey(s, seed)
   402  						require.NoError(t, err)
   403  
   404  						hasher, err := crypto.NewPrefixedHashing(h, c.signTag)
   405  						require.NoError(t, err)
   406  
   407  						data := []byte("some data")
   408  						sig, err := sk.Sign(data, hasher)
   409  						require.NoError(t, err)
   410  						signature := sig.Bytes()
   411  
   412  						ok, err := crypto.VerifySignatureFromTransaction(signature, data, sk.PublicKey(), h)
   413  						c.require(t, ok, err)
   414  					})
   415  				}
   416  			}
   417  		}
   418  	})
   419  }
   420  
   421  func TestValidatePublicKey(t *testing.T) {
   422  
   423  	// make sure the seed length is larger than miniumum seed lengths of all signature schemes
   424  	seedLength := 64
   425  
   426  	validPublicKey := func(t *testing.T, s runtime.SignatureAlgorithm) []byte {
   427  		seed := make([]byte, seedLength)
   428  		rand.Read(seed)
   429  		pk, err := gocrypto.GeneratePrivateKey(crypto.RuntimeToCryptoSigningAlgorithm(s), seed)
   430  		require.NoError(t, err)
   431  		return pk.PublicKey().Encode()
   432  	}
   433  
   434  	t.Run("Unknown algorithm should return false", func(t *testing.T) {
   435  		err := crypto.ValidatePublicKey(runtime.SignatureAlgorithmUnknown, validPublicKey(t, runtime.SignatureAlgorithmECDSA_P256))
   436  		require.Error(t, err)
   437  	})
   438  
   439  	t.Run("valid public key should return true", func(t *testing.T) {
   440  		signatureAlgos := []runtime.SignatureAlgorithm{
   441  			runtime.SignatureAlgorithmECDSA_P256,
   442  			runtime.SignatureAlgorithmECDSA_secp256k1,
   443  			runtime.SignatureAlgorithmBLS_BLS12_381,
   444  		}
   445  		for i, s := range signatureAlgos {
   446  			t.Run(fmt.Sprintf("case %v: %v", i, s), func(t *testing.T) {
   447  				err := crypto.ValidatePublicKey(s, validPublicKey(t, s))
   448  				require.NoError(t, err)
   449  			})
   450  		}
   451  	})
   452  
   453  	t.Run("invalid public key should return false", func(t *testing.T) {
   454  		signatureAlgos := []runtime.SignatureAlgorithm{
   455  			runtime.SignatureAlgorithmECDSA_P256,
   456  			runtime.SignatureAlgorithmECDSA_secp256k1,
   457  			runtime.SignatureAlgorithmBLS_BLS12_381,
   458  		}
   459  		for i, s := range signatureAlgos {
   460  			t.Run(fmt.Sprintf("case %v: %v", i, s), func(t *testing.T) {
   461  				key := validPublicKey(t, s)
   462  				key[0] ^= 1 // alter one bit of the valid key
   463  
   464  				err := crypto.ValidatePublicKey(s, key)
   465  				require.Error(t, err)
   466  			})
   467  		}
   468  	})
   469  }
   470  
   471  func TestHashingAlgorithmConversion(t *testing.T) {
   472  	hashingAlgoMapping := map[runtime.HashAlgorithm]hash.HashingAlgorithm{
   473  		runtime.HashAlgorithmSHA2_256:              hash.SHA2_256,
   474  		runtime.HashAlgorithmSHA3_256:              hash.SHA3_256,
   475  		runtime.HashAlgorithmSHA2_384:              hash.SHA2_384,
   476  		runtime.HashAlgorithmSHA3_384:              hash.SHA3_384,
   477  		runtime.HashAlgorithmKMAC128_BLS_BLS12_381: hash.KMAC128,
   478  		runtime.HashAlgorithmKECCAK_256:            hash.Keccak_256,
   479  	}
   480  
   481  	for runtimeAlgo, cryptoAlgo := range hashingAlgoMapping {
   482  		assert.Equal(t, cryptoAlgo, crypto.RuntimeToCryptoHashingAlgorithm(runtimeAlgo))
   483  		assert.Equal(t, runtimeAlgo, crypto.CryptoToRuntimeHashingAlgorithm(cryptoAlgo))
   484  	}
   485  }
   486  
   487  func TestSigningAlgorithmConversion(t *testing.T) {
   488  	signingAlgoMapping := map[runtime.SignatureAlgorithm]gocrypto.SigningAlgorithm{
   489  		runtime.SignatureAlgorithmECDSA_P256:      gocrypto.ECDSAP256,
   490  		runtime.SignatureAlgorithmECDSA_secp256k1: gocrypto.ECDSASecp256k1,
   491  		runtime.SignatureAlgorithmBLS_BLS12_381:   gocrypto.BLSBLS12381,
   492  	}
   493  
   494  	for runtimeAlgo, cryptoAlgo := range signingAlgoMapping {
   495  		assert.Equal(t, cryptoAlgo, crypto.RuntimeToCryptoSigningAlgorithm(runtimeAlgo))
   496  		assert.Equal(t, runtimeAlgo, crypto.CryptoToRuntimeSigningAlgorithm(cryptoAlgo))
   497  	}
   498  }
   499  
   500  func TestVerifySignatureFromRuntime_error_handling_produces_valid_utf8_for_invalid_sign_algo(t *testing.T) {
   501  
   502  	invalidSignatureAlgo := runtime.SignatureAlgorithm(164)
   503  
   504  	_, err := crypto.VerifySignatureFromRuntime(
   505  		nil, "", nil, nil, invalidSignatureAlgo, 0,
   506  	)
   507  
   508  	require.True(t, errors.IsValueError(err))
   509  
   510  	require.Contains(t, err.Error(), fmt.Sprintf("%d", invalidSignatureAlgo))
   511  
   512  	errorString := err.Error()
   513  	assert.True(t, utf8.ValidString(errorString))
   514  
   515  	// check if they can encoded and decoded using CBOR
   516  	marshalledBytes, err := cbor.Marshal(errorString)
   517  	require.NoError(t, err)
   518  
   519  	var unmarshalledString string
   520  
   521  	err = cbor.Unmarshal(marshalledBytes, &unmarshalledString)
   522  	require.NoError(t, err)
   523  
   524  	require.Equal(t, errorString, unmarshalledString)
   525  }
   526  
   527  func TestVerifySignatureFromRuntime_error_handling_produces_valid_utf8_for_invalid_hash_algo(t *testing.T) {
   528  
   529  	invalidHashAlgo := runtime.HashAlgorithm(164)
   530  
   531  	_, err := crypto.VerifySignatureFromRuntime(
   532  		nil, "", nil, nil, runtime.SignatureAlgorithmECDSA_P256, invalidHashAlgo,
   533  	)
   534  
   535  	require.True(t, errors.IsValueError(err))
   536  
   537  	require.Contains(t, err.Error(), fmt.Sprintf("%d", invalidHashAlgo))
   538  
   539  	errorString := err.Error()
   540  	assert.True(t, utf8.ValidString(errorString))
   541  
   542  	// check if they can encoded and decoded using CBOR
   543  	marshalledBytes, err := cbor.Marshal(errorString)
   544  	require.NoError(t, err)
   545  
   546  	var unmarshalledString string
   547  
   548  	err = cbor.Unmarshal(marshalledBytes, &unmarshalledString)
   549  	require.NoError(t, err)
   550  
   551  	require.Equal(t, errorString, unmarshalledString)
   552  }
   553  
   554  func TestVerifySignatureFromRuntime_error_handling_produces_valid_utf8_for_invalid_public_key(t *testing.T) {
   555  
   556  	invalidPublicKey := []byte{0xc3, 0x28} // some invalid UTF8
   557  
   558  	_, err := crypto.VerifySignatureFromRuntime(
   559  		nil, "random_tag", nil, invalidPublicKey, runtime.SignatureAlgorithmECDSA_P256, runtime.HashAlgorithmSHA2_256,
   560  	)
   561  
   562  	require.True(t, errors.IsValueError(err))
   563  	errorString := err.Error()
   564  
   565  	require.Contains(t, errorString, fmt.Sprintf("%x", invalidPublicKey))
   566  
   567  	assert.True(t, utf8.ValidString(errorString))
   568  
   569  	// check if they can encoded and decoded using CBOR
   570  	marshalledBytes, err := cbor.Marshal(errorString)
   571  	require.NoError(t, err)
   572  
   573  	var unmarshalledString string
   574  
   575  	err = cbor.Unmarshal(marshalledBytes, &unmarshalledString)
   576  	require.NoError(t, err)
   577  
   578  	require.Equal(t, errorString, unmarshalledString)
   579  }