github.com/consensys/gnark@v0.11.0/test/assert_checkcircuit.go (about)

     1  package test
     2  
     3  import (
     4  	"github.com/consensys/gnark-crypto/ecc"
     5  	"github.com/consensys/gnark/backend"
     6  	"github.com/consensys/gnark/backend/groth16"
     7  	"github.com/consensys/gnark/backend/plonk"
     8  	"github.com/consensys/gnark/backend/solidity"
     9  	"github.com/consensys/gnark/backend/witness"
    10  	"github.com/consensys/gnark/constraint"
    11  	"github.com/consensys/gnark/frontend"
    12  	"github.com/consensys/gnark/frontend/schema"
    13  	"github.com/consensys/gnark/logger"
    14  	"github.com/consensys/gnark/test/unsafekzg"
    15  )
    16  
    17  // CheckCircuit performs a series of check on the provided circuit.
    18  //
    19  //	go test -short                  --> testEngineChecks
    20  //	go test                         --> testEngineChecks  + constraintSolverChecks
    21  //	go test -tags=prover_checks     --> ... + proverChecks
    22  //	go test -tags=release_checks    --> ... + releaseChecks (solidity, serialization, ...)
    23  //
    24  // Depending on the above flags, the following checks are performed:
    25  //   - the circuit compiles
    26  //   - the circuit can be solved with the test engine
    27  //   - the circuit can be solved with the constraint system solver
    28  //   - the circuit can be solved with the prover
    29  //   - the circuit can be verified with the verifier
    30  //   - the circuit can be verified with gnark-solidity-checker
    31  //   - the circuit, witness, proving and verifying keys can be serialized and deserialized
    32  func (assert *Assert) CheckCircuit(circuit frontend.Circuit, opts ...TestingOption) {
    33  	// get the testing configuration
    34  	opt := assert.options(opts...)
    35  	log := logger.Logger()
    36  
    37  	// for each {curve, backend} tuple
    38  	for _, curve := range opt.curves {
    39  		curve := curve
    40  
    41  		// run in sub-test to contextualize with curve
    42  		assert.Run(func(assert *Assert) {
    43  
    44  			// parse valid / invalid assignments
    45  			var invalidWitnesses, validWitnesses []_witness
    46  			for _, a := range opt.validAssignments {
    47  				w := assert.parseAssignment(circuit, a, curve, opt.checkSerialization)
    48  				validWitnesses = append(validWitnesses, w)
    49  
    50  				// check that the assignment is valid with the test engine
    51  				if !opt.skipTestEngine {
    52  					err := IsSolved(circuit, w.assignment, curve.ScalarField())
    53  					assert.noError(err, &w)
    54  				}
    55  			}
    56  
    57  			for _, a := range opt.invalidAssignments {
    58  				w := assert.parseAssignment(circuit, a, curve, opt.checkSerialization)
    59  				invalidWitnesses = append(invalidWitnesses, w)
    60  
    61  				// check that the assignment is invalid with the test engine
    62  				if !opt.skipTestEngine {
    63  					err := IsSolved(circuit, w.assignment, curve.ScalarField())
    64  					assert.error(err, &w)
    65  				}
    66  			}
    67  
    68  			// for each backend; compile, prove/verify or solve, check serialization if needed.
    69  			for _, b := range opt.backends {
    70  				b := b
    71  
    72  				// run in sub-test to contextualize with backend
    73  				assert.Run(func(assert *Assert) {
    74  
    75  					// 1- check that the circuit compiles
    76  					ccs, err := assert.compile(circuit, curve, b, opt.compileOpts)
    77  					assert.noError(err, nil)
    78  
    79  					// TODO @gbotrel check serialization round trip with constraint system.
    80  
    81  					// 2- if we are not running the full prover;
    82  					// we need to run the solver on the constraint system only
    83  					if !opt.checkProver {
    84  						for _, w := range invalidWitnesses {
    85  							w := w
    86  							assert.Run(func(assert *Assert) {
    87  								_, err = ccs.Solve(w.full, opt.solverOpts...)
    88  								assert.error(err, &w)
    89  							}, "invalid_witness")
    90  						}
    91  
    92  						for _, w := range validWitnesses {
    93  							w := w
    94  							assert.Run(func(assert *Assert) {
    95  								_, err = ccs.Solve(w.full, opt.solverOpts...)
    96  								assert.noError(err, &w)
    97  							}, "valid_witness")
    98  						}
    99  
   100  						return
   101  					}
   102  
   103  					// we need to run the setup, prove and verify and check serialization
   104  					assert.t.Parallel()
   105  
   106  					var concreteBackend tBackend
   107  
   108  					switch b {
   109  					case backend.GROTH16:
   110  						concreteBackend = _groth16
   111  					case backend.PLONK:
   112  						concreteBackend = _plonk
   113  					default:
   114  						panic("backend not implemented")
   115  					}
   116  
   117  					// proof system setup.
   118  					pk, vk, pkBuilder, vkBuilder, proofBuilder, err := concreteBackend.setup(ccs, curve)
   119  					assert.noError(err, nil)
   120  
   121  					// for each valid witness, run the prover and verifier
   122  					for _, w := range validWitnesses {
   123  						w := w
   124  						assert.Run(func(assert *Assert) {
   125  							checkSolidity := opt.checkSolidity && curve == ecc.BN254
   126  							proverOpts := opt.proverOpts
   127  							verifierOpts := opt.verifierOpts
   128  							if b == backend.GROTH16 {
   129  								// currently groth16 Solidity checker only supports circuits with up to 1 commitment
   130  								if len(ccs.GetCommitments().CommitmentIndexes()) > 1 {
   131  									log.Warn().
   132  										Int("nb_commitments", len(ccs.GetCommitments().CommitmentIndexes())).
   133  										Msg("skipping solidity check, too many commitments")
   134  								}
   135  								checkSolidity = checkSolidity && (len(ccs.GetCommitments().CommitmentIndexes()) <= 1)
   136  								// set the default hash function in case of	custom hash function not set. This is to ensure that the proof can be verified by gnark-solidity-checker
   137  								proverOpts = append([]backend.ProverOption{solidity.WithProverTargetSolidityVerifier(b)}, opt.proverOpts...)
   138  								verifierOpts = append([]backend.VerifierOption{solidity.WithVerifierTargetSolidityVerifier(b)}, opt.verifierOpts...)
   139  							}
   140  							proof, err := concreteBackend.prove(ccs, pk, w.full, proverOpts...)
   141  							assert.noError(err, &w)
   142  
   143  							err = concreteBackend.verify(proof, vk, w.public, verifierOpts...)
   144  							assert.noError(err, &w)
   145  
   146  							if checkSolidity {
   147  								// check that the proof can be verified by gnark-solidity-checker
   148  								if _vk, ok := vk.(solidity.VerifyingKey); ok {
   149  									assert.Run(func(assert *Assert) {
   150  										assert.solidityVerification(b, _vk, proof, w.public, opt.solidityOpts)
   151  									}, "solidity")
   152  								}
   153  							}
   154  
   155  							// check proof serialization
   156  							assert.roundTripCheck(proof, proofBuilder, "proof")
   157  						}, "valid_witness")
   158  					}
   159  
   160  					// for each invalid witness, run the prover only, it should fail.
   161  					for _, w := range invalidWitnesses {
   162  						w := w
   163  						assert.Run(func(assert *Assert) {
   164  							_, err := concreteBackend.prove(ccs, pk, w.full, opt.proverOpts...)
   165  							assert.error(err, &w)
   166  						}, "invalid_witness")
   167  					}
   168  
   169  					// check serialization of proving and verifying keys
   170  					if opt.checkSerialization && ccs.GetNbConstraints() <= serializationThreshold && (curve == ecc.BN254 || curve == ecc.BLS12_381) {
   171  						assert.roundTripCheck(pk, pkBuilder, "proving_key")
   172  						assert.roundTripCheck(vk, vkBuilder, "verifying_key")
   173  					}
   174  
   175  				}, b.String())
   176  			}
   177  
   178  		}, curve.String())
   179  	}
   180  
   181  	// TODO @gbotrel revisit this.
   182  	if false && opt.fuzzing {
   183  		// TODO may not be the right place, but ensures all our tests call these minimal tests
   184  		// (like filling a witness with zeroes, or binary values, ...)
   185  		assert.Run(func(assert *Assert) {
   186  			assert.Fuzz(circuit, 5, opts...)
   187  		}, "fuzz")
   188  	}
   189  }
   190  
   191  type _witness struct {
   192  	full       witness.Witness
   193  	public     witness.Witness
   194  	assignment frontend.Circuit
   195  }
   196  
   197  func (assert *Assert) parseAssignment(circuit frontend.Circuit, assignment frontend.Circuit, curve ecc.ID, checkSerialization bool) _witness {
   198  	if assignment == nil {
   199  		return _witness{}
   200  	}
   201  	full, err := frontend.NewWitness(assignment, curve.ScalarField())
   202  	assert.NoError(err, "can't parse assignment into full witness")
   203  
   204  	public, err := frontend.NewWitness(assignment, curve.ScalarField(), frontend.PublicOnly())
   205  	assert.NoError(err, "can't parse assignment into public witness")
   206  
   207  	if checkSerialization {
   208  		witnessBuilder := func() any {
   209  			w, err := witness.New(curve.ScalarField())
   210  			if err != nil {
   211  				panic(err)
   212  			}
   213  			return w
   214  		}
   215  		assert.roundTripCheck(full, witnessBuilder, "witness", "full")
   216  		assert.roundTripCheck(public, witnessBuilder, "witness", "public")
   217  
   218  		// count number of element in witness.
   219  		// if too many, we don't do JSON serialization.
   220  		s, err := schema.Walk(assignment, tVariable, nil)
   221  		assert.NoError(err)
   222  
   223  		if s.Public+s.Secret <= serializationThreshold {
   224  			assert.Run(func(assert *Assert) {
   225  				s := lazySchema(circuit)()
   226  				assert.marshalWitnessJSON(full, s, curve, false)
   227  			}, curve.String(), "marshal/json")
   228  			assert.Run(func(assert *Assert) {
   229  				s := lazySchema(circuit)()
   230  				assert.marshalWitnessJSON(public, s, curve, true)
   231  			}, curve.String(), "marshal-public/json")
   232  		}
   233  	}
   234  
   235  	return _witness{full: full, public: public, assignment: assignment}
   236  }
   237  
   238  type fnSetup func(ccs constraint.ConstraintSystem, curve ecc.ID) (
   239  	pk, vk any,
   240  	pkBuilder, vkBuilder, proofBuilder func() any,
   241  	err error)
   242  type fnProve func(ccs constraint.ConstraintSystem, pk any, fullWitness witness.Witness, opts ...backend.ProverOption) (proof any, err error)
   243  type fnVerify func(proof, vk any, publicWitness witness.Witness, opts ...backend.VerifierOption) error
   244  
   245  // tBackend abstracts the backend implementation in the test package.
   246  type tBackend struct {
   247  	setup  fnSetup
   248  	prove  fnProve
   249  	verify fnVerify
   250  }
   251  
   252  var (
   253  	_groth16 = tBackend{
   254  		setup: func(ccs constraint.ConstraintSystem, curve ecc.ID) (
   255  			pk, vk any,
   256  			pkBuilder, vkBuilder, proofBuilder func() any,
   257  			err error) {
   258  			pk, vk, err = groth16.Setup(ccs)
   259  			return pk, vk, func() any { return groth16.NewProvingKey(curve) }, func() any { return groth16.NewVerifyingKey(curve) }, func() any { return groth16.NewProof(curve) }, err
   260  		},
   261  		prove: func(ccs constraint.ConstraintSystem, pk any, fullWitness witness.Witness, opts ...backend.ProverOption) (proof any, err error) {
   262  			return groth16.Prove(ccs, pk.(groth16.ProvingKey), fullWitness, opts...)
   263  		},
   264  		verify: func(proof, vk any, publicWitness witness.Witness, opts ...backend.VerifierOption) error {
   265  			return groth16.Verify(proof.(groth16.Proof), vk.(groth16.VerifyingKey), publicWitness, opts...)
   266  		},
   267  	}
   268  
   269  	_plonk = tBackend{
   270  		setup: func(ccs constraint.ConstraintSystem, curve ecc.ID) (
   271  			pk, vk any,
   272  			pkBuilder, vkBuilder, proofBuilder func() any,
   273  			err error) {
   274  			srs, srsLagrange, err := unsafekzg.NewSRS(ccs)
   275  			if err != nil {
   276  				return nil, nil, nil, nil, nil, err
   277  			}
   278  			pk, vk, err = plonk.Setup(ccs, srs, srsLagrange)
   279  			return pk, vk, func() any { return plonk.NewProvingKey(curve) }, func() any { return plonk.NewVerifyingKey(curve) }, func() any { return plonk.NewProof(curve) }, err
   280  		},
   281  		prove: func(ccs constraint.ConstraintSystem, pk any, fullWitness witness.Witness, opts ...backend.ProverOption) (proof any, err error) {
   282  			return plonk.Prove(ccs, pk.(plonk.ProvingKey), fullWitness, opts...)
   283  		},
   284  		verify: func(proof, vk any, publicWitness witness.Witness, opts ...backend.VerifierOption) error {
   285  			return plonk.Verify(proof.(plonk.Proof), vk.(plonk.VerifyingKey), publicWitness, opts...)
   286  		},
   287  	}
   288  )