github.com/consensys/gnark-crypto@v0.14.0/internal/generator/ecc/template/multiexp.go.tmpl (about)

     1  {{ $G1TAffine := print (toUpper .G1.PointName) "Affine" }}
     2  {{ $G1TJacobian := print (toUpper .G1.PointName) "Jac" }}
     3  {{ $G1TJacobianExtended := print (toLower .G1.PointName) "JacExtended" }}
     4  
     5  {{ $G2TAffine := print (toUpper .G2.PointName) "Affine" }}
     6  {{ $G2TJacobian := print (toUpper .G2.PointName) "Jac" }}
     7  {{ $G2TJacobianExtended := print (toLower .G2.PointName) "JacExtended" }}
     8  
     9  
    10  import (
    11  	"github.com/consensys/gnark-crypto/internal/parallel"
    12  	"github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr"
    13  	"github.com/consensys/gnark-crypto/ecc"
    14  	"errors"
    15  	"math"
    16  	"runtime"
    17  )
    18  
    19  {{- if ne .Name "secp256k1"}}
    20  {{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 16}}
    21  {{template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange "cmax" 16}}
    22  {{- else}}
    23  {{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 15}}
    24  {{- end}}
    25  
    26  
    27  // selector stores the index, mask and shifts needed to select bits from a scalar
    28  // it is used during the multiExp algorithm or the batch scalar multiplication
    29  type selector struct {
    30  	index uint64 			// index in the multi-word scalar to select bits from
    31  	mask uint64				// mask (c-bit wide)
    32  	shift uint64			// shift needed to get our bits on low positions
    33  
    34  	multiWordSelect bool	// set to true if we need to select bits from 2 words (case where c doesn't divide 64)
    35  	maskHigh uint64 	  	// same than mask, for index+1
    36  	shiftHigh uint64		// same than shift, for index+1
    37  }
    38  
    39  // return number of chunks for a given window size c
    40  // the last chunk may be bigger to accommodate a potential carry from the NAF decomposition
    41  func computeNbChunks(c uint64) uint64 {
    42  	return (fr.Bits+c-1) / c
    43  }
    44  
    45  // return the last window size for a scalar;
    46  // this last window should accommodate a carry (from the NAF decomposition)
    47  // it can be == c if we have 1 available bit
    48  // it can be > c if we have 0 available bit
    49  // it can be < c if we have 2+ available bits
    50  func lastC(c uint64) uint64 {
    51  	nbAvailableBits := (computeNbChunks(c)*c) - fr.Bits
    52  	return c+1-nbAvailableBits
    53  }
    54  
    55  type chunkStat struct {
    56  	// relative weight of work compared to other chunks. 100.0 -> nominal weight.
    57  	weight float32
    58  
    59  	// percentage of bucket filled in the window;
    60  	ppBucketFilled float32
    61  	nbBucketFilled int
    62  }
    63  
    64  
    65  
    66  // partitionScalars  compute, for each scalars over c-bit wide windows, nbChunk digits
    67  // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract
    68  // 2^{c} to the current digit, making it negative.
    69  // negative digits can be processed in a later step as adding -G into the bucket instead of G
    70  // (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication)
    71  func partitionScalars(scalars []fr.Element, c uint64,  nbTasks int) ([]uint16, []chunkStat) {
    72  	// no benefit here to have more tasks than CPUs
    73  	if nbTasks > runtime.NumCPU() {
    74  		nbTasks = runtime.NumCPU()
    75  	}
    76  
    77  	// number of c-bit radixes in a scalar
    78  	nbChunks := computeNbChunks(c)
    79  
    80  	digits := make([]uint16, len(scalars)*int(nbChunks))
    81  
    82  	mask  := uint64((1 << c) - 1) 		// low c bits are 1
    83  	max := int(1 << (c -1)) - 1					// max value (inclusive) we want for our digits
    84  	cDivides64 :=  (64 %c ) == 0 				// if c doesn't divide 64, we may need to select over multiple words
    85  
    86  
    87  	// compute offset and word selector / shift to select the right bits of our windows
    88  	selectors := make([]selector, nbChunks)
    89  	for chunk:=uint64(0); chunk < nbChunks; chunk++ {
    90  		jc := uint64(chunk * c)
    91  		d := selector{}
    92  		d.index = jc / 64
    93  		d.shift = jc - (d.index * 64)
    94  		d.mask = mask << d.shift
    95  		d.multiWordSelect = !cDivides64  && d.shift > (64-c) && d.index < (fr.Limbs - 1 )
    96  		if d.multiWordSelect {
    97  			nbBitsHigh := d.shift - uint64(64-c)
    98  			d.maskHigh = (1 << nbBitsHigh) - 1
    99  			d.shiftHigh = (c - nbBitsHigh)
   100  		}
   101  		selectors[chunk] = d
   102  	}
   103  
   104  
   105  	parallel.Execute(len(scalars), func(start, end int) {
   106  		for i:=start; i < end; i++ {
   107  			if scalars[i].IsZero() {
   108  				// everything is 0, no need to process this scalar
   109  				continue
   110  			}
   111  			scalar := scalars[i].Bits()
   112  
   113  			var carry int
   114  
   115  			// for each chunk in the scalar, compute the current digit, and an eventual carry
   116  			for chunk := uint64(0); chunk < nbChunks - 1; chunk++ {
   117  				s := selectors[chunk]
   118  
   119  				// init with carry if any
   120  				digit := carry
   121  				carry = 0
   122  
   123  				// digit = value of the c-bit window
   124  				digit += int((scalar[s.index] & s.mask) >> s.shift)
   125  
   126  				if s.multiWordSelect {
   127  					// we are selecting bits over 2 words
   128  					digit += int(scalar[s.index+1] & s.maskHigh) << s.shiftHigh
   129  				}
   130  
   131  
   132  				// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract
   133  				// 2^{c} to the current digit, making it negative.
   134  				if digit > max {
   135  					digit -= (1 << c)
   136  					carry = 1
   137  				}
   138  
   139  				// if digit is zero, no impact on result
   140  				if digit == 0 {
   141  					continue
   142  				}
   143  
   144  				var bits uint16
   145  				if digit > 0 {
   146  					bits = uint16(digit) << 1
   147  				} else {
   148  					bits = (uint16(-digit-1) << 1) + 1
   149  				}
   150  				digits[int(chunk)*len(scalars)+i] = bits
   151  			}
   152  
   153  			// for the last chunk, we don't want to borrow from a next window
   154  			// (but may have a larger max value)
   155  			chunk := nbChunks - 1
   156  			s := selectors[chunk]
   157  			// init with carry if any
   158  			digit := carry
   159  			// digit = value of the c-bit window
   160  			digit += int((scalar[s.index] & s.mask) >> s.shift)
   161  			if s.multiWordSelect {
   162  				// we are selecting bits over 2 words
   163  				digit += int(scalar[s.index+1] & s.maskHigh) << s.shiftHigh
   164  			}
   165  			digits[int(chunk)*len(scalars)+i] =  uint16(digit) << 1
   166  		}
   167  
   168  	}, nbTasks)
   169  
   170  
   171  	// aggregate  chunk stats
   172  	chunkStats := make([]chunkStat, nbChunks)
   173  	if c <= 9 {
   174  		// no need to compute stats for small window sizes
   175  		return digits, chunkStats
   176  	}
   177  	parallel.Execute(len(chunkStats), func(start, end int) {
   178  		// for each chunk compute the statistics
   179  		for chunkID := start; chunkID < end; chunkID++ {
   180  			// indicates if a bucket is hit.
   181              {{- if eq .Name "secp256k1"}}
   182                  var b bitSetC15
   183              {{- else}}
   184                  var b bitSetC16
   185              {{- end}}
   186  
   187  			// digits for the chunk
   188  			chunkDigits := digits[chunkID*len(scalars):(chunkID+1)*len(scalars)]
   189  
   190  			totalOps := 0
   191  			nz := 0 // non zero buckets count
   192  			for _, digit := range chunkDigits {
   193  				if digit == 0 {
   194  					continue
   195  				}
   196  				totalOps++
   197  				bucketID := digit >> 1
   198  				if digit &1 == 0 {
   199  					bucketID-=1
   200  				}
   201  				if !b[bucketID] {
   202  					nz++
   203  					b[bucketID] = true
   204  				}
   205  			}
   206  			chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after
   207  			chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1 << (c-1)))
   208  			chunkStats[chunkID].nbBucketFilled = nz
   209  		}
   210  	}, nbTasks)
   211  
   212  	totalOps := float32(0.0)
   213  	for _, stat := range chunkStats {
   214  		totalOps+=stat.weight
   215  	}
   216  
   217  	target := totalOps / float32(nbChunks)
   218  	if target != 0.0 {
   219  		// if target == 0, it means all the scalars are 0 everywhere, there is no work to be done.
   220  		for i := 0; i < len(chunkStats); i++ {
   221  			chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target
   222  		}
   223  	}
   224  
   225  
   226  	return digits, chunkStats
   227  }
   228  
   229  {{define "multiexp" }}
   230  
   231  
   232  // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf
   233  //
   234  // This call return an error if len(scalars) != len(points) or if provided config is invalid.
   235  func (p *{{ $.TAffine }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) (*{{ $.TAffine }}, error) {
   236  	var _p {{$.TJacobian}}
   237  	if _, err := _p.MultiExp(points, scalars, config); err != nil {
   238  		return nil, err
   239  	}
   240  	p.FromJacobian(&_p)
   241  	return p, nil
   242  }
   243  
   244  // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf
   245  //
   246  // This call return an error if len(scalars) != len(points) or if provided config is invalid.
   247  func (p *{{ $.TJacobian }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) (*{{ $.TJacobian }}, error) {
   248  	// TODO @gbotrel replace the ecc.MultiExpConfig by a Option pattern for maintainability.
   249  	// note:
   250  	// each of the msmCX method is the same, except for the c constant it declares
   251  	// duplicating (through template generation) these methods allows to declare the buckets on the stack
   252  	// the choice of c needs to be improved:
   253  	// there is a theoretical value that gives optimal asymptotics
   254  	// but in practice, other factors come into play, including:
   255  	// * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1
   256  	// * number of CPUs
   257  	// * cache friendliness (which depends on the host, G1 or G2... )
   258  	//	--> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't.
   259  
   260  	// for each msmCX
   261  	// step 1
   262  	// we compute, for each scalars over c-bit wide windows, nbChunk digits
   263  	// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract
   264  	// 2^{c} to the current digit, making it negative.
   265  	// negative digits will be processed in the next step as adding -G into the bucket instead of G
   266  	// (computing -G is cheap, and this saves us half of the buckets)
   267  	// step 2
   268  	// buckets are declared on the stack
   269  	// notice that we have 2^{c-1} buckets instead of 2^{c} (see step1)
   270  	// we use jacobian extended formulas here as they are faster than mixed addition
   271  	// msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel
   272  	// step 3
   273  	// reduce the buckets weighed sums into our result (msmReduceChunk)
   274  
   275  	// ensure len(points) == len(scalars)
   276  	nbPoints := len(points)
   277  	if nbPoints != len(scalars) {
   278  		return nil, errors.New("len(points) != len(scalars)")
   279  	}
   280  
   281  	// if nbTasks is not set, use all available CPUs
   282  	if config.NbTasks <= 0 {
   283  		config.NbTasks = runtime.NumCPU() * 2
   284  	} else if config.NbTasks > 1024 {
   285  		return nil, errors.New("invalid config: config.NbTasks > 1024")
   286  	}
   287  
   288  	// here, we compute the best C for nbPoints
   289  	// we split recursively until nbChunks(c) >= nbTasks,
   290  	bestC := func(nbPoints int) uint64 {
   291  		// implemented msmC methods (the c we use must be in this slice)
   292  		implementedCs := []uint64{
   293  			{{- range $c :=  $.CRange}}{{- if ge $c 4}}{{$c}},{{- end}}{{- end}}
   294  		}
   295  		var C uint64
   296  		// approximate cost (in group operations)
   297  		// cost = bits/c * (nbPoints + 2^{c})
   298  		// this needs to be verified empirically.
   299  		// for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results
   300  		min := math.MaxFloat64
   301  		for _, c := range implementedCs {
   302  			cc := (fr.Bits+1) * (nbPoints + (1 << c))
   303  			cost := float64(cc) / float64(c)
   304  			if cost < min {
   305  				min = cost
   306  				C = c
   307  			}
   308  		}
   309  		return C
   310  	}
   311  
   312  	C := bestC(nbPoints)
   313  	nbChunks := int(computeNbChunks(C))
   314  
   315  	// should we recursively split the msm in half? (see below)
   316  	// we want to minimize the execution time of the algorithm; 
   317  	// splitting the msm will **add** operations, but if it allows to use more CPU, it might be worth it.
   318  
   319  	// costFunction returns a metric that represent the "wall time" of the algorithm
   320  	costFunction := func(nbTasks, nbCpus, costPerTask int) int {
   321  		// cost for the reduction of all tasks (msmReduceChunk)
   322  		totalCost := nbTasks
   323  
   324  		// cost for the computation of each task (msmProcessChunk)
   325  		for nbTasks >= nbCpus {
   326  			nbTasks -= nbCpus
   327  			totalCost += costPerTask
   328  		}
   329  		if nbTasks > 0 {
   330  			totalCost += costPerTask
   331  		}
   332  		return totalCost
   333  	}
   334  
   335  	// costPerTask is the approximate number of group ops per task
   336  	costPerTask := func(c uint64, nbPoints int) int {return (nbPoints + int((1 << c)))}
   337  
   338  	costPreSplit := costFunction(nbChunks, config.NbTasks, costPerTask(C, nbPoints))
   339  	
   340  	cPostSplit := bestC(nbPoints/2)
   341  	nbChunksPostSplit := int(computeNbChunks(cPostSplit))
   342  	costPostSplit := costFunction(nbChunksPostSplit * 2, config.NbTasks, costPerTask(cPostSplit, nbPoints/2))
   343  
   344  	// if the cost of the split msm is lower than the cost of the non split msm, we split
   345  	if costPostSplit < costPreSplit {
   346  		config.NbTasks = int(math.Ceil(float64(config.NbTasks) / 2.0))
   347  		var _p {{ $.TJacobian }}
   348  		chDone := make(chan struct{}, 1)
   349  		go func() {
   350  			_p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config)
   351  			close(chDone)
   352  		}()
   353  		p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config)
   354  		<-chDone
   355  		p.AddAssign(&_p)
   356  		return p, nil
   357  	}
   358  
   359  	// if we don't split, we use the best C we found
   360  	_innerMsm{{ $.UPointName }}(p, C, points, scalars, config)
   361  
   362  	return p, nil
   363  }
   364  
   365  func _innerMsm{{ $.UPointName }}(p *{{ $.TJacobian }}, c uint64, points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) *{{ $.TJacobian }} {
   366  	// partition the scalars
   367  	digits, chunkStats := partitionScalars(scalars, c, config.NbTasks)
   368  
   369  	nbChunks := computeNbChunks(c)
   370  
   371  	// for each chunk, spawn one go routine that'll loop through all the scalars in the
   372  	// corresponding bit-window
   373  	// note that buckets is an array allocated on the stack and this is critical for performance
   374  
   375  	// each go routine sends its result in chChunks[i] channel
   376  	chChunks := make([]chan {{ $.TJacobianExtended }}, nbChunks)
   377  	for i:=0; i < len(chChunks);i++ {
   378  		chChunks[i] = make(chan {{ $.TJacobianExtended }}, 1)
   379  	}
   380  
   381  	// we use a semaphore to limit the number of go routines running concurrently
   382  	// (only if nbTasks < nbCPU)
   383  	var sem chan struct{}
   384  	if config.NbTasks < runtime.NumCPU() {
   385  		// we add nbChunks because if chunk is overweight we split it in two
   386  		sem = make(chan struct{}, config.NbTasks + int(nbChunks)) 
   387  		for i:=0; i < config.NbTasks; i++ {
   388  			sem <- struct{}{}
   389  		}
   390  		defer func() {
   391  			close(sem)
   392  		}()
   393  	}
   394  
   395  	// the last chunk may be processed with a different method than the rest, as it could be smaller.
   396  	n := len(points)
   397  	for j := int(nbChunks - 1); j >= 0; j-- {
   398  		processChunk := getChunkProcessor{{ $.UPointName }}(c, chunkStats[j])
   399  		if j == int(nbChunks - 1) {
   400  			processChunk = getChunkProcessor{{ $.UPointName }}(lastC(c), chunkStats[j])
   401  		}
   402  		if chunkStats[j].weight >= 115 {
   403  			// we split this in more go routines since this chunk has more work to do than the others.
   404  			// else what would happen is this go routine would finish much later than the others.
   405  			chSplit := make(chan {{ $.TJacobianExtended }}, 2)
   406  			split := n / 2
   407  
   408  			if sem != nil {
   409  				sem <- struct{}{} // add another token to the semaphore, since we split in two.
   410  			}
   411  			go processChunk(uint64(j),chSplit, c, points[:split], digits[j*n:(j*n)+split], sem)
   412  			go processChunk(uint64(j),chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n], sem)
   413  			go func(chunkID int) {
   414  				s1 := <-chSplit
   415  				s2 := <-chSplit
   416  				close(chSplit)
   417  				s1.add(&s2)
   418  				chChunks[chunkID] <- s1
   419  			}(j)
   420  			continue
   421  		}
   422  		go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n], sem)
   423  	}
   424  
   425  	return msmReduceChunk{{ $.TAffine }}(p, int(c), chChunks[:])
   426  }
   427  
   428  
   429  // getChunkProcessor{{ $.UPointName }} decides, depending on c window size and statistics for the chunk
   430  // to return the best algorithm to process the chunk.
   431  func getChunkProcessor{{ $.UPointName }}(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- {{ $.TJacobianExtended }}, c uint64, points []{{ $.TAffine }}, digits []uint16, sem chan struct{}) {
   432  	switch c {
   433  		{{- range $c :=  $.LastCRange}}
   434  		case {{$c}}:
   435  			return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}]
   436  		{{- end }}
   437  		{{range $c :=  $.CRange}}
   438  		case {{$c}}:
   439  			{{- if le $c 9}}
   440  				return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}]
   441  			{{- else}}
   442  				const batchSize = {{batchSize $c}}
   443  				// here we could check some chunk statistic (deviation, ...) to determine if calling
   444  				// the batch affine version is worth it.
   445  				if stat.nbBucketFilled < batchSize {
   446  					// clear indicator that batch affine method is not appropriate here.
   447  					return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}]
   448  				}
   449  				return processChunk{{ $.UPointName }}BatchAffine[bucket{{ $.TJacobianExtended }}C{{$c}}, bucket{{ $.TAffine }}C{{$c}}, bitSetC{{$c}}, p{{$.TAffine}}C{{$c}}, pp{{$.TAffine}}C{{$c}}, q{{$.TAffine}}C{{$c}}, c{{$.TAffine}}C{{$c}}]
   450  			{{- end}}
   451  		{{- end}}
   452  		default:
   453  			// panic("will not happen c != previous values is not generated by templates")
   454              return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$.cmax}}]
   455  	}
   456  }
   457  
   458  
   459  // msmReduceChunk{{ $.TAffine }} reduces the weighted sum of the buckets into the result of the multiExp
   460  func msmReduceChunk{{ $.TAffine }}(p *{{ $.TJacobian }}, c int, chChunks []chan {{ $.TJacobianExtended }})  *{{ $.TJacobian }} {
   461  	var _p {{ $.TJacobianExtended }}
   462  	totalj := <-chChunks[len(chChunks)-1]
   463      _p.Set(&totalj)
   464  	for j := len(chChunks) - 2; j >= 0; j-- {
   465  		for l := 0; l < c; l++ {
   466  			_p.double(&_p)
   467  		}
   468  		totalj := <-chChunks[j]
   469  		_p.add(&totalj)
   470  	}
   471  
   472  	return p.unsafeFromJacExtended(&_p)
   473  }
   474  
   475  // Fold computes the multi-exponentiation \sum_{i=0}^{len(points)-1} points[i] *
   476  // combinationCoeff^i and stores the result in p. It returns error in case
   477  // configuration is invalid.
   478  func (p *{{ $.TAffine }}) Fold(points []{{ $.TAffine }}, combinationCoeff fr.Element, config ecc.MultiExpConfig) (*{{ $.TAffine }}, error) {
   479  	var _p {{ $.TJacobian }}
   480  	if _, err := _p.Fold(points, combinationCoeff, config); err != nil {
   481  		return nil, err
   482  	}
   483  	p.FromJacobian(&_p)
   484  	return p, nil
   485  }
   486  
   487  // Fold computes the multi-exponentiation \sum_{i=0}^{len(points)-1} points[i] *
   488  // combinationCoeff^i and stores the result in p. It returns error in case
   489  // configuration is invalid.
   490  func (p *{{$.TJacobian}}) Fold(points []{{ $.TAffine }}, combinationCoeff fr.Element, config ecc.MultiExpConfig) (*{{ $.TJacobian }}, error) {
   491  	scalars := make([]fr.Element, len(points))
   492  	scalar := fr.NewElement(1)
   493  	for i := 0; i < len(points); i++ {
   494  		scalars[i].Set(&scalar)
   495  		scalar.Mul(&scalar, &combinationCoeff)
   496  	}
   497  	return p.MultiExp(points, scalars, config)
   498  }
   499  
   500  
   501  
   502  {{end }}