github.com/consensys/gnark-crypto@v0.14.0/ecc/bn254/fr/fft/domain.go (about)

     1  // Copyright 2020 Consensys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Code generated by consensys/gnark-crypto DO NOT EDIT
    16  
    17  package fft
    18  
    19  import (
    20  	"errors"
    21  	"io"
    22  	"math/big"
    23  	"math/bits"
    24  	"runtime"
    25  	"sync"
    26  
    27  	"github.com/consensys/gnark-crypto/ecc/bn254/fr"
    28  
    29  	curve "github.com/consensys/gnark-crypto/ecc/bn254"
    30  
    31  	"github.com/consensys/gnark-crypto/ecc"
    32  )
    33  
    34  // Domain with a power of 2 cardinality
    35  // compute a field element of order 2x and store it in FinerGenerator
    36  // all other values can be derived from x, GeneratorSqrt
    37  type Domain struct {
    38  	Cardinality            uint64
    39  	CardinalityInv         fr.Element
    40  	Generator              fr.Element
    41  	GeneratorInv           fr.Element
    42  	FrMultiplicativeGen    fr.Element // generator of Fr*
    43  	FrMultiplicativeGenInv fr.Element
    44  
    45  	// this is set with the WithoutPrecompute option;
    46  	// if true, the domain does some pre-computation and stores it.
    47  	// if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory)
    48  	withPrecompute bool
    49  
    50  	// the following slices are not serialized and are (re)computed through domain.preComputeTwiddles()
    51  
    52  	// twiddles factor for the FFT using Generator for each stage of the recursive FFT
    53  	twiddles [][]fr.Element
    54  
    55  	// twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT
    56  	twiddlesInv [][]fr.Element
    57  
    58  	// we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover
    59  
    60  	// cosetTable u*<1,g,..,g^(n-1)>
    61  	cosetTable []fr.Element
    62  
    63  	// cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j
    64  	cosetTableInv []fr.Element
    65  }
    66  
    67  // GeneratorFullMultiplicativeGroup returns a generator of 𝔽ᵣˣ
    68  func GeneratorFullMultiplicativeGroup() fr.Element {
    69  	var res fr.Element
    70  
    71  	res.SetUint64(5)
    72  
    73  	return res
    74  }
    75  
    76  // NewDomain returns a subgroup with a power of 2 cardinality
    77  // cardinality >= m
    78  // shift: when specified, it's the element by which the set of root of unity is shifted.
    79  func NewDomain(m uint64, opts ...DomainOption) *Domain {
    80  	opt := domainOptions(opts...)
    81  	domain := &Domain{}
    82  	x := ecc.NextPowerOfTwo(m)
    83  	domain.Cardinality = uint64(x)
    84  	domain.FrMultiplicativeGen = GeneratorFullMultiplicativeGroup()
    85  
    86  	if opt.shift != nil {
    87  		domain.FrMultiplicativeGen.Set(opt.shift)
    88  	}
    89  	domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen)
    90  
    91  	var err error
    92  	domain.Generator, err = Generator(m)
    93  	if err != nil {
    94  		panic(err)
    95  	}
    96  	domain.GeneratorInv.Inverse(&domain.Generator)
    97  	domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv)
    98  
    99  	// twiddle factors
   100  	domain.withPrecompute = opt.withPrecompute
   101  	if domain.withPrecompute {
   102  		domain.preComputeTwiddles()
   103  	}
   104  
   105  	return domain
   106  }
   107  
   108  // Generator returns a generator for Z/2^(log(m))Z
   109  // or an error if m is too big (required root of unity doesn't exist)
   110  func Generator(m uint64) (fr.Element, error) {
   111  	return fr.Generator(m)
   112  }
   113  
   114  // Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT
   115  // or an error if the domain was created with the WithoutPrecompute option
   116  func (d *Domain) Twiddles() ([][]fr.Element, error) {
   117  	if d.twiddles == nil {
   118  		return nil, errors.New("twiddles not precomputed")
   119  	}
   120  	return d.twiddles, nil
   121  }
   122  
   123  // TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT
   124  // or an error if the domain was created with the WithoutPrecompute option
   125  func (d *Domain) TwiddlesInv() ([][]fr.Element, error) {
   126  	if d.twiddlesInv == nil {
   127  		return nil, errors.New("twiddles not precomputed")
   128  	}
   129  	return d.twiddlesInv, nil
   130  }
   131  
   132  // CosetTable returns the cosetTable u*<1,g,..,g^(n-1)>
   133  // or an error if the domain was created with the WithoutPrecompute option
   134  func (d *Domain) CosetTable() ([]fr.Element, error) {
   135  	if d.cosetTable == nil {
   136  		return nil, errors.New("cosetTable not precomputed")
   137  	}
   138  	return d.cosetTable, nil
   139  }
   140  
   141  // CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)>
   142  // or an error if the domain was created with the WithoutPrecompute option
   143  func (d *Domain) CosetTableInv() ([]fr.Element, error) {
   144  	if d.cosetTableInv == nil {
   145  		return nil, errors.New("cosetTableInv not precomputed")
   146  	}
   147  	return d.cosetTableInv, nil
   148  }
   149  
   150  func (d *Domain) preComputeTwiddles() {
   151  
   152  	// nb fft stages
   153  	nbStages := uint64(bits.TrailingZeros64(d.Cardinality))
   154  
   155  	d.twiddles = make([][]fr.Element, nbStages)
   156  	d.twiddlesInv = make([][]fr.Element, nbStages)
   157  	d.cosetTable = make([]fr.Element, d.Cardinality)
   158  	d.cosetTableInv = make([]fr.Element, d.Cardinality)
   159  
   160  	var wg sync.WaitGroup
   161  
   162  	expTable := func(sqrt fr.Element, t []fr.Element) {
   163  		BuildExpTable(sqrt, t)
   164  		wg.Done()
   165  	}
   166  
   167  	wg.Add(4)
   168  	go func() {
   169  		buildTwiddles(d.twiddles, d.Generator, nbStages)
   170  		wg.Done()
   171  	}()
   172  	go func() {
   173  		buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages)
   174  		wg.Done()
   175  	}()
   176  	go expTable(d.FrMultiplicativeGen, d.cosetTable)
   177  	go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv)
   178  
   179  	wg.Wait()
   180  
   181  }
   182  
   183  func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) {
   184  	if nbStages == 0 {
   185  		return
   186  	}
   187  	if len(t) != int(nbStages) {
   188  		panic("invalid twiddle table")
   189  	}
   190  	// we just compute the first stage
   191  	t[0] = make([]fr.Element, 1+(1<<(nbStages-1)))
   192  	BuildExpTable(omega, t[0])
   193  
   194  	// for the next stages, we just iterate on the first stage with larger stride
   195  	for i := uint64(1); i < nbStages; i++ {
   196  		t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1)))
   197  		k := 0
   198  		for j := 0; j < len(t[i]); j++ {
   199  			t[i][j] = t[0][k]
   200  			k += 1 << i
   201  		}
   202  	}
   203  
   204  }
   205  
   206  // BuildExpTable precomputes the first n powers of w in parallel
   207  // table[0] = w^0
   208  // table[1] = w^1
   209  // ...
   210  func BuildExpTable(w fr.Element, table []fr.Element) {
   211  	table[0].SetOne()
   212  	n := len(table)
   213  
   214  	// see if it makes sense to parallelize exp tables pre-computation
   215  	interval := 0
   216  	if runtime.NumCPU() >= 4 {
   217  		interval = (n - 1) / (runtime.NumCPU() / 4)
   218  	}
   219  
   220  	// this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation
   221  	// TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio.
   222  	const ratioExpMul = 6000 / 17
   223  
   224  	if interval < ratioExpMul {
   225  		precomputeExpTableChunk(w, 1, table[1:])
   226  		return
   227  	}
   228  
   229  	// we parallelize
   230  	var wg sync.WaitGroup
   231  	for i := 1; i < n; i += interval {
   232  		start := i
   233  		end := i + interval
   234  		if end > n {
   235  			end = n
   236  		}
   237  		wg.Add(1)
   238  		go func() {
   239  			precomputeExpTableChunk(w, uint64(start), table[start:end])
   240  			wg.Done()
   241  		}()
   242  	}
   243  	wg.Wait()
   244  }
   245  
   246  func precomputeExpTableChunk(w fr.Element, power uint64, table []fr.Element) {
   247  
   248  	// this condition ensures that creating a domain of size 1 with cosets don't fail
   249  	if len(table) > 0 {
   250  		table[0].Exp(w, new(big.Int).SetUint64(power))
   251  		for i := 1; i < len(table); i++ {
   252  			table[i].Mul(&table[i-1], &w)
   253  		}
   254  	}
   255  }
   256  
   257  // WriteTo writes a binary representation of the domain (without the precomputed twiddle factors)
   258  // to the provided writer
   259  func (d *Domain) WriteTo(w io.Writer) (int64, error) {
   260  
   261  	enc := curve.NewEncoder(w)
   262  
   263  	toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute}
   264  
   265  	for _, v := range toEncode {
   266  		if err := enc.Encode(v); err != nil {
   267  			return enc.BytesWritten(), err
   268  		}
   269  	}
   270  
   271  	return enc.BytesWritten(), nil
   272  }
   273  
   274  // ReadFrom attempts to decode a domain from Reader
   275  func (d *Domain) ReadFrom(r io.Reader) (int64, error) {
   276  
   277  	dec := curve.NewDecoder(r)
   278  
   279  	toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute}
   280  
   281  	for _, v := range toDecode {
   282  		if err := dec.Decode(v); err != nil {
   283  			return dec.BytesRead(), err
   284  		}
   285  	}
   286  
   287  	if d.withPrecompute {
   288  		d.preComputeTwiddles()
   289  	}
   290  
   291  	return dec.BytesRead(), nil
   292  }