github.com/cloudflare/circl@v1.5.0/ecc/bls12381/g1_test.go (about)

     1  package bls12381
     2  
     3  import (
     4  	"crypto/rand"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/cloudflare/circl/ecc/bls12381/ff"
     9  	"github.com/cloudflare/circl/internal/test"
    10  )
    11  
    12  func randomScalar(t testing.TB) *Scalar {
    13  	s := &Scalar{}
    14  	err := s.Random(rand.Reader)
    15  	test.CheckNoErr(t, err, "random scalar")
    16  	return s
    17  }
    18  
    19  func randomG1(t testing.TB) *G1 {
    20  	P := &G1{}
    21  	u := &ff.Fp{}
    22  	r := &isogG1Point{}
    23  
    24  	err := u.Random(rand.Reader)
    25  	test.CheckNoErr(t, err, "random fp")
    26  
    27  	r.sswu(u)
    28  	P.evalIsogG1(r)
    29  	P.clearCofactor()
    30  	got := P.IsOnG1()
    31  	want := true
    32  
    33  	if got != want {
    34  		test.ReportError(t, got, want, "point not in G1", u)
    35  	}
    36  	return P
    37  }
    38  
    39  func TestG1Add(t *testing.T) {
    40  	const testTimes = 1 << 6
    41  	var Q, R G1
    42  	for i := 0; i < testTimes; i++ {
    43  		P := randomG1(t)
    44  		Q = *P
    45  		R = *P
    46  		R.Add(&R, &R)
    47  		R.Neg()
    48  		Q.Double()
    49  		Q.Neg()
    50  		got := R
    51  		want := Q
    52  		if !got.IsEqual(&want) {
    53  			test.ReportError(t, got, want, P)
    54  		}
    55  	}
    56  }
    57  
    58  func TestG1ScalarMult(t *testing.T) {
    59  	const testTimes = 1 << 6
    60  	var Q G1
    61  	for i := 0; i < testTimes; i++ {
    62  		P := randomG1(t)
    63  		k := randomScalar(t)
    64  		Q.ScalarMult(k, P)
    65  		Q.toAffine()
    66  		got := Q.IsOnG1()
    67  		want := true
    68  		if got != want {
    69  			test.ReportError(t, got, want, P, k)
    70  		}
    71  	}
    72  }
    73  
    74  func TestG1Hash(t *testing.T) {
    75  	const testTimes = 1 << 8
    76  
    77  	for _, e := range [...]struct {
    78  		Name string
    79  		Enc  func(p *G1, input, dst []byte)
    80  	}{
    81  		{"Encode", func(p *G1, input, dst []byte) { p.Encode(input, dst) }},
    82  		{"Hash", func(p *G1, input, dst []byte) { p.Hash(input, dst) }},
    83  	} {
    84  		var msg, dst [4]byte
    85  		var p G1
    86  		t.Run(e.Name, func(t *testing.T) {
    87  			for i := 0; i < testTimes; i++ {
    88  				_, _ = rand.Read(msg[:])
    89  				_, _ = rand.Read(dst[:])
    90  				e.Enc(&p, msg[:], dst[:])
    91  
    92  				got := p.isRTorsion()
    93  				want := true
    94  				if got != want {
    95  					test.ReportError(t, got, want, e.Name, msg, dst)
    96  				}
    97  			}
    98  		})
    99  	}
   100  }
   101  
   102  func BenchmarkG1(b *testing.B) {
   103  	P := randomG1(b)
   104  	Q := randomG1(b)
   105  	k := randomScalar(b)
   106  	var msg, dst [4]byte
   107  	_, _ = rand.Read(msg[:])
   108  	_, _ = rand.Read(dst[:])
   109  
   110  	b.Run("Add", func(b *testing.B) {
   111  		for i := 0; i < b.N; i++ {
   112  			P.Add(P, Q)
   113  		}
   114  	})
   115  	b.Run("Mul", func(b *testing.B) {
   116  		for i := 0; i < b.N; i++ {
   117  			P.ScalarMult(k, P)
   118  		}
   119  	})
   120  	b.Run("Hash", func(b *testing.B) {
   121  		for i := 0; i < b.N; i++ {
   122  			P.Hash(msg[:], dst[:])
   123  		}
   124  	})
   125  }
   126  
   127  func TestG1Serial(t *testing.T) {
   128  	mustOk := "must be ok"
   129  	mustErr := "must be an error"
   130  	t.Run("valid", func(t *testing.T) {
   131  		testTimes := 1 << 6
   132  		var got, want G1
   133  		want.SetIdentity()
   134  		for i := 0; i < testTimes; i++ {
   135  			for _, b := range [][]byte{want.Bytes(), want.BytesCompressed()} {
   136  				err := got.SetBytes(b)
   137  				test.CheckNoErr(t, err, fmt.Sprintf("failure to deserialize: (P:%v b:%x)", want, b))
   138  
   139  				if !got.IsEqual(&want) {
   140  					test.ReportError(t, got, want, b)
   141  				}
   142  			}
   143  			want = *randomG1(t)
   144  		}
   145  	})
   146  	t.Run("badPrefix", func(t *testing.T) {
   147  		q := new(G1)
   148  		b := make([]byte, G1Size)
   149  		for _, b[0] = range []byte{0x20, 0x60, 0xE0} {
   150  			test.CheckIsErr(t, q.SetBytes(b), mustErr)
   151  		}
   152  	})
   153  	t.Run("badLength", func(t *testing.T) {
   154  		q := new(G1)
   155  		p := randomG1(t)
   156  		b := p.Bytes()
   157  		test.CheckIsErr(t, q.SetBytes(b[:0]), mustErr)
   158  		test.CheckIsErr(t, q.SetBytes(b[:1]), mustErr)
   159  		test.CheckIsErr(t, q.SetBytes(b[:G1Size-1]), mustErr)
   160  		test.CheckIsErr(t, q.SetBytes(b[:G1SizeCompressed]), mustErr)
   161  		test.CheckNoErr(t, q.SetBytes(b), mustOk)
   162  		test.CheckNoErr(t, q.SetBytes(append(b, 0)), mustOk)
   163  		b = p.BytesCompressed()
   164  		test.CheckIsErr(t, q.SetBytes(b[:0]), mustErr)
   165  		test.CheckIsErr(t, q.SetBytes(b[:1]), mustErr)
   166  		test.CheckIsErr(t, q.SetBytes(b[:G1SizeCompressed-1]), mustErr)
   167  		test.CheckNoErr(t, q.SetBytes(b), mustOk)
   168  		test.CheckNoErr(t, q.SetBytes(append(b, 0)), mustOk)
   169  	})
   170  	t.Run("badInfinity", func(t *testing.T) {
   171  		var badInf, p G1
   172  		badInf.SetIdentity()
   173  		b := badInf.Bytes()
   174  		b[0] |= 0x1F
   175  		err := p.SetBytes(b)
   176  		test.CheckIsErr(t, err, mustErr)
   177  		b[0] &= 0xE0
   178  		b[1] = 0xFF
   179  		err = p.SetBytes(b)
   180  		test.CheckIsErr(t, err, mustErr)
   181  	})
   182  	t.Run("badCoords", func(t *testing.T) {
   183  		bad := (&[ff.FpSize]byte{})[:]
   184  		for i := range bad {
   185  			bad[i] = 0xFF
   186  		}
   187  		var e ff.Fp
   188  		_ = e.Random(rand.Reader)
   189  		good, err := e.MarshalBinary()
   190  		test.CheckNoErr(t, err, mustOk)
   191  
   192  		// bad x, good y
   193  		b := append(bad, good...)
   194  		b[0] = b[0]&0x1F | headerEncoding(0, 0, 0)
   195  		test.CheckIsErr(t, new(G1).SetBytes(b), mustErr)
   196  
   197  		// good x, bad y
   198  		b = append(good, bad...)
   199  		b[0] = b[0]&0x1F | headerEncoding(0, 0, 0)
   200  		test.CheckIsErr(t, new(G1).SetBytes(b), mustErr)
   201  	})
   202  	t.Run("noQR", func(t *testing.T) {
   203  		var x ff.Fp
   204  		x.SetUint64(1) // Let x=1, so x^3+4 = 5, which is not QR.
   205  		b, err := x.MarshalBinary()
   206  		test.CheckNoErr(t, err, mustOk)
   207  		b[0] = b[0]&0x1F | headerEncoding(1, 0, 0)
   208  		test.CheckIsErr(t, new(G1).SetBytes(b), mustErr)
   209  	})
   210  	t.Run("notInG1", func(t *testing.T) {
   211  		// p=(0,1) is not on curve.
   212  		var x, y ff.Fp
   213  		y.SetUint64(1)
   214  		bx, err := x.MarshalBinary()
   215  		test.CheckNoErr(t, err, mustOk)
   216  		by, err := y.MarshalBinary()
   217  		test.CheckNoErr(t, err, mustOk)
   218  		b := append(bx, by...)
   219  		b[0] = b[0]&0x1F | headerEncoding(0, 0, 0)
   220  		test.CheckIsErr(t, new(G1).SetBytes(b), mustErr)
   221  	})
   222  }
   223  
   224  func TestG1Affinize(t *testing.T) {
   225  	N := 20
   226  	testTimes := 1 << 6
   227  	g1 := make([]*G1, N)
   228  	for i := 0; i < testTimes; i++ {
   229  		for j := 0; j < N; j++ {
   230  			g1[j] = randomG1(t)
   231  		}
   232  		g2 := affinize(g1)
   233  		for j := 0; j < N; j++ {
   234  			g1[j].toAffine()
   235  			if !g1[j].IsEqual(&g2[j]) {
   236  				t.Fatal("failure to preserve points")
   237  			}
   238  			if g2[j].z.IsEqual(&g1[j].z) != 1 {
   239  				t.Fatal("failure to make affine")
   240  			}
   241  		}
   242  	}
   243  }
   244  
   245  func TestG1Torsion(t *testing.T) {
   246  	if !G1Generator().isRTorsion() {
   247  		t.Fatalf("G1 generator is not r-torsion")
   248  	}
   249  }
   250  
   251  func TestG1Bytes(t *testing.T) {
   252  	got := new(G1)
   253  	id := new(G1)
   254  	id.SetIdentity()
   255  	g := G1Generator()
   256  	minusG := G1Generator()
   257  	minusG.Neg()
   258  
   259  	type testCase struct {
   260  		header  byte
   261  		length  int
   262  		point   *G1
   263  		toBytes func(G1) []byte
   264  	}
   265  
   266  	for i, v := range []testCase{
   267  		{headerEncoding(0, 0, 0), G1Size, randomG1(t), (G1).Bytes},
   268  		{headerEncoding(0, 0, 0), G1Size, g, (G1).Bytes},
   269  		{headerEncoding(1, 0, 0), G1SizeCompressed, g, (G1).BytesCompressed},
   270  		{headerEncoding(1, 0, 1), G1SizeCompressed, minusG, (G1).BytesCompressed},
   271  		{headerEncoding(0, 1, 0), G1Size, id, (G1).Bytes},
   272  		{headerEncoding(1, 1, 0), G1SizeCompressed, id, (G1).BytesCompressed},
   273  	} {
   274  		b := v.toBytes(*v.point)
   275  		test.CheckOk(len(b) == v.length, fmt.Sprintf("bad encoding size (case:%v point:%v b:%x)", i, v.point, b), t)
   276  		test.CheckOk(b[0]&0xE0 == v.header, fmt.Sprintf("bad encoding header (case:%v point:%v b:%x)", i, v.point, b), t)
   277  
   278  		err := got.SetBytes(b)
   279  		want := v.point
   280  		if err != nil || !got.IsEqual(want) {
   281  			test.ReportError(t, got, want, i, b)
   282  		}
   283  	}
   284  }