github.com/consensys/gnark-crypto@v0.14.0/internal/generator/fft/template/fft.go.tmpl (about) 1 import ( 2 "math/bits" 3 "github.com/consensys/gnark-crypto/ecc" 4 "github.com/consensys/gnark-crypto/internal/parallel" 5 "math/big" 6 {{ template "import_fr" . }} 7 ) 8 9 {{- /* these params set the size of the kernel we generate & unroll */}} 10 {{ $sizeKernelLog2 := 8}} 11 {{ $sizeKernel := shl 1 $sizeKernelLog2}} 12 13 // Decimation is used in the FFT call to select decimation in time or in frequency 14 type Decimation uint8 15 16 const ( 17 DIT Decimation = iota 18 DIF 19 ) 20 21 // parallelize threshold for a single butterfly op, if the fft stage is not parallelized already 22 const butterflyThreshold = 16 23 24 // FFT computes (recursively) the discrete Fourier transform of a and stores the result in a 25 // if decimation == DIT (decimation in time), the input must be in bit-reversed order 26 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order 27 func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { 28 29 opt := fftOptions(opts...) 30 31 // find the stage where we should stop spawning go routines in our recursive calls 32 // (ie when we have as many go routines running as we have available CPUs) 33 maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) 34 if opt.nbTasks == 1 { 35 maxSplits = -1 36 } 37 38 // if coset != 0, scale by coset table 39 if opt.coset { 40 if decimation == DIT { 41 // scale by coset table (in bit reversed order) 42 cosetTable := domain.cosetTable 43 if !domain.withPrecompute { 44 // we need to build the full table or do a bit reverse dance. 45 cosetTable = make([]fr.Element, len(a)) 46 BuildExpTable(domain.FrMultiplicativeGen, cosetTable) 47 } 48 parallel.Execute(len(a), func(start, end int) { 49 n := uint64(len(a)) 50 nn := uint64(64 - bits.TrailingZeros64(n)) 51 for i := start; i < end; i++ { 52 irev := int(bits.Reverse64(uint64(i)) >> nn) 53 a[i].Mul(&a[i], &cosetTable[irev]) 54 } 55 }, opt.nbTasks) 56 } else { 57 if domain.withPrecompute { 58 parallel.Execute(len(a), func(start, end int) { 59 for i := start; i < end; i++ { 60 a[i].Mul(&a[i], &domain.cosetTable[i]) 61 } 62 }, opt.nbTasks) 63 } else { 64 c := domain.FrMultiplicativeGen 65 parallel.Execute(len(a), func(start, end int) { 66 var at fr.Element 67 at.Exp(c, big.NewInt(int64(start))) 68 for i := start; i < end; i++ { 69 a[i].Mul(&a[i], &at) 70 at.Mul(&at, &c) 71 } 72 }, opt.nbTasks) 73 } 74 75 } 76 } 77 78 twiddles := domain.twiddles 79 twiddlesStartStage := 0 80 if !domain.withPrecompute { 81 twiddlesStartStage = 3 82 nbStages := int(bits.TrailingZeros64(domain.Cardinality)) 83 if nbStages - twiddlesStartStage > 0 { 84 twiddles = make([][]fr.Element, nbStages - twiddlesStartStage) 85 w := domain.Generator 86 w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage))) 87 buildTwiddles(twiddles, w, uint64(nbStages - twiddlesStartStage)) 88 } // else, we don't need twiddles 89 } 90 91 switch decimation { 92 case DIF: 93 difFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) 94 case DIT: 95 ditFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) 96 default: 97 panic("not implemented") 98 } 99 } 100 101 102 103 // FFTInverse computes (recursively) the inverse discrete Fourier transform of a and stores the result in a 104 // if decimation == DIT (decimation in time), the input must be in bit-reversed order 105 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order 106 // coset sets the shift of the fft (0 = no shift, standard fft) 107 // len(a) must be a power of 2, and w must be a len(a)th root of unity in field F. 108 func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ...Option) { 109 opt := fftOptions(opts...) 110 111 // find the stage where we should stop spawning go routines in our recursive calls 112 // (ie when we have as many go routines running as we have available CPUs) 113 maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) 114 if opt.nbTasks == 1 { 115 maxSplits = -1 116 } 117 118 twiddlesInv := domain.twiddlesInv 119 twiddlesStartStage := 0 120 if !domain.withPrecompute { 121 twiddlesStartStage = 3 122 nbStages := int(bits.TrailingZeros64(domain.Cardinality)) 123 if nbStages - twiddlesStartStage > 0 { 124 twiddlesInv = make([][]fr.Element, nbStages - twiddlesStartStage) 125 w := domain.GeneratorInv 126 w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage))) 127 buildTwiddles(twiddlesInv, w, uint64(nbStages - twiddlesStartStage)) 128 } // else, we don't need twiddles 129 } 130 131 switch decimation { 132 case DIF: 133 difFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) 134 case DIT: 135 ditFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) 136 default: 137 panic("not implemented") 138 } 139 140 // scale by CardinalityInv 141 if !opt.coset { 142 parallel.Execute(len(a), func(start, end int) { 143 for i := start; i < end; i++ { 144 a[i].Mul(&a[i], &domain.CardinalityInv) 145 } 146 }, opt.nbTasks) 147 return 148 } 149 150 151 if decimation == DIT { 152 if domain.withPrecompute { 153 parallel.Execute(len(a), func(start, end int) { 154 for i := start; i < end; i++ { 155 a[i].Mul(&a[i], &domain.cosetTableInv[i]). 156 Mul(&a[i], &domain.CardinalityInv) 157 } 158 }, opt.nbTasks) 159 } else { 160 c := domain.FrMultiplicativeGenInv 161 parallel.Execute(len(a), func(start, end int) { 162 var at fr.Element 163 at.Exp(c, big.NewInt(int64(start))) 164 at.Mul(&at, &domain.CardinalityInv) 165 for i := start; i < end; i++ { 166 a[i].Mul(&a[i], &at) 167 at.Mul(&at, &c) 168 } 169 }, opt.nbTasks) 170 } 171 return 172 } 173 174 // decimation == DIF, need to access coset table in bit reversed order. 175 cosetTableInv := domain.cosetTableInv 176 if !domain.withPrecompute { 177 // we need to build the full table or do a bit reverse dance. 178 cosetTableInv = make([]fr.Element, len(a)) 179 BuildExpTable(domain.FrMultiplicativeGenInv, cosetTableInv) 180 } 181 parallel.Execute(len(a), func(start, end int) { 182 n := uint64(len(a)) 183 nn := uint64(64 - bits.TrailingZeros64(n)) 184 for i := start; i < end; i++ { 185 irev := int(bits.Reverse64(uint64(i)) >> nn) 186 a[i].Mul(&a[i], &cosetTableInv[irev]). 187 Mul(&a[i], &domain.CardinalityInv) 188 } 189 }, opt.nbTasks) 190 191 } 192 193 func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { 194 if chDone != nil { 195 defer close(chDone) 196 } 197 198 n := len(a) 199 if n == 1 { 200 return 201 } else if n == {{$sizeKernel}} && stage >= twiddlesStartStage { 202 kerDIFNP_{{$sizeKernel}}(a, twiddles, stage-twiddlesStartStage) 203 return 204 } 205 m := n >> 1 206 207 parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) 208 209 if stage < twiddlesStartStage { 210 if parallelButterfly { 211 w := w 212 parallel.Execute(m, func(start, end int) { 213 if start == 0 { 214 fr.Butterfly(&a[0], &a[m]) 215 start++ 216 } 217 var at fr.Element 218 at.Exp(w, big.NewInt(int64(start))) 219 innerDIFWithoutTwiddles(a, at,w, start, end, m) 220 }, nbTasks / (1 << (stage))) // 1 << stage == estimated used CPUs 221 } else { 222 innerDIFWithoutTwiddles(a, w, w, 0, m, m) 223 } 224 // compute next twiddle 225 w.Square(&w) 226 } else { 227 if parallelButterfly { 228 parallel.Execute(m, func(start, end int) { 229 innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) 230 }, nbTasks / (1 << (stage))) 231 } else { 232 innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) 233 } 234 } 235 236 if m == 1 { 237 return 238 } 239 240 nextStage := stage + 1 241 if stage < maxSplits { 242 chDone := make(chan struct{}, 1) 243 go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) 244 difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 245 <-chDone 246 } else { 247 difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 248 difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 249 } 250 251 } 252 253 254 func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { 255 if start == 0 { 256 fr.Butterfly(&a[0], &a[m]) 257 start++ 258 } 259 for i := start; i < end; i++ { 260 fr.Butterfly(&a[i], &a[i+m]) 261 a[i+m].Mul(&a[i+m], &twiddles[i]) 262 } 263 } 264 265 func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { 266 if start == 0 { 267 fr.Butterfly(&a[0], &a[m]) 268 start++ 269 } 270 for i := start; i < end; i++ { 271 fr.Butterfly(&a[i], &a[i+m]) 272 a[i+m].Mul(&a[i+m], &at) 273 at.Mul(&at, &w) 274 } 275 } 276 277 278 func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { 279 if chDone != nil { 280 defer close(chDone) 281 } 282 n := len(a) 283 if n == 1 { 284 return 285 } else if n == {{$sizeKernel}} && stage >= twiddlesStartStage { 286 kerDITNP_{{$sizeKernel}}(a, twiddles, stage-twiddlesStartStage) 287 return 288 } 289 m := n >> 1 290 291 nextStage := stage + 1 292 nextW := w 293 nextW.Square(&nextW) 294 295 if stage < maxSplits { 296 // that's the only time we fire go routines 297 chDone := make(chan struct{}, 1) 298 go ditFFT(a[m:],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) 299 ditFFT(a[0:m],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 300 <-chDone 301 } else { 302 ditFFT(a[0:m],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 303 ditFFT(a[m:n],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 304 } 305 306 parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) 307 308 if stage < twiddlesStartStage { 309 // we need to compute the twiddles for this stage on the fly. 310 if parallelButterfly { 311 w := w 312 parallel.Execute(m, func(start, end int) { 313 if start == 0 { 314 fr.Butterfly(&a[0], &a[m]) 315 start++ 316 } 317 var at fr.Element 318 at.Exp(w, big.NewInt(int64(start))) 319 innerDITWithoutTwiddles(a, at,w, start, end, m) 320 }, nbTasks / (1 << (stage))) // 1 << stage == estimated used CPUs 321 322 } else { 323 innerDITWithoutTwiddles(a, w,w, 0, m, m) 324 } 325 return 326 } 327 if parallelButterfly { 328 parallel.Execute(m, func(start, end int) { 329 innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) 330 }, nbTasks / (1 << (stage))) 331 } else { 332 innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) 333 } 334 } 335 336 337 func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { 338 if start == 0 { 339 fr.Butterfly(&a[0], &a[m]) 340 start++ 341 } 342 for i := start; i < end; i++ { 343 a[i+m].Mul(&a[i+m], &twiddles[i]) 344 fr.Butterfly(&a[i], &a[i+m]) 345 } 346 } 347 348 func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { 349 if start == 0 { 350 fr.Butterfly(&a[0], &a[m]) 351 start++ 352 } 353 for i := start; i < end; i++ { 354 a[i+m].Mul(&a[i+m], &at) 355 fr.Butterfly(&a[i], &a[i+m]) 356 at.Mul(&at, &w) 357 } 358 } 359 360 361 362 func kerDIFNP_{{$sizeKernel}}(a []fr.Element, twiddles [][]fr.Element, stage int) { 363 // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl 364 365 {{ $n := shl 1 $sizeKernelLog2}} 366 {{ $m := div $n 2}} 367 {{ $split := 1}} 368 {{- range $step := iterate 0 $sizeKernelLog2}} 369 {{- $offset := 0}} 370 371 {{- $bound := mul $split $n}} 372 {{- if eq $bound $n}} 373 innerDIFWithTwiddles(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) 374 {{- else}} 375 for offset := 0; offset < {{$bound}}; offset += {{$n}} { 376 {{- if eq $m 1}} 377 fr.Butterfly(&a[offset], &a[offset+1]) 378 {{- else}} 379 innerDIFWithTwiddles(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) 380 {{- end}} 381 } 382 {{- end}} 383 384 {{- $n = div $n 2}} 385 {{- $m = div $n 2}} 386 {{- $split = mul $split 2}} 387 {{- end}} 388 } 389 390 391 func kerDITNP_{{$sizeKernel}}(a []fr.Element, twiddles [][]fr.Element, stage int) { 392 // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl 393 394 {{ $n := 2}} 395 {{ $m := div $n 2}} 396 {{ $split := div (shl 1 $sizeKernelLog2) 2}} 397 {{- range $step := reverse (iterate 0 $sizeKernelLog2)}} 398 {{- $offset := 0}} 399 400 {{- $bound := mul $split $n}} 401 {{- if eq $bound $n}} 402 innerDITWithTwiddles(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) 403 {{- else}} 404 for offset := 0; offset < {{$bound}}; offset += {{$n}} { 405 {{- if eq $m 1}} 406 fr.Butterfly(&a[offset], &a[offset+1]) 407 {{- else}} 408 innerDITWithTwiddles(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) 409 {{- end}} 410 } 411 {{- end}} 412 413 {{- $n = mul $n 2}} 414 {{- $m = div $n 2}} 415 {{- $split = div $split 2}} 416 {{- end}} 417 }