github.com/consensys/gnark-crypto@v0.14.0/ecc/bls12-377/fr/gkr/gkr_test.go (about)

     1  // Copyright 2020 Consensys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Code generated by consensys/gnark-crypto DO NOT EDIT
    16  
    17  package gkr
    18  
    19  import (
    20  	"encoding/json"
    21  	"fmt"
    22  	"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
    23  	"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc"
    24  	"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial"
    25  	"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sumcheck"
    26  	"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/test_vector_utils"
    27  	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
    28  	"github.com/consensys/gnark-crypto/utils"
    29  	"github.com/stretchr/testify/assert"
    30  	"hash"
    31  	"os"
    32  	"path/filepath"
    33  	"reflect"
    34  	"strconv"
    35  	"testing"
    36  	"time"
    37  )
    38  
    39  func TestNoGateTwoInstances(t *testing.T) {
    40  	// Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case
    41  	testNoGate(t, []fr.Element{four, three})
    42  }
    43  
    44  func TestNoGate(t *testing.T) {
    45  	testManyInstances(t, 1, testNoGate)
    46  }
    47  
    48  func TestSingleAddGateTwoInstances(t *testing.T) {
    49  	testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three})
    50  }
    51  
    52  func TestSingleAddGate(t *testing.T) {
    53  	testManyInstances(t, 2, testSingleAddGate)
    54  }
    55  
    56  func TestSingleMulGateTwoInstances(t *testing.T) {
    57  	testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three})
    58  }
    59  
    60  func TestSingleMulGate(t *testing.T) {
    61  	testManyInstances(t, 2, testSingleMulGate)
    62  }
    63  
    64  func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) {
    65  
    66  	testSingleInputTwoIdentityGates(t, []fr.Element{two, three})
    67  }
    68  
    69  func TestSingleInputTwoIdentityGates(t *testing.T) {
    70  
    71  	testManyInstances(t, 2, testSingleInputTwoIdentityGates)
    72  }
    73  
    74  func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) {
    75  	testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one})
    76  }
    77  
    78  func TestSingleInputTwoIdentityGatesComposed(t *testing.T) {
    79  	testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed)
    80  }
    81  
    82  func TestSingleMimcCipherGateTwoInstances(t *testing.T) {
    83  	testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two})
    84  }
    85  
    86  func TestSingleMimcCipherGate(t *testing.T) {
    87  	testManyInstances(t, 2, testSingleMimcCipherGate)
    88  }
    89  
    90  func TestATimesBSquaredTwoInstances(t *testing.T) {
    91  	testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two})
    92  }
    93  
    94  func TestShallowMimcTwoInstances(t *testing.T) {
    95  	testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two})
    96  }
    97  func TestMimcTwoInstances(t *testing.T) {
    98  	testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two})
    99  }
   100  
   101  func TestMimc(t *testing.T) {
   102  	testManyInstances(t, 2, generateTestMimc(93))
   103  }
   104  
   105  func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) {
   106  	return func(t *testing.T, inputAssignments ...[]fr.Element) {
   107  		testMimc(t, numRounds, inputAssignments...)
   108  	}
   109  }
   110  
   111  func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) {
   112  	circuit := Circuit{Wire{
   113  		Gate:            IdentityGate{},
   114  		Inputs:          []*Wire{},
   115  		nbUniqueOutputs: 2,
   116  	}}
   117  
   118  	wire := &circuit[0]
   119  
   120  	assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}}
   121  	var o settings
   122  	pool := polynomial.NewPool(256, 1<<11)
   123  	workers := utils.NewWorkerPool()
   124  	o.pool = &pool
   125  	o.workers = workers
   126  
   127  	claimsManagerGen := func() *claimsManager {
   128  		manager := newClaimsManager(circuit, assignment, o)
   129  		manager.add(wire, []fr.Element{three}, five)
   130  		manager.add(wire, []fr.Element{four}, six)
   131  		return &manager
   132  	}
   133  
   134  	transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1)
   135  
   136  	proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil))
   137  	assert.NoError(t, err)
   138  	err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil))
   139  	assert.NoError(t, err)
   140  }
   141  
   142  var one, two, three, four, five, six fr.Element
   143  
   144  func init() {
   145  	one.SetOne()
   146  	two.Double(&one)
   147  	three.Add(&two, &one)
   148  	four.Double(&two)
   149  	five.Add(&three, &two)
   150  	six.Double(&three)
   151  }
   152  
   153  var testManyInstancesLogMaxInstances = -1
   154  
   155  func getLogMaxInstances(t *testing.T) int {
   156  	if testManyInstancesLogMaxInstances == -1 {
   157  
   158  		s := os.Getenv("GKR_LOG_INSTANCES")
   159  		if s == "" {
   160  			testManyInstancesLogMaxInstances = 5
   161  		} else {
   162  			var err error
   163  			testManyInstancesLogMaxInstances, err = strconv.Atoi(s)
   164  			if err != nil {
   165  				t.Error(err)
   166  			}
   167  		}
   168  
   169  	}
   170  	return testManyInstancesLogMaxInstances
   171  }
   172  
   173  func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) {
   174  	fullAssignments := make([][]fr.Element, numInput)
   175  	maxSize := 1 << getLogMaxInstances(t)
   176  
   177  	t.Log("Entered test orchestrator, assigning and randomizing inputs")
   178  
   179  	for i := range fullAssignments {
   180  		fullAssignments[i] = make([]fr.Element, maxSize)
   181  		setRandom(fullAssignments[i])
   182  	}
   183  
   184  	inputAssignments := make([][]fr.Element, numInput)
   185  	for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 {
   186  		for i, fullAssignment := range fullAssignments {
   187  			inputAssignments[i] = fullAssignment[:numEvals]
   188  		}
   189  
   190  		t.Log("Selected inputs for test")
   191  		test(t, inputAssignments...)
   192  	}
   193  }
   194  
   195  func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) {
   196  	c := Circuit{
   197  		{
   198  			Inputs: []*Wire{},
   199  			Gate:   nil,
   200  		},
   201  	}
   202  
   203  	assignment := WireAssignment{&c[0]: inputAssignments[0]}
   204  
   205  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   206  	assert.NoError(t, err)
   207  
   208  	// Even though a hash is called here, the proof is empty
   209  
   210  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   211  	assert.NoError(t, err, "proof rejected")
   212  }
   213  
   214  func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) {
   215  	c := make(Circuit, 3)
   216  	c[2] = Wire{
   217  		Gate:   Gates["add"],
   218  		Inputs: []*Wire{&c[0], &c[1]},
   219  	}
   220  
   221  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   222  
   223  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   224  	assert.NoError(t, err)
   225  
   226  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   227  	assert.NoError(t, err, "proof rejected")
   228  
   229  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   230  	assert.NotNil(t, err, "bad proof accepted")
   231  }
   232  
   233  func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) {
   234  
   235  	c := make(Circuit, 3)
   236  	c[2] = Wire{
   237  		Gate:   Gates["mul"],
   238  		Inputs: []*Wire{&c[0], &c[1]},
   239  	}
   240  
   241  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   242  
   243  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   244  	assert.NoError(t, err)
   245  
   246  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   247  	assert.NoError(t, err, "proof rejected")
   248  
   249  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   250  	assert.NotNil(t, err, "bad proof accepted")
   251  }
   252  
   253  func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) {
   254  	c := make(Circuit, 3)
   255  
   256  	c[1] = Wire{
   257  		Gate:   IdentityGate{},
   258  		Inputs: []*Wire{&c[0]},
   259  	}
   260  
   261  	c[2] = Wire{
   262  		Gate:   IdentityGate{},
   263  		Inputs: []*Wire{&c[0]},
   264  	}
   265  
   266  	assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c)
   267  
   268  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   269  	assert.NoError(t, err)
   270  
   271  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   272  	assert.NoError(t, err, "proof rejected")
   273  
   274  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   275  	assert.NotNil(t, err, "bad proof accepted")
   276  }
   277  
   278  func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) {
   279  	c := make(Circuit, 3)
   280  
   281  	c[2] = Wire{
   282  		Gate:   mimcCipherGate{},
   283  		Inputs: []*Wire{&c[0], &c[1]},
   284  	}
   285  
   286  	t.Log("Evaluating all circuit wires")
   287  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   288  	t.Log("Circuit evaluation complete")
   289  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   290  	assert.NoError(t, err)
   291  	t.Log("Proof complete")
   292  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   293  	assert.NoError(t, err, "proof rejected")
   294  
   295  	t.Log("Successful verification complete")
   296  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   297  	assert.NotNil(t, err, "bad proof accepted")
   298  	t.Log("Unsuccessful verification complete")
   299  }
   300  
   301  func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) {
   302  	c := make(Circuit, 3)
   303  
   304  	c[1] = Wire{
   305  		Gate:   IdentityGate{},
   306  		Inputs: []*Wire{&c[0]},
   307  	}
   308  	c[2] = Wire{
   309  		Gate:   IdentityGate{},
   310  		Inputs: []*Wire{&c[1]},
   311  	}
   312  
   313  	assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c)
   314  
   315  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   316  	assert.NoError(t, err)
   317  
   318  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   319  	assert.NoError(t, err, "proof rejected")
   320  
   321  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   322  	assert.NotNil(t, err, "bad proof accepted")
   323  }
   324  
   325  func mimcCircuit(numRounds int) Circuit {
   326  	c := make(Circuit, numRounds+2)
   327  
   328  	for i := 2; i < len(c); i++ {
   329  		c[i] = Wire{
   330  			Gate:   mimcCipherGate{},
   331  			Inputs: []*Wire{&c[i-1], &c[0]},
   332  		}
   333  	}
   334  	return c
   335  }
   336  
   337  func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) {
   338  	//TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b)
   339  	// @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10
   340  
   341  	c := mimcCircuit(numRounds)
   342  
   343  	t.Log("Evaluating all circuit wires")
   344  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   345  	t.Log("Circuit evaluation complete")
   346  
   347  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   348  	assert.NoError(t, err)
   349  
   350  	t.Log("Proof finished")
   351  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   352  	assert.NoError(t, err, "proof rejected")
   353  
   354  	t.Log("Successful verification finished")
   355  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   356  	assert.NotNil(t, err, "bad proof accepted")
   357  	t.Log("Unsuccessful verification finished")
   358  }
   359  
   360  func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) {
   361  	// This imitates the MiMC circuit
   362  
   363  	c := make(Circuit, numRounds+2)
   364  
   365  	for i := 2; i < len(c); i++ {
   366  		c[i] = Wire{
   367  			Gate:   Gates["mul"],
   368  			Inputs: []*Wire{&c[i-1], &c[0]},
   369  		}
   370  	}
   371  
   372  	assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c)
   373  
   374  	proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   375  	assert.NoError(t, err)
   376  
   377  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1)))
   378  	assert.NoError(t, err, "proof rejected")
   379  
   380  	err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1)))
   381  	assert.NotNil(t, err, "bad proof accepted")
   382  }
   383  
   384  func setRandom(slice []fr.Element) {
   385  	for i := range slice {
   386  		slice[i].SetRandom()
   387  	}
   388  }
   389  
   390  func generateTestProver(path string) func(t *testing.T) {
   391  	return func(t *testing.T) {
   392  		testCase, err := newTestCase(path)
   393  		assert.NoError(t, err)
   394  		proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash))
   395  		assert.NoError(t, err)
   396  		assert.NoError(t, proofEquals(testCase.Proof, proof))
   397  	}
   398  }
   399  
   400  func generateTestVerifier(path string) func(t *testing.T) {
   401  	return func(t *testing.T) {
   402  		testCase, err := newTestCase(path)
   403  		assert.NoError(t, err)
   404  		err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash))
   405  		assert.NoError(t, err, "proof rejected")
   406  		testCase, err = newTestCase(path)
   407  		assert.NoError(t, err)
   408  		err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0)))
   409  		assert.NotNil(t, err, "bad proof accepted")
   410  	}
   411  }
   412  
   413  func TestGkrVectors(t *testing.T) {
   414  
   415  	testDirPath := "../../../../internal/generator/gkr/test_vectors"
   416  	dirEntries, err := os.ReadDir(testDirPath)
   417  	assert.NoError(t, err)
   418  	for _, dirEntry := range dirEntries {
   419  		if !dirEntry.IsDir() {
   420  
   421  			if filepath.Ext(dirEntry.Name()) == ".json" {
   422  				path := filepath.Join(testDirPath, dirEntry.Name())
   423  				noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")]
   424  
   425  				t.Run(noExt+"_prover", generateTestProver(path))
   426  				t.Run(noExt+"_verifier", generateTestVerifier(path))
   427  
   428  			}
   429  		}
   430  	}
   431  }
   432  
   433  func proofEquals(expected Proof, seen Proof) error {
   434  	if len(expected) != len(seen) {
   435  		return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen))
   436  	}
   437  	for i, x := range expected {
   438  		xSeen := seen[i]
   439  
   440  		if xSeen.FinalEvalProof == nil {
   441  			if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 {
   442  				return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval))
   443  			}
   444  		} else {
   445  			if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil {
   446  				return fmt.Errorf("final evaluation proof mismatch")
   447  			}
   448  		}
   449  		if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil {
   450  			return err
   451  		}
   452  	}
   453  	return nil
   454  }
   455  
   456  func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) {
   457  	fmt.Println("creating circuit structure")
   458  	c := mimcCircuit(mimcDepth)
   459  
   460  	in0 := make([]fr.Element, nbInstances)
   461  	in1 := make([]fr.Element, nbInstances)
   462  	setRandom(in0)
   463  	setRandom(in1)
   464  
   465  	fmt.Println("evaluating circuit")
   466  	start := time.Now().UnixMicro()
   467  	assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c)
   468  	solved := time.Now().UnixMicro() - start
   469  	fmt.Println("solved in", solved, "μs")
   470  
   471  	//b.ResetTimer()
   472  	fmt.Println("constructing proof")
   473  	start = time.Now().UnixMicro()
   474  	_, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC()))
   475  	proved := time.Now().UnixMicro() - start
   476  	fmt.Println("proved in", proved, "μs")
   477  	assert.NoError(b, err)
   478  }
   479  
   480  func BenchmarkGkrMimc19(b *testing.B) {
   481  	benchmarkGkrMiMC(b, 1<<19, 91)
   482  }
   483  
   484  func BenchmarkGkrMimc17(b *testing.B) {
   485  	benchmarkGkrMiMC(b, 1<<17, 91)
   486  }
   487  
   488  func TestTopSortTrivial(t *testing.T) {
   489  	c := make(Circuit, 2)
   490  	c[0].Inputs = []*Wire{&c[1]}
   491  	sorted := topologicalSort(c)
   492  	assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted)
   493  }
   494  
   495  func TestTopSortDeep(t *testing.T) {
   496  	c := make(Circuit, 4)
   497  	c[0].Inputs = []*Wire{&c[2]}
   498  	c[1].Inputs = []*Wire{&c[3]}
   499  	c[2].Inputs = []*Wire{}
   500  	c[3].Inputs = []*Wire{&c[0]}
   501  	sorted := topologicalSort(c)
   502  	assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted)
   503  }
   504  
   505  func TestTopSortWide(t *testing.T) {
   506  	c := make(Circuit, 10)
   507  	c[0].Inputs = []*Wire{&c[3], &c[8]}
   508  	c[1].Inputs = []*Wire{&c[6]}
   509  	c[2].Inputs = []*Wire{&c[4]}
   510  	c[3].Inputs = []*Wire{}
   511  	c[4].Inputs = []*Wire{}
   512  	c[5].Inputs = []*Wire{&c[9]}
   513  	c[6].Inputs = []*Wire{&c[9]}
   514  	c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]}
   515  	c[8].Inputs = []*Wire{&c[4], &c[3]}
   516  	c[9].Inputs = []*Wire{}
   517  
   518  	sorted := topologicalSort(c)
   519  	sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]}
   520  
   521  	assert.Equal(t, sortedExpected, sorted)
   522  }
   523  
   524  type WireInfo struct {
   525  	Gate   string `json:"gate"`
   526  	Inputs []int  `json:"inputs"`
   527  }
   528  
   529  type CircuitInfo []WireInfo
   530  
   531  var circuitCache = make(map[string]Circuit)
   532  
   533  func getCircuit(path string) (Circuit, error) {
   534  	path, err := filepath.Abs(path)
   535  	if err != nil {
   536  		return nil, err
   537  	}
   538  	if circuit, ok := circuitCache[path]; ok {
   539  		return circuit, nil
   540  	}
   541  	var bytes []byte
   542  	if bytes, err = os.ReadFile(path); err == nil {
   543  		var circuitInfo CircuitInfo
   544  		if err = json.Unmarshal(bytes, &circuitInfo); err == nil {
   545  			circuit := circuitInfo.toCircuit()
   546  			circuitCache[path] = circuit
   547  			return circuit, nil
   548  		} else {
   549  			return nil, err
   550  		}
   551  	} else {
   552  		return nil, err
   553  	}
   554  }
   555  
   556  func (c CircuitInfo) toCircuit() (circuit Circuit) {
   557  	circuit = make(Circuit, len(c))
   558  	for i := range c {
   559  		circuit[i].Gate = Gates[c[i].Gate]
   560  		circuit[i].Inputs = make([]*Wire, len(c[i].Inputs))
   561  		for k, inputCoord := range c[i].Inputs {
   562  			input := &circuit[inputCoord]
   563  			circuit[i].Inputs[k] = input
   564  		}
   565  	}
   566  	return
   567  }
   568  
   569  func init() {
   570  	Gates["mimc"] = mimcCipherGate{} //TODO: Add ark
   571  	Gates["select-input-3"] = _select(2)
   572  }
   573  
   574  type mimcCipherGate struct {
   575  	ark fr.Element
   576  }
   577  
   578  func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) {
   579  	var sum fr.Element
   580  
   581  	sum.
   582  		Add(&input[0], &input[1]).
   583  		Add(&sum, &m.ark)
   584  
   585  	res.Square(&sum)    // sum^2
   586  	res.Mul(&res, &sum) // sum^3
   587  	res.Square(&res)    //sum^6
   588  	res.Mul(&res, &sum) //sum^7
   589  
   590  	return
   591  }
   592  
   593  func (m mimcCipherGate) Degree() int {
   594  	return 7
   595  }
   596  
   597  type PrintableProof []PrintableSumcheckProof
   598  
   599  type PrintableSumcheckProof struct {
   600  	FinalEvalProof  interface{}     `json:"finalEvalProof"`
   601  	PartialSumPolys [][]interface{} `json:"partialSumPolys"`
   602  }
   603  
   604  func unmarshalProof(printable PrintableProof) (Proof, error) {
   605  	proof := make(Proof, len(printable))
   606  	for i := range printable {
   607  		finalEvalProof := []fr.Element(nil)
   608  
   609  		if printable[i].FinalEvalProof != nil {
   610  			finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof)
   611  			finalEvalProof = make([]fr.Element, finalEvalSlice.Len())
   612  			for k := range finalEvalProof {
   613  				if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil {
   614  					return nil, err
   615  				}
   616  			}
   617  		}
   618  
   619  		proof[i] = sumcheck.Proof{
   620  			PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)),
   621  			FinalEvalProof:  finalEvalProof,
   622  		}
   623  		for k := range printable[i].PartialSumPolys {
   624  			var err error
   625  			if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil {
   626  				return nil, err
   627  			}
   628  		}
   629  	}
   630  	return proof, nil
   631  }
   632  
   633  type TestCase struct {
   634  	Circuit         Circuit
   635  	Hash            hash.Hash
   636  	Proof           Proof
   637  	FullAssignment  WireAssignment
   638  	InOutAssignment WireAssignment
   639  }
   640  
   641  type TestCaseInfo struct {
   642  	Hash    test_vector_utils.HashDescription `json:"hash"`
   643  	Circuit string                            `json:"circuit"`
   644  	Input   [][]interface{}                   `json:"input"`
   645  	Output  [][]interface{}                   `json:"output"`
   646  	Proof   PrintableProof                    `json:"proof"`
   647  }
   648  
   649  var testCases = make(map[string]*TestCase)
   650  
   651  func newTestCase(path string) (*TestCase, error) {
   652  	path, err := filepath.Abs(path)
   653  	if err != nil {
   654  		return nil, err
   655  	}
   656  	dir := filepath.Dir(path)
   657  
   658  	tCase, ok := testCases[path]
   659  	if !ok {
   660  		var bytes []byte
   661  		if bytes, err = os.ReadFile(path); err == nil {
   662  			var info TestCaseInfo
   663  			err = json.Unmarshal(bytes, &info)
   664  			if err != nil {
   665  				return nil, err
   666  			}
   667  
   668  			var circuit Circuit
   669  			if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil {
   670  				return nil, err
   671  			}
   672  			var _hash hash.Hash
   673  			if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil {
   674  				return nil, err
   675  			}
   676  			var proof Proof
   677  			if proof, err = unmarshalProof(info.Proof); err != nil {
   678  				return nil, err
   679  			}
   680  
   681  			fullAssignment := make(WireAssignment)
   682  			inOutAssignment := make(WireAssignment)
   683  
   684  			sorted := topologicalSort(circuit)
   685  
   686  			inI, outI := 0, 0
   687  			for _, w := range sorted {
   688  				var assignmentRaw []interface{}
   689  				if w.IsInput() {
   690  					if inI == len(info.Input) {
   691  						return nil, fmt.Errorf("fewer input in vector than in circuit")
   692  					}
   693  					assignmentRaw = info.Input[inI]
   694  					inI++
   695  				} else if w.IsOutput() {
   696  					if outI == len(info.Output) {
   697  						return nil, fmt.Errorf("fewer output in vector than in circuit")
   698  					}
   699  					assignmentRaw = info.Output[outI]
   700  					outI++
   701  				}
   702  				if assignmentRaw != nil {
   703  					var wireAssignment []fr.Element
   704  					if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil {
   705  						return nil, err
   706  					}
   707  
   708  					fullAssignment[w] = wireAssignment
   709  					inOutAssignment[w] = wireAssignment
   710  				}
   711  			}
   712  
   713  			fullAssignment.Complete(circuit)
   714  
   715  			for _, w := range sorted {
   716  				if w.IsOutput() {
   717  
   718  					if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil {
   719  						return nil, fmt.Errorf("assignment mismatch: %v", err)
   720  					}
   721  
   722  				}
   723  			}
   724  
   725  			tCase = &TestCase{
   726  				FullAssignment:  fullAssignment,
   727  				InOutAssignment: inOutAssignment,
   728  				Proof:           proof,
   729  				Hash:            _hash,
   730  				Circuit:         circuit,
   731  			}
   732  
   733  			testCases[path] = tCase
   734  		} else {
   735  			return nil, err
   736  		}
   737  	}
   738  
   739  	return tCase, nil
   740  }
   741  
   742  type _select int
   743  
   744  func (g _select) Evaluate(in ...fr.Element) fr.Element {
   745  	return in[g]
   746  }
   747  
   748  func (g _select) Degree() int {
   749  	return 1
   750  }