github.com/emmansun/gmsm@v0.29.1/sm2/sm2ec/elliptic_test.go (about) 1 package sm2ec 2 3 import ( 4 "bytes" 5 "crypto/elliptic" 6 "crypto/rand" 7 "encoding/hex" 8 "math/big" 9 "testing" 10 ) 11 12 var _ = elliptic.P256() // force NIST P curves init, avoid panic when we invoke generic implementation's method 13 14 // genericParamsForCurve returns the dereferenced CurveParams for 15 // the specified curve. This is used to avoid the logic for 16 // upgrading a curve to its specific implementation, forcing 17 // usage of the generic implementation. 18 func genericParamsForCurve(c elliptic.Curve) *elliptic.CurveParams { 19 d := *(c.Params()) 20 return &d 21 } 22 23 func testAllCurves(t *testing.T, f func(*testing.T, elliptic.Curve)) { 24 tests := []struct { 25 name string 26 curve elliptic.Curve 27 }{ 28 {"SM2P256", P256()}, 29 {"SM2P256/Params", genericParamsForCurve(P256())}, 30 } 31 if testing.Short() { 32 tests = tests[:1] 33 } 34 for _, test := range tests { 35 curve := test.curve 36 t.Run(test.name, func(t *testing.T) { 37 t.Parallel() 38 f(t, curve) 39 }) 40 } 41 } 42 43 func TestOnCurve(t *testing.T) { 44 testAllCurves(t, func(t *testing.T, curve elliptic.Curve) { 45 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) { 46 t.Error("basepoint is not on the curve") 47 } 48 }) 49 } 50 51 func TestOffCurve(t *testing.T) { 52 testAllCurves(t, func(t *testing.T, curve elliptic.Curve) { 53 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1) 54 if curve.IsOnCurve(x, y) { 55 t.Errorf("point off curve is claimed to be on the curve") 56 } 57 58 byteLen := (curve.Params().BitSize + 7) / 8 59 b := make([]byte, 1+2*byteLen) 60 b[0] = 4 // uncompressed point 61 x.FillBytes(b[1 : 1+byteLen]) 62 y.FillBytes(b[1+byteLen : 1+2*byteLen]) 63 64 x1, y1 := Unmarshal(curve, b) 65 if x1 != nil || y1 != nil { 66 t.Errorf("unmarshaling a point not on the curve succeeded") 67 } 68 }) 69 } 70 71 func TestInfinity(t *testing.T) { 72 testAllCurves(t, testInfinity) 73 } 74 75 func isInfinity(x, y *big.Int) bool { 76 return x.Sign() == 0 && y.Sign() == 0 77 } 78 79 func testInfinity(t *testing.T, curve elliptic.Curve) { 80 x0, y0 := new(big.Int), new(big.Int) 81 xG, yG := curve.Params().Gx, curve.Params().Gy 82 83 if !isInfinity(curve.ScalarMult(xG, yG, curve.Params().N.Bytes())) { 84 t.Errorf("x^q != ∞") 85 } 86 if !isInfinity(curve.ScalarMult(xG, yG, []byte{0})) { 87 t.Errorf("x^0 != ∞") 88 } 89 90 if !isInfinity(curve.ScalarMult(x0, y0, []byte{1, 2, 3})) { 91 t.Errorf("∞^k != ∞") 92 } 93 if !isInfinity(curve.ScalarMult(x0, y0, []byte{0})) { 94 t.Errorf("∞^0 != ∞") 95 } 96 97 if !isInfinity(curve.ScalarBaseMult(curve.Params().N.Bytes())) { 98 t.Errorf("b^q != ∞") 99 } 100 if !isInfinity(curve.ScalarBaseMult([]byte{0})) { 101 t.Errorf("b^0 != ∞") 102 } 103 104 if !isInfinity(curve.Double(x0, y0)) { 105 t.Errorf("2∞ != ∞") 106 } 107 // There is no other point of order two on the NIST curves (as they have 108 // cofactor one), so Double can't otherwise return the point at infinity. 109 110 nMinusOne := new(big.Int).Sub(curve.Params().N, big.NewInt(1)) 111 x, y := curve.ScalarMult(xG, yG, nMinusOne.Bytes()) 112 x, y = curve.Add(x, y, xG, yG) 113 if !isInfinity(x, y) { 114 t.Errorf("x^(q-1) + x != ∞") 115 } 116 x, y = curve.Add(xG, yG, x0, y0) 117 if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 { 118 t.Errorf("x+∞ != x") 119 } 120 x, y = curve.Add(x0, y0, xG, yG) 121 if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 { 122 t.Errorf("∞+x != x") 123 } 124 125 if curve.IsOnCurve(x0, y0) { 126 t.Errorf("IsOnCurve(∞) == true") 127 } 128 129 if xx, yy := Unmarshal(curve, elliptic.Marshal(curve, x0, y0)); xx != nil || yy != nil { 130 t.Errorf("Unmarshal(Marshal(∞)) did not return an error") 131 } 132 // We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are 133 // two valid points with x = 0. 134 if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil { 135 t.Errorf("Unmarshal(∞) did not return an error") 136 } 137 byteLen := (curve.Params().BitSize + 7) / 8 138 buf := make([]byte, byteLen*2+1) 139 buf[0] = 4 // Uncompressed format. 140 if xx, yy := Unmarshal(curve, buf); xx != nil || yy != nil { 141 t.Errorf("Unmarshal((0,0)) did not return an error") 142 } 143 } 144 145 func TestMarshal(t *testing.T) { 146 testAllCurves(t, func(t *testing.T, curve elliptic.Curve) { 147 _, x, y, err := elliptic.GenerateKey(curve, rand.Reader) 148 if err != nil { 149 t.Fatal(err) 150 } 151 serialized := elliptic.Marshal(curve, x, y) 152 xx, yy := Unmarshal(curve, serialized) 153 if xx == nil { 154 t.Fatal("failed to unmarshal") 155 } 156 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 157 t.Fatal("unmarshal returned different values") 158 } 159 }) 160 } 161 162 // TestInvalidCoordinates tests big.Int values that are not valid field elements 163 // (negative or bigger than P). They are expected to return false from 164 // IsOnCurve, all other behavior is undefined. 165 func TestInvalidCoordinates(t *testing.T) { 166 testAllCurves(t, testInvalidCoordinates) 167 } 168 169 func testInvalidCoordinates(t *testing.T, curve elliptic.Curve) { 170 checkIsOnCurveFalse := func(name string, x, y *big.Int) { 171 if curve.IsOnCurve(x, y) { 172 t.Errorf("IsOnCurve(%s) unexpectedly returned true", name) 173 } 174 } 175 176 p := curve.Params().P 177 _, x, y, _ := elliptic.GenerateKey(curve, rand.Reader) 178 xx, yy := new(big.Int), new(big.Int) 179 180 // Check if the sign is getting dropped. 181 xx.Neg(x) 182 checkIsOnCurveFalse("-x, y", xx, y) 183 yy.Neg(y) 184 checkIsOnCurveFalse("x, -y", x, yy) 185 186 // Check if negative values are reduced modulo P. 187 xx.Sub(x, p) 188 checkIsOnCurveFalse("x-P, y", xx, y) 189 yy.Sub(y, p) 190 checkIsOnCurveFalse("x, y-P", x, yy) 191 192 // Check if positive values are reduced modulo P. 193 xx.Add(x, p) 194 checkIsOnCurveFalse("x+P, y", xx, y) 195 yy.Add(y, p) 196 checkIsOnCurveFalse("x, y+P", x, yy) 197 198 // Check if the overflow is dropped. 199 xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535)) 200 checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y) 201 yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535)) 202 checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy) 203 204 // Check if P is treated like zero (if possible). 205 // y^2 = x^3 - 3x + B 206 // y = mod_sqrt(x^3 - 3x + B) 207 // y = mod_sqrt(B) if x = 0 208 // If there is no modsqrt, there is no point with x = 0, can't test x = P. 209 if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil { 210 if !curve.IsOnCurve(big.NewInt(0), yy) { 211 t.Fatal("(0, mod_sqrt(B)) is not on the curve?") 212 } 213 checkIsOnCurveFalse("P, y", p, yy) 214 } 215 } 216 217 func TestMarshalCompressed(t *testing.T) { 218 t.Run("P-256/03", func(t *testing.T) { 219 data, _ := hex.DecodeString("031b5709a068f5c1d05d0a61c0c70a13310df2d3a6c2ca9c9bba53337ea3e10de3") 220 x, _ := new(big.Int).SetString("1b5709a068f5c1d05d0a61c0c70a13310df2d3a6c2ca9c9bba53337ea3e10de3", 16) 221 y, _ := new(big.Int).SetString("a7ac81d1fdd4fcd224bbd95183136f948861812594ef24bd867c23d955fee3bb", 16) 222 testMarshalCompressed(t, P256(), x, y, data) 223 }) 224 t.Run("P-256/02", func(t *testing.T) { 225 data, _ := hex.DecodeString("0258f9a2efca4139f2b07662b937439a719ea3bf59d7de346c365db7c85d4bc32a") 226 x, _ := new(big.Int).SetString("58f9a2efca4139f2b07662b937439a719ea3bf59d7de346c365db7c85d4bc32a", 16) 227 y, _ := new(big.Int).SetString("02680fbe48b1d8cf023d0b7c1d9ab9b56535384729db5fcb8db29ec72c7fc9ca", 16) 228 testMarshalCompressed(t, P256(), x, y, data) 229 }) 230 231 t.Run("Invalid", func(t *testing.T) { 232 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535") 233 X, Y := UnmarshalCompressed(P256(), data) 234 if X != nil || Y != nil { 235 t.Error("expected an error for invalid encoding") 236 } 237 }) 238 239 if testing.Short() { 240 t.Skip("skipping other curves on short test") 241 } 242 243 testAllCurves(t, func(t *testing.T, curve elliptic.Curve) { 244 _, x, y, err := elliptic.GenerateKey(curve, rand.Reader) 245 if err != nil { 246 t.Fatal(err) 247 } 248 testMarshalCompressed(t, curve, x, y, nil) 249 }) 250 } 251 252 func testMarshalCompressed(t *testing.T, curve elliptic.Curve, x, y *big.Int, want []byte) { 253 if !curve.IsOnCurve(x, y) { 254 t.Fatal("invalid test point") 255 } 256 got := elliptic.MarshalCompressed(curve, x, y) 257 if want != nil && !bytes.Equal(got, want) { 258 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want) 259 } 260 261 X, Y := UnmarshalCompressed(curve, got) 262 if X == nil || Y == nil { 263 t.Fatalf("UnmarshalCompressed failed unexpectedly") 264 } 265 266 if !curve.IsOnCurve(X, Y) { 267 t.Error("UnmarshalCompressed returned a point not on the curve") 268 } 269 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 { 270 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y) 271 } 272 } 273 274 func TestLargeIsOnCurve(t *testing.T) { 275 testAllCurves(t, func(t *testing.T, curve elliptic.Curve) { 276 large := big.NewInt(1) 277 large.Lsh(large, 1000) 278 if curve.IsOnCurve(large, large) { 279 t.Errorf("(2^1000, 2^1000) is reported on the curve") 280 } 281 }) 282 } 283 284 func benchmarkAllCurves(b *testing.B, f func(*testing.B, elliptic.Curve)) { 285 tests := []struct { 286 name string 287 curve elliptic.Curve 288 }{ 289 {"SM2P256", P256()}, 290 } 291 for _, test := range tests { 292 curve := test.curve 293 b.Run(test.name, func(b *testing.B) { 294 f(b, curve) 295 }) 296 } 297 } 298 299 func BenchmarkScalarBaseMult(b *testing.B) { 300 benchmarkAllCurves(b, func(b *testing.B, curve elliptic.Curve) { 301 priv, _, _, _ := elliptic.GenerateKey(curve, rand.Reader) 302 b.ReportAllocs() 303 b.ResetTimer() 304 for i := 0; i < b.N; i++ { 305 x, _ := curve.ScalarBaseMult(priv) 306 // Prevent the compiler from optimizing out the operation. 307 priv[0] ^= byte(x.Bits()[0]) 308 } 309 }) 310 } 311 312 func BenchmarkScalarMult(b *testing.B) { 313 benchmarkAllCurves(b, func(b *testing.B, curve elliptic.Curve) { 314 _, x, y, _ := elliptic.GenerateKey(curve, rand.Reader) 315 priv, _, _, _ := elliptic.GenerateKey(curve, rand.Reader) 316 b.ReportAllocs() 317 b.ResetTimer() 318 for i := 0; i < b.N; i++ { 319 x, y = curve.ScalarMult(x, y, priv) 320 } 321 }) 322 } 323 324 func BenchmarkMarshalUnmarshal(b *testing.B) { 325 benchmarkAllCurves(b, func(b *testing.B, curve elliptic.Curve) { 326 _, x, y, _ := elliptic.GenerateKey(curve, rand.Reader) 327 b.Run("Uncompressed", func(b *testing.B) { 328 b.ReportAllocs() 329 for i := 0; i < b.N; i++ { 330 buf := elliptic.Marshal(curve, x, y) 331 xx, yy := Unmarshal(curve, buf) 332 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 333 b.Error("Unmarshal output different from Marshal input") 334 } 335 } 336 }) 337 b.Run("Compressed", func(b *testing.B) { 338 b.ReportAllocs() 339 for i := 0; i < b.N; i++ { 340 buf := elliptic.MarshalCompressed(curve, x, y) 341 xx, yy := UnmarshalCompressed(curve, buf) 342 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 343 b.Error("Unmarshal output different from Marshal input") 344 } 345 } 346 }) 347 }) 348 }