github.com/cloudflare/circl@v1.5.0/math/fp25519/fp_test.go (about)

     1  package fp25519
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"fmt"
     7  	"math/big"
     8  	"testing"
     9  
    10  	"github.com/cloudflare/circl/internal/conv"
    11  	"github.com/cloudflare/circl/internal/test"
    12  )
    13  
    14  type (
    15  	tcmov   func(x, y *Elt, n uint)
    16  	tcswap  func(x, y *Elt, n uint)
    17  	tadd    func(z, x, y *Elt)
    18  	tsub    func(z, x, y *Elt)
    19  	taddsub func(x, y *Elt)
    20  	tmul    func(z, x, y *Elt)
    21  	tsqr    func(z, x *Elt)
    22  	tmodp   func(z *Elt)
    23  )
    24  
    25  func testCmov(t *testing.T, f tcmov) {
    26  	const numTests = 1 << 9
    27  	var x, y Elt
    28  	for i := 0; i < numTests; i++ {
    29  		_, _ = rand.Read(x[:])
    30  		_, _ = rand.Read(y[:])
    31  		b := uint(y[0] & 0x1)
    32  		want := conv.BytesLe2BigInt(x[:])
    33  		if b != 0 {
    34  			want = conv.BytesLe2BigInt(y[:])
    35  		}
    36  
    37  		f(&x, &y, b)
    38  		got := conv.BytesLe2BigInt(x[:])
    39  
    40  		if got.Cmp(want) != 0 {
    41  			test.ReportError(t, got, want, x, y, b)
    42  		}
    43  	}
    44  }
    45  
    46  func testCswap(t *testing.T, f tcswap) {
    47  	const numTests = 1 << 9
    48  	var x, y Elt
    49  	for i := 0; i < numTests; i++ {
    50  		_, _ = rand.Read(x[:])
    51  		_, _ = rand.Read(y[:])
    52  		b := uint(y[0] & 0x1)
    53  		want0 := conv.BytesLe2BigInt(x[:])
    54  		want1 := conv.BytesLe2BigInt(y[:])
    55  		if b != 0 {
    56  			want0 = conv.BytesLe2BigInt(y[:])
    57  			want1 = conv.BytesLe2BigInt(x[:])
    58  		}
    59  
    60  		f(&x, &y, b)
    61  		got0 := conv.BytesLe2BigInt(x[:])
    62  		got1 := conv.BytesLe2BigInt(y[:])
    63  
    64  		if got0.Cmp(want0) != 0 {
    65  			test.ReportError(t, got0, want0, x, y, b)
    66  		}
    67  		if got1.Cmp(want1) != 0 {
    68  			test.ReportError(t, got1, want1, x, y, b)
    69  		}
    70  	}
    71  }
    72  
    73  func testAdd(t *testing.T, f tadd) {
    74  	const numTests = 1 << 9
    75  	var x, y, z Elt
    76  	prime := P()
    77  	p := conv.BytesLe2BigInt(prime[:])
    78  	for i := 0; i < numTests; i++ {
    79  		_, _ = rand.Read(x[:])
    80  		_, _ = rand.Read(y[:])
    81  		f(&z, &x, &y)
    82  		Modp(&z)
    83  		got := conv.BytesLe2BigInt(z[:])
    84  
    85  		xx, yy := conv.BytesLe2BigInt(x[:]), conv.BytesLe2BigInt(y[:])
    86  		want := xx.Add(xx, yy).Mod(xx, p)
    87  
    88  		if got.Cmp(want) != 0 {
    89  			test.ReportError(t, got, want, x, y)
    90  		}
    91  	}
    92  }
    93  
    94  func testSub(t *testing.T, f tsub) {
    95  	const numTests = 1 << 9
    96  	var x, y, z Elt
    97  	prime := P()
    98  	p := conv.BytesLe2BigInt(prime[:])
    99  	for i := 0; i < numTests; i++ {
   100  		_, _ = rand.Read(x[:])
   101  		_, _ = rand.Read(y[:])
   102  		f(&z, &x, &y)
   103  		Modp(&z)
   104  		got := conv.BytesLe2BigInt(z[:])
   105  
   106  		xx, yy := conv.BytesLe2BigInt(x[:]), conv.BytesLe2BigInt(y[:])
   107  		want := xx.Sub(xx, yy).Mod(xx, p)
   108  
   109  		if got.Cmp(want) != 0 {
   110  			test.ReportError(t, got, want, x, y)
   111  		}
   112  	}
   113  }
   114  
   115  func testAddSub(t *testing.T, f taddsub) {
   116  	const numTests = 1 << 9
   117  	var x, y Elt
   118  	prime := P()
   119  	p := conv.BytesLe2BigInt(prime[:])
   120  	want0, want1 := big.NewInt(0), big.NewInt(0)
   121  	for i := 0; i < numTests; i++ {
   122  		_, _ = rand.Read(x[:])
   123  		_, _ = rand.Read(y[:])
   124  		xx, yy := conv.BytesLe2BigInt(x[:]), conv.BytesLe2BigInt(y[:])
   125  		want0.Add(xx, yy).Mod(want0, p)
   126  		want1.Sub(xx, yy).Mod(want1, p)
   127  
   128  		f(&x, &y)
   129  		Modp(&x)
   130  		Modp(&y)
   131  		got0 := conv.BytesLe2BigInt(x[:])
   132  		got1 := conv.BytesLe2BigInt(y[:])
   133  
   134  		if got0.Cmp(want0) != 0 {
   135  			test.ReportError(t, got0, want0, x, y)
   136  		}
   137  		if got1.Cmp(want1) != 0 {
   138  			test.ReportError(t, got1, want1, x, y)
   139  		}
   140  	}
   141  }
   142  
   143  func testMul(t *testing.T, f tmul) {
   144  	const numTests = 1 << 9
   145  	var x, y, z Elt
   146  	prime := P()
   147  	p := conv.BytesLe2BigInt(prime[:])
   148  	for i := 0; i < numTests; i++ {
   149  		_, _ = rand.Read(x[:])
   150  		_, _ = rand.Read(y[:])
   151  		f(&z, &x, &y)
   152  		Modp(&z)
   153  		got := conv.BytesLe2BigInt(z[:])
   154  
   155  		xx, yy := conv.BytesLe2BigInt(x[:]), conv.BytesLe2BigInt(y[:])
   156  		want := xx.Mul(xx, yy).Mod(xx, p)
   157  
   158  		if got.Cmp(want) != 0 {
   159  			test.ReportError(t, got, want, x, y)
   160  		}
   161  	}
   162  }
   163  
   164  func testSqr(t *testing.T, f tsqr) {
   165  	const numTests = 1 << 9
   166  	var x, z Elt
   167  	prime := P()
   168  	p := conv.BytesLe2BigInt(prime[:])
   169  	for i := 0; i < numTests; i++ {
   170  		_, _ = rand.Read(x[:])
   171  		f(&z, &x)
   172  		Modp(&z)
   173  		got := conv.BytesLe2BigInt(z[:])
   174  
   175  		xx := conv.BytesLe2BigInt(x[:])
   176  		want := xx.Mul(xx, xx).Mod(xx, p)
   177  
   178  		if got.Cmp(want) != 0 {
   179  			test.ReportError(t, got, want, x)
   180  		}
   181  	}
   182  }
   183  
   184  func testModp(t *testing.T, f tmodp) {
   185  	const numTests = 1 << 9
   186  	var x Elt
   187  	prime := P()
   188  	p := conv.BytesLe2BigInt(prime[:])
   189  	two256 := big.NewInt(1)
   190  	two256.Lsh(two256, 256)
   191  	want := new(big.Int)
   192  	for i := 0; i < numTests; i++ {
   193  		bigX, _ := rand.Int(rand.Reader, two256)
   194  		bigX.Add(bigX, p).Mod(bigX, two256)
   195  		conv.BigInt2BytesLe(x[:], bigX)
   196  
   197  		f(&x)
   198  		got := conv.BytesLe2BigInt(x[:])
   199  
   200  		want.Mod(bigX, p)
   201  
   202  		if got.Cmp(want) != 0 {
   203  			test.ReportError(t, got, want, bigX)
   204  		}
   205  	}
   206  }
   207  
   208  func TestIsZero(t *testing.T) {
   209  	var x Elt
   210  	got := IsZero(&x)
   211  	want := true
   212  	if got != want {
   213  		test.ReportError(t, got, want, x)
   214  	}
   215  
   216  	SetOne(&x)
   217  	got = IsZero(&x)
   218  	want = false
   219  	if got != want {
   220  		test.ReportError(t, got, want, x)
   221  	}
   222  
   223  	x = P()
   224  	got = IsZero(&x)
   225  	want = true
   226  	if got != want {
   227  		test.ReportError(t, got, want, x)
   228  	}
   229  
   230  	x = Elt{ // 2P
   231  		0xda, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   232  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   233  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   234  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   235  	}
   236  	got = IsZero(&x)
   237  	want = true
   238  	if got != want {
   239  		test.ReportError(t, got, want, x)
   240  	}
   241  }
   242  
   243  func TestToBytes(t *testing.T) {
   244  	const numTests = 1 << 9
   245  	var x Elt
   246  	var got, want [Size]byte
   247  	for i := 0; i < numTests; i++ {
   248  		_, _ = rand.Read(x[:])
   249  		err := ToBytes(got[:], &x)
   250  		conv.BigInt2BytesLe(want[:], conv.BytesLe2BigInt(x[:]))
   251  		if got != want || err != nil {
   252  			test.ReportError(t, got, want, x)
   253  		}
   254  	}
   255  
   256  	{
   257  		var small [Size + 1]byte
   258  		got := ToBytes(small[:], &x)
   259  		want := errors.New("wrong size")
   260  		if got.Error() != want.Error() {
   261  			test.ReportError(t, got, want)
   262  		}
   263  	}
   264  }
   265  
   266  func TestString(t *testing.T) {
   267  	const numTests = 1 << 9
   268  	var x Elt
   269  	var bigX big.Int
   270  	for i := 0; i < numTests; i++ {
   271  		_, _ = rand.Read(x[:])
   272  		got, _ := bigX.SetString(fmt.Sprint(x), 0)
   273  		want := conv.BytesLe2BigInt(x[:])
   274  
   275  		if got.Cmp(want) != 0 {
   276  			test.ReportError(t, got, want, x)
   277  		}
   278  	}
   279  }
   280  
   281  func TestNeg(t *testing.T) {
   282  	const numTests = 1 << 9
   283  	var x, z Elt
   284  	prime := P()
   285  	p := conv.BytesLe2BigInt(prime[:])
   286  	for i := 0; i < numTests; i++ {
   287  		_, _ = rand.Read(x[:])
   288  		Neg(&z, &x)
   289  		Modp(&z)
   290  		got := conv.BytesLe2BigInt(z[:])
   291  
   292  		bigX := conv.BytesLe2BigInt(x[:])
   293  		want := bigX.Neg(bigX).Mod(bigX, p)
   294  
   295  		if got.Cmp(want) != 0 {
   296  			test.ReportError(t, got, want, bigX)
   297  		}
   298  	}
   299  }
   300  
   301  func TestInv(t *testing.T) {
   302  	const numTests = 1 << 9
   303  	var x, z Elt
   304  	prime := P()
   305  	p := conv.BytesLe2BigInt(prime[:])
   306  	for i := 0; i < numTests; i++ {
   307  		_, _ = rand.Read(x[:])
   308  		Inv(&z, &x)
   309  		Modp(&z)
   310  		got := conv.BytesLe2BigInt(z[:])
   311  
   312  		xx := conv.BytesLe2BigInt(x[:])
   313  		want := xx.ModInverse(xx, p)
   314  
   315  		if got.Cmp(want) != 0 {
   316  			test.ReportError(t, got, want, x)
   317  		}
   318  	}
   319  }
   320  
   321  func TestInvSqrt(t *testing.T) {
   322  	const numTests = 1 << 9
   323  	var x, y, z Elt
   324  	prime := P()
   325  	p := conv.BytesLe2BigInt(prime[:])
   326  	exp := big.NewInt(3)
   327  	exp.Add(p, exp).Rsh(exp, 3)
   328  	var frac, root, sqRoot big.Int
   329  	var wantQR bool
   330  	var want *big.Int
   331  	sqrtMinusOne, _ := new(big.Int).SetString("2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0", 16)
   332  	for i := 0; i < numTests; i++ {
   333  		_, _ = rand.Read(x[:])
   334  		_, _ = rand.Read(y[:])
   335  
   336  		gotQR := InvSqrt(&z, &x, &y)
   337  		Modp(&z)
   338  		got := conv.BytesLe2BigInt(z[:])
   339  
   340  		xx := conv.BytesLe2BigInt(x[:])
   341  		yy := conv.BytesLe2BigInt(y[:])
   342  		frac.ModInverse(yy, p).Mul(&frac, xx).Mod(&frac, p)
   343  		root.Exp(&frac, exp, p)
   344  		sqRoot.Mul(&root, &root).Mod(&sqRoot, p)
   345  
   346  		if sqRoot.Cmp(&frac) == 0 {
   347  			want = &root
   348  			wantQR = true
   349  		} else {
   350  			frac.Neg(&frac).Mod(&frac, p)
   351  			if sqRoot.Cmp(&frac) == 0 {
   352  				want = root.Mul(&root, sqrtMinusOne).Mod(&root, p)
   353  				wantQR = true
   354  			} else {
   355  				want = big.NewInt(0)
   356  				wantQR = false
   357  			}
   358  		}
   359  
   360  		if wantQR {
   361  			if gotQR != wantQR || got.Cmp(want) != 0 {
   362  				test.ReportError(t, got, want, x, y)
   363  			}
   364  		} else {
   365  			if gotQR != wantQR {
   366  				test.ReportError(t, gotQR, wantQR, x, y)
   367  			}
   368  		}
   369  	}
   370  }
   371  
   372  func TestGeneric(t *testing.T) {
   373  	t.Run("Cmov", func(t *testing.T) { testCmov(t, cmovGeneric) })
   374  	t.Run("Cswap", func(t *testing.T) { testCswap(t, cswapGeneric) })
   375  	t.Run("Add", func(t *testing.T) { testAdd(t, addGeneric) })
   376  	t.Run("Sub", func(t *testing.T) { testSub(t, subGeneric) })
   377  	t.Run("AddSub", func(t *testing.T) { testAddSub(t, addsubGeneric) })
   378  	t.Run("Mul", func(t *testing.T) { testMul(t, mulGeneric) })
   379  	t.Run("Sqr", func(t *testing.T) { testSqr(t, sqrGeneric) })
   380  	t.Run("Modp", func(t *testing.T) { testModp(t, modpGeneric) })
   381  }
   382  
   383  func TestNative(t *testing.T) {
   384  	t.Run("Cmov", func(t *testing.T) { testCmov(t, Cmov) })
   385  	t.Run("Cswap", func(t *testing.T) { testCswap(t, Cswap) })
   386  	t.Run("Add", func(t *testing.T) { testAdd(t, Add) })
   387  	t.Run("Sub", func(t *testing.T) { testSub(t, Sub) })
   388  	t.Run("AddSub", func(t *testing.T) { testAddSub(t, AddSub) })
   389  	t.Run("Mul", func(t *testing.T) { testMul(t, Mul) })
   390  	t.Run("Sqr", func(t *testing.T) { testSqr(t, Sqr) })
   391  	t.Run("Modp", func(t *testing.T) { testModp(t, Modp) })
   392  }
   393  
   394  func BenchmarkFp(b *testing.B) {
   395  	var x, y, z Elt
   396  	_, _ = rand.Read(x[:])
   397  	_, _ = rand.Read(y[:])
   398  	_, _ = rand.Read(z[:])
   399  	b.Run("Add", func(b *testing.B) {
   400  		for i := 0; i < b.N; i++ {
   401  			Add(&x, &y, &z)
   402  		}
   403  	})
   404  	b.Run("Sub", func(b *testing.B) {
   405  		for i := 0; i < b.N; i++ {
   406  			Sub(&x, &y, &z)
   407  		}
   408  	})
   409  	b.Run("Mul", func(b *testing.B) {
   410  		for i := 0; i < b.N; i++ {
   411  			Mul(&x, &y, &z)
   412  		}
   413  	})
   414  	b.Run("Sqr", func(b *testing.B) {
   415  		for i := 0; i < b.N; i++ {
   416  			Sqr(&x, &y)
   417  		}
   418  	})
   419  	b.Run("Inv", func(b *testing.B) {
   420  		for i := 0; i < b.N; i++ {
   421  			Inv(&x, &y)
   422  		}
   423  	})
   424  	b.Run("InvSqrt", func(b *testing.B) {
   425  		for i := 0; i < b.N; i++ {
   426  			InvSqrt(&z, &x, &y)
   427  		}
   428  	})
   429  }