github.com/consensys/gnark-crypto@v0.14.0/internal/generator/kzg/template/utils.go.tmpl (about) 1 import ( 2 "math/big" 3 "math/bits" 4 "runtime" 5 "fmt" 6 7 "github.com/consensys/gnark-crypto/ecc" 8 curve "github.com/consensys/gnark-crypto/ecc/{{ .Name }}" 9 "github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr" 10 "github.com/consensys/gnark-crypto/internal/parallel" 11 ) 12 13 // ToLagrangeG1 in place transform of coeffs canonical form into Lagrange form. 14 // From the formula Lᵢ(τ) = 1/n∑_{j<n}(τ/ωⁱ)ʲ we 15 // see that [L₁(τ),..,Lₙ(τ)] = FFT_inv(∑_{j<n}τʲXʲ), so it suffices to apply the inverse 16 // fft on the vector consisting of the original SRS. 17 // Size of coeffs must be a power of 2. 18 func ToLagrangeG1(coeffs []curve.G1Affine) ([]curve.G1Affine, error) { 19 if bits.OnesCount64(uint64(len(coeffs))) != 1 { 20 return nil, fmt.Errorf("len(coeffs) must be a power of 2") 21 } 22 size := len(coeffs) 23 24 numCPU := uint64(runtime.NumCPU()) 25 maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) << 1 26 27 twiddlesInv, err := computeTwiddlesInv(size) 28 if err != nil { 29 return nil, err 30 } 31 32 // batch convert to Jacobian 33 jCoeffs := make([]curve.G1Jac, len(coeffs)) 34 for i := 0; i < len(coeffs); i++ { 35 jCoeffs[i].FromAffine(&coeffs[i]) 36 } 37 38 difFFTG1(jCoeffs, twiddlesInv, 0, maxSplits, nil) 39 40 // TODO @gbotrel generify the cobra bitreverse function, benchmark it and use it everywhere 41 bitReverse(jCoeffs) 42 43 var invBigint big.Int 44 var frCardinality fr.Element 45 frCardinality.SetUint64(uint64(size)) 46 frCardinality.Inverse(&frCardinality) 47 frCardinality.BigInt(&invBigint) 48 49 parallel.Execute(size, func(start, end int) { 50 for i := start; i < end; i++ { 51 jCoeffs[i].ScalarMultiplication(&jCoeffs[i], &invBigint) 52 } 53 }) 54 55 // batch convert to affine 56 return curve.BatchJacobianToAffineG1(jCoeffs), nil 57 } 58 59 func computeTwiddlesInv(cardinality int) ([]*big.Int, error) { 60 generator, err := fr.Generator(uint64(cardinality)) 61 if err != nil { 62 return nil, err 63 } 64 65 // inverse the generator 66 generator.Inverse(&generator) 67 68 // nb fft stages 69 nbStages := uint64(bits.TrailingZeros64(uint64(cardinality))) 70 71 r := make([]*big.Int, 1+(1<<(nbStages-1))) 72 73 w := generator 74 r[0] = new(big.Int).SetUint64(1) 75 if len(r) == 1 { 76 return r, nil 77 } 78 r[1] = new(big.Int) 79 w.BigInt(r[1]) 80 for j := 2; j < len(r); j++ { 81 w.Mul(&w, &generator) 82 r[j] = new(big.Int) 83 w.BigInt(r[j]) 84 } 85 86 return r, nil 87 } 88 89 func bitReverse[T any](a []T) { 90 n := uint64(len(a)) 91 nn := uint64(64 - bits.TrailingZeros64(n)) 92 93 for i := uint64(0); i < n; i++ { 94 irev := bits.Reverse64(i) >> nn 95 if irev > i { 96 a[i], a[irev] = a[irev], a[i] 97 } 98 } 99 } 100 101 func butterflyG1(a *curve.G1Jac, b *curve.G1Jac) { 102 t := *a 103 a.AddAssign(b) 104 t.SubAssign(b) 105 b.Set(&t) 106 } 107 108 func difFFTG1(a []curve.G1Jac, twiddles []*big.Int, stage, maxSplits int, chDone chan struct{}) { 109 if chDone != nil { 110 defer close(chDone) 111 } 112 113 n := len(a) 114 if n == 1 { 115 return 116 } 117 m := n >> 1 118 119 butterflyG1(&a[0], &a[m]) 120 // stage determines the stride 121 // if stage == 0, then we use 1, w, w**2, w**3, w**4, w**5, w**6, ... 122 // if stage == 1, then we use 1, w**2, w**4, w**6, ... that is, indexes 0, 2, 4, 6, ... of stage 0 123 // if stage == 2, then we use 1, w**4, w**8, w**12, ... that is indexes 0, 4, 8, 12, ... of stage 0 124 stride := 1 << stage 125 126 const butterflyThreshold = 8 127 if m >= butterflyThreshold { 128 // 1 << stage == estimated used CPUs 129 numCPU := runtime.NumCPU() / (1 << (stage)) 130 parallel.Execute(m, func(start, end int) { 131 if start == 0 { 132 start = 1 133 } 134 j := start * stride 135 for i := start; i < end; i++ { 136 butterflyG1(&a[i], &a[i+m]) 137 a[i+m].ScalarMultiplication(&a[i+m], twiddles[j]) 138 j += stride 139 } 140 }, numCPU) 141 } else { 142 j := stride 143 for i := 1; i < m; i++ { 144 butterflyG1(&a[i], &a[i+m]) 145 a[i+m].ScalarMultiplication(&a[i+m], twiddles[j]) 146 j += stride 147 } 148 } 149 150 if m == 1 { 151 return 152 } 153 154 nextStage := stage + 1 155 if stage < maxSplits { 156 chDone := make(chan struct{}, 1) 157 go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone) 158 difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) 159 <-chDone 160 } else { 161 difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil) 162 difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil) 163 } 164 }