github.com/cloudflare/circl@v1.5.0/ecc/bls12381/ff/fp_test.go (about)

     1  package ff
     2  
     3  import (
     4  	"crypto/rand"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/cloudflare/circl/internal/test"
     9  )
    10  
    11  func randomFp(t testing.TB) *Fp {
    12  	t.Helper()
    13  	f := new(Fp)
    14  	err := f.Random(rand.Reader)
    15  	if err != nil {
    16  		t.Error(err)
    17  	}
    18  	return f
    19  }
    20  
    21  func TestFp(t *testing.T) {
    22  	const testTimes = 1 << 10
    23  	t.Run("no_alias", func(t *testing.T) {
    24  		var want, got Fp
    25  		x := randomFp(t)
    26  		got = *x
    27  		got.Sqr(&got)
    28  		want = *x
    29  		want.Mul(&want, &want)
    30  		if got.IsEqual(&want) == 0 {
    31  			test.ReportError(t, got, want, x)
    32  		}
    33  	})
    34  	t.Run("mul_inv", func(t *testing.T) {
    35  		var z Fp
    36  		for i := 0; i < testTimes; i++ {
    37  			x := randomFp(t)
    38  			y := randomFp(t)
    39  
    40  			// x*y*x^1 - y = 0
    41  			z.Inv(x)
    42  			z.Mul(&z, y)
    43  			z.Mul(&z, x)
    44  			z.Sub(&z, y)
    45  			got := z.IsZero()
    46  			want := 1
    47  			if got != want {
    48  				test.ReportError(t, got, want, x, y)
    49  			}
    50  		}
    51  	})
    52  	t.Run("mul_sqr", func(t *testing.T) {
    53  		var l0, l1, r0, r1 Fp
    54  		for i := 0; i < testTimes; i++ {
    55  			x := randomFp(t)
    56  			y := randomFp(t)
    57  
    58  			// (x+y)(x-y) = (x^2-y^2)
    59  			l0.Add(x, y)
    60  			l1.Sub(x, y)
    61  			l0.Mul(&l0, &l1)
    62  			r0.Sqr(x)
    63  			r1.Sqr(y)
    64  			r0.Sub(&r0, &r1)
    65  			got := &l0
    66  			want := &r0
    67  			if got.IsEqual(want) == 0 {
    68  				test.ReportError(t, got, want, x, y)
    69  			}
    70  		}
    71  	})
    72  	t.Run("sqrt", func(t *testing.T) {
    73  		var r, notRoot, got Fp
    74  		// Check when x has square-root.
    75  		for i := 0; i < testTimes; i++ {
    76  			x := randomFp(t)
    77  			x.Sqr(x)
    78  
    79  			// let x is QR and r = sqrt(x); check (+r)^2 = (-r)^2 = x.
    80  			isQR := r.Sqrt(x)
    81  			test.CheckOk(isQR == 1, fmt.Sprintf("should be a QR: %v", x), t)
    82  			rNeg := r
    83  			rNeg.Neg()
    84  
    85  			want := x
    86  			for _, root := range []*Fp{&r, &rNeg} {
    87  				got.Sqr(root)
    88  				if got.IsEqual(want) == 0 {
    89  					test.ReportError(t, got, want, x, root)
    90  				}
    91  			}
    92  		}
    93  		// Check when x has not square-root.
    94  		for i := 0; i < testTimes; i++ {
    95  			want := randomFp(t)
    96  			x := randomFp(t)
    97  			x.Sqr(x)
    98  			x.Add(x, x) // x = 2*(x^2), since 2 is not QR in Fp.
    99  
   100  			// let x is not QR and r = sqrt(x); check that r was not modified.
   101  			got := want
   102  			isQR := got.Sqrt(x)
   103  			test.CheckOk(isQR == 0, fmt.Sprintf("shouldn't be a QR: %v", x), t)
   104  
   105  			if got.IsEqual(want) != 1 {
   106  				test.ReportError(t, got, want, x, notRoot)
   107  			}
   108  		}
   109  	})
   110  	t.Run("marshal", func(t *testing.T) {
   111  		var b Fp
   112  		for i := 0; i < testTimes; i++ {
   113  			a := randomFp(t)
   114  			s, err := a.MarshalBinary()
   115  			test.CheckNoErr(t, err, "MarshalBinary failed")
   116  			err = b.UnmarshalBinary(s)
   117  			test.CheckNoErr(t, err, "UnmarshalBinary failed")
   118  			if b.IsEqual(a) == 0 {
   119  				test.ReportError(t, a, b)
   120  			}
   121  		}
   122  	})
   123  }
   124  
   125  func BenchmarkFp(b *testing.B) {
   126  	x := randomFp(b)
   127  	y := randomFp(b)
   128  	z := randomFp(b)
   129  	b.Run("Add", func(b *testing.B) {
   130  		for i := 0; i < b.N; i++ {
   131  			z.Add(x, y)
   132  		}
   133  	})
   134  	b.Run("Mul", func(b *testing.B) {
   135  		for i := 0; i < b.N; i++ {
   136  			z.Mul(x, y)
   137  		}
   138  	})
   139  	b.Run("Sqr", func(b *testing.B) {
   140  		for i := 0; i < b.N; i++ {
   141  			z.Sqr(x)
   142  		}
   143  	})
   144  	b.Run("Inv", func(b *testing.B) {
   145  		for i := 0; i < b.N; i++ {
   146  			z.Inv(x)
   147  		}
   148  	})
   149  }