github.com/consensys/gnark-crypto@v0.14.0/ecc/bn254/fr/fft/fft.go (about) 1 // Copyright 2020 Consensys Software Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Code generated by consensys/gnark-crypto DO NOT EDIT 16 17 package fft 18 19 import ( 20 "github.com/consensys/gnark-crypto/ecc" 21 "github.com/consensys/gnark-crypto/internal/parallel" 22 "math/big" 23 "math/bits" 24 25 "github.com/consensys/gnark-crypto/ecc/bn254/fr" 26 ) 27 28 // Decimation is used in the FFT call to select decimation in time or in frequency 29 type Decimation uint8 30 31 const ( 32 DIT Decimation = iota 33 DIF 34 ) 35 36 // parallelize threshold for a single butterfly op, if the fft stage is not parallelized already 37 const butterflyThreshold = 16 38 39 // FFT computes (recursively) the discrete Fourier transform of a and stores the result in a 40 // if decimation == DIT (decimation in time), the input must be in bit-reversed order 41 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order 42 func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { 43 44 opt := fftOptions(opts...) 45 46 // find the stage where we should stop spawning go routines in our recursive calls 47 // (ie when we have as many go routines running as we have available CPUs) 48 maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) 49 if opt.nbTasks == 1 { 50 maxSplits = -1 51 } 52 53 // if coset != 0, scale by coset table 54 if opt.coset { 55 if decimation == DIT { 56 // scale by coset table (in bit reversed order) 57 cosetTable := domain.cosetTable 58 if !domain.withPrecompute { 59 // we need to build the full table or do a bit reverse dance. 60 cosetTable = make([]fr.Element, len(a)) 61 BuildExpTable(domain.FrMultiplicativeGen, cosetTable) 62 } 63 parallel.Execute(len(a), func(start, end int) { 64 n := uint64(len(a)) 65 nn := uint64(64 - bits.TrailingZeros64(n)) 66 for i := start; i < end; i++ { 67 irev := int(bits.Reverse64(uint64(i)) >> nn) 68 a[i].Mul(&a[i], &cosetTable[irev]) 69 } 70 }, opt.nbTasks) 71 } else { 72 if domain.withPrecompute { 73 parallel.Execute(len(a), func(start, end int) { 74 for i := start; i < end; i++ { 75 a[i].Mul(&a[i], &domain.cosetTable[i]) 76 } 77 }, opt.nbTasks) 78 } else { 79 c := domain.FrMultiplicativeGen 80 parallel.Execute(len(a), func(start, end int) { 81 var at fr.Element 82 at.Exp(c, big.NewInt(int64(start))) 83 for i := start; i < end; i++ { 84 a[i].Mul(&a[i], &at) 85 at.Mul(&at, &c) 86 } 87 }, opt.nbTasks) 88 } 89 90 } 91 } 92 93 twiddles := domain.twiddles 94 twiddlesStartStage := 0 95 if !domain.withPrecompute { 96 twiddlesStartStage = 3 97 nbStages := int(bits.TrailingZeros64(domain.Cardinality)) 98 if nbStages-twiddlesStartStage > 0 { 99 twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) 100 w := domain.Generator 101 w.Exp(w, big.NewInt(int64(1<<twiddlesStartStage))) 102 buildTwiddles(twiddles, w, uint64(nbStages-twiddlesStartStage)) 103 } // else, we don't need twiddles 104 } 105 106 switch decimation { 107 case DIF: 108 difFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) 109 case DIT: 110 ditFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) 111 default: 112 panic("not implemented") 113 } 114 } 115 116 // FFTInverse computes (recursively) the inverse discrete Fourier transform of a and stores the result in a 117 // if decimation == DIT (decimation in time), the input must be in bit-reversed order 118 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order 119 // coset sets the shift of the fft (0 = no shift, standard fft) 120 // len(a) must be a power of 2, and w must be a len(a)th root of unity in field F. 121 func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ...Option) { 122 opt := fftOptions(opts...) 123 124 // find the stage where we should stop spawning go routines in our recursive calls 125 // (ie when we have as many go routines running as we have available CPUs) 126 maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) 127 if opt.nbTasks == 1 { 128 maxSplits = -1 129 } 130 131 twiddlesInv := domain.twiddlesInv 132 twiddlesStartStage := 0 133 if !domain.withPrecompute { 134 twiddlesStartStage = 3 135 nbStages := int(bits.TrailingZeros64(domain.Cardinality)) 136 if nbStages-twiddlesStartStage > 0 { 137 twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage) 138 w := domain.GeneratorInv 139 w.Exp(w, big.NewInt(int64(1<<twiddlesStartStage))) 140 buildTwiddles(twiddlesInv, w, uint64(nbStages-twiddlesStartStage)) 141 } // else, we don't need twiddles 142 } 143 144 switch decimation { 145 case DIF: 146 difFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) 147 case DIT: 148 ditFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) 149 default: 150 panic("not implemented") 151 } 152 153 // scale by CardinalityInv 154 if !opt.coset { 155 parallel.Execute(len(a), func(start, end int) { 156 for i := start; i < end; i++ { 157 a[i].Mul(&a[i], &domain.CardinalityInv) 158 } 159 }, opt.nbTasks) 160 return 161 } 162 163 if decimation == DIT { 164 if domain.withPrecompute { 165 parallel.Execute(len(a), func(start, end int) { 166 for i := start; i < end; i++ { 167 a[i].Mul(&a[i], &domain.cosetTableInv[i]). 168 Mul(&a[i], &domain.CardinalityInv) 169 } 170 }, opt.nbTasks) 171 } else { 172 c := domain.FrMultiplicativeGenInv 173 parallel.Execute(len(a), func(start, end int) { 174 var at fr.Element 175 at.Exp(c, big.NewInt(int64(start))) 176 at.Mul(&at, &domain.CardinalityInv) 177 for i := start; i < end; i++ { 178 a[i].Mul(&a[i], &at) 179 at.Mul(&at, &c) 180 } 181 }, opt.nbTasks) 182 } 183 return 184 } 185 186 // decimation == DIF, need to access coset table in bit reversed order. 187 cosetTableInv := domain.cosetTableInv 188 if !domain.withPrecompute { 189 // we need to build the full table or do a bit reverse dance. 190 cosetTableInv = make([]fr.Element, len(a)) 191 BuildExpTable(domain.FrMultiplicativeGenInv, cosetTableInv) 192 } 193 parallel.Execute(len(a), func(start, end int) { 194 n := uint64(len(a)) 195 nn := uint64(64 - bits.TrailingZeros64(n)) 196 for i := start; i < end; i++ { 197 irev := int(bits.Reverse64(uint64(i)) >> nn) 198 a[i].Mul(&a[i], &cosetTableInv[irev]). 199 Mul(&a[i], &domain.CardinalityInv) 200 } 201 }, opt.nbTasks) 202 203 } 204 205 func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { 206 if chDone != nil { 207 defer close(chDone) 208 } 209 210 n := len(a) 211 if n == 1 { 212 return 213 } else if n == 256 && stage >= twiddlesStartStage { 214 kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) 215 return 216 } 217 m := n >> 1 218 219 parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) 220 221 if stage < twiddlesStartStage { 222 if parallelButterfly { 223 w := w 224 parallel.Execute(m, func(start, end int) { 225 if start == 0 { 226 fr.Butterfly(&a[0], &a[m]) 227 start++ 228 } 229 var at fr.Element 230 at.Exp(w, big.NewInt(int64(start))) 231 innerDIFWithoutTwiddles(a, at, w, start, end, m) 232 }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs 233 } else { 234 innerDIFWithoutTwiddles(a, w, w, 0, m, m) 235 } 236 // compute next twiddle 237 w.Square(&w) 238 } else { 239 if parallelButterfly { 240 parallel.Execute(m, func(start, end int) { 241 innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) 242 }, nbTasks/(1<<(stage))) 243 } else { 244 innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) 245 } 246 } 247 248 if m == 1 { 249 return 250 } 251 252 nextStage := stage + 1 253 if stage < maxSplits { 254 chDone := make(chan struct{}, 1) 255 go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) 256 difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 257 <-chDone 258 } else { 259 difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 260 difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 261 } 262 263 } 264 265 func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { 266 if start == 0 { 267 fr.Butterfly(&a[0], &a[m]) 268 start++ 269 } 270 for i := start; i < end; i++ { 271 fr.Butterfly(&a[i], &a[i+m]) 272 a[i+m].Mul(&a[i+m], &twiddles[i]) 273 } 274 } 275 276 func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { 277 if start == 0 { 278 fr.Butterfly(&a[0], &a[m]) 279 start++ 280 } 281 for i := start; i < end; i++ { 282 fr.Butterfly(&a[i], &a[i+m]) 283 a[i+m].Mul(&a[i+m], &at) 284 at.Mul(&at, &w) 285 } 286 } 287 288 func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { 289 if chDone != nil { 290 defer close(chDone) 291 } 292 n := len(a) 293 if n == 1 { 294 return 295 } else if n == 256 && stage >= twiddlesStartStage { 296 kerDITNP_256(a, twiddles, stage-twiddlesStartStage) 297 return 298 } 299 m := n >> 1 300 301 nextStage := stage + 1 302 nextW := w 303 nextW.Square(&nextW) 304 305 if stage < maxSplits { 306 // that's the only time we fire go routines 307 chDone := make(chan struct{}, 1) 308 go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) 309 ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 310 <-chDone 311 } else { 312 ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 313 ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) 314 } 315 316 parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) 317 318 if stage < twiddlesStartStage { 319 // we need to compute the twiddles for this stage on the fly. 320 if parallelButterfly { 321 w := w 322 parallel.Execute(m, func(start, end int) { 323 if start == 0 { 324 fr.Butterfly(&a[0], &a[m]) 325 start++ 326 } 327 var at fr.Element 328 at.Exp(w, big.NewInt(int64(start))) 329 innerDITWithoutTwiddles(a, at, w, start, end, m) 330 }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs 331 332 } else { 333 innerDITWithoutTwiddles(a, w, w, 0, m, m) 334 } 335 return 336 } 337 if parallelButterfly { 338 parallel.Execute(m, func(start, end int) { 339 innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) 340 }, nbTasks/(1<<(stage))) 341 } else { 342 innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) 343 } 344 } 345 346 func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { 347 if start == 0 { 348 fr.Butterfly(&a[0], &a[m]) 349 start++ 350 } 351 for i := start; i < end; i++ { 352 a[i+m].Mul(&a[i+m], &twiddles[i]) 353 fr.Butterfly(&a[i], &a[i+m]) 354 } 355 } 356 357 func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { 358 if start == 0 { 359 fr.Butterfly(&a[0], &a[m]) 360 start++ 361 } 362 for i := start; i < end; i++ { 363 a[i+m].Mul(&a[i+m], &at) 364 fr.Butterfly(&a[i], &a[i+m]) 365 at.Mul(&at, &w) 366 } 367 } 368 369 func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { 370 // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl 371 372 innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) 373 for offset := 0; offset < 256; offset += 128 { 374 innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) 375 } 376 for offset := 0; offset < 256; offset += 64 { 377 innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) 378 } 379 for offset := 0; offset < 256; offset += 32 { 380 innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) 381 } 382 for offset := 0; offset < 256; offset += 16 { 383 innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) 384 } 385 for offset := 0; offset < 256; offset += 8 { 386 innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) 387 } 388 for offset := 0; offset < 256; offset += 4 { 389 innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) 390 } 391 for offset := 0; offset < 256; offset += 2 { 392 fr.Butterfly(&a[offset], &a[offset+1]) 393 } 394 } 395 396 func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { 397 // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl 398 399 for offset := 0; offset < 256; offset += 2 { 400 fr.Butterfly(&a[offset], &a[offset+1]) 401 } 402 for offset := 0; offset < 256; offset += 4 { 403 innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) 404 } 405 for offset := 0; offset < 256; offset += 8 { 406 innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) 407 } 408 for offset := 0; offset < 256; offset += 16 { 409 innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) 410 } 411 for offset := 0; offset < 256; offset += 32 { 412 innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) 413 } 414 for offset := 0; offset < 256; offset += 64 { 415 innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) 416 } 417 for offset := 0; offset < 256; offset += 128 { 418 innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) 419 } 420 innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) 421 }