github.com/onflow/flow-go/crypto@v0.24.8/ecdsa_test.go (about)

     1  //go:build !relic
     2  // +build !relic
     3  
     4  package crypto
     5  
     6  import (
     7  	"encoding/hex"
     8  	"testing"
     9  
    10  	"crypto/elliptic"
    11  	crand "crypto/rand"
    12  	"math/big"
    13  
    14  	"github.com/btcsuite/btcd/btcec/v2"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  
    18  	"github.com/onflow/flow-go/crypto/hash"
    19  )
    20  
    21  var ecdsaCurves = []SigningAlgorithm{
    22  	ECDSAP256,
    23  	ECDSASecp256k1,
    24  }
    25  var ecdsaPrKeyLen = map[SigningAlgorithm]int{
    26  	ECDSAP256:      PrKeyLenECDSAP256,
    27  	ECDSASecp256k1: PrKeyLenECDSASecp256k1,
    28  }
    29  var ecdsaPubKeyLen = map[SigningAlgorithm]int{
    30  	ECDSAP256:      PubKeyLenECDSAP256,
    31  	ECDSASecp256k1: PubKeyLenECDSASecp256k1,
    32  }
    33  var ecdsaSigLen = map[SigningAlgorithm]int{
    34  	ECDSAP256:      SignatureLenECDSAP256,
    35  	ECDSASecp256k1: SignatureLenECDSASecp256k1,
    36  }
    37  
    38  // ECDSA tests
    39  func TestECDSA(t *testing.T) {
    40  
    41  	for _, curve := range ecdsaCurves {
    42  		t.Logf("Testing ECDSA for curve %s", curve)
    43  		// test key generation seed limits
    44  		testKeyGenSeed(t, curve, KeyGenSeedMinLen, KeyGenSeedMaxLen)
    45  		// test consistency
    46  		halg := hash.NewSHA3_256()
    47  		testGenSignVerify(t, curve, halg)
    48  	}
    49  }
    50  
    51  type dummyHasher struct{ size int }
    52  
    53  func newDummyHasher(size int) hash.Hasher               { return &dummyHasher{size} }
    54  func (d *dummyHasher) Algorithm() hash.HashingAlgorithm { return hash.UnknownHashingAlgorithm }
    55  func (d *dummyHasher) Size() int                        { return d.size }
    56  func (d *dummyHasher) ComputeHash([]byte) hash.Hash     { return make([]byte, d.size) }
    57  func (d *dummyHasher) Write([]byte) (int, error)        { return 0, nil }
    58  func (d *dummyHasher) SumHash() hash.Hash               { return make([]byte, d.size) }
    59  func (d *dummyHasher) Reset()                           {}
    60  
    61  func TestECDSAHasher(t *testing.T) {
    62  
    63  	for _, curve := range ecdsaCurves {
    64  
    65  		// generate a key pair
    66  		seed := make([]byte, KeyGenSeedMinLen)
    67  		n, err := crand.Read(seed)
    68  		require.Equal(t, n, KeyGenSeedMinLen)
    69  		require.NoError(t, err)
    70  		sk, err := GeneratePrivateKey(curve, seed)
    71  		require.NoError(t, err)
    72  		sig := make([]byte, ecdsaSigLen[curve])
    73  
    74  		// empty hasher
    75  		t.Run("Empty hasher", func(t *testing.T) {
    76  			_, err := sk.Sign(seed, nil)
    77  			assert.Error(t, err)
    78  			assert.True(t, IsNilHasherError(err))
    79  			_, err = sk.PublicKey().Verify(sig, seed, nil)
    80  			assert.Error(t, err)
    81  			assert.True(t, IsNilHasherError(err))
    82  		})
    83  
    84  		// hasher with large output size
    85  		t.Run("large size hasher is accepted", func(t *testing.T) {
    86  			dummy := newDummyHasher(500)
    87  			_, err := sk.Sign(seed, dummy)
    88  			assert.NoError(t, err)
    89  			_, err = sk.PublicKey().Verify(sig, seed, dummy)
    90  			assert.NoError(t, err)
    91  		})
    92  
    93  		// hasher with small output size
    94  		t.Run("small size hasher is rejected", func(t *testing.T) {
    95  			dummy := newDummyHasher(31) // 31 is one byte less than the supported curves' order
    96  			_, err := sk.Sign(seed, dummy)
    97  			assert.Error(t, err)
    98  			assert.True(t, IsInvalidHasherSizeError(err))
    99  			_, err = sk.PublicKey().Verify(sig, seed, dummy)
   100  			assert.Error(t, err)
   101  			assert.True(t, IsInvalidHasherSizeError(err))
   102  		})
   103  	}
   104  }
   105  
   106  // Signing bench
   107  func BenchmarkECDSAP256Sign(b *testing.B) {
   108  	halg := hash.NewSHA3_256()
   109  	benchSign(b, ECDSAP256, halg)
   110  }
   111  
   112  // Verifying bench
   113  func BenchmarkECDSAP256Verify(b *testing.B) {
   114  	halg := hash.NewSHA3_256()
   115  	benchVerify(b, ECDSAP256, halg)
   116  }
   117  
   118  // Signing bench
   119  func BenchmarkECDSASecp256k1Sign(b *testing.B) {
   120  	halg := hash.NewSHA3_256()
   121  	benchSign(b, ECDSASecp256k1, halg)
   122  }
   123  
   124  // Verifying bench
   125  func BenchmarkECDSASecp256k1Verify(b *testing.B) {
   126  	halg := hash.NewSHA3_256()
   127  	benchVerify(b, ECDSASecp256k1, halg)
   128  }
   129  
   130  // TestECDSAEncodeDecode tests encoding and decoding of ECDSA keys
   131  func TestECDSAEncodeDecode(t *testing.T) {
   132  	for _, curve := range ecdsaCurves {
   133  		testEncodeDecode(t, curve)
   134  	}
   135  }
   136  
   137  // TestECDSAEquals tests equal for ECDSA keys
   138  func TestECDSAEquals(t *testing.T) {
   139  	for i, curve := range ecdsaCurves {
   140  		testEquals(t, curve, ecdsaCurves[i]^1)
   141  	}
   142  }
   143  
   144  // TestECDSAUtils tests some utility functions
   145  func TestECDSAUtils(t *testing.T) {
   146  
   147  	for _, curve := range ecdsaCurves {
   148  		// generate a key pair
   149  		seed := make([]byte, KeyGenSeedMinLen)
   150  		n, err := crand.Read(seed)
   151  		require.Equal(t, n, KeyGenSeedMinLen)
   152  		require.NoError(t, err)
   153  		sk, err := GeneratePrivateKey(curve, seed)
   154  		require.NoError(t, err)
   155  		testKeysAlgorithm(t, sk, curve)
   156  		testKeySize(t, sk, ecdsaPrKeyLen[curve], ecdsaPubKeyLen[curve])
   157  	}
   158  }
   159  
   160  // TestScalarMult is a unit test of the scalar multiplication
   161  // This is only a sanity check meant to make sure the curve implemented
   162  // is checked against an independant test vector
   163  func TestScalarMult(t *testing.T) {
   164  	secp256k1 := secp256k1Instance.curve
   165  	p256 := p256Instance.curve
   166  	genericMultTests := []struct {
   167  		curve elliptic.Curve
   168  		Px    string
   169  		Py    string
   170  		k     string
   171  		Qx    string
   172  		Qy    string
   173  	}{
   174  		{
   175  			secp256k1,
   176  			"858a2ea2498449acf531128892f8ee5eb6d10cfb2f7ebfa851def0e0d8428742",
   177  			"015c59492d794a4f6a3ab3046eecfc85e223d1ce8571aa99b98af6838018286e",
   178  			"6e37a39c31a05181bf77919ace790efd0bdbcaf42b5a52871fc112fceb918c95",
   179  			"fea24b9a6acdd97521f850e782ef4a24f3ef672b5cd51f824499d708bb0c744d",
   180  			"5f0b6db1a2c851cb2959fab5ed36ad377e8b53f1f43b7923f1be21b316df1ea1",
   181  		},
   182  		{
   183  			p256,
   184  			"fa1a85f1ae436e9aa05baabe60eb83b2d7ff52e5766504fda4e18d2d25887481",
   185  			"f7cc347e1ac53f6720ffc511bfb23c2f04c764620be0baf8c44313e92d5404de",
   186  			"6e37a39c31a05181bf77919ace790efd0bdbcaf42b5a52871fc112fceb918c95",
   187  			"28a27fc352f315d5cc562cb0d97e5882b6393fd6571f7d394cc583e65b5c7ffe",
   188  			"4086d17a2d0d9dc365388c91ba2176de7acc5c152c1a8d04e14edc6edaebd772",
   189  		},
   190  	}
   191  
   192  	baseMultTests := []struct {
   193  		curve elliptic.Curve
   194  		k     string
   195  		Qx    string
   196  		Qy    string
   197  	}{
   198  		{
   199  			secp256k1,
   200  			"6e37a39c31a05181bf77919ace790efd0bdbcaf42b5a52871fc112fceb918c95",
   201  			"36f292f6c287b6e72ca8128465647c7f88730f84ab27a1e934dbd2da753930fa",
   202  			"39a09ddcf3d28fb30cc683de3fc725e095ec865c3d41aef6065044cb12b1ff61",
   203  		},
   204  		{
   205  			p256,
   206  			"6e37a39c31a05181bf77919ace790efd0bdbcaf42b5a52871fc112fceb918c95",
   207  			"78a80dfe190a6068be8ddf05644c32d2540402ffc682442f6a9eeb96125d8681",
   208  			"3789f92cf4afabf719aaba79ecec54b27e33a188f83158f6dd15ecb231b49808",
   209  		},
   210  	}
   211  
   212  	t.Run("scalar mult check", func(t *testing.T) {
   213  		for _, test := range genericMultTests {
   214  			Px, _ := new(big.Int).SetString(test.Px, 16)
   215  			Py, _ := new(big.Int).SetString(test.Py, 16)
   216  			k, _ := new(big.Int).SetString(test.k, 16)
   217  			Qx, _ := new(big.Int).SetString(test.Qx, 16)
   218  			Qy, _ := new(big.Int).SetString(test.Qy, 16)
   219  			Rx, Ry := test.curve.ScalarMult(Px, Py, k.Bytes())
   220  			assert.Equal(t, Rx.Cmp(Qx), 0)
   221  			assert.Equal(t, Ry.Cmp(Qy), 0)
   222  		}
   223  	})
   224  
   225  	t.Run("base scalar mult check", func(t *testing.T) {
   226  		for _, test := range baseMultTests {
   227  			k, _ := new(big.Int).SetString(test.k, 16)
   228  			Qx, _ := new(big.Int).SetString(test.Qx, 16)
   229  			Qy, _ := new(big.Int).SetString(test.Qy, 16)
   230  			// base mult
   231  			Rx, Ry := test.curve.ScalarBaseMult(k.Bytes())
   232  			assert.Equal(t, Rx.Cmp(Qx), 0)
   233  			assert.Equal(t, Ry.Cmp(Qy), 0)
   234  			// generic mult with base point
   235  			Px := new(big.Int).Set(test.curve.Params().Gx)
   236  			Py := new(big.Int).Set(test.curve.Params().Gy)
   237  			Rx, Ry = test.curve.ScalarMult(Px, Py, k.Bytes())
   238  			assert.Equal(t, Rx.Cmp(Qx), 0)
   239  			assert.Equal(t, Ry.Cmp(Qy), 0)
   240  		}
   241  	})
   242  }
   243  
   244  func TestSignatureFormatCheck(t *testing.T) {
   245  
   246  	for _, curve := range ecdsaCurves {
   247  		t.Run("valid signature", func(t *testing.T) {
   248  			len := ecdsaSigLen[curve]
   249  			sig := Signature(make([]byte, len))
   250  			_, err := crand.Read(sig)
   251  			require.NoError(t, err)
   252  			sig[len/2] = 0    // force s to be less than the curve order
   253  			sig[len-1] |= 1   // force s to be non zero
   254  			sig[0] = 0        // force r to be less than the curve order
   255  			sig[len/2-1] |= 1 // force r to be non zero
   256  			valid, err := SignatureFormatCheck(curve, sig)
   257  			assert.Nil(t, err)
   258  			assert.True(t, valid)
   259  		})
   260  
   261  		t.Run("invalid length", func(t *testing.T) {
   262  			len := ecdsaSigLen[curve]
   263  			shortSig := Signature(make([]byte, len/2))
   264  			valid, err := SignatureFormatCheck(curve, shortSig)
   265  			assert.Nil(t, err)
   266  			assert.False(t, valid)
   267  
   268  			longSig := Signature(make([]byte, len*2))
   269  			valid, err = SignatureFormatCheck(curve, longSig)
   270  			assert.Nil(t, err)
   271  			assert.False(t, valid)
   272  		})
   273  
   274  		t.Run("zero values", func(t *testing.T) {
   275  			// signature with a zero s
   276  			len := ecdsaSigLen[curve]
   277  			sig0s := Signature(make([]byte, len))
   278  			_, err := crand.Read(sig0s[:len/2])
   279  			require.NoError(t, err)
   280  
   281  			valid, err := SignatureFormatCheck(curve, sig0s)
   282  			assert.Nil(t, err)
   283  			assert.False(t, valid)
   284  
   285  			// signature with a zero r
   286  			sig0r := Signature(make([]byte, len))
   287  			_, err = crand.Read(sig0r[len/2:])
   288  			require.NoError(t, err)
   289  
   290  			valid, err = SignatureFormatCheck(curve, sig0r)
   291  			assert.Nil(t, err)
   292  			assert.False(t, valid)
   293  		})
   294  
   295  		t.Run("large values", func(t *testing.T) {
   296  			len := ecdsaSigLen[curve]
   297  			sigLargeS := Signature(make([]byte, len))
   298  			_, err := crand.Read(sigLargeS[:len/2])
   299  			require.NoError(t, err)
   300  			// make sure s is larger than the curve order
   301  			for i := len / 2; i < len; i++ {
   302  				sigLargeS[i] = 0xFF
   303  			}
   304  
   305  			valid, err := SignatureFormatCheck(curve, sigLargeS)
   306  			assert.Nil(t, err)
   307  			assert.False(t, valid)
   308  
   309  			sigLargeR := Signature(make([]byte, len))
   310  			_, err = crand.Read(sigLargeR[len/2:])
   311  			require.NoError(t, err)
   312  			// make sure s is larger than the curve order
   313  			for i := 0; i < len/2; i++ {
   314  				sigLargeR[i] = 0xFF
   315  			}
   316  
   317  			valid, err = SignatureFormatCheck(curve, sigLargeR)
   318  			assert.Nil(t, err)
   319  			assert.False(t, valid)
   320  		})
   321  	}
   322  }
   323  
   324  func TestEllipticUnmarshalSecp256k1(t *testing.T) {
   325  
   326  	testVectors := []string{
   327  		"028b10bf56476bf7da39a3286e29df389177a2fa0fca2d73348ff78887515d8da1", // IsOnCurve for elliptic returns false
   328  		"03d39427f07f680d202fe8504306eb29041aceaf4b628c2c69b0ec248155443166", // odd, IsOnCurve for elliptic returns false
   329  		"0267d1942a6cbe4daec242ea7e01c6cdb82dadb6e7077092deb55c845bf851433e", // arith of sqrt in elliptic doesn't match secp256k1
   330  		"0345d45eda6d087918b041453a96303b78c478dce89a4ae9b3c933a018888c5e06", // odd, arith of sqrt in elliptic doesn't match secp256k1
   331  	}
   332  
   333  	for _, testVector := range testVectors {
   334  
   335  		// get the compressed bytes
   336  		publicBytes, err := hex.DecodeString(testVector)
   337  		require.NoError(t, err)
   338  
   339  		// decompress, check that those are perfectly valid Secp256k1 public keys
   340  		retrieved, err := DecodePublicKeyCompressed(ECDSASecp256k1, publicBytes)
   341  		require.NoError(t, err)
   342  
   343  		// check the compression is canonical by re-compressing to the same bytes
   344  		require.Equal(t, retrieved.EncodeCompressed(), publicBytes)
   345  
   346  		// check that elliptic fails at decompressing them
   347  		x, y := elliptic.UnmarshalCompressed(btcec.S256(), publicBytes)
   348  		require.Nil(t, x)
   349  		require.Nil(t, y)
   350  	}
   351  }
   352  
   353  func BenchmarkECDSADecode(b *testing.B) {
   354  	// random message
   355  	seed := make([]byte, 50)
   356  	_, _ = crand.Read(seed)
   357  
   358  	for _, curve := range []SigningAlgorithm{ECDSASecp256k1, ECDSAP256} {
   359  		sk, _ := GeneratePrivateKey(curve, seed)
   360  		comp := sk.PublicKey().EncodeCompressed()
   361  		uncomp := sk.PublicKey().Encode()
   362  
   363  		b.Run("compressed point on "+curve.String(), func(b *testing.B) {
   364  			b.ResetTimer()
   365  			for i := 0; i < b.N; i++ {
   366  				_, err := DecodePublicKeyCompressed(curve, comp)
   367  				require.NoError(b, err)
   368  			}
   369  			b.StopTimer()
   370  		})
   371  
   372  		b.Run("uncompressed point on "+curve.String(), func(b *testing.B) {
   373  			b.ResetTimer()
   374  			for i := 0; i < b.N; i++ {
   375  				_, err := DecodePublicKey(curve, uncomp)
   376  				require.NoError(b, err)
   377  			}
   378  			b.StopTimer()
   379  		})
   380  	}
   381  }