github.com/consensys/gnark-crypto@v0.14.0/internal/generator/gkr/template/gkr.test.vectors.go.tmpl (about)

     1  {{define "gkrTestVectors"}}
     2  
     3  {{$GkrPackagePrefix := select .OutsideGkrPackage "" "gkr."}}
     4  {{$CheckOutputCorrectness := not .OutsideGkrPackage}}
     5  
     6  {{$Circuit          := print $GkrPackagePrefix "Circuit"}}
     7  {{$Gate             := print $GkrPackagePrefix "Gate"}}
     8  {{$Proof            := print $GkrPackagePrefix "Proof"}}
     9  {{$WireAssignment   := print $GkrPackagePrefix "WireAssignment"}}
    10  {{$Wire             := print $GkrPackagePrefix "Wire"}}
    11  {{$CircuitLayer     := print $GkrPackagePrefix "CircuitLayer"}}
    12  
    13  type WireInfo struct {
    14  	Gate   string `json:"gate"`
    15  	Inputs []int  `json:"inputs"`
    16  }
    17  
    18  type CircuitInfo []WireInfo
    19  
    20  var circuitCache = make(map[string]{{$Circuit}})
    21  
    22  func getCircuit(path string) ({{$Circuit}}, error) {
    23  	path, err := filepath.Abs(path)
    24  	if err != nil {
    25  		return nil, err
    26  	}
    27  	if circuit, ok := circuitCache[path]; ok {
    28  		return circuit, nil
    29  	}
    30  	var bytes []byte
    31  	if bytes, err = os.ReadFile(path); err == nil {
    32  		var circuitInfo CircuitInfo
    33  		if err = json.Unmarshal(bytes, &circuitInfo); err == nil {
    34  			circuit := circuitInfo.toCircuit()
    35  			circuitCache[path] = circuit
    36  			return circuit, nil
    37  		} else {
    38  			return nil, err
    39  		}
    40  	} else {
    41  		return nil, err
    42  	}
    43  }
    44  
    45  func (c CircuitInfo) toCircuit() (circuit {{$Circuit}}) {
    46  	circuit = make({{$Circuit}}, len(c))
    47  	for i := range c {
    48  		circuit[i].Gate = Gates[c[i].Gate]
    49  		circuit[i].Inputs = make([]*{{$Wire}}, len(c[i].Inputs))
    50  		for k, inputCoord := range c[i].Inputs {
    51  			input := &circuit[inputCoord]
    52  			circuit[i].Inputs[k] = input
    53  		}
    54  	}
    55  	return
    56  }
    57  
    58  func init() {
    59  	Gates["mimc"] = mimcCipherGate{} //TODO: Add ark
    60  	Gates["select-input-3"] = _select(2)
    61  }
    62  
    63  type mimcCipherGate struct {
    64  	ark {{.ElementType}}
    65  }
    66  
    67  func (m mimcCipherGate) Evaluate(input ...{{.ElementType}}) (res {{.ElementType}}) {
    68  	var sum {{.ElementType}}
    69  
    70  	sum.
    71  		Add(&input[0], &input[1]).
    72  		Add(&sum, &m.ark)
    73  
    74  	res.Square(&sum)    // sum^2
    75  	res.Mul(&res, &sum) // sum^3
    76  	res.Square(&res)    //sum^6
    77  	res.Mul(&res, &sum) //sum^7
    78  
    79  	return
    80  }
    81  
    82  func (m mimcCipherGate) Degree() int {
    83  	return 7
    84  }
    85  
    86  type PrintableProof []PrintableSumcheckProof
    87  
    88  type PrintableSumcheckProof struct {
    89  	FinalEvalProof  interface{}     `json:"finalEvalProof"`
    90  	PartialSumPolys [][]interface{} `json:"partialSumPolys"`
    91  }
    92  
    93  func unmarshalProof(printable PrintableProof) ({{$Proof}}, error) {
    94  	proof := make({{$Proof}}, len(printable))
    95  	for i := range printable {
    96  		finalEvalProof := []{{.ElementType}}(nil)
    97  
    98  		if printable[i].FinalEvalProof != nil {
    99  			finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof)
   100  			finalEvalProof = make([]{{.ElementType}}, finalEvalSlice.Len())
   101  			for k := range finalEvalProof {
   102  				if _, err := {{ setElement "finalEvalProof[k]" "finalEvalSlice.Index(k).Interface()" .ElementType}}; err != nil {
   103  					return nil, err
   104  				}
   105  			}
   106  		}
   107  
   108  		proof[i] = sumcheck.Proof{
   109  			PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)),
   110  			FinalEvalProof:  finalEvalProof,
   111  		}
   112  		for k := range printable[i].PartialSumPolys {
   113  			var err error
   114  			if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil {
   115  				return nil, err
   116  			}
   117  		}
   118  	}
   119  	return proof, nil
   120  }
   121  
   122  type TestCase struct {
   123  	Circuit         {{$Circuit}}
   124  	Hash            hash.Hash
   125  	Proof           {{$Proof}}
   126  	FullAssignment  {{$WireAssignment}}
   127  	InOutAssignment {{$WireAssignment}}
   128  	{{if .RetainTestCaseRawInfo}}Info TestCaseInfo{{end}}
   129  }
   130  
   131  type TestCaseInfo struct {
   132  	Hash    test_vector_utils.HashDescription `json:"hash"`
   133  	Circuit string                            `json:"circuit"`
   134  	Input   [][]interface{}                   `json:"input"`
   135  	Output  [][]interface{}                   `json:"output"`
   136  	Proof   PrintableProof                    `json:"proof"`
   137  }
   138  
   139  var testCases = make(map[string]*TestCase)
   140  
   141  func newTestCase(path string) (*TestCase, error) {
   142  	path, err := filepath.Abs(path)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	dir := filepath.Dir(path)
   147  
   148  	tCase, ok := testCases[path]
   149  	if !ok {
   150  		var bytes []byte
   151  		if bytes, err = os.ReadFile(path); err == nil {
   152  			var info TestCaseInfo
   153  			err = json.Unmarshal(bytes, &info)
   154  			if err != nil {
   155  				return nil, err
   156  			}
   157  
   158  			var circuit {{$Circuit}}
   159  			if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil {
   160  				return nil, err
   161  			}
   162  			var _hash hash.Hash
   163  			if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil {
   164  				return nil, err
   165  			}
   166  			var proof {{$Proof}}
   167  			if proof, err = unmarshalProof(info.Proof); err != nil {
   168  				return nil, err
   169  			}
   170  
   171  			fullAssignment := make({{$WireAssignment}})
   172  			inOutAssignment := make({{$WireAssignment}})
   173  
   174  			sorted := {{select .OutsideGkrPackage "t" "gkr.T"}}opologicalSort(circuit)
   175  
   176  			inI, outI := 0, 0
   177  			for _, w := range sorted {
   178  				var assignmentRaw []interface{}
   179  				if w.IsInput() {
   180  					if inI == len(info.Input) {
   181  						return nil, fmt.Errorf("fewer input in vector than in circuit")
   182  					}
   183  					assignmentRaw = info.Input[inI]
   184  					inI++
   185  				} else if w.IsOutput() {
   186  					if outI == len(info.Output) {
   187  						return nil, fmt.Errorf("fewer output in vector than in circuit")
   188  					}
   189  					assignmentRaw = info.Output[outI]
   190  					outI++
   191  				}
   192  				if assignmentRaw != nil {
   193  					var wireAssignment []{{.ElementType}}
   194  					if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil {
   195  					return nil, err
   196  					}
   197  
   198  					fullAssignment[w] = wireAssignment
   199  					inOutAssignment[w] = wireAssignment
   200  				}
   201  			}
   202  
   203  			fullAssignment.Complete(circuit)
   204  
   205  			{{if not $CheckOutputCorrectness}}
   206  				info.Output = make([][]interface{}, 0, outI)
   207  			{{end}}
   208  
   209  			for _, w := range sorted {
   210  				if w.IsOutput() {
   211  				{{if $CheckOutputCorrectness}}
   212  					if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil {
   213  						return nil, fmt.Errorf("assignment mismatch: %v", err)
   214  					}
   215  				{{else}}
   216  					info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w]))
   217  				{{end}}
   218  				}
   219  			}
   220  
   221  			tCase = &TestCase{
   222  				FullAssignment:  fullAssignment,
   223  				InOutAssignment: inOutAssignment,
   224  				Proof:           proof,
   225  				Hash:            _hash,
   226  				Circuit:         circuit,
   227  				{{if .RetainTestCaseRawInfo }}Info: info,{{end}}
   228  			}
   229  
   230  			testCases[path] = tCase
   231  		} else {
   232  			return nil, err
   233  		}
   234  	}
   235  
   236  	return tCase, nil
   237  }
   238  
   239  type _select int
   240  
   241  func (g _select) Evaluate(in ...{{.ElementType}}) {{.ElementType}} {
   242  	return in[g]
   243  }
   244  
   245  func (g _select) Degree() int {
   246  	return 1
   247  }
   248  
   249  {{end}}
   250  
   251  {{- define "setElement element value elementType"}}
   252  {{- if eq .elementType "fr.Element"}} test_vector_utils.SetElement(&{{.element}}, {{.value}})
   253  {{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}})
   254  {{- else}}
   255  {{print "\"UNEXPECTED TYPE" .elementType "\""}}
   256  {{- end}}
   257  {{- end}}