github.com/consensys/gnark-crypto@v0.14.0/internal/generator/fft/template/domain.go.tmpl (about)

     1  import (
     2  	"io"
     3  	"math/big"
     4  	"math/bits"
     5  	"runtime"
     6  	"sync"
     7  	"errors"
     8  
     9  	{{ template "import_fr" . }}
    10  	{{ template "import_curve" . }}
    11  
    12  	"github.com/consensys/gnark-crypto/ecc"
    13  )
    14  
    15  // Domain with a power of 2 cardinality
    16  // compute a field element of order 2x and store it in FinerGenerator
    17  // all other values can be derived from x, GeneratorSqrt
    18  type Domain struct {
    19  	Cardinality             uint64
    20  	CardinalityInv          fr.Element
    21  	Generator               fr.Element
    22  	GeneratorInv            fr.Element
    23  	FrMultiplicativeGen     fr.Element // generator of Fr*
    24  	FrMultiplicativeGenInv  fr.Element
    25  
    26  	// this is set with the WithoutPrecompute option;
    27  	// if true, the domain does some pre-computation and stores it.
    28  	// if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory)
    29  	withPrecompute bool 
    30  
    31  	// the following slices are not serialized and are (re)computed through domain.preComputeTwiddles()
    32  
    33  	// twiddles factor for the FFT using Generator for each stage of the recursive FFT
    34  	twiddles [][]fr.Element
    35  
    36  	// twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT
    37  	twiddlesInv [][]fr.Element
    38  
    39  	// we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover
    40  
    41  	// cosetTable u*<1,g,..,g^(n-1)>
    42  	cosetTable         []fr.Element
    43  
    44  	// cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j
    45  	cosetTableInv         []fr.Element
    46  }
    47  
    48  // GeneratorFullMultiplicativeGroup returns a generator of 𝔽ᵣˣ
    49  func GeneratorFullMultiplicativeGroup() fr.Element {
    50  	var res fr.Element
    51  	{{if eq .Name "bls12-378"}}
    52          res.SetUint64(22)
    53  	{{else if eq .Name "bls12-377"}}
    54          res.SetUint64(22)
    55  	{{else if eq .Name "bls12-381"}}
    56          res.SetUint64(7)
    57  	{{else if eq .Name "bn254"}}
    58          res.SetUint64(5)
    59  	{{else if eq .Name "bw6-761"}}
    60          res.SetUint64(15)
    61  	{{else if eq .Name "bw6-756"}}
    62          res.SetUint64(5)
    63      {{else if eq .Name "bw6-633"}}
    64          res.SetUint64(13)
    65  	{{else if eq .Name "bls24-315"}}
    66          res.SetUint64(7)
    67  	{{else if eq .Name "bls24-317"}}
    68          res.SetUint64(7)
    69  	{{end}}
    70  	return res
    71  }
    72  
    73  // NewDomain returns a subgroup with a power of 2 cardinality
    74  // cardinality >= m
    75  // shift: when specified, it's the element by which the set of root of unity is shifted.
    76  func NewDomain(m uint64, opts ...DomainOption) *Domain {
    77  	opt := domainOptions(opts...)
    78  	domain := &Domain{}
    79  	x := ecc.NextPowerOfTwo(m)
    80  	domain.Cardinality = uint64(x)
    81  	domain.FrMultiplicativeGen = GeneratorFullMultiplicativeGroup()
    82  
    83  	if opt.shift != nil{
    84  		domain.FrMultiplicativeGen.Set(opt.shift)
    85  	}
    86  	domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen)
    87  
    88  	var err error 
    89  	domain.Generator, err = Generator(m)
    90  	if err != nil {
    91  		panic(err)
    92  	}
    93  	domain.GeneratorInv.Inverse(&domain.Generator)
    94  	domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv)
    95  
    96  	// twiddle factors
    97  	domain.withPrecompute = opt.withPrecompute
    98  	if domain.withPrecompute {
    99  		domain.preComputeTwiddles()
   100  	}
   101  
   102  	return domain
   103  }
   104  
   105  // Generator returns a generator for Z/2^(log(m))Z
   106  // or an error if m is too big (required root of unity doesn't exist)
   107  func Generator(m uint64) (fr.Element, error) {
   108  	return fr.Generator(m)
   109  }
   110  
   111  // Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT
   112  // or an error if the domain was created with the WithoutPrecompute option
   113  func (d *Domain) Twiddles() ([][]fr.Element, error) {
   114  	if d.twiddles == nil {
   115  		return nil, errors.New("twiddles not precomputed")
   116  	}
   117  	return d.twiddles, nil
   118  }
   119  
   120  // TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT
   121  // or an error if the domain was created with the WithoutPrecompute option
   122  func (d *Domain) TwiddlesInv() ([][]fr.Element, error) {
   123  	if d.twiddlesInv == nil {
   124  		return nil, errors.New("twiddles not precomputed")
   125  	}
   126  	return d.twiddlesInv, nil
   127  }
   128  
   129  // CosetTable returns the cosetTable u*<1,g,..,g^(n-1)>
   130  // or an error if the domain was created with the WithoutPrecompute option
   131  func (d *Domain) CosetTable() ([]fr.Element, error) {
   132  	if d.cosetTable == nil {
   133  		return nil, errors.New("cosetTable not precomputed")
   134  	}
   135  	return d.cosetTable, nil
   136  }
   137  
   138  // CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)>
   139  // or an error if the domain was created with the WithoutPrecompute option
   140  func (d *Domain) CosetTableInv() ([]fr.Element, error) {
   141  	if d.cosetTableInv == nil {
   142  		return nil, errors.New("cosetTableInv not precomputed")
   143  	}
   144  	return d.cosetTableInv, nil
   145  }
   146  
   147  
   148  
   149  func (d *Domain) preComputeTwiddles() {
   150  
   151  	// nb fft stages
   152  	nbStages := uint64(bits.TrailingZeros64(d.Cardinality))
   153  
   154  	d.twiddles = make([][]fr.Element, nbStages)
   155  	d.twiddlesInv = make([][]fr.Element, nbStages)
   156  	d.cosetTable = make([]fr.Element, d.Cardinality)
   157  	d.cosetTableInv = make([]fr.Element, d.Cardinality)
   158  
   159  	var wg sync.WaitGroup
   160  
   161  	expTable := func(sqrt fr.Element, t []fr.Element) {
   162  		BuildExpTable(sqrt, t)
   163  		wg.Done()
   164  	}
   165  
   166  	wg.Add(4)
   167  	go func() {
   168  		buildTwiddles(d.twiddles, d.Generator, nbStages)
   169  		wg.Done()
   170  	}()
   171  	go func() {
   172  		buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages)
   173  		wg.Done()
   174  	}()
   175  	go expTable(d.FrMultiplicativeGen, d.cosetTable)
   176  	go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv)
   177  
   178  	wg.Wait()
   179  
   180  }
   181  
   182  func buildTwiddles(t [][]fr.Element, omega fr.Element,nbStages uint64) {
   183  	if nbStages == 0 {
   184  		return
   185  	}
   186  	if len(t) != int(nbStages) {
   187  		panic("invalid twiddle table")
   188  	}
   189  	// we just compute the first stage
   190  	t[0] = make([]fr.Element, 1+(1<<(nbStages-1)))
   191  	BuildExpTable(omega, t[0])
   192  
   193  	// for the next stages, we just iterate on the first stage with larger stride
   194  	for i := uint64(1); i < nbStages; i++ {
   195  		t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1)))
   196  		k := 0
   197  		for j := 0; j < len(t[i]); j++ {
   198  			t[i][j] = t[0][k]
   199  			k += 1 << i
   200  		}
   201  	}
   202  
   203  }
   204  
   205  // BuildExpTable precomputes the first n powers of w in parallel
   206  // table[0] = w^0
   207  // table[1] = w^1
   208  // ...
   209  func BuildExpTable(w fr.Element, table []fr.Element) {
   210  	table[0].SetOne()
   211  	n := len(table)
   212  
   213  	// see if it makes sense to parallelize exp tables pre-computation
   214  	interval := 0
   215  	if runtime.NumCPU() >= 4 {
   216  		interval = (n - 1) / (runtime.NumCPU() / 4)
   217  	}
   218  
   219  	// this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation
   220  	// TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio.
   221  	const ratioExpMul = 6000 / 17
   222  
   223  	if interval < ratioExpMul {
   224  		precomputeExpTableChunk(w, 1, table[1:])
   225  		return
   226  	}
   227  
   228  	// we parallelize
   229  	var wg sync.WaitGroup
   230  	for i := 1; i < n; i += interval {
   231  		start := i
   232  		end := i + interval
   233  		if end > n {
   234  			end = n
   235  		}
   236  		wg.Add(1)
   237  		go func() {
   238  			precomputeExpTableChunk(w, uint64(start), table[start:end])
   239  			wg.Done()
   240  		}()
   241  	}
   242  	wg.Wait()
   243  }
   244  
   245  func precomputeExpTableChunk(w fr.Element, power uint64, table []fr.Element) {
   246  
   247  	// this condition ensures that creating a domain of size 1 with cosets don't fail
   248  	if len(table) > 0 {
   249  		table[0].Exp(w, new(big.Int).SetUint64(power))
   250  		for i := 1; i < len(table); i++ {
   251  			table[i].Mul(&table[i-1], &w)
   252  		}
   253  	}
   254  }
   255  
   256  // WriteTo writes a binary representation of the domain (without the precomputed twiddle factors)
   257  // to the provided writer
   258  func (d *Domain) WriteTo(w io.Writer) (int64, error) {
   259  
   260  	enc := curve.NewEncoder(w)
   261  
   262  	toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute}
   263  
   264  	for _, v := range toEncode {
   265  		if err := enc.Encode(v); err != nil {
   266  			return enc.BytesWritten(), err
   267  		}
   268  	}
   269  
   270  	return enc.BytesWritten(), nil
   271  }
   272  
   273  // ReadFrom attempts to decode a domain from Reader
   274  func (d *Domain) ReadFrom(r io.Reader) (int64, error) {
   275  
   276  	dec := curve.NewDecoder(r)
   277  
   278  	toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute}
   279  
   280  	for _, v := range toDecode {
   281  		if err := dec.Decode(v); err != nil {
   282  			return dec.BytesRead(), err
   283  		}
   284  	}
   285  
   286  	if d.withPrecompute {
   287  		d.preComputeTwiddles()
   288  	} 
   289  
   290  	return dec.BytesRead(), nil
   291  }