github.com/comwrg/go/src@v0.0.0-20220319063731-c238d0440370/crypto/elliptic/elliptic_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package elliptic 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "encoding/hex" 11 "math/big" 12 "testing" 13 ) 14 15 // genericParamsForCurve returns the dereferenced CurveParams for 16 // the specified curve. This is used to avoid the logic for 17 // upgrading a curve to it's specific implementation, forcing 18 // usage of the generic implementation. This is only relevant 19 // for the P224, P256, and P521 curves. 20 func genericParamsForCurve(c Curve) *CurveParams { 21 d := *(c.Params()) 22 return &d 23 } 24 25 func testAllCurves(t *testing.T, f func(*testing.T, Curve)) { 26 tests := []struct { 27 name string 28 curve Curve 29 }{ 30 {"P256", P256()}, 31 {"P256/Params", genericParamsForCurve(P256())}, 32 {"P224", P224()}, 33 {"P224/Params", genericParamsForCurve(P224())}, 34 {"P384", P384()}, 35 {"P384/Params", genericParamsForCurve(P384())}, 36 {"P521", P521()}, 37 {"P521/Params", genericParamsForCurve(P521())}, 38 } 39 if testing.Short() { 40 tests = tests[:1] 41 } 42 for _, test := range tests { 43 curve := test.curve 44 t.Run(test.name, func(t *testing.T) { 45 t.Parallel() 46 f(t, curve) 47 }) 48 } 49 } 50 51 func TestOnCurve(t *testing.T) { 52 testAllCurves(t, func(t *testing.T, curve Curve) { 53 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) { 54 t.Error("basepoint is not on the curve") 55 } 56 }) 57 } 58 59 func TestOffCurve(t *testing.T) { 60 testAllCurves(t, func(t *testing.T, curve Curve) { 61 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1) 62 if curve.IsOnCurve(x, y) { 63 t.Errorf("point off curve is claimed to be on the curve") 64 } 65 b := Marshal(curve, x, y) 66 x1, y1 := Unmarshal(curve, b) 67 if x1 != nil || y1 != nil { 68 t.Errorf("unmarshaling a point not on the curve succeeded") 69 } 70 }) 71 } 72 73 func TestInfinity(t *testing.T) { 74 testAllCurves(t, testInfinity) 75 } 76 77 func testInfinity(t *testing.T, curve Curve) { 78 _, x, y, _ := GenerateKey(curve, rand.Reader) 79 x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes()) 80 if x.Sign() != 0 || y.Sign() != 0 { 81 t.Errorf("x^q != ∞") 82 } 83 84 x, y = curve.ScalarBaseMult([]byte{0}) 85 if x.Sign() != 0 || y.Sign() != 0 { 86 t.Errorf("b^0 != ∞") 87 x.SetInt64(0) 88 y.SetInt64(0) 89 } 90 91 x2, y2 := curve.Double(x, y) 92 if x2.Sign() != 0 || y2.Sign() != 0 { 93 t.Errorf("2∞ != ∞") 94 } 95 96 baseX := curve.Params().Gx 97 baseY := curve.Params().Gy 98 99 x3, y3 := curve.Add(baseX, baseY, x, y) 100 if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 { 101 t.Errorf("x+∞ != x") 102 } 103 104 x4, y4 := curve.Add(x, y, baseX, baseY) 105 if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 { 106 t.Errorf("∞+x != x") 107 } 108 109 if curve.IsOnCurve(x, y) { 110 t.Errorf("IsOnCurve(∞) == true") 111 } 112 } 113 114 func TestMarshal(t *testing.T) { 115 testAllCurves(t, func(t *testing.T, curve Curve) { 116 _, x, y, err := GenerateKey(curve, rand.Reader) 117 if err != nil { 118 t.Fatal(err) 119 } 120 serialized := Marshal(curve, x, y) 121 xx, yy := Unmarshal(curve, serialized) 122 if xx == nil { 123 t.Fatal("failed to unmarshal") 124 } 125 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 126 t.Fatal("unmarshal returned different values") 127 } 128 }) 129 } 130 131 func TestUnmarshalToLargeCoordinates(t *testing.T) { 132 // See https://golang.org/issues/20482. 133 testAllCurves(t, testUnmarshalToLargeCoordinates) 134 } 135 136 func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) { 137 p := curve.Params().P 138 byteLen := (p.BitLen() + 7) / 8 139 140 // Set x to be greater than curve's parameter P – specifically, to P+5. 141 // Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the 142 // curve. 143 x := new(big.Int).Add(p, big.NewInt(5)) 144 y := curve.Params().polynomial(x) 145 y.ModSqrt(y, p) 146 147 invalid := make([]byte, byteLen*2+1) 148 invalid[0] = 4 // uncompressed encoding 149 x.FillBytes(invalid[1 : 1+byteLen]) 150 y.FillBytes(invalid[1+byteLen:]) 151 152 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 153 t.Errorf("Unmarshal accepts invalid X coordinate") 154 } 155 156 if curve == p256 { 157 // This is a point on the curve with a small y value, small enough that 158 // we can add p and still be within 32 bytes. 159 x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10) 160 y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10) 161 y.Add(y, p) 162 163 if p.Cmp(y) > 0 || y.BitLen() != 256 { 164 t.Fatal("y not within expected range") 165 } 166 167 // marshal 168 x.FillBytes(invalid[1 : 1+byteLen]) 169 y.FillBytes(invalid[1+byteLen:]) 170 171 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 172 t.Errorf("Unmarshal accepts invalid Y coordinate") 173 } 174 } 175 } 176 177 // TestInvalidCoordinates tests big.Int values that are not valid field elements 178 // (negative or bigger than P). They are expected to return false from 179 // IsOnCurve, all other behavior is undefined. 180 func TestInvalidCoordinates(t *testing.T) { 181 testAllCurves(t, testInvalidCoordinates) 182 } 183 184 func testInvalidCoordinates(t *testing.T, curve Curve) { 185 checkIsOnCurveFalse := func(name string, x, y *big.Int) { 186 if curve.IsOnCurve(x, y) { 187 t.Errorf("IsOnCurve(%s) unexpectedly returned true", name) 188 } 189 } 190 191 p := curve.Params().P 192 _, x, y, _ := GenerateKey(curve, rand.Reader) 193 xx, yy := new(big.Int), new(big.Int) 194 195 // Check if the sign is getting dropped. 196 xx.Neg(x) 197 checkIsOnCurveFalse("-x, y", xx, y) 198 yy.Neg(y) 199 checkIsOnCurveFalse("x, -y", x, yy) 200 201 // Check if negative values are reduced modulo P. 202 xx.Sub(x, p) 203 checkIsOnCurveFalse("x-P, y", xx, y) 204 yy.Sub(y, p) 205 checkIsOnCurveFalse("x, y-P", x, yy) 206 207 // Check if positive values are reduced modulo P. 208 xx.Add(x, p) 209 checkIsOnCurveFalse("x+P, y", xx, y) 210 yy.Add(y, p) 211 checkIsOnCurveFalse("x, y+P", x, yy) 212 213 // Check if the overflow is dropped. 214 xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535)) 215 checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y) 216 yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535)) 217 checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy) 218 219 // Check if P is treated like zero (if possible). 220 // y^2 = x^3 - 3x + B 221 // y = mod_sqrt(x^3 - 3x + B) 222 // y = mod_sqrt(B) if x = 0 223 // If there is no modsqrt, there is no point with x = 0, can't test x = P. 224 if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil { 225 if !curve.IsOnCurve(big.NewInt(0), yy) { 226 t.Fatal("(0, mod_sqrt(B)) is not on the curve?") 227 } 228 checkIsOnCurveFalse("P, y", p, yy) 229 } 230 } 231 232 func TestMarshalCompressed(t *testing.T) { 233 t.Run("P-256/03", func(t *testing.T) { 234 data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 235 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 236 y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10) 237 testMarshalCompressed(t, P256(), x, y, data) 238 }) 239 t.Run("P-256/02", func(t *testing.T) { 240 data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 241 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 242 y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10) 243 testMarshalCompressed(t, P256(), x, y, data) 244 }) 245 246 t.Run("Invalid", func(t *testing.T) { 247 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535") 248 X, Y := UnmarshalCompressed(P256(), data) 249 if X != nil || Y != nil { 250 t.Error("expected an error for invalid encoding") 251 } 252 }) 253 254 if testing.Short() { 255 t.Skip("skipping other curves on short test") 256 } 257 258 testAllCurves(t, func(t *testing.T, curve Curve) { 259 _, x, y, err := GenerateKey(curve, rand.Reader) 260 if err != nil { 261 t.Fatal(err) 262 } 263 testMarshalCompressed(t, curve, x, y, nil) 264 }) 265 266 } 267 268 func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) { 269 if !curve.IsOnCurve(x, y) { 270 t.Fatal("invalid test point") 271 } 272 got := MarshalCompressed(curve, x, y) 273 if want != nil && !bytes.Equal(got, want) { 274 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want) 275 } 276 277 X, Y := UnmarshalCompressed(curve, got) 278 if X == nil || Y == nil { 279 t.Fatalf("UnmarshalCompressed failed unexpectedly") 280 } 281 282 if !curve.IsOnCurve(X, Y) { 283 t.Error("UnmarshalCompressed returned a point not on the curve") 284 } 285 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 { 286 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y) 287 } 288 } 289 290 func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) { 291 tests := []struct { 292 name string 293 curve Curve 294 }{ 295 {"P256", P256()}, 296 {"P224", P224()}, 297 {"P384", P384()}, 298 {"P521", P521()}, 299 } 300 for _, test := range tests { 301 curve := test.curve 302 t.Run(test.name, func(t *testing.B) { 303 f(t, curve) 304 }) 305 } 306 } 307 308 func BenchmarkScalarBaseMult(b *testing.B) { 309 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 310 priv, _, _, _ := GenerateKey(curve, rand.Reader) 311 b.ReportAllocs() 312 b.ResetTimer() 313 for i := 0; i < b.N; i++ { 314 x, _ := curve.ScalarBaseMult(priv) 315 // Prevent the compiler from optimizing out the operation. 316 priv[0] ^= byte(x.Bits()[0]) 317 } 318 }) 319 } 320 321 func BenchmarkScalarMult(b *testing.B) { 322 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 323 _, x, y, _ := GenerateKey(curve, rand.Reader) 324 priv, _, _, _ := GenerateKey(curve, rand.Reader) 325 b.ReportAllocs() 326 b.ResetTimer() 327 for i := 0; i < b.N; i++ { 328 x, y = curve.ScalarMult(x, y, priv) 329 } 330 }) 331 }