github.com/consensys/gnark-crypto@v0.14.0/internal/generator/sumcheck/template/sumcheck.test.go.tmpl (about) 1 import ( 2 "fmt" 3 "{{.FieldPackagePath}}" 4 "{{.FieldPackagePath}}/polynomial" 5 "{{.FieldPackagePath}}/test_vector_utils" 6 fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" 7 "github.com/stretchr/testify/assert" 8 "hash" 9 "math/bits" 10 "strings" 11 "testing" 12 ) 13 14 type singleMultilinClaim struct { 15 g polynomial.MultiLin 16 } 17 18 func (c singleMultilinClaim) ProveFinalEval(r []{{.ElementType}}) interface{} { 19 return nil // verifier can compute the final eval itself 20 } 21 22 func (c singleMultilinClaim) VarsNum() int { 23 return bits.TrailingZeros(uint(len(c.g))) 24 } 25 26 func (c singleMultilinClaim) ClaimsNum() int { 27 return 1 28 } 29 30 func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { 31 sum := g[len(g)/2] 32 for i := len(g)/2 + 1; i < len(g); i++ { 33 sum.Add(&sum, &g[i]) 34 } 35 return []{{.ElementType}}{sum} 36 } 37 38 func (c singleMultilinClaim) Combine({{.ElementType}}) polynomial.Polynomial { 39 return sumForX1One(c.g) 40 } 41 42 func (c *singleMultilinClaim) Next(r {{.ElementType}}) polynomial.Polynomial { 43 c.g.Fold(r) 44 return sumForX1One(c.g) 45 } 46 47 type singleMultilinLazyClaim struct { 48 g polynomial.MultiLin 49 claimedSum {{.ElementType}} 50 } 51 52 func (c singleMultilinLazyClaim) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { 53 val := c.g.Evaluate(r, nil) 54 if val.Equal(&purportedValue) { 55 return nil 56 } 57 return fmt.Errorf("mismatch") 58 } 59 60 func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs {{.ElementType}}) {{.ElementType}} { 61 return c.claimedSum 62 } 63 64 func (c singleMultilinLazyClaim) Degree(i int) int { 65 return 1 66 } 67 68 func (c singleMultilinLazyClaim) ClaimsNum() int { 69 return 1 70 } 71 72 func (c singleMultilinLazyClaim) VarsNum() int { 73 return bits.TrailingZeros(uint(len(c.g))) 74 } 75 76 func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { 77 poly := make(polynomial.MultiLin, len(polyInt)) 78 for i, n := range polyInt { 79 poly[i].SetUint64(n) 80 } 81 82 claim := singleMultilinClaim{g: poly.Clone()} 83 84 proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) 85 if err != nil { 86 return err 87 } 88 89 var sb strings.Builder 90 for _, p := range proof.PartialSumPolys { 91 92 sb.WriteString("\t{") 93 for i := 0; i < len(p); i++ { 94 sb.WriteString(p[i].String()) 95 if i+1 < len(p) { 96 sb.WriteString(", ") 97 } 98 } 99 sb.WriteString("}\n") 100 } 101 102 lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} 103 if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { 104 return err 105 } 106 107 proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) 108 lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} 109 if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { 110 return fmt.Errorf("bad proof accepted") 111 } 112 return nil 113 } 114 115 func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { 116 //printMsws(36) 117 118 polys := [][]uint64{ 119 {1, 2, 3, 4}, // 1 + 2X₁ + X₂ 120 {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ 121 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ 122 } 123 124 const MaxStep = 4 125 const MaxStart = 4 126 hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) 127 128 for step := 0; step < MaxStep; step++ { 129 for startState := 0; startState < MaxStart; startState++ { 130 if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted 131 continue 132 } 133 hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) 134 } 135 } 136 137 for _, poly := range polys { 138 for _, hashGen := range hashGens { 139 assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), 140 "failed with poly %v and hashGen %v", poly, hashGen()) 141 } 142 } 143 }