github.com/consensys/gnark@v0.11.0/backend/groth16/bn254/icicle/icicle.go (about)

     1  //go:build icicle
     2  
     3  package icicle_bn254
     4  
     5  import (
     6  	"fmt"
     7  	"math/big"
     8  	"math/bits"
     9  	"time"
    10  	"unsafe"
    11  
    12  	curve "github.com/consensys/gnark-crypto/ecc/bn254"
    13  	"github.com/consensys/gnark-crypto/ecc/bn254/fp"
    14  	"github.com/consensys/gnark-crypto/ecc/bn254/fr"
    15  	"github.com/consensys/gnark-crypto/ecc/bn254/fr/hash_to_field"
    16  	"github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen"
    17  	"github.com/consensys/gnark/backend"
    18  	groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254"
    19  	"github.com/consensys/gnark/backend/groth16/internal"
    20  	"github.com/consensys/gnark/backend/witness"
    21  	"github.com/consensys/gnark/constraint"
    22  	cs "github.com/consensys/gnark/constraint/bn254"
    23  	"github.com/consensys/gnark/constraint/solver"
    24  	"github.com/consensys/gnark/internal/utils"
    25  	"github.com/consensys/gnark/logger"
    26  	iciclegnark "github.com/ingonyama-zk/iciclegnark/curves/bn254"
    27  )
    28  
    29  const HasIcicle = true
    30  
    31  func (pk *ProvingKey) setupDevicePointers() error {
    32  	if pk.deviceInfo != nil {
    33  		return nil
    34  	}
    35  	pk.deviceInfo = &deviceInfo{}
    36  	n := int(pk.Domain.Cardinality)
    37  	sizeBytes := n * fr.Bytes
    38  
    39  	/*************************  Start Domain Device Setup  ***************************/
    40  	copyCosetInvDone := make(chan unsafe.Pointer, 1)
    41  	copyCosetDone := make(chan unsafe.Pointer, 1)
    42  	copyDenDone := make(chan unsafe.Pointer, 1)
    43  	/*************************     CosetTableInv      ***************************/
    44  	go iciclegnark.CopyToDevice(pk.Domain.CosetTableInv, sizeBytes, copyCosetInvDone)
    45  
    46  	/*************************     CosetTable      ***************************/
    47  	go iciclegnark.CopyToDevice(pk.Domain.CosetTable, sizeBytes, copyCosetDone)
    48  
    49  	/*************************     Den      ***************************/
    50  	var denI, oneI fr.Element
    51  	oneI.SetOne()
    52  	denI.Exp(pk.Domain.FrMultiplicativeGen, big.NewInt(int64(pk.Domain.Cardinality)))
    53  	denI.Sub(&denI, &oneI).Inverse(&denI)
    54  
    55  	log2SizeFloor := bits.Len(uint(n)) - 1
    56  	denIcicleArr := []fr.Element{denI}
    57  	for i := 0; i < log2SizeFloor; i++ {
    58  		denIcicleArr = append(denIcicleArr, denIcicleArr...)
    59  	}
    60  	pow2Remainder := n - 1<<log2SizeFloor
    61  	for i := 0; i < pow2Remainder; i++ {
    62  		denIcicleArr = append(denIcicleArr, denI)
    63  	}
    64  
    65  	go iciclegnark.CopyToDevice(denIcicleArr, sizeBytes, copyDenDone)
    66  
    67  	/*************************     Twiddles and Twiddles Inv    ***************************/
    68  	twiddlesInv_d_gen, twddles_err := iciclegnark.GenerateTwiddleFactors(n, true)
    69  	if twddles_err != nil {
    70  		return twddles_err
    71  	}
    72  
    73  	twiddles_d_gen, twddles_err := iciclegnark.GenerateTwiddleFactors(n, false)
    74  	if twddles_err != nil {
    75  		return twddles_err
    76  	}
    77  
    78  	/*************************  End Domain Device Setup  ***************************/
    79  	pk.DomainDevice.Twiddles = twiddles_d_gen
    80  	pk.DomainDevice.TwiddlesInv = twiddlesInv_d_gen
    81  
    82  	pk.DomainDevice.CosetTableInv = <-copyCosetInvDone
    83  	pk.DomainDevice.CosetTable = <-copyCosetDone
    84  	pk.DenDevice = <-copyDenDone
    85  
    86  	/*************************  Start G1 Device Setup  ***************************/
    87  	/*************************     A      ***************************/
    88  	pointsBytesA := len(pk.G1.A) * fp.Bytes * 2
    89  	copyADone := make(chan unsafe.Pointer, 1)
    90  	go iciclegnark.CopyPointsToDevice(pk.G1.A, pointsBytesA, copyADone) // Make a function for points
    91  
    92  	/*************************     B      ***************************/
    93  	pointsBytesB := len(pk.G1.B) * fp.Bytes * 2
    94  	copyBDone := make(chan unsafe.Pointer, 1)
    95  	go iciclegnark.CopyPointsToDevice(pk.G1.B, pointsBytesB, copyBDone) // Make a function for points
    96  
    97  	/*************************     K      ***************************/
    98  	var pointsNoInfinity []curve.G1Affine
    99  	for i, gnarkPoint := range pk.G1.K {
   100  		if gnarkPoint.IsInfinity() {
   101  			pk.InfinityPointIndicesK = append(pk.InfinityPointIndicesK, i)
   102  		} else {
   103  			pointsNoInfinity = append(pointsNoInfinity, gnarkPoint)
   104  		}
   105  	}
   106  
   107  	pointsBytesK := len(pointsNoInfinity) * fp.Bytes * 2
   108  	copyKDone := make(chan unsafe.Pointer, 1)
   109  	go iciclegnark.CopyPointsToDevice(pointsNoInfinity, pointsBytesK, copyKDone) // Make a function for points
   110  
   111  	/*************************     Z      ***************************/
   112  	pointsBytesZ := len(pk.G1.Z) * fp.Bytes * 2
   113  	copyZDone := make(chan unsafe.Pointer, 1)
   114  	go iciclegnark.CopyPointsToDevice(pk.G1.Z, pointsBytesZ, copyZDone) // Make a function for points
   115  
   116  	/*************************  End G1 Device Setup  ***************************/
   117  	pk.G1Device.A = <-copyADone
   118  	pk.G1Device.B = <-copyBDone
   119  	pk.G1Device.K = <-copyKDone
   120  	pk.G1Device.Z = <-copyZDone
   121  
   122  	/*************************  Start G2 Device Setup  ***************************/
   123  	pointsBytesB2 := len(pk.G2.B) * fp.Bytes * 4
   124  	copyG2BDone := make(chan unsafe.Pointer, 1)
   125  	go iciclegnark.CopyG2PointsToDevice(pk.G2.B, pointsBytesB2, copyG2BDone) // Make a function for points
   126  	pk.G2Device.B = <-copyG2BDone
   127  
   128  	/*************************  End G2 Device Setup  ***************************/
   129  	return nil
   130  }
   131  
   132  // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part).
   133  func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) {
   134  	opt, err := backend.NewProverConfig(opts...)
   135  	if err != nil {
   136  		return nil, fmt.Errorf("new prover config: %w", err)
   137  	}
   138  	if opt.HashToFieldFn == nil {
   139  		opt.HashToFieldFn = hash_to_field.New([]byte(constraint.CommitmentDst))
   140  	}
   141  	if opt.Accelerator != "icicle" {
   142  		return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...)
   143  	}
   144  	log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "icicle").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger()
   145  	if pk.deviceInfo == nil {
   146  		log.Debug().Msg("precomputing proving key in GPU")
   147  		if err := pk.setupDevicePointers(); err != nil {
   148  			return nil, fmt.Errorf("setup device pointers: %w", err)
   149  		}
   150  	}
   151  
   152  	commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments)
   153  
   154  	proof := &groth16_bn254.Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))}
   155  
   156  	solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)]
   157  
   158  	privateCommittedValues := make([][]fr.Element, len(commitmentInfo))
   159  	for i := range commitmentInfo {
   160  		solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint {
   161  			return func(_ *big.Int, in []*big.Int, out []*big.Int) error {
   162  				privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted))
   163  				hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)]
   164  				committed := in[len(hashed):]
   165  				for j, inJ := range committed {
   166  					privateCommittedValues[i][j].SetBigInt(inJ)
   167  				}
   168  
   169  				var err error
   170  				if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil {
   171  					return err
   172  				}
   173  
   174  				opt.HashToFieldFn.Write(constraint.SerializeCommitment(proof.Commitments[i].Marshal(), hashed, (fr.Bits-1)/8+1))
   175  				hashBts := opt.HashToFieldFn.Sum(nil)
   176  				opt.HashToFieldFn.Reset()
   177  				nbBuf := fr.Bytes
   178  				if opt.HashToFieldFn.Size() < fr.Bytes {
   179  					nbBuf = opt.HashToFieldFn.Size()
   180  				}
   181  				var res fr.Element
   182  				res.SetBytes(hashBts[:nbBuf])
   183  				res.BigInt(out[0])
   184  				return err
   185  			}
   186  		}(i)))
   187  	}
   188  
   189  	if r1cs.GkrInfo.Is() {
   190  		var gkrData cs.GkrSolvingData
   191  		solverOpts = append(solverOpts,
   192  			solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)),
   193  			solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData)))
   194  	}
   195  
   196  	_solution, err := r1cs.Solve(fullWitness, solverOpts...)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  
   201  	solution := _solution.(*cs.R1CSSolution)
   202  	wireValues := []fr.Element(solution.W)
   203  
   204  	start := time.Now()
   205  
   206  	commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo))
   207  	for i := range commitmentInfo {
   208  		copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal())
   209  	}
   210  
   211  	if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	// H (witness reduction / FFT part)
   216  	var h unsafe.Pointer
   217  	chHDone := make(chan struct{}, 1)
   218  	go func() {
   219  		h = computeH(solution.A, solution.B, solution.C, pk)
   220  		solution.A = nil
   221  		solution.B = nil
   222  		solution.C = nil
   223  		chHDone <- struct{}{}
   224  	}()
   225  
   226  	// we need to copy and filter the wireValues for each multi exp
   227  	// as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity
   228  	var wireValuesADevice, wireValuesBDevice iciclegnark.OnDeviceData
   229  	chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1)
   230  
   231  	go func() {
   232  		wireValuesA := make([]fr.Element, len(wireValues)-int(pk.NbInfinityA))
   233  		for i, j := 0, 0; j < len(wireValuesA); i++ {
   234  			if pk.InfinityA[i] {
   235  				continue
   236  			}
   237  			wireValuesA[j] = wireValues[i]
   238  			j++
   239  		}
   240  		wireValuesASize := len(wireValuesA)
   241  		scalarBytes := wireValuesASize * fr.Bytes
   242  
   243  		// Copy scalars to the device and retain ptr to them
   244  		copyDone := make(chan unsafe.Pointer, 1)
   245  		iciclegnark.CopyToDevice(wireValuesA, scalarBytes, copyDone)
   246  		wireValuesADevicePtr := <-copyDone
   247  
   248  		wireValuesADevice = iciclegnark.OnDeviceData{
   249  			P:    wireValuesADevicePtr,
   250  			Size: wireValuesASize,
   251  		}
   252  
   253  		close(chWireValuesA)
   254  	}()
   255  	go func() {
   256  		wireValuesB := make([]fr.Element, len(wireValues)-int(pk.NbInfinityB))
   257  		for i, j := 0, 0; j < len(wireValuesB); i++ {
   258  			if pk.InfinityB[i] {
   259  				continue
   260  			}
   261  			wireValuesB[j] = wireValues[i]
   262  			j++
   263  		}
   264  		wireValuesBSize := len(wireValuesB)
   265  		scalarBytes := wireValuesBSize * fr.Bytes
   266  
   267  		// Copy scalars to the device and retain ptr to them
   268  		copyDone := make(chan unsafe.Pointer, 1)
   269  		iciclegnark.CopyToDevice(wireValuesB, scalarBytes, copyDone)
   270  		wireValuesBDevicePtr := <-copyDone
   271  
   272  		wireValuesBDevice = iciclegnark.OnDeviceData{
   273  			P:    wireValuesBDevicePtr,
   274  			Size: wireValuesBSize,
   275  		}
   276  
   277  		close(chWireValuesB)
   278  	}()
   279  
   280  	// sample random r and s
   281  	var r, s big.Int
   282  	var _r, _s, _kr fr.Element
   283  	if _, err := _r.SetRandom(); err != nil {
   284  		return nil, err
   285  	}
   286  	if _, err := _s.SetRandom(); err != nil {
   287  		return nil, err
   288  	}
   289  	_kr.Mul(&_r, &_s).Neg(&_kr)
   290  
   291  	_r.BigInt(&r)
   292  	_s.BigInt(&s)
   293  
   294  	// computes r[δ], s[δ], kr[δ]
   295  	deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr})
   296  
   297  	var bs1, ar curve.G1Jac
   298  
   299  	computeBS1 := func() error {
   300  		<-chWireValuesB
   301  
   302  		if bs1, _, err = iciclegnark.MsmOnDevice(wireValuesBDevice.P, pk.G1Device.B, wireValuesBDevice.Size, true); err != nil {
   303  			return err
   304  		}
   305  
   306  		bs1.AddMixed(&pk.G1.Beta)
   307  		bs1.AddMixed(&deltas[1])
   308  
   309  		return nil
   310  	}
   311  
   312  	computeAR1 := func() error {
   313  		<-chWireValuesA
   314  
   315  		if ar, _, err = iciclegnark.MsmOnDevice(wireValuesADevice.P, pk.G1Device.A, wireValuesADevice.Size, true); err != nil {
   316  			return err
   317  		}
   318  
   319  		ar.AddMixed(&pk.G1.Alpha)
   320  		ar.AddMixed(&deltas[0])
   321  		proof.Ar.FromJacobian(&ar)
   322  
   323  		return nil
   324  	}
   325  
   326  	computeKRS := func() error {
   327  		var krs, krs2, p1 curve.G1Jac
   328  		sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2
   329  
   330  		// check for small circuits as iciclegnark doesn't handle zero sizes well
   331  		if len(pk.G1.Z) > 0 {
   332  			if krs2, _, err = iciclegnark.MsmOnDevice(h, pk.G1Device.Z, sizeH, true); err != nil {
   333  				return err
   334  			}
   335  		}
   336  
   337  		// filter the wire values if needed
   338  		// TODO Perf @Tabaie worst memory allocation offender
   339  		toRemove := commitmentInfo.GetPrivateCommitted()
   340  		toRemove = append(toRemove, commitmentInfo.CommitmentIndexes())
   341  		scalars := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...))
   342  
   343  		// filter zero/infinity points since icicle doesn't handle them
   344  		// See https://github.com/ingonyama-zk/icicle/issues/169 for more info
   345  		for _, indexToRemove := range pk.InfinityPointIndicesK {
   346  			scalars = append(scalars[:indexToRemove], scalars[indexToRemove+1:]...)
   347  		}
   348  
   349  		scalarBytes := len(scalars) * fr.Bytes
   350  
   351  		copyDone := make(chan unsafe.Pointer, 1)
   352  		iciclegnark.CopyToDevice(scalars, scalarBytes, copyDone)
   353  		scalars_d := <-copyDone
   354  
   355  		krs, _, err = iciclegnark.MsmOnDevice(scalars_d, pk.G1Device.K, len(scalars), true)
   356  		iciclegnark.FreeDevicePointer(scalars_d)
   357  
   358  		if err != nil {
   359  			return err
   360  		}
   361  
   362  		krs.AddMixed(&deltas[2])
   363  
   364  		krs.AddAssign(&krs2)
   365  
   366  		p1.ScalarMultiplication(&ar, &s)
   367  		krs.AddAssign(&p1)
   368  
   369  		p1.ScalarMultiplication(&bs1, &r)
   370  		krs.AddAssign(&p1)
   371  
   372  		proof.Krs.FromJacobian(&krs)
   373  
   374  		return nil
   375  	}
   376  
   377  	computeBS2 := func() error {
   378  		// Bs2 (1 multi exp G2 - size = len(wires))
   379  		var Bs, deltaS curve.G2Jac
   380  
   381  		<-chWireValuesB
   382  		if Bs, _, err = iciclegnark.MsmG2OnDevice(wireValuesBDevice.P, pk.G2Device.B, wireValuesBDevice.Size, true); err != nil {
   383  			return err
   384  		}
   385  
   386  		deltaS.FromAffine(&pk.G2.Delta)
   387  		deltaS.ScalarMultiplication(&deltaS, &s)
   388  		Bs.AddAssign(&deltaS)
   389  		Bs.AddMixed(&pk.G2.Beta)
   390  
   391  		proof.Bs.FromJacobian(&Bs)
   392  		return nil
   393  	}
   394  
   395  	// wait for FFT to end
   396  	<-chHDone
   397  
   398  	// schedule our proof part computations
   399  	if err := computeAR1(); err != nil {
   400  		return nil, err
   401  	}
   402  	if err := computeBS1(); err != nil {
   403  		return nil, err
   404  	}
   405  	if err := computeKRS(); err != nil {
   406  		return nil, err
   407  	}
   408  	if err := computeBS2(); err != nil {
   409  		return nil, err
   410  	}
   411  
   412  	log.Debug().Dur("took", time.Since(start)).Msg("prover done")
   413  
   414  	// free device/GPU memory that is not needed for future proofs (scalars/hpoly)
   415  	go func() {
   416  		iciclegnark.FreeDevicePointer(wireValuesADevice.P)
   417  		iciclegnark.FreeDevicePointer(wireValuesBDevice.P)
   418  		iciclegnark.FreeDevicePointer(h)
   419  	}()
   420  
   421  	return proof, nil
   422  }
   423  
   424  // if len(toRemove) == 0, returns slice
   425  // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex
   426  // this assumes len(slice) > len(toRemove)
   427  // filterHeap modifies toRemove
   428  func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) {
   429  
   430  	if len(toRemove) == 0 {
   431  		return slice
   432  	}
   433  
   434  	heap := utils.IntHeap(toRemove)
   435  	heap.Heapify()
   436  
   437  	r = make([]fr.Element, 0, len(slice))
   438  
   439  	// note: we can optimize that for the likely case where len(slice) >>> len(toRemove)
   440  	for i := 0; i < len(slice); i++ {
   441  		if len(heap) > 0 && i+sliceFirstIndex == heap[0] {
   442  			for len(heap) > 0 && i+sliceFirstIndex == heap[0] {
   443  				heap.Pop()
   444  			}
   445  			continue
   446  		}
   447  		r = append(r, slice[i])
   448  	}
   449  
   450  	return
   451  }
   452  
   453  func computeH(a, b, c []fr.Element, pk *ProvingKey) unsafe.Pointer {
   454  	// H part of Krs
   455  	// Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1))
   456  	// 	1 - _a = ifft(a), _b = ifft(b), _c = ifft(c)
   457  	// 	2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c)
   458  	// 	3 - h = ifft_coset(ca o cb - cc)
   459  
   460  	n := len(a)
   461  
   462  	// add padding to ensure input length is domain cardinality
   463  	padding := make([]fr.Element, int(pk.Domain.Cardinality)-n)
   464  	a = append(a, padding...)
   465  	b = append(b, padding...)
   466  	c = append(c, padding...)
   467  	n = len(a)
   468  
   469  	sizeBytes := n * fr.Bytes
   470  
   471  	/*********** Copy a,b,c to Device Start ************/
   472  	// Individual channels are necessary to know which device pointers
   473  	// point to which vector
   474  	copyADone := make(chan unsafe.Pointer, 1)
   475  	copyBDone := make(chan unsafe.Pointer, 1)
   476  	copyCDone := make(chan unsafe.Pointer, 1)
   477  
   478  	go iciclegnark.CopyToDevice(a, sizeBytes, copyADone)
   479  	go iciclegnark.CopyToDevice(b, sizeBytes, copyBDone)
   480  	go iciclegnark.CopyToDevice(c, sizeBytes, copyCDone)
   481  
   482  	a_device := <-copyADone
   483  	b_device := <-copyBDone
   484  	c_device := <-copyCDone
   485  	/*********** Copy a,b,c to Device End ************/
   486  
   487  	computeInttNttDone := make(chan error, 1)
   488  	computeInttNttOnDevice := func(devicePointer unsafe.Pointer) {
   489  		a_intt_d := iciclegnark.INttOnDevice(devicePointer, pk.DomainDevice.TwiddlesInv, nil, n, sizeBytes, false)
   490  		iciclegnark.NttOnDevice(devicePointer, a_intt_d, pk.DomainDevice.Twiddles, pk.DomainDevice.CosetTable, n, n, sizeBytes, true)
   491  		computeInttNttDone <- nil
   492  		iciclegnark.FreeDevicePointer(a_intt_d)
   493  	}
   494  
   495  	go computeInttNttOnDevice(a_device)
   496  	go computeInttNttOnDevice(b_device)
   497  	go computeInttNttOnDevice(c_device)
   498  	_, _, _ = <-computeInttNttDone, <-computeInttNttDone, <-computeInttNttDone
   499  
   500  	iciclegnark.PolyOps(a_device, b_device, c_device, pk.DenDevice, n)
   501  
   502  	h := iciclegnark.INttOnDevice(a_device, pk.DomainDevice.TwiddlesInv, pk.DomainDevice.CosetTableInv, n, sizeBytes, true)
   503  
   504  	go func() {
   505  		iciclegnark.FreeDevicePointer(a_device)
   506  		iciclegnark.FreeDevicePointer(b_device)
   507  		iciclegnark.FreeDevicePointer(c_device)
   508  	}()
   509  
   510  	iciclegnark.ReverseScalars(h, n)
   511  
   512  	return h
   513  }