github.com/consensys/gnark-crypto@v0.14.0/internal/generator/gkr/template/gkr.test.go.tmpl (about) 1 2 import ( 3 "{{.FieldPackagePath}}" 4 "{{.FieldPackagePath}}/mimc" 5 "{{.FieldPackagePath}}/polynomial" 6 "{{.FieldPackagePath}}/sumcheck" 7 "{{.FieldPackagePath}}/test_vector_utils" 8 fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" 9 "github.com/consensys/gnark-crypto/utils" 10 "github.com/stretchr/testify/assert" 11 "fmt" 12 "hash" 13 "os" 14 "strconv" 15 "testing" 16 "path/filepath" 17 "encoding/json" 18 "reflect" 19 "time" 20 ) 21 22 {{$GenerateLargeTests := .GenerateTests}} {{/* this is redundant. soon to be removed if a use case for it doesn't come back */}} 23 {{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} 24 25 func TestNoGateTwoInstances(t *testing.T) { 26 // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case 27 testNoGate(t, []{{.ElementType}}{four, three}) 28 } 29 30 func TestNoGate(t *testing.T) { 31 testManyInstances(t, 1, testNoGate) 32 } 33 34 func TestSingleAddGateTwoInstances(t *testing.T) { 35 testSingleAddGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) 36 } 37 38 func TestSingleAddGate(t *testing.T) { 39 testManyInstances(t, 2, testSingleAddGate) 40 } 41 42 func TestSingleMulGateTwoInstances(t *testing.T) { 43 testSingleMulGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) 44 } 45 46 func TestSingleMulGate(t *testing.T) { 47 testManyInstances(t, 2, testSingleMulGate) 48 } 49 50 func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { 51 52 testSingleInputTwoIdentityGates(t, []{{.ElementType}}{two, three}) 53 } 54 55 func TestSingleInputTwoIdentityGates(t *testing.T) { 56 57 testManyInstances(t, 2, testSingleInputTwoIdentityGates) 58 } 59 60 func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { 61 testSingleInputTwoIdentityGatesComposed(t, []{{.ElementType}}{two, one}) 62 } 63 64 func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { 65 testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) 66 } 67 68 func TestSingleMimcCipherGateTwoInstances(t *testing.T) { 69 testSingleMimcCipherGate(t, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) 70 } 71 72 func TestSingleMimcCipherGate(t *testing.T) { 73 testManyInstances(t, 2, testSingleMimcCipherGate) 74 } 75 76 func TestATimesBSquaredTwoInstances(t *testing.T) { 77 testATimesBSquared(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) 78 } 79 80 func TestShallowMimcTwoInstances(t *testing.T) { 81 testMimc(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) 82 } 83 84 {{- if $GenerateLargeTests}} 85 func TestMimcTwoInstances(t *testing.T) { 86 testMimc(t, 93, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) 87 } 88 89 func TestMimc(t *testing.T) { 90 testManyInstances(t, 2, generateTestMimc(93)) 91 } 92 93 func generateTestMimc(numRounds int) func(*testing.T, ...[]{{.ElementType}}) { 94 return func(t *testing.T, inputAssignments ...[]{{.ElementType}}) { 95 testMimc(t, numRounds, inputAssignments...) 96 } 97 } 98 99 {{- end}} 100 101 func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { 102 circuit := Circuit{ Wire{ 103 Gate: IdentityGate{}, 104 Inputs: []*Wire{}, 105 nbUniqueOutputs: 2, 106 } } 107 108 wire := &circuit[0] 109 110 assignment := WireAssignment{&circuit[0]: []{{.ElementType}}{two, three}} 111 var o settings 112 pool := polynomial.NewPool(256, 1<<11) 113 workers := utils.NewWorkerPool() 114 o.pool = &pool 115 o.workers = workers 116 117 claimsManagerGen := func() *claimsManager { 118 manager := newClaimsManager(circuit, assignment, o) 119 manager.add(wire, []{{.ElementType}}{three}, five) 120 manager.add(wire, []{{.ElementType}}{four}, six) 121 return &manager 122 } 123 124 transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) 125 126 proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) 127 assert.NoError(t, err) 128 err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) 129 assert.NoError(t, err) 130 } 131 132 var one, two, three, four, five, six {{.ElementType}} 133 134 func init() { 135 one.SetOne() 136 two.Double(&one) 137 three.Add(&two, &one) 138 four.Double(&two) 139 five.Add(&three, &two) 140 six.Double(&three) 141 } 142 143 var testManyInstancesLogMaxInstances = -1 144 145 func getLogMaxInstances(t *testing.T) int { 146 if testManyInstancesLogMaxInstances == -1 { 147 148 s := os.Getenv("GKR_LOG_INSTANCES") 149 if s == "" { 150 testManyInstancesLogMaxInstances = 5 151 } else { 152 var err error 153 testManyInstancesLogMaxInstances, err = strconv.Atoi(s) 154 if err != nil { 155 t.Error(err) 156 } 157 } 158 159 } 160 return testManyInstancesLogMaxInstances 161 } 162 163 func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]{{.ElementType}})) { 164 fullAssignments := make([][]{{.ElementType}}, numInput) 165 maxSize := 1 << getLogMaxInstances(t) 166 167 t.Log("Entered test orchestrator, assigning and randomizing inputs") 168 169 for i := range fullAssignments { 170 fullAssignments[i] = make([]fr.Element, maxSize) 171 setRandom(fullAssignments[i]) 172 } 173 174 inputAssignments := make([][]{{.ElementType}}, numInput) 175 for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { 176 for i, fullAssignment := range fullAssignments { 177 inputAssignments[i] = fullAssignment[:numEvals] 178 } 179 180 t.Log("Selected inputs for test") 181 test(t, inputAssignments...) 182 } 183 } 184 185 func testNoGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { 186 c := Circuit{ 187 { 188 Inputs: []*Wire{}, 189 Gate: nil, 190 }, 191 } 192 193 assignment := WireAssignment{&c[0]: inputAssignments[0]} 194 195 proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 196 assert.NoError(t, err) 197 198 // Even though a hash is called here, the proof is empty 199 200 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 201 assert.NoError(t, err, "proof rejected") 202 } 203 204 func testSingleAddGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { 205 c := make(Circuit, 3) 206 c[2] = Wire{ 207 Gate: Gates["add"], 208 Inputs: []*Wire{&c[0], &c[1]}, 209 } 210 211 assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) 212 213 proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 214 assert.NoError(t,err) 215 216 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 217 assert.NoError(t, err, "proof rejected") 218 219 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 220 assert.NotNil(t, err, "bad proof accepted") 221 } 222 223 func testSingleMulGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { 224 225 c := make(Circuit, 3) 226 c[2] = Wire{ 227 Gate: Gates["mul"], 228 Inputs: []*Wire{&c[0], &c[1]}, 229 } 230 231 assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) 232 233 proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 234 assert.NoError(t, err) 235 236 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 237 assert.NoError(t, err, "proof rejected") 238 239 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 240 assert.NotNil(t, err, "bad proof accepted") 241 } 242 243 func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]{{.ElementType}}) { 244 c := make(Circuit, 3) 245 246 c[1] = Wire{ 247 Gate: IdentityGate{}, 248 Inputs: []*Wire{&c[0]}, 249 } 250 251 c[2] = Wire{ 252 Gate: IdentityGate{}, 253 Inputs: []*Wire{&c[0]}, 254 } 255 256 assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) 257 258 proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 259 assert.NoError(t, err) 260 261 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 262 assert.NoError(t, err, "proof rejected") 263 264 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 265 assert.NotNil(t, err, "bad proof accepted") 266 } 267 268 func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { 269 c := make(Circuit, 3) 270 271 c[2] = Wire{ 272 Gate: mimcCipherGate{}, 273 Inputs: []*Wire{&c[0], &c[1]}, 274 } 275 276 t.Log("Evaluating all circuit wires") 277 assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) 278 t.Log("Circuit evaluation complete") 279 proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 280 assert.NoError(t, err) 281 t.Log("Proof complete") 282 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 283 assert.NoError(t, err, "proof rejected") 284 285 t.Log("Successful verification complete") 286 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 287 assert.NotNil(t, err, "bad proof accepted") 288 t.Log("Unsuccessful verification complete") 289 } 290 291 func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]{{.ElementType}}) { 292 c := make(Circuit, 3) 293 294 c[1] = Wire{ 295 Gate: IdentityGate{}, 296 Inputs: []*Wire{&c[0]}, 297 } 298 c[2] = Wire{ 299 Gate: IdentityGate{}, 300 Inputs: []*Wire{&c[1]}, 301 } 302 303 assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) 304 305 proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 306 assert.NoError(t, err) 307 308 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 309 assert.NoError(t, err, "proof rejected") 310 311 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 312 assert.NotNil(t, err, "bad proof accepted") 313 } 314 315 func mimcCircuit(numRounds int) Circuit { 316 c := make(Circuit, numRounds+2) 317 318 for i := 2; i < len(c); i++ { 319 c[i] = Wire{ 320 Gate: mimcCipherGate{}, 321 Inputs: []*Wire{&c[i-1], &c[0]}, 322 } 323 } 324 return c 325 } 326 327 func testMimc(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { 328 //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) 329 // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 330 331 c := mimcCircuit(numRounds) 332 333 t.Log("Evaluating all circuit wires") 334 assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) 335 t.Log("Circuit evaluation complete") 336 337 proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 338 assert.NoError(t, err) 339 340 t.Log("Proof finished") 341 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 342 assert.NoError(t, err, "proof rejected") 343 344 t.Log("Successful verification finished") 345 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 346 assert.NotNil(t, err, "bad proof accepted") 347 t.Log("Unsuccessful verification finished") 348 } 349 350 func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { 351 // This imitates the MiMC circuit 352 353 c := make(Circuit, numRounds+2) 354 355 for i := 2; i < len(c); i++ { 356 c[i] = Wire{ 357 Gate: Gates["mul"], 358 Inputs: []*Wire{&c[i-1], &c[0]}, 359 } 360 } 361 362 assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) 363 364 proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 365 assert.NoError(t, err) 366 367 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) 368 assert.NoError(t, err, "proof rejected") 369 370 err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) 371 assert.NotNil(t, err, "bad proof accepted") 372 } 373 374 func setRandom(slice []{{.ElementType}}) { 375 for i := range slice { 376 slice[i].SetRandom() 377 } 378 } 379 380 func generateTestProver(path string) func(t *testing.T) { 381 return func(t *testing.T) { 382 testCase, err := newTestCase(path) 383 assert.NoError(t, err) 384 proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) 385 assert.NoError(t, err) 386 assert.NoError(t, proofEquals(testCase.Proof, proof)) 387 } 388 } 389 390 func generateTestVerifier(path string) func(t *testing.T) { 391 return func(t *testing.T) { 392 testCase, err := newTestCase(path) 393 assert.NoError(t, err) 394 err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) 395 assert.NoError(t, err, "proof rejected") 396 testCase, err = newTestCase(path) 397 assert.NoError(t, err) 398 err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) 399 assert.NotNil(t, err, "bad proof accepted") 400 } 401 } 402 403 func TestGkrVectors(t *testing.T) { 404 405 testDirPath := "{{.TestVectorsRelativePath}}" 406 dirEntries, err := os.ReadDir(testDirPath) 407 assert.NoError(t, err) 408 for _, dirEntry := range dirEntries { 409 if !dirEntry.IsDir() { 410 411 if filepath.Ext(dirEntry.Name()) == ".json" { 412 path := filepath.Join(testDirPath, dirEntry.Name()) 413 noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] 414 415 t.Run(noExt+"_prover", generateTestProver(path)) 416 t.Run(noExt+"_verifier", generateTestVerifier(path)) 417 418 } 419 } 420 } 421 } 422 423 func proofEquals(expected Proof, seen Proof) error { 424 if len(expected) != len(seen) { 425 return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) 426 } 427 for i, x := range expected { 428 xSeen := seen[i] 429 430 if xSeen.FinalEvalProof == nil { 431 if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { 432 return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) 433 } 434 } else { 435 if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { 436 return fmt.Errorf("final evaluation proof mismatch") 437 } 438 } 439 if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { 440 return err 441 } 442 } 443 return nil 444 } 445 446 func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { 447 fmt.Println("creating circuit structure") 448 c := mimcCircuit(mimcDepth) 449 450 in0 := make([]fr.Element, nbInstances) 451 in1 := make([]fr.Element, nbInstances) 452 setRandom(in0) 453 setRandom(in1) 454 455 fmt.Println("evaluating circuit") 456 start := time.Now().UnixMicro() 457 assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) 458 solved := time.Now().UnixMicro() - start 459 fmt.Println("solved in", solved, "μs") 460 461 //b.ResetTimer() 462 fmt.Println("constructing proof") 463 start = time.Now().UnixMicro() 464 _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) 465 proved := time.Now().UnixMicro() - start 466 fmt.Println("proved in", proved, "μs") 467 assert.NoError(b, err) 468 } 469 470 func BenchmarkGkrMimc19(b *testing.B) { 471 benchmarkGkrMiMC(b, 1<<19, 91) 472 } 473 474 func BenchmarkGkrMimc17(b *testing.B) { 475 benchmarkGkrMiMC(b, 1<<17, 91) 476 } 477 478 func TestTopSortTrivial(t *testing.T) { 479 c := make(Circuit, 2) 480 c[0].Inputs = []*Wire{&c[1]} 481 sorted := {{$topologicalSort}}(c) 482 assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) 483 } 484 485 func TestTopSortDeep(t *testing.T) { 486 c := make(Circuit, 4) 487 c[0].Inputs = []*Wire{&c[2]} 488 c[1].Inputs = []*Wire{&c[3]} 489 c[2].Inputs = []*Wire{} 490 c[3].Inputs = []*Wire{&c[0]} 491 sorted := {{$topologicalSort}}(c) 492 assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) 493 } 494 495 func TestTopSortWide(t *testing.T) { 496 c := make(Circuit, 10) 497 c[0].Inputs = []*Wire{&c[3], &c[8]} 498 c[1].Inputs = []*Wire{&c[6]} 499 c[2].Inputs = []*Wire{&c[4]} 500 c[3].Inputs = []*Wire{} 501 c[4].Inputs = []*Wire{} 502 c[5].Inputs = []*Wire{&c[9]} 503 c[6].Inputs = []*Wire{&c[9]} 504 c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} 505 c[8].Inputs = []*Wire{&c[4], &c[3]} 506 c[9].Inputs = []*Wire{} 507 508 sorted := {{$topologicalSort}}(c) 509 sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} 510 511 assert.Equal(t, sortedExpected, sorted) 512 } 513 514 {{template "gkrTestVectors" .}}