github.com/consensys/gnark-crypto@v0.14.0/ecc/bls12-377/fr/sis/sis_test.go (about)

     1  // Copyright 2023 ConsenSys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sis
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/rand"
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"fmt"
    23  	"io"
    24  	"math/big"
    25  	"math/bits"
    26  	"os"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/bits-and-blooms/bitset"
    31  	"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
    32  	"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft"
    33  	"github.com/stretchr/testify/assert"
    34  	"github.com/stretchr/testify/require"
    35  )
    36  
    37  type sisParams struct {
    38  	logTwoBound, logTwoDegree int
    39  }
    40  
    41  var params128Bits []sisParams = []sisParams{
    42  	{logTwoBound: 2, logTwoDegree: 3},
    43  	{logTwoBound: 4, logTwoDegree: 4},
    44  	{logTwoBound: 6, logTwoDegree: 5},
    45  	{logTwoBound: 8, logTwoDegree: 6},
    46  	{logTwoBound: 10, logTwoDegree: 6},
    47  	{logTwoBound: 16, logTwoDegree: 7},
    48  	{logTwoBound: 32, logTwoDegree: 8},
    49  }
    50  
    51  type TestCases struct {
    52  	Inputs  [][]fr.Element `json:"inputs"`
    53  	Entries []struct {
    54  		Params struct {
    55  			Seed                int64 `json:"seed"`
    56  			LogTwoDegree        int   `json:"logTwoDegree"`
    57  			LogTwoBound         int   `json:"logTwoBound"`
    58  			MaxNbElementsToHash int   `json:"maxNbElementsToHash"`
    59  		} `json:"params"`
    60  		Expected [][]fr.Element `json:"expected"`
    61  	} `json:"entries"`
    62  }
    63  
    64  func TestReference(t *testing.T) {
    65  	if bits.UintSize == 32 {
    66  		t.Skip("skipping this test in 32bit.")
    67  	}
    68  	assert := require.New(t)
    69  
    70  	// read the test case file
    71  	var testCases TestCases
    72  	data, err := os.ReadFile("test_cases.json")
    73  	assert.NoError(err, "reading test cases failed")
    74  	err = json.Unmarshal(data, &testCases)
    75  	assert.NoError(err, "reading test cases failed")
    76  
    77  	for testCaseID, testCase := range testCases.Entries {
    78  		// create the SIS instance
    79  		sis, err := NewRSis(testCase.Params.Seed, testCase.Params.LogTwoDegree, testCase.Params.LogTwoBound, testCase.Params.MaxNbElementsToHash)
    80  		assert.NoError(err)
    81  
    82  		// key generation same than in sage
    83  		makeKeyDeterministic(t, sis, testCase.Params.Seed)
    84  
    85  		for i, in := range testCases.Inputs {
    86  			sis.Reset()
    87  
    88  			// hash test case entry input and compare with expected (computed by sage)
    89  			got, err := sis.Hash(in)
    90  			assert.NoError(err)
    91  			if len(testCase.Expected[i]) == 0 {
    92  				for _, e := range got {
    93  					assert.True(e.IsZero(), "mismatch between reference test and computed value")
    94  				}
    95  			} else {
    96  				assert.EqualValues(
    97  					testCase.Expected[i], got,
    98  					"mismatch between reference test and computed value (testcase %v - input n° %v)",
    99  					testCaseID, i,
   100  				)
   101  			}
   102  
   103  			// ensure max nb elements to hash has no incidence on result.
   104  			if len(in) < testCase.Params.MaxNbElementsToHash {
   105  				sis2, err := NewRSis(testCase.Params.Seed, testCase.Params.LogTwoDegree, testCase.Params.LogTwoBound, len(in))
   106  				assert.NoError(err)
   107  				makeKeyDeterministic(t, sis2, testCase.Params.Seed)
   108  
   109  				got2, err := sis2.Hash(in)
   110  				assert.NoError(err)
   111  				if len(testCase.Expected[i]) == 0 {
   112  					for _, e := range got2 {
   113  						assert.True(e.IsZero(), "mismatch between reference test and computed value")
   114  					}
   115  				} else {
   116  					assert.EqualValues(got, got2, "max nb elements to hash change SIS result")
   117  				}
   118  			}
   119  
   120  		}
   121  	}
   122  
   123  }
   124  
   125  func TestMulMod(t *testing.T) {
   126  
   127  	size := 4
   128  
   129  	p := make([]fr.Element, size)
   130  	p[0].SetString("2389")
   131  	p[1].SetString("987192")
   132  	p[2].SetString("623")
   133  	p[3].SetString("91")
   134  
   135  	q := make([]fr.Element, size)
   136  	q[0].SetString("76755")
   137  	q[1].SetString("232893720")
   138  	q[2].SetString("989273")
   139  	q[3].SetString("675273")
   140  
   141  	// creation of the domain. Note that the value of the shift is purely arbitrary
   142  	// and random.
   143  	var shift fr.Element
   144  	shift.SetString("19540430494807482326159819597004422086093766032135589407132600596362845576832")
   145  	domain := fft.NewDomain(uint64(size), fft.WithShift(shift))
   146  
   147  	// mul mod
   148  	domain.FFT(p, fft.DIF, fft.OnCoset())
   149  	domain.FFT(q, fft.DIF, fft.OnCoset())
   150  	r := mulMod(p, q)
   151  	domain.FFTInverse(r, fft.DIT, fft.OnCoset())
   152  
   153  	// expected result
   154  	expectedr := make([]fr.Element, 4)
   155  	expectedr[0].SetString("1612335717510792699655803640368030375030102334758014597273992900671209760644")
   156  	expectedr[1].SetString("7801939531701554412773487946871330804032103116849274522170062191886789772982")
   157  	expectedr[2].SetString("3958508915483304256780311126107405810447277056055916093890488600778971951103")
   158  	expectedr[3].SetString("1123315390878")
   159  
   160  	for i := 0; i < 4; i++ {
   161  		assert.Equal(t, expectedr[i].String(), r[i].String())
   162  	}
   163  }
   164  
   165  // Test the fact that the limb decomposition allows obtaining the original
   166  // field element by evaluating the polynomial whose the coeffiients are the
   167  // limbs.
   168  func TestLimbDecomposition(t *testing.T) {
   169  
   170  	// Skipping the test for 32 bits
   171  	if bits.UintSize == 32 {
   172  		t.Skip("skipping this test in 32bit.")
   173  	}
   174  
   175  	testcases := []struct {
   176  		logTwoDegree, logTwoBound int
   177  		vec                       fr.Vector
   178  	}{
   179  		{
   180  			logTwoDegree: 4,
   181  			logTwoBound:  4,
   182  			vec:          fr.Vector{fr.One()},
   183  		},
   184  		{
   185  			logTwoDegree: 4,
   186  			logTwoBound:  4,
   187  			vec:          fr.Vector{fr.NewElement(2)},
   188  		},
   189  		{
   190  			logTwoDegree: 4,
   191  			logTwoBound:  4,
   192  			vec:          fr.Vector{fr.NewElement(1 << 32), fr.NewElement(2), fr.NewElement(1)},
   193  		},
   194  		{
   195  			logTwoDegree: 4,
   196  			logTwoBound:  16,
   197  			vec:          fr.Vector{fr.One()},
   198  		},
   199  		{
   200  			logTwoDegree: 4,
   201  			logTwoBound:  16,
   202  			vec:          fr.Vector{fr.NewElement(2)},
   203  		},
   204  		{
   205  			logTwoDegree: 4,
   206  			logTwoBound:  16,
   207  			vec:          fr.Vector{fr.NewElement(1 << 32), fr.NewElement(2), fr.NewElement(1)},
   208  		},
   209  	}
   210  
   211  	for i, testcase := range testcases {
   212  
   213  		t.Run(fmt.Sprintf("testcase-%v", i), func(t *testing.T) {
   214  
   215  			t.Logf("testcase %v", testcase)
   216  
   217  			sis, _ := NewRSis(0, testcase.logTwoDegree, testcase.logTwoBound, 3)
   218  
   219  			// clean the sis hasher
   220  			sis.bufMValues.ClearAll()
   221  			for i := 0; i < len(sis.bufM); i++ {
   222  				sis.bufM[i].SetZero()
   223  			}
   224  			for i := 0; i < len(sis.bufRes); i++ {
   225  				sis.bufRes[i].SetZero()
   226  			}
   227  
   228  			buf := bytes.Buffer{}
   229  			for _, x := range testcase.vec {
   230  				xBytes := x.Bytes()
   231  				buf.Write(xBytes[:])
   232  			}
   233  
   234  			limbDecomposeBytes(buf.Bytes(), sis.bufM, sis.LogTwoBound, sis.Degree, sis.bufMValues)
   235  
   236  			// Just to test, this does not return panic
   237  			dummyBuffer := make(fr.Vector, 192)
   238  			LimbDecomposeBytes(buf.Bytes(), dummyBuffer, sis.LogTwoBound)
   239  
   240  			// b is a field element representing the max norm bound
   241  			// used for limb splitting the input field elements.
   242  			b := fr.NewElement(1 << sis.LogTwoBound)
   243  			numLimbsPerField := fr.Bytes * 8 / sis.LogTwoBound
   244  
   245  			// Compute r (corresponds to the Montgommery constant)
   246  			var r fr.Element
   247  			r.SetString("6014086494747379908336260804527802945383293308637734276299549080986809532403")
   248  
   249  			// Attempt to recompose the entry #i in the test-case
   250  			for i := range testcase.vec {
   251  				// allegedly corresponds to the limbs of the entry i
   252  				subRes := sis.bufM[i*numLimbsPerField : (i+1)*numLimbsPerField]
   253  
   254  				// performs a Horner evaluation of subres by b
   255  				var y fr.Element
   256  				for j := numLimbsPerField - 1; j >= 0; j-- {
   257  					y.Mul(&y, &b)
   258  					y.Add(&y, &subRes[j])
   259  				}
   260  
   261  				y.Mul(&y, &r)
   262  				require.Equal(t, testcase.vec[i].String(), y.String(), "the subRes was %v", subRes)
   263  			}
   264  		})
   265  
   266  	}
   267  }
   268  
   269  func makeKeyDeterministic(t *testing.T, sis *RSis, _seed int64) {
   270  	t.Helper()
   271  	// generate the key deterministically, the same way
   272  	// we do in sage to generate the test vectors.
   273  
   274  	polyRand := func(seed fr.Element, deg int) []fr.Element {
   275  		res := make([]fr.Element, deg)
   276  		for i := 0; i < deg; i++ {
   277  			res[i].Square(&seed)
   278  			seed.Set(&res[i])
   279  		}
   280  		return res
   281  	}
   282  
   283  	var seed, one fr.Element
   284  	one.SetOne()
   285  	seed.SetInt64(_seed)
   286  	for i := 0; i < len(sis.A); i++ {
   287  		sis.A[i] = polyRand(seed, sis.Degree)
   288  		copy(sis.Ag[i], sis.A[i])
   289  		sis.Domain.FFT(sis.Ag[i], fft.DIF, fft.OnCoset())
   290  		seed.Add(&seed, &one)
   291  	}
   292  }
   293  
   294  const (
   295  	LATENCY_MUL_FIELD_NS int = 18
   296  	LATENCY_ADD_FIELD_NS int = 4
   297  )
   298  
   299  // Estimate the theoretical performances that are achievable using ring-SIS
   300  // operations. The time is obtained by counting the number of additions and
   301  // multiplications occurring in the computation. This does not account for the
   302  // possibilities to use SIMD instructions or for cache-locality issues. Thus, it
   303  // does not represents a maximum even though it returns a good idea of what is
   304  // achievable . This returns performances in term of ns/field. This also does not
   305  // account for the time taken for "limb-splitting" the input.
   306  func estimateSisTheory(p sisParams) float64 {
   307  
   308  	// Since the FFT occurs over a coset, we need to multiply all the coefficients
   309  	// of the input by some coset factors (for an entire polynomial)
   310  	timeCosetShift := (1 << p.logTwoDegree) * LATENCY_MUL_FIELD_NS
   311  
   312  	// The two additions are from the butterfly, and the multiplication represents
   313  	// the one by the twiddle. (for an entire polynomial)
   314  	timeFFT := (1 << p.logTwoDegree) * p.logTwoDegree * (2*LATENCY_ADD_FIELD_NS + LATENCY_MUL_FIELD_NS)
   315  
   316  	// Time taken to multiply by the key and accumulate (for an entire polynomial)
   317  	timeMulAddKey := (1 << p.logTwoDegree) * (LATENCY_MUL_FIELD_NS + LATENCY_ADD_FIELD_NS)
   318  
   319  	// Total computation time for an entire polynomial
   320  	totalTimePoly := timeCosetShift + timeFFT + timeMulAddKey
   321  
   322  	// Convert this into a time per input field
   323  	r := totalTimePoly * fr.Bits / p.logTwoBound / (1 << p.logTwoDegree)
   324  	return float64(r)
   325  }
   326  
   327  func BenchmarkSIS(b *testing.B) {
   328  
   329  	// max nb field elements to hash
   330  	const nbInputs = 1 << 16
   331  
   332  	// Assign the input with random bytes. In practice, theses bytes encodes
   333  	// a string of field element. It would be more meaningful to take a slice
   334  	// of field element directly because otherwise the conversion time is not
   335  	// accounted for in the benchmark.
   336  	inputs := make(fr.Vector, nbInputs)
   337  	for i := 0; i < len(inputs); i++ {
   338  		inputs[i].SetRandom()
   339  	}
   340  
   341  	for _, param := range params128Bits {
   342  		for n := 1 << 10; n <= nbInputs; n <<= 1 {
   343  			in := inputs[:n]
   344  			benchmarkSIS(b, in, false, param.logTwoBound, param.logTwoDegree, estimateSisTheory(param))
   345  		}
   346  
   347  	}
   348  }
   349  
   350  func benchmarkSIS(b *testing.B, input []fr.Element, sparse bool, logTwoBound, logTwoDegree int, theoretical float64) {
   351  	b.Helper()
   352  
   353  	n := len(input)
   354  
   355  	benchName := "ring-sis/"
   356  	if sparse {
   357  		benchName += "sparse/"
   358  	}
   359  	benchName += fmt.Sprintf("inputs=%v/log2-bound=%v/log2-degree=%v", n, logTwoBound, logTwoDegree)
   360  
   361  	b.Run(benchName, func(b *testing.B) {
   362  		instance, err := NewRSis(0, logTwoDegree, logTwoBound, n)
   363  		if err != nil {
   364  			b.Fatal(err)
   365  		}
   366  
   367  		// We introduce a custom metric which is the time per field element
   368  		// Since the benchmark object allows to report extra meta but does
   369  		// not allow accessing them. We measure the time ourself.
   370  
   371  		startTime := time.Now()
   372  		b.ResetTimer()
   373  		for i := 0; i < b.N; i++ {
   374  			_, err = instance.Hash(input)
   375  			if err != nil {
   376  				b.Fatal(err)
   377  			}
   378  		}
   379  		b.StopTimer()
   380  
   381  		totalDuration := time.Since(startTime)
   382  		nsPerField := totalDuration.Nanoseconds() / int64(b.N) / int64(n)
   383  
   384  		b.ReportMetric(float64(nsPerField), "ns/field")
   385  
   386  		b.ReportMetric(theoretical, "ns/field(theory)")
   387  
   388  	})
   389  }
   390  
   391  // Hash interprets the input vector as a sequence of coefficients of size r.LogTwoBound bits long,
   392  // and return the hash of the polynomial corresponding to the sum sum_i A[i]*m Mod X^{d}+1
   393  //
   394  // It is equivalent to calling r.Write(element.Marshal()); outBytes = r.Sum(nil);
   395  // ! note @gbotrel: this is a place holder, may not make sense
   396  func (r *RSis) Hash(v []fr.Element) ([]fr.Element, error) {
   397  	if len(v) > r.maxNbElementsToHash {
   398  		return nil, fmt.Errorf("can't hash more than %d elements with params provided in constructor", r.maxNbElementsToHash)
   399  	}
   400  
   401  	r.Reset()
   402  	for _, e := range v {
   403  		r.Write(e.Marshal())
   404  	}
   405  	sum := r.Sum(nil)
   406  	var rlen [4]byte
   407  	binary.BigEndian.PutUint32(rlen[:], uint32(len(sum)/fr.Bytes))
   408  	reader := io.MultiReader(bytes.NewReader(rlen[:]), bytes.NewReader(sum))
   409  	var result fr.Vector
   410  	_, err := result.ReadFrom(reader)
   411  	if err != nil {
   412  		return nil, err
   413  	}
   414  	return result, nil
   415  }
   416  
   417  func TestLimbDecompositionFastPath(t *testing.T) {
   418  	assert := require.New(t)
   419  
   420  	for size := fr.Bytes; size < 5*fr.Bytes; size += fr.Bytes {
   421  		// Test the fast path of limbDecomposeBytes8_64
   422  		buf := make([]byte, size)
   423  		m := make([]fr.Element, size)
   424  		mValues := bitset.New(uint(size))
   425  		n := make([]fr.Element, size)
   426  		nValues := bitset.New(uint(size))
   427  
   428  		// Generate a random buffer
   429  		_, err := rand.Read(buf) //#nosec G404 weak rng is fine here
   430  		assert.NoError(err)
   431  
   432  		limbDecomposeBytes8_64(buf, m, mValues)
   433  		limbDecomposeBytes(buf, n, 8, 64, nValues)
   434  
   435  		for i := 0; i < size; i++ {
   436  			assert.Equal(mValues.Test(uint(i)), nValues.Test(uint(i)))
   437  			assert.True(m[i].Equal(&n[i]))
   438  		}
   439  	}
   440  
   441  }
   442  
   443  func TestUnrolledFFT(t *testing.T) {
   444  
   445  	var shift fr.Element
   446  	shift.SetString("8065159656716812877374967518403273466521432693661810619979959746626482506078") // -> 2^47-th root of unity of bls12377
   447  	e := int64(1 << (47 - (6 + 1)))
   448  	shift.Exp(shift, big.NewInt(e))
   449  
   450  	const size = 64
   451  	assert := require.New(t)
   452  	domain := fft.NewDomain(size, fft.WithShift(shift))
   453  
   454  	k1 := make([]fr.Element, size)
   455  	for i := 0; i < size; i++ {
   456  		k1[i].SetRandom()
   457  	}
   458  	k2 := make([]fr.Element, size)
   459  	copy(k2, k1)
   460  
   461  	// default FFT
   462  	domain.FFT(k1, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1))
   463  
   464  	// unrolled FFT
   465  	twiddlesCoset := PrecomputeTwiddlesCoset(domain.Generator, domain.FrMultiplicativeGen)
   466  	FFT64(k2, twiddlesCoset)
   467  
   468  	// compare results
   469  	for i := 0; i < size; i++ {
   470  		// fmt.Printf("i = %d, k1 = %v, k2 = %v\n", i, k1[i].String(), k2[i].String())
   471  		assert.True(k1[i].Equal(&k2[i]), "i = %d", i)
   472  	}
   473  }