github.com/consensys/gnark-crypto@v0.14.0/internal/generator/sumcheck/template/sumcheck.test.go.tmpl (about)

     1  import (
     2  	"fmt"
     3  	"{{.FieldPackagePath}}"
     4  	"{{.FieldPackagePath}}/polynomial"
     5  	"{{.FieldPackagePath}}/test_vector_utils"
     6  	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
     7  	"github.com/stretchr/testify/assert"
     8  	"hash"
     9  	"math/bits"
    10  	"strings"
    11  	"testing"
    12  )
    13  
    14  type singleMultilinClaim struct {
    15  	g polynomial.MultiLin
    16  }
    17  
    18  func (c singleMultilinClaim) ProveFinalEval(r []{{.ElementType}}) interface{} {
    19  	return nil // verifier can compute the final eval itself
    20  }
    21  
    22  func (c singleMultilinClaim) VarsNum() int {
    23  	return bits.TrailingZeros(uint(len(c.g)))
    24  }
    25  
    26  func (c singleMultilinClaim) ClaimsNum() int {
    27  	return 1
    28  }
    29  
    30  func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial {
    31  	sum := g[len(g)/2]
    32  	for i := len(g)/2 + 1; i < len(g); i++ {
    33  		sum.Add(&sum, &g[i])
    34  	}
    35  	return []{{.ElementType}}{sum}
    36  }
    37  
    38  func (c singleMultilinClaim) Combine({{.ElementType}}) polynomial.Polynomial {
    39  	return sumForX1One(c.g)
    40  }
    41  
    42  func (c *singleMultilinClaim) Next(r {{.ElementType}}) polynomial.Polynomial {
    43  	c.g.Fold(r)
    44  	return sumForX1One(c.g)
    45  }
    46  
    47  type singleMultilinLazyClaim struct {
    48  	g          polynomial.MultiLin
    49  	claimedSum {{.ElementType}}
    50  }
    51  
    52  func (c singleMultilinLazyClaim) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error {
    53  	val := c.g.Evaluate(r, nil)
    54  	if val.Equal(&purportedValue) {
    55  		return nil
    56  	}
    57  	return fmt.Errorf("mismatch")
    58  }
    59  
    60  func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs {{.ElementType}}) {{.ElementType}} {
    61  	return c.claimedSum
    62  }
    63  
    64  func (c singleMultilinLazyClaim) Degree(i int) int {
    65  	return 1
    66  }
    67  
    68  func (c singleMultilinLazyClaim) ClaimsNum() int {
    69  	return 1
    70  }
    71  
    72  func (c singleMultilinLazyClaim) VarsNum() int {
    73  	return bits.TrailingZeros(uint(len(c.g)))
    74  }
    75  
    76  func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error {
    77  	poly := make(polynomial.MultiLin, len(polyInt))
    78  	for i, n := range polyInt {
    79  		poly[i].SetUint64(n)
    80  	}
    81  
    82  	claim := singleMultilinClaim{g: poly.Clone()}
    83  
    84  	proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator()))
    85  	if err != nil {
    86  		return err
    87  	}
    88  
    89  	var sb strings.Builder
    90  	for _, p := range proof.PartialSumPolys {
    91  
    92  		sb.WriteString("\t{")
    93  		for i := 0; i < len(p); i++ {
    94  			sb.WriteString(p[i].String())
    95  			if i+1 < len(p) {
    96  				sb.WriteString(", ")
    97  			}
    98  		}
    99  		sb.WriteString("}\n")
   100  	}
   101  
   102  	lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()}
   103  	if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil {
   104  		return err
   105  	}
   106  
   107  	proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1))
   108  	lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()}
   109  	if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil {
   110  		return fmt.Errorf("bad proof accepted")
   111  	}
   112  	return nil
   113  }
   114  
   115  func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) {
   116  	//printMsws(36)
   117  
   118  	polys := [][]uint64{
   119  		{1, 2, 3, 4},             // 1 + 2X₁ + X₂
   120  		{1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃
   121  		{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄
   122  	}
   123  
   124  	const MaxStep = 4
   125  	const MaxStart = 4
   126  	hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep)
   127  
   128  	for step := 0; step < MaxStep; step++ {
   129  		for startState := 0; startState < MaxStart; startState++ {
   130  			if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted
   131  				continue
   132  			}
   133  			hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step))
   134  		}
   135  	}
   136  
   137  	for _, poly := range polys {
   138  		for _, hashGen := range hashGens {
   139  			assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen),
   140  				"failed with poly %v and hashGen %v", poly, hashGen())
   141  		}
   142  	}
   143  }