github.com/cloudflare/circl@v1.5.0/math/mlsbset/mlsbset_test.go (about)

     1  package mlsbset_test
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"math/big"
     7  	"testing"
     8  
     9  	"github.com/cloudflare/circl/internal/conv"
    10  	"github.com/cloudflare/circl/internal/test"
    11  	"github.com/cloudflare/circl/math/mlsbset"
    12  )
    13  
    14  func TestExp(t *testing.T) {
    15  	T := uint(126)
    16  	for v := uint(1); v <= 5; v++ {
    17  		for w := uint(2); w <= 5; w++ {
    18  			m, err := mlsbset.New(T, v, w)
    19  			if err != nil {
    20  				test.ReportError(t, err, nil)
    21  			}
    22  			testExp(t, m)
    23  		}
    24  	}
    25  }
    26  
    27  func testExp(t *testing.T, m mlsbset.Encoder) {
    28  	const testTimes = 1 << 8
    29  	params := m.GetParams()
    30  	TBytes := (params.T + 7) / 8
    31  	topBits := (byte(1) << (params.T % 8)) - 1
    32  	k := make([]byte, TBytes)
    33  	for i := 0; i < testTimes; i++ {
    34  		_, _ = rand.Read(k)
    35  		k[0] |= 1
    36  		k[TBytes-1] &= topBits
    37  
    38  		c, err := m.Encode(k)
    39  		if err != nil {
    40  			test.ReportError(t, err, nil)
    41  		}
    42  
    43  		g := zzAdd{m.GetParams()}
    44  		a := c.Exp(g)
    45  
    46  		got := a.(*big.Int)
    47  		want := conv.BytesLe2BigInt(k)
    48  		if got.Cmp(want) != 0 {
    49  			test.ReportError(t, got, want, m)
    50  		}
    51  	}
    52  }
    53  
    54  type zzAdd struct{ set mlsbset.Params }
    55  
    56  func (zzAdd) Identity() mlsbset.EltG { return big.NewInt(0) }
    57  func (zzAdd) NewEltP() mlsbset.EltP  { return new(big.Int) }
    58  func (zzAdd) Sqr(x mlsbset.EltG) {
    59  	a := x.(*big.Int)
    60  	a.Add(a, a)
    61  }
    62  
    63  func (zzAdd) Mul(x mlsbset.EltG, y mlsbset.EltP) {
    64  	a := x.(*big.Int)
    65  	b := y.(*big.Int)
    66  	a.Add(a, b)
    67  }
    68  
    69  func (z zzAdd) ExtendedEltP() mlsbset.EltP {
    70  	a := big.NewInt(1)
    71  	a.Lsh(a, z.set.W*z.set.D)
    72  	return a
    73  }
    74  
    75  func (z zzAdd) Lookup(x mlsbset.EltP, idTable uint, sgnElt int32, idElt int32) {
    76  	a := x.(*big.Int)
    77  	a.SetInt64(1)
    78  	a.Lsh(a, z.set.E*idTable) // 2^(e*v)
    79  	sum := big.NewInt(0)
    80  	for i := int(z.set.W - 2); i >= 0; i-- {
    81  		ui := big.NewInt(int64((idElt >> uint(i)) & 0x1))
    82  		sum.Add(sum, ui)
    83  		sum.Lsh(sum, z.set.D)
    84  	}
    85  	sum.Add(sum, big.NewInt(1))
    86  	a.Mul(a, sum)
    87  	if sgnElt == -1 {
    88  		a.Neg(a)
    89  	}
    90  }
    91  
    92  func TestEncodeErr(t *testing.T) {
    93  	t.Run("mArgs", func(t *testing.T) {
    94  		_, got := mlsbset.New(0, 0, 0)
    95  		want := errors.New("t>1, v>=1, w>=2")
    96  		if got.Error() != want.Error() {
    97  			test.ReportError(t, got, want)
    98  		}
    99  	})
   100  	t.Run("kOdd", func(t *testing.T) {
   101  		m, _ := mlsbset.New(16, 2, 2)
   102  		k := make([]byte, 2)
   103  		_, got := m.Encode(k)
   104  		want := errors.New("k must be odd")
   105  		if got.Error() != want.Error() {
   106  			test.ReportError(t, got, want)
   107  		}
   108  	})
   109  	t.Run("kBig", func(t *testing.T) {
   110  		m, _ := mlsbset.New(16, 2, 2)
   111  		k := make([]byte, 4)
   112  		_, got := m.Encode(k)
   113  		want := errors.New("k too big")
   114  		if got.Error() != want.Error() {
   115  			test.ReportError(t, got, want)
   116  		}
   117  	})
   118  	t.Run("kEmpty", func(t *testing.T) {
   119  		m, _ := mlsbset.New(16, 2, 2)
   120  		k := []byte{}
   121  		_, got := m.Encode(k)
   122  		want := errors.New("empty slice")
   123  		if got.Error() != want.Error() {
   124  			test.ReportError(t, got, want)
   125  		}
   126  	})
   127  }
   128  
   129  func BenchmarkEncode(b *testing.B) {
   130  	t, v, w := uint(256), uint(2), uint(3)
   131  	m, _ := mlsbset.New(t, v, w)
   132  	params := m.GetParams()
   133  	TBytes := (params.T + 7) / 8
   134  	topBits := (byte(1) << (params.T % 8)) - 1
   135  
   136  	k := make([]byte, TBytes)
   137  	_, _ = rand.Read(k)
   138  	k[0] |= 1
   139  	k[TBytes-1] &= topBits
   140  
   141  	c, _ := m.Encode(k)
   142  	g := zzAdd{params}
   143  
   144  	b.Run("Encode", func(b *testing.B) {
   145  		for i := 0; i < b.N; i++ {
   146  			_, _ = m.Encode(k)
   147  		}
   148  	})
   149  	b.Run("Exp", func(b *testing.B) {
   150  		for i := 0; i < b.N; i++ {
   151  			c.Exp(g)
   152  		}
   153  	})
   154  }