github.com/consensys/gnark-crypto@v0.14.0/internal/generator/fft/template/domain.go.tmpl (about) 1 import ( 2 "io" 3 "math/big" 4 "math/bits" 5 "runtime" 6 "sync" 7 "errors" 8 9 {{ template "import_fr" . }} 10 {{ template "import_curve" . }} 11 12 "github.com/consensys/gnark-crypto/ecc" 13 ) 14 15 // Domain with a power of 2 cardinality 16 // compute a field element of order 2x and store it in FinerGenerator 17 // all other values can be derived from x, GeneratorSqrt 18 type Domain struct { 19 Cardinality uint64 20 CardinalityInv fr.Element 21 Generator fr.Element 22 GeneratorInv fr.Element 23 FrMultiplicativeGen fr.Element // generator of Fr* 24 FrMultiplicativeGenInv fr.Element 25 26 // this is set with the WithoutPrecompute option; 27 // if true, the domain does some pre-computation and stores it. 28 // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) 29 withPrecompute bool 30 31 // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() 32 33 // twiddles factor for the FFT using Generator for each stage of the recursive FFT 34 twiddles [][]fr.Element 35 36 // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT 37 twiddlesInv [][]fr.Element 38 39 // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover 40 41 // cosetTable u*<1,g,..,g^(n-1)> 42 cosetTable []fr.Element 43 44 // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j 45 cosetTableInv []fr.Element 46 } 47 48 // GeneratorFullMultiplicativeGroup returns a generator of 𝔽ᵣˣ 49 func GeneratorFullMultiplicativeGroup() fr.Element { 50 var res fr.Element 51 {{if eq .Name "bls12-378"}} 52 res.SetUint64(22) 53 {{else if eq .Name "bls12-377"}} 54 res.SetUint64(22) 55 {{else if eq .Name "bls12-381"}} 56 res.SetUint64(7) 57 {{else if eq .Name "bn254"}} 58 res.SetUint64(5) 59 {{else if eq .Name "bw6-761"}} 60 res.SetUint64(15) 61 {{else if eq .Name "bw6-756"}} 62 res.SetUint64(5) 63 {{else if eq .Name "bw6-633"}} 64 res.SetUint64(13) 65 {{else if eq .Name "bls24-315"}} 66 res.SetUint64(7) 67 {{else if eq .Name "bls24-317"}} 68 res.SetUint64(7) 69 {{end}} 70 return res 71 } 72 73 // NewDomain returns a subgroup with a power of 2 cardinality 74 // cardinality >= m 75 // shift: when specified, it's the element by which the set of root of unity is shifted. 76 func NewDomain(m uint64, opts ...DomainOption) *Domain { 77 opt := domainOptions(opts...) 78 domain := &Domain{} 79 x := ecc.NextPowerOfTwo(m) 80 domain.Cardinality = uint64(x) 81 domain.FrMultiplicativeGen = GeneratorFullMultiplicativeGroup() 82 83 if opt.shift != nil{ 84 domain.FrMultiplicativeGen.Set(opt.shift) 85 } 86 domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) 87 88 var err error 89 domain.Generator, err = Generator(m) 90 if err != nil { 91 panic(err) 92 } 93 domain.GeneratorInv.Inverse(&domain.Generator) 94 domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) 95 96 // twiddle factors 97 domain.withPrecompute = opt.withPrecompute 98 if domain.withPrecompute { 99 domain.preComputeTwiddles() 100 } 101 102 return domain 103 } 104 105 // Generator returns a generator for Z/2^(log(m))Z 106 // or an error if m is too big (required root of unity doesn't exist) 107 func Generator(m uint64) (fr.Element, error) { 108 return fr.Generator(m) 109 } 110 111 // Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT 112 // or an error if the domain was created with the WithoutPrecompute option 113 func (d *Domain) Twiddles() ([][]fr.Element, error) { 114 if d.twiddles == nil { 115 return nil, errors.New("twiddles not precomputed") 116 } 117 return d.twiddles, nil 118 } 119 120 // TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT 121 // or an error if the domain was created with the WithoutPrecompute option 122 func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { 123 if d.twiddlesInv == nil { 124 return nil, errors.New("twiddles not precomputed") 125 } 126 return d.twiddlesInv, nil 127 } 128 129 // CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> 130 // or an error if the domain was created with the WithoutPrecompute option 131 func (d *Domain) CosetTable() ([]fr.Element, error) { 132 if d.cosetTable == nil { 133 return nil, errors.New("cosetTable not precomputed") 134 } 135 return d.cosetTable, nil 136 } 137 138 // CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> 139 // or an error if the domain was created with the WithoutPrecompute option 140 func (d *Domain) CosetTableInv() ([]fr.Element, error) { 141 if d.cosetTableInv == nil { 142 return nil, errors.New("cosetTableInv not precomputed") 143 } 144 return d.cosetTableInv, nil 145 } 146 147 148 149 func (d *Domain) preComputeTwiddles() { 150 151 // nb fft stages 152 nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) 153 154 d.twiddles = make([][]fr.Element, nbStages) 155 d.twiddlesInv = make([][]fr.Element, nbStages) 156 d.cosetTable = make([]fr.Element, d.Cardinality) 157 d.cosetTableInv = make([]fr.Element, d.Cardinality) 158 159 var wg sync.WaitGroup 160 161 expTable := func(sqrt fr.Element, t []fr.Element) { 162 BuildExpTable(sqrt, t) 163 wg.Done() 164 } 165 166 wg.Add(4) 167 go func() { 168 buildTwiddles(d.twiddles, d.Generator, nbStages) 169 wg.Done() 170 }() 171 go func() { 172 buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) 173 wg.Done() 174 }() 175 go expTable(d.FrMultiplicativeGen, d.cosetTable) 176 go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) 177 178 wg.Wait() 179 180 } 181 182 func buildTwiddles(t [][]fr.Element, omega fr.Element,nbStages uint64) { 183 if nbStages == 0 { 184 return 185 } 186 if len(t) != int(nbStages) { 187 panic("invalid twiddle table") 188 } 189 // we just compute the first stage 190 t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) 191 BuildExpTable(omega, t[0]) 192 193 // for the next stages, we just iterate on the first stage with larger stride 194 for i := uint64(1); i < nbStages; i++ { 195 t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) 196 k := 0 197 for j := 0; j < len(t[i]); j++ { 198 t[i][j] = t[0][k] 199 k += 1 << i 200 } 201 } 202 203 } 204 205 // BuildExpTable precomputes the first n powers of w in parallel 206 // table[0] = w^0 207 // table[1] = w^1 208 // ... 209 func BuildExpTable(w fr.Element, table []fr.Element) { 210 table[0].SetOne() 211 n := len(table) 212 213 // see if it makes sense to parallelize exp tables pre-computation 214 interval := 0 215 if runtime.NumCPU() >= 4 { 216 interval = (n - 1) / (runtime.NumCPU() / 4) 217 } 218 219 // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation 220 // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. 221 const ratioExpMul = 6000 / 17 222 223 if interval < ratioExpMul { 224 precomputeExpTableChunk(w, 1, table[1:]) 225 return 226 } 227 228 // we parallelize 229 var wg sync.WaitGroup 230 for i := 1; i < n; i += interval { 231 start := i 232 end := i + interval 233 if end > n { 234 end = n 235 } 236 wg.Add(1) 237 go func() { 238 precomputeExpTableChunk(w, uint64(start), table[start:end]) 239 wg.Done() 240 }() 241 } 242 wg.Wait() 243 } 244 245 func precomputeExpTableChunk(w fr.Element, power uint64, table []fr.Element) { 246 247 // this condition ensures that creating a domain of size 1 with cosets don't fail 248 if len(table) > 0 { 249 table[0].Exp(w, new(big.Int).SetUint64(power)) 250 for i := 1; i < len(table); i++ { 251 table[i].Mul(&table[i-1], &w) 252 } 253 } 254 } 255 256 // WriteTo writes a binary representation of the domain (without the precomputed twiddle factors) 257 // to the provided writer 258 func (d *Domain) WriteTo(w io.Writer) (int64, error) { 259 260 enc := curve.NewEncoder(w) 261 262 toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} 263 264 for _, v := range toEncode { 265 if err := enc.Encode(v); err != nil { 266 return enc.BytesWritten(), err 267 } 268 } 269 270 return enc.BytesWritten(), nil 271 } 272 273 // ReadFrom attempts to decode a domain from Reader 274 func (d *Domain) ReadFrom(r io.Reader) (int64, error) { 275 276 dec := curve.NewDecoder(r) 277 278 toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} 279 280 for _, v := range toDecode { 281 if err := dec.Decode(v); err != nil { 282 return dec.BytesRead(), err 283 } 284 } 285 286 if d.withPrecompute { 287 d.preComputeTwiddles() 288 } 289 290 return dec.BytesRead(), nil 291 }