github.com/consensys/gnark-crypto@v0.14.0/internal/generator/gkr/template/gkr.test.go.tmpl (about)

     1  
     2  import (
     3  	"{{.FieldPackagePath}}"
     4  	"{{.FieldPackagePath}}/mimc"
     5  	"{{.FieldPackagePath}}/polynomial"
     6  	"{{.FieldPackagePath}}/sumcheck"
     7  	"{{.FieldPackagePath}}/test_vector_utils"
     8  	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
     9  	"github.com/consensys/gnark-crypto/utils"
    10  	"github.com/stretchr/testify/assert"
    11  	"fmt"
    12  	"hash"
    13  	"os"
    14  	"strconv"
    15  	"testing"
    16  	"path/filepath"
    17  	"encoding/json"
    18  	"reflect"
    19  	"time"
    20  )
    21  
    22  {{$GenerateLargeTests := .GenerateTests}} {{/* this is redundant. soon to be removed if a use case for it doesn't come back */}}
    23  {{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}}
    24  
    25  func TestNoGateTwoInstances(t *testing.T) {
    26  	// Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case
    27  	testNoGate(t, []{{.ElementType}}{four, three})
    28  }
    29  
    30  func TestNoGate(t *testing.T) {
    31  	testManyInstances(t, 1, testNoGate)
    32  }
    33  
    34  func TestSingleAddGateTwoInstances(t *testing.T) {
    35  	testSingleAddGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three})
    36  }
    37  
    38  func TestSingleAddGate(t *testing.T) {
    39  	testManyInstances(t, 2, testSingleAddGate)
    40  }
    41  
    42  func TestSingleMulGateTwoInstances(t *testing.T) {
    43  	testSingleMulGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three})
    44  }
    45  
    46  func TestSingleMulGate(t *testing.T) {
    47  	testManyInstances(t, 2, testSingleMulGate)
    48  }
    49  
    50  func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) {
    51  
    52  	testSingleInputTwoIdentityGates(t, []{{.ElementType}}{two, three})
    53  }
    54  
    55  func TestSingleInputTwoIdentityGates(t *testing.T) {
    56  
    57  	testManyInstances(t, 2, testSingleInputTwoIdentityGates)
    58  }
    59  
    60  func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) {
    61  	testSingleInputTwoIdentityGatesComposed(t, []{{.ElementType}}{two, one})
    62  }
    63  
    64  func TestSingleInputTwoIdentityGatesComposed(t *testing.T) {
    65  	testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed)
    66  }
    67  
    68  func TestSingleMimcCipherGateTwoInstances(t *testing.T) {
    69  	testSingleMimcCipherGate(t, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two})
    70  }
    71  
    72  func TestSingleMimcCipherGate(t *testing.T) {
    73  	testManyInstances(t, 2, testSingleMimcCipherGate)
    74  }
    75  
    76  func TestATimesBSquaredTwoInstances(t *testing.T) {
    77  	testATimesBSquared(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two})
    78  }
    79  
    80  func TestShallowMimcTwoInstances(t *testing.T) {
    81  	testMimc(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two})
    82  }
    83  
    84  {{- if $GenerateLargeTests}}
    85  func TestMimcTwoInstances(t *testing.T) {
    86  	testMimc(t, 93, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two})
    87  }
    88  
    89  func TestMimc(t *testing.T) {
    90  	testManyInstances(t, 2, generateTestMimc(93))
    91  }
    92  
    93  func generateTestMimc(numRounds int) func(*testing.T, ...[]{{.ElementType}}) {
    94  	return func(t *testing.T, inputAssignments ...[]{{.ElementType}}) {
    95  		testMimc(t, numRounds, inputAssignments...)
    96  	}
    97  }
    98  
    99  {{- end}}
   100  
   101  func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) {
   102  	circuit := Circuit{ Wire{
   103  		Gate:            IdentityGate{},
   104  		Inputs:          []*Wire{},
   105  		nbUniqueOutputs: 2,
   106  	} }
   107  
   108  	wire := &circuit[0]
   109  
   110  	assignment := WireAssignment{&circuit[0]: []{{.ElementType}}{two, three}}
   111  	var o settings
   112  	pool := polynomial.NewPool(256, 1<<11)
   113  	workers := utils.NewWorkerPool()
   114  	o.pool = &pool
   115  	o.workers = workers
   116  
   117  	claimsManagerGen := func() *claimsManager {
   118  		manager := newClaimsManager(circuit, assignment, o)
   119  		manager.add(wire, []{{.ElementType}}{three}, five)
   120  		manager.add(wire, []{{.ElementType}}{four}, six)
   121  		return &manager
   122  	}
   123  
   124  	transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1)
   125  
   126  	proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil))
   127  	assert.NoError(t, err)
   128  	err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil))
   129  	assert.NoError(t, err)
   130  }
   131  
   132  var one, two, three, four, five, six {{.ElementType}}
   133  
   134  func init() {
   135  	one.SetOne()
   136  	two.Double(&one)
   137  	three.Add(&two, &one)
   138  	four.Double(&two)
   139  	five.Add(&three, &two)
   140  	six.Double(&three)
   141  }
   142  
   143  var testManyInstancesLogMaxInstances = -1
   144  
   145  func getLogMaxInstances(t *testing.T) int {
   146  	if testManyInstancesLogMaxInstances == -1 {
   147  
   148  		s := os.Getenv("GKR_LOG_INSTANCES")
   149  		if s == "" {
   150  			testManyInstancesLogMaxInstances = 5
   151  		} else {
   152  			var err error
   153  			testManyInstancesLogMaxInstances, err = strconv.Atoi(s)
   154  			if err != nil {
   155  				t.Error(err)
   156  			}
   157  		}
   158  
   159  	}
   160  	return testManyInstancesLogMaxInstances
   161  }
   162  
   163  func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]{{.ElementType}})) {
   164  	fullAssignments := make([][]{{.ElementType}}, numInput)
   165  	maxSize := 1 << getLogMaxInstances(t)
   166  
   167  	t.Log("Entered test orchestrator, assigning and randomizing inputs")
   168  
   169  	for i := range fullAssignments {
   170  		fullAssignments[i] = make([]fr.Element, maxSize)
   171  		setRandom(fullAssignments[i])
   172  	}
   173  
   174  	inputAssignments := make([][]{{.ElementType}}, numInput)
   175  	for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 {
   176  		for i, fullAssignment := range fullAssignments {
   177  			inputAssignments[i] = fullAssignment[:numEvals]
   178  		}
   179  
   180  		t.Log("Selected inputs for test")
   181  		test(t, inputAssignments...)
   182  	}
   183  }
   184  
   185  func testNoGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) {
   186  	c := Circuit{
   187  		{
   188  			Inputs:     []*Wire{},
   189  			Gate:       nil,
   190  		},
   191  	}
   192  
   193  	assignment := WireAssignment{&c[0]: inputAssignments[0]}
   194  
   195  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   196  	assert.NoError(t, err)
   197  
   198  	// Even though a hash is called here, the proof is empty
   199  
   200  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   201  	assert.NoError(t, err, "proof rejected")
   202  }
   203  
   204  func testSingleAddGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) {
   205  	c := make(Circuit, 3)
   206  	c[2] = Wire{
   207  		Gate: Gates["add"],
   208  		Inputs: []*Wire{&c[0], &c[1]},
   209  	}
   210  
   211  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   212  
   213  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   214  	assert.NoError(t,err)
   215  
   216  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   217  	assert.NoError(t, err, "proof rejected")
   218  
   219  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   220  	assert.NotNil(t, err, "bad proof accepted")
   221  }
   222  
   223  func testSingleMulGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) {
   224  
   225  	c := make(Circuit, 3)
   226  	c[2] = Wire{
   227  		Gate:   Gates["mul"],
   228  		Inputs: []*Wire{&c[0], &c[1]},
   229  	}
   230  
   231  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   232  
   233  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   234  	assert.NoError(t, err)
   235  
   236  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   237  	assert.NoError(t, err, "proof rejected")
   238  
   239  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   240  	assert.NotNil(t, err, "bad proof accepted")
   241  }
   242  
   243  func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]{{.ElementType}}) {
   244  	c := make(Circuit, 3)
   245  
   246  	c[1] = Wire{
   247  		Gate:   IdentityGate{},
   248  		Inputs: []*Wire{&c[0]},
   249  	}
   250  
   251  	c[2] = Wire{
   252  		Gate:   IdentityGate{},
   253  		Inputs: []*Wire{&c[0]},
   254  	}
   255  
   256  	assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c)
   257  
   258  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   259  	assert.NoError(t, err)
   260  
   261  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   262  	assert.NoError(t, err, "proof rejected")
   263  
   264  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   265  	assert.NotNil(t, err, "bad proof accepted")
   266  }
   267  
   268  func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) {
   269  	c := make(Circuit, 3)
   270  
   271  	c[2] = Wire{
   272  		Gate:   mimcCipherGate{},
   273  		Inputs: []*Wire{&c[0], &c[1]},
   274  	}
   275  
   276  	t.Log("Evaluating all circuit wires")
   277  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   278  	t.Log("Circuit evaluation complete")
   279  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   280  	assert.NoError(t, err)
   281  	t.Log("Proof complete")
   282  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   283  	assert.NoError(t, err, "proof rejected")
   284  
   285  	t.Log("Successful verification complete")
   286  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   287  	assert.NotNil(t, err, "bad proof accepted")
   288  	t.Log("Unsuccessful verification complete")
   289  }
   290  
   291  func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]{{.ElementType}}) {
   292  	c := make(Circuit, 3)
   293  
   294  	c[1] = Wire{
   295  		Gate:   IdentityGate{},
   296  		Inputs: []*Wire{&c[0]},
   297  	}
   298  	c[2] = Wire{
   299  		Gate:   IdentityGate{},
   300  		Inputs: []*Wire{&c[1]},
   301  	}
   302  
   303  	assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c)
   304  
   305  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   306  	assert.NoError(t, err)
   307  
   308  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   309  	assert.NoError(t, err, "proof rejected")
   310  
   311  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   312  	assert.NotNil(t, err, "bad proof accepted")
   313  }
   314  
   315  func mimcCircuit(numRounds int) Circuit {
   316  	c := make(Circuit, numRounds+2)
   317  
   318  	for i := 2; i < len(c); i++ {
   319  		c[i] = Wire{
   320  			Gate:   mimcCipherGate{},
   321  			Inputs: []*Wire{&c[i-1], &c[0]},
   322  		}
   323  	}
   324  	return c
   325  }
   326  
   327  func testMimc(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) {
   328  	//TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b)
   329  	// @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10
   330  
   331  	c := mimcCircuit(numRounds)
   332  
   333  	t.Log("Evaluating all circuit wires")
   334  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   335  	t.Log("Circuit evaluation complete")
   336  
   337  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   338  	assert.NoError(t, err)
   339  
   340  	t.Log("Proof finished")
   341  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   342  	assert.NoError(t, err, "proof rejected")
   343  
   344  	t.Log("Successful verification finished")
   345  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   346  	assert.NotNil(t, err, "bad proof accepted")
   347  	t.Log("Unsuccessful verification finished")
   348  }
   349  
   350  func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) {
   351  	// This imitates the MiMC circuit
   352  
   353  	c := make(Circuit, numRounds+2)
   354  
   355  	for i := 2; i < len(c); i++ {
   356  		c[i] = Wire{
   357  			Gate:   Gates["mul"],
   358  			Inputs: []*Wire{&c[i-1], &c[0]},
   359  		}
   360  	}
   361  
   362  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   363  
   364  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   365  	assert.NoError(t, err)
   366  
   367  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   368  	assert.NoError(t, err, "proof rejected")
   369  
   370  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   371  	assert.NotNil(t, err, "bad proof accepted")
   372  }
   373  
   374  func setRandom(slice []{{.ElementType}}) {
   375  	for i := range slice {
   376  		slice[i].SetRandom()
   377  	}
   378  }
   379  
   380  func generateTestProver(path string) func(t *testing.T) {
   381  	return func(t *testing.T) {
   382  		testCase, err := newTestCase(path)
   383  		assert.NoError(t, err)
   384  		proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash))
   385  		assert.NoError(t, err)
   386  		assert.NoError(t, proofEquals(testCase.Proof, proof))
   387  	}
   388  }
   389  
   390  func generateTestVerifier(path string) func(t *testing.T) {
   391  	return func(t *testing.T) {
   392  		testCase, err := newTestCase(path)
   393  		assert.NoError(t, err)
   394  		err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash))
   395  		assert.NoError(t, err, "proof rejected")
   396  		testCase, err = newTestCase(path)
   397  		assert.NoError(t, err)
   398  		err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0)))
   399  		assert.NotNil(t, err, "bad proof accepted")
   400  	}
   401  }
   402  
   403  func TestGkrVectors(t *testing.T) {
   404  
   405  	testDirPath := "{{.TestVectorsRelativePath}}"
   406  	dirEntries, err := os.ReadDir(testDirPath)
   407  	assert.NoError(t, err)
   408  	for _, dirEntry := range dirEntries {
   409  		if !dirEntry.IsDir() {
   410  
   411  			if filepath.Ext(dirEntry.Name()) == ".json" {
   412  				path := filepath.Join(testDirPath, dirEntry.Name())
   413  				noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")]
   414  
   415  				t.Run(noExt+"_prover", generateTestProver(path))
   416  				t.Run(noExt+"_verifier", generateTestVerifier(path))
   417  
   418  			}
   419  		}
   420  	}
   421  }
   422  
   423  func proofEquals(expected Proof, seen Proof) error {
   424  	if len(expected) != len(seen) {
   425  		return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen))
   426  	}
   427  	for i, x := range expected {
   428  		xSeen := seen[i]
   429  
   430  		if xSeen.FinalEvalProof == nil {
   431  			if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 {
   432  				return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval))
   433  			}
   434  		} else {
   435  			if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil {
   436  				return fmt.Errorf("final evaluation proof mismatch")
   437  			}
   438  		}
   439  		if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil {
   440  			return err
   441  		}
   442  	}
   443  	return nil
   444  }
   445  
   446  func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) {
   447  	fmt.Println("creating circuit structure")
   448  	c := mimcCircuit(mimcDepth)
   449  
   450  	in0 := make([]fr.Element, nbInstances)
   451  	in1 := make([]fr.Element, nbInstances)
   452  	setRandom(in0)
   453  	setRandom(in1)
   454  
   455  	fmt.Println("evaluating circuit")
   456  	start := time.Now().UnixMicro()
   457  	assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c)
   458  	solved := time.Now().UnixMicro() - start
   459  	fmt.Println("solved in", solved, "μs")
   460  
   461  	//b.ResetTimer()
   462  	fmt.Println("constructing proof")
   463  	start = time.Now().UnixMicro()
   464  	_, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC()))
   465  	proved := time.Now().UnixMicro() - start
   466  	fmt.Println("proved in", proved, "μs")
   467  	assert.NoError(b, err)
   468  }
   469  
   470  func BenchmarkGkrMimc19(b *testing.B) {
   471  	benchmarkGkrMiMC(b, 1<<19, 91)
   472  }
   473  
   474  func BenchmarkGkrMimc17(b *testing.B) {
   475  	benchmarkGkrMiMC(b, 1<<17, 91)
   476  }
   477  
   478  func TestTopSortTrivial(t *testing.T) {
   479  	c := make(Circuit, 2)
   480  	c[0].Inputs = []*Wire{&c[1]}
   481  	sorted := {{$topologicalSort}}(c)
   482  	assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted)
   483  }
   484  
   485  func TestTopSortDeep(t *testing.T) {
   486  	c := make(Circuit, 4)
   487  	c[0].Inputs = []*Wire{&c[2]}
   488  	c[1].Inputs = []*Wire{&c[3]}
   489  	c[2].Inputs = []*Wire{}
   490  	c[3].Inputs = []*Wire{&c[0]}
   491  	sorted := {{$topologicalSort}}(c)
   492  	assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted)
   493  }
   494  
   495  func TestTopSortWide(t *testing.T) {
   496  	c := make(Circuit, 10)
   497  	c[0].Inputs = []*Wire{&c[3], &c[8]}
   498  	c[1].Inputs = []*Wire{&c[6]}
   499  	c[2].Inputs = []*Wire{&c[4]}
   500  	c[3].Inputs = []*Wire{}
   501  	c[4].Inputs = []*Wire{}
   502  	c[5].Inputs = []*Wire{&c[9]}
   503  	c[6].Inputs = []*Wire{&c[9]}
   504  	c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]}
   505  	c[8].Inputs = []*Wire{&c[4], &c[3]}
   506  	c[9].Inputs = []*Wire{}
   507  
   508  	sorted := {{$topologicalSort}}(c)
   509  	sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]}
   510  
   511  	assert.Equal(t, sortedExpected, sorted)
   512  }
   513  
   514  {{template "gkrTestVectors" .}}