github.com/AESNooper/go/src@v0.0.0-20220218095104-b56a4ab1bbbb/crypto/elliptic/elliptic_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package elliptic 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "encoding/hex" 11 "math/big" 12 "testing" 13 ) 14 15 // genericParamsForCurve returns the dereferenced CurveParams for 16 // the specified curve. This is used to avoid the logic for 17 // upgrading a curve to its specific implementation, forcing 18 // usage of the generic implementation. 19 func genericParamsForCurve(c Curve) *CurveParams { 20 d := *(c.Params()) 21 return &d 22 } 23 24 func testAllCurves(t *testing.T, f func(*testing.T, Curve)) { 25 tests := []struct { 26 name string 27 curve Curve 28 }{ 29 {"P256", P256()}, 30 {"P256/Params", genericParamsForCurve(P256())}, 31 {"P224", P224()}, 32 {"P224/Params", genericParamsForCurve(P224())}, 33 {"P384", P384()}, 34 {"P384/Params", genericParamsForCurve(P384())}, 35 {"P521", P521()}, 36 {"P521/Params", genericParamsForCurve(P521())}, 37 } 38 if testing.Short() { 39 tests = tests[:1] 40 } 41 for _, test := range tests { 42 curve := test.curve 43 t.Run(test.name, func(t *testing.T) { 44 t.Parallel() 45 f(t, curve) 46 }) 47 } 48 } 49 50 func TestOnCurve(t *testing.T) { 51 testAllCurves(t, func(t *testing.T, curve Curve) { 52 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) { 53 t.Error("basepoint is not on the curve") 54 } 55 }) 56 } 57 58 func TestOffCurve(t *testing.T) { 59 testAllCurves(t, func(t *testing.T, curve Curve) { 60 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1) 61 if curve.IsOnCurve(x, y) { 62 t.Errorf("point off curve is claimed to be on the curve") 63 } 64 b := Marshal(curve, x, y) 65 x1, y1 := Unmarshal(curve, b) 66 if x1 != nil || y1 != nil { 67 t.Errorf("unmarshaling a point not on the curve succeeded") 68 } 69 }) 70 } 71 72 func TestInfinity(t *testing.T) { 73 testAllCurves(t, testInfinity) 74 } 75 76 func testInfinity(t *testing.T, curve Curve) { 77 _, x, y, _ := GenerateKey(curve, rand.Reader) 78 x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes()) 79 if x.Sign() != 0 || y.Sign() != 0 { 80 t.Errorf("x^q != ∞") 81 } 82 83 x, y = curve.ScalarBaseMult([]byte{0}) 84 if x.Sign() != 0 || y.Sign() != 0 { 85 t.Errorf("b^0 != ∞") 86 x.SetInt64(0) 87 y.SetInt64(0) 88 } 89 90 x2, y2 := curve.Double(x, y) 91 if x2.Sign() != 0 || y2.Sign() != 0 { 92 t.Errorf("2∞ != ∞") 93 } 94 95 baseX := curve.Params().Gx 96 baseY := curve.Params().Gy 97 98 x3, y3 := curve.Add(baseX, baseY, x, y) 99 if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 { 100 t.Errorf("x+∞ != x") 101 } 102 103 x4, y4 := curve.Add(x, y, baseX, baseY) 104 if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 { 105 t.Errorf("∞+x != x") 106 } 107 108 if curve.IsOnCurve(x, y) { 109 t.Errorf("IsOnCurve(∞) == true") 110 } 111 112 if xx, yy := Unmarshal(curve, Marshal(curve, x, y)); xx != nil || yy != nil { 113 t.Errorf("Unmarshal(Marshal(∞)) did not return an error") 114 } 115 // We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are 116 // two valid points with x = 0. 117 if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil { 118 t.Errorf("Unmarshal(∞) did not return an error") 119 } 120 } 121 122 func TestMarshal(t *testing.T) { 123 testAllCurves(t, func(t *testing.T, curve Curve) { 124 _, x, y, err := GenerateKey(curve, rand.Reader) 125 if err != nil { 126 t.Fatal(err) 127 } 128 serialized := Marshal(curve, x, y) 129 xx, yy := Unmarshal(curve, serialized) 130 if xx == nil { 131 t.Fatal("failed to unmarshal") 132 } 133 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 134 t.Fatal("unmarshal returned different values") 135 } 136 }) 137 } 138 139 func TestUnmarshalToLargeCoordinates(t *testing.T) { 140 // See https://golang.org/issues/20482. 141 testAllCurves(t, testUnmarshalToLargeCoordinates) 142 } 143 144 func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) { 145 p := curve.Params().P 146 byteLen := (p.BitLen() + 7) / 8 147 148 // Set x to be greater than curve's parameter P – specifically, to P+5. 149 // Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the 150 // curve. 151 x := new(big.Int).Add(p, big.NewInt(5)) 152 y := curve.Params().polynomial(x) 153 y.ModSqrt(y, p) 154 155 invalid := make([]byte, byteLen*2+1) 156 invalid[0] = 4 // uncompressed encoding 157 x.FillBytes(invalid[1 : 1+byteLen]) 158 y.FillBytes(invalid[1+byteLen:]) 159 160 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 161 t.Errorf("Unmarshal accepts invalid X coordinate") 162 } 163 164 if curve == p256 { 165 // This is a point on the curve with a small y value, small enough that 166 // we can add p and still be within 32 bytes. 167 x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10) 168 y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10) 169 y.Add(y, p) 170 171 if p.Cmp(y) > 0 || y.BitLen() != 256 { 172 t.Fatal("y not within expected range") 173 } 174 175 // marshal 176 x.FillBytes(invalid[1 : 1+byteLen]) 177 y.FillBytes(invalid[1+byteLen:]) 178 179 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 180 t.Errorf("Unmarshal accepts invalid Y coordinate") 181 } 182 } 183 } 184 185 func TestMarshalCompressed(t *testing.T) { 186 t.Run("P-256/03", func(t *testing.T) { 187 data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 188 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 189 y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10) 190 testMarshalCompressed(t, P256(), x, y, data) 191 }) 192 t.Run("P-256/02", func(t *testing.T) { 193 data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 194 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 195 y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10) 196 testMarshalCompressed(t, P256(), x, y, data) 197 }) 198 199 t.Run("Invalid", func(t *testing.T) { 200 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535") 201 X, Y := UnmarshalCompressed(P256(), data) 202 if X != nil || Y != nil { 203 t.Error("expected an error for invalid encoding") 204 } 205 }) 206 207 if testing.Short() { 208 t.Skip("skipping other curves on short test") 209 } 210 211 testAllCurves(t, func(t *testing.T, curve Curve) { 212 _, x, y, err := GenerateKey(curve, rand.Reader) 213 if err != nil { 214 t.Fatal(err) 215 } 216 testMarshalCompressed(t, curve, x, y, nil) 217 }) 218 219 } 220 221 func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) { 222 if !curve.IsOnCurve(x, y) { 223 t.Fatal("invalid test point") 224 } 225 got := MarshalCompressed(curve, x, y) 226 if want != nil && !bytes.Equal(got, want) { 227 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want) 228 } 229 230 X, Y := UnmarshalCompressed(curve, got) 231 if X == nil || Y == nil { 232 t.Fatalf("UnmarshalCompressed failed unexpectedly") 233 } 234 235 if !curve.IsOnCurve(X, Y) { 236 t.Error("UnmarshalCompressed returned a point not on the curve") 237 } 238 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 { 239 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y) 240 } 241 } 242 243 func TestLargeIsOnCurve(t *testing.T) { 244 testAllCurves(t, func(t *testing.T, curve Curve) { 245 large := big.NewInt(1) 246 large.Lsh(large, 1000) 247 if curve.IsOnCurve(large, large) { 248 t.Errorf("(2^1000, 2^1000) is reported on the curve") 249 } 250 }) 251 } 252 253 func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) { 254 tests := []struct { 255 name string 256 curve Curve 257 }{ 258 {"P256", P256()}, 259 {"P224", P224()}, 260 {"P384", P384()}, 261 {"P521", P521()}, 262 } 263 for _, test := range tests { 264 curve := test.curve 265 t.Run(test.name, func(t *testing.B) { 266 f(t, curve) 267 }) 268 } 269 } 270 271 func BenchmarkScalarBaseMult(b *testing.B) { 272 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 273 priv, _, _, _ := GenerateKey(curve, rand.Reader) 274 b.ReportAllocs() 275 b.ResetTimer() 276 for i := 0; i < b.N; i++ { 277 x, _ := curve.ScalarBaseMult(priv) 278 // Prevent the compiler from optimizing out the operation. 279 priv[0] ^= byte(x.Bits()[0]) 280 } 281 }) 282 } 283 284 func BenchmarkScalarMult(b *testing.B) { 285 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 286 _, x, y, _ := GenerateKey(curve, rand.Reader) 287 priv, _, _, _ := GenerateKey(curve, rand.Reader) 288 b.ReportAllocs() 289 b.ResetTimer() 290 for i := 0; i < b.N; i++ { 291 x, y = curve.ScalarMult(x, y, priv) 292 } 293 }) 294 } 295 296 func BenchmarkMarshalUnmarshal(b *testing.B) { 297 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 298 _, x, y, _ := GenerateKey(curve, rand.Reader) 299 b.Run("Uncompressed", func(b *testing.B) { 300 b.ReportAllocs() 301 for i := 0; i < b.N; i++ { 302 buf := Marshal(curve, x, y) 303 xx, yy := Unmarshal(curve, buf) 304 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 305 b.Error("Unmarshal output different from Marshal input") 306 } 307 } 308 }) 309 b.Run("Compressed", func(b *testing.B) { 310 b.ReportAllocs() 311 for i := 0; i < b.N; i++ { 312 buf := Marshal(curve, x, y) 313 xx, yy := Unmarshal(curve, buf) 314 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 315 b.Error("Unmarshal output different from Marshal input") 316 } 317 } 318 }) 319 }) 320 }