github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/crypto/internal/edwards25519/field/fe_test.go (about)

     1  // Copyright (c) 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package field
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"io"
    12  	"math/big"
    13  	"math/bits"
    14  	mathrand "math/rand"
    15  	"reflect"
    16  	"testing"
    17  	"testing/quick"
    18  )
    19  
    20  func (v Element) String() string {
    21  	return hex.EncodeToString(v.Bytes())
    22  }
    23  
    24  // quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks)
    25  // times. The default value of -quickchecks is 100.
    26  var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10}
    27  
    28  func generateFieldElement(rand *mathrand.Rand) Element {
    29  	const maskLow52Bits = (1 << 52) - 1
    30  	return Element{
    31  		rand.Uint64() & maskLow52Bits,
    32  		rand.Uint64() & maskLow52Bits,
    33  		rand.Uint64() & maskLow52Bits,
    34  		rand.Uint64() & maskLow52Bits,
    35  		rand.Uint64() & maskLow52Bits,
    36  	}
    37  }
    38  
    39  // weirdLimbs can be combined to generate a range of edge-case field elements.
    40  // 0 and -1 are intentionally more weighted, as they combine well.
    41  var (
    42  	weirdLimbs51 = []uint64{
    43  		0, 0, 0, 0,
    44  		1,
    45  		19 - 1,
    46  		19,
    47  		0x2aaaaaaaaaaaa,
    48  		0x5555555555555,
    49  		(1 << 51) - 20,
    50  		(1 << 51) - 19,
    51  		(1 << 51) - 1, (1 << 51) - 1,
    52  		(1 << 51) - 1, (1 << 51) - 1,
    53  	}
    54  	weirdLimbs52 = []uint64{
    55  		0, 0, 0, 0, 0, 0,
    56  		1,
    57  		19 - 1,
    58  		19,
    59  		0x2aaaaaaaaaaaa,
    60  		0x5555555555555,
    61  		(1 << 51) - 20,
    62  		(1 << 51) - 19,
    63  		(1 << 51) - 1, (1 << 51) - 1,
    64  		(1 << 51) - 1, (1 << 51) - 1,
    65  		(1 << 51) - 1, (1 << 51) - 1,
    66  		1 << 51,
    67  		(1 << 51) + 1,
    68  		(1 << 52) - 19,
    69  		(1 << 52) - 1,
    70  	}
    71  )
    72  
    73  func generateWeirdFieldElement(rand *mathrand.Rand) Element {
    74  	return Element{
    75  		weirdLimbs52[rand.Intn(len(weirdLimbs52))],
    76  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    77  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    78  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    79  		weirdLimbs51[rand.Intn(len(weirdLimbs51))],
    80  	}
    81  }
    82  
    83  func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value {
    84  	if rand.Intn(2) == 0 {
    85  		return reflect.ValueOf(generateWeirdFieldElement(rand))
    86  	}
    87  	return reflect.ValueOf(generateFieldElement(rand))
    88  }
    89  
    90  // isInBounds returns whether the element is within the expected bit size bounds
    91  // after a light reduction.
    92  func isInBounds(x *Element) bool {
    93  	return bits.Len64(x.l0) <= 52 &&
    94  		bits.Len64(x.l1) <= 52 &&
    95  		bits.Len64(x.l2) <= 52 &&
    96  		bits.Len64(x.l3) <= 52 &&
    97  		bits.Len64(x.l4) <= 52
    98  }
    99  
   100  func TestMultiplyDistributesOverAdd(t *testing.T) {
   101  	multiplyDistributesOverAdd := func(x, y, z Element) bool {
   102  		// Compute t1 = (x+y)*z
   103  		t1 := new(Element)
   104  		t1.Add(&x, &y)
   105  		t1.Multiply(t1, &z)
   106  
   107  		// Compute t2 = x*z + y*z
   108  		t2 := new(Element)
   109  		t3 := new(Element)
   110  		t2.Multiply(&x, &z)
   111  		t3.Multiply(&y, &z)
   112  		t2.Add(t2, t3)
   113  
   114  		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
   115  	}
   116  
   117  	if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil {
   118  		t.Error(err)
   119  	}
   120  }
   121  
   122  func TestMul64to128(t *testing.T) {
   123  	a := uint64(5)
   124  	b := uint64(5)
   125  	r := mul64(a, b)
   126  	if r.lo != 0x19 || r.hi != 0 {
   127  		t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
   128  	}
   129  
   130  	a = uint64(18014398509481983) // 2^54 - 1
   131  	b = uint64(18014398509481983) // 2^54 - 1
   132  	r = mul64(a, b)
   133  	if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff {
   134  		t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi)
   135  	}
   136  
   137  	a = uint64(1125899906842661)
   138  	b = uint64(2097155)
   139  	r = mul64(a, b)
   140  	r = addMul64(r, a, b)
   141  	r = addMul64(r, a, b)
   142  	r = addMul64(r, a, b)
   143  	r = addMul64(r, a, b)
   144  	if r.lo != 16888498990613035 || r.hi != 640 {
   145  		t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi)
   146  	}
   147  }
   148  
   149  func TestSetBytesRoundTrip(t *testing.T) {
   150  	f1 := func(in [32]byte, fe Element) bool {
   151  		fe.SetBytes(in[:])
   152  
   153  		// Mask the most significant bit as it's ignored by SetBytes. (Now
   154  		// instead of earlier so we check the masking in SetBytes is working.)
   155  		in[len(in)-1] &= (1 << 7) - 1
   156  
   157  		return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe)
   158  	}
   159  	if err := quick.Check(f1, nil); err != nil {
   160  		t.Errorf("failed bytes->FE->bytes round-trip: %v", err)
   161  	}
   162  
   163  	f2 := func(fe, r Element) bool {
   164  		r.SetBytes(fe.Bytes())
   165  
   166  		// Intentionally not using Equal not to go through Bytes again.
   167  		// Calling reduce because both Generate and SetBytes can produce
   168  		// non-canonical representations.
   169  		fe.reduce()
   170  		r.reduce()
   171  		return fe == r
   172  	}
   173  	if err := quick.Check(f2, nil); err != nil {
   174  		t.Errorf("failed FE->bytes->FE round-trip: %v", err)
   175  	}
   176  
   177  	// Check some fixed vectors from dalek
   178  	type feRTTest struct {
   179  		fe Element
   180  		b  []byte
   181  	}
   182  	var tests = []feRTTest{
   183  		{
   184  			fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676},
   185  			b:  []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31},
   186  		},
   187  		{
   188  			fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972},
   189  			b:  []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122},
   190  		},
   191  	}
   192  
   193  	for _, tt := range tests {
   194  		b := tt.fe.Bytes()
   195  		fe, _ := new(Element).SetBytes(tt.b)
   196  		if !bytes.Equal(b, tt.b) || fe.Equal(&tt.fe) != 1 {
   197  			t.Errorf("Failed fixed roundtrip: %v", tt)
   198  		}
   199  	}
   200  }
   201  
   202  func swapEndianness(buf []byte) []byte {
   203  	for i := 0; i < len(buf)/2; i++ {
   204  		buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i]
   205  	}
   206  	return buf
   207  }
   208  
   209  func TestBytesBigEquivalence(t *testing.T) {
   210  	f1 := func(in [32]byte, fe, fe1 Element) bool {
   211  		fe.SetBytes(in[:])
   212  
   213  		in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit
   214  		b := new(big.Int).SetBytes(swapEndianness(in[:]))
   215  		fe1.fromBig(b)
   216  
   217  		if fe != fe1 {
   218  			return false
   219  		}
   220  
   221  		buf := make([]byte, 32)
   222  		buf = swapEndianness(fe1.toBig().FillBytes(buf))
   223  
   224  		return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1)
   225  	}
   226  	if err := quick.Check(f1, nil); err != nil {
   227  		t.Error(err)
   228  	}
   229  }
   230  
   231  // fromBig sets v = n, and returns v. The bit length of n must not exceed 256.
   232  func (v *Element) fromBig(n *big.Int) *Element {
   233  	if n.BitLen() > 32*8 {
   234  		panic("edwards25519: invalid field element input size")
   235  	}
   236  
   237  	buf := make([]byte, 0, 32)
   238  	for _, word := range n.Bits() {
   239  		for i := 0; i < bits.UintSize; i += 8 {
   240  			if len(buf) >= cap(buf) {
   241  				break
   242  			}
   243  			buf = append(buf, byte(word))
   244  			word >>= 8
   245  		}
   246  	}
   247  
   248  	v.SetBytes(buf[:32])
   249  	return v
   250  }
   251  
   252  func (v *Element) fromDecimal(s string) *Element {
   253  	n, ok := new(big.Int).SetString(s, 10)
   254  	if !ok {
   255  		panic("not a valid decimal: " + s)
   256  	}
   257  	return v.fromBig(n)
   258  }
   259  
   260  // toBig returns v as a big.Int.
   261  func (v *Element) toBig() *big.Int {
   262  	buf := v.Bytes()
   263  
   264  	words := make([]big.Word, 32*8/bits.UintSize)
   265  	for n := range words {
   266  		for i := 0; i < bits.UintSize; i += 8 {
   267  			if len(buf) == 0 {
   268  				break
   269  			}
   270  			words[n] |= big.Word(buf[0]) << big.Word(i)
   271  			buf = buf[1:]
   272  		}
   273  	}
   274  
   275  	return new(big.Int).SetBits(words)
   276  }
   277  
   278  func TestDecimalConstants(t *testing.T) {
   279  	sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752"
   280  	if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 {
   281  		t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp)
   282  	}
   283  	// d is in the parent package, and we don't want to expose d or fromDecimal.
   284  	// dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555"
   285  	// if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 {
   286  	// 	t.Errorf("d is %v, expected %v", d, exp)
   287  	// }
   288  }
   289  
   290  func TestSetBytesRoundTripEdgeCases(t *testing.T) {
   291  	// TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1,
   292  	// and between 2^255 and 2^256-1. Test both the documented SetBytes
   293  	// behavior, and that Bytes reduces them.
   294  }
   295  
   296  // Tests self-consistency between Multiply and Square.
   297  func TestConsistency(t *testing.T) {
   298  	var x Element
   299  	var x2, x2sq Element
   300  
   301  	x = Element{1, 1, 1, 1, 1}
   302  	x2.Multiply(&x, &x)
   303  	x2sq.Square(&x)
   304  
   305  	if x2 != x2sq {
   306  		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
   307  	}
   308  
   309  	var bytes [32]byte
   310  
   311  	_, err := io.ReadFull(rand.Reader, bytes[:])
   312  	if err != nil {
   313  		t.Fatal(err)
   314  	}
   315  	x.SetBytes(bytes[:])
   316  
   317  	x2.Multiply(&x, &x)
   318  	x2sq.Square(&x)
   319  
   320  	if x2 != x2sq {
   321  		t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq)
   322  	}
   323  }
   324  
   325  func TestEqual(t *testing.T) {
   326  	x := Element{1, 1, 1, 1, 1}
   327  	y := Element{5, 4, 3, 2, 1}
   328  
   329  	eq := x.Equal(&x)
   330  	if eq != 1 {
   331  		t.Errorf("wrong about equality")
   332  	}
   333  
   334  	eq = x.Equal(&y)
   335  	if eq != 0 {
   336  		t.Errorf("wrong about inequality")
   337  	}
   338  }
   339  
   340  func TestInvert(t *testing.T) {
   341  	x := Element{1, 1, 1, 1, 1}
   342  	one := Element{1, 0, 0, 0, 0}
   343  	var xinv, r Element
   344  
   345  	xinv.Invert(&x)
   346  	r.Multiply(&x, &xinv)
   347  	r.reduce()
   348  
   349  	if one != r {
   350  		t.Errorf("inversion identity failed, got: %x", r)
   351  	}
   352  
   353  	var bytes [32]byte
   354  
   355  	_, err := io.ReadFull(rand.Reader, bytes[:])
   356  	if err != nil {
   357  		t.Fatal(err)
   358  	}
   359  	x.SetBytes(bytes[:])
   360  
   361  	xinv.Invert(&x)
   362  	r.Multiply(&x, &xinv)
   363  	r.reduce()
   364  
   365  	if one != r {
   366  		t.Errorf("random inversion identity failed, got: %x for field element %x", r, x)
   367  	}
   368  
   369  	zero := Element{}
   370  	x.Set(&zero)
   371  	if xx := xinv.Invert(&x); xx != &xinv {
   372  		t.Errorf("inverting zero did not return the receiver")
   373  	} else if xinv.Equal(&zero) != 1 {
   374  		t.Errorf("inverting zero did not return zero")
   375  	}
   376  }
   377  
   378  func TestSelectSwap(t *testing.T) {
   379  	a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}
   380  	b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}
   381  
   382  	var c, d Element
   383  
   384  	c.Select(&a, &b, 1)
   385  	d.Select(&a, &b, 0)
   386  
   387  	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
   388  		t.Errorf("Select failed")
   389  	}
   390  
   391  	c.Swap(&d, 0)
   392  
   393  	if c.Equal(&a) != 1 || d.Equal(&b) != 1 {
   394  		t.Errorf("Swap failed")
   395  	}
   396  
   397  	c.Swap(&d, 1)
   398  
   399  	if c.Equal(&b) != 1 || d.Equal(&a) != 1 {
   400  		t.Errorf("Swap failed")
   401  	}
   402  }
   403  
   404  func TestMult32(t *testing.T) {
   405  	mult32EquivalentToMul := func(x Element, y uint32) bool {
   406  		t1 := new(Element)
   407  		for i := 0; i < 100; i++ {
   408  			t1.Mult32(&x, y)
   409  		}
   410  
   411  		ty := new(Element)
   412  		ty.l0 = uint64(y)
   413  
   414  		t2 := new(Element)
   415  		for i := 0; i < 100; i++ {
   416  			t2.Multiply(&x, ty)
   417  		}
   418  
   419  		return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2)
   420  	}
   421  
   422  	if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil {
   423  		t.Error(err)
   424  	}
   425  }
   426  
   427  func TestSqrtRatio(t *testing.T) {
   428  	// From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4.
   429  	type test struct {
   430  		u, v      string
   431  		wasSquare int
   432  		r         string
   433  	}
   434  	var tests = []test{
   435  		// If u is 0, the function is defined to return (0, TRUE), even if v
   436  		// is zero. Note that where used in this package, the denominator v
   437  		// is never zero.
   438  		{
   439  			"0000000000000000000000000000000000000000000000000000000000000000",
   440  			"0000000000000000000000000000000000000000000000000000000000000000",
   441  			1, "0000000000000000000000000000000000000000000000000000000000000000",
   442  		},
   443  		// 0/1 == 0²
   444  		{
   445  			"0000000000000000000000000000000000000000000000000000000000000000",
   446  			"0100000000000000000000000000000000000000000000000000000000000000",
   447  			1, "0000000000000000000000000000000000000000000000000000000000000000",
   448  		},
   449  		// If u is non-zero and v is zero, defined to return (0, FALSE).
   450  		{
   451  			"0100000000000000000000000000000000000000000000000000000000000000",
   452  			"0000000000000000000000000000000000000000000000000000000000000000",
   453  			0, "0000000000000000000000000000000000000000000000000000000000000000",
   454  		},
   455  		// 2/1 is not square in this field.
   456  		{
   457  			"0200000000000000000000000000000000000000000000000000000000000000",
   458  			"0100000000000000000000000000000000000000000000000000000000000000",
   459  			0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54",
   460  		},
   461  		// 4/1 == 2²
   462  		{
   463  			"0400000000000000000000000000000000000000000000000000000000000000",
   464  			"0100000000000000000000000000000000000000000000000000000000000000",
   465  			1, "0200000000000000000000000000000000000000000000000000000000000000",
   466  		},
   467  		// 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem
   468  		{
   469  			"0100000000000000000000000000000000000000000000000000000000000000",
   470  			"0400000000000000000000000000000000000000000000000000000000000000",
   471  			1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f",
   472  		},
   473  	}
   474  
   475  	for i, tt := range tests {
   476  		u, _ := new(Element).SetBytes(decodeHex(tt.u))
   477  		v, _ := new(Element).SetBytes(decodeHex(tt.v))
   478  		want, _ := new(Element).SetBytes(decodeHex(tt.r))
   479  		got, wasSquare := new(Element).SqrtRatio(u, v)
   480  		if got.Equal(want) == 0 || wasSquare != tt.wasSquare {
   481  			t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare)
   482  		}
   483  	}
   484  }
   485  
   486  func TestCarryPropagate(t *testing.T) {
   487  	asmLikeGeneric := func(a [5]uint64) bool {
   488  		t1 := &Element{a[0], a[1], a[2], a[3], a[4]}
   489  		t2 := &Element{a[0], a[1], a[2], a[3], a[4]}
   490  
   491  		t1.carryPropagate()
   492  		t2.carryPropagateGeneric()
   493  
   494  		if *t1 != *t2 {
   495  			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
   496  		}
   497  
   498  		return *t1 == *t2 && isInBounds(t2)
   499  	}
   500  
   501  	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
   502  		t.Error(err)
   503  	}
   504  
   505  	if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) {
   506  		t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}")
   507  	}
   508  }
   509  
   510  func TestFeSquare(t *testing.T) {
   511  	asmLikeGeneric := func(a Element) bool {
   512  		t1 := a
   513  		t2 := a
   514  
   515  		feSquareGeneric(&t1, &t1)
   516  		feSquare(&t2, &t2)
   517  
   518  		if t1 != t2 {
   519  			t.Logf("got: %#v,\nexpected: %#v", t1, t2)
   520  		}
   521  
   522  		return t1 == t2 && isInBounds(&t2)
   523  	}
   524  
   525  	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
   526  		t.Error(err)
   527  	}
   528  }
   529  
   530  func TestFeMul(t *testing.T) {
   531  	asmLikeGeneric := func(a, b Element) bool {
   532  		a1 := a
   533  		a2 := a
   534  		b1 := b
   535  		b2 := b
   536  
   537  		feMulGeneric(&a1, &a1, &b1)
   538  		feMul(&a2, &a2, &b2)
   539  
   540  		if a1 != a2 || b1 != b2 {
   541  			t.Logf("got: %#v,\nexpected: %#v", a1, a2)
   542  			t.Logf("got: %#v,\nexpected: %#v", b1, b2)
   543  		}
   544  
   545  		return a1 == a2 && isInBounds(&a2) &&
   546  			b1 == b2 && isInBounds(&b2)
   547  	}
   548  
   549  	if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil {
   550  		t.Error(err)
   551  	}
   552  }
   553  
   554  func decodeHex(s string) []byte {
   555  	b, err := hex.DecodeString(s)
   556  	if err != nil {
   557  		panic(err)
   558  	}
   559  	return b
   560  }