github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/crypto/elliptic/elliptic_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package elliptic 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "encoding/hex" 11 "math/big" 12 "testing" 13 ) 14 15 // genericParamsForCurve returns the dereferenced CurveParams for 16 // the specified curve. This is used to avoid the logic for 17 // upgrading a curve to its specific implementation, forcing 18 // usage of the generic implementation. 19 func genericParamsForCurve(c Curve) *CurveParams { 20 d := *(c.Params()) 21 return &d 22 } 23 24 func testAllCurves(t *testing.T, f func(*testing.T, Curve)) { 25 tests := []struct { 26 name string 27 curve Curve 28 }{ 29 {"P256", P256()}, 30 {"P256/Params", genericParamsForCurve(P256())}, 31 {"P224", P224()}, 32 {"P224/Params", genericParamsForCurve(P224())}, 33 {"P384", P384()}, 34 {"P384/Params", genericParamsForCurve(P384())}, 35 {"P521", P521()}, 36 {"P521/Params", genericParamsForCurve(P521())}, 37 } 38 if testing.Short() { 39 tests = tests[:1] 40 } 41 for _, test := range tests { 42 curve := test.curve 43 t.Run(test.name, func(t *testing.T) { 44 t.Parallel() 45 f(t, curve) 46 }) 47 } 48 } 49 50 func TestOnCurve(t *testing.T) { 51 t.Parallel() 52 testAllCurves(t, func(t *testing.T, curve Curve) { 53 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) { 54 t.Error("basepoint is not on the curve") 55 } 56 }) 57 } 58 59 func TestOffCurve(t *testing.T) { 60 t.Parallel() 61 testAllCurves(t, func(t *testing.T, curve Curve) { 62 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1) 63 if curve.IsOnCurve(x, y) { 64 t.Errorf("point off curve is claimed to be on the curve") 65 } 66 67 byteLen := (curve.Params().BitSize + 7) / 8 68 b := make([]byte, 1+2*byteLen) 69 b[0] = 4 // uncompressed point 70 x.FillBytes(b[1 : 1+byteLen]) 71 y.FillBytes(b[1+byteLen : 1+2*byteLen]) 72 73 x1, y1 := Unmarshal(curve, b) 74 if x1 != nil || y1 != nil { 75 t.Errorf("unmarshaling a point not on the curve succeeded") 76 } 77 }) 78 } 79 80 func TestInfinity(t *testing.T) { 81 t.Parallel() 82 testAllCurves(t, testInfinity) 83 } 84 85 func isInfinity(x, y *big.Int) bool { 86 return x.Sign() == 0 && y.Sign() == 0 87 } 88 89 func testInfinity(t *testing.T, curve Curve) { 90 x0, y0 := new(big.Int), new(big.Int) 91 xG, yG := curve.Params().Gx, curve.Params().Gy 92 93 if !isInfinity(curve.ScalarMult(xG, yG, curve.Params().N.Bytes())) { 94 t.Errorf("x^q != ∞") 95 } 96 if !isInfinity(curve.ScalarMult(xG, yG, []byte{0})) { 97 t.Errorf("x^0 != ∞") 98 } 99 100 if !isInfinity(curve.ScalarMult(x0, y0, []byte{1, 2, 3})) { 101 t.Errorf("∞^k != ∞") 102 } 103 if !isInfinity(curve.ScalarMult(x0, y0, []byte{0})) { 104 t.Errorf("∞^0 != ∞") 105 } 106 107 if !isInfinity(curve.ScalarBaseMult(curve.Params().N.Bytes())) { 108 t.Errorf("b^q != ∞") 109 } 110 if !isInfinity(curve.ScalarBaseMult([]byte{0})) { 111 t.Errorf("b^0 != ∞") 112 } 113 114 if !isInfinity(curve.Double(x0, y0)) { 115 t.Errorf("2∞ != ∞") 116 } 117 // There is no other point of order two on the NIST curves (as they have 118 // cofactor one), so Double can't otherwise return the point at infinity. 119 120 nMinusOne := new(big.Int).Sub(curve.Params().N, big.NewInt(1)) 121 x, y := curve.ScalarMult(xG, yG, nMinusOne.Bytes()) 122 x, y = curve.Add(x, y, xG, yG) 123 if !isInfinity(x, y) { 124 t.Errorf("x^(q-1) + x != ∞") 125 } 126 x, y = curve.Add(xG, yG, x0, y0) 127 if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 { 128 t.Errorf("x+∞ != x") 129 } 130 x, y = curve.Add(x0, y0, xG, yG) 131 if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 { 132 t.Errorf("∞+x != x") 133 } 134 135 if curve.IsOnCurve(x0, y0) { 136 t.Errorf("IsOnCurve(∞) == true") 137 } 138 139 if xx, yy := Unmarshal(curve, Marshal(curve, x0, y0)); xx != nil || yy != nil { 140 t.Errorf("Unmarshal(Marshal(∞)) did not return an error") 141 } 142 // We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are 143 // two valid points with x = 0. 144 if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil { 145 t.Errorf("Unmarshal(∞) did not return an error") 146 } 147 byteLen := (curve.Params().BitSize + 7) / 8 148 buf := make([]byte, byteLen*2+1) 149 buf[0] = 4 // Uncompressed format. 150 if xx, yy := Unmarshal(curve, buf); xx != nil || yy != nil { 151 t.Errorf("Unmarshal((0,0)) did not return an error") 152 } 153 } 154 155 func TestMarshal(t *testing.T) { 156 t.Parallel() 157 testAllCurves(t, func(t *testing.T, curve Curve) { 158 _, x, y, err := GenerateKey(curve, rand.Reader) 159 if err != nil { 160 t.Fatal(err) 161 } 162 serialized := Marshal(curve, x, y) 163 xx, yy := Unmarshal(curve, serialized) 164 if xx == nil { 165 t.Fatal("failed to unmarshal") 166 } 167 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 168 t.Fatal("unmarshal returned different values") 169 } 170 }) 171 } 172 173 func TestUnmarshalToLargeCoordinates(t *testing.T) { 174 t.Parallel() 175 // See https://golang.org/issues/20482. 176 testAllCurves(t, testUnmarshalToLargeCoordinates) 177 } 178 179 func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) { 180 p := curve.Params().P 181 byteLen := (p.BitLen() + 7) / 8 182 183 // Set x to be greater than curve's parameter P – specifically, to P+5. 184 // Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the 185 // curve. 186 x := new(big.Int).Add(p, big.NewInt(5)) 187 y := curve.Params().polynomial(x) 188 y.ModSqrt(y, p) 189 190 invalid := make([]byte, byteLen*2+1) 191 invalid[0] = 4 // uncompressed encoding 192 x.FillBytes(invalid[1 : 1+byteLen]) 193 y.FillBytes(invalid[1+byteLen:]) 194 195 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 196 t.Errorf("Unmarshal accepts invalid X coordinate") 197 } 198 199 if curve == p256 { 200 // This is a point on the curve with a small y value, small enough that 201 // we can add p and still be within 32 bytes. 202 x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10) 203 y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10) 204 y.Add(y, p) 205 206 if p.Cmp(y) > 0 || y.BitLen() != 256 { 207 t.Fatal("y not within expected range") 208 } 209 210 // marshal 211 x.FillBytes(invalid[1 : 1+byteLen]) 212 y.FillBytes(invalid[1+byteLen:]) 213 214 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 215 t.Errorf("Unmarshal accepts invalid Y coordinate") 216 } 217 } 218 } 219 220 // TestInvalidCoordinates tests big.Int values that are not valid field elements 221 // (negative or bigger than P). They are expected to return false from 222 // IsOnCurve, all other behavior is undefined. 223 func TestInvalidCoordinates(t *testing.T) { 224 t.Parallel() 225 testAllCurves(t, testInvalidCoordinates) 226 } 227 228 func testInvalidCoordinates(t *testing.T, curve Curve) { 229 checkIsOnCurveFalse := func(name string, x, y *big.Int) { 230 if curve.IsOnCurve(x, y) { 231 t.Errorf("IsOnCurve(%s) unexpectedly returned true", name) 232 } 233 } 234 235 p := curve.Params().P 236 _, x, y, _ := GenerateKey(curve, rand.Reader) 237 xx, yy := new(big.Int), new(big.Int) 238 239 // Check if the sign is getting dropped. 240 xx.Neg(x) 241 checkIsOnCurveFalse("-x, y", xx, y) 242 yy.Neg(y) 243 checkIsOnCurveFalse("x, -y", x, yy) 244 245 // Check if negative values are reduced modulo P. 246 xx.Sub(x, p) 247 checkIsOnCurveFalse("x-P, y", xx, y) 248 yy.Sub(y, p) 249 checkIsOnCurveFalse("x, y-P", x, yy) 250 251 // Check if positive values are reduced modulo P. 252 xx.Add(x, p) 253 checkIsOnCurveFalse("x+P, y", xx, y) 254 yy.Add(y, p) 255 checkIsOnCurveFalse("x, y+P", x, yy) 256 257 // Check if the overflow is dropped. 258 xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535)) 259 checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y) 260 yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535)) 261 checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy) 262 263 // Check if P is treated like zero (if possible). 264 // y^2 = x^3 - 3x + B 265 // y = mod_sqrt(x^3 - 3x + B) 266 // y = mod_sqrt(B) if x = 0 267 // If there is no modsqrt, there is no point with x = 0, can't test x = P. 268 if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil { 269 if !curve.IsOnCurve(big.NewInt(0), yy) { 270 t.Fatal("(0, mod_sqrt(B)) is not on the curve?") 271 } 272 checkIsOnCurveFalse("P, y", p, yy) 273 } 274 } 275 276 func TestMarshalCompressed(t *testing.T) { 277 t.Parallel() 278 t.Run("P-256/03", func(t *testing.T) { 279 data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 280 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 281 y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10) 282 testMarshalCompressed(t, P256(), x, y, data) 283 }) 284 t.Run("P-256/02", func(t *testing.T) { 285 data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 286 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 287 y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10) 288 testMarshalCompressed(t, P256(), x, y, data) 289 }) 290 291 t.Run("Invalid", func(t *testing.T) { 292 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535") 293 X, Y := UnmarshalCompressed(P256(), data) 294 if X != nil || Y != nil { 295 t.Error("expected an error for invalid encoding") 296 } 297 }) 298 299 if testing.Short() { 300 t.Skip("skipping other curves on short test") 301 } 302 303 testAllCurves(t, func(t *testing.T, curve Curve) { 304 _, x, y, err := GenerateKey(curve, rand.Reader) 305 if err != nil { 306 t.Fatal(err) 307 } 308 testMarshalCompressed(t, curve, x, y, nil) 309 }) 310 311 } 312 313 func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) { 314 if !curve.IsOnCurve(x, y) { 315 t.Fatal("invalid test point") 316 } 317 got := MarshalCompressed(curve, x, y) 318 if want != nil && !bytes.Equal(got, want) { 319 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want) 320 } 321 322 X, Y := UnmarshalCompressed(curve, got) 323 if X == nil || Y == nil { 324 t.Fatalf("UnmarshalCompressed failed unexpectedly") 325 } 326 327 if !curve.IsOnCurve(X, Y) { 328 t.Error("UnmarshalCompressed returned a point not on the curve") 329 } 330 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 { 331 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y) 332 } 333 } 334 335 func TestLargeIsOnCurve(t *testing.T) { 336 t.Parallel() 337 testAllCurves(t, func(t *testing.T, curve Curve) { 338 large := big.NewInt(1) 339 large.Lsh(large, 1000) 340 if curve.IsOnCurve(large, large) { 341 t.Errorf("(2^1000, 2^1000) is reported on the curve") 342 } 343 }) 344 } 345 346 func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) { 347 tests := []struct { 348 name string 349 curve Curve 350 }{ 351 {"P256", P256()}, 352 {"P224", P224()}, 353 {"P384", P384()}, 354 {"P521", P521()}, 355 } 356 for _, test := range tests { 357 curve := test.curve 358 b.Run(test.name, func(b *testing.B) { 359 f(b, curve) 360 }) 361 } 362 } 363 364 func BenchmarkScalarBaseMult(b *testing.B) { 365 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 366 priv, _, _, _ := GenerateKey(curve, rand.Reader) 367 b.ReportAllocs() 368 b.ResetTimer() 369 for i := 0; i < b.N; i++ { 370 x, _ := curve.ScalarBaseMult(priv) 371 // Prevent the compiler from optimizing out the operation. 372 priv[0] ^= byte(x.Bits()[0]) 373 } 374 }) 375 } 376 377 func BenchmarkScalarMult(b *testing.B) { 378 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 379 _, x, y, _ := GenerateKey(curve, rand.Reader) 380 priv, _, _, _ := GenerateKey(curve, rand.Reader) 381 b.ReportAllocs() 382 b.ResetTimer() 383 for i := 0; i < b.N; i++ { 384 x, y = curve.ScalarMult(x, y, priv) 385 } 386 }) 387 } 388 389 func BenchmarkMarshalUnmarshal(b *testing.B) { 390 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 391 _, x, y, _ := GenerateKey(curve, rand.Reader) 392 b.Run("Uncompressed", func(b *testing.B) { 393 b.ReportAllocs() 394 for i := 0; i < b.N; i++ { 395 buf := Marshal(curve, x, y) 396 xx, yy := Unmarshal(curve, buf) 397 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 398 b.Error("Unmarshal output different from Marshal input") 399 } 400 } 401 }) 402 b.Run("Compressed", func(b *testing.B) { 403 b.ReportAllocs() 404 for i := 0; i < b.N; i++ { 405 buf := MarshalCompressed(curve, x, y) 406 xx, yy := UnmarshalCompressed(curve, buf) 407 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 408 b.Error("Unmarshal output different from Marshal input") 409 } 410 } 411 }) 412 }) 413 }