github.com/consensys/gnark-crypto@v0.14.0/ecc/bn254/fr/sis/sis_test.go (about)

     1  // Copyright 2023 ConsenSys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sis
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/rand"
    20  	"encoding/binary"
    21  	"encoding/json"
    22  	"fmt"
    23  	"io"
    24  	"math/big"
    25  	"math/bits"
    26  	"os"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/bits-and-blooms/bitset"
    31  	"github.com/consensys/gnark-crypto/ecc/bn254/fr"
    32  	"github.com/consensys/gnark-crypto/ecc/bn254/fr/fft"
    33  	"github.com/stretchr/testify/require"
    34  )
    35  
    36  type sisParams struct {
    37  	logTwoBound, logTwoDegree int
    38  }
    39  
    40  var params128Bits []sisParams = []sisParams{
    41  	{logTwoBound: 2, logTwoDegree: 3},
    42  	{logTwoBound: 4, logTwoDegree: 4},
    43  	{logTwoBound: 6, logTwoDegree: 5},
    44  	{logTwoBound: 8, logTwoDegree: 6},
    45  	{logTwoBound: 10, logTwoDegree: 6},
    46  	{logTwoBound: 16, logTwoDegree: 7},
    47  	{logTwoBound: 32, logTwoDegree: 8},
    48  }
    49  
    50  type TestCases struct {
    51  	Inputs  [][]fr.Element `json:"inputs"`
    52  	Entries []struct {
    53  		Params struct {
    54  			Seed                int64 `json:"seed"`
    55  			LogTwoDegree        int   `json:"logTwoDegree"`
    56  			LogTwoBound         int   `json:"logTwoBound"`
    57  			MaxNbElementsToHash int   `json:"maxNbElementsToHash"`
    58  		} `json:"params"`
    59  		Expected [][]fr.Element `json:"expected"`
    60  	} `json:"entries"`
    61  }
    62  
    63  func TestReference(t *testing.T) {
    64  	if bits.UintSize == 32 {
    65  		t.Skip("skipping this test in 32bit.")
    66  	}
    67  	assert := require.New(t)
    68  
    69  	// read the test case file
    70  	var testCases TestCases
    71  	data, err := os.ReadFile("test_cases.json")
    72  	assert.NoError(err, "reading test cases failed")
    73  	err = json.Unmarshal(data, &testCases)
    74  	assert.NoError(err, "reading test cases failed")
    75  
    76  	for testCaseID, testCase := range testCases.Entries {
    77  		// create the SIS instance
    78  		sis, err := NewRSis(testCase.Params.Seed, testCase.Params.LogTwoDegree, testCase.Params.LogTwoBound, testCase.Params.MaxNbElementsToHash)
    79  		assert.NoError(err)
    80  
    81  		// key generation same than in sage
    82  		makeKeyDeterministic(t, sis, testCase.Params.Seed)
    83  
    84  		for i, in := range testCases.Inputs {
    85  			sis.Reset()
    86  
    87  			// hash test case entry input and compare with expected (computed by sage)
    88  			got, err := sis.Hash(in)
    89  			assert.NoError(err)
    90  			if len(testCase.Expected[i]) == 0 {
    91  				for _, e := range got {
    92  					assert.True(e.IsZero(), "mismatch between reference test and computed value")
    93  				}
    94  			} else {
    95  				assert.EqualValues(
    96  					testCase.Expected[i], got,
    97  					"mismatch between reference test and computed value (testcase %v - input n° %v)",
    98  					testCaseID, i,
    99  				)
   100  			}
   101  
   102  			// ensure max nb elements to hash has no incidence on result.
   103  			if len(in) < testCase.Params.MaxNbElementsToHash {
   104  				sis2, err := NewRSis(testCase.Params.Seed, testCase.Params.LogTwoDegree, testCase.Params.LogTwoBound, len(in))
   105  				assert.NoError(err)
   106  				makeKeyDeterministic(t, sis2, testCase.Params.Seed)
   107  
   108  				got2, err := sis2.Hash(in)
   109  				assert.NoError(err)
   110  				if len(testCase.Expected[i]) == 0 {
   111  					for _, e := range got2 {
   112  						assert.True(e.IsZero(), "mismatch between reference test and computed value")
   113  					}
   114  				} else {
   115  					assert.EqualValues(got, got2, "max nb elements to hash change SIS result")
   116  				}
   117  			}
   118  
   119  		}
   120  	}
   121  
   122  }
   123  
   124  func TestMulMod(t *testing.T) {
   125  
   126  	size := 4
   127  
   128  	p := make([]fr.Element, size)
   129  	p[0].SetString("2389")
   130  	p[1].SetString("987192")
   131  	p[2].SetString("623")
   132  	p[3].SetString("91")
   133  
   134  	q := make([]fr.Element, size)
   135  	q[0].SetString("76755")
   136  	q[1].SetString("232893720")
   137  	q[2].SetString("989273")
   138  	q[3].SetString("675273")
   139  
   140  	// creation of the domain
   141  	var shift fr.Element
   142  	shift.SetString("19540430494807482326159819597004422086093766032135589407132600596362845576832")
   143  	domain := fft.NewDomain(uint64(size), fft.WithShift(shift))
   144  
   145  	// mul mod
   146  	domain.FFT(p, fft.DIF, fft.OnCoset())
   147  	domain.FFT(q, fft.DIF, fft.OnCoset())
   148  	r := mulMod(p, q)
   149  	domain.FFTInverse(r, fft.DIT, fft.OnCoset())
   150  
   151  	// expected result
   152  	expectedr := make([]fr.Element, 4)
   153  	expectedr[0].SetString("21888242871839275222246405745257275088548364400416034343698204185887558114297")
   154  	expectedr[1].SetString("631644300118")
   155  	expectedr[2].SetString("229913166975959")
   156  	expectedr[3].SetString("1123315390878")
   157  
   158  	for i := 0; i < 4; i++ {
   159  		if !expectedr[i].Equal(&r[i]) {
   160  			t.Fatal("product failed")
   161  		}
   162  	}
   163  
   164  }
   165  
   166  // Test the fact that the limb decomposition allows obtaining the original
   167  // field element by evaluating the polynomial whose the coeffiients are the
   168  // limbs.
   169  func TestLimbDecomposition(t *testing.T) {
   170  
   171  	// Skipping the test for 32 bits
   172  	if bits.UintSize == 32 {
   173  		t.Skip("skipping this test in 32bit.")
   174  	}
   175  
   176  	sis, _ := NewRSis(0, 4, 4, 3)
   177  
   178  	testcases := []fr.Vector{
   179  		{fr.One()},
   180  		{fr.NewElement(2)},
   181  		{fr.NewElement(1 << 32), fr.NewElement(2), fr.NewElement(1)},
   182  	}
   183  
   184  	for _, testcase := range testcases {
   185  
   186  		// clean the sis hasher
   187  		sis.bufMValues.ClearAll()
   188  		for i := 0; i < len(sis.bufM); i++ {
   189  			sis.bufM[i].SetZero()
   190  		}
   191  		for i := 0; i < len(sis.bufRes); i++ {
   192  			sis.bufRes[i].SetZero()
   193  		}
   194  
   195  		buf := bytes.Buffer{}
   196  		for _, x := range testcase {
   197  			xBytes := x.Bytes()
   198  			buf.Write(xBytes[:])
   199  		}
   200  		limbDecomposeBytes(buf.Bytes(), sis.bufM, sis.LogTwoBound, sis.Degree, sis.bufMValues)
   201  
   202  		// Just to test, this does not return panic
   203  		dummyBuffer := make(fr.Vector, 192)
   204  		LimbDecomposeBytes(buf.Bytes(), dummyBuffer, sis.LogTwoBound)
   205  
   206  		// b is a field element representing the max norm bound
   207  		// used for limb splitting the input field elements.
   208  		b := fr.NewElement(1 << sis.LogTwoBound)
   209  		numLimbsPerField := fr.Bytes * 8 / sis.LogTwoBound
   210  
   211  		// Compute r (corresponds to the Montgommery constant)
   212  		var r fr.Element
   213  		r.SetString("6350874878119819312338956282401532410528162663560392320966563075034087161851")
   214  
   215  		// Attempt to recompose the entry #i in the test-case
   216  		for i := range testcase {
   217  			// allegedly corresponds to the limbs of the entry i
   218  			subRes := sis.bufM[i*numLimbsPerField : (i+1)*numLimbsPerField]
   219  
   220  			// performs a Horner evaluation of subres by b
   221  			var y fr.Element
   222  			for j := numLimbsPerField - 1; j >= 0; j-- {
   223  				y.Mul(&y, &b)
   224  				y.Add(&y, &subRes[j])
   225  			}
   226  
   227  			y.Mul(&y, &r)
   228  			require.Equal(t, testcase[i].String(), y.String(), "the subRes was %v", subRes)
   229  		}
   230  	}
   231  }
   232  
   233  func makeKeyDeterministic(t *testing.T, sis *RSis, _seed int64) {
   234  	t.Helper()
   235  	// generate the key deterministically, the same way
   236  	// we do in sage to generate the test vectors.
   237  
   238  	polyRand := func(seed fr.Element, deg int) []fr.Element {
   239  		res := make([]fr.Element, deg)
   240  		for i := 0; i < deg; i++ {
   241  			res[i].Square(&seed)
   242  			seed.Set(&res[i])
   243  		}
   244  		return res
   245  	}
   246  
   247  	var seed, one fr.Element
   248  	one.SetOne()
   249  	seed.SetInt64(_seed)
   250  	for i := 0; i < len(sis.A); i++ {
   251  		sis.A[i] = polyRand(seed, sis.Degree)
   252  		copy(sis.Ag[i], sis.A[i])
   253  		sis.Domain.FFT(sis.Ag[i], fft.DIF, fft.OnCoset())
   254  		seed.Add(&seed, &one)
   255  	}
   256  }
   257  
   258  const (
   259  	LATENCY_MUL_FIELD_NS int = 18
   260  	LATENCY_ADD_FIELD_NS int = 4
   261  )
   262  
   263  // Estimate the theoretical performances that are achievable using ring-SIS
   264  // operations. The time is obtained by counting the number of additions and
   265  // multiplications occurring in the computation. This does not account for the
   266  // possibilities to use SIMD instructions or for cache-locality issues. Thus, it
   267  // does not represents a maximum even though it returns a good idea of what is
   268  // achievable . This returns performances in term of ns/field. This also does not
   269  // account for the time taken for "limb-splitting" the input.
   270  func estimateSisTheory(p sisParams) float64 {
   271  
   272  	// Since the FFT occurs over a coset, we need to multiply all the coefficients
   273  	// of the input by some coset factors (for an entire polynomial)
   274  	timeCosetShift := (1 << p.logTwoDegree) * LATENCY_MUL_FIELD_NS
   275  
   276  	// The two additions are from the butterfly, and the multiplication represents
   277  	// the one by the twiddle. (for an entire polynomial)
   278  	timeFFT := (1 << p.logTwoDegree) * p.logTwoDegree * (2*LATENCY_ADD_FIELD_NS + LATENCY_MUL_FIELD_NS)
   279  
   280  	// Time taken to multiply by the key and accumulate (for an entire polynomial)
   281  	timeMulAddKey := (1 << p.logTwoDegree) * (LATENCY_MUL_FIELD_NS + LATENCY_ADD_FIELD_NS)
   282  
   283  	// Total computation time for an entire polynomial
   284  	totalTimePoly := timeCosetShift + timeFFT + timeMulAddKey
   285  
   286  	// Convert this into a time per input field
   287  	r := totalTimePoly * fr.Bits / p.logTwoBound / (1 << p.logTwoDegree)
   288  	return float64(r)
   289  }
   290  
   291  func BenchmarkSIS(b *testing.B) {
   292  
   293  	// max nb field elements to hash
   294  	const nbInputs = 1 << 16
   295  
   296  	// Assign the input with random bytes. In practice, theses bytes encodes
   297  	// a string of field element. It would be more meaningful to take a slice
   298  	// of field element directly because otherwise the conversion time is not
   299  	// accounted for in the benchmark.
   300  	inputs := make(fr.Vector, nbInputs)
   301  	for i := 0; i < len(inputs); i++ {
   302  		inputs[i].SetRandom()
   303  	}
   304  
   305  	for _, param := range params128Bits {
   306  		for n := 1 << 10; n <= nbInputs; n <<= 1 {
   307  			in := inputs[:n]
   308  			benchmarkSIS(b, in, false, param.logTwoBound, param.logTwoDegree, estimateSisTheory(param))
   309  		}
   310  
   311  	}
   312  }
   313  
   314  func benchmarkSIS(b *testing.B, input []fr.Element, sparse bool, logTwoBound, logTwoDegree int, theoretical float64) {
   315  	b.Helper()
   316  
   317  	n := len(input)
   318  
   319  	benchName := "ring-sis/"
   320  	if sparse {
   321  		benchName += "sparse/"
   322  	}
   323  	benchName += fmt.Sprintf("inputs=%v/log2-bound=%v/log2-degree=%v", n, logTwoBound, logTwoDegree)
   324  
   325  	b.Run(benchName, func(b *testing.B) {
   326  		instance, err := NewRSis(0, logTwoDegree, logTwoBound, n)
   327  		if err != nil {
   328  			b.Fatal(err)
   329  		}
   330  
   331  		// We introduce a custom metric which is the time per field element
   332  		// Since the benchmark object allows to report extra meta but does
   333  		// not allow accessing them. We measure the time ourself.
   334  
   335  		startTime := time.Now()
   336  		b.ResetTimer()
   337  		for i := 0; i < b.N; i++ {
   338  			_, err = instance.Hash(input)
   339  			if err != nil {
   340  				b.Fatal(err)
   341  			}
   342  		}
   343  		b.StopTimer()
   344  
   345  		totalDuration := time.Since(startTime)
   346  		nsPerField := totalDuration.Nanoseconds() / int64(b.N) / int64(n)
   347  
   348  		b.ReportMetric(float64(nsPerField), "ns/field")
   349  
   350  		b.ReportMetric(theoretical, "ns/field(theory)")
   351  
   352  	})
   353  }
   354  
   355  // Hash interprets the input vector as a sequence of coefficients of size r.LogTwoBound bits long,
   356  // and return the hash of the polynomial corresponding to the sum sum_i A[i]*m Mod X^{d}+1
   357  //
   358  // It is equivalent to calling r.Write(element.Marshal()); outBytes = r.Sum(nil);
   359  // ! note @gbotrel: this is a place holder, may not make sense
   360  func (r *RSis) Hash(v []fr.Element) ([]fr.Element, error) {
   361  	if len(v) > r.maxNbElementsToHash {
   362  		return nil, fmt.Errorf("can't hash more than %d elements with params provided in constructor", r.maxNbElementsToHash)
   363  	}
   364  
   365  	r.Reset()
   366  	for _, e := range v {
   367  		r.Write(e.Marshal())
   368  	}
   369  	sum := r.Sum(nil)
   370  	var rlen [4]byte
   371  	binary.BigEndian.PutUint32(rlen[:], uint32(len(sum)/fr.Bytes))
   372  	reader := io.MultiReader(bytes.NewReader(rlen[:]), bytes.NewReader(sum))
   373  	var result fr.Vector
   374  	_, err := result.ReadFrom(reader)
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  	return result, nil
   379  }
   380  
   381  func TestLimbDecompositionFastPath(t *testing.T) {
   382  	assert := require.New(t)
   383  
   384  	for size := fr.Bytes; size < 5*fr.Bytes; size += fr.Bytes {
   385  		// Test the fast path of limbDecomposeBytes8_64
   386  		buf := make([]byte, size)
   387  		m := make([]fr.Element, size)
   388  		mValues := bitset.New(uint(size))
   389  		n := make([]fr.Element, size)
   390  		nValues := bitset.New(uint(size))
   391  
   392  		// Generate a random buffer
   393  		_, err := rand.Read(buf)
   394  		assert.NoError(err)
   395  
   396  		limbDecomposeBytes8_64(buf, m, mValues)
   397  		limbDecomposeBytes(buf, n, 8, 64, nValues)
   398  
   399  		for i := 0; i < size; i++ {
   400  			assert.Equal(mValues.Test(uint(i)), nValues.Test(uint(i)))
   401  			assert.True(m[i].Equal(&n[i]))
   402  		}
   403  	}
   404  
   405  }
   406  
   407  func TestUnrolledFFT(t *testing.T) {
   408  
   409  	var shift fr.Element
   410  	shift.SetString("19103219067921713944291392827692070036145651957329286315305642004821462161904") // -> 2²⁸-th root of unity of bn254
   411  	e := int64(1 << (28 - (6 + 1)))
   412  	shift.Exp(shift, big.NewInt(e))
   413  
   414  	const size = 64
   415  	assert := require.New(t)
   416  	domain := fft.NewDomain(size, fft.WithShift(shift))
   417  
   418  	k1 := make([]fr.Element, size)
   419  	for i := 0; i < size; i++ {
   420  		k1[i].SetRandom()
   421  	}
   422  	k2 := make([]fr.Element, size)
   423  	copy(k2, k1)
   424  
   425  	// default FFT
   426  	domain.FFT(k1, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1))
   427  
   428  	// unrolled FFT
   429  	twiddlesCoset := PrecomputeTwiddlesCoset(domain.Generator, domain.FrMultiplicativeGen)
   430  	FFT64(k2, twiddlesCoset)
   431  
   432  	// compare results
   433  	for i := 0; i < size; i++ {
   434  		// fmt.Printf("i = %d, k1 = %v, k2 = %v\n", i, k1[i].String(), k2[i].String())
   435  		assert.True(k1[i].Equal(&k2[i]), "i = %d", i)
   436  	}
   437  }