github.com/consensys/gnark-crypto@v0.14.0/internal/generator/gkr/test_vectors/main.go (about)

     1  // Copyright 2020 Consensys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Code generated by consensys/gnark-crypto DO NOT EDIT
    16  
    17  package main
    18  
    19  import (
    20  	"encoding/json"
    21  	"fmt"
    22  	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
    23  	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational"
    24  	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr"
    25  	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial"
    26  	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck"
    27  	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils"
    28  	"hash"
    29  	"os"
    30  	"path/filepath"
    31  	"reflect"
    32  )
    33  
    34  func main() {
    35  	if err := GenerateVectors(); err != nil {
    36  		fmt.Println(err.Error())
    37  		os.Exit(-1)
    38  	}
    39  }
    40  
    41  func GenerateVectors() error {
    42  	testDirPath, err := filepath.Abs("gkr/test_vectors")
    43  	if err != nil {
    44  		return err
    45  	}
    46  
    47  	fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath)
    48  
    49  	dirEntries, err := os.ReadDir(testDirPath)
    50  	if err != nil {
    51  		return err
    52  	}
    53  	for _, dirEntry := range dirEntries {
    54  		if !dirEntry.IsDir() {
    55  
    56  			if filepath.Ext(dirEntry.Name()) == ".json" {
    57  				fmt.Println("\tprocessing", dirEntry.Name())
    58  				path := filepath.Join(testDirPath, dirEntry.Name())
    59  				if err = run(path); err != nil {
    60  					return err
    61  				}
    62  			}
    63  		}
    64  	}
    65  
    66  	return nil
    67  }
    68  
    69  func run(absPath string) error {
    70  	testCase, err := newTestCase(absPath)
    71  	if err != nil {
    72  		return err
    73  	}
    74  
    75  	transcriptSetting := fiatshamir.WithHash(testCase.Hash)
    76  
    77  	var proof gkr.Proof
    78  	proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting)
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	if testCase.Info.Proof, err = toPrintableProof(proof); err != nil {
    84  		return err
    85  	}
    86  	var outBytes []byte
    87  	if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil {
    88  		if err = os.WriteFile(absPath, outBytes, 0); err != nil {
    89  			return err
    90  		}
    91  	} else {
    92  		return err
    93  	}
    94  
    95  	testCase, err = newTestCase(absPath)
    96  	if err != nil {
    97  		return err
    98  	}
    99  
   100  	err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting)
   101  	if err != nil {
   102  		return err
   103  	}
   104  
   105  	testCase, err = newTestCase(absPath)
   106  	if err != nil {
   107  		return err
   108  	}
   109  
   110  	err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0)))
   111  	if err == nil {
   112  		return fmt.Errorf("bad proof accepted")
   113  	}
   114  	return nil
   115  }
   116  
   117  func toPrintableProof(proof gkr.Proof) (PrintableProof, error) {
   118  	res := make(PrintableProof, len(proof))
   119  
   120  	for i := range proof {
   121  
   122  		partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys))
   123  		for k, partialK := range proof[i].PartialSumPolys {
   124  			partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK)
   125  		}
   126  
   127  		res[i] = PrintableSumcheckProof{
   128  			FinalEvalProof:  test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof),
   129  			PartialSumPolys: partialSumPolys,
   130  		}
   131  	}
   132  	return res, nil
   133  }
   134  
   135  var Gates = gkr.Gates
   136  
   137  type WireInfo struct {
   138  	Gate   string `json:"gate"`
   139  	Inputs []int  `json:"inputs"`
   140  }
   141  
   142  type CircuitInfo []WireInfo
   143  
   144  var circuitCache = make(map[string]gkr.Circuit)
   145  
   146  func getCircuit(path string) (gkr.Circuit, error) {
   147  	path, err := filepath.Abs(path)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	if circuit, ok := circuitCache[path]; ok {
   152  		return circuit, nil
   153  	}
   154  	var bytes []byte
   155  	if bytes, err = os.ReadFile(path); err == nil {
   156  		var circuitInfo CircuitInfo
   157  		if err = json.Unmarshal(bytes, &circuitInfo); err == nil {
   158  			circuit := circuitInfo.toCircuit()
   159  			circuitCache[path] = circuit
   160  			return circuit, nil
   161  		} else {
   162  			return nil, err
   163  		}
   164  	} else {
   165  		return nil, err
   166  	}
   167  }
   168  
   169  func (c CircuitInfo) toCircuit() (circuit gkr.Circuit) {
   170  	circuit = make(gkr.Circuit, len(c))
   171  	for i := range c {
   172  		circuit[i].Gate = Gates[c[i].Gate]
   173  		circuit[i].Inputs = make([]*gkr.Wire, len(c[i].Inputs))
   174  		for k, inputCoord := range c[i].Inputs {
   175  			input := &circuit[inputCoord]
   176  			circuit[i].Inputs[k] = input
   177  		}
   178  	}
   179  	return
   180  }
   181  
   182  func init() {
   183  	Gates["mimc"] = mimcCipherGate{} //TODO: Add ark
   184  	Gates["select-input-3"] = _select(2)
   185  }
   186  
   187  type mimcCipherGate struct {
   188  	ark small_rational.SmallRational
   189  }
   190  
   191  func (m mimcCipherGate) Evaluate(input ...small_rational.SmallRational) (res small_rational.SmallRational) {
   192  	var sum small_rational.SmallRational
   193  
   194  	sum.
   195  		Add(&input[0], &input[1]).
   196  		Add(&sum, &m.ark)
   197  
   198  	res.Square(&sum)    // sum^2
   199  	res.Mul(&res, &sum) // sum^3
   200  	res.Square(&res)    //sum^6
   201  	res.Mul(&res, &sum) //sum^7
   202  
   203  	return
   204  }
   205  
   206  func (m mimcCipherGate) Degree() int {
   207  	return 7
   208  }
   209  
   210  type PrintableProof []PrintableSumcheckProof
   211  
   212  type PrintableSumcheckProof struct {
   213  	FinalEvalProof  interface{}     `json:"finalEvalProof"`
   214  	PartialSumPolys [][]interface{} `json:"partialSumPolys"`
   215  }
   216  
   217  func unmarshalProof(printable PrintableProof) (gkr.Proof, error) {
   218  	proof := make(gkr.Proof, len(printable))
   219  	for i := range printable {
   220  		finalEvalProof := []small_rational.SmallRational(nil)
   221  
   222  		if printable[i].FinalEvalProof != nil {
   223  			finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof)
   224  			finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len())
   225  			for k := range finalEvalProof {
   226  				if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil {
   227  					return nil, err
   228  				}
   229  			}
   230  		}
   231  
   232  		proof[i] = sumcheck.Proof{
   233  			PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)),
   234  			FinalEvalProof:  finalEvalProof,
   235  		}
   236  		for k := range printable[i].PartialSumPolys {
   237  			var err error
   238  			if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil {
   239  				return nil, err
   240  			}
   241  		}
   242  	}
   243  	return proof, nil
   244  }
   245  
   246  type TestCase struct {
   247  	Circuit         gkr.Circuit
   248  	Hash            hash.Hash
   249  	Proof           gkr.Proof
   250  	FullAssignment  gkr.WireAssignment
   251  	InOutAssignment gkr.WireAssignment
   252  	Info            TestCaseInfo
   253  }
   254  
   255  type TestCaseInfo struct {
   256  	Hash    test_vector_utils.HashDescription `json:"hash"`
   257  	Circuit string                            `json:"circuit"`
   258  	Input   [][]interface{}                   `json:"input"`
   259  	Output  [][]interface{}                   `json:"output"`
   260  	Proof   PrintableProof                    `json:"proof"`
   261  }
   262  
   263  var testCases = make(map[string]*TestCase)
   264  
   265  func newTestCase(path string) (*TestCase, error) {
   266  	path, err := filepath.Abs(path)
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  	dir := filepath.Dir(path)
   271  
   272  	tCase, ok := testCases[path]
   273  	if !ok {
   274  		var bytes []byte
   275  		if bytes, err = os.ReadFile(path); err == nil {
   276  			var info TestCaseInfo
   277  			err = json.Unmarshal(bytes, &info)
   278  			if err != nil {
   279  				return nil, err
   280  			}
   281  
   282  			var circuit gkr.Circuit
   283  			if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil {
   284  				return nil, err
   285  			}
   286  			var _hash hash.Hash
   287  			if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil {
   288  				return nil, err
   289  			}
   290  			var proof gkr.Proof
   291  			if proof, err = unmarshalProof(info.Proof); err != nil {
   292  				return nil, err
   293  			}
   294  
   295  			fullAssignment := make(gkr.WireAssignment)
   296  			inOutAssignment := make(gkr.WireAssignment)
   297  
   298  			sorted := gkr.TopologicalSort(circuit)
   299  
   300  			inI, outI := 0, 0
   301  			for _, w := range sorted {
   302  				var assignmentRaw []interface{}
   303  				if w.IsInput() {
   304  					if inI == len(info.Input) {
   305  						return nil, fmt.Errorf("fewer input in vector than in circuit")
   306  					}
   307  					assignmentRaw = info.Input[inI]
   308  					inI++
   309  				} else if w.IsOutput() {
   310  					if outI == len(info.Output) {
   311  						return nil, fmt.Errorf("fewer output in vector than in circuit")
   312  					}
   313  					assignmentRaw = info.Output[outI]
   314  					outI++
   315  				}
   316  				if assignmentRaw != nil {
   317  					var wireAssignment []small_rational.SmallRational
   318  					if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil {
   319  						return nil, err
   320  					}
   321  
   322  					fullAssignment[w] = wireAssignment
   323  					inOutAssignment[w] = wireAssignment
   324  				}
   325  			}
   326  
   327  			fullAssignment.Complete(circuit)
   328  
   329  			info.Output = make([][]interface{}, 0, outI)
   330  
   331  			for _, w := range sorted {
   332  				if w.IsOutput() {
   333  
   334  					info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w]))
   335  
   336  				}
   337  			}
   338  
   339  			tCase = &TestCase{
   340  				FullAssignment:  fullAssignment,
   341  				InOutAssignment: inOutAssignment,
   342  				Proof:           proof,
   343  				Hash:            _hash,
   344  				Circuit:         circuit,
   345  				Info:            info,
   346  			}
   347  
   348  			testCases[path] = tCase
   349  		} else {
   350  			return nil, err
   351  		}
   352  	}
   353  
   354  	return tCase, nil
   355  }
   356  
   357  type _select int
   358  
   359  func (g _select) Evaluate(in ...small_rational.SmallRational) small_rational.SmallRational {
   360  	return in[g]
   361  }
   362  
   363  func (g _select) Degree() int {
   364  	return 1
   365  }