github.com/emmansun/gmsm@v0.29.1/sm2/sm2ec/elliptic_test.go (about)

     1  package sm2ec
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/elliptic"
     6  	"crypto/rand"
     7  	"encoding/hex"
     8  	"math/big"
     9  	"testing"
    10  )
    11  
    12  var _ = elliptic.P256() // force NIST P curves init, avoid panic when we invoke generic implementation's method
    13  
    14  // genericParamsForCurve returns the dereferenced CurveParams for
    15  // the specified curve. This is used to avoid the logic for
    16  // upgrading a curve to its specific implementation, forcing
    17  // usage of the generic implementation.
    18  func genericParamsForCurve(c elliptic.Curve) *elliptic.CurveParams {
    19  	d := *(c.Params())
    20  	return &d
    21  }
    22  
    23  func testAllCurves(t *testing.T, f func(*testing.T, elliptic.Curve)) {
    24  	tests := []struct {
    25  		name  string
    26  		curve elliptic.Curve
    27  	}{
    28  		{"SM2P256", P256()},
    29  		{"SM2P256/Params", genericParamsForCurve(P256())},
    30  	}
    31  	if testing.Short() {
    32  		tests = tests[:1]
    33  	}
    34  	for _, test := range tests {
    35  		curve := test.curve
    36  		t.Run(test.name, func(t *testing.T) {
    37  			t.Parallel()
    38  			f(t, curve)
    39  		})
    40  	}
    41  }
    42  
    43  func TestOnCurve(t *testing.T) {
    44  	testAllCurves(t, func(t *testing.T, curve elliptic.Curve) {
    45  		if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
    46  			t.Error("basepoint is not on the curve")
    47  		}
    48  	})
    49  }
    50  
    51  func TestOffCurve(t *testing.T) {
    52  	testAllCurves(t, func(t *testing.T, curve elliptic.Curve) {
    53  		x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
    54  		if curve.IsOnCurve(x, y) {
    55  			t.Errorf("point off curve is claimed to be on the curve")
    56  		}
    57  
    58  		byteLen := (curve.Params().BitSize + 7) / 8
    59  		b := make([]byte, 1+2*byteLen)
    60  		b[0] = 4 // uncompressed point
    61  		x.FillBytes(b[1 : 1+byteLen])
    62  		y.FillBytes(b[1+byteLen : 1+2*byteLen])
    63  
    64  		x1, y1 := Unmarshal(curve, b)
    65  		if x1 != nil || y1 != nil {
    66  			t.Errorf("unmarshaling a point not on the curve succeeded")
    67  		}
    68  	})
    69  }
    70  
    71  func TestInfinity(t *testing.T) {
    72  	testAllCurves(t, testInfinity)
    73  }
    74  
    75  func isInfinity(x, y *big.Int) bool {
    76  	return x.Sign() == 0 && y.Sign() == 0
    77  }
    78  
    79  func testInfinity(t *testing.T, curve elliptic.Curve) {
    80  	x0, y0 := new(big.Int), new(big.Int)
    81  	xG, yG := curve.Params().Gx, curve.Params().Gy
    82  
    83  	if !isInfinity(curve.ScalarMult(xG, yG, curve.Params().N.Bytes())) {
    84  		t.Errorf("x^q != ∞")
    85  	}
    86  	if !isInfinity(curve.ScalarMult(xG, yG, []byte{0})) {
    87  		t.Errorf("x^0 != ∞")
    88  	}
    89  
    90  	if !isInfinity(curve.ScalarMult(x0, y0, []byte{1, 2, 3})) {
    91  		t.Errorf("∞^k != ∞")
    92  	}
    93  	if !isInfinity(curve.ScalarMult(x0, y0, []byte{0})) {
    94  		t.Errorf("∞^0 != ∞")
    95  	}
    96  
    97  	if !isInfinity(curve.ScalarBaseMult(curve.Params().N.Bytes())) {
    98  		t.Errorf("b^q != ∞")
    99  	}
   100  	if !isInfinity(curve.ScalarBaseMult([]byte{0})) {
   101  		t.Errorf("b^0 != ∞")
   102  	}
   103  
   104  	if !isInfinity(curve.Double(x0, y0)) {
   105  		t.Errorf("2∞ != ∞")
   106  	}
   107  	// There is no other point of order two on the NIST curves (as they have
   108  	// cofactor one), so Double can't otherwise return the point at infinity.
   109  
   110  	nMinusOne := new(big.Int).Sub(curve.Params().N, big.NewInt(1))
   111  	x, y := curve.ScalarMult(xG, yG, nMinusOne.Bytes())
   112  	x, y = curve.Add(x, y, xG, yG)
   113  	if !isInfinity(x, y) {
   114  		t.Errorf("x^(q-1) + x != ∞")
   115  	}
   116  	x, y = curve.Add(xG, yG, x0, y0)
   117  	if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
   118  		t.Errorf("x+∞ != x")
   119  	}
   120  	x, y = curve.Add(x0, y0, xG, yG)
   121  	if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
   122  		t.Errorf("∞+x != x")
   123  	}
   124  
   125  	if curve.IsOnCurve(x0, y0) {
   126  		t.Errorf("IsOnCurve(∞) == true")
   127  	}
   128  
   129  	if xx, yy := Unmarshal(curve, elliptic.Marshal(curve, x0, y0)); xx != nil || yy != nil {
   130  		t.Errorf("Unmarshal(Marshal(∞)) did not return an error")
   131  	}
   132  	// We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are
   133  	// two valid points with x = 0.
   134  	if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil {
   135  		t.Errorf("Unmarshal(∞) did not return an error")
   136  	}
   137  	byteLen := (curve.Params().BitSize + 7) / 8
   138  	buf := make([]byte, byteLen*2+1)
   139  	buf[0] = 4 // Uncompressed format.
   140  	if xx, yy := Unmarshal(curve, buf); xx != nil || yy != nil {
   141  		t.Errorf("Unmarshal((0,0)) did not return an error")
   142  	}
   143  }
   144  
   145  func TestMarshal(t *testing.T) {
   146  	testAllCurves(t, func(t *testing.T, curve elliptic.Curve) {
   147  		_, x, y, err := elliptic.GenerateKey(curve, rand.Reader)
   148  		if err != nil {
   149  			t.Fatal(err)
   150  		}
   151  		serialized := elliptic.Marshal(curve, x, y)
   152  		xx, yy := Unmarshal(curve, serialized)
   153  		if xx == nil {
   154  			t.Fatal("failed to unmarshal")
   155  		}
   156  		if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   157  			t.Fatal("unmarshal returned different values")
   158  		}
   159  	})
   160  }
   161  
   162  // TestInvalidCoordinates tests big.Int values that are not valid field elements
   163  // (negative or bigger than P). They are expected to return false from
   164  // IsOnCurve, all other behavior is undefined.
   165  func TestInvalidCoordinates(t *testing.T) {
   166  	testAllCurves(t, testInvalidCoordinates)
   167  }
   168  
   169  func testInvalidCoordinates(t *testing.T, curve elliptic.Curve) {
   170  	checkIsOnCurveFalse := func(name string, x, y *big.Int) {
   171  		if curve.IsOnCurve(x, y) {
   172  			t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
   173  		}
   174  	}
   175  
   176  	p := curve.Params().P
   177  	_, x, y, _ := elliptic.GenerateKey(curve, rand.Reader)
   178  	xx, yy := new(big.Int), new(big.Int)
   179  
   180  	// Check if the sign is getting dropped.
   181  	xx.Neg(x)
   182  	checkIsOnCurveFalse("-x, y", xx, y)
   183  	yy.Neg(y)
   184  	checkIsOnCurveFalse("x, -y", x, yy)
   185  
   186  	// Check if negative values are reduced modulo P.
   187  	xx.Sub(x, p)
   188  	checkIsOnCurveFalse("x-P, y", xx, y)
   189  	yy.Sub(y, p)
   190  	checkIsOnCurveFalse("x, y-P", x, yy)
   191  
   192  	// Check if positive values are reduced modulo P.
   193  	xx.Add(x, p)
   194  	checkIsOnCurveFalse("x+P, y", xx, y)
   195  	yy.Add(y, p)
   196  	checkIsOnCurveFalse("x, y+P", x, yy)
   197  
   198  	// Check if the overflow is dropped.
   199  	xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
   200  	checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
   201  	yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
   202  	checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
   203  
   204  	// Check if P is treated like zero (if possible).
   205  	// y^2 = x^3 - 3x + B
   206  	// y = mod_sqrt(x^3 - 3x + B)
   207  	// y = mod_sqrt(B) if x = 0
   208  	// If there is no modsqrt, there is no point with x = 0, can't test x = P.
   209  	if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
   210  		if !curve.IsOnCurve(big.NewInt(0), yy) {
   211  			t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
   212  		}
   213  		checkIsOnCurveFalse("P, y", p, yy)
   214  	}
   215  }
   216  
   217  func TestMarshalCompressed(t *testing.T) {
   218  	t.Run("P-256/03", func(t *testing.T) {
   219  		data, _ := hex.DecodeString("031b5709a068f5c1d05d0a61c0c70a13310df2d3a6c2ca9c9bba53337ea3e10de3")
   220  		x, _ := new(big.Int).SetString("1b5709a068f5c1d05d0a61c0c70a13310df2d3a6c2ca9c9bba53337ea3e10de3", 16)
   221  		y, _ := new(big.Int).SetString("a7ac81d1fdd4fcd224bbd95183136f948861812594ef24bd867c23d955fee3bb", 16)
   222  		testMarshalCompressed(t, P256(), x, y, data)
   223  	})
   224  	t.Run("P-256/02", func(t *testing.T) {
   225  		data, _ := hex.DecodeString("0258f9a2efca4139f2b07662b937439a719ea3bf59d7de346c365db7c85d4bc32a")
   226  		x, _ := new(big.Int).SetString("58f9a2efca4139f2b07662b937439a719ea3bf59d7de346c365db7c85d4bc32a", 16)
   227  		y, _ := new(big.Int).SetString("02680fbe48b1d8cf023d0b7c1d9ab9b56535384729db5fcb8db29ec72c7fc9ca", 16)
   228  		testMarshalCompressed(t, P256(), x, y, data)
   229  	})
   230  
   231  	t.Run("Invalid", func(t *testing.T) {
   232  		data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
   233  		X, Y := UnmarshalCompressed(P256(), data)
   234  		if X != nil || Y != nil {
   235  			t.Error("expected an error for invalid encoding")
   236  		}
   237  	})
   238  
   239  	if testing.Short() {
   240  		t.Skip("skipping other curves on short test")
   241  	}
   242  
   243  	testAllCurves(t, func(t *testing.T, curve elliptic.Curve) {
   244  		_, x, y, err := elliptic.GenerateKey(curve, rand.Reader)
   245  		if err != nil {
   246  			t.Fatal(err)
   247  		}
   248  		testMarshalCompressed(t, curve, x, y, nil)
   249  	})
   250  }
   251  
   252  func testMarshalCompressed(t *testing.T, curve elliptic.Curve, x, y *big.Int, want []byte) {
   253  	if !curve.IsOnCurve(x, y) {
   254  		t.Fatal("invalid test point")
   255  	}
   256  	got := elliptic.MarshalCompressed(curve, x, y)
   257  	if want != nil && !bytes.Equal(got, want) {
   258  		t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
   259  	}
   260  
   261  	X, Y := UnmarshalCompressed(curve, got)
   262  	if X == nil || Y == nil {
   263  		t.Fatalf("UnmarshalCompressed failed unexpectedly")
   264  	}
   265  
   266  	if !curve.IsOnCurve(X, Y) {
   267  		t.Error("UnmarshalCompressed returned a point not on the curve")
   268  	}
   269  	if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
   270  		t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
   271  	}
   272  }
   273  
   274  func TestLargeIsOnCurve(t *testing.T) {
   275  	testAllCurves(t, func(t *testing.T, curve elliptic.Curve) {
   276  		large := big.NewInt(1)
   277  		large.Lsh(large, 1000)
   278  		if curve.IsOnCurve(large, large) {
   279  			t.Errorf("(2^1000, 2^1000) is reported on the curve")
   280  		}
   281  	})
   282  }
   283  
   284  func benchmarkAllCurves(b *testing.B, f func(*testing.B, elliptic.Curve)) {
   285  	tests := []struct {
   286  		name  string
   287  		curve elliptic.Curve
   288  	}{
   289  		{"SM2P256", P256()},
   290  	}
   291  	for _, test := range tests {
   292  		curve := test.curve
   293  		b.Run(test.name, func(b *testing.B) {
   294  			f(b, curve)
   295  		})
   296  	}
   297  }
   298  
   299  func BenchmarkScalarBaseMult(b *testing.B) {
   300  	benchmarkAllCurves(b, func(b *testing.B, curve elliptic.Curve) {
   301  		priv, _, _, _ := elliptic.GenerateKey(curve, rand.Reader)
   302  		b.ReportAllocs()
   303  		b.ResetTimer()
   304  		for i := 0; i < b.N; i++ {
   305  			x, _ := curve.ScalarBaseMult(priv)
   306  			// Prevent the compiler from optimizing out the operation.
   307  			priv[0] ^= byte(x.Bits()[0])
   308  		}
   309  	})
   310  }
   311  
   312  func BenchmarkScalarMult(b *testing.B) {
   313  	benchmarkAllCurves(b, func(b *testing.B, curve elliptic.Curve) {
   314  		_, x, y, _ := elliptic.GenerateKey(curve, rand.Reader)
   315  		priv, _, _, _ := elliptic.GenerateKey(curve, rand.Reader)
   316  		b.ReportAllocs()
   317  		b.ResetTimer()
   318  		for i := 0; i < b.N; i++ {
   319  			x, y = curve.ScalarMult(x, y, priv)
   320  		}
   321  	})
   322  }
   323  
   324  func BenchmarkMarshalUnmarshal(b *testing.B) {
   325  	benchmarkAllCurves(b, func(b *testing.B, curve elliptic.Curve) {
   326  		_, x, y, _ := elliptic.GenerateKey(curve, rand.Reader)
   327  		b.Run("Uncompressed", func(b *testing.B) {
   328  			b.ReportAllocs()
   329  			for i := 0; i < b.N; i++ {
   330  				buf := elliptic.Marshal(curve, x, y)
   331  				xx, yy := Unmarshal(curve, buf)
   332  				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   333  					b.Error("Unmarshal output different from Marshal input")
   334  				}
   335  			}
   336  		})
   337  		b.Run("Compressed", func(b *testing.B) {
   338  			b.ReportAllocs()
   339  			for i := 0; i < b.N; i++ {
   340  				buf := elliptic.MarshalCompressed(curve, x, y)
   341  				xx, yy := UnmarshalCompressed(curve, buf)
   342  				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   343  					b.Error("Unmarshal output different from Marshal input")
   344  				}
   345  			}
   346  		})
   347  	})
   348  }