github.com/consensys/gnark@v0.11.0/test/assert_fuzz.go (about)

     1  package test
     2  
     3  import (
     4  	"crypto/rand"
     5  	"math/big"
     6  	mrand "math/rand"
     7  	"reflect"
     8  	"time"
     9  
    10  	"github.com/consensys/gnark"
    11  	"github.com/consensys/gnark-crypto/ecc"
    12  	"github.com/consensys/gnark/backend"
    13  	"github.com/consensys/gnark/frontend"
    14  	"github.com/consensys/gnark/frontend/schema"
    15  )
    16  
    17  var seedCorpus []*big.Int
    18  
    19  func init() {
    20  	seedCorpus = make([]*big.Int, 0, 300)
    21  
    22  	// small values, including bits
    23  	for i := -5; i <= 5; i++ {
    24  		seedCorpus = append(seedCorpus, big.NewInt(int64(i)))
    25  	}
    26  
    27  	// moduli
    28  	for _, curve := range gnark.Curves() {
    29  		fp := curve.BaseField()
    30  		fr := curve.ScalarField()
    31  		seedCorpus = append(seedCorpus, fp)
    32  		seedCorpus = append(seedCorpus, fr)
    33  
    34  		var bi big.Int
    35  		for i := -3; i <= 3; i++ {
    36  			bi.SetInt64(int64(i))
    37  			var fp1, fr1 big.Int
    38  			fp1.Add(fp, &bi)
    39  			fr1.Add(fr, &bi)
    40  
    41  			seedCorpus = append(seedCorpus, &fp1)
    42  			seedCorpus = append(seedCorpus, &fr1)
    43  		}
    44  	}
    45  
    46  	// powers of 2
    47  	bi := big.NewInt(1)
    48  	bi.Lsh(bi, 32)
    49  	seedCorpus = append(seedCorpus, bi)
    50  
    51  	bi = big.NewInt(1)
    52  	bi.Lsh(bi, 64)
    53  	seedCorpus = append(seedCorpus, bi)
    54  
    55  	bi = big.NewInt(1)
    56  	bi.Lsh(bi, 254)
    57  	seedCorpus = append(seedCorpus, bi)
    58  
    59  	bi = big.NewInt(1)
    60  	bi.Lsh(bi, 255)
    61  	seedCorpus = append(seedCorpus, bi)
    62  
    63  	bi = big.NewInt(1)
    64  	bi.Lsh(bi, 256)
    65  	seedCorpus = append(seedCorpus, bi)
    66  
    67  }
    68  
    69  type filler func(frontend.Circuit, ecc.ID)
    70  
    71  func zeroFiller(w frontend.Circuit, curve ecc.ID) {
    72  	fill(w, func() interface{} {
    73  		return 0
    74  	})
    75  }
    76  
    77  func binaryFiller(w frontend.Circuit, curve ecc.ID) {
    78  	mrand := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here
    79  
    80  	fill(w, func() interface{} {
    81  		return int(mrand.Uint32() % 2) //#nosec G404 weak rng is fine here
    82  	})
    83  }
    84  
    85  func seedFiller(w frontend.Circuit, curve ecc.ID) {
    86  
    87  	mrand := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here
    88  
    89  	m := curve.ScalarField()
    90  
    91  	fill(w, func() interface{} {
    92  		i := int(mrand.Uint32() % uint32(len(seedCorpus))) //#nosec G404 weak rng is fine here
    93  		r := new(big.Int).Set(seedCorpus[i])
    94  		return r.Mod(r, m)
    95  	})
    96  }
    97  
    98  func randomFiller(w frontend.Circuit, curve ecc.ID) {
    99  
   100  	r := mrand.New(mrand.NewSource(time.Now().Unix())) //#nosec G404 weak rng is fine here
   101  	m := curve.ScalarField()
   102  
   103  	fill(w, func() interface{} {
   104  		i := int(mrand.Uint32() % uint32(len(seedCorpus)*2)) //#nosec G404 weak rng is fine here
   105  		if i >= len(seedCorpus) {
   106  			b1, _ := rand.Int(r, m) //#nosec G404 weak rng is fine here
   107  			return b1
   108  		}
   109  		r := new(big.Int).Set(seedCorpus[i])
   110  		return r.Mod(r, m)
   111  	})
   112  }
   113  
   114  func fill(w frontend.Circuit, nextValue func() interface{}) {
   115  	setHandler := func(f schema.LeafInfo, tInput reflect.Value) error {
   116  		v := nextValue()
   117  		tInput.Set(reflect.ValueOf((v)))
   118  		return nil
   119  	}
   120  	// this can't error.
   121  	// TODO @gbotrel it might error with .Walk?
   122  	_, _ = schema.Walk(w, tVariable, setHandler)
   123  }
   124  
   125  var tVariable reflect.Type
   126  
   127  func init() {
   128  	tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type()
   129  }
   130  
   131  // Fuzz fuzzes the given circuit by instantiating "randomized" witnesses and cross checking
   132  // execution result between constraint system solver and big.Int test execution engine
   133  //
   134  // note: this is experimental and will be more tightly integrated with go1.18 built-in fuzzing
   135  func (assert *Assert) Fuzz(circuit frontend.Circuit, fuzzCount int, opts ...TestingOption) {
   136  	opt := assert.options(opts...)
   137  
   138  	// first we clone the circuit
   139  	// then we parse the frontend.Variable and set them to a random value  or from our interesting pool
   140  	// (% of allocations to be tuned)
   141  	w := shallowClone(circuit)
   142  
   143  	fillers := []filler{randomFiller, binaryFiller, seedFiller}
   144  
   145  	for _, curve := range opt.curves {
   146  		for _, b := range opt.backends {
   147  			curve := curve
   148  			b := b
   149  			assert.Run(func(assert *Assert) {
   150  				// this puts the compiled circuit in the cache
   151  				// we do this here in case our fuzzWitness method mutates some references in the circuit
   152  				// (like []frontend.Variable) before cleaning up
   153  				_, err := assert.compile(circuit, curve, b, opt.compileOpts)
   154  				assert.NoError(err)
   155  				valid := 0
   156  				// "fuzz" with zeros
   157  				valid += assert.fuzzer(zeroFiller, circuit, w, b, curve, &opt)
   158  
   159  				for i := 0; i < fuzzCount; i++ {
   160  					for _, f := range fillers {
   161  						valid += assert.fuzzer(f, circuit, w, b, curve, &opt)
   162  					}
   163  				}
   164  
   165  			}, curve.String(), b.String())
   166  
   167  		}
   168  	}
   169  }
   170  
   171  func (assert *Assert) fuzzer(fuzzer filler, circuit, w frontend.Circuit, b backend.ID, curve ecc.ID, opt *testingConfig) int {
   172  	// fuzz a witness
   173  	fuzzer(w, curve)
   174  
   175  	errVars := IsSolved(circuit, w, curve.ScalarField())
   176  	errConsts := IsSolved(circuit, w, curve.ScalarField(), SetAllVariablesAsConstants())
   177  
   178  	if (errVars == nil) != (errConsts == nil) {
   179  		w, err := frontend.NewWitness(w, curve.ScalarField())
   180  		if err != nil {
   181  			panic(err)
   182  		}
   183  		s, err := frontend.NewSchema(circuit)
   184  		if err != nil {
   185  			panic(err)
   186  		}
   187  		bb, err := w.ToJSON(s)
   188  		if err != nil {
   189  			panic(err)
   190  		}
   191  
   192  		assert.Log("errVars", errVars)
   193  		assert.Log("errConsts", errConsts)
   194  		assert.Log("fuzzer witness", string(bb))
   195  		assert.FailNow("solving circuit with values as constants vs non-constants mismatched result")
   196  	}
   197  
   198  	if errVars == nil && errConsts == nil {
   199  		// valid witness
   200  		assert.solvingSucceeded(circuit, w, b, curve, opt)
   201  		return 1
   202  	}
   203  
   204  	// invalid witness
   205  	assert.solvingFailed(circuit, w, b, curve, opt)
   206  	return 0
   207  }
   208  
   209  func (assert *Assert) solvingSucceeded(circuit frontend.Circuit, validAssignment frontend.Circuit, b backend.ID, curve ecc.ID, opt *testingConfig) {
   210  	// parse assignment
   211  	w := assert.parseAssignment(circuit, validAssignment, curve, opt.checkSerialization)
   212  
   213  	checkError := func(err error) { assert.noError(err, &w) }
   214  
   215  	// 1- compile the circuit
   216  	ccs, err := assert.compile(circuit, curve, b, opt.compileOpts)
   217  	checkError(err)
   218  
   219  	// must not error with big int test engine
   220  	err = IsSolved(circuit, validAssignment, curve.ScalarField())
   221  	checkError(err)
   222  
   223  	err = ccs.IsSolved(w.full, opt.solverOpts...)
   224  	checkError(err)
   225  
   226  }
   227  
   228  func (assert *Assert) solvingFailed(circuit frontend.Circuit, invalidAssignment frontend.Circuit, b backend.ID, curve ecc.ID, opt *testingConfig) {
   229  	// parse assignment
   230  	w := assert.parseAssignment(circuit, invalidAssignment, curve, opt.checkSerialization)
   231  
   232  	checkError := func(err error) { assert.noError(err, &w) }
   233  	mustError := func(err error) { assert.error(err, &w) }
   234  
   235  	// 1- compile the circuit
   236  	ccs, err := assert.compile(circuit, curve, b, opt.compileOpts)
   237  	checkError(err)
   238  
   239  	// must error with big int test engine
   240  	err = IsSolved(circuit, invalidAssignment, curve.ScalarField())
   241  	mustError(err)
   242  
   243  	err = ccs.IsSolved(w.full, opt.solverOpts...)
   244  	mustError(err)
   245  
   246  }