github.com/consensys/gnark@v0.11.0/backend/groth16/bn254/prove.go (about)

     1  // Copyright 2020 ConsenSys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Code generated by gnark DO NOT EDIT
    16  
    17  package groth16
    18  
    19  import (
    20  	"fmt"
    21  	"github.com/consensys/gnark-crypto/ecc"
    22  	curve "github.com/consensys/gnark-crypto/ecc/bn254"
    23  	"github.com/consensys/gnark-crypto/ecc/bn254/fr"
    24  	"github.com/consensys/gnark-crypto/ecc/bn254/fr/fft"
    25  	"github.com/consensys/gnark-crypto/ecc/bn254/fr/hash_to_field"
    26  	"github.com/consensys/gnark/backend"
    27  	"github.com/consensys/gnark/backend/groth16/internal"
    28  	"github.com/consensys/gnark/backend/witness"
    29  	"github.com/consensys/gnark/constraint"
    30  	cs "github.com/consensys/gnark/constraint/bn254"
    31  	"github.com/consensys/gnark/constraint/solver"
    32  	"github.com/consensys/gnark/internal/utils"
    33  	"github.com/consensys/gnark/logger"
    34  	"math/big"
    35  	"runtime"
    36  	"time"
    37  
    38  	fcs "github.com/consensys/gnark/frontend/cs"
    39  )
    40  
    41  // Proof represents a Groth16 proof that was encoded with a ProvingKey and can be verified
    42  // with a valid statement and a VerifyingKey
    43  // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf
    44  type Proof struct {
    45  	Ar, Krs       curve.G1Affine
    46  	Bs            curve.G2Affine
    47  	Commitments   []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072
    48  	CommitmentPok curve.G1Affine   // Batched proof of knowledge of the above commitments
    49  }
    50  
    51  // isValid ensures proof elements are in the correct subgroup
    52  func (proof *Proof) isValid() bool {
    53  	return proof.Ar.IsInSubGroup() && proof.Krs.IsInSubGroup() && proof.Bs.IsInSubGroup()
    54  }
    55  
    56  // CurveID returns the curveID
    57  func (proof *Proof) CurveID() ecc.ID {
    58  	return curve.ID
    59  }
    60  
    61  // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part).
    62  func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) {
    63  	opt, err := backend.NewProverConfig(opts...)
    64  	if err != nil {
    65  		return nil, fmt.Errorf("new prover config: %w", err)
    66  	}
    67  	if opt.HashToFieldFn == nil {
    68  		opt.HashToFieldFn = hash_to_field.New([]byte(constraint.CommitmentDst))
    69  	}
    70  
    71  	log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "none").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger()
    72  
    73  	commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments)
    74  
    75  	proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))}
    76  
    77  	solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)]
    78  
    79  	privateCommittedValues := make([][]fr.Element, len(commitmentInfo))
    80  
    81  	// override hints
    82  	bsb22ID := solver.GetHintID(fcs.Bsb22CommitmentComputePlaceholder)
    83  	solverOpts = append(solverOpts, solver.OverrideHint(bsb22ID, func(_ *big.Int, in []*big.Int, out []*big.Int) error {
    84  		i := int(in[0].Int64())
    85  		in = in[1:]
    86  		privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted))
    87  		hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)]
    88  		committed := in[+len(hashed):]
    89  		for j, inJ := range committed {
    90  			privateCommittedValues[i][j].SetBigInt(inJ)
    91  		}
    92  
    93  		var err error
    94  		if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil {
    95  			return err
    96  		}
    97  
    98  		opt.HashToFieldFn.Write(constraint.SerializeCommitment(proof.Commitments[i].Marshal(), hashed, (fr.Bits-1)/8+1))
    99  		hashBts := opt.HashToFieldFn.Sum(nil)
   100  		opt.HashToFieldFn.Reset()
   101  		nbBuf := fr.Bytes
   102  		if opt.HashToFieldFn.Size() < fr.Bytes {
   103  			nbBuf = opt.HashToFieldFn.Size()
   104  		}
   105  		var res fr.Element
   106  		res.SetBytes(hashBts[:nbBuf])
   107  		res.BigInt(out[0])
   108  		return nil
   109  	}))
   110  
   111  	_solution, err := r1cs.Solve(fullWitness, solverOpts...)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	solution := _solution.(*cs.R1CSSolution)
   117  	wireValues := []fr.Element(solution.W)
   118  
   119  	start := time.Now()
   120  	poks := make([]curve.G1Affine, len(pk.CommitmentKeys))
   121  
   122  	for i := range pk.CommitmentKeys {
   123  		var err error
   124  		if poks[i], err = pk.CommitmentKeys[i].ProveKnowledge(privateCommittedValues[i]); err != nil {
   125  			return nil, err
   126  		}
   127  	}
   128  	// compute challenge for folding the PoKs from the commitments
   129  	commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo))
   130  	for i := range commitmentInfo {
   131  		copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal())
   132  	}
   133  	challenge, err := fr.Hash(commitmentsSerialized, []byte("G16-BSB22"), 1)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	if _, err = proof.CommitmentPok.Fold(poks, challenge[0], ecc.MultiExpConfig{NbTasks: 1}); err != nil {
   138  		return nil, err
   139  	}
   140  
   141  	// H (witness reduction / FFT part)
   142  	var h []fr.Element
   143  	chHDone := make(chan struct{}, 1)
   144  	go func() {
   145  		h = computeH(solution.A, solution.B, solution.C, &pk.Domain)
   146  		solution.A = nil
   147  		solution.B = nil
   148  		solution.C = nil
   149  		chHDone <- struct{}{}
   150  	}()
   151  
   152  	// we need to copy and filter the wireValues for each multi exp
   153  	// as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity
   154  	var wireValuesA, wireValuesB []fr.Element
   155  	chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1)
   156  
   157  	go func() {
   158  		wireValuesA = make([]fr.Element, len(wireValues)-int(pk.NbInfinityA))
   159  		for i, j := 0, 0; j < len(wireValuesA); i++ {
   160  			if pk.InfinityA[i] {
   161  				continue
   162  			}
   163  			wireValuesA[j] = wireValues[i]
   164  			j++
   165  		}
   166  		close(chWireValuesA)
   167  	}()
   168  	go func() {
   169  		wireValuesB = make([]fr.Element, len(wireValues)-int(pk.NbInfinityB))
   170  		for i, j := 0, 0; j < len(wireValuesB); i++ {
   171  			if pk.InfinityB[i] {
   172  				continue
   173  			}
   174  			wireValuesB[j] = wireValues[i]
   175  			j++
   176  		}
   177  		close(chWireValuesB)
   178  	}()
   179  
   180  	// sample random r and s
   181  	var r, s big.Int
   182  	var _r, _s, _kr fr.Element
   183  	if _, err := _r.SetRandom(); err != nil {
   184  		return nil, err
   185  	}
   186  	if _, err := _s.SetRandom(); err != nil {
   187  		return nil, err
   188  	}
   189  	_kr.Mul(&_r, &_s).Neg(&_kr)
   190  
   191  	_r.BigInt(&r)
   192  	_s.BigInt(&s)
   193  
   194  	// computes r[δ], s[δ], kr[δ]
   195  	deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr})
   196  
   197  	var bs1, ar curve.G1Jac
   198  
   199  	n := runtime.NumCPU()
   200  
   201  	chBs1Done := make(chan error, 1)
   202  	computeBS1 := func() {
   203  		<-chWireValuesB
   204  		if _, err := bs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil {
   205  			chBs1Done <- err
   206  			close(chBs1Done)
   207  			return
   208  		}
   209  		bs1.AddMixed(&pk.G1.Beta)
   210  		bs1.AddMixed(&deltas[1])
   211  		chBs1Done <- nil
   212  	}
   213  
   214  	chArDone := make(chan error, 1)
   215  	computeAR1 := func() {
   216  		<-chWireValuesA
   217  		if _, err := ar.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil {
   218  			chArDone <- err
   219  			close(chArDone)
   220  			return
   221  		}
   222  		ar.AddMixed(&pk.G1.Alpha)
   223  		ar.AddMixed(&deltas[0])
   224  		proof.Ar.FromJacobian(&ar)
   225  		chArDone <- nil
   226  	}
   227  
   228  	chKrsDone := make(chan error, 1)
   229  	computeKRS := func() {
   230  		// we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z
   231  		// however, having similar lengths for our tasks helps with parallelism
   232  
   233  		var krs, krs2, p1 curve.G1Jac
   234  		chKrs2Done := make(chan error, 1)
   235  		sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2
   236  		go func() {
   237  			_, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2})
   238  			chKrs2Done <- err
   239  		}()
   240  
   241  		// filter the wire values if needed
   242  		// TODO Perf @Tabaie worst memory allocation offender
   243  		toRemove := commitmentInfo.GetPrivateCommitted()
   244  		toRemove = append(toRemove, commitmentInfo.CommitmentIndexes())
   245  		_wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...))
   246  
   247  		if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil {
   248  			chKrsDone <- err
   249  			return
   250  		}
   251  		krs.AddMixed(&deltas[2])
   252  		n := 3
   253  		for n != 0 {
   254  			select {
   255  			case err := <-chKrs2Done:
   256  				if err != nil {
   257  					chKrsDone <- err
   258  					return
   259  				}
   260  				krs.AddAssign(&krs2)
   261  			case err := <-chArDone:
   262  				if err != nil {
   263  					chKrsDone <- err
   264  					return
   265  				}
   266  				p1.ScalarMultiplication(&ar, &s)
   267  				krs.AddAssign(&p1)
   268  			case err := <-chBs1Done:
   269  				if err != nil {
   270  					chKrsDone <- err
   271  					return
   272  				}
   273  				p1.ScalarMultiplication(&bs1, &r)
   274  				krs.AddAssign(&p1)
   275  			}
   276  			n--
   277  		}
   278  
   279  		proof.Krs.FromJacobian(&krs)
   280  		chKrsDone <- nil
   281  	}
   282  
   283  	computeBS2 := func() error {
   284  		// Bs2 (1 multi exp G2 - size = len(wires))
   285  		var Bs, deltaS curve.G2Jac
   286  
   287  		nbTasks := n
   288  		if nbTasks <= 16 {
   289  			// if we don't have a lot of CPUs, this may artificially split the MSM
   290  			nbTasks *= 2
   291  		}
   292  		<-chWireValuesB
   293  		if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil {
   294  			return err
   295  		}
   296  
   297  		deltaS.FromAffine(&pk.G2.Delta)
   298  		deltaS.ScalarMultiplication(&deltaS, &s)
   299  		Bs.AddAssign(&deltaS)
   300  		Bs.AddMixed(&pk.G2.Beta)
   301  
   302  		proof.Bs.FromJacobian(&Bs)
   303  		return nil
   304  	}
   305  
   306  	// wait for FFT to end, as it uses all our CPUs
   307  	<-chHDone
   308  
   309  	// schedule our proof part computations
   310  	go computeKRS()
   311  	go computeAR1()
   312  	go computeBS1()
   313  	if err := computeBS2(); err != nil {
   314  		return nil, err
   315  	}
   316  
   317  	// wait for all parts of the proof to be computed.
   318  	if err := <-chKrsDone; err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	log.Debug().Dur("took", time.Since(start)).Msg("prover done")
   323  
   324  	return proof, nil
   325  }
   326  
   327  // if len(toRemove) == 0, returns slice
   328  // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex
   329  // this assumes len(slice) > len(toRemove)
   330  // filterHeap modifies toRemove
   331  func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) {
   332  
   333  	if len(toRemove) == 0 {
   334  		return slice
   335  	}
   336  
   337  	heap := utils.IntHeap(toRemove)
   338  	heap.Heapify()
   339  
   340  	r = make([]fr.Element, 0, len(slice))
   341  
   342  	// note: we can optimize that for the likely case where len(slice) >>> len(toRemove)
   343  	for i := 0; i < len(slice); i++ {
   344  		if len(heap) > 0 && i+sliceFirstIndex == heap[0] {
   345  			for len(heap) > 0 && i+sliceFirstIndex == heap[0] {
   346  				heap.Pop()
   347  			}
   348  			continue
   349  		}
   350  		r = append(r, slice[i])
   351  	}
   352  
   353  	return
   354  }
   355  
   356  func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
   357  	// H part of Krs
   358  	// Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1))
   359  	// 	1 - _a = ifft(a), _b = ifft(b), _c = ifft(c)
   360  	// 	2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c)
   361  	// 	3 - h = ifft_coset(ca o cb - cc)
   362  
   363  	n := len(a)
   364  
   365  	// add padding to ensure input length is domain cardinality
   366  	padding := make([]fr.Element, int(domain.Cardinality)-n)
   367  	a = append(a, padding...)
   368  	b = append(b, padding...)
   369  	c = append(c, padding...)
   370  	n = len(a)
   371  
   372  	domain.FFTInverse(a, fft.DIF)
   373  	domain.FFTInverse(b, fft.DIF)
   374  	domain.FFTInverse(c, fft.DIF)
   375  
   376  	domain.FFT(a, fft.DIT, fft.OnCoset())
   377  	domain.FFT(b, fft.DIT, fft.OnCoset())
   378  	domain.FFT(c, fft.DIT, fft.OnCoset())
   379  
   380  	var den, one fr.Element
   381  	one.SetOne()
   382  	den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
   383  	den.Sub(&den, &one).Inverse(&den)
   384  
   385  	// h = ifft_coset(ca o cb - cc)
   386  	// reusing a to avoid unnecessary memory allocation
   387  	utils.Parallelize(n, func(start, end int) {
   388  		for i := start; i < end; i++ {
   389  			a[i].Mul(&a[i], &b[i]).
   390  				Sub(&a[i], &c[i]).
   391  				Mul(&a[i], &den)
   392  		}
   393  	})
   394  
   395  	// ifft_coset
   396  	domain.FFTInverse(a, fft.DIF, fft.OnCoset())
   397  
   398  	return a
   399  }