github.com/AESNooper/go/src@v0.0.0-20220218095104-b56a4ab1bbbb/crypto/elliptic/elliptic_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package elliptic
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"math/big"
    12  	"testing"
    13  )
    14  
    15  // genericParamsForCurve returns the dereferenced CurveParams for
    16  // the specified curve. This is used to avoid the logic for
    17  // upgrading a curve to its specific implementation, forcing
    18  // usage of the generic implementation.
    19  func genericParamsForCurve(c Curve) *CurveParams {
    20  	d := *(c.Params())
    21  	return &d
    22  }
    23  
    24  func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
    25  	tests := []struct {
    26  		name  string
    27  		curve Curve
    28  	}{
    29  		{"P256", P256()},
    30  		{"P256/Params", genericParamsForCurve(P256())},
    31  		{"P224", P224()},
    32  		{"P224/Params", genericParamsForCurve(P224())},
    33  		{"P384", P384()},
    34  		{"P384/Params", genericParamsForCurve(P384())},
    35  		{"P521", P521()},
    36  		{"P521/Params", genericParamsForCurve(P521())},
    37  	}
    38  	if testing.Short() {
    39  		tests = tests[:1]
    40  	}
    41  	for _, test := range tests {
    42  		curve := test.curve
    43  		t.Run(test.name, func(t *testing.T) {
    44  			t.Parallel()
    45  			f(t, curve)
    46  		})
    47  	}
    48  }
    49  
    50  func TestOnCurve(t *testing.T) {
    51  	testAllCurves(t, func(t *testing.T, curve Curve) {
    52  		if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
    53  			t.Error("basepoint is not on the curve")
    54  		}
    55  	})
    56  }
    57  
    58  func TestOffCurve(t *testing.T) {
    59  	testAllCurves(t, func(t *testing.T, curve Curve) {
    60  		x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
    61  		if curve.IsOnCurve(x, y) {
    62  			t.Errorf("point off curve is claimed to be on the curve")
    63  		}
    64  		b := Marshal(curve, x, y)
    65  		x1, y1 := Unmarshal(curve, b)
    66  		if x1 != nil || y1 != nil {
    67  			t.Errorf("unmarshaling a point not on the curve succeeded")
    68  		}
    69  	})
    70  }
    71  
    72  func TestInfinity(t *testing.T) {
    73  	testAllCurves(t, testInfinity)
    74  }
    75  
    76  func testInfinity(t *testing.T, curve Curve) {
    77  	_, x, y, _ := GenerateKey(curve, rand.Reader)
    78  	x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes())
    79  	if x.Sign() != 0 || y.Sign() != 0 {
    80  		t.Errorf("x^q != ∞")
    81  	}
    82  
    83  	x, y = curve.ScalarBaseMult([]byte{0})
    84  	if x.Sign() != 0 || y.Sign() != 0 {
    85  		t.Errorf("b^0 != ∞")
    86  		x.SetInt64(0)
    87  		y.SetInt64(0)
    88  	}
    89  
    90  	x2, y2 := curve.Double(x, y)
    91  	if x2.Sign() != 0 || y2.Sign() != 0 {
    92  		t.Errorf("2∞ != ∞")
    93  	}
    94  
    95  	baseX := curve.Params().Gx
    96  	baseY := curve.Params().Gy
    97  
    98  	x3, y3 := curve.Add(baseX, baseY, x, y)
    99  	if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
   100  		t.Errorf("x+∞ != x")
   101  	}
   102  
   103  	x4, y4 := curve.Add(x, y, baseX, baseY)
   104  	if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
   105  		t.Errorf("∞+x != x")
   106  	}
   107  
   108  	if curve.IsOnCurve(x, y) {
   109  		t.Errorf("IsOnCurve(∞) == true")
   110  	}
   111  
   112  	if xx, yy := Unmarshal(curve, Marshal(curve, x, y)); xx != nil || yy != nil {
   113  		t.Errorf("Unmarshal(Marshal(∞)) did not return an error")
   114  	}
   115  	// We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are
   116  	// two valid points with x = 0.
   117  	if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil {
   118  		t.Errorf("Unmarshal(∞) did not return an error")
   119  	}
   120  }
   121  
   122  func TestMarshal(t *testing.T) {
   123  	testAllCurves(t, func(t *testing.T, curve Curve) {
   124  		_, x, y, err := GenerateKey(curve, rand.Reader)
   125  		if err != nil {
   126  			t.Fatal(err)
   127  		}
   128  		serialized := Marshal(curve, x, y)
   129  		xx, yy := Unmarshal(curve, serialized)
   130  		if xx == nil {
   131  			t.Fatal("failed to unmarshal")
   132  		}
   133  		if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   134  			t.Fatal("unmarshal returned different values")
   135  		}
   136  	})
   137  }
   138  
   139  func TestUnmarshalToLargeCoordinates(t *testing.T) {
   140  	// See https://golang.org/issues/20482.
   141  	testAllCurves(t, testUnmarshalToLargeCoordinates)
   142  }
   143  
   144  func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
   145  	p := curve.Params().P
   146  	byteLen := (p.BitLen() + 7) / 8
   147  
   148  	// Set x to be greater than curve's parameter P – specifically, to P+5.
   149  	// Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the
   150  	// curve.
   151  	x := new(big.Int).Add(p, big.NewInt(5))
   152  	y := curve.Params().polynomial(x)
   153  	y.ModSqrt(y, p)
   154  
   155  	invalid := make([]byte, byteLen*2+1)
   156  	invalid[0] = 4 // uncompressed encoding
   157  	x.FillBytes(invalid[1 : 1+byteLen])
   158  	y.FillBytes(invalid[1+byteLen:])
   159  
   160  	if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
   161  		t.Errorf("Unmarshal accepts invalid X coordinate")
   162  	}
   163  
   164  	if curve == p256 {
   165  		// This is a point on the curve with a small y value, small enough that
   166  		// we can add p and still be within 32 bytes.
   167  		x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
   168  		y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
   169  		y.Add(y, p)
   170  
   171  		if p.Cmp(y) > 0 || y.BitLen() != 256 {
   172  			t.Fatal("y not within expected range")
   173  		}
   174  
   175  		// marshal
   176  		x.FillBytes(invalid[1 : 1+byteLen])
   177  		y.FillBytes(invalid[1+byteLen:])
   178  
   179  		if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
   180  			t.Errorf("Unmarshal accepts invalid Y coordinate")
   181  		}
   182  	}
   183  }
   184  
   185  func TestMarshalCompressed(t *testing.T) {
   186  	t.Run("P-256/03", func(t *testing.T) {
   187  		data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
   188  		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
   189  		y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
   190  		testMarshalCompressed(t, P256(), x, y, data)
   191  	})
   192  	t.Run("P-256/02", func(t *testing.T) {
   193  		data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
   194  		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
   195  		y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
   196  		testMarshalCompressed(t, P256(), x, y, data)
   197  	})
   198  
   199  	t.Run("Invalid", func(t *testing.T) {
   200  		data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
   201  		X, Y := UnmarshalCompressed(P256(), data)
   202  		if X != nil || Y != nil {
   203  			t.Error("expected an error for invalid encoding")
   204  		}
   205  	})
   206  
   207  	if testing.Short() {
   208  		t.Skip("skipping other curves on short test")
   209  	}
   210  
   211  	testAllCurves(t, func(t *testing.T, curve Curve) {
   212  		_, x, y, err := GenerateKey(curve, rand.Reader)
   213  		if err != nil {
   214  			t.Fatal(err)
   215  		}
   216  		testMarshalCompressed(t, curve, x, y, nil)
   217  	})
   218  
   219  }
   220  
   221  func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
   222  	if !curve.IsOnCurve(x, y) {
   223  		t.Fatal("invalid test point")
   224  	}
   225  	got := MarshalCompressed(curve, x, y)
   226  	if want != nil && !bytes.Equal(got, want) {
   227  		t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
   228  	}
   229  
   230  	X, Y := UnmarshalCompressed(curve, got)
   231  	if X == nil || Y == nil {
   232  		t.Fatalf("UnmarshalCompressed failed unexpectedly")
   233  	}
   234  
   235  	if !curve.IsOnCurve(X, Y) {
   236  		t.Error("UnmarshalCompressed returned a point not on the curve")
   237  	}
   238  	if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
   239  		t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
   240  	}
   241  }
   242  
   243  func TestLargeIsOnCurve(t *testing.T) {
   244  	testAllCurves(t, func(t *testing.T, curve Curve) {
   245  		large := big.NewInt(1)
   246  		large.Lsh(large, 1000)
   247  		if curve.IsOnCurve(large, large) {
   248  			t.Errorf("(2^1000, 2^1000) is reported on the curve")
   249  		}
   250  	})
   251  }
   252  
   253  func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) {
   254  	tests := []struct {
   255  		name  string
   256  		curve Curve
   257  	}{
   258  		{"P256", P256()},
   259  		{"P224", P224()},
   260  		{"P384", P384()},
   261  		{"P521", P521()},
   262  	}
   263  	for _, test := range tests {
   264  		curve := test.curve
   265  		t.Run(test.name, func(t *testing.B) {
   266  			f(t, curve)
   267  		})
   268  	}
   269  }
   270  
   271  func BenchmarkScalarBaseMult(b *testing.B) {
   272  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   273  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
   274  		b.ReportAllocs()
   275  		b.ResetTimer()
   276  		for i := 0; i < b.N; i++ {
   277  			x, _ := curve.ScalarBaseMult(priv)
   278  			// Prevent the compiler from optimizing out the operation.
   279  			priv[0] ^= byte(x.Bits()[0])
   280  		}
   281  	})
   282  }
   283  
   284  func BenchmarkScalarMult(b *testing.B) {
   285  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   286  		_, x, y, _ := GenerateKey(curve, rand.Reader)
   287  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
   288  		b.ReportAllocs()
   289  		b.ResetTimer()
   290  		for i := 0; i < b.N; i++ {
   291  			x, y = curve.ScalarMult(x, y, priv)
   292  		}
   293  	})
   294  }
   295  
   296  func BenchmarkMarshalUnmarshal(b *testing.B) {
   297  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   298  		_, x, y, _ := GenerateKey(curve, rand.Reader)
   299  		b.Run("Uncompressed", func(b *testing.B) {
   300  			b.ReportAllocs()
   301  			for i := 0; i < b.N; i++ {
   302  				buf := Marshal(curve, x, y)
   303  				xx, yy := Unmarshal(curve, buf)
   304  				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   305  					b.Error("Unmarshal output different from Marshal input")
   306  				}
   307  			}
   308  		})
   309  		b.Run("Compressed", func(b *testing.B) {
   310  			b.ReportAllocs()
   311  			for i := 0; i < b.N; i++ {
   312  				buf := Marshal(curve, x, y)
   313  				xx, yy := Unmarshal(curve, buf)
   314  				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   315  					b.Error("Unmarshal output different from Marshal input")
   316  				}
   317  			}
   318  		})
   319  	})
   320  }