github.com/consensys/gnark-crypto@v0.14.0/internal/generator/ecc/template/multiexp.go.tmpl (about) 1 {{ $G1TAffine := print (toUpper .G1.PointName) "Affine" }} 2 {{ $G1TJacobian := print (toUpper .G1.PointName) "Jac" }} 3 {{ $G1TJacobianExtended := print (toLower .G1.PointName) "JacExtended" }} 4 5 {{ $G2TAffine := print (toUpper .G2.PointName) "Affine" }} 6 {{ $G2TJacobian := print (toUpper .G2.PointName) "Jac" }} 7 {{ $G2TJacobianExtended := print (toLower .G2.PointName) "JacExtended" }} 8 9 10 import ( 11 "github.com/consensys/gnark-crypto/internal/parallel" 12 "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr" 13 "github.com/consensys/gnark-crypto/ecc" 14 "errors" 15 "math" 16 "runtime" 17 ) 18 19 {{- if ne .Name "secp256k1"}} 20 {{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 16}} 21 {{template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange "cmax" 16}} 22 {{- else}} 23 {{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange "cmax" 15}} 24 {{- end}} 25 26 27 // selector stores the index, mask and shifts needed to select bits from a scalar 28 // it is used during the multiExp algorithm or the batch scalar multiplication 29 type selector struct { 30 index uint64 // index in the multi-word scalar to select bits from 31 mask uint64 // mask (c-bit wide) 32 shift uint64 // shift needed to get our bits on low positions 33 34 multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) 35 maskHigh uint64 // same than mask, for index+1 36 shiftHigh uint64 // same than shift, for index+1 37 } 38 39 // return number of chunks for a given window size c 40 // the last chunk may be bigger to accommodate a potential carry from the NAF decomposition 41 func computeNbChunks(c uint64) uint64 { 42 return (fr.Bits+c-1) / c 43 } 44 45 // return the last window size for a scalar; 46 // this last window should accommodate a carry (from the NAF decomposition) 47 // it can be == c if we have 1 available bit 48 // it can be > c if we have 0 available bit 49 // it can be < c if we have 2+ available bits 50 func lastC(c uint64) uint64 { 51 nbAvailableBits := (computeNbChunks(c)*c) - fr.Bits 52 return c+1-nbAvailableBits 53 } 54 55 type chunkStat struct { 56 // relative weight of work compared to other chunks. 100.0 -> nominal weight. 57 weight float32 58 59 // percentage of bucket filled in the window; 60 ppBucketFilled float32 61 nbBucketFilled int 62 } 63 64 65 66 // partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits 67 // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract 68 // 2^{c} to the current digit, making it negative. 69 // negative digits can be processed in a later step as adding -G into the bucket instead of G 70 // (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) 71 func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { 72 // no benefit here to have more tasks than CPUs 73 if nbTasks > runtime.NumCPU() { 74 nbTasks = runtime.NumCPU() 75 } 76 77 // number of c-bit radixes in a scalar 78 nbChunks := computeNbChunks(c) 79 80 digits := make([]uint16, len(scalars)*int(nbChunks)) 81 82 mask := uint64((1 << c) - 1) // low c bits are 1 83 max := int(1 << (c -1)) - 1 // max value (inclusive) we want for our digits 84 cDivides64 := (64 %c ) == 0 // if c doesn't divide 64, we may need to select over multiple words 85 86 87 // compute offset and word selector / shift to select the right bits of our windows 88 selectors := make([]selector, nbChunks) 89 for chunk:=uint64(0); chunk < nbChunks; chunk++ { 90 jc := uint64(chunk * c) 91 d := selector{} 92 d.index = jc / 64 93 d.shift = jc - (d.index * 64) 94 d.mask = mask << d.shift 95 d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs - 1 ) 96 if d.multiWordSelect { 97 nbBitsHigh := d.shift - uint64(64-c) 98 d.maskHigh = (1 << nbBitsHigh) - 1 99 d.shiftHigh = (c - nbBitsHigh) 100 } 101 selectors[chunk] = d 102 } 103 104 105 parallel.Execute(len(scalars), func(start, end int) { 106 for i:=start; i < end; i++ { 107 if scalars[i].IsZero() { 108 // everything is 0, no need to process this scalar 109 continue 110 } 111 scalar := scalars[i].Bits() 112 113 var carry int 114 115 // for each chunk in the scalar, compute the current digit, and an eventual carry 116 for chunk := uint64(0); chunk < nbChunks - 1; chunk++ { 117 s := selectors[chunk] 118 119 // init with carry if any 120 digit := carry 121 carry = 0 122 123 // digit = value of the c-bit window 124 digit += int((scalar[s.index] & s.mask) >> s.shift) 125 126 if s.multiWordSelect { 127 // we are selecting bits over 2 words 128 digit += int(scalar[s.index+1] & s.maskHigh) << s.shiftHigh 129 } 130 131 132 // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract 133 // 2^{c} to the current digit, making it negative. 134 if digit > max { 135 digit -= (1 << c) 136 carry = 1 137 } 138 139 // if digit is zero, no impact on result 140 if digit == 0 { 141 continue 142 } 143 144 var bits uint16 145 if digit > 0 { 146 bits = uint16(digit) << 1 147 } else { 148 bits = (uint16(-digit-1) << 1) + 1 149 } 150 digits[int(chunk)*len(scalars)+i] = bits 151 } 152 153 // for the last chunk, we don't want to borrow from a next window 154 // (but may have a larger max value) 155 chunk := nbChunks - 1 156 s := selectors[chunk] 157 // init with carry if any 158 digit := carry 159 // digit = value of the c-bit window 160 digit += int((scalar[s.index] & s.mask) >> s.shift) 161 if s.multiWordSelect { 162 // we are selecting bits over 2 words 163 digit += int(scalar[s.index+1] & s.maskHigh) << s.shiftHigh 164 } 165 digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 166 } 167 168 }, nbTasks) 169 170 171 // aggregate chunk stats 172 chunkStats := make([]chunkStat, nbChunks) 173 if c <= 9 { 174 // no need to compute stats for small window sizes 175 return digits, chunkStats 176 } 177 parallel.Execute(len(chunkStats), func(start, end int) { 178 // for each chunk compute the statistics 179 for chunkID := start; chunkID < end; chunkID++ { 180 // indicates if a bucket is hit. 181 {{- if eq .Name "secp256k1"}} 182 var b bitSetC15 183 {{- else}} 184 var b bitSetC16 185 {{- end}} 186 187 // digits for the chunk 188 chunkDigits := digits[chunkID*len(scalars):(chunkID+1)*len(scalars)] 189 190 totalOps := 0 191 nz := 0 // non zero buckets count 192 for _, digit := range chunkDigits { 193 if digit == 0 { 194 continue 195 } 196 totalOps++ 197 bucketID := digit >> 1 198 if digit &1 == 0 { 199 bucketID-=1 200 } 201 if !b[bucketID] { 202 nz++ 203 b[bucketID] = true 204 } 205 } 206 chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after 207 chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1 << (c-1))) 208 chunkStats[chunkID].nbBucketFilled = nz 209 } 210 }, nbTasks) 211 212 totalOps := float32(0.0) 213 for _, stat := range chunkStats { 214 totalOps+=stat.weight 215 } 216 217 target := totalOps / float32(nbChunks) 218 if target != 0.0 { 219 // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. 220 for i := 0; i < len(chunkStats); i++ { 221 chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target 222 } 223 } 224 225 226 return digits, chunkStats 227 } 228 229 {{define "multiexp" }} 230 231 232 // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf 233 // 234 // This call return an error if len(scalars) != len(points) or if provided config is invalid. 235 func (p *{{ $.TAffine }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) (*{{ $.TAffine }}, error) { 236 var _p {{$.TJacobian}} 237 if _, err := _p.MultiExp(points, scalars, config); err != nil { 238 return nil, err 239 } 240 p.FromJacobian(&_p) 241 return p, nil 242 } 243 244 // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf 245 // 246 // This call return an error if len(scalars) != len(points) or if provided config is invalid. 247 func (p *{{ $.TJacobian }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) (*{{ $.TJacobian }}, error) { 248 // TODO @gbotrel replace the ecc.MultiExpConfig by a Option pattern for maintainability. 249 // note: 250 // each of the msmCX method is the same, except for the c constant it declares 251 // duplicating (through template generation) these methods allows to declare the buckets on the stack 252 // the choice of c needs to be improved: 253 // there is a theoretical value that gives optimal asymptotics 254 // but in practice, other factors come into play, including: 255 // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 256 // * number of CPUs 257 // * cache friendliness (which depends on the host, G1 or G2... ) 258 // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. 259 260 // for each msmCX 261 // step 1 262 // we compute, for each scalars over c-bit wide windows, nbChunk digits 263 // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and subtract 264 // 2^{c} to the current digit, making it negative. 265 // negative digits will be processed in the next step as adding -G into the bucket instead of G 266 // (computing -G is cheap, and this saves us half of the buckets) 267 // step 2 268 // buckets are declared on the stack 269 // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) 270 // we use jacobian extended formulas here as they are faster than mixed addition 271 // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel 272 // step 3 273 // reduce the buckets weighed sums into our result (msmReduceChunk) 274 275 // ensure len(points) == len(scalars) 276 nbPoints := len(points) 277 if nbPoints != len(scalars) { 278 return nil, errors.New("len(points) != len(scalars)") 279 } 280 281 // if nbTasks is not set, use all available CPUs 282 if config.NbTasks <= 0 { 283 config.NbTasks = runtime.NumCPU() * 2 284 } else if config.NbTasks > 1024 { 285 return nil, errors.New("invalid config: config.NbTasks > 1024") 286 } 287 288 // here, we compute the best C for nbPoints 289 // we split recursively until nbChunks(c) >= nbTasks, 290 bestC := func(nbPoints int) uint64 { 291 // implemented msmC methods (the c we use must be in this slice) 292 implementedCs := []uint64{ 293 {{- range $c := $.CRange}}{{- if ge $c 4}}{{$c}},{{- end}}{{- end}} 294 } 295 var C uint64 296 // approximate cost (in group operations) 297 // cost = bits/c * (nbPoints + 2^{c}) 298 // this needs to be verified empirically. 299 // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results 300 min := math.MaxFloat64 301 for _, c := range implementedCs { 302 cc := (fr.Bits+1) * (nbPoints + (1 << c)) 303 cost := float64(cc) / float64(c) 304 if cost < min { 305 min = cost 306 C = c 307 } 308 } 309 return C 310 } 311 312 C := bestC(nbPoints) 313 nbChunks := int(computeNbChunks(C)) 314 315 // should we recursively split the msm in half? (see below) 316 // we want to minimize the execution time of the algorithm; 317 // splitting the msm will **add** operations, but if it allows to use more CPU, it might be worth it. 318 319 // costFunction returns a metric that represent the "wall time" of the algorithm 320 costFunction := func(nbTasks, nbCpus, costPerTask int) int { 321 // cost for the reduction of all tasks (msmReduceChunk) 322 totalCost := nbTasks 323 324 // cost for the computation of each task (msmProcessChunk) 325 for nbTasks >= nbCpus { 326 nbTasks -= nbCpus 327 totalCost += costPerTask 328 } 329 if nbTasks > 0 { 330 totalCost += costPerTask 331 } 332 return totalCost 333 } 334 335 // costPerTask is the approximate number of group ops per task 336 costPerTask := func(c uint64, nbPoints int) int {return (nbPoints + int((1 << c)))} 337 338 costPreSplit := costFunction(nbChunks, config.NbTasks, costPerTask(C, nbPoints)) 339 340 cPostSplit := bestC(nbPoints/2) 341 nbChunksPostSplit := int(computeNbChunks(cPostSplit)) 342 costPostSplit := costFunction(nbChunksPostSplit * 2, config.NbTasks, costPerTask(cPostSplit, nbPoints/2)) 343 344 // if the cost of the split msm is lower than the cost of the non split msm, we split 345 if costPostSplit < costPreSplit { 346 config.NbTasks = int(math.Ceil(float64(config.NbTasks) / 2.0)) 347 var _p {{ $.TJacobian }} 348 chDone := make(chan struct{}, 1) 349 go func() { 350 _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) 351 close(chDone) 352 }() 353 p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) 354 <-chDone 355 p.AddAssign(&_p) 356 return p, nil 357 } 358 359 // if we don't split, we use the best C we found 360 _innerMsm{{ $.UPointName }}(p, C, points, scalars, config) 361 362 return p, nil 363 } 364 365 func _innerMsm{{ $.UPointName }}(p *{{ $.TJacobian }}, c uint64, points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) *{{ $.TJacobian }} { 366 // partition the scalars 367 digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) 368 369 nbChunks := computeNbChunks(c) 370 371 // for each chunk, spawn one go routine that'll loop through all the scalars in the 372 // corresponding bit-window 373 // note that buckets is an array allocated on the stack and this is critical for performance 374 375 // each go routine sends its result in chChunks[i] channel 376 chChunks := make([]chan {{ $.TJacobianExtended }}, nbChunks) 377 for i:=0; i < len(chChunks);i++ { 378 chChunks[i] = make(chan {{ $.TJacobianExtended }}, 1) 379 } 380 381 // we use a semaphore to limit the number of go routines running concurrently 382 // (only if nbTasks < nbCPU) 383 var sem chan struct{} 384 if config.NbTasks < runtime.NumCPU() { 385 // we add nbChunks because if chunk is overweight we split it in two 386 sem = make(chan struct{}, config.NbTasks + int(nbChunks)) 387 for i:=0; i < config.NbTasks; i++ { 388 sem <- struct{}{} 389 } 390 defer func() { 391 close(sem) 392 }() 393 } 394 395 // the last chunk may be processed with a different method than the rest, as it could be smaller. 396 n := len(points) 397 for j := int(nbChunks - 1); j >= 0; j-- { 398 processChunk := getChunkProcessor{{ $.UPointName }}(c, chunkStats[j]) 399 if j == int(nbChunks - 1) { 400 processChunk = getChunkProcessor{{ $.UPointName }}(lastC(c), chunkStats[j]) 401 } 402 if chunkStats[j].weight >= 115 { 403 // we split this in more go routines since this chunk has more work to do than the others. 404 // else what would happen is this go routine would finish much later than the others. 405 chSplit := make(chan {{ $.TJacobianExtended }}, 2) 406 split := n / 2 407 408 if sem != nil { 409 sem <- struct{}{} // add another token to the semaphore, since we split in two. 410 } 411 go processChunk(uint64(j),chSplit, c, points[:split], digits[j*n:(j*n)+split], sem) 412 go processChunk(uint64(j),chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n], sem) 413 go func(chunkID int) { 414 s1 := <-chSplit 415 s2 := <-chSplit 416 close(chSplit) 417 s1.add(&s2) 418 chChunks[chunkID] <- s1 419 }(j) 420 continue 421 } 422 go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n], sem) 423 } 424 425 return msmReduceChunk{{ $.TAffine }}(p, int(c), chChunks[:]) 426 } 427 428 429 // getChunkProcessor{{ $.UPointName }} decides, depending on c window size and statistics for the chunk 430 // to return the best algorithm to process the chunk. 431 func getChunkProcessor{{ $.UPointName }}(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- {{ $.TJacobianExtended }}, c uint64, points []{{ $.TAffine }}, digits []uint16, sem chan struct{}) { 432 switch c { 433 {{- range $c := $.LastCRange}} 434 case {{$c}}: 435 return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}] 436 {{- end }} 437 {{range $c := $.CRange}} 438 case {{$c}}: 439 {{- if le $c 9}} 440 return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}] 441 {{- else}} 442 const batchSize = {{batchSize $c}} 443 // here we could check some chunk statistic (deviation, ...) to determine if calling 444 // the batch affine version is worth it. 445 if stat.nbBucketFilled < batchSize { 446 // clear indicator that batch affine method is not appropriate here. 447 return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}] 448 } 449 return processChunk{{ $.UPointName }}BatchAffine[bucket{{ $.TJacobianExtended }}C{{$c}}, bucket{{ $.TAffine }}C{{$c}}, bitSetC{{$c}}, p{{$.TAffine}}C{{$c}}, pp{{$.TAffine}}C{{$c}}, q{{$.TAffine}}C{{$c}}, c{{$.TAffine}}C{{$c}}] 450 {{- end}} 451 {{- end}} 452 default: 453 // panic("will not happen c != previous values is not generated by templates") 454 return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$.cmax}}] 455 } 456 } 457 458 459 // msmReduceChunk{{ $.TAffine }} reduces the weighted sum of the buckets into the result of the multiExp 460 func msmReduceChunk{{ $.TAffine }}(p *{{ $.TJacobian }}, c int, chChunks []chan {{ $.TJacobianExtended }}) *{{ $.TJacobian }} { 461 var _p {{ $.TJacobianExtended }} 462 totalj := <-chChunks[len(chChunks)-1] 463 _p.Set(&totalj) 464 for j := len(chChunks) - 2; j >= 0; j-- { 465 for l := 0; l < c; l++ { 466 _p.double(&_p) 467 } 468 totalj := <-chChunks[j] 469 _p.add(&totalj) 470 } 471 472 return p.unsafeFromJacExtended(&_p) 473 } 474 475 // Fold computes the multi-exponentiation \sum_{i=0}^{len(points)-1} points[i] * 476 // combinationCoeff^i and stores the result in p. It returns error in case 477 // configuration is invalid. 478 func (p *{{ $.TAffine }}) Fold(points []{{ $.TAffine }}, combinationCoeff fr.Element, config ecc.MultiExpConfig) (*{{ $.TAffine }}, error) { 479 var _p {{ $.TJacobian }} 480 if _, err := _p.Fold(points, combinationCoeff, config); err != nil { 481 return nil, err 482 } 483 p.FromJacobian(&_p) 484 return p, nil 485 } 486 487 // Fold computes the multi-exponentiation \sum_{i=0}^{len(points)-1} points[i] * 488 // combinationCoeff^i and stores the result in p. It returns error in case 489 // configuration is invalid. 490 func (p *{{$.TJacobian}}) Fold(points []{{ $.TAffine }}, combinationCoeff fr.Element, config ecc.MultiExpConfig) (*{{ $.TJacobian }}, error) { 491 scalars := make([]fr.Element, len(points)) 492 scalar := fr.NewElement(1) 493 for i := 0; i < len(points); i++ { 494 scalars[i].Set(&scalar) 495 scalar.Mul(&scalar, &combinationCoeff) 496 } 497 return p.MultiExp(points, scalars, config) 498 } 499 500 501 502 {{end }}