github.com/consensys/gnark-crypto@v0.14.0/ecc/bn254/fr/fft/fft.go (about)

     1  // Copyright 2020 Consensys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Code generated by consensys/gnark-crypto DO NOT EDIT
    16  
    17  package fft
    18  
    19  import (
    20  	"github.com/consensys/gnark-crypto/ecc"
    21  	"github.com/consensys/gnark-crypto/internal/parallel"
    22  	"math/big"
    23  	"math/bits"
    24  
    25  	"github.com/consensys/gnark-crypto/ecc/bn254/fr"
    26  )
    27  
    28  // Decimation is used in the FFT call to select decimation in time or in frequency
    29  type Decimation uint8
    30  
    31  const (
    32  	DIT Decimation = iota
    33  	DIF
    34  )
    35  
    36  // parallelize threshold for a single butterfly op, if the fft stage is not parallelized already
    37  const butterflyThreshold = 16
    38  
    39  // FFT computes (recursively) the discrete Fourier transform of a and stores the result in a
    40  // if decimation == DIT (decimation in time), the input must be in bit-reversed order
    41  // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order
    42  func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) {
    43  
    44  	opt := fftOptions(opts...)
    45  
    46  	// find the stage where we should stop spawning go routines in our recursive calls
    47  	// (ie when we have as many go routines running as we have available CPUs)
    48  	maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks)))
    49  	if opt.nbTasks == 1 {
    50  		maxSplits = -1
    51  	}
    52  
    53  	// if coset != 0, scale by coset table
    54  	if opt.coset {
    55  		if decimation == DIT {
    56  			// scale by coset table (in bit reversed order)
    57  			cosetTable := domain.cosetTable
    58  			if !domain.withPrecompute {
    59  				// we need to build the full table or do a bit reverse dance.
    60  				cosetTable = make([]fr.Element, len(a))
    61  				BuildExpTable(domain.FrMultiplicativeGen, cosetTable)
    62  			}
    63  			parallel.Execute(len(a), func(start, end int) {
    64  				n := uint64(len(a))
    65  				nn := uint64(64 - bits.TrailingZeros64(n))
    66  				for i := start; i < end; i++ {
    67  					irev := int(bits.Reverse64(uint64(i)) >> nn)
    68  					a[i].Mul(&a[i], &cosetTable[irev])
    69  				}
    70  			}, opt.nbTasks)
    71  		} else {
    72  			if domain.withPrecompute {
    73  				parallel.Execute(len(a), func(start, end int) {
    74  					for i := start; i < end; i++ {
    75  						a[i].Mul(&a[i], &domain.cosetTable[i])
    76  					}
    77  				}, opt.nbTasks)
    78  			} else {
    79  				c := domain.FrMultiplicativeGen
    80  				parallel.Execute(len(a), func(start, end int) {
    81  					var at fr.Element
    82  					at.Exp(c, big.NewInt(int64(start)))
    83  					for i := start; i < end; i++ {
    84  						a[i].Mul(&a[i], &at)
    85  						at.Mul(&at, &c)
    86  					}
    87  				}, opt.nbTasks)
    88  			}
    89  
    90  		}
    91  	}
    92  
    93  	twiddles := domain.twiddles
    94  	twiddlesStartStage := 0
    95  	if !domain.withPrecompute {
    96  		twiddlesStartStage = 3
    97  		nbStages := int(bits.TrailingZeros64(domain.Cardinality))
    98  		if nbStages-twiddlesStartStage > 0 {
    99  			twiddles = make([][]fr.Element, nbStages-twiddlesStartStage)
   100  			w := domain.Generator
   101  			w.Exp(w, big.NewInt(int64(1<<twiddlesStartStage)))
   102  			buildTwiddles(twiddles, w, uint64(nbStages-twiddlesStartStage))
   103  		} // else, we don't need twiddles
   104  	}
   105  
   106  	switch decimation {
   107  	case DIF:
   108  		difFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks)
   109  	case DIT:
   110  		ditFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks)
   111  	default:
   112  		panic("not implemented")
   113  	}
   114  }
   115  
   116  // FFTInverse computes (recursively) the inverse discrete Fourier transform of a and stores the result in a
   117  // if decimation == DIT (decimation in time), the input must be in bit-reversed order
   118  // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order
   119  // coset sets the shift of the fft (0 = no shift, standard fft)
   120  // len(a) must be a power of 2, and w must be a len(a)th root of unity in field F.
   121  func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ...Option) {
   122  	opt := fftOptions(opts...)
   123  
   124  	// find the stage where we should stop spawning go routines in our recursive calls
   125  	// (ie when we have as many go routines running as we have available CPUs)
   126  	maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks)))
   127  	if opt.nbTasks == 1 {
   128  		maxSplits = -1
   129  	}
   130  
   131  	twiddlesInv := domain.twiddlesInv
   132  	twiddlesStartStage := 0
   133  	if !domain.withPrecompute {
   134  		twiddlesStartStage = 3
   135  		nbStages := int(bits.TrailingZeros64(domain.Cardinality))
   136  		if nbStages-twiddlesStartStage > 0 {
   137  			twiddlesInv = make([][]fr.Element, nbStages-twiddlesStartStage)
   138  			w := domain.GeneratorInv
   139  			w.Exp(w, big.NewInt(int64(1<<twiddlesStartStage)))
   140  			buildTwiddles(twiddlesInv, w, uint64(nbStages-twiddlesStartStage))
   141  		} // else, we don't need twiddles
   142  	}
   143  
   144  	switch decimation {
   145  	case DIF:
   146  		difFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks)
   147  	case DIT:
   148  		ditFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks)
   149  	default:
   150  		panic("not implemented")
   151  	}
   152  
   153  	// scale by CardinalityInv
   154  	if !opt.coset {
   155  		parallel.Execute(len(a), func(start, end int) {
   156  			for i := start; i < end; i++ {
   157  				a[i].Mul(&a[i], &domain.CardinalityInv)
   158  			}
   159  		}, opt.nbTasks)
   160  		return
   161  	}
   162  
   163  	if decimation == DIT {
   164  		if domain.withPrecompute {
   165  			parallel.Execute(len(a), func(start, end int) {
   166  				for i := start; i < end; i++ {
   167  					a[i].Mul(&a[i], &domain.cosetTableInv[i]).
   168  						Mul(&a[i], &domain.CardinalityInv)
   169  				}
   170  			}, opt.nbTasks)
   171  		} else {
   172  			c := domain.FrMultiplicativeGenInv
   173  			parallel.Execute(len(a), func(start, end int) {
   174  				var at fr.Element
   175  				at.Exp(c, big.NewInt(int64(start)))
   176  				at.Mul(&at, &domain.CardinalityInv)
   177  				for i := start; i < end; i++ {
   178  					a[i].Mul(&a[i], &at)
   179  					at.Mul(&at, &c)
   180  				}
   181  			}, opt.nbTasks)
   182  		}
   183  		return
   184  	}
   185  
   186  	// decimation == DIF, need to access coset table in bit reversed order.
   187  	cosetTableInv := domain.cosetTableInv
   188  	if !domain.withPrecompute {
   189  		// we need to build the full table or do a bit reverse dance.
   190  		cosetTableInv = make([]fr.Element, len(a))
   191  		BuildExpTable(domain.FrMultiplicativeGenInv, cosetTableInv)
   192  	}
   193  	parallel.Execute(len(a), func(start, end int) {
   194  		n := uint64(len(a))
   195  		nn := uint64(64 - bits.TrailingZeros64(n))
   196  		for i := start; i < end; i++ {
   197  			irev := int(bits.Reverse64(uint64(i)) >> nn)
   198  			a[i].Mul(&a[i], &cosetTableInv[irev]).
   199  				Mul(&a[i], &domain.CardinalityInv)
   200  		}
   201  	}, opt.nbTasks)
   202  
   203  }
   204  
   205  func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) {
   206  	if chDone != nil {
   207  		defer close(chDone)
   208  	}
   209  
   210  	n := len(a)
   211  	if n == 1 {
   212  		return
   213  	} else if n == 256 && stage >= twiddlesStartStage {
   214  		kerDIFNP_256(a, twiddles, stage-twiddlesStartStage)
   215  		return
   216  	}
   217  	m := n >> 1
   218  
   219  	parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits)
   220  
   221  	if stage < twiddlesStartStage {
   222  		if parallelButterfly {
   223  			w := w
   224  			parallel.Execute(m, func(start, end int) {
   225  				if start == 0 {
   226  					fr.Butterfly(&a[0], &a[m])
   227  					start++
   228  				}
   229  				var at fr.Element
   230  				at.Exp(w, big.NewInt(int64(start)))
   231  				innerDIFWithoutTwiddles(a, at, w, start, end, m)
   232  			}, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs
   233  		} else {
   234  			innerDIFWithoutTwiddles(a, w, w, 0, m, m)
   235  		}
   236  		// compute next twiddle
   237  		w.Square(&w)
   238  	} else {
   239  		if parallelButterfly {
   240  			parallel.Execute(m, func(start, end int) {
   241  				innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m)
   242  			}, nbTasks/(1<<(stage)))
   243  		} else {
   244  			innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m)
   245  		}
   246  	}
   247  
   248  	if m == 1 {
   249  		return
   250  	}
   251  
   252  	nextStage := stage + 1
   253  	if stage < maxSplits {
   254  		chDone := make(chan struct{}, 1)
   255  		go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks)
   256  		difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   257  		<-chDone
   258  	} else {
   259  		difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   260  		difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   261  	}
   262  
   263  }
   264  
   265  func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) {
   266  	if start == 0 {
   267  		fr.Butterfly(&a[0], &a[m])
   268  		start++
   269  	}
   270  	for i := start; i < end; i++ {
   271  		fr.Butterfly(&a[i], &a[i+m])
   272  		a[i+m].Mul(&a[i+m], &twiddles[i])
   273  	}
   274  }
   275  
   276  func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) {
   277  	if start == 0 {
   278  		fr.Butterfly(&a[0], &a[m])
   279  		start++
   280  	}
   281  	for i := start; i < end; i++ {
   282  		fr.Butterfly(&a[i], &a[i+m])
   283  		a[i+m].Mul(&a[i+m], &at)
   284  		at.Mul(&at, &w)
   285  	}
   286  }
   287  
   288  func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) {
   289  	if chDone != nil {
   290  		defer close(chDone)
   291  	}
   292  	n := len(a)
   293  	if n == 1 {
   294  		return
   295  	} else if n == 256 && stage >= twiddlesStartStage {
   296  		kerDITNP_256(a, twiddles, stage-twiddlesStartStage)
   297  		return
   298  	}
   299  	m := n >> 1
   300  
   301  	nextStage := stage + 1
   302  	nextW := w
   303  	nextW.Square(&nextW)
   304  
   305  	if stage < maxSplits {
   306  		// that's the only time we fire go routines
   307  		chDone := make(chan struct{}, 1)
   308  		go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks)
   309  		ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   310  		<-chDone
   311  	} else {
   312  		ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   313  		ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   314  	}
   315  
   316  	parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits)
   317  
   318  	if stage < twiddlesStartStage {
   319  		// we need to compute the twiddles for this stage on the fly.
   320  		if parallelButterfly {
   321  			w := w
   322  			parallel.Execute(m, func(start, end int) {
   323  				if start == 0 {
   324  					fr.Butterfly(&a[0], &a[m])
   325  					start++
   326  				}
   327  				var at fr.Element
   328  				at.Exp(w, big.NewInt(int64(start)))
   329  				innerDITWithoutTwiddles(a, at, w, start, end, m)
   330  			}, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs
   331  
   332  		} else {
   333  			innerDITWithoutTwiddles(a, w, w, 0, m, m)
   334  		}
   335  		return
   336  	}
   337  	if parallelButterfly {
   338  		parallel.Execute(m, func(start, end int) {
   339  			innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m)
   340  		}, nbTasks/(1<<(stage)))
   341  	} else {
   342  		innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m)
   343  	}
   344  }
   345  
   346  func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) {
   347  	if start == 0 {
   348  		fr.Butterfly(&a[0], &a[m])
   349  		start++
   350  	}
   351  	for i := start; i < end; i++ {
   352  		a[i+m].Mul(&a[i+m], &twiddles[i])
   353  		fr.Butterfly(&a[i], &a[i+m])
   354  	}
   355  }
   356  
   357  func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) {
   358  	if start == 0 {
   359  		fr.Butterfly(&a[0], &a[m])
   360  		start++
   361  	}
   362  	for i := start; i < end; i++ {
   363  		a[i+m].Mul(&a[i+m], &at)
   364  		fr.Butterfly(&a[i], &a[i+m])
   365  		at.Mul(&at, &w)
   366  	}
   367  }
   368  
   369  func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) {
   370  	// code unrolled & generated by internal/generator/fft/template/fft.go.tmpl
   371  
   372  	innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128)
   373  	for offset := 0; offset < 256; offset += 128 {
   374  		innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64)
   375  	}
   376  	for offset := 0; offset < 256; offset += 64 {
   377  		innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32)
   378  	}
   379  	for offset := 0; offset < 256; offset += 32 {
   380  		innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16)
   381  	}
   382  	for offset := 0; offset < 256; offset += 16 {
   383  		innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8)
   384  	}
   385  	for offset := 0; offset < 256; offset += 8 {
   386  		innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4)
   387  	}
   388  	for offset := 0; offset < 256; offset += 4 {
   389  		innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2)
   390  	}
   391  	for offset := 0; offset < 256; offset += 2 {
   392  		fr.Butterfly(&a[offset], &a[offset+1])
   393  	}
   394  }
   395  
   396  func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) {
   397  	// code unrolled & generated by internal/generator/fft/template/fft.go.tmpl
   398  
   399  	for offset := 0; offset < 256; offset += 2 {
   400  		fr.Butterfly(&a[offset], &a[offset+1])
   401  	}
   402  	for offset := 0; offset < 256; offset += 4 {
   403  		innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2)
   404  	}
   405  	for offset := 0; offset < 256; offset += 8 {
   406  		innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4)
   407  	}
   408  	for offset := 0; offset < 256; offset += 16 {
   409  		innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8)
   410  	}
   411  	for offset := 0; offset < 256; offset += 32 {
   412  		innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16)
   413  	}
   414  	for offset := 0; offset < 256; offset += 64 {
   415  		innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32)
   416  	}
   417  	for offset := 0; offset < 256; offset += 128 {
   418  		innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64)
   419  	}
   420  	innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128)
   421  }