github.com/cloudflare/circl@v1.5.0/group/group_test.go (about)

     1  package group_test
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"testing"
     8  
     9  	"github.com/cloudflare/circl/group"
    10  	"github.com/cloudflare/circl/internal/test"
    11  )
    12  
    13  var allGroups = []group.Group{
    14  	group.P256,
    15  	group.P384,
    16  	group.P521,
    17  	group.Ristretto255,
    18  }
    19  
    20  func TestGroup(t *testing.T) {
    21  	const testTimes = 1 << 7
    22  	for _, g := range allGroups {
    23  		n := g.(fmt.Stringer).String()
    24  		t.Run(n+"/Add", func(tt *testing.T) { testAdd(tt, testTimes, g) })
    25  		t.Run(n+"/Neg", func(tt *testing.T) { testNeg(tt, testTimes, g) })
    26  		t.Run(n+"/Mul", func(tt *testing.T) { testMul(tt, testTimes, g) })
    27  		t.Run(n+"/MulGen", func(tt *testing.T) { testMulGen(tt, testTimes, g) })
    28  		t.Run(n+"/CMov", func(tt *testing.T) { testCMov(tt, testTimes, g) })
    29  		t.Run(n+"/CSelect", func(tt *testing.T) { testCSelect(tt, testTimes, g) })
    30  		t.Run(n+"/Order", func(tt *testing.T) { testOrder(tt, testTimes, g) })
    31  		t.Run(n+"/Marshal", func(tt *testing.T) { testMarshal(tt, testTimes, g) })
    32  		t.Run(n+"/Scalar", func(tt *testing.T) { testScalar(tt, testTimes, g) })
    33  	}
    34  }
    35  
    36  func testAdd(t *testing.T, testTimes int, g group.Group) {
    37  	Q := g.NewElement()
    38  	for i := 0; i < testTimes; i++ {
    39  		P := g.RandomElement(rand.Reader)
    40  
    41  		got := Q.Dbl(P).Dbl(Q).Dbl(Q).Dbl(Q) // Q = 16P
    42  
    43  		R := g.Identity()
    44  		for j := 0; j < 16; j++ {
    45  			R.Add(R, P)
    46  		}
    47  		want := R // R = 16P = P+P...+P
    48  		if !got.IsEqual(want) {
    49  			test.ReportError(t, got, want, P)
    50  		}
    51  	}
    52  }
    53  
    54  func testNeg(t *testing.T, testTimes int, g group.Group) {
    55  	Q := g.NewElement()
    56  	for i := 0; i < testTimes; i++ {
    57  		P := g.RandomElement(rand.Reader)
    58  		Q.Neg(P)
    59  		Q.Add(Q, P)
    60  		got := Q.IsIdentity()
    61  		want := true
    62  		if got != want {
    63  			test.ReportError(t, got, want, P)
    64  		}
    65  	}
    66  }
    67  
    68  func testMul(t *testing.T, testTimes int, g group.Group) {
    69  	Q := g.NewElement()
    70  	kInv := g.NewScalar()
    71  	for i := 0; i < testTimes; i++ {
    72  		P := g.RandomElement(rand.Reader)
    73  		k := g.RandomScalar(rand.Reader)
    74  		kInv.Inv(k)
    75  
    76  		Q.Mul(P, k)
    77  		Q.Mul(Q, kInv)
    78  
    79  		got := P
    80  		want := Q
    81  		if !got.IsEqual(want) {
    82  			test.ReportError(t, got, want, P, k)
    83  		}
    84  	}
    85  }
    86  
    87  func testMulGen(t *testing.T, testTimes int, g group.Group) {
    88  	G := g.Generator()
    89  	P := g.NewElement()
    90  	Q := g.NewElement()
    91  	for i := 0; i < testTimes; i++ {
    92  		k := g.RandomScalar(rand.Reader)
    93  
    94  		P.Mul(G, k)
    95  		Q.MulGen(k)
    96  
    97  		got := P
    98  		want := Q
    99  		if !got.IsEqual(want) {
   100  			test.ReportError(t, got, want, P, k)
   101  		}
   102  	}
   103  }
   104  
   105  func testCMov(t *testing.T, testTimes int, g group.Group) {
   106  	P := g.RandomElement(rand.Reader)
   107  	Q := g.RandomElement(rand.Reader)
   108  
   109  	err := test.CheckPanic(func() { P.CMov(0, Q) })
   110  	test.CheckIsErr(t, err, "shouldn't fail with 0")
   111  	err = test.CheckPanic(func() { P.CMov(1, Q) })
   112  	test.CheckIsErr(t, err, "shouldn't fail with 1")
   113  	err = test.CheckPanic(func() { P.CMov(2, Q) })
   114  	test.CheckNoErr(t, err, "should fail with dif 0,1")
   115  
   116  	for i := 0; i < testTimes; i++ {
   117  		P = g.RandomElement(rand.Reader)
   118  		Q = g.RandomElement(rand.Reader)
   119  
   120  		want := P.Copy()
   121  		got := P.CMov(0, Q)
   122  		if !got.IsEqual(want) {
   123  			test.ReportError(t, got, want)
   124  		}
   125  
   126  		want = Q.Copy()
   127  		got = P.CMov(1, Q)
   128  		if !got.IsEqual(want) {
   129  			test.ReportError(t, got, want)
   130  		}
   131  	}
   132  }
   133  
   134  func testCSelect(t *testing.T, testTimes int, g group.Group) {
   135  	P := g.RandomElement(rand.Reader)
   136  	Q := g.RandomElement(rand.Reader)
   137  	R := g.RandomElement(rand.Reader)
   138  
   139  	err := test.CheckPanic(func() { P.CSelect(0, Q, R) })
   140  	test.CheckIsErr(t, err, "shouldn't fail with 0")
   141  	err = test.CheckPanic(func() { P.CSelect(1, Q, R) })
   142  	test.CheckIsErr(t, err, "shouldn't fail with 1")
   143  	err = test.CheckPanic(func() { P.CSelect(2, Q, R) })
   144  	test.CheckNoErr(t, err, "should fail with dif 0,1")
   145  
   146  	for i := 0; i < testTimes; i++ {
   147  		P = g.RandomElement(rand.Reader)
   148  		Q = g.RandomElement(rand.Reader)
   149  		R = g.RandomElement(rand.Reader)
   150  
   151  		want := R.Copy()
   152  		got := P.CSelect(0, Q, R)
   153  		if !got.IsEqual(want) {
   154  			test.ReportError(t, got, want)
   155  		}
   156  
   157  		want = Q.Copy()
   158  		got = P.CSelect(1, Q, R)
   159  		if !got.IsEqual(want) {
   160  			test.ReportError(t, got, want)
   161  		}
   162  	}
   163  }
   164  
   165  func testOrder(t *testing.T, testTimes int, g group.Group) {
   166  	I := g.Identity()
   167  	Q := g.NewElement()
   168  	minusOne := g.NewScalar().SetUint64(1)
   169  	minusOne.Neg(minusOne)
   170  	for i := 0; i < testTimes; i++ {
   171  		P := g.RandomElement(rand.Reader)
   172  
   173  		Q.Mul(P, minusOne)
   174  		got := Q.Add(Q, P)
   175  		want := I
   176  		if !got.IsEqual(want) {
   177  			test.ReportError(t, got, want, P)
   178  		}
   179  	}
   180  }
   181  
   182  func isZero(b []byte) bool {
   183  	for i := 0; i < len(b); i++ {
   184  		if b[i] != 0x00 {
   185  			return false
   186  		}
   187  	}
   188  	return true
   189  }
   190  
   191  func testMarshal(t *testing.T, testTimes int, g group.Group) {
   192  	params := g.Params()
   193  	I := g.Identity()
   194  	got, err := I.MarshalBinary()
   195  	test.CheckNoErr(t, err, "error on MarshalBinary")
   196  	if !isZero(got) {
   197  		test.ReportError(t, got, "Non-zero identity")
   198  	}
   199  	if l := uint(len(got)); !(l == 1 || l == params.ElementLength) {
   200  		test.ReportError(t, l, params.ElementLength)
   201  	}
   202  	got, err = I.MarshalBinaryCompress()
   203  	test.CheckNoErr(t, err, "error on MarshalBinaryCompress")
   204  	if !isZero(got) {
   205  		test.ReportError(t, got, "Non-zero identity")
   206  	}
   207  	if l := uint(len(got)); !(l == 1 || l == params.CompressedElementLength) {
   208  		test.ReportError(t, l, params.CompressedElementLength)
   209  	}
   210  	II := g.NewElement()
   211  	err = II.UnmarshalBinary(got)
   212  	if err != nil || !I.IsEqual(II) {
   213  		test.ReportError(t, I, II)
   214  	}
   215  
   216  	got1 := g.NewElement()
   217  	got2 := g.NewElement()
   218  	for i := 0; i < testTimes; i++ {
   219  		x := g.RandomElement(rand.Reader)
   220  		enc1, err1 := x.MarshalBinary()
   221  		enc2, err2 := x.MarshalBinaryCompress()
   222  		test.CheckNoErr(t, err1, "error on marshalling")
   223  		test.CheckNoErr(t, err2, "error on marshalling compress")
   224  
   225  		err1 = got1.UnmarshalBinary(enc1)
   226  		err2 = got2.UnmarshalBinary(enc2)
   227  		test.CheckNoErr(t, err1, "error on unmarshalling")
   228  		test.CheckNoErr(t, err2, "error on unmarshalling compress")
   229  		if !x.IsEqual(got1) {
   230  			test.ReportError(t, got1, x)
   231  		}
   232  		if !x.IsEqual(got2) {
   233  			test.ReportError(t, got2, x)
   234  		}
   235  		if l := uint(len(enc1)); l != params.ElementLength {
   236  			test.ReportError(t, l, params.ElementLength)
   237  		}
   238  		if l := uint(len(enc2)); l != params.CompressedElementLength {
   239  			test.ReportError(t, l, params.CompressedElementLength)
   240  		}
   241  	}
   242  }
   243  
   244  func testScalar(t *testing.T, testTimes int, g group.Group) {
   245  	a := g.RandomScalar(rand.Reader)
   246  	b := g.RandomScalar(rand.Reader)
   247  	c := g.NewScalar()
   248  	d := g.NewScalar()
   249  	e := g.NewScalar()
   250  	f := g.NewScalar()
   251  	one := g.NewScalar()
   252  	one.SetUint64(1)
   253  	params := g.Params()
   254  
   255  	err := test.CheckPanic(func() { a.CMov(0, b) })
   256  	test.CheckIsErr(t, err, "shouldn't fail with 0")
   257  	err = test.CheckPanic(func() { a.CMov(1, b) })
   258  	test.CheckIsErr(t, err, "shouldn't fail with 1")
   259  	err = test.CheckPanic(func() { a.CMov(2, b) })
   260  	test.CheckNoErr(t, err, "should fail with dif 0,1")
   261  
   262  	err = test.CheckPanic(func() { a.CSelect(0, b, c) })
   263  	test.CheckIsErr(t, err, "shouldn't fail with 0")
   264  	err = test.CheckPanic(func() { a.CSelect(1, b, c) })
   265  	test.CheckIsErr(t, err, "shouldn't fail with 1")
   266  	err = test.CheckPanic(func() { a.CSelect(2, b, c) })
   267  	test.CheckNoErr(t, err, "should fail with dif 0,1")
   268  
   269  	for i := 0; i < testTimes; i++ {
   270  		a = g.RandomScalar(rand.Reader)
   271  		b = g.RandomScalar(rand.Reader)
   272  		c.Add(a, b)
   273  		d.Sub(a, b)
   274  		e.Mul(c, d)
   275  		e.Add(e, one)
   276  
   277  		c.Mul(a, a)
   278  		d.Mul(b, b)
   279  		d.Neg(d)
   280  		f.Add(c, d)
   281  		f.Add(f, one)
   282  		enc1, err1 := e.MarshalBinary()
   283  		enc2, err2 := f.MarshalBinary()
   284  		if err1 != nil || err2 != nil || !bytes.Equal(enc1, enc2) {
   285  			test.ReportError(t, enc1, enc2, a, b)
   286  		}
   287  		if l := uint(len(enc1)); l != params.ScalarLength {
   288  			test.ReportError(t, l, params.ScalarLength)
   289  		}
   290  
   291  		want := c.Copy()
   292  		got := c.CMov(0, a)
   293  		if !got.IsEqual(want) {
   294  			test.ReportError(t, got, want)
   295  		}
   296  
   297  		want = b.Copy()
   298  		got = d.CMov(1, b)
   299  		if !got.IsEqual(want) {
   300  			test.ReportError(t, got, want)
   301  		}
   302  
   303  		want = b.Copy()
   304  		got = e.CSelect(0, a, b)
   305  		if !got.IsEqual(want) {
   306  			test.ReportError(t, got, want)
   307  		}
   308  
   309  		want = a.Copy()
   310  		got = f.CSelect(1, a, b)
   311  		if !got.IsEqual(want) {
   312  			test.ReportError(t, got, want)
   313  		}
   314  	}
   315  
   316  	c.Inv(a)
   317  	c.Mul(c, a)
   318  	c.Sub(c, one)
   319  	if !c.IsZero() {
   320  		test.ReportError(t, c, one, a)
   321  	}
   322  }
   323  
   324  func BenchmarkElement(b *testing.B) {
   325  	for _, g := range allGroups {
   326  		x := g.RandomElement(rand.Reader)
   327  		y := g.RandomElement(rand.Reader)
   328  		n := g.RandomScalar(rand.Reader)
   329  		name := g.(fmt.Stringer).String()
   330  		b.Run(name+"/Add", func(b *testing.B) {
   331  			for i := 0; i < b.N; i++ {
   332  				x.Add(x, y)
   333  			}
   334  		})
   335  		b.Run(name+"/Dbl", func(b *testing.B) {
   336  			for i := 0; i < b.N; i++ {
   337  				x.Dbl(x)
   338  			}
   339  		})
   340  		b.Run(name+"/Mul", func(b *testing.B) {
   341  			for i := 0; i < b.N; i++ {
   342  				y.Mul(x, n)
   343  			}
   344  		})
   345  		b.Run(name+"/MulGen", func(b *testing.B) {
   346  			for i := 0; i < b.N; i++ {
   347  				x.MulGen(n)
   348  			}
   349  		})
   350  	}
   351  }
   352  
   353  func BenchmarkScalar(b *testing.B) {
   354  	for _, g := range allGroups {
   355  		x := g.RandomScalar(rand.Reader)
   356  		y := g.RandomScalar(rand.Reader)
   357  		name := g.(fmt.Stringer).String()
   358  		b.Run(name+"/Add", func(b *testing.B) {
   359  			for i := 0; i < b.N; i++ {
   360  				x.Add(x, y)
   361  			}
   362  		})
   363  		b.Run(name+"/Mul", func(b *testing.B) {
   364  			for i := 0; i < b.N; i++ {
   365  				x.Mul(x, y)
   366  			}
   367  		})
   368  		b.Run(name+"/Inv", func(b *testing.B) {
   369  			for i := 0; i < b.N; i++ {
   370  				y.Inv(x)
   371  			}
   372  		})
   373  	}
   374  }