github.com/consensys/gnark@v0.11.0/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl (about)

     1  import (
     2  	"fmt"
     3  	"runtime"
     4  	"math/big"
     5  	"time"
     6  
     7  	{{- template "import_fr" . }}
     8  	{{- template "import_curve" . }}
     9  	{{- template "import_backend_cs" . }}
    10  	{{- template "import_fft" . }}
    11  	{{- template "import_hash_to_field" . }}
    12  	"github.com/consensys/gnark/constraint"
    13  	"github.com/consensys/gnark-crypto/ecc"
    14  	"github.com/consensys/gnark/internal/utils"
    15  	"github.com/consensys/gnark/backend"
    16  	"github.com/consensys/gnark/backend/groth16/internal"
    17  	"github.com/consensys/gnark/constraint/solver"
    18  	"github.com/consensys/gnark/backend/witness"
    19  	"github.com/consensys/gnark/logger"
    20  
    21  	fcs "github.com/consensys/gnark/frontend/cs"
    22  )
    23  
    24  
    25  // Proof represents a Groth16 proof that was encoded with a ProvingKey and can be verified
    26  // with a valid statement and a VerifyingKey
    27  // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf
    28  type Proof struct {
    29  	Ar, Krs                   curve.G1Affine
    30  	Bs                        curve.G2Affine
    31  	Commitments []curve.G1Affine	// Pedersen commitments a la https://eprint.iacr.org/2022/1072
    32  	CommitmentPok curve.G1Affine	// Batched proof of knowledge of the above commitments
    33  }
    34  
    35  // isValid ensures proof elements are in the correct subgroup
    36  func (proof *Proof) isValid() bool {
    37  	return proof.Ar.IsInSubGroup() && proof.Krs.IsInSubGroup() && proof.Bs.IsInSubGroup()
    38  }
    39  
    40  // CurveID returns the curveID
    41  func (proof *Proof) CurveID() ecc.ID {
    42  	return curve.ID
    43  }
    44  
    45  // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part).
    46  func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) {
    47  	opt, err := backend.NewProverConfig(opts...)
    48  	if err != nil {
    49  		return nil, fmt.Errorf("new prover config: %w", err)
    50  	}
    51  	if opt.HashToFieldFn == nil {
    52  		opt.HashToFieldFn = hash_to_field.New([]byte(constraint.CommitmentDst))
    53  	}
    54  
    55  	log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "none").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger()
    56  
    57  	commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments)
    58  
    59  	proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))}
    60  
    61  	solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)]
    62  
    63  	privateCommittedValues := make([][]fr.Element, len(commitmentInfo))
    64  
    65  	// override hints
    66  	bsb22ID := solver.GetHintID(fcs.Bsb22CommitmentComputePlaceholder)
    67  	solverOpts = append(solverOpts,	solver.OverrideHint(bsb22ID,  func(_ *big.Int, in []*big.Int, out []*big.Int) error {
    68  			i := int(in[0].Int64()) 
    69  			in = in[1:]
    70  			privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted))
    71  			hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)]
    72  			committed := in[+len(hashed):]
    73  			for j, inJ := range committed {
    74  				privateCommittedValues[i][j].SetBigInt(inJ)
    75  			}
    76  
    77  			var err error
    78  			if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil {
    79  				return err
    80  			}
    81  
    82  			opt.HashToFieldFn.Write(constraint.SerializeCommitment(proof.Commitments[i].Marshal(), hashed, (fr.Bits-1)/8+1))
    83  			hashBts := opt.HashToFieldFn.Sum(nil)
    84  			opt.HashToFieldFn.Reset()
    85  			nbBuf := fr.Bytes
    86  			if opt.HashToFieldFn.Size() < fr.Bytes {
    87  				nbBuf = opt.HashToFieldFn.Size()
    88  			}
    89  			var res fr.Element
    90  			res.SetBytes(hashBts[:nbBuf])
    91  			res.BigInt(out[0])
    92  			return nil
    93  	}))
    94  
    95  	_solution, err := r1cs.Solve(fullWitness, solverOpts...)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	solution := _solution.(*cs.R1CSSolution)
   101  	wireValues := []fr.Element(solution.W)
   102  
   103  	start := time.Now()
   104  	poks := make([]curve.G1Affine, len(pk.CommitmentKeys))
   105  
   106  	for i := range pk.CommitmentKeys {
   107  		var err error
   108  		if poks[i], err = pk.CommitmentKeys[i].ProveKnowledge(privateCommittedValues[i]); err != nil {
   109  			return nil, err
   110  		}
   111  	}
   112  	// compute challenge for folding the PoKs from the commitments
   113  	commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo))
   114  	for i := range commitmentInfo {
   115  		copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal())
   116  	}
   117  	challenge, err := fr.Hash(commitmentsSerialized, []byte("G16-BSB22"), 1)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	if _, err = proof.CommitmentPok.Fold(poks, challenge[0], ecc.MultiExpConfig{NbTasks: 1}); err != nil {
   122  		return nil, err
   123  	}
   124  
   125  	// H (witness reduction / FFT part)
   126  	var h []fr.Element
   127  	chHDone := make(chan struct{}, 1)
   128  	go func() {
   129  		h = computeH(solution.A, solution.B, solution.C, &pk.Domain)
   130  		solution.A = nil
   131  		solution.B = nil
   132  		solution.C = nil
   133  		chHDone <- struct{}{}
   134  	}()
   135  
   136  	// we need to copy and filter the wireValues for each multi exp
   137  	// as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity
   138  	var wireValuesA, wireValuesB []fr.Element
   139  	chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1)
   140  
   141  	go func() {
   142  		wireValuesA = make([]fr.Element, len(wireValues)-int(pk.NbInfinityA))
   143  		for i, j := 0, 0; j < len(wireValuesA); i++ {
   144  			if pk.InfinityA[i] {
   145  				continue
   146  			}
   147  			wireValuesA[j] = wireValues[i]
   148  			j++
   149  		}
   150  		close(chWireValuesA)
   151  	}()
   152  	go func() {
   153  		wireValuesB = make([]fr.Element, len(wireValues)-int(pk.NbInfinityB))
   154  		for i, j := 0, 0; j < len(wireValuesB); i++ {
   155  			if pk.InfinityB[i] {
   156  				continue
   157  			}
   158  			wireValuesB[j] = wireValues[i]
   159  			j++
   160  		}
   161  		close(chWireValuesB)
   162  	}()
   163  
   164  	// sample random r and s
   165  	var r, s big.Int
   166  	var _r, _s, _kr fr.Element
   167  	if _, err := _r.SetRandom(); err != nil {
   168  		return nil, err
   169  	}
   170  	if _, err := _s.SetRandom(); err != nil {
   171  		return nil, err
   172  	}
   173  	_kr.Mul(&_r, &_s).Neg(&_kr)
   174  
   175  	_r.BigInt(&r)
   176  	_s.BigInt(&s)
   177  
   178  	// computes r[δ], s[δ], kr[δ]
   179  	deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr})
   180  
   181  	var bs1, ar curve.G1Jac
   182  
   183  	n := runtime.NumCPU()
   184  
   185  	chBs1Done := make(chan error, 1)
   186  	computeBS1 := func() {
   187  		<-chWireValuesB
   188  		if _, err := bs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil {
   189  			chBs1Done <- err
   190  			close(chBs1Done)
   191  			return
   192  		}
   193  		bs1.AddMixed(&pk.G1.Beta)
   194  		bs1.AddMixed(&deltas[1])
   195  		chBs1Done <- nil
   196  	}
   197  
   198  	chArDone := make(chan error, 1)
   199  	computeAR1 := func() {
   200  		<-chWireValuesA
   201  		if _, err := ar.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil {
   202  			chArDone <- err
   203  			close(chArDone)
   204  			return
   205  		}
   206  		ar.AddMixed(&pk.G1.Alpha)
   207  		ar.AddMixed(&deltas[0])
   208  		proof.Ar.FromJacobian(&ar)
   209  		chArDone <- nil
   210  	}
   211  
   212  	chKrsDone := make(chan error, 1)
   213  	computeKRS := func() {
   214  		// we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z
   215  		// however, having similar lengths for our tasks helps with parallelism
   216  
   217  		var krs, krs2, p1 curve.G1Jac
   218  		chKrs2Done := make(chan error, 1)
   219  		sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2
   220  		go func() {
   221  			_, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2})
   222  			chKrs2Done <- err
   223  		}()
   224  
   225  		// filter the wire values if needed
   226  		// TODO Perf @Tabaie worst memory allocation offender
   227  		toRemove := commitmentInfo.GetPrivateCommitted()
   228  		toRemove = append(toRemove, commitmentInfo.CommitmentIndexes())
   229  		_wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...))
   230  
   231  		if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil {
   232  			chKrsDone <- err
   233  			return
   234  		}
   235  		krs.AddMixed(&deltas[2])
   236  		n := 3
   237  		for n != 0 {
   238  			select {
   239  			case err := <-chKrs2Done:
   240  				if err != nil {
   241  					chKrsDone <- err
   242  					return
   243  				}
   244  				krs.AddAssign(&krs2)
   245  			case err := <-chArDone:
   246  				if err != nil {
   247  					chKrsDone <- err
   248  					return
   249  				}
   250  				p1.ScalarMultiplication(&ar, &s)
   251  				krs.AddAssign(&p1)
   252  			case err := <-chBs1Done:
   253  				if err != nil {
   254  					chKrsDone <- err
   255  					return
   256  				}
   257  				p1.ScalarMultiplication(&bs1, &r)
   258  				krs.AddAssign(&p1)
   259  			}
   260  			n--
   261  		}
   262  
   263  		proof.Krs.FromJacobian(&krs)
   264  		chKrsDone <- nil
   265  	}
   266  
   267  	computeBS2 := func() error {
   268  		// Bs2 (1 multi exp G2 - size = len(wires))
   269  		var Bs, deltaS curve.G2Jac
   270  
   271  		nbTasks := n
   272  		if nbTasks <= 16 {
   273  			// if we don't have a lot of CPUs, this may artificially split the MSM
   274  			nbTasks *= 2
   275  		}
   276  		<-chWireValuesB
   277  		if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil {
   278  			return err
   279  		}
   280  
   281  		deltaS.FromAffine(&pk.G2.Delta)
   282  		deltaS.ScalarMultiplication(&deltaS, &s)
   283  		Bs.AddAssign(&deltaS)
   284  		Bs.AddMixed(&pk.G2.Beta)
   285  
   286  		proof.Bs.FromJacobian(&Bs)
   287  		return nil
   288  	}
   289  
   290  	// wait for FFT to end, as it uses all our CPUs
   291  	<-chHDone
   292  
   293  	// schedule our proof part computations
   294  	go computeKRS()
   295  	go computeAR1()
   296  	go computeBS1()
   297  	if err := computeBS2(); err != nil {
   298  		return nil, err
   299  	}
   300  
   301  	// wait for all parts of the proof to be computed.
   302  	if err := <-chKrsDone; err != nil {
   303  		return nil, err
   304  	}
   305  
   306  	log.Debug().Dur("took", time.Since(start)).Msg("prover done")
   307  
   308  	return proof, nil
   309  }
   310  
   311  // if len(toRemove) == 0, returns slice
   312  // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex
   313  // this assumes len(slice) > len(toRemove)
   314  // filterHeap modifies toRemove
   315  func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) {
   316  
   317  	if len(toRemove) == 0 {
   318  		return slice
   319  	}
   320  	
   321  	heap := utils.IntHeap(toRemove)
   322  	heap.Heapify()
   323  
   324  	r = make([]fr.Element, 0, len(slice))
   325  
   326  	// note: we can optimize that for the likely case where len(slice) >>> len(toRemove)
   327  	for i:=0; i < len(slice);i++ {
   328  		if len(heap) > 0 && i+sliceFirstIndex == heap[0] {
   329  			for len(heap) > 0 && i+sliceFirstIndex == heap[0] {
   330  				heap.Pop()
   331  			}
   332  			continue 
   333  		}
   334  		r = append(r, slice[i])
   335  	}
   336  
   337  	return
   338  }
   339  
   340  func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
   341  	// H part of Krs
   342  	// Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1))
   343  	// 	1 - _a = ifft(a), _b = ifft(b), _c = ifft(c)
   344  	// 	2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c)
   345  	// 	3 - h = ifft_coset(ca o cb - cc)
   346  
   347  	n := len(a)
   348  
   349  	// add padding to ensure input length is domain cardinality
   350  	padding := make([]fr.Element, int(domain.Cardinality)-n)
   351  	a = append(a, padding...)
   352  	b = append(b, padding...)
   353  	c = append(c, padding...)
   354  	n = len(a)
   355  
   356  	domain.FFTInverse(a, fft.DIF)
   357  	domain.FFTInverse(b, fft.DIF)
   358  	domain.FFTInverse(c, fft.DIF)
   359  
   360  	domain.FFT(a, fft.DIT, fft.OnCoset())
   361  	domain.FFT(b, fft.DIT, fft.OnCoset())
   362  	domain.FFT(c, fft.DIT, fft.OnCoset())
   363  
   364  	var den, one fr.Element
   365  	one.SetOne()
   366  	den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
   367  	den.Sub(&den, &one).Inverse(&den)
   368  
   369  	// h = ifft_coset(ca o cb - cc)
   370  	// reusing a to avoid unnecessary memory allocation
   371  	utils.Parallelize(n, func(start, end int) {
   372  		for i := start; i < end; i++ {
   373  			a[i].Mul(&a[i], &b[i]).
   374  				Sub(&a[i], &c[i]).
   375  				Mul(&a[i], &den)
   376  		}
   377  	})
   378  
   379  	// ifft_coset
   380  	domain.FFTInverse(a, fft.DIF, fft.OnCoset())
   381  
   382  	return a
   383  }