github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/crypto/elliptic/elliptic_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package elliptic 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "encoding/hex" 11 "math/big" 12 "testing" 13 ) 14 15 // genericParamsForCurve returns the dereferenced CurveParams for 16 // the specified curve. This is used to avoid the logic for 17 // upgrading a curve to its specific implementation, forcing 18 // usage of the generic implementation. 19 func genericParamsForCurve(c Curve) *CurveParams { 20 d := *(c.Params()) 21 return &d 22 } 23 24 func testAllCurves(t *testing.T, f func(*testing.T, Curve)) { 25 tests := []struct { 26 name string 27 curve Curve 28 }{ 29 {"P256", P256()}, 30 {"P256/Params", genericParamsForCurve(P256())}, 31 {"P224", P224()}, 32 {"P224/Params", genericParamsForCurve(P224())}, 33 {"P384", P384()}, 34 {"P384/Params", genericParamsForCurve(P384())}, 35 {"P521", P521()}, 36 {"P521/Params", genericParamsForCurve(P521())}, 37 } 38 if testing.Short() { 39 tests = tests[:1] 40 } 41 for _, test := range tests { 42 curve := test.curve 43 t.Run(test.name, func(t *testing.T) { 44 t.Parallel() 45 f(t, curve) 46 }) 47 } 48 } 49 50 func TestOnCurve(t *testing.T) { 51 testAllCurves(t, func(t *testing.T, curve Curve) { 52 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) { 53 t.Error("basepoint is not on the curve") 54 } 55 }) 56 } 57 58 func TestOffCurve(t *testing.T) { 59 testAllCurves(t, func(t *testing.T, curve Curve) { 60 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1) 61 if curve.IsOnCurve(x, y) { 62 t.Errorf("point off curve is claimed to be on the curve") 63 } 64 65 byteLen := (curve.Params().BitSize + 7) / 8 66 b := make([]byte, 1+2*byteLen) 67 b[0] = 4 // uncompressed point 68 x.FillBytes(b[1 : 1+byteLen]) 69 y.FillBytes(b[1+byteLen : 1+2*byteLen]) 70 71 x1, y1 := Unmarshal(curve, b) 72 if x1 != nil || y1 != nil { 73 t.Errorf("unmarshaling a point not on the curve succeeded") 74 } 75 }) 76 } 77 78 func TestInfinity(t *testing.T) { 79 testAllCurves(t, testInfinity) 80 } 81 82 func isInfinity(x, y *big.Int) bool { 83 return x.Sign() == 0 && y.Sign() == 0 84 } 85 86 func testInfinity(t *testing.T, curve Curve) { 87 x0, y0 := new(big.Int), new(big.Int) 88 xG, yG := curve.Params().Gx, curve.Params().Gy 89 90 if !isInfinity(curve.ScalarMult(xG, yG, curve.Params().N.Bytes())) { 91 t.Errorf("x^q != ∞") 92 } 93 if !isInfinity(curve.ScalarMult(xG, yG, []byte{0})) { 94 t.Errorf("x^0 != ∞") 95 } 96 97 if !isInfinity(curve.ScalarMult(x0, y0, []byte{1, 2, 3})) { 98 t.Errorf("∞^k != ∞") 99 } 100 if !isInfinity(curve.ScalarMult(x0, y0, []byte{0})) { 101 t.Errorf("∞^0 != ∞") 102 } 103 104 if !isInfinity(curve.ScalarBaseMult(curve.Params().N.Bytes())) { 105 t.Errorf("b^q != ∞") 106 } 107 if !isInfinity(curve.ScalarBaseMult([]byte{0})) { 108 t.Errorf("b^0 != ∞") 109 } 110 111 if !isInfinity(curve.Double(x0, y0)) { 112 t.Errorf("2∞ != ∞") 113 } 114 // There is no other point of order two on the NIST curves (as they have 115 // cofactor one), so Double can't otherwise return the point at infinity. 116 117 nMinusOne := new(big.Int).Sub(curve.Params().N, big.NewInt(1)) 118 x, y := curve.ScalarMult(xG, yG, nMinusOne.Bytes()) 119 x, y = curve.Add(x, y, xG, yG) 120 if !isInfinity(x, y) { 121 t.Errorf("x^(q-1) + x != ∞") 122 } 123 x, y = curve.Add(xG, yG, x0, y0) 124 if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 { 125 t.Errorf("x+∞ != x") 126 } 127 x, y = curve.Add(x0, y0, xG, yG) 128 if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 { 129 t.Errorf("∞+x != x") 130 } 131 132 if curve.IsOnCurve(x0, y0) { 133 t.Errorf("IsOnCurve(∞) == true") 134 } 135 136 if xx, yy := Unmarshal(curve, Marshal(curve, x0, y0)); xx != nil || yy != nil { 137 t.Errorf("Unmarshal(Marshal(∞)) did not return an error") 138 } 139 // We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are 140 // two valid points with x = 0. 141 if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil { 142 t.Errorf("Unmarshal(∞) did not return an error") 143 } 144 byteLen := (curve.Params().BitSize + 7) / 8 145 buf := make([]byte, byteLen*2+1) 146 buf[0] = 4 // Uncompressed format. 147 if xx, yy := Unmarshal(curve, buf); xx != nil || yy != nil { 148 t.Errorf("Unmarshal((0,0)) did not return an error") 149 } 150 } 151 152 func TestMarshal(t *testing.T) { 153 testAllCurves(t, func(t *testing.T, curve Curve) { 154 _, x, y, err := GenerateKey(curve, rand.Reader) 155 if err != nil { 156 t.Fatal(err) 157 } 158 serialized := Marshal(curve, x, y) 159 xx, yy := Unmarshal(curve, serialized) 160 if xx == nil { 161 t.Fatal("failed to unmarshal") 162 } 163 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 164 t.Fatal("unmarshal returned different values") 165 } 166 }) 167 } 168 169 func TestUnmarshalToLargeCoordinates(t *testing.T) { 170 // See https://golang.org/issues/20482. 171 testAllCurves(t, testUnmarshalToLargeCoordinates) 172 } 173 174 func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) { 175 p := curve.Params().P 176 byteLen := (p.BitLen() + 7) / 8 177 178 // Set x to be greater than curve's parameter P – specifically, to P+5. 179 // Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the 180 // curve. 181 x := new(big.Int).Add(p, big.NewInt(5)) 182 y := curve.Params().polynomial(x) 183 y.ModSqrt(y, p) 184 185 invalid := make([]byte, byteLen*2+1) 186 invalid[0] = 4 // uncompressed encoding 187 x.FillBytes(invalid[1 : 1+byteLen]) 188 y.FillBytes(invalid[1+byteLen:]) 189 190 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 191 t.Errorf("Unmarshal accepts invalid X coordinate") 192 } 193 194 if curve == p256 { 195 // This is a point on the curve with a small y value, small enough that 196 // we can add p and still be within 32 bytes. 197 x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10) 198 y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10) 199 y.Add(y, p) 200 201 if p.Cmp(y) > 0 || y.BitLen() != 256 { 202 t.Fatal("y not within expected range") 203 } 204 205 // marshal 206 x.FillBytes(invalid[1 : 1+byteLen]) 207 y.FillBytes(invalid[1+byteLen:]) 208 209 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 210 t.Errorf("Unmarshal accepts invalid Y coordinate") 211 } 212 } 213 } 214 215 // TestInvalidCoordinates tests big.Int values that are not valid field elements 216 // (negative or bigger than P). They are expected to return false from 217 // IsOnCurve, all other behavior is undefined. 218 func TestInvalidCoordinates(t *testing.T) { 219 testAllCurves(t, testInvalidCoordinates) 220 } 221 222 func testInvalidCoordinates(t *testing.T, curve Curve) { 223 checkIsOnCurveFalse := func(name string, x, y *big.Int) { 224 if curve.IsOnCurve(x, y) { 225 t.Errorf("IsOnCurve(%s) unexpectedly returned true", name) 226 } 227 } 228 229 p := curve.Params().P 230 _, x, y, _ := GenerateKey(curve, rand.Reader) 231 xx, yy := new(big.Int), new(big.Int) 232 233 // Check if the sign is getting dropped. 234 xx.Neg(x) 235 checkIsOnCurveFalse("-x, y", xx, y) 236 yy.Neg(y) 237 checkIsOnCurveFalse("x, -y", x, yy) 238 239 // Check if negative values are reduced modulo P. 240 xx.Sub(x, p) 241 checkIsOnCurveFalse("x-P, y", xx, y) 242 yy.Sub(y, p) 243 checkIsOnCurveFalse("x, y-P", x, yy) 244 245 // Check if positive values are reduced modulo P. 246 xx.Add(x, p) 247 checkIsOnCurveFalse("x+P, y", xx, y) 248 yy.Add(y, p) 249 checkIsOnCurveFalse("x, y+P", x, yy) 250 251 // Check if the overflow is dropped. 252 xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535)) 253 checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y) 254 yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535)) 255 checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy) 256 257 // Check if P is treated like zero (if possible). 258 // y^2 = x^3 - 3x + B 259 // y = mod_sqrt(x^3 - 3x + B) 260 // y = mod_sqrt(B) if x = 0 261 // If there is no modsqrt, there is no point with x = 0, can't test x = P. 262 if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil { 263 if !curve.IsOnCurve(big.NewInt(0), yy) { 264 t.Fatal("(0, mod_sqrt(B)) is not on the curve?") 265 } 266 checkIsOnCurveFalse("P, y", p, yy) 267 } 268 } 269 270 func TestMarshalCompressed(t *testing.T) { 271 t.Run("P-256/03", func(t *testing.T) { 272 data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 273 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 274 y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10) 275 testMarshalCompressed(t, P256(), x, y, data) 276 }) 277 t.Run("P-256/02", func(t *testing.T) { 278 data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 279 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 280 y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10) 281 testMarshalCompressed(t, P256(), x, y, data) 282 }) 283 284 t.Run("Invalid", func(t *testing.T) { 285 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535") 286 X, Y := UnmarshalCompressed(P256(), data) 287 if X != nil || Y != nil { 288 t.Error("expected an error for invalid encoding") 289 } 290 }) 291 292 if testing.Short() { 293 t.Skip("skipping other curves on short test") 294 } 295 296 testAllCurves(t, func(t *testing.T, curve Curve) { 297 _, x, y, err := GenerateKey(curve, rand.Reader) 298 if err != nil { 299 t.Fatal(err) 300 } 301 testMarshalCompressed(t, curve, x, y, nil) 302 }) 303 304 } 305 306 func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) { 307 if !curve.IsOnCurve(x, y) { 308 t.Fatal("invalid test point") 309 } 310 got := MarshalCompressed(curve, x, y) 311 if want != nil && !bytes.Equal(got, want) { 312 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want) 313 } 314 315 X, Y := UnmarshalCompressed(curve, got) 316 if X == nil || Y == nil { 317 t.Fatalf("UnmarshalCompressed failed unexpectedly") 318 } 319 320 if !curve.IsOnCurve(X, Y) { 321 t.Error("UnmarshalCompressed returned a point not on the curve") 322 } 323 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 { 324 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y) 325 } 326 } 327 328 func TestLargeIsOnCurve(t *testing.T) { 329 testAllCurves(t, func(t *testing.T, curve Curve) { 330 large := big.NewInt(1) 331 large.Lsh(large, 1000) 332 if curve.IsOnCurve(large, large) { 333 t.Errorf("(2^1000, 2^1000) is reported on the curve") 334 } 335 }) 336 } 337 338 func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) { 339 tests := []struct { 340 name string 341 curve Curve 342 }{ 343 {"P256", P256()}, 344 {"P224", P224()}, 345 {"P384", P384()}, 346 {"P521", P521()}, 347 } 348 for _, test := range tests { 349 curve := test.curve 350 b.Run(test.name, func(b *testing.B) { 351 f(b, curve) 352 }) 353 } 354 } 355 356 func BenchmarkScalarBaseMult(b *testing.B) { 357 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 358 priv, _, _, _ := GenerateKey(curve, rand.Reader) 359 b.ReportAllocs() 360 b.ResetTimer() 361 for i := 0; i < b.N; i++ { 362 x, _ := curve.ScalarBaseMult(priv) 363 // Prevent the compiler from optimizing out the operation. 364 priv[0] ^= byte(x.Bits()[0]) 365 } 366 }) 367 } 368 369 func BenchmarkScalarMult(b *testing.B) { 370 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 371 _, x, y, _ := GenerateKey(curve, rand.Reader) 372 priv, _, _, _ := GenerateKey(curve, rand.Reader) 373 b.ReportAllocs() 374 b.ResetTimer() 375 for i := 0; i < b.N; i++ { 376 x, y = curve.ScalarMult(x, y, priv) 377 } 378 }) 379 } 380 381 func BenchmarkMarshalUnmarshal(b *testing.B) { 382 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 383 _, x, y, _ := GenerateKey(curve, rand.Reader) 384 b.Run("Uncompressed", func(b *testing.B) { 385 b.ReportAllocs() 386 for i := 0; i < b.N; i++ { 387 buf := Marshal(curve, x, y) 388 xx, yy := Unmarshal(curve, buf) 389 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 390 b.Error("Unmarshal output different from Marshal input") 391 } 392 } 393 }) 394 b.Run("Compressed", func(b *testing.B) { 395 b.ReportAllocs() 396 for i := 0; i < b.N; i++ { 397 buf := MarshalCompressed(curve, x, y) 398 xx, yy := UnmarshalCompressed(curve, buf) 399 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 400 b.Error("Unmarshal output different from Marshal input") 401 } 402 } 403 }) 404 }) 405 }