github.com/consensys/gnark@v0.11.0/backend/plonk/plonk_test.go (about) 1 package plonk_test 2 3 import ( 4 "bytes" 5 "fmt" 6 "math/big" 7 "testing" 8 9 "github.com/consensys/gnark" 10 "github.com/consensys/gnark-crypto/ecc" 11 "github.com/consensys/gnark-crypto/kzg" 12 "github.com/consensys/gnark/backend" 13 "github.com/consensys/gnark/backend/plonk" 14 "github.com/consensys/gnark/constraint" 15 "github.com/consensys/gnark/frontend" 16 "github.com/consensys/gnark/frontend/cs/scs" 17 "github.com/consensys/gnark/test" 18 "github.com/consensys/gnark/test/unsafekzg" 19 "github.com/stretchr/testify/require" 20 ) 21 22 //--------------------// 23 // benches // 24 //--------------------// 25 26 func TestProver(t *testing.T) { 27 28 for _, curve := range getCurves() { 29 t.Run(curve.String(), func(t *testing.T) { 30 var b1, b2 bytes.Buffer 31 assert := require.New(t) 32 33 ccs, _solution, srs, srsLagrange := referenceCircuit(curve) 34 fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField()) 35 assert.NoError(err) 36 37 publicWitness, err := fullWitness.Public() 38 assert.NoError(err) 39 40 pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) 41 assert.NoError(err) 42 43 // write the PK to ensure it is not mutated 44 _, err = pk.WriteTo(&b1) 45 assert.NoError(err) 46 47 proof, err := plonk.Prove(ccs, pk, fullWitness) 48 assert.NoError(err) 49 50 // check pk 51 _, err = pk.WriteTo(&b2) 52 assert.NoError(err) 53 54 assert.True(bytes.Equal(b1.Bytes(), b2.Bytes()), "plonk prover mutated the proving key") 55 56 err = plonk.Verify(proof, vk, publicWitness) 57 assert.NoError(err) 58 59 // testing with full witness should output a clear error. 60 err = plonk.Verify(proof, vk, fullWitness) 61 assert.Error(err) 62 63 // check that error contains "witness length is invalid" 64 assert.Contains(err.Error(), "witness length is invalid") 65 66 }) 67 68 } 69 } 70 71 func TestCustomHashToField(t *testing.T) { 72 assert := test.NewAssert(t) 73 assignment := &commitmentCircuit{X: 1} 74 for _, curve := range getCurves() { 75 curve := curve 76 assert.Run(func(assert *test.Assert) { 77 ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &commitmentCircuit{}) 78 assert.NoError(err) 79 srs, srsLagrange, err := unsafekzg.NewSRS(ccs) 80 assert.NoError(err) 81 82 pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) 83 assert.NoError(err) 84 witness, err := frontend.NewWitness(assignment, curve.ScalarField()) 85 assert.NoError(err) 86 assert.Run(func(assert *test.Assert) { 87 proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverHashToFieldFunction(constantHash{})) 88 assert.NoError(err) 89 pubWitness, err := witness.Public() 90 assert.NoError(err) 91 err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierHashToFieldFunction(constantHash{})) 92 assert.NoError(err) 93 }, "prover_verifier") 94 assert.Run(func(assert *test.Assert) { 95 proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverHashToFieldFunction(constantHash{})) 96 assert.NoError(err) 97 pubWitness, err := witness.Public() 98 assert.NoError(err) 99 err = plonk.Verify(proof, vk, pubWitness) 100 assert.Error(err) 101 }, "prover_only") 102 assert.Run(func(assert *test.Assert) { 103 proof, err := plonk.Prove(ccs, pk, witness) 104 assert.Error(err) 105 _ = proof 106 }, "verifier_only") 107 }, curve.String()) 108 } 109 } 110 111 func TestCustomChallengeHash(t *testing.T) { 112 assert := test.NewAssert(t) 113 assignment := &smallCircuit{X: 1} 114 for _, curve := range getCurves() { 115 curve := curve 116 assert.Run(func(assert *test.Assert) { 117 ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &smallCircuit{}) 118 assert.NoError(err) 119 srs, srsLagrange, err := unsafekzg.NewSRS(ccs) 120 assert.NoError(err) 121 122 pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) 123 assert.NoError(err) 124 witness, err := frontend.NewWitness(assignment, curve.ScalarField()) 125 assert.NoError(err) 126 assert.Run(func(assert *test.Assert) { 127 proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverChallengeHashFunction(constantHash{})) 128 assert.NoError(err) 129 pubWitness, err := witness.Public() 130 assert.NoError(err) 131 err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierChallengeHashFunction(constantHash{})) 132 assert.NoError(err) 133 }, "prover_verifier") 134 assert.Run(func(assert *test.Assert) { 135 proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverChallengeHashFunction(constantHash{})) 136 assert.NoError(err) 137 pubWitness, err := witness.Public() 138 assert.NoError(err) 139 err = plonk.Verify(proof, vk, pubWitness) 140 assert.Error(err) 141 }, "prover_only") 142 assert.Run(func(assert *test.Assert) { 143 proof, err := plonk.Prove(ccs, pk, witness) 144 assert.NoError(err) 145 pubWitness, err := witness.Public() 146 assert.NoError(err) 147 err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierChallengeHashFunction(constantHash{})) 148 assert.Error(err) 149 }, "verifier_only") 150 }, curve.String()) 151 } 152 } 153 154 func TestCustomKZGFoldingHash(t *testing.T) { 155 assert := test.NewAssert(t) 156 assignment := &smallCircuit{X: 1} 157 for _, curve := range getCurves() { 158 curve := curve 159 assert.Run(func(assert *test.Assert) { 160 ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &smallCircuit{}) 161 assert.NoError(err) 162 srs, srsLagrange, err := unsafekzg.NewSRS(ccs) 163 assert.NoError(err) 164 165 pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) 166 assert.NoError(err) 167 witness, err := frontend.NewWitness(assignment, curve.ScalarField()) 168 assert.NoError(err) 169 assert.Run(func(assert *test.Assert) { 170 proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverKZGFoldingHashFunction(constantHash{})) 171 assert.NoError(err) 172 pubWitness, err := witness.Public() 173 assert.NoError(err) 174 err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierKZGFoldingHashFunction(constantHash{})) 175 assert.NoError(err) 176 }, "prover_verifier") 177 assert.Run(func(assert *test.Assert) { 178 proof, err := plonk.Prove(ccs, pk, witness, backend.WithProverKZGFoldingHashFunction(constantHash{})) 179 assert.NoError(err) 180 pubWitness, err := witness.Public() 181 assert.NoError(err) 182 err = plonk.Verify(proof, vk, pubWitness) 183 assert.Error(err) 184 }, "prover_only") 185 assert.Run(func(assert *test.Assert) { 186 proof, err := plonk.Prove(ccs, pk, witness) 187 assert.NoError(err) 188 pubWitness, err := witness.Public() 189 assert.NoError(err) 190 err = plonk.Verify(proof, vk, pubWitness, backend.WithVerifierKZGFoldingHashFunction(constantHash{})) 191 assert.Error(err) 192 }, "verifier_only") 193 }, curve.String()) 194 } 195 } 196 197 func BenchmarkSetup(b *testing.B) { 198 for _, curve := range getCurves() { 199 b.Run(curve.String(), func(b *testing.B) { 200 ccs, _, srs, srsLagrange := referenceCircuit(curve) 201 b.ResetTimer() 202 for i := 0; i < b.N; i++ { 203 _, _, _ = plonk.Setup(ccs, srs, srsLagrange) 204 } 205 }) 206 } 207 } 208 209 func BenchmarkProver(b *testing.B) { 210 for _, curve := range getCurves() { 211 b.Run(curve.String(), func(b *testing.B) { 212 ccs, _solution, srs, srsLagrange := referenceCircuit(curve) 213 fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField()) 214 if err != nil { 215 b.Fatal(err) 216 } 217 pk, _, err := plonk.Setup(ccs, srs, srsLagrange) 218 if err != nil { 219 b.Fatal(err) 220 } 221 b.ResetTimer() 222 for i := 0; i < b.N; i++ { 223 _, _ = plonk.Prove(ccs, pk, fullWitness) 224 } 225 }) 226 } 227 } 228 229 func BenchmarkVerifier(b *testing.B) { 230 for _, curve := range getCurves() { 231 b.Run(curve.String(), func(b *testing.B) { 232 ccs, _solution, srs, srsLagrange := referenceCircuit(curve) 233 fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField()) 234 if err != nil { 235 b.Fatal(err) 236 } 237 publicWitness, err := fullWitness.Public() 238 if err != nil { 239 b.Fatal(err) 240 } 241 242 pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) 243 if err != nil { 244 b.Fatal(err) 245 } 246 proof, err := plonk.Prove(ccs, pk, fullWitness) 247 if err != nil { 248 panic(err) 249 } 250 251 b.ResetTimer() 252 for i := 0; i < b.N; i++ { 253 _ = plonk.Verify(proof, vk, publicWitness) 254 } 255 }) 256 } 257 } 258 259 type refCircuit struct { 260 nbConstraints int 261 X frontend.Variable 262 Y frontend.Variable `gnark:",public"` 263 } 264 265 func (circuit *refCircuit) Define(api frontend.API) error { 266 for i := 0; i < circuit.nbConstraints; i++ { 267 circuit.X = api.Mul(circuit.X, circuit.X) 268 } 269 api.AssertIsEqual(circuit.X, circuit.Y) 270 return nil 271 } 272 273 func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circuit, kzg.SRS, kzg.SRS) { 274 const nbConstraints = (1 << 12) - 3 275 circuit := refCircuit{ 276 nbConstraints: nbConstraints, 277 } 278 ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &circuit) 279 if err != nil { 280 panic(err) 281 } 282 283 var good refCircuit 284 good.X = 2 285 286 // compute expected Y 287 expectedY := new(big.Int).SetUint64(2) 288 exp := big.NewInt(1) 289 exp.Lsh(exp, nbConstraints) 290 expectedY.Exp(expectedY, exp, curve.ScalarField()) 291 292 good.Y = expectedY 293 srs, srsLagrange, err := unsafekzg.NewSRS(ccs, unsafekzg.WithFSCache()) 294 if err != nil { 295 panic(err) 296 } 297 return ccs, &good, srs, srsLagrange 298 } 299 300 type commitmentCircuit struct { 301 X frontend.Variable 302 } 303 304 func (c *commitmentCircuit) Define(api frontend.API) error { 305 cmt, err := api.(frontend.Committer).Commit(c.X) 306 if err != nil { 307 return fmt.Errorf("commit: %w", err) 308 } 309 api.AssertIsEqual(cmt, "0xaabbcc") 310 return nil 311 } 312 313 type smallCircuit struct { 314 X frontend.Variable 315 } 316 317 func (c *smallCircuit) Define(api frontend.API) error { 318 res := api.Mul(c.X, c.X) 319 api.AssertIsEqual(c.X, res) 320 return nil 321 } 322 323 type constantHash struct{} 324 325 func (h constantHash) Write(p []byte) (n int, err error) { return len(p), nil } 326 func (h constantHash) Sum(b []byte) []byte { return []byte{0xaa, 0xbb, 0xcc} } 327 func (h constantHash) Reset() {} 328 func (h constantHash) Size() int { return 3 } 329 func (h constantHash) BlockSize() int { return 32 } 330 331 func getCurves() []ecc.ID { 332 if testing.Short() { 333 return []ecc.ID{ecc.BN254} 334 } 335 return gnark.Curves() 336 }