github.com/comwrg/go/src@v0.0.0-20220319063731-c238d0440370/crypto/elliptic/elliptic_test.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package elliptic
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"math/big"
    12  	"testing"
    13  )
    14  
    15  // genericParamsForCurve returns the dereferenced CurveParams for
    16  // the specified curve. This is used to avoid the logic for
    17  // upgrading a curve to it's specific implementation, forcing
    18  // usage of the generic implementation. This is only relevant
    19  // for the P224, P256, and P521 curves.
    20  func genericParamsForCurve(c Curve) *CurveParams {
    21  	d := *(c.Params())
    22  	return &d
    23  }
    24  
    25  func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
    26  	tests := []struct {
    27  		name  string
    28  		curve Curve
    29  	}{
    30  		{"P256", P256()},
    31  		{"P256/Params", genericParamsForCurve(P256())},
    32  		{"P224", P224()},
    33  		{"P224/Params", genericParamsForCurve(P224())},
    34  		{"P384", P384()},
    35  		{"P384/Params", genericParamsForCurve(P384())},
    36  		{"P521", P521()},
    37  		{"P521/Params", genericParamsForCurve(P521())},
    38  	}
    39  	if testing.Short() {
    40  		tests = tests[:1]
    41  	}
    42  	for _, test := range tests {
    43  		curve := test.curve
    44  		t.Run(test.name, func(t *testing.T) {
    45  			t.Parallel()
    46  			f(t, curve)
    47  		})
    48  	}
    49  }
    50  
    51  func TestOnCurve(t *testing.T) {
    52  	testAllCurves(t, func(t *testing.T, curve Curve) {
    53  		if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
    54  			t.Error("basepoint is not on the curve")
    55  		}
    56  	})
    57  }
    58  
    59  func TestOffCurve(t *testing.T) {
    60  	testAllCurves(t, func(t *testing.T, curve Curve) {
    61  		x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
    62  		if curve.IsOnCurve(x, y) {
    63  			t.Errorf("point off curve is claimed to be on the curve")
    64  		}
    65  		b := Marshal(curve, x, y)
    66  		x1, y1 := Unmarshal(curve, b)
    67  		if x1 != nil || y1 != nil {
    68  			t.Errorf("unmarshaling a point not on the curve succeeded")
    69  		}
    70  	})
    71  }
    72  
    73  func TestInfinity(t *testing.T) {
    74  	testAllCurves(t, testInfinity)
    75  }
    76  
    77  func testInfinity(t *testing.T, curve Curve) {
    78  	_, x, y, _ := GenerateKey(curve, rand.Reader)
    79  	x, y = curve.ScalarMult(x, y, curve.Params().N.Bytes())
    80  	if x.Sign() != 0 || y.Sign() != 0 {
    81  		t.Errorf("x^q != ∞")
    82  	}
    83  
    84  	x, y = curve.ScalarBaseMult([]byte{0})
    85  	if x.Sign() != 0 || y.Sign() != 0 {
    86  		t.Errorf("b^0 != ∞")
    87  		x.SetInt64(0)
    88  		y.SetInt64(0)
    89  	}
    90  
    91  	x2, y2 := curve.Double(x, y)
    92  	if x2.Sign() != 0 || y2.Sign() != 0 {
    93  		t.Errorf("2∞ != ∞")
    94  	}
    95  
    96  	baseX := curve.Params().Gx
    97  	baseY := curve.Params().Gy
    98  
    99  	x3, y3 := curve.Add(baseX, baseY, x, y)
   100  	if x3.Cmp(baseX) != 0 || y3.Cmp(baseY) != 0 {
   101  		t.Errorf("x+∞ != x")
   102  	}
   103  
   104  	x4, y4 := curve.Add(x, y, baseX, baseY)
   105  	if x4.Cmp(baseX) != 0 || y4.Cmp(baseY) != 0 {
   106  		t.Errorf("∞+x != x")
   107  	}
   108  
   109  	if curve.IsOnCurve(x, y) {
   110  		t.Errorf("IsOnCurve(∞) == true")
   111  	}
   112  }
   113  
   114  func TestMarshal(t *testing.T) {
   115  	testAllCurves(t, func(t *testing.T, curve Curve) {
   116  		_, x, y, err := GenerateKey(curve, rand.Reader)
   117  		if err != nil {
   118  			t.Fatal(err)
   119  		}
   120  		serialized := Marshal(curve, x, y)
   121  		xx, yy := Unmarshal(curve, serialized)
   122  		if xx == nil {
   123  			t.Fatal("failed to unmarshal")
   124  		}
   125  		if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
   126  			t.Fatal("unmarshal returned different values")
   127  		}
   128  	})
   129  }
   130  
   131  func TestUnmarshalToLargeCoordinates(t *testing.T) {
   132  	// See https://golang.org/issues/20482.
   133  	testAllCurves(t, testUnmarshalToLargeCoordinates)
   134  }
   135  
   136  func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
   137  	p := curve.Params().P
   138  	byteLen := (p.BitLen() + 7) / 8
   139  
   140  	// Set x to be greater than curve's parameter P – specifically, to P+5.
   141  	// Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the
   142  	// curve.
   143  	x := new(big.Int).Add(p, big.NewInt(5))
   144  	y := curve.Params().polynomial(x)
   145  	y.ModSqrt(y, p)
   146  
   147  	invalid := make([]byte, byteLen*2+1)
   148  	invalid[0] = 4 // uncompressed encoding
   149  	x.FillBytes(invalid[1 : 1+byteLen])
   150  	y.FillBytes(invalid[1+byteLen:])
   151  
   152  	if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
   153  		t.Errorf("Unmarshal accepts invalid X coordinate")
   154  	}
   155  
   156  	if curve == p256 {
   157  		// This is a point on the curve with a small y value, small enough that
   158  		// we can add p and still be within 32 bytes.
   159  		x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
   160  		y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
   161  		y.Add(y, p)
   162  
   163  		if p.Cmp(y) > 0 || y.BitLen() != 256 {
   164  			t.Fatal("y not within expected range")
   165  		}
   166  
   167  		// marshal
   168  		x.FillBytes(invalid[1 : 1+byteLen])
   169  		y.FillBytes(invalid[1+byteLen:])
   170  
   171  		if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
   172  			t.Errorf("Unmarshal accepts invalid Y coordinate")
   173  		}
   174  	}
   175  }
   176  
   177  // TestInvalidCoordinates tests big.Int values that are not valid field elements
   178  // (negative or bigger than P). They are expected to return false from
   179  // IsOnCurve, all other behavior is undefined.
   180  func TestInvalidCoordinates(t *testing.T) {
   181  	testAllCurves(t, testInvalidCoordinates)
   182  }
   183  
   184  func testInvalidCoordinates(t *testing.T, curve Curve) {
   185  	checkIsOnCurveFalse := func(name string, x, y *big.Int) {
   186  		if curve.IsOnCurve(x, y) {
   187  			t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
   188  		}
   189  	}
   190  
   191  	p := curve.Params().P
   192  	_, x, y, _ := GenerateKey(curve, rand.Reader)
   193  	xx, yy := new(big.Int), new(big.Int)
   194  
   195  	// Check if the sign is getting dropped.
   196  	xx.Neg(x)
   197  	checkIsOnCurveFalse("-x, y", xx, y)
   198  	yy.Neg(y)
   199  	checkIsOnCurveFalse("x, -y", x, yy)
   200  
   201  	// Check if negative values are reduced modulo P.
   202  	xx.Sub(x, p)
   203  	checkIsOnCurveFalse("x-P, y", xx, y)
   204  	yy.Sub(y, p)
   205  	checkIsOnCurveFalse("x, y-P", x, yy)
   206  
   207  	// Check if positive values are reduced modulo P.
   208  	xx.Add(x, p)
   209  	checkIsOnCurveFalse("x+P, y", xx, y)
   210  	yy.Add(y, p)
   211  	checkIsOnCurveFalse("x, y+P", x, yy)
   212  
   213  	// Check if the overflow is dropped.
   214  	xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
   215  	checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
   216  	yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
   217  	checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
   218  
   219  	// Check if P is treated like zero (if possible).
   220  	// y^2 = x^3 - 3x + B
   221  	// y = mod_sqrt(x^3 - 3x + B)
   222  	// y = mod_sqrt(B) if x = 0
   223  	// If there is no modsqrt, there is no point with x = 0, can't test x = P.
   224  	if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
   225  		if !curve.IsOnCurve(big.NewInt(0), yy) {
   226  			t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
   227  		}
   228  		checkIsOnCurveFalse("P, y", p, yy)
   229  	}
   230  }
   231  
   232  func TestMarshalCompressed(t *testing.T) {
   233  	t.Run("P-256/03", func(t *testing.T) {
   234  		data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
   235  		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
   236  		y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
   237  		testMarshalCompressed(t, P256(), x, y, data)
   238  	})
   239  	t.Run("P-256/02", func(t *testing.T) {
   240  		data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
   241  		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
   242  		y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
   243  		testMarshalCompressed(t, P256(), x, y, data)
   244  	})
   245  
   246  	t.Run("Invalid", func(t *testing.T) {
   247  		data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
   248  		X, Y := UnmarshalCompressed(P256(), data)
   249  		if X != nil || Y != nil {
   250  			t.Error("expected an error for invalid encoding")
   251  		}
   252  	})
   253  
   254  	if testing.Short() {
   255  		t.Skip("skipping other curves on short test")
   256  	}
   257  
   258  	testAllCurves(t, func(t *testing.T, curve Curve) {
   259  		_, x, y, err := GenerateKey(curve, rand.Reader)
   260  		if err != nil {
   261  			t.Fatal(err)
   262  		}
   263  		testMarshalCompressed(t, curve, x, y, nil)
   264  	})
   265  
   266  }
   267  
   268  func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
   269  	if !curve.IsOnCurve(x, y) {
   270  		t.Fatal("invalid test point")
   271  	}
   272  	got := MarshalCompressed(curve, x, y)
   273  	if want != nil && !bytes.Equal(got, want) {
   274  		t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
   275  	}
   276  
   277  	X, Y := UnmarshalCompressed(curve, got)
   278  	if X == nil || Y == nil {
   279  		t.Fatalf("UnmarshalCompressed failed unexpectedly")
   280  	}
   281  
   282  	if !curve.IsOnCurve(X, Y) {
   283  		t.Error("UnmarshalCompressed returned a point not on the curve")
   284  	}
   285  	if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
   286  		t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
   287  	}
   288  }
   289  
   290  func benchmarkAllCurves(t *testing.B, f func(*testing.B, Curve)) {
   291  	tests := []struct {
   292  		name  string
   293  		curve Curve
   294  	}{
   295  		{"P256", P256()},
   296  		{"P224", P224()},
   297  		{"P384", P384()},
   298  		{"P521", P521()},
   299  	}
   300  	for _, test := range tests {
   301  		curve := test.curve
   302  		t.Run(test.name, func(t *testing.B) {
   303  			f(t, curve)
   304  		})
   305  	}
   306  }
   307  
   308  func BenchmarkScalarBaseMult(b *testing.B) {
   309  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   310  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
   311  		b.ReportAllocs()
   312  		b.ResetTimer()
   313  		for i := 0; i < b.N; i++ {
   314  			x, _ := curve.ScalarBaseMult(priv)
   315  			// Prevent the compiler from optimizing out the operation.
   316  			priv[0] ^= byte(x.Bits()[0])
   317  		}
   318  	})
   319  }
   320  
   321  func BenchmarkScalarMult(b *testing.B) {
   322  	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
   323  		_, x, y, _ := GenerateKey(curve, rand.Reader)
   324  		priv, _, _, _ := GenerateKey(curve, rand.Reader)
   325  		b.ReportAllocs()
   326  		b.ResetTimer()
   327  		for i := 0; i < b.N; i++ {
   328  			x, y = curve.ScalarMult(x, y, priv)
   329  		}
   330  	})
   331  }