github.com/code-reading/golang@v0.0.0-20220303082512-ba5bc0e589a3/go/src/crypto/elliptic/elliptic_test.go (about) 1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package elliptic 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "encoding/hex" 11 "math/big" 12 "testing" 13 ) 14 15 // genericParamsForCurve returns the dereferenced CurveParams for 16 // the specified curve. This is used to avoid the logic for 17 // upgrading a curve to it's specific implementation, forcing 18 // usage of the generic implementation. This is only relevant 19 // for the P224, P256, and P521 curves. 20 func genericParamsForCurve(c Curve) *CurveParams { 21 d := *(c.Params()) 22 return &d 23 } 24 25 func testAllCurves(t *testing.T, f func(*testing.T, Curve)) { 26 tests := []struct { 27 name string 28 curve Curve 29 }{ 30 {"P256", P256()}, 31 {"P256/Params", genericParamsForCurve(P256())}, 32 {"P224", P224()}, 33 {"P224/Params", genericParamsForCurve(P224())}, 34 {"P384", P384()}, 35 {"P384/Params", genericParamsForCurve(P384())}, 36 {"P521", P521()}, 37 {"P521/Params", genericParamsForCurve(P521())}, 38 } 39 if testing.Short() { 40 tests = tests[:1] 41 } 42 for _, test := range tests { 43 curve := test.curve 44 t.Run(test.name, func(t *testing.T) { 45 t.Parallel() 46 f(t, curve) 47 }) 48 } 49 } 50 51 func TestOnCurve(t *testing.T) { 52 testAllCurves(t, func(t *testing.T, curve Curve) { 53 if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) { 54 t.Error("basepoint is not on the curve") 55 } 56 }) 57 } 58 59 func TestOffCurve(t *testing.T) { 60 testAllCurves(t, func(t *testing.T, curve Curve) { 61 x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1) 62 if curve.IsOnCurve(x, y) { 63 t.Errorf("point off curve is claimed to be on the curve") 64 } 65 b := Marshal(curve, x, y) 66 x1, y1 := Unmarshal(curve, b) 67 if x1 != nil || y1 != nil { 68 t.Errorf("unmarshaling a point not on the curve succeeded") 69 } 70 }) 71 } 72 73 func TestInfinity(t *testing.T) { 74 testAllCurves(t, testInfinity) 75 } 76 77 func testInfinity(t *testing.T, curve Curve) { 78 _, x, y, _ := GenerateKey(curve, rand.Reader) 79 x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes()) 80 if x.Sign() != 0 || y.Sign() != 0 { 81 t.Errorf("x^q != ∞") 82 } 83 84 x, y = curve.ScalarBaseMult([]byte{0}) 85 if x.Sign() != 0 || y.Sign() != 0 { 86 t.Errorf("b^0 != ∞") 87 x.SetInt64(0) 88 y.SetInt64(0) 89 } 90 91 x2, y2 := curve.Double(x, y) 92 if x2.Sign() != 0 || y2.Sign() != 0 { 93 t.Errorf("2∞ != ∞") 94 } 95 96 baseX := curve.Params().Gx 97 baseY := curve.Params().Gy 98 99 x3, y3 := curve.Add(baseX, baseY, x, y) 100 if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 { 101 t.Errorf("x+∞ != x") 102 } 103 104 x4, y4 := curve.Add(x, y, baseX, baseY) 105 if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 { 106 t.Errorf("∞+x != x") 107 } 108 109 if curve.IsOnCurve(x, y) { 110 t.Errorf("IsOnCurve(∞) == true") 111 } 112 } 113 114 func TestMarshal(t *testing.T) { 115 testAllCurves(t, func(t *testing.T, curve Curve) { 116 _, x, y, err := GenerateKey(curve, rand.Reader) 117 if err != nil { 118 t.Fatal(err) 119 } 120 serialized := Marshal(curve, x, y) 121 xx, yy := Unmarshal(curve, serialized) 122 if xx == nil { 123 t.Fatal("failed to unmarshal") 124 } 125 if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { 126 t.Fatal("unmarshal returned different values") 127 } 128 }) 129 } 130 131 func TestUnmarshalToLargeCoordinates(t *testing.T) { 132 // See https://golang.org/issues/20482. 133 testAllCurves(t, testUnmarshalToLargeCoordinates) 134 } 135 136 func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) { 137 p := curve.Params().P 138 byteLen := (p.BitLen() + 7) / 8 139 140 // Set x to be greater than curve's parameter P – specifically, to P+5. 141 // Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the 142 // curve. 143 x := new(big.Int).Add(p, big.NewInt(5)) 144 y := curve.Params().polynomial(x) 145 y.ModSqrt(y, p) 146 147 invalid := make([]byte, byteLen*2+1) 148 invalid[0] = 4 // uncompressed encoding 149 x.FillBytes(invalid[1 : 1+byteLen]) 150 y.FillBytes(invalid[1+byteLen:]) 151 152 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 153 t.Errorf("Unmarshal accepts invalid X coordinate") 154 } 155 156 if curve == p256 { 157 // This is a point on the curve with a small y value, small enough that 158 // we can add p and still be within 32 bytes. 159 x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10) 160 y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10) 161 y.Add(y, p) 162 163 if p.Cmp(y) > 0 || y.BitLen() != 256 { 164 t.Fatal("y not within expected range") 165 } 166 167 // marshal 168 x.FillBytes(invalid[1 : 1+byteLen]) 169 y.FillBytes(invalid[1+byteLen:]) 170 171 if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil { 172 t.Errorf("Unmarshal accepts invalid Y coordinate") 173 } 174 } 175 } 176 177 func TestMarshalCompressed(t *testing.T) { 178 t.Run("P-256/03", func(t *testing.T) { 179 data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 180 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 181 y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10) 182 testMarshalCompressed(t, P256(), x, y, data) 183 }) 184 t.Run("P-256/02", func(t *testing.T) { 185 data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79") 186 x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10) 187 y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10) 188 testMarshalCompressed(t, P256(), x, y, data) 189 }) 190 191 t.Run("Invalid", func(t *testing.T) { 192 data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535") 193 X, Y := UnmarshalCompressed(P256(), data) 194 if X != nil || Y != nil { 195 t.Error("expected an error for invalid encoding") 196 } 197 }) 198 199 if testing.Short() { 200 t.Skip("skipping other curves on short test") 201 } 202 203 testAllCurves(t, func(t *testing.T, curve Curve) { 204 _, x, y, err := GenerateKey(curve, rand.Reader) 205 if err != nil { 206 t.Fatal(err) 207 } 208 testMarshalCompressed(t, curve, x, y, nil) 209 }) 210 211 } 212 213 func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) { 214 if !curve.IsOnCurve(x, y) { 215 t.Fatal("invalid test point") 216 } 217 got := MarshalCompressed(curve, x, y) 218 if want != nil && !bytes.Equal(got, want) { 219 t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want) 220 } 221 222 X, Y := UnmarshalCompressed(curve, got) 223 if X == nil || Y == nil { 224 t.Fatalf("UnmarshalCompressed failed unexpectedly") 225 } 226 227 if !curve.IsOnCurve(X, Y) { 228 t.Error("UnmarshalCompressed returned a point not on the curve") 229 } 230 if X.Cmp(x) != 0 || Y.Cmp(y) != 0 { 231 t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y) 232 } 233 } 234 235 func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) { 236 tests := []struct { 237 name string 238 curve Curve 239 }{ 240 {"P256", P256()}, 241 {"P224", P224()}, 242 {"P384", P384()}, 243 {"P521", P521()}, 244 } 245 for _, test := range tests { 246 curve := test.curve 247 t.Run(test.name, func(t *testing.B) { 248 f(t, curve) 249 }) 250 } 251 } 252 253 func BenchmarkScalarBaseMult(b *testing.B) { 254 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 255 priv, _, _, _ := GenerateKey(curve, rand.Reader) 256 b.ReportAllocs() 257 b.ResetTimer() 258 for i := 0; i < b.N; i++ { 259 x, _ := curve.ScalarBaseMult(priv) 260 // Prevent the compiler from optimizing out the operation. 261 priv[0] ^= byte(x.Bits()[0]) 262 } 263 }) 264 } 265 266 func BenchmarkScalarMult(b *testing.B) { 267 benchmarkAllCurves(b, func(b *testing.B, curve Curve) { 268 _, x, y, _ := GenerateKey(curve, rand.Reader) 269 priv, _, _, _ := GenerateKey(curve, rand.Reader) 270 b.ReportAllocs() 271 b.ResetTimer() 272 for i := 0; i < b.N; i++ { 273 x, y = curve.ScalarMult(x, y, priv) 274 } 275 }) 276 }