github.com/consensys/gnark@v0.11.0/backend/plonk/plonk_test.go (about)

     1  package plonk_test
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"math/big"
     7  	"testing"
     8  
     9  	"github.com/consensys/gnark"
    10  	"github.com/consensys/gnark-crypto/ecc"
    11  	"github.com/consensys/gnark-crypto/kzg"
    12  	"github.com/consensys/gnark/backend"
    13  	"github.com/consensys/gnark/backend/plonk"
    14  	"github.com/consensys/gnark/constraint"
    15  	"github.com/consensys/gnark/frontend"
    16  	"github.com/consensys/gnark/frontend/cs/scs"
    17  	"github.com/consensys/gnark/test"
    18  	"github.com/consensys/gnark/test/unsafekzg"
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  //--------------------//
    23  //     benches		  //
    24  //--------------------//
    25  
    26  func TestProver(t *testing.T) {
    27  
    28  	for _, curve := range getCurves() {
    29  		t.Run(curve.String(), func(t *testing.T) {
    30  			var b1, b2 bytes.Buffer
    31  			assert := require.New(t)
    32  
    33  			ccs, _solution, srs, srsLagrange := referenceCircuit(curve)
    34  			fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
    35  			assert.NoError(err)
    36  
    37  			publicWitness, err := fullWitness.Public()
    38  			assert.NoError(err)
    39  
    40  			pk, vk, err := plonk.Setup(ccs, srs, srsLagrange)
    41  			assert.NoError(err)
    42  
    43  			// write the PK to ensure it is not mutated
    44  			_, err = pk.WriteTo(&b1)
    45  			assert.NoError(err)
    46  
    47  			proof, err := plonk.Prove(ccs, pk, fullWitness)
    48  			assert.NoError(err)
    49  
    50  			// check pk
    51  			_, err = pk.WriteTo(&b2)
    52  			assert.NoError(err)
    53  
    54  			assert.True(bytes.Equal(b1.Bytes(), b2.Bytes()), "plonk prover mutated the proving key")
    55  
    56  			err = plonk.Verify(proof, vk, publicWitness)
    57  			assert.NoError(err)
    58  
    59  			// testing with full witness should output a clear error.
    60  			err = plonk.Verify(proof, vk, fullWitness)
    61  			assert.Error(err)
    62  
    63  			// check that error contains "witness length is invalid"
    64  			assert.Contains(err.Error(), "witness length is invalid")
    65  
    66  		})
    67  
    68  	}
    69  }
    70  
    71  func TestCustomHashToField(t *testing.T) {
    72  	assert := test.NewAssert(t)
    73  	assignment := &commitmentCircuit{X: 1}
    74  	for _, curve := range getCurves() {
    75  		curve := curve
    76  		assert.Run(func(assert *test.Assert) {
    77  			ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &commitmentCircuit{})
    78  			assert.NoError(err)
    79  			srs, srsLagrange, err := unsafekzg.NewSRS(ccs)
    80  			assert.NoError(err)
    81  
    82  			pk, vk, err := plonk.Setup(ccs, srs, srsLagrange)
    83  			assert.NoError(err)
    84  			witness, err := frontend.NewWitness(assignment, curve.ScalarField())
    85  			assert.NoError(err)
    86  			assert.Run(func(assert *test.Assert) {
    87  				proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverHashToFieldFunction(constantHash{}))
    88  				assert.NoError(err)
    89  				pubWitness, err := witness.Public()
    90  				assert.NoError(err)
    91  				err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierHashToFieldFunction(constantHash{}))
    92  				assert.NoError(err)
    93  			}, "prover_verifier")
    94  			assert.Run(func(assert *test.Assert) {
    95  				proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverHashToFieldFunction(constantHash{}))
    96  				assert.NoError(err)
    97  				pubWitness, err := witness.Public()
    98  				assert.NoError(err)
    99  				err = plonk.Verify(proof, vk, pubWitness)
   100  				assert.Error(err)
   101  			}, "prover_only")
   102  			assert.Run(func(assert *test.Assert) {
   103  				proof, err := plonk.Prove(ccs, pk, witness)
   104  				assert.Error(err)
   105  				_ = proof
   106  			}, "verifier_only")
   107  		}, curve.String())
   108  	}
   109  }
   110  
   111  func TestCustomChallengeHash(t *testing.T) {
   112  	assert := test.NewAssert(t)
   113  	assignment := &smallCircuit{X: 1}
   114  	for _, curve := range getCurves() {
   115  		curve := curve
   116  		assert.Run(func(assert *test.Assert) {
   117  			ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &smallCircuit{})
   118  			assert.NoError(err)
   119  			srs, srsLagrange, err := unsafekzg.NewSRS(ccs)
   120  			assert.NoError(err)
   121  
   122  			pk, vk, err := plonk.Setup(ccs, srs, srsLagrange)
   123  			assert.NoError(err)
   124  			witness, err := frontend.NewWitness(assignment, curve.ScalarField())
   125  			assert.NoError(err)
   126  			assert.Run(func(assert *test.Assert) {
   127  				proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverChallengeHashFunction(constantHash{}))
   128  				assert.NoError(err)
   129  				pubWitness, err := witness.Public()
   130  				assert.NoError(err)
   131  				err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierChallengeHashFunction(constantHash{}))
   132  				assert.NoError(err)
   133  			}, "prover_verifier")
   134  			assert.Run(func(assert *test.Assert) {
   135  				proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverChallengeHashFunction(constantHash{}))
   136  				assert.NoError(err)
   137  				pubWitness, err := witness.Public()
   138  				assert.NoError(err)
   139  				err = plonk.Verify(proof, vk, pubWitness)
   140  				assert.Error(err)
   141  			}, "prover_only")
   142  			assert.Run(func(assert *test.Assert) {
   143  				proof, err := plonk.Prove(ccs, pk, witness)
   144  				assert.NoError(err)
   145  				pubWitness, err := witness.Public()
   146  				assert.NoError(err)
   147  				err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierChallengeHashFunction(constantHash{}))
   148  				assert.Error(err)
   149  			}, "verifier_only")
   150  		}, curve.String())
   151  	}
   152  }
   153  
   154  func TestCustomKZGFoldingHash(t *testing.T) {
   155  	assert := test.NewAssert(t)
   156  	assignment := &smallCircuit{X: 1}
   157  	for _, curve := range getCurves() {
   158  		curve := curve
   159  		assert.Run(func(assert *test.Assert) {
   160  			ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &smallCircuit{})
   161  			assert.NoError(err)
   162  			srs, srsLagrange, err := unsafekzg.NewSRS(ccs)
   163  			assert.NoError(err)
   164  
   165  			pk, vk, err := plonk.Setup(ccs, srs, srsLagrange)
   166  			assert.NoError(err)
   167  			witness, err := frontend.NewWitness(assignment, curve.ScalarField())
   168  			assert.NoError(err)
   169  			assert.Run(func(assert *test.Assert) {
   170  				proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverKZGFoldingHashFunction(constantHash{}))
   171  				assert.NoError(err)
   172  				pubWitness, err := witness.Public()
   173  				assert.NoError(err)
   174  				err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierKZGFoldingHashFunction(constantHash{}))
   175  				assert.NoError(err)
   176  			}, "prover_verifier")
   177  			assert.Run(func(assert *test.Assert) {
   178  				proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverKZGFoldingHashFunction(constantHash{}))
   179  				assert.NoError(err)
   180  				pubWitness, err := witness.Public()
   181  				assert.NoError(err)
   182  				err = plonk.Verify(proof, vk, pubWitness)
   183  				assert.Error(err)
   184  			}, "prover_only")
   185  			assert.Run(func(assert *test.Assert) {
   186  				proof, err := plonk.Prove(ccs, pk, witness)
   187  				assert.NoError(err)
   188  				pubWitness, err := witness.Public()
   189  				assert.NoError(err)
   190  				err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierKZGFoldingHashFunction(constantHash{}))
   191  				assert.Error(err)
   192  			}, "verifier_only")
   193  		}, curve.String())
   194  	}
   195  }
   196  
   197  func BenchmarkSetup(b *testing.B) {
   198  	for _, curve := range getCurves() {
   199  		b.Run(curve.String(), func(b *testing.B) {
   200  			ccs, _, srs, srsLagrange := referenceCircuit(curve)
   201  			b.ResetTimer()
   202  			for i := 0; i < b.N; i++ {
   203  				_, _, _ = plonk.Setup(ccs, srs, srsLagrange)
   204  			}
   205  		})
   206  	}
   207  }
   208  
   209  func BenchmarkProver(b *testing.B) {
   210  	for _, curve := range getCurves() {
   211  		b.Run(curve.String(), func(b *testing.B) {
   212  			ccs, _solution, srs, srsLagrange := referenceCircuit(curve)
   213  			fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
   214  			if err != nil {
   215  				b.Fatal(err)
   216  			}
   217  			pk, _, err := plonk.Setup(ccs, srs, srsLagrange)
   218  			if err != nil {
   219  				b.Fatal(err)
   220  			}
   221  			b.ResetTimer()
   222  			for i := 0; i < b.N; i++ {
   223  				_, _ = plonk.Prove(ccs, pk, fullWitness)
   224  			}
   225  		})
   226  	}
   227  }
   228  
   229  func BenchmarkVerifier(b *testing.B) {
   230  	for _, curve := range getCurves() {
   231  		b.Run(curve.String(), func(b *testing.B) {
   232  			ccs, _solution, srs, srsLagrange := referenceCircuit(curve)
   233  			fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
   234  			if err != nil {
   235  				b.Fatal(err)
   236  			}
   237  			publicWitness, err := fullWitness.Public()
   238  			if err != nil {
   239  				b.Fatal(err)
   240  			}
   241  
   242  			pk, vk, err := plonk.Setup(ccs, srs, srsLagrange)
   243  			if err != nil {
   244  				b.Fatal(err)
   245  			}
   246  			proof, err := plonk.Prove(ccs, pk, fullWitness)
   247  			if err != nil {
   248  				panic(err)
   249  			}
   250  
   251  			b.ResetTimer()
   252  			for i := 0; i < b.N; i++ {
   253  				_ = plonk.Verify(proof, vk, publicWitness)
   254  			}
   255  		})
   256  	}
   257  }
   258  
   259  type refCircuit struct {
   260  	nbConstraints int
   261  	X             frontend.Variable
   262  	Y             frontend.Variable `gnark:",public"`
   263  }
   264  
   265  func (circuit *refCircuit) Define(api frontend.API) error {
   266  	for i := 0; i < circuit.nbConstraints; i++ {
   267  		circuit.X = api.Mul(circuit.X, circuit.X)
   268  	}
   269  	api.AssertIsEqual(circuit.X, circuit.Y)
   270  	return nil
   271  }
   272  
   273  func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circuit, kzg.SRS, kzg.SRS) {
   274  	const nbConstraints = (1 << 12) - 3
   275  	circuit := refCircuit{
   276  		nbConstraints: nbConstraints,
   277  	}
   278  	ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &circuit)
   279  	if err != nil {
   280  		panic(err)
   281  	}
   282  
   283  	var good refCircuit
   284  	good.X = 2
   285  
   286  	// compute expected Y
   287  	expectedY := new(big.Int).SetUint64(2)
   288  	exp := big.NewInt(1)
   289  	exp.Lsh(exp, nbConstraints)
   290  	expectedY.Exp(expectedY, exp, curve.ScalarField())
   291  
   292  	good.Y = expectedY
   293  	srs, srsLagrange, err := unsafekzg.NewSRS(ccs, unsafekzg.WithFSCache())
   294  	if err != nil {
   295  		panic(err)
   296  	}
   297  	return ccs, &good, srs, srsLagrange
   298  }
   299  
   300  type commitmentCircuit struct {
   301  	X frontend.Variable
   302  }
   303  
   304  func (c *commitmentCircuit) Define(api frontend.API) error {
   305  	cmt, err := api.(frontend.Committer).Commit(c.X)
   306  	if err != nil {
   307  		return fmt.Errorf("commit: %w", err)
   308  	}
   309  	api.AssertIsEqual(cmt, "0xaabbcc")
   310  	return nil
   311  }
   312  
   313  type smallCircuit struct {
   314  	X frontend.Variable
   315  }
   316  
   317  func (c *smallCircuit) Define(api frontend.API) error {
   318  	res := api.Mul(c.X, c.X)
   319  	api.AssertIsEqual(c.X, res)
   320  	return nil
   321  }
   322  
   323  type constantHash struct{}
   324  
   325  func (h constantHash) Write(p []byte) (n int, err error) { return len(p), nil }
   326  func (h constantHash) Sum(b []byte) []byte               { return []byte{0xaa, 0xbb, 0xcc} }
   327  func (h constantHash) Reset()                            {}
   328  func (h constantHash) Size() int                         { return 3 }
   329  func (h constantHash) BlockSize() int                    { return 32 }
   330  
   331  func getCurves() []ecc.ID {
   332  	if testing.Short() {
   333  		return []ecc.ID{ecc.BN254}
   334  	}
   335  	return gnark.Curves()
   336  }