github.com/consensys/gnark-crypto@v0.14.0/ecc/bn254/kzg/utils.go (about)

     1  // Copyright 2020 Consensys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Code generated by consensys/gnark-crypto DO NOT EDIT
    16  
    17  package kzg
    18  
    19  import (
    20  	"fmt"
    21  	"math/big"
    22  	"math/bits"
    23  	"runtime"
    24  
    25  	"github.com/consensys/gnark-crypto/ecc"
    26  	curve "github.com/consensys/gnark-crypto/ecc/bn254"
    27  	"github.com/consensys/gnark-crypto/ecc/bn254/fr"
    28  	"github.com/consensys/gnark-crypto/internal/parallel"
    29  )
    30  
    31  // ToLagrangeG1 in place transform of coeffs canonical form into Lagrange form.
    32  // From the formula Lᵢ(τ) = 1/n∑_{j<n}(τ/ωⁱ)ʲ we
    33  // see that [L₁(τ),..,Lₙ(τ)] = FFT_inv(∑_{j<n}τʲXʲ), so it suffices to apply the inverse
    34  // fft on the vector consisting of the original SRS.
    35  // Size of coeffs must be a power of 2.
    36  func ToLagrangeG1(coeffs []curve.G1Affine) ([]curve.G1Affine, error) {
    37  	if bits.OnesCount64(uint64(len(coeffs))) != 1 {
    38  		return nil, fmt.Errorf("len(coeffs) must be a power of 2")
    39  	}
    40  	size := len(coeffs)
    41  
    42  	numCPU := uint64(runtime.NumCPU())
    43  	maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) << 1
    44  
    45  	twiddlesInv, err := computeTwiddlesInv(size)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  
    50  	// batch convert to Jacobian
    51  	jCoeffs := make([]curve.G1Jac, len(coeffs))
    52  	for i := 0; i < len(coeffs); i++ {
    53  		jCoeffs[i].FromAffine(&coeffs[i])
    54  	}
    55  
    56  	difFFTG1(jCoeffs, twiddlesInv, 0, maxSplits, nil)
    57  
    58  	// TODO @gbotrel generify the cobra bitreverse function, benchmark it and use it everywhere
    59  	bitReverse(jCoeffs)
    60  
    61  	var invBigint big.Int
    62  	var frCardinality fr.Element
    63  	frCardinality.SetUint64(uint64(size))
    64  	frCardinality.Inverse(&frCardinality)
    65  	frCardinality.BigInt(&invBigint)
    66  
    67  	parallel.Execute(size, func(start, end int) {
    68  		for i := start; i < end; i++ {
    69  			jCoeffs[i].ScalarMultiplication(&jCoeffs[i], &invBigint)
    70  		}
    71  	})
    72  
    73  	// batch convert to affine
    74  	return curve.BatchJacobianToAffineG1(jCoeffs), nil
    75  }
    76  
    77  func computeTwiddlesInv(cardinality int) ([]*big.Int, error) {
    78  	generator, err := fr.Generator(uint64(cardinality))
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	// inverse the generator
    84  	generator.Inverse(&generator)
    85  
    86  	// nb fft stages
    87  	nbStages := uint64(bits.TrailingZeros64(uint64(cardinality)))
    88  
    89  	r := make([]*big.Int, 1+(1<<(nbStages-1)))
    90  
    91  	w := generator
    92  	r[0] = new(big.Int).SetUint64(1)
    93  	if len(r) == 1 {
    94  		return r, nil
    95  	}
    96  	r[1] = new(big.Int)
    97  	w.BigInt(r[1])
    98  	for j := 2; j < len(r); j++ {
    99  		w.Mul(&w, &generator)
   100  		r[j] = new(big.Int)
   101  		w.BigInt(r[j])
   102  	}
   103  
   104  	return r, nil
   105  }
   106  
   107  func bitReverse[T any](a []T) {
   108  	n := uint64(len(a))
   109  	nn := uint64(64 - bits.TrailingZeros64(n))
   110  
   111  	for i := uint64(0); i < n; i++ {
   112  		irev := bits.Reverse64(i) >> nn
   113  		if irev > i {
   114  			a[i], a[irev] = a[irev], a[i]
   115  		}
   116  	}
   117  }
   118  
   119  func butterflyG1(a *curve.G1Jac, b *curve.G1Jac) {
   120  	t := *a
   121  	a.AddAssign(b)
   122  	t.SubAssign(b)
   123  	b.Set(&t)
   124  }
   125  
   126  func difFFTG1(a []curve.G1Jac, twiddles []*big.Int, stage, maxSplits int, chDone chan struct{}) {
   127  	if chDone != nil {
   128  		defer close(chDone)
   129  	}
   130  
   131  	n := len(a)
   132  	if n == 1 {
   133  		return
   134  	}
   135  	m := n >> 1
   136  
   137  	butterflyG1(&a[0], &a[m])
   138  	// stage determines the stride
   139  	// if stage == 0, then we use 1, w, w**2, w**3, w**4, w**5, w**6, ...
   140  	// if stage == 1, then we use 1, w**2, w**4, w**6, ... that is, indexes 0, 2, 4, 6, ... of stage 0
   141  	// if stage == 2, then we use 1, w**4, w**8, w**12, ... that is indexes 0, 4, 8, 12, ... of stage 0
   142  	stride := 1 << stage
   143  
   144  	const butterflyThreshold = 8
   145  	if m >= butterflyThreshold {
   146  		// 1 << stage == estimated used CPUs
   147  		numCPU := runtime.NumCPU() / (1 << (stage))
   148  		parallel.Execute(m, func(start, end int) {
   149  			if start == 0 {
   150  				start = 1
   151  			}
   152  			j := start * stride
   153  			for i := start; i < end; i++ {
   154  				butterflyG1(&a[i], &a[i+m])
   155  				a[i+m].ScalarMultiplication(&a[i+m], twiddles[j])
   156  				j += stride
   157  			}
   158  		}, numCPU)
   159  	} else {
   160  		j := stride
   161  		for i := 1; i < m; i++ {
   162  			butterflyG1(&a[i], &a[i+m])
   163  			a[i+m].ScalarMultiplication(&a[i+m], twiddles[j])
   164  			j += stride
   165  		}
   166  	}
   167  
   168  	if m == 1 {
   169  		return
   170  	}
   171  
   172  	nextStage := stage + 1
   173  	if stage < maxSplits {
   174  		chDone := make(chan struct{}, 1)
   175  		go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone)
   176  		difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil)
   177  		<-chDone
   178  	} else {
   179  		difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil)
   180  		difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil)
   181  	}
   182  }