github.com/emmansun/gmsm@v0.29.1/internal/sm2ec/sm2ec_test.go (about)

     1  package sm2ec
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"math/big"
     9  	"testing"
    10  )
    11  
    12  // r = 2^256
    13  var r = bigFromHex("010000000000000000000000000000000000000000000000000000000000000000")
    14  var r0 = bigFromHex("010000000000000000")
    15  var sm2Prime = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF")
    16  var sm2n = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123")
    17  var nistP256Prime = bigFromDecimal("115792089210356248762697446949407573530086143415290314195533631308867097853951")
    18  var nistP256N = bigFromDecimal("115792089210356248762697446949407573529996955224135760342422259061068512044369")
    19  
    20  func generateMontgomeryDomain(in *big.Int, p *big.Int) *big.Int {
    21  	tmp := new(big.Int)
    22  	tmp = tmp.Mul(in, r)
    23  	return tmp.Mod(tmp, p)
    24  }
    25  
    26  func bigFromHex(s string) *big.Int {
    27  	b, ok := new(big.Int).SetString(s, 16)
    28  	if !ok {
    29  		panic("sm2ec: internal error: invalid encoding")
    30  	}
    31  	return b
    32  }
    33  
    34  func bigFromDecimal(s string) *big.Int {
    35  	b, ok := new(big.Int).SetString(s, 10)
    36  	if !ok {
    37  		panic("sm2ec: internal error: invalid encoding")
    38  	}
    39  	return b
    40  }
    41  
    42  func TestSM2P256MontgomeryDomain(t *testing.T) {
    43  	tests := []struct {
    44  		in  string
    45  		out string
    46  	}{
    47  		{ // One
    48  			"01",
    49  			"0000000100000000000000000000000000000000ffffffff0000000000000001",
    50  		},
    51  		{ // Gx
    52  			"32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7",
    53  			"91167a5ee1c13b05d6a1ed99ac24c3c33e7981eddca6c05061328990f418029e",
    54  		},
    55  		{ // Gy
    56  			"BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0",
    57  			"63cd65d481d735bd8d4cfb066e2a48f8c1f5e5788d3295fac1354e593c2d0ddd",
    58  		},
    59  		{ // B
    60  			"28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93",
    61  			"240fe188ba20e2c8527981505ea51c3c71cf379ae9b537ab90d230632bc0dd42",
    62  		},
    63  		{ // R
    64  			"010000000000000000000000000000000000000000000000000000000000000000",
    65  			"0400000002000000010000000100000002ffffffff0000000200000003",
    66  		},
    67  	}
    68  	for _, test := range tests {
    69  		out := generateMontgomeryDomain(bigFromHex(test.in), sm2Prime)
    70  		if out.Cmp(bigFromHex(test.out)) != 0 {
    71  			t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes()))
    72  		}
    73  	}
    74  }
    75  
    76  func TestSM2P256MontgomeryDomainN(t *testing.T) {
    77  	tests := []struct {
    78  		in  string
    79  		out string
    80  	}{
    81  		{ // One
    82  			"01",
    83  			"010000000000000000000000008dfc2094de39fad4ac440bf6c62abedd",
    84  		},
    85  		{ // R
    86  			"010000000000000000000000000000000000000000000000000000000000000000",
    87  			"1eb5e412a22b3d3b620fc84c3affe0d43464504ade6fa2fa901192af7c114f20",
    88  		},
    89  	}
    90  	for _, test := range tests {
    91  		out := generateMontgomeryDomain(bigFromHex(test.in), sm2n)
    92  		if out.Cmp(bigFromHex(test.out)) != 0 {
    93  			t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes()))
    94  		}
    95  	}
    96  }
    97  
    98  func TestSM2P256MontgomeryK0(t *testing.T) {
    99  	tests := []struct {
   100  		in  *big.Int
   101  		out string
   102  	}{
   103  		{
   104  			sm2n,
   105  			"327f9e8872350975",
   106  		},
   107  		{
   108  			sm2Prime,
   109  			"0000000000000001",
   110  		},
   111  	}
   112  	for _, test := range tests {
   113  		// k0 = -in^(-1) mod 2^64
   114  		k0 := new(big.Int).ModInverse(test.in, r0)
   115  		k0.Neg(k0)
   116  		k0.Mod(k0, r0)
   117  		if k0.Cmp(bigFromHex(test.out)) != 0 {
   118  			t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(k0.Bytes()))
   119  		}
   120  	}
   121  }
   122  
   123  func TestNISTP256MontgomeryDomain(t *testing.T) {
   124  	tests := []struct {
   125  		in  string
   126  		out string
   127  	}{
   128  		{ // One
   129  			"01",
   130  			"fffffffeffffffffffffffffffffffff000000000000000000000001",
   131  		},
   132  		{ // Gx
   133  			"6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296",
   134  			"18905f76a53755c679fb732b7762251075ba95fc5fedb60179e730d418a9143c",
   135  		},
   136  		{ // Gy
   137  			"4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5",
   138  			"8571ff1825885d85d2e88688dd21f3258b4ab8e4ba19e45cddf25357ce95560a",
   139  		},
   140  		{ // B
   141  			"5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b",
   142  			"dc30061d04874834e5a220abf7212ed6acf005cd78843090d89cdf6229c4bddf",
   143  		},
   144  		{ // R
   145  			"010000000000000000000000000000000000000000000000000000000000000000",
   146  			"04fffffffdfffffffffffffffefffffffbffffffff0000000000000003",
   147  		},
   148  	}
   149  	for _, test := range tests {
   150  		out := generateMontgomeryDomain(bigFromHex(test.in), nistP256Prime)
   151  		if out.Cmp(bigFromHex(test.out)) != 0 {
   152  			t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes()))
   153  		}
   154  	}
   155  }
   156  
   157  func TestForSqrt(t *testing.T) {
   158  	mod4 := new(big.Int).Mod(sm2Prime, big.NewInt(4))
   159  	if mod4.Cmp(big.NewInt(3)) != 0 {
   160  		t.Fatal("sm2 prime is not fulfill 3 mod 4")
   161  	}
   162  
   163  	exp := new(big.Int).Add(sm2Prime, big.NewInt(1))
   164  	exp.Div(exp, big.NewInt(4))
   165  }
   166  
   167  func TestEquivalents(t *testing.T) {
   168  	p := NewSM2P256Point().SetGenerator()
   169  
   170  	elementSize := 32
   171  	two := make([]byte, elementSize)
   172  	two[len(two)-1] = 2
   173  	nPlusTwo := make([]byte, elementSize)
   174  	new(big.Int).Add(sm2n, big.NewInt(2)).FillBytes(nPlusTwo)
   175  
   176  	p1 := NewSM2P256Point().Double(p)
   177  	p2 := NewSM2P256Point().Add(p, p)
   178  	p3, err := NewSM2P256Point().ScalarMult(p, two)
   179  	fatalIfErr(t, err)
   180  	p4, err := NewSM2P256Point().ScalarBaseMult(two)
   181  	fatalIfErr(t, err)
   182  	p5, err := NewSM2P256Point().ScalarMult(p, nPlusTwo)
   183  	fatalIfErr(t, err)
   184  	p6, err := NewSM2P256Point().ScalarBaseMult(nPlusTwo)
   185  	fatalIfErr(t, err)
   186  
   187  	if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
   188  		t.Error("P+P != 2*P")
   189  	}
   190  	if !bytes.Equal(p1.Bytes(), p3.Bytes()) {
   191  		t.Error("P+P != [2]P")
   192  	}
   193  	if !bytes.Equal(p1.Bytes(), p4.Bytes()) {
   194  		t.Error("G+G != [2]G")
   195  	}
   196  	if !bytes.Equal(p1.Bytes(), p5.Bytes()) {
   197  		t.Error("P+P != [N+2]P")
   198  	}
   199  	if !bytes.Equal(p1.Bytes(), p6.Bytes()) {
   200  		t.Error("G+G != [N+2]G")
   201  	}
   202  }
   203  
   204  func TestBasicScalarMult(t *testing.T) {
   205  	testvector := []struct {
   206  		name     string
   207  		scalar   *big.Int
   208  		expected string
   209  	}{
   210  		{
   211  			"32",
   212  			big.NewInt(32),
   213  			"0425d3debd0950d180a6d5c2b5817f2329791734cd03e5565ca32641e56024666c92d99a70679d61efb938c406dd5cb0e10458895120e208b4d39e100303fa10a2",
   214  		},
   215  		{
   216  			"N-3",
   217  			new(big.Int).Sub(sm2n, big.NewInt(3)),
   218  			"04a97f7cd4b3c993b4be2daa8cdb41e24ca13f6bd945302244e26918f1d0509ebfacf4a2267397710a333a313f758deaf083bff11932fbad6e555322fc8ba70919",
   219  		},
   220  	}
   221  	p := NewSM2P256Point().SetGenerator()
   222  
   223  	for _, test := range testvector {
   224  		scalar := make([]byte, 32)
   225  		test.scalar.FillBytes(scalar)
   226  		p1, err := NewSM2P256Point().ScalarBaseMult(scalar)
   227  		fatalIfErr(t, err)
   228  		p2, err := NewSM2P256Point().ScalarMult(p, scalar)
   229  		fatalIfErr(t, err)
   230  		if hex.EncodeToString(p1.Bytes()) != test.expected {
   231  			t.Errorf("%s ScalarBaseMult fail, got %x", test.name, p1.Bytes())
   232  		}
   233  		if hex.EncodeToString(p2.Bytes()) != test.expected {
   234  			t.Errorf("%s ScalarMult fail, got %x", test.name, p2.Bytes())
   235  		}
   236  	}
   237  }
   238  
   239  func TestScalarMult(t *testing.T) {
   240  	G := NewSM2P256Point().SetGenerator()
   241  	checkScalar := func(t *testing.T, scalar []byte) {
   242  		p1, err := NewSM2P256Point().ScalarBaseMult(scalar)
   243  		fatalIfErr(t, err)
   244  		p2, err := NewSM2P256Point().ScalarMult(G, scalar)
   245  		fatalIfErr(t, err)
   246  		if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
   247  			t.Errorf("[k]G != ScalarBaseMult(k), k=%x, p1=%x, p2=%x", scalar, p1.Bytes(), p2.Bytes())
   248  		}
   249  
   250  		d := new(big.Int).SetBytes(scalar)
   251  		d.Sub(sm2n, d)
   252  		d.Mod(d, sm2n)
   253  		g1, err := NewSM2P256Point().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
   254  		fatalIfErr(t, err)
   255  		g1.Add(g1, p1)
   256  		if !bytes.Equal(g1.Bytes(), NewSM2P256Point().Bytes()) {
   257  			t.Errorf("[N - k]G + [k]G != ∞, k=%x, g1=%x", scalar, g1.Bytes())
   258  		}
   259  	}
   260  
   261  	byteLen := len(sm2n.Bytes())
   262  	bitLen := sm2n.BitLen()
   263  	t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
   264  	t.Run("1", func(t *testing.T) {
   265  		checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
   266  	})
   267  	t.Run("N-6", func(t *testing.T) {
   268  		checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(6)).Bytes())
   269  	})
   270  	t.Run("N-1", func(t *testing.T) {
   271  		checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(1)).Bytes())
   272  	})
   273  	t.Run("N", func(t *testing.T) { checkScalar(t, sm2n.Bytes()) })
   274  	t.Run("N+1", func(t *testing.T) {
   275  		checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(1)).Bytes())
   276  	})
   277  	t.Run("N+58", func(t *testing.T) {
   278  		checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(58)).Bytes())
   279  	})
   280  	t.Run("all1s", func(t *testing.T) {
   281  		s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
   282  		s.Sub(s, big.NewInt(1))
   283  		checkScalar(t, s.Bytes())
   284  	})
   285  	if testing.Short() {
   286  		return
   287  	}
   288  	for i := 0; i < bitLen; i++ {
   289  		t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
   290  			s := new(big.Int).Lsh(big.NewInt(1), uint(i))
   291  			checkScalar(t, s.FillBytes(make([]byte, byteLen)))
   292  		})
   293  	}
   294  	for i := 0; i <= 64; i++ {
   295  		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   296  			checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
   297  		})
   298  	}
   299  
   300  	// Test N-64...N+64 since they risk overlapping with precomputed table values
   301  	// in the final additions.
   302  	for i := int64(-64); i <= 64; i++ {
   303  		t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) {
   304  			checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes())
   305  		})
   306  	}
   307  
   308  }
   309  
   310  func fatalIfErr(t *testing.T, err error) {
   311  	t.Helper()
   312  	if err != nil {
   313  		t.Fatal(err)
   314  	}
   315  }
   316  
   317  func BenchmarkScalarBaseMult(b *testing.B) {
   318  	p := NewSM2P256Point().SetGenerator()
   319  	scalar := make([]byte, 32)
   320  	rand.Read(scalar)
   321  	b.ReportAllocs()
   322  	b.ResetTimer()
   323  	for i := 0; i < b.N; i++ {
   324  		p.ScalarBaseMult(scalar)
   325  	}
   326  }
   327  
   328  func BenchmarkScalarMult(b *testing.B) {
   329  	p := NewSM2P256Point().SetGenerator()
   330  	scalar := make([]byte, 32)
   331  	rand.Read(scalar)
   332  	b.ReportAllocs()
   333  	b.ResetTimer()
   334  	for i := 0; i < b.N; i++ {
   335  		p.ScalarMult(p, scalar)
   336  	}
   337  }