github.com/emmansun/gmsm@v0.29.1/internal/sm2ec/sm2ec_test.go (about) 1 package sm2ec 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "encoding/hex" 7 "fmt" 8 "math/big" 9 "testing" 10 ) 11 12 // r = 2^256 13 var r = bigFromHex("010000000000000000000000000000000000000000000000000000000000000000") 14 var r0 = bigFromHex("010000000000000000") 15 var sm2Prime = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF") 16 var sm2n = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123") 17 var nistP256Prime = bigFromDecimal("115792089210356248762697446949407573530086143415290314195533631308867097853951") 18 var nistP256N = bigFromDecimal("115792089210356248762697446949407573529996955224135760342422259061068512044369") 19 20 func generateMontgomeryDomain(in *big.Int, p *big.Int) *big.Int { 21 tmp := new(big.Int) 22 tmp = tmp.Mul(in, r) 23 return tmp.Mod(tmp, p) 24 } 25 26 func bigFromHex(s string) *big.Int { 27 b, ok := new(big.Int).SetString(s, 16) 28 if !ok { 29 panic("sm2ec: internal error: invalid encoding") 30 } 31 return b 32 } 33 34 func bigFromDecimal(s string) *big.Int { 35 b, ok := new(big.Int).SetString(s, 10) 36 if !ok { 37 panic("sm2ec: internal error: invalid encoding") 38 } 39 return b 40 } 41 42 func TestSM2P256MontgomeryDomain(t *testing.T) { 43 tests := []struct { 44 in string 45 out string 46 }{ 47 { // One 48 "01", 49 "0000000100000000000000000000000000000000ffffffff0000000000000001", 50 }, 51 { // Gx 52 "32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 53 "91167a5ee1c13b05d6a1ed99ac24c3c33e7981eddca6c05061328990f418029e", 54 }, 55 { // Gy 56 "BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 57 "63cd65d481d735bd8d4cfb066e2a48f8c1f5e5788d3295fac1354e593c2d0ddd", 58 }, 59 { // B 60 "28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 61 "240fe188ba20e2c8527981505ea51c3c71cf379ae9b537ab90d230632bc0dd42", 62 }, 63 { // R 64 "010000000000000000000000000000000000000000000000000000000000000000", 65 "0400000002000000010000000100000002ffffffff0000000200000003", 66 }, 67 } 68 for _, test := range tests { 69 out := generateMontgomeryDomain(bigFromHex(test.in), sm2Prime) 70 if out.Cmp(bigFromHex(test.out)) != 0 { 71 t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes())) 72 } 73 } 74 } 75 76 func TestSM2P256MontgomeryDomainN(t *testing.T) { 77 tests := []struct { 78 in string 79 out string 80 }{ 81 { // One 82 "01", 83 "010000000000000000000000008dfc2094de39fad4ac440bf6c62abedd", 84 }, 85 { // R 86 "010000000000000000000000000000000000000000000000000000000000000000", 87 "1eb5e412a22b3d3b620fc84c3affe0d43464504ade6fa2fa901192af7c114f20", 88 }, 89 } 90 for _, test := range tests { 91 out := generateMontgomeryDomain(bigFromHex(test.in), sm2n) 92 if out.Cmp(bigFromHex(test.out)) != 0 { 93 t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes())) 94 } 95 } 96 } 97 98 func TestSM2P256MontgomeryK0(t *testing.T) { 99 tests := []struct { 100 in *big.Int 101 out string 102 }{ 103 { 104 sm2n, 105 "327f9e8872350975", 106 }, 107 { 108 sm2Prime, 109 "0000000000000001", 110 }, 111 } 112 for _, test := range tests { 113 // k0 = -in^(-1) mod 2^64 114 k0 := new(big.Int).ModInverse(test.in, r0) 115 k0.Neg(k0) 116 k0.Mod(k0, r0) 117 if k0.Cmp(bigFromHex(test.out)) != 0 { 118 t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(k0.Bytes())) 119 } 120 } 121 } 122 123 func TestNISTP256MontgomeryDomain(t *testing.T) { 124 tests := []struct { 125 in string 126 out string 127 }{ 128 { // One 129 "01", 130 "fffffffeffffffffffffffffffffffff000000000000000000000001", 131 }, 132 { // Gx 133 "6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", 134 "18905f76a53755c679fb732b7762251075ba95fc5fedb60179e730d418a9143c", 135 }, 136 { // Gy 137 "4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", 138 "8571ff1825885d85d2e88688dd21f3258b4ab8e4ba19e45cddf25357ce95560a", 139 }, 140 { // B 141 "5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", 142 "dc30061d04874834e5a220abf7212ed6acf005cd78843090d89cdf6229c4bddf", 143 }, 144 { // R 145 "010000000000000000000000000000000000000000000000000000000000000000", 146 "04fffffffdfffffffffffffffefffffffbffffffff0000000000000003", 147 }, 148 } 149 for _, test := range tests { 150 out := generateMontgomeryDomain(bigFromHex(test.in), nistP256Prime) 151 if out.Cmp(bigFromHex(test.out)) != 0 { 152 t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes())) 153 } 154 } 155 } 156 157 func TestForSqrt(t *testing.T) { 158 mod4 := new(big.Int).Mod(sm2Prime, big.NewInt(4)) 159 if mod4.Cmp(big.NewInt(3)) != 0 { 160 t.Fatal("sm2 prime is not fulfill 3 mod 4") 161 } 162 163 exp := new(big.Int).Add(sm2Prime, big.NewInt(1)) 164 exp.Div(exp, big.NewInt(4)) 165 } 166 167 func TestEquivalents(t *testing.T) { 168 p := NewSM2P256Point().SetGenerator() 169 170 elementSize := 32 171 two := make([]byte, elementSize) 172 two[len(two)-1] = 2 173 nPlusTwo := make([]byte, elementSize) 174 new(big.Int).Add(sm2n, big.NewInt(2)).FillBytes(nPlusTwo) 175 176 p1 := NewSM2P256Point().Double(p) 177 p2 := NewSM2P256Point().Add(p, p) 178 p3, err := NewSM2P256Point().ScalarMult(p, two) 179 fatalIfErr(t, err) 180 p4, err := NewSM2P256Point().ScalarBaseMult(two) 181 fatalIfErr(t, err) 182 p5, err := NewSM2P256Point().ScalarMult(p, nPlusTwo) 183 fatalIfErr(t, err) 184 p6, err := NewSM2P256Point().ScalarBaseMult(nPlusTwo) 185 fatalIfErr(t, err) 186 187 if !bytes.Equal(p1.Bytes(), p2.Bytes()) { 188 t.Error("P+P != 2*P") 189 } 190 if !bytes.Equal(p1.Bytes(), p3.Bytes()) { 191 t.Error("P+P != [2]P") 192 } 193 if !bytes.Equal(p1.Bytes(), p4.Bytes()) { 194 t.Error("G+G != [2]G") 195 } 196 if !bytes.Equal(p1.Bytes(), p5.Bytes()) { 197 t.Error("P+P != [N+2]P") 198 } 199 if !bytes.Equal(p1.Bytes(), p6.Bytes()) { 200 t.Error("G+G != [N+2]G") 201 } 202 } 203 204 func TestBasicScalarMult(t *testing.T) { 205 testvector := []struct { 206 name string 207 scalar *big.Int 208 expected string 209 }{ 210 { 211 "32", 212 big.NewInt(32), 213 "0425d3debd0950d180a6d5c2b5817f2329791734cd03e5565ca32641e56024666c92d99a70679d61efb938c406dd5cb0e10458895120e208b4d39e100303fa10a2", 214 }, 215 { 216 "N-3", 217 new(big.Int).Sub(sm2n, big.NewInt(3)), 218 "04a97f7cd4b3c993b4be2daa8cdb41e24ca13f6bd945302244e26918f1d0509ebfacf4a2267397710a333a313f758deaf083bff11932fbad6e555322fc8ba70919", 219 }, 220 } 221 p := NewSM2P256Point().SetGenerator() 222 223 for _, test := range testvector { 224 scalar := make([]byte, 32) 225 test.scalar.FillBytes(scalar) 226 p1, err := NewSM2P256Point().ScalarBaseMult(scalar) 227 fatalIfErr(t, err) 228 p2, err := NewSM2P256Point().ScalarMult(p, scalar) 229 fatalIfErr(t, err) 230 if hex.EncodeToString(p1.Bytes()) != test.expected { 231 t.Errorf("%s ScalarBaseMult fail, got %x", test.name, p1.Bytes()) 232 } 233 if hex.EncodeToString(p2.Bytes()) != test.expected { 234 t.Errorf("%s ScalarMult fail, got %x", test.name, p2.Bytes()) 235 } 236 } 237 } 238 239 func TestScalarMult(t *testing.T) { 240 G := NewSM2P256Point().SetGenerator() 241 checkScalar := func(t *testing.T, scalar []byte) { 242 p1, err := NewSM2P256Point().ScalarBaseMult(scalar) 243 fatalIfErr(t, err) 244 p2, err := NewSM2P256Point().ScalarMult(G, scalar) 245 fatalIfErr(t, err) 246 if !bytes.Equal(p1.Bytes(), p2.Bytes()) { 247 t.Errorf("[k]G != ScalarBaseMult(k), k=%x, p1=%x, p2=%x", scalar, p1.Bytes(), p2.Bytes()) 248 } 249 250 d := new(big.Int).SetBytes(scalar) 251 d.Sub(sm2n, d) 252 d.Mod(d, sm2n) 253 g1, err := NewSM2P256Point().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar)))) 254 fatalIfErr(t, err) 255 g1.Add(g1, p1) 256 if !bytes.Equal(g1.Bytes(), NewSM2P256Point().Bytes()) { 257 t.Errorf("[N - k]G + [k]G != ∞, k=%x, g1=%x", scalar, g1.Bytes()) 258 } 259 } 260 261 byteLen := len(sm2n.Bytes()) 262 bitLen := sm2n.BitLen() 263 t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) }) 264 t.Run("1", func(t *testing.T) { 265 checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen))) 266 }) 267 t.Run("N-6", func(t *testing.T) { 268 checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(6)).Bytes()) 269 }) 270 t.Run("N-1", func(t *testing.T) { 271 checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(1)).Bytes()) 272 }) 273 t.Run("N", func(t *testing.T) { checkScalar(t, sm2n.Bytes()) }) 274 t.Run("N+1", func(t *testing.T) { 275 checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(1)).Bytes()) 276 }) 277 t.Run("N+58", func(t *testing.T) { 278 checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(58)).Bytes()) 279 }) 280 t.Run("all1s", func(t *testing.T) { 281 s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen)) 282 s.Sub(s, big.NewInt(1)) 283 checkScalar(t, s.Bytes()) 284 }) 285 if testing.Short() { 286 return 287 } 288 for i := 0; i < bitLen; i++ { 289 t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) { 290 s := new(big.Int).Lsh(big.NewInt(1), uint(i)) 291 checkScalar(t, s.FillBytes(make([]byte, byteLen))) 292 }) 293 } 294 for i := 0; i <= 64; i++ { 295 t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { 296 checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen))) 297 }) 298 } 299 300 // Test N-64...N+64 since they risk overlapping with precomputed table values 301 // in the final additions. 302 for i := int64(-64); i <= 64; i++ { 303 t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) { 304 checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes()) 305 }) 306 } 307 308 } 309 310 func fatalIfErr(t *testing.T, err error) { 311 t.Helper() 312 if err != nil { 313 t.Fatal(err) 314 } 315 } 316 317 func BenchmarkScalarBaseMult(b *testing.B) { 318 p := NewSM2P256Point().SetGenerator() 319 scalar := make([]byte, 32) 320 rand.Read(scalar) 321 b.ReportAllocs() 322 b.ResetTimer() 323 for i := 0; i < b.N; i++ { 324 p.ScalarBaseMult(scalar) 325 } 326 } 327 328 func BenchmarkScalarMult(b *testing.B) { 329 p := NewSM2P256Point().SetGenerator() 330 scalar := make([]byte, 32) 331 rand.Read(scalar) 332 b.ReportAllocs() 333 b.ResetTimer() 334 for i := 0; i < b.N; i++ { 335 p.ScalarMult(p, scalar) 336 } 337 }