github.com/emmansun/gmsm@v0.29.1/sm9/bn256/g1_test.go (about)

     1  package bn256
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"io"
     8  	"math/big"
     9  	"testing"
    10  	"time"
    11  )
    12  
    13  func TestG1AddNeg(t *testing.T) {
    14  	g1, g2 := &G1{}, &G1{}
    15  
    16  	g1.Neg(Gen1)
    17  	g2.Add(g1, Gen1)
    18  	if !g2.p.IsInfinity() {
    19  		t.Fail()
    20  	}
    21  	g3 := &G1{}
    22  	g3.Set(Gen1)
    23  	if !g3.Equal(Gen1) {
    24  		t.Fail()
    25  	}
    26  }
    27  
    28  func TestG1AddSame(t *testing.T) {
    29  	g1, g2 := &G1{}, &G1{}
    30  	g1.Add(Gen1, Gen1)
    31  	g2.Double(Gen1)
    32  
    33  	if !g1.Equal(g2) {
    34  		t.Fail()
    35  	}
    36  }
    37  
    38  func TestCurvePointDouble(t *testing.T) {
    39  	p := &curvePoint{}
    40  	p.Double(p)
    41  	if !p.IsInfinity() {
    42  		t.Fail()
    43  	}
    44  }
    45  
    46  func TestCurvePointDobuleComplete(t *testing.T) {
    47  	t.Parallel()
    48  	t.Run("normal case", func(t *testing.T) {
    49  		p2 := &curvePoint{}
    50  		p2.DoubleComplete(curveGen)
    51  		p2.AffineFromProjective()
    52  
    53  		p3 := &curvePoint{}
    54  		curvePointDouble(p3, curveGen)
    55  		p3.AffineFromJacobian()
    56  
    57  		if !p2.Equal(p3) {
    58  			t.Errorf("Got %v, expected %v", p2, p3)
    59  		}
    60  	})
    61  
    62  	t.Run("exception case: IsInfinity", func(t *testing.T) {
    63  		p1 := &curvePoint{}
    64  		p1.SetInfinity()
    65  		p2 := &curvePoint{}
    66  		p2.DoubleComplete(p1)
    67  		p2.AffineFromProjective()
    68  		if !p2.IsInfinity() {
    69  			t.Fatal("should be infinity")
    70  		}
    71  	})
    72  }
    73  
    74  func TestCurvePointAddComplete(t *testing.T) {
    75  	t.Parallel()
    76  	t.Run("normal case", func(t *testing.T) {
    77  		p1 := &curvePoint{}
    78  		curvePointDouble(p1, curveGen)
    79  		p1.AffineFromJacobian()
    80  
    81  		p2 := &curvePoint{}
    82  		p2.AddComplete(p1, curveGen)
    83  		p2.AffineFromProjective()
    84  
    85  		p3 := &curvePoint{}
    86  		curvePointAdd(p3, curveGen, p1)
    87  		p3.AffineFromJacobian()
    88  
    89  		if !p2.Equal(p3) {
    90  			t.Errorf("Got %v, expected %v", p2, p3)
    91  		}
    92  	})
    93  	t.Run("exception case: double", func(t *testing.T) {
    94  		p2 := &curvePoint{}
    95  		p2.AddComplete(curveGen, curveGen)
    96  		p2.AffineFromProjective()
    97  
    98  		p3 := &curvePoint{}
    99  		curvePointDouble(p3, curveGen)
   100  		p3.AffineFromJacobian()
   101  		if !p2.Equal(p3) {
   102  			t.Errorf("Got %v, expected %v", p2, p3)
   103  		}
   104  	})
   105  	t.Run("exception case: neg", func(t *testing.T) {
   106  		p1 := &curvePoint{}
   107  		p1.Neg(curveGen)
   108  		p2 := &curvePoint{}
   109  		p2.AddComplete(curveGen, p1)
   110  		p2.AffineFromProjective()
   111  		if !p2.IsInfinity() {
   112  			t.Fatal("should be infinity")
   113  		}
   114  	})
   115  	t.Run("exception case: IsInfinity", func(t *testing.T) {
   116  		p1 := &curvePoint{}
   117  		p1.SetInfinity()
   118  		p2 := &curvePoint{}
   119  		p2.AddComplete(curveGen, p1)
   120  		p2.AffineFromProjective()
   121  		if !p2.Equal(curveGen) {
   122  			t.Fatal("should be curveGen")
   123  		}
   124  		p2.AddComplete(p1, curveGen)
   125  		p2.AffineFromProjective()
   126  		if !p2.Equal(curveGen) {
   127  			t.Fatal("should be curveGen")
   128  		}
   129  		p2.AddComplete(p1, p1)
   130  		p2.AffineFromProjective()
   131  		if !p2.IsInfinity() {
   132  			t.Fatal("should be infinity")
   133  		}
   134  	})
   135  }
   136  
   137  type g1BaseMultTest struct {
   138  	k string
   139  }
   140  
   141  var baseMultTests = []g1BaseMultTest{
   142  	{
   143  		"112233445566778899",
   144  	},
   145  	{
   146  		"112233445566778899112233445566778899",
   147  	},
   148  	{
   149  		"6950511619965839450988900688150712778015737983940691968051900319680",
   150  	},
   151  	{
   152  		"13479972933410060327035789020509431695094902435494295338570602119423",
   153  	},
   154  	{
   155  		"13479971751745682581351455311314208093898607229429740618390390702079",
   156  	},
   157  	{
   158  		"13479972931865328106486971546324465392952975980343228160962702868479",
   159  	},
   160  	{
   161  		"11795773708834916026404142434151065506931607341523388140225443265536",
   162  	},
   163  	{
   164  		"784254593043826236572847595991346435467177662189391577090",
   165  	},
   166  	{
   167  		"13479767645505654746623887797783387853576174193480695826442858012671",
   168  	},
   169  	{
   170  		"205688069665150753842126177372015544874550518966168735589597183",
   171  	},
   172  	{
   173  		"13479966930919337728895168462090683249159702977113823384618282123295",
   174  	},
   175  	{
   176  		"50210731791415612487756441341851895584393717453129007497216",
   177  	},
   178  	{
   179  		"26959946667150639794667015087019625940457807714424391721682722368041",
   180  	},
   181  	{
   182  		"26959946667150639794667015087019625940457807714424391721682722368042",
   183  	},
   184  	{
   185  		"26959946667150639794667015087019625940457807714424391721682722368043",
   186  	},
   187  	{
   188  		"26959946667150639794667015087019625940457807714424391721682722368044",
   189  	},
   190  	{
   191  		"26959946667150639794667015087019625940457807714424391721682722368045",
   192  	},
   193  	{
   194  		"26959946667150639794667015087019625940457807714424391721682722368046",
   195  	},
   196  	{
   197  		"26959946667150639794667015087019625940457807714424391721682722368047",
   198  	},
   199  	{
   200  		"26959946667150639794667015087019625940457807714424391721682722368048",
   201  	},
   202  	{
   203  		"26959946667150639794667015087019625940457807714424391721682722368049",
   204  	},
   205  	{
   206  		"26959946667150639794667015087019625940457807714424391721682722368050",
   207  	},
   208  	{
   209  		"26959946667150639794667015087019625940457807714424391721682722368051",
   210  	},
   211  	{
   212  		"26959946667150639794667015087019625940457807714424391721682722368052",
   213  	},
   214  	{
   215  		"26959946667150639794667015087019625940457807714424391721682722368053",
   216  	},
   217  	{
   218  		"26959946667150639794667015087019625940457807714424391721682722368054",
   219  	},
   220  	{
   221  		"26959946667150639794667015087019625940457807714424391721682722368055",
   222  	},
   223  	{
   224  		"26959946667150639794667015087019625940457807714424391721682722368056",
   225  	},
   226  	{
   227  		"26959946667150639794667015087019625940457807714424391721682722368057",
   228  	},
   229  	{
   230  		"26959946667150639794667015087019625940457807714424391721682722368058",
   231  	},
   232  	{
   233  		"26959946667150639794667015087019625940457807714424391721682722368059",
   234  	},
   235  	{
   236  		"26959946667150639794667015087019625940457807714424391721682722368060",
   237  	},
   238  }
   239  
   240  func TestG1BaseMult(t *testing.T) {
   241  	g1 := g1Curve
   242  	g1Generic := g1.Params()
   243  
   244  	scalars := make([]*big.Int, 0, len(baseMultTests)+1)
   245  	for i := 1; i <= 20; i++ {
   246  		k := new(big.Int).SetInt64(int64(i))
   247  		scalars = append(scalars, k)
   248  	}
   249  	for _, e := range baseMultTests {
   250  		k, _ := new(big.Int).SetString(e.k, 10)
   251  		scalars = append(scalars, k)
   252  	}
   253  	k := new(big.Int).SetInt64(1)
   254  	k.Lsh(k, 500)
   255  	scalars = append(scalars, k)
   256  
   257  	for i, k := range scalars {
   258  		x, y := g1.ScalarBaseMult(k.Bytes())
   259  		x2, y2 := g1Generic.ScalarBaseMult(k.Bytes())
   260  		if x.Cmp(x2) != 0 || y.Cmp(y2) != 0 {
   261  			t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, x, y, x2, y2)
   262  		}
   263  
   264  		if testing.Short() && i > 5 {
   265  			break
   266  		}
   267  	}
   268  }
   269  
   270  func TestG1ScalarMult(t *testing.T) {
   271  	checkScalar := func(t *testing.T, scalar []byte) {
   272  		p1, err := (&G1{}).ScalarBaseMult(scalar)
   273  		fatalIfErr(t, err)
   274  		p2, err := (&G1{}).ScalarMult(Gen1, scalar)
   275  		fatalIfErr(t, err)
   276  		p1.p.MakeAffine()
   277  		p2.p.MakeAffine()
   278  		if !p1.Equal(p2) {
   279  			t.Error("[k]G != ScalarBaseMult(k)")
   280  		}
   281  
   282  		d := new(big.Int).SetBytes(scalar)
   283  		d.Sub(Order, d)
   284  		d.Mod(d, Order)
   285  		g1, err := (&G1{}).ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
   286  		fatalIfErr(t, err)
   287  		g1.Add(g1, p1)
   288  		g1.p.MakeAffine()
   289  		if !g1.p.IsInfinity() {
   290  			t.Error("[N - k]G + [k]G != ∞")
   291  		}
   292  	}
   293  
   294  	byteLen := len(Order.Bytes())
   295  	bitLen := Order.BitLen()
   296  	t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
   297  	t.Run("1", func(t *testing.T) {
   298  		checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
   299  	})
   300  	t.Run("N-6", func(t *testing.T) {
   301  		checkScalar(t, new(big.Int).Sub(Order, big.NewInt(6)).Bytes())
   302  	})
   303  	t.Run("N-1", func(t *testing.T) {
   304  		checkScalar(t, new(big.Int).Sub(Order, big.NewInt(1)).Bytes())
   305  	})
   306  	t.Run("N", func(t *testing.T) { checkScalar(t, Order.Bytes()) })
   307  	t.Run("N+1", func(t *testing.T) {
   308  		checkScalar(t, new(big.Int).Add(Order, big.NewInt(1)).Bytes())
   309  	})
   310  	t.Run("N+22", func(t *testing.T) {
   311  		checkScalar(t, new(big.Int).Add(Order, big.NewInt(22)).Bytes())
   312  	})
   313  	t.Run("all1s", func(t *testing.T) {
   314  		s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
   315  		s.Sub(s, big.NewInt(1))
   316  		checkScalar(t, s.Bytes())
   317  	})
   318  	if testing.Short() {
   319  		return
   320  	}
   321  	for i := 0; i < bitLen; i++ {
   322  		t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
   323  			s := new(big.Int).Lsh(big.NewInt(1), uint(i))
   324  			checkScalar(t, s.FillBytes(make([]byte, byteLen)))
   325  		})
   326  	}
   327  	for i := 0; i <= 64; i++ {
   328  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   329  			checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
   330  		})
   331  	}
   332  
   333  	// Test N-64...N+64 since they risk overlapping with precomputed table values
   334  	// in the final additions.
   335  	for i := int64(-64); i <= 64; i++ {
   336  		t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) {
   337  			checkScalar(t, new(big.Int).Add(Order, big.NewInt(i)).Bytes())
   338  		})
   339  	}
   340  
   341  }
   342  
   343  func fatalIfErr(t *testing.T, err error) {
   344  	t.Helper()
   345  	if err != nil {
   346  		t.Fatal(err)
   347  	}
   348  }
   349  
   350  func TestFuzz(t *testing.T) {
   351  	g1 := g1Curve
   352  	g1Generic := g1.Params()
   353  
   354  	var scalar1 [32]byte
   355  	var scalar2 [32]byte
   356  	var timeout *time.Timer
   357  
   358  	if testing.Short() {
   359  		timeout = time.NewTimer(10 * time.Millisecond)
   360  	} else {
   361  		timeout = time.NewTimer(2 * time.Second)
   362  	}
   363  
   364  	for {
   365  		select {
   366  		case <-timeout.C:
   367  			return
   368  		default:
   369  		}
   370  
   371  		io.ReadFull(rand.Reader, scalar1[:])
   372  		io.ReadFull(rand.Reader, scalar2[:])
   373  
   374  		x, y := g1.ScalarBaseMult(scalar1[:])
   375  		x2, y2 := g1Generic.ScalarBaseMult(scalar1[:])
   376  
   377  		xx, yy := g1.ScalarMult(x, y, scalar2[:])
   378  		xx2, yy2 := g1Generic.ScalarMult(x2, y2, scalar2[:])
   379  
   380  		if x.Cmp(x2) != 0 || y.Cmp(y2) != 0 {
   381  			t.Fatalf("ScalarBaseMult does not match reference result with scalar: %x, please report this error to https://github.com/emmansun/gmsm/issues", scalar1)
   382  		}
   383  
   384  		if xx.Cmp(xx2) != 0 || yy.Cmp(yy2) != 0 {
   385  			t.Fatalf("ScalarMult does not match reference result with scalars: %x and %x, please report this error to https://github.com/emmansun/gmsm/issues", scalar1, scalar2)
   386  		}
   387  	}
   388  }
   389  
   390  func TestG1OnCurve(t *testing.T) {
   391  	if !g1Curve.IsOnCurve(g1Curve.Params().Gx, g1Curve.Params().Gy) {
   392  		t.Error("basepoint is not on the curve")
   393  	}
   394  }
   395  
   396  func TestOffCurve(t *testing.T) {
   397  	x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
   398  	if g1Curve.IsOnCurve(x, y) {
   399  		t.Errorf("point off curve is claimed to be on the curve")
   400  	}
   401  
   402  	byteLen := (g1Curve.Params().BitSize + 7) / 8
   403  	b := make([]byte, 1+2*byteLen)
   404  	b[0] = 4 // uncompressed point
   405  	x.FillBytes(b[1 : 1+byteLen])
   406  	y.FillBytes(b[1+byteLen : 1+2*byteLen])
   407  
   408  	x1, y1 := Unmarshal(g1Curve, b)
   409  	if x1 != nil || y1 != nil {
   410  		t.Errorf("unmarshaling a point not on the curve succeeded")
   411  	}
   412  }
   413  
   414  func isInfinity(x, y *big.Int) bool {
   415  	return x.Sign() == 0 && y.Sign() == 0
   416  }
   417  
   418  func TestInfinity(t *testing.T) {
   419  	x0, y0 := new(big.Int), new(big.Int)
   420  	xG, yG := g1Curve.Params().Gx, g1Curve.Params().Gy
   421  
   422  	if !isInfinity(g1Curve.ScalarMult(xG, yG, g1Curve.Params().N.Bytes())) {
   423  		t.Errorf("x^q != ∞")
   424  	}
   425  	if !isInfinity(g1Curve.ScalarMult(xG, yG, []byte{0})) {
   426  		t.Errorf("x^0 != ∞")
   427  	}
   428  
   429  	if !isInfinity(g1Curve.ScalarMult(x0, y0, []byte{1, 2, 3})) {
   430  		t.Errorf("∞^k != ∞")
   431  	}
   432  	if !isInfinity(g1Curve.ScalarMult(x0, y0, []byte{0})) {
   433  		t.Errorf("∞^0 != ∞")
   434  	}
   435  
   436  	if !isInfinity(g1Curve.ScalarBaseMult(g1Curve.Params().N.Bytes())) {
   437  		t.Errorf("b^q != ∞")
   438  	}
   439  	if !isInfinity(g1Curve.ScalarBaseMult([]byte{0})) {
   440  		t.Errorf("b^0 != ∞")
   441  	}
   442  
   443  	if !isInfinity(g1Curve.Double(x0, y0)) {
   444  		t.Errorf("2∞ != ∞")
   445  	}
   446  	// There is no other point of order two on the NIST curves (as they have
   447  	// cofactor one), so Double can't otherwise return the point at infinity.
   448  
   449  	nMinusOne := new(big.Int).Sub(g1Curve.Params().N, big.NewInt(1))
   450  	x, y := g1Curve.ScalarMult(xG, yG, nMinusOne.Bytes())
   451  	x, y = g1Curve.Add(x, y, xG, yG)
   452  	if !isInfinity(x, y) {
   453  		t.Errorf("x^(q-1) + x != ∞")
   454  	}
   455  	x, y = g1Curve.Add(xG, yG, x0, y0)
   456  	if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
   457  		t.Errorf("x+∞ != x")
   458  	}
   459  	x, y = g1Curve.Add(x0, y0, xG, yG)
   460  	if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
   461  		t.Errorf("∞+x != x")
   462  	}
   463  
   464  	if !g1Curve.IsOnCurve(x0, y0) {
   465  		t.Errorf("IsOnCurve(∞) != true")
   466  	}
   467  	/*
   468  			if xx, yy := Unmarshal(g1Curve, Marshal(g1Curve, x0, y0)); xx == nil || yy == nil {
   469  				t.Errorf("Unmarshal(Marshal(∞)) did return an error")
   470  			}
   471  			// We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are
   472  			// two valid points with x = 0.
   473  			if xx, yy := Unmarshal(g1Curve, []byte{0x00}); xx != nil || yy != nil {
   474  				t.Errorf("Unmarshal(∞) did not return an error")
   475  			}
   476  
   477  		byteLen := (g1Curve.Params().BitSize + 7) / 8
   478  		buf := make([]byte, byteLen*2+1)
   479  		buf[0] = 4 // Uncompressed format.
   480  		if xx, yy := Unmarshal(g1Curve, buf); xx == nil || yy == nil {
   481  			t.Errorf("Unmarshal((0,0)) did return an error")
   482  		}
   483  	*/
   484  }
   485  
   486  func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
   487  	tests := []struct {
   488  		name  string
   489  		curve Curve
   490  	}{
   491  		{"g1", g1Curve},
   492  		{"g1/Params", g1Curve.params},
   493  	}
   494  	for _, test := range tests {
   495  		curve := test.curve
   496  		t.Run(test.name, func(t *testing.T) {
   497  			t.Parallel()
   498  			f(t, curve)
   499  		})
   500  	}
   501  }
   502  
   503  func TestMarshal(t *testing.T) {
   504  	testAllCurves(t, func(t *testing.T, curve Curve) {
   505  		_, x, y, err := GenerateKey(curve, rand.Reader)
   506  		if err != nil {
   507  			t.Fatal(err)
   508  		}
   509  		serialized := Marshal(curve, x, y)
   510  		xx, yy := Unmarshal(curve, serialized)
   511  		if xx == nil {
   512  			t.Fatal("failed to unmarshal")
   513  		}
   514  		if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   515  			t.Fatal("unmarshal returned different values")
   516  		}
   517  	})
   518  }
   519  
   520  func TestMarshalCompressed(t *testing.T) {
   521  	testAllCurves(t, func(t *testing.T, curve Curve) {
   522  		_, x, y, err := GenerateKey(curve, rand.Reader)
   523  		if err != nil {
   524  			t.Fatal(err)
   525  		}
   526  		testMarshalCompressed(t, curve, x, y, nil)
   527  	})
   528  }
   529  
   530  func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
   531  	if !curve.IsOnCurve(x, y) {
   532  		t.Fatal("invalid test point")
   533  	}
   534  	got := MarshalCompressed(curve, x, y)
   535  	if want != nil && !bytes.Equal(got, want) {
   536  		t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
   537  	}
   538  
   539  	X, Y := UnmarshalCompressed(curve, got)
   540  	if X == nil || Y == nil {
   541  		t.Fatalf("UnmarshalCompressed failed unexpectedly")
   542  	}
   543  
   544  	if !curve.IsOnCurve(X, Y) {
   545  		t.Error("UnmarshalCompressed returned a point not on the curve")
   546  	}
   547  	if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
   548  		t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
   549  	}
   550  }
   551  func TestInvalidCoordinates(t *testing.T) {
   552  	checkIsOnCurveFalse := func(name string, x, y *big.Int) {
   553  		if g1Curve.IsOnCurve(x, y) {
   554  			t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
   555  		}
   556  	}
   557  
   558  	p := g1Curve.Params().P
   559  	_, x, y, _ := GenerateKey(g1Curve, rand.Reader)
   560  	xx, yy := new(big.Int), new(big.Int)
   561  
   562  	// Check if the sign is getting dropped.
   563  	xx.Neg(x)
   564  	checkIsOnCurveFalse("-x, y", xx, y)
   565  	yy.Neg(y)
   566  	checkIsOnCurveFalse("x, -y", x, yy)
   567  
   568  	// Check if negative values are reduced modulo P.
   569  	xx.Sub(x, p)
   570  	checkIsOnCurveFalse("x-P, y", xx, y)
   571  	yy.Sub(y, p)
   572  	checkIsOnCurveFalse("x, y-P", x, yy)
   573  
   574  	/*
   575  		// Check if positive values are reduced modulo P.
   576  		xx.Add(x, p)
   577  		checkIsOnCurveFalse("x+P, y", xx, y)
   578  		yy.Add(y, p)
   579  		checkIsOnCurveFalse("x, y+P", x, yy)
   580  	*/
   581  	// Check if the overflow is dropped.
   582  	xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
   583  	checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
   584  	yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
   585  	checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
   586  
   587  	// Check if P is treated like zero (if possible).
   588  	// y^2 = x^3 + B
   589  	// y = mod_sqrt(x^3 + B)
   590  	// y = mod_sqrt(B) if x = 0
   591  	// If there is no modsqrt, there is no point with x = 0, can't test x = P.
   592  	if yy := new(big.Int).ModSqrt(g1Curve.Params().B, p); yy != nil {
   593  		if !g1Curve.IsOnCurve(big.NewInt(0), yy) {
   594  			t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
   595  		}
   596  		checkIsOnCurveFalse("P, y", p, yy)
   597  	}
   598  }
   599  
   600  func TestLargeIsOnCurve(t *testing.T) {
   601  	large := big.NewInt(1)
   602  	large.Lsh(large, 1000)
   603  	if g1Curve.IsOnCurve(large, large) {
   604  		t.Errorf("(2^1000, 2^1000) is reported on the curve")
   605  	}
   606  }
   607  
   608  func Test_G1MarshalCompressed(t *testing.T) {
   609  	e, e2 := &G1{}, &G1{}
   610  	ret := e.MarshalCompressed()
   611  	_, err := e2.UnmarshalCompressed(ret)
   612  	if err != nil {
   613  		t.Fatal(err)
   614  	}
   615  	if !e2.p.IsInfinity() {
   616  		t.Errorf("not same")
   617  	}
   618  	e.p.Set(curveGen)
   619  	ret = e.MarshalCompressed()
   620  	_, err = e2.UnmarshalCompressed(ret)
   621  	if err != nil {
   622  		t.Fatal(err)
   623  	}
   624  	if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z {
   625  		t.Errorf("not same")
   626  	}
   627  	e.p.Neg(e.p)
   628  	ret = e.MarshalCompressed()
   629  	_, err = e2.UnmarshalCompressed(ret)
   630  	if err != nil {
   631  		t.Fatal(err)
   632  	}
   633  	if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z {
   634  		t.Errorf("not same")
   635  	}
   636  }
   637  
   638  func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) {
   639  	tests := []struct {
   640  		name  string
   641  		curve Curve
   642  	}{
   643  		{"sm9", g1Curve},
   644  		{"sm9Parmas", g1Curve.Params()},
   645  	}
   646  	for _, test := range tests {
   647  		curve := test.curve
   648  		b.Run(test.name, func(b *testing.B) {
   649  			f(b, curve)
   650  		})
   651  	}
   652  }
   653  
   654  func BenchmarkScalarBaseMult(b *testing.B) {
   655  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   656  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
   657  		b.ReportAllocs()
   658  		b.ResetTimer()
   659  		for i := 0; i < b.N; i++ {
   660  			x, _ := curve.ScalarBaseMult(priv)
   661  			// Prevent the compiler from optimizing out the operation.
   662  			priv[0] ^= byte(x.Bits()[0])
   663  		}
   664  	})
   665  }
   666  
   667  func BenchmarkScalarMult(b *testing.B) {
   668  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   669  		_, x, y, _ := GenerateKey(curve, rand.Reader)
   670  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
   671  		b.ReportAllocs()
   672  		b.ResetTimer()
   673  		for i := 0; i < b.N; i++ {
   674  			x, y = curve.ScalarMult(x, y, priv)
   675  		}
   676  	})
   677  }
   678  
   679  func BenchmarkMarshalUnmarshal(b *testing.B) {
   680  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   681  		_, x, y, _ := GenerateKey(curve, rand.Reader)
   682  		b.Run("Uncompressed", func(b *testing.B) {
   683  			b.ReportAllocs()
   684  			for i := 0; i < b.N; i++ {
   685  				buf := Marshal(curve, x, y)
   686  				xx, yy := Unmarshal(curve, buf)
   687  				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   688  					b.Error("Unmarshal output different from Marshal input")
   689  				}
   690  			}
   691  		})
   692  		b.Run("Compressed", func(b *testing.B) {
   693  			b.ReportAllocs()
   694  			for i := 0; i < b.N; i++ {
   695  				buf := MarshalCompressed(curve, x, y)
   696  				xx, yy := UnmarshalCompressed(curve, buf)
   697  				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   698  					b.Error("Unmarshal output different from Marshal input")
   699  				}
   700  			}
   701  		})
   702  	})
   703  }
   704  
   705  func BenchmarkAddPoint(b *testing.B) {
   706  	p1 := &curvePoint{}
   707  	curvePointDouble(p1, curveGen)
   708  	p1.AffineFromJacobian()
   709  	p2 := &curvePoint{}
   710  
   711  	b.Run("Add complete", func(b *testing.B) {
   712  		b.ReportAllocs()
   713  		for i := 0; i < b.N; i++ {
   714  			p2.AddComplete(curveGen, p1)
   715  		}
   716  	})
   717  
   718  	b.Run("Add traditional", func(b *testing.B) {
   719  		b.ReportAllocs()
   720  		for i := 0; i < b.N; i++ {
   721  			curvePointAdd(p2, curveGen, p1)
   722  		}
   723  	})
   724  }
   725  
   726  func BenchmarkDoublePoint(b *testing.B) {
   727  	p2 := &curvePoint{}
   728  
   729  	b.Run("Double complete", func(b *testing.B) {
   730  		b.ReportAllocs()
   731  		for i := 0; i < b.N; i++ {
   732  			p2.DoubleComplete(curveGen)
   733  		}
   734  	})
   735  
   736  	b.Run("Double traditional", func(b *testing.B) {
   737  		b.ReportAllocs()
   738  		for i := 0; i < b.N; i++ {
   739  			curvePointDouble(p2, curveGen)
   740  		}
   741  	})
   742  	
   743  }