github.com/consensys/gnark-crypto@v0.14.0/internal/generator/ecc/generate.go (about)

     1  package ecc
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"path/filepath"
     7  	"reflect"
     8  	"sort"
     9  	"strings"
    10  	"text/template"
    11  
    12  	"github.com/consensys/bavard"
    13  	"github.com/consensys/gnark-crypto/internal/generator/config"
    14  )
    15  
    16  func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error {
    17  
    18  	packageName := strings.ReplaceAll(conf.Name, "-", "")
    19  
    20  	var entries []bavard.Entry
    21  
    22  	// hash To curve
    23  	genHashToCurve := func(point *config.Point, suite config.HashSuite) error {
    24  		if suite == nil { //Nothing to generate. Bypass
    25  			return nil
    26  		}
    27  
    28  		entries = []bavard.Entry{
    29  			{File: filepath.Join(baseDir, fmt.Sprintf("hash_to_%s.go", point.PointName)), Templates: []string{"hash_to_curve.go.tmpl", "sswu.go.tmpl", "svdw.go.tmpl"}},
    30  			{File: filepath.Join(baseDir, fmt.Sprintf("hash_to_%s_test.go", point.PointName)), Templates: []string{"tests/hash_to_curve.go.tmpl"}}}
    31  
    32  		hashConf := suite.GetInfo(conf.Fp, point, conf.Name)
    33  
    34  		funcs := make(template.FuncMap)
    35  		funcs["asElement"] = hashConf.Field.Base.WriteElement
    36  		bavardOpts := []func(*bavard.Bavard) error{bavard.Funcs(funcs)}
    37  
    38  		return bgen.GenerateWithOptions(hashConf, packageName, "./ecc/template", bavardOpts, entries...)
    39  	}
    40  
    41  	if err := genHashToCurve(&conf.G1, conf.HashE1); err != nil {
    42  		return err
    43  	}
    44  	if err := genHashToCurve(&conf.G2, conf.HashE2); err != nil {
    45  		return err
    46  	}
    47  
    48  	// G1
    49  	entries = []bavard.Entry{
    50  		{File: filepath.Join(baseDir, "g1.go"), Templates: []string{"point.go.tmpl"}},
    51  		{File: filepath.Join(baseDir, "g1_test.go"), Templates: []string{"tests/point.go.tmpl"}},
    52  	}
    53  	// if not secp256k1, generate the lagrange transform
    54  	if conf.Name != config.SECP256K1.Name {
    55  		os.Remove(filepath.Join(baseDir, "g1_lagrange.go"))
    56  		os.Remove(filepath.Join(baseDir, "g1_lagrange_test.go"))
    57  		// entries = append(entries, bavard.Entry{File: filepath.Join(baseDir, "g1_lagrange.go"), Templates: []string{"lagrange.go.tmpl"}})
    58  		// entries = append(entries, bavard.Entry{File: filepath.Join(baseDir, "g1_lagrange_test.go"), Templates: []string{"tests/lagrange.go.tmpl"}})
    59  	}
    60  
    61  	g1 := pconf{conf, conf.G1}
    62  	if err := bgen.Generate(g1, packageName, "./ecc/template", entries...); err != nil {
    63  		return err
    64  	}
    65  
    66  	// MSM
    67  	entries = []bavard.Entry{
    68  		{File: filepath.Join(baseDir, "multiexp.go"), Templates: []string{"multiexp.go.tmpl"}},
    69  		{File: filepath.Join(baseDir, "multiexp_affine.go"), Templates: []string{"multiexp_affine.go.tmpl"}},
    70  		{File: filepath.Join(baseDir, "multiexp_jacobian.go"), Templates: []string{"multiexp_jacobian.go.tmpl"}},
    71  		{File: filepath.Join(baseDir, "multiexp_test.go"), Templates: []string{"tests/multiexp.go.tmpl"}},
    72  	}
    73  	conf.Package = packageName
    74  	funcs := make(template.FuncMap)
    75  	funcs["last"] = func(x int, a interface{}) bool {
    76  		return x == reflect.ValueOf(a).Len()-1
    77  	}
    78  
    79  	// return the last window size for a scalar;
    80  	// this last window should accommodate a carry (from the NAF decomposition)
    81  	// it can be == c if we have 1 available bit
    82  	// it can be > c if we have 0 available bit
    83  	// it can be < c if we have 2+ available bits
    84  	lastC := func(c int) int {
    85  		nbChunks := (conf.Fr.NbBits + c - 1) / c
    86  		nbAvailableBits := (nbChunks * c) - conf.Fr.NbBits
    87  		lc := c + 1 - nbAvailableBits
    88  		if lc > 16 {
    89  			panic("we have a problem since we are using uint16 to store digits")
    90  		}
    91  		return lc
    92  	}
    93  	batchSize := func(c int) int {
    94  		// nbBuckets := (1 << (c - 1))
    95  		// if c <= 12 {
    96  		// 	return nbBuckets/10 + 3*c
    97  		// }
    98  		// if c <= 14 {
    99  		// 	return nbBuckets/15
   100  		// }
   101  		// return nbBuckets / 20
   102  		// TODO @gbotrel / @yelhousni this need a better heuristic
   103  		// in theory, larger batch size == less inversions
   104  		// but if nbBuckets is small, then a large batch size will produce lots of collisions
   105  		// and queue ops.
   106  		// there is probably a cache-friendliness factor at play here too.
   107  		switch c {
   108  		case 10:
   109  			return 80
   110  		case 11:
   111  			return 150
   112  		case 12:
   113  			return 200
   114  		case 13:
   115  			return 350
   116  		case 14:
   117  			return 400
   118  		case 15:
   119  			return 500
   120  		default:
   121  			return 640
   122  		}
   123  	}
   124  	funcs["lastC"] = lastC
   125  	funcs["batchSize"] = batchSize
   126  
   127  	funcs["nbBuckets"] = func(c int) int {
   128  		return 1 << (c - 1)
   129  	}
   130  
   131  	funcs["contains"] = func(v int, s []int) bool {
   132  		for _, sv := range s {
   133  			if v == sv {
   134  				return true
   135  			}
   136  		}
   137  		return false
   138  	}
   139  	lastCG1 := make([]int, 0)
   140  	for {
   141  		for i := 0; i < len(conf.G1.CRange); i++ {
   142  			lc := lastC(conf.G1.CRange[i])
   143  			if !contains(conf.G1.CRange, lc) && !contains(lastCG1, lc) {
   144  				lastCG1 = append(lastCG1, lc)
   145  			}
   146  		}
   147  		if len(lastCG1) == 0 {
   148  			break
   149  		}
   150  		conf.G1.CRange = append(conf.G1.CRange, lastCG1...)
   151  		sort.Ints(conf.G1.CRange)
   152  		lastCG1 = lastCG1[:0]
   153  	}
   154  
   155  	lastCG2 := make([]int, 0)
   156  	for {
   157  		for i := 0; i < len(conf.G2.CRange); i++ {
   158  			lc := lastC(conf.G2.CRange[i])
   159  			if !contains(conf.G2.CRange, lc) && !contains(lastCG2, lc) {
   160  				lastCG2 = append(lastCG2, lc)
   161  			}
   162  		}
   163  		if len(lastCG2) == 0 {
   164  			break
   165  		}
   166  		conf.G2.CRange = append(conf.G2.CRange, lastCG2...)
   167  		sort.Ints(conf.G2.CRange)
   168  		lastCG2 = lastCG2[:0]
   169  	}
   170  
   171  	bavardOpts := []func(*bavard.Bavard) error{bavard.Funcs(funcs)}
   172  	if err := bgen.GenerateWithOptions(conf, packageName, "./ecc/template", bavardOpts, entries...); err != nil {
   173  		return err
   174  	}
   175  
   176  	// No G2 for secp256k1
   177  	if conf.Equal(config.SECP256K1) {
   178  		return nil
   179  	}
   180  
   181  	// marshal
   182  	entries = []bavard.Entry{
   183  		{File: filepath.Join(baseDir, "marshal.go"), Templates: []string{"marshal.go.tmpl"}},
   184  		{File: filepath.Join(baseDir, "marshal_test.go"), Templates: []string{"tests/marshal.go.tmpl"}},
   185  	}
   186  
   187  	marshal := []func(*bavard.Bavard) error{bavard.Funcs(funcs)}
   188  	if err := bgen.GenerateWithOptions(conf, packageName, "./ecc/template", marshal, entries...); err != nil {
   189  		return err
   190  	}
   191  
   192  	// G2
   193  	entries = []bavard.Entry{
   194  		{File: filepath.Join(baseDir, "g2.go"), Templates: []string{"point.go.tmpl"}},
   195  		{File: filepath.Join(baseDir, "g2_test.go"), Templates: []string{"tests/point.go.tmpl"}},
   196  	}
   197  	g2 := pconf{conf, conf.G2}
   198  	return bgen.Generate(g2, packageName, "./ecc/template", entries...)
   199  }
   200  
   201  type pconf struct {
   202  	config.Curve
   203  	config.Point
   204  }
   205  
   206  func contains(slice []int, v int) bool {
   207  	for i := 0; i < len(slice); i++ {
   208  		if slice[i] == v {
   209  			return true
   210  		}
   211  	}
   212  	return false
   213  }