github.com/consensys/gnark-crypto@v0.14.0/internal/generator/fft/template/fft.go.tmpl (about)

     1  import (
     2  	"math/bits"
     3  	"github.com/consensys/gnark-crypto/ecc"
     4  	"github.com/consensys/gnark-crypto/internal/parallel"
     5  	"math/big"
     6  	{{ template "import_fr" . }}
     7  )
     8  
     9  {{- /* these params set the size of the kernel we generate & unroll */}}
    10  {{ $sizeKernelLog2 := 8}}
    11  {{ $sizeKernel := shl 1 $sizeKernelLog2}}
    12  
    13  // Decimation is used in the FFT call to select decimation in time or in frequency
    14  type Decimation uint8
    15  
    16  const (
    17  	DIT Decimation = iota
    18  	DIF
    19  )
    20  
    21  // parallelize threshold for a single butterfly op, if the fft stage is not parallelized already
    22  const butterflyThreshold = 16
    23  
    24  // FFT computes (recursively) the discrete Fourier transform of a and stores the result in a
    25  // if decimation == DIT (decimation in time), the input must be in bit-reversed order
    26  // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order
    27  func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) {
    28  
    29  	opt := fftOptions(opts...)
    30  
    31  	// find the stage where we should stop spawning go routines in our recursive calls
    32  	// (ie when we have as many go routines running as we have available CPUs)
    33  	maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks)))
    34  	if opt.nbTasks == 1 {
    35  		maxSplits = -1
    36  	}
    37  
    38  	// if coset != 0, scale by coset table
    39  	if opt.coset {
    40  		if decimation == DIT {
    41  			// scale by coset table (in bit reversed order)
    42  			cosetTable := domain.cosetTable
    43  			if !domain.withPrecompute {
    44  				// we need to build the full table or do a bit reverse dance.
    45  				cosetTable = make([]fr.Element, len(a))
    46  				BuildExpTable(domain.FrMultiplicativeGen, cosetTable)
    47  			}
    48  			parallel.Execute(len(a), func(start, end int) {
    49  				n := uint64(len(a))
    50  				nn := uint64(64 - bits.TrailingZeros64(n))
    51  				for i := start; i < end; i++ {
    52  					irev := int(bits.Reverse64(uint64(i)) >> nn)
    53  					a[i].Mul(&a[i], &cosetTable[irev])
    54  				}
    55  			}, opt.nbTasks)
    56  		} else {
    57  			if domain.withPrecompute {
    58  				parallel.Execute(len(a), func(start, end int) {
    59  					for i := start; i < end; i++ {
    60  						a[i].Mul(&a[i], &domain.cosetTable[i])
    61  					}
    62  				}, opt.nbTasks)
    63  			} else {
    64  				c := domain.FrMultiplicativeGen
    65  				parallel.Execute(len(a), func(start, end int) {
    66  					var at fr.Element
    67  					at.Exp(c, big.NewInt(int64(start)))
    68  					for i := start; i < end; i++ {
    69  						a[i].Mul(&a[i], &at)
    70  						at.Mul(&at, &c)
    71  					}
    72  				}, opt.nbTasks)
    73  			}
    74  	
    75  		}
    76  	}
    77  
    78  	twiddles := domain.twiddles
    79  	twiddlesStartStage := 0
    80  	if !domain.withPrecompute {
    81  		twiddlesStartStage = 3
    82  		nbStages := int(bits.TrailingZeros64(domain.Cardinality))
    83  		if nbStages - twiddlesStartStage > 0 {
    84  			twiddles = make([][]fr.Element, nbStages - twiddlesStartStage)
    85  			w := domain.Generator
    86  			w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage)))
    87  			buildTwiddles(twiddles, w, uint64(nbStages - twiddlesStartStage))
    88  		} // else, we don't need twiddles
    89  	}
    90  
    91  	switch decimation {
    92  	case DIF:
    93  		difFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks)
    94  	case DIT:
    95  		ditFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks)
    96  	default:
    97  		panic("not implemented")
    98  	}
    99  }
   100  
   101  
   102  
   103  // FFTInverse computes (recursively) the inverse discrete Fourier transform of a and stores the result in a
   104  // if decimation == DIT (decimation in time), the input must be in bit-reversed order
   105  // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order
   106  // coset sets the shift of the fft (0 = no shift, standard fft)
   107  // len(a) must be a power of 2, and w must be a len(a)th root of unity in field F.
   108  func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ...Option) {
   109  	opt := fftOptions(opts...)
   110  
   111  	// find the stage where we should stop spawning go routines in our recursive calls
   112  	// (ie when we have as many go routines running as we have available CPUs)
   113  	maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks)))
   114  	if opt.nbTasks == 1 {
   115  		maxSplits = -1
   116  	}
   117  
   118  	twiddlesInv := domain.twiddlesInv
   119  	twiddlesStartStage := 0
   120  	if !domain.withPrecompute {
   121  		twiddlesStartStage = 3
   122  		nbStages := int(bits.TrailingZeros64(domain.Cardinality))
   123  		if nbStages - twiddlesStartStage > 0 {
   124  			twiddlesInv = make([][]fr.Element, nbStages - twiddlesStartStage)
   125  			w := domain.GeneratorInv
   126  			w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage)))
   127  			buildTwiddles(twiddlesInv, w, uint64(nbStages - twiddlesStartStage))
   128  		} // else, we don't need twiddles
   129  	}
   130  
   131  	switch decimation {
   132  	case DIF:
   133  		difFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks)
   134  	case DIT:
   135  		ditFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks)
   136  	default:
   137  		panic("not implemented")
   138  	}
   139  
   140  	// scale by CardinalityInv
   141  	if !opt.coset {
   142  		parallel.Execute(len(a), func(start, end int) {
   143  			for i := start; i < end; i++ {
   144  				a[i].Mul(&a[i], &domain.CardinalityInv)
   145  			}
   146  		}, opt.nbTasks)
   147  		return
   148  	}
   149  
   150  
   151  	if decimation == DIT {
   152  		if domain.withPrecompute {
   153  			parallel.Execute(len(a), func(start, end int) {
   154  				for i := start; i < end; i++ {
   155  					a[i].Mul(&a[i], &domain.cosetTableInv[i]).
   156  						Mul(&a[i], &domain.CardinalityInv)
   157  				}
   158  			}, opt.nbTasks)
   159  		} else {
   160  			c := domain.FrMultiplicativeGenInv
   161  			parallel.Execute(len(a), func(start, end int) {
   162  				var at fr.Element
   163  				at.Exp(c, big.NewInt(int64(start)))
   164  				at.Mul(&at, &domain.CardinalityInv)
   165  				for i := start; i < end; i++ {
   166  					a[i].Mul(&a[i], &at)
   167  					at.Mul(&at, &c)
   168  				}
   169  			}, opt.nbTasks)
   170  		}
   171  		return
   172  	}
   173  
   174  	// decimation == DIF, need to access coset table in bit reversed order.
   175  	cosetTableInv := domain.cosetTableInv
   176  	if !domain.withPrecompute {
   177  		// we need to build the full table or do a bit reverse dance.
   178  		cosetTableInv = make([]fr.Element, len(a))
   179  		BuildExpTable(domain.FrMultiplicativeGenInv, cosetTableInv)
   180  	}
   181  	parallel.Execute(len(a), func(start, end int) {
   182  		n := uint64(len(a))
   183  		nn := uint64(64 - bits.TrailingZeros64(n))
   184  		for i := start; i < end; i++ {
   185  			irev := int(bits.Reverse64(uint64(i)) >> nn)
   186  			a[i].Mul(&a[i], &cosetTableInv[irev]).
   187  				Mul(&a[i], &domain.CardinalityInv)
   188  		}
   189  	}, opt.nbTasks)
   190  
   191  }
   192  
   193  func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) {
   194  	if chDone != nil {
   195  		defer close(chDone)
   196  	}
   197  
   198  	n := len(a)
   199  	if n == 1 {
   200  		return
   201  	} else if n == {{$sizeKernel}} && stage >= twiddlesStartStage  {
   202  		kerDIFNP_{{$sizeKernel}}(a, twiddles, stage-twiddlesStartStage)
   203  		return
   204  	}
   205  	m := n >> 1
   206  
   207  	parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits)
   208  
   209  	if stage < twiddlesStartStage {
   210  		if parallelButterfly {
   211  			w := w
   212  			parallel.Execute(m, func(start, end int) {
   213  				if start == 0 {
   214  					fr.Butterfly(&a[0], &a[m])
   215  					start++
   216  				}
   217  				var at fr.Element
   218  				at.Exp(w, big.NewInt(int64(start)))
   219  				innerDIFWithoutTwiddles(a, at,w, start, end, m)
   220  			}, nbTasks / (1 << (stage))) // 1 << stage == estimated used CPUs
   221  		} else {
   222  			innerDIFWithoutTwiddles(a, w, w, 0, m, m)
   223  		}
   224  		// compute next twiddle
   225  		w.Square(&w) 
   226  	} else {
   227  		if parallelButterfly {
   228  			parallel.Execute(m, func(start, end int) {
   229  				innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m)
   230  			}, nbTasks / (1 << (stage)))
   231  		} else {
   232  			innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m)
   233  		}
   234  	}
   235  
   236  	if m == 1 {
   237  		return
   238  	}
   239  	
   240  	nextStage := stage + 1
   241  	if stage < maxSplits {
   242  		chDone := make(chan struct{}, 1)
   243  		go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks)
   244  		difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   245  		<-chDone
   246  	} else {
   247  		difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   248  		difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   249  	}
   250  
   251  }
   252  
   253  
   254  func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) {
   255  	if start == 0 {
   256  		fr.Butterfly(&a[0], &a[m])
   257  		start++
   258  	}
   259  	for i := start; i < end; i++ {
   260  		fr.Butterfly(&a[i], &a[i+m])
   261  		a[i+m].Mul(&a[i+m], &twiddles[i])
   262  	}
   263  }
   264  
   265  func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) {
   266  	if start == 0 {
   267  		fr.Butterfly(&a[0], &a[m])
   268  		start++
   269  	}
   270  	for i := start; i < end; i++ {
   271  		fr.Butterfly(&a[i], &a[i+m])
   272  		a[i+m].Mul(&a[i+m], &at)
   273  		at.Mul(&at, &w)
   274  	}
   275  }
   276  
   277  
   278  func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) {
   279  	if chDone != nil {
   280  		defer close(chDone)
   281  	}
   282  	n := len(a)
   283  	if n == 1 {
   284  		return
   285  	} else if n == {{$sizeKernel}} && stage >= twiddlesStartStage  {
   286  		kerDITNP_{{$sizeKernel}}(a, twiddles, stage-twiddlesStartStage)
   287  		return
   288  	}
   289  	m := n >> 1
   290  
   291  	nextStage := stage + 1
   292  	nextW := w 
   293  	nextW.Square(&nextW)
   294  
   295  	if stage < maxSplits {
   296  		// that's the only time we fire go routines
   297  		chDone := make(chan struct{}, 1)
   298  		go ditFFT(a[m:],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks)
   299  		ditFFT(a[0:m],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   300  		<-chDone
   301  	} else {
   302  		ditFFT(a[0:m],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   303  		ditFFT(a[m:n],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks)
   304  	}
   305  
   306  	parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits)
   307  
   308  	if stage < twiddlesStartStage {
   309  		// we need to compute the twiddles for this stage on the fly.
   310  		if  parallelButterfly {
   311  			w := w
   312  			parallel.Execute(m, func(start, end int) {
   313  				if start == 0 {
   314  					fr.Butterfly(&a[0], &a[m])
   315  					start++
   316  				}
   317  				var at fr.Element
   318  				at.Exp(w, big.NewInt(int64(start)))
   319  				innerDITWithoutTwiddles(a, at,w, start, end, m)
   320  			}, nbTasks / (1 << (stage))) // 1 << stage == estimated used CPUs
   321  
   322  		} else {
   323  			innerDITWithoutTwiddles(a, w,w, 0, m, m)
   324  		}
   325  		return
   326  	}
   327  	if parallelButterfly {
   328  		parallel.Execute(m, func(start, end int) {
   329  			innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m)
   330  		}, nbTasks / (1 << (stage)))
   331  	} else {
   332  		innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m)
   333  	}
   334  }
   335  
   336  
   337  func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) {
   338  	if start == 0 {
   339  		fr.Butterfly(&a[0], &a[m])
   340  		start++
   341  	}
   342  	for i := start; i < end; i++ {
   343  		a[i+m].Mul(&a[i+m], &twiddles[i])
   344  		fr.Butterfly(&a[i], &a[i+m])
   345  	}
   346  }
   347  
   348  func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) {
   349  	if start == 0 {
   350  		fr.Butterfly(&a[0], &a[m])
   351  		start++
   352  	}
   353  	for i := start; i < end; i++ {
   354  		a[i+m].Mul(&a[i+m], &at)
   355  		fr.Butterfly(&a[i], &a[i+m])
   356  		at.Mul(&at, &w)
   357  	}
   358  }
   359  
   360  
   361  
   362  func kerDIFNP_{{$sizeKernel}}(a []fr.Element, twiddles [][]fr.Element, stage int) {
   363  	// code unrolled & generated by internal/generator/fft/template/fft.go.tmpl
   364  
   365  	{{ $n := shl 1 $sizeKernelLog2}}
   366  	{{ $m := div $n 2}}
   367  	{{ $split := 1}}
   368  	{{- range $step := iterate 0 $sizeKernelLog2}} 
   369  		{{- $offset := 0}}
   370  
   371  		{{- $bound := mul $split $n}}
   372  		{{- if eq $bound $n}}
   373  			innerDIFWithTwiddles(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}})
   374  		{{- else}}
   375  			for offset := 0; offset < {{$bound}}; offset += {{$n}} {
   376  				{{- if eq $m 1}}
   377  					fr.Butterfly(&a[offset], &a[offset+1])
   378  				{{- else}}
   379  					innerDIFWithTwiddles(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}})
   380  				{{- end}}
   381  			}
   382  		{{- end}}
   383  
   384  		{{- $n = div $n 2}}
   385  		{{- $m = div $n 2}}
   386  		{{- $split = mul $split 2}}
   387  	{{- end}}
   388  }
   389  
   390  
   391  func kerDITNP_{{$sizeKernel}}(a []fr.Element, twiddles [][]fr.Element, stage int) {
   392  	// code unrolled & generated by internal/generator/fft/template/fft.go.tmpl
   393  
   394  	{{ $n := 2}}
   395  	{{ $m := div $n 2}}
   396  	{{ $split := div (shl 1 $sizeKernelLog2) 2}}
   397  	{{- range $step := reverse (iterate 0 $sizeKernelLog2)}} 
   398  		{{- $offset := 0}}
   399  
   400  		{{- $bound := mul $split $n}}
   401  		{{- if eq $bound $n}}
   402  			innerDITWithTwiddles(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}})
   403  		{{- else}}
   404  			for offset := 0; offset < {{$bound}}; offset += {{$n}} {
   405  				{{- if eq $m 1}}
   406  					fr.Butterfly(&a[offset], &a[offset+1])
   407  				{{- else}}
   408  					innerDITWithTwiddles(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}})
   409  				{{- end}}
   410  			}
   411  		{{- end}}
   412  
   413  		{{- $n = mul $n 2}}
   414  		{{- $m = div $n 2}}
   415  		{{- $split = div $split 2}}
   416  	{{- end}}
   417  }