github.com/consensys/gnark@v0.11.0/backend/groth16/groth16_test.go (about)

     1  package groth16_test
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  	"testing"
     7  
     8  	"github.com/consensys/gnark"
     9  	"github.com/consensys/gnark-crypto/ecc"
    10  	"github.com/consensys/gnark/backend"
    11  	"github.com/consensys/gnark/backend/groth16"
    12  	"github.com/consensys/gnark/constraint"
    13  	"github.com/consensys/gnark/frontend"
    14  	"github.com/consensys/gnark/frontend/cs/r1cs"
    15  	"github.com/consensys/gnark/test"
    16  )
    17  
    18  func TestCustomHashToField(t *testing.T) {
    19  	assert := test.NewAssert(t)
    20  	assignment := &commitmentCircuit{X: 1}
    21  	for _, curve := range getCurves() {
    22  		assert.Run(func(assert *test.Assert) {
    23  			ccs, err := frontend.Compile(curve.ScalarField(), r1cs.NewBuilder, &commitmentCircuit{})
    24  			assert.NoError(err)
    25  			pk, vk, err := groth16.Setup(ccs)
    26  			assert.NoError(err)
    27  			witness, err := frontend.NewWitness(assignment, curve.ScalarField())
    28  			assert.NoError(err)
    29  			assert.Run(func(assert *test.Assert) {
    30  				proof, err := groth16.Prove(ccs, pk, witness, backend.WithProverHashToFieldFunction(constantHash{}))
    31  				assert.NoError(err)
    32  				pubWitness, err := witness.Public()
    33  				assert.NoError(err)
    34  				err = groth16.Verify(proof, vk, pubWitness, backend.WithVerifierHashToFieldFunction(constantHash{}))
    35  				assert.NoError(err)
    36  			}, "custom success")
    37  			assert.Run(func(assert *test.Assert) {
    38  				proof, err := groth16.Prove(ccs, pk, witness, backend.WithProverHashToFieldFunction(constantHash{}))
    39  				assert.NoError(err)
    40  				pubWitness, err := witness.Public()
    41  				assert.NoError(err)
    42  				err = groth16.Verify(proof, vk, pubWitness)
    43  				assert.Error(err)
    44  			}, "prover_only")
    45  			assert.Run(func(assert *test.Assert) {
    46  				proof, err := groth16.Prove(ccs, pk, witness)
    47  				assert.Error(err)
    48  				_ = proof
    49  			}, "verifier_only")
    50  		}, curve.String())
    51  	}
    52  }
    53  
    54  //--------------------//
    55  //     benches		  //
    56  //--------------------//
    57  
    58  func BenchmarkSetup(b *testing.B) {
    59  	for _, curve := range getCurves() {
    60  		b.Run(curve.String(), func(b *testing.B) {
    61  			r1cs, _ := referenceCircuit(curve)
    62  			b.ResetTimer()
    63  			for i := 0; i < b.N; i++ {
    64  				_, _, _ = groth16.Setup(r1cs)
    65  			}
    66  		})
    67  	}
    68  }
    69  
    70  func BenchmarkProver(b *testing.B) {
    71  	for _, curve := range getCurves() {
    72  		b.Run(curve.String(), func(b *testing.B) {
    73  			r1cs, _solution := referenceCircuit(curve)
    74  			fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
    75  			if err != nil {
    76  				b.Fatal(err)
    77  			}
    78  			pk, err := groth16.DummySetup(r1cs)
    79  			if err != nil {
    80  				b.Fatal(err)
    81  			}
    82  			b.ResetTimer()
    83  			for i := 0; i < b.N; i++ {
    84  				_, _ = groth16.Prove(r1cs, pk, fullWitness)
    85  			}
    86  		})
    87  	}
    88  }
    89  
    90  func BenchmarkVerifier(b *testing.B) {
    91  	for _, curve := range getCurves() {
    92  		b.Run(curve.String(), func(b *testing.B) {
    93  			r1cs, _solution := referenceCircuit(curve)
    94  			fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
    95  			if err != nil {
    96  				b.Fatal(err)
    97  			}
    98  			publicWitness, err := fullWitness.Public()
    99  			if err != nil {
   100  				b.Fatal(err)
   101  			}
   102  
   103  			pk, vk, err := groth16.Setup(r1cs)
   104  			if err != nil {
   105  				b.Fatal(err)
   106  			}
   107  			proof, err := groth16.Prove(r1cs, pk, fullWitness)
   108  			if err != nil {
   109  				panic(err)
   110  			}
   111  
   112  			b.ResetTimer()
   113  			for i := 0; i < b.N; i++ {
   114  				_ = groth16.Verify(proof, vk, publicWitness)
   115  			}
   116  		})
   117  	}
   118  }
   119  
   120  type refCircuit struct {
   121  	nbConstraints int
   122  	X             frontend.Variable
   123  	Y             frontend.Variable `gnark:",public"`
   124  }
   125  
   126  func (circuit *refCircuit) Define(api frontend.API) error {
   127  	for i := 0; i < circuit.nbConstraints; i++ {
   128  		circuit.X = api.Mul(circuit.X, circuit.X)
   129  	}
   130  	api.AssertIsEqual(circuit.X, circuit.Y)
   131  	return nil
   132  }
   133  
   134  func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circuit) {
   135  	const nbConstraints = 40000
   136  	circuit := refCircuit{
   137  		nbConstraints: nbConstraints,
   138  	}
   139  	r1cs, err := frontend.Compile(curve.ScalarField(), r1cs.NewBuilder, &circuit)
   140  	if err != nil {
   141  		panic(err)
   142  	}
   143  
   144  	var good refCircuit
   145  	good.X = 2
   146  
   147  	// compute expected Y
   148  	expectedY := new(big.Int).SetUint64(2)
   149  	exp := big.NewInt(1)
   150  	exp.Lsh(exp, nbConstraints)
   151  	expectedY.Exp(expectedY, exp, curve.ScalarField())
   152  
   153  	good.Y = expectedY
   154  
   155  	return r1cs, &good
   156  }
   157  
   158  type commitmentCircuit struct {
   159  	X frontend.Variable
   160  }
   161  
   162  func (c *commitmentCircuit) Define(api frontend.API) error {
   163  	cmt, err := api.(frontend.Committer).Commit(c.X)
   164  	if err != nil {
   165  		return fmt.Errorf("commit: %w", err)
   166  	}
   167  	api.AssertIsEqual(cmt, "0xaabbcc")
   168  	return nil
   169  }
   170  
   171  type constantHash struct{}
   172  
   173  func (h constantHash) Write(p []byte) (n int, err error) { return len(p), nil }
   174  func (h constantHash) Sum(b []byte) []byte               { return []byte{0xaa, 0xbb, 0xcc} }
   175  func (h constantHash) Reset()                            {}
   176  func (h constantHash) Size() int                         { return 3 }
   177  func (h constantHash) BlockSize() int                    { return 32 }
   178  
   179  func getCurves() []ecc.ID {
   180  	if testing.Short() {
   181  		return []ecc.ID{ecc.BN254}
   182  	}
   183  	return gnark.Curves()
   184  }