github.com/consensys/gnark-crypto@v0.14.0/internal/generator/kzg/template/utils.go.tmpl (about)

     1  import (
     2  	"math/big"
     3  	"math/bits"
     4  	"runtime"
     5  	"fmt"
     6  
     7  	"github.com/consensys/gnark-crypto/ecc"
     8  	curve "github.com/consensys/gnark-crypto/ecc/{{ .Name }}"
     9  	"github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr"
    10  	"github.com/consensys/gnark-crypto/internal/parallel"
    11  )
    12  
    13  // ToLagrangeG1 in place transform of coeffs canonical form into Lagrange form.
    14  // From the formula Lᵢ(τ) = 1/n∑_{j<n}(τ/ωⁱ)ʲ we
    15  // see that [L₁(τ),..,Lₙ(τ)] = FFT_inv(∑_{j<n}τʲXʲ), so it suffices to apply the inverse
    16  // fft on the vector consisting of the original SRS.
    17  // Size of coeffs must be a power of 2.
    18  func ToLagrangeG1(coeffs []curve.G1Affine) ([]curve.G1Affine, error) {
    19  	if bits.OnesCount64(uint64(len(coeffs))) != 1 {
    20  		return nil, fmt.Errorf("len(coeffs) must be a power of 2")
    21  	}
    22  	size := len(coeffs)
    23  
    24  	numCPU := uint64(runtime.NumCPU())
    25  	maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) << 1
    26  
    27  	twiddlesInv, err := computeTwiddlesInv(size)
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	// batch convert to Jacobian
    33  	jCoeffs := make([]curve.G1Jac, len(coeffs))
    34  	for i := 0; i < len(coeffs); i++ {
    35  		jCoeffs[i].FromAffine(&coeffs[i])
    36  	}
    37  
    38  	difFFTG1(jCoeffs, twiddlesInv, 0, maxSplits, nil)
    39  
    40  	// TODO @gbotrel generify the cobra bitreverse function, benchmark it and use it everywhere
    41  	bitReverse(jCoeffs)
    42  
    43  	var invBigint big.Int
    44  	var frCardinality fr.Element
    45  	frCardinality.SetUint64(uint64(size))
    46  	frCardinality.Inverse(&frCardinality)
    47  	frCardinality.BigInt(&invBigint)
    48  
    49  	parallel.Execute(size, func(start, end int) {
    50  		for i := start; i < end; i++ {
    51  			jCoeffs[i].ScalarMultiplication(&jCoeffs[i], &invBigint)
    52  		}
    53  	})
    54  
    55  	// batch convert to affine
    56  	return curve.BatchJacobianToAffineG1(jCoeffs), nil
    57  }
    58  
    59  func computeTwiddlesInv(cardinality int) ([]*big.Int, error) {
    60  	generator, err := fr.Generator(uint64(cardinality))
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	// inverse the generator
    66  	generator.Inverse(&generator)
    67  
    68  	// nb fft stages
    69  	nbStages := uint64(bits.TrailingZeros64(uint64(cardinality)))
    70  
    71  	r := make([]*big.Int, 1+(1<<(nbStages-1)))
    72  
    73  	w := generator
    74  	r[0] = new(big.Int).SetUint64(1)
    75  	if len(r) == 1 {
    76  		return r, nil
    77  	}
    78  	r[1] = new(big.Int)
    79  	w.BigInt(r[1])
    80  	for j := 2; j < len(r); j++ {
    81  		w.Mul(&w, &generator)
    82  		r[j] = new(big.Int)
    83  		w.BigInt(r[j])
    84  	}
    85  
    86  	return r, nil
    87  }
    88  
    89  func bitReverse[T any](a []T) {
    90  	n := uint64(len(a))
    91  	nn := uint64(64 - bits.TrailingZeros64(n))
    92  
    93  	for i := uint64(0); i < n; i++ {
    94  		irev := bits.Reverse64(i) >> nn
    95  		if irev > i {
    96  			a[i], a[irev] = a[irev], a[i]
    97  		}
    98  	}
    99  }
   100  
   101  func butterflyG1(a *curve.G1Jac, b *curve.G1Jac) {
   102  	t := *a
   103  	a.AddAssign(b)
   104  	t.SubAssign(b)
   105  	b.Set(&t)
   106  }
   107  
   108  func difFFTG1(a []curve.G1Jac, twiddles []*big.Int, stage, maxSplits int, chDone chan struct{}) {
   109  	if chDone != nil {
   110  		defer close(chDone)
   111  	}
   112  
   113  	n := len(a)
   114  	if n == 1 {
   115  		return
   116  	}
   117  	m := n >> 1
   118  
   119  	butterflyG1(&a[0], &a[m])
   120  	// stage determines the stride
   121  	// if stage == 0, then we use 1, w, w**2, w**3, w**4, w**5, w**6, ...
   122  	// if stage == 1, then we use 1, w**2, w**4, w**6, ... that is, indexes 0, 2, 4, 6, ... of stage 0
   123  	// if stage == 2, then we use 1, w**4, w**8, w**12, ... that is indexes 0, 4, 8, 12, ... of stage 0
   124  	stride := 1 << stage
   125  
   126  	const butterflyThreshold = 8
   127  	if m >= butterflyThreshold {
   128  		// 1 << stage == estimated used CPUs
   129  		numCPU := runtime.NumCPU() / (1 << (stage))
   130  		parallel.Execute(m, func(start, end int) {
   131  			if start == 0 {
   132  				start = 1
   133  			}
   134  			j := start * stride
   135  			for i := start; i < end; i++ {
   136  				butterflyG1(&a[i], &a[i+m])
   137  				a[i+m].ScalarMultiplication(&a[i+m], twiddles[j])
   138  				j += stride
   139  			}
   140  		}, numCPU)
   141  	} else {
   142  		j := stride
   143  		for i := 1; i < m; i++ {
   144  			butterflyG1(&a[i], &a[i+m])
   145  			a[i+m].ScalarMultiplication(&a[i+m], twiddles[j])
   146  			j += stride
   147  		}
   148  	}
   149  
   150  	if m == 1 {
   151  		return
   152  	}
   153  
   154  	nextStage := stage + 1
   155  	if stage < maxSplits {
   156  		chDone := make(chan struct{}, 1)
   157  		go difFFTG1(a[m:n], twiddles, nextStage, maxSplits, chDone)
   158  		difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil)
   159  		<-chDone
   160  	} else {
   161  		difFFTG1(a[0:m], twiddles, nextStage, maxSplits, nil)
   162  		difFFTG1(a[m:n], twiddles, nextStage, maxSplits, nil)
   163  	}
   164  }