github.com/consensys/gnark@v0.11.0/test/unsafekzg/kzgsrs.go (about)

     1  // Package unsafekzg is a convenience package (to be use for test purposes only)
     2  // to generate and cache SRS for the kzg scheme (and indirectly for PlonK setup).
     3  //
     4  // Functions in this package are thread safe.
     5  package unsafekzg
     6  
     7  import (
     8  	"bufio"
     9  	"crypto/rand"
    10  	"fmt"
    11  	"math/big"
    12  	"os"
    13  	"path/filepath"
    14  	"regexp"
    15  	"sync"
    16  
    17  	"github.com/consensys/gnark-crypto/ecc"
    18  	"github.com/consensys/gnark-crypto/kzg"
    19  	"github.com/consensys/gnark/constraint"
    20  	"github.com/consensys/gnark/internal/utils"
    21  	"github.com/consensys/gnark/logger"
    22  
    23  	kzg_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/kzg"
    24  	kzg_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/kzg"
    25  	kzg_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/kzg"
    26  	kzg_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/kzg"
    27  	kzg_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/kzg"
    28  	kzg_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/kzg"
    29  	kzg_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/kzg"
    30  
    31  	fft_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft"
    32  	fft_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft"
    33  	fft_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft"
    34  	fft_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft"
    35  	fft_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft"
    36  	fft_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft"
    37  	fft_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft"
    38  
    39  	fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
    40  	fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
    41  	fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr"
    42  	fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr"
    43  	fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr"
    44  	fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr"
    45  	fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr"
    46  
    47  	bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377"
    48  	bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
    49  	bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315"
    50  	bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317"
    51  	"github.com/consensys/gnark-crypto/ecc/bn254"
    52  	bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633"
    53  	bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761"
    54  )
    55  
    56  var (
    57  	cache           = make(map[string]cacheEntry)
    58  	reCacheKey      = regexp.MustCompile(`kzgsrs-(.*?)-\d+`)
    59  	memLock, fsLock sync.RWMutex
    60  )
    61  
    62  // NewSRS returns a pair of kzg.SRS; one in canonical form, the other in lagrange form.
    63  // Default options use a memory cache, see Option for more details & options.
    64  func NewSRS(ccs constraint.ConstraintSystem, opts ...Option) (canonical kzg.SRS, lagrange kzg.SRS, err error) {
    65  
    66  	nbConstraints := ccs.GetNbConstraints()
    67  	sizeSystem := nbConstraints + ccs.GetNbPublicVariables()
    68  
    69  	sizeLagrange := ecc.NextPowerOfTwo(uint64(sizeSystem))
    70  	sizeCanonical := sizeLagrange + 3
    71  
    72  	curveID := utils.FieldToCurve(ccs.Field())
    73  
    74  	log := logger.Logger().With().Str("package", "kzgsrs").Int("size", int(sizeCanonical)).Str("curve", curveID.String()).Logger()
    75  
    76  	cfg, err := options(opts...)
    77  	if err != nil {
    78  		return nil, nil, err
    79  	}
    80  
    81  	key := cacheKey(curveID, sizeCanonical)
    82  	log.Debug().Str("key", key).Msg("fetching SRS from mem cache")
    83  	memLock.RLock()
    84  	entry, ok := cache[key]
    85  	memLock.RUnlock()
    86  	if ok {
    87  		log.Debug().Msg("SRS found in mem cache")
    88  		return entry.canonical, entry.lagrange, nil
    89  	}
    90  	log.Debug().Msg("SRS not found in mem cache")
    91  
    92  	if cfg.fsCache {
    93  		log.Debug().Str("key", key).Str("cacheDir", cfg.cacheDir).Msg("fetching SRS from fs cache")
    94  		fsLock.RLock()
    95  		entry, err = fsRead(key, cfg.cacheDir)
    96  		fsLock.RUnlock()
    97  		if err == nil {
    98  			log.Debug().Str("key", key).Msg("SRS found in fs cache")
    99  			canonical, lagrange = entry.canonical, entry.lagrange
   100  			memLock.Lock()
   101  			cache[key] = cacheEntry{canonical, lagrange}
   102  			memLock.Unlock()
   103  			return
   104  		} else {
   105  			log.Debug().Str("key", key).Err(err).Msg("SRS not found in fs cache")
   106  		}
   107  	}
   108  
   109  	log.Debug().Msg("SRS not found in cache, generating")
   110  
   111  	// not in cache, generate
   112  	canonical, lagrange, err = newSRS(curveID, sizeCanonical)
   113  	if err != nil {
   114  		return nil, nil, err
   115  	}
   116  
   117  	// cache it
   118  	memLock.Lock()
   119  	cache[key] = cacheEntry{canonical, lagrange}
   120  	memLock.Unlock()
   121  
   122  	if cfg.fsCache {
   123  		log.Debug().Str("key", key).Str("cacheDir", cfg.cacheDir).Msg("writing SRS to fs cache")
   124  		fsLock.Lock()
   125  		fsWrite(key, cfg.cacheDir, canonical, lagrange)
   126  		fsLock.Unlock()
   127  	}
   128  
   129  	return canonical, lagrange, nil
   130  }
   131  
   132  type cacheEntry struct {
   133  	canonical kzg.SRS
   134  	lagrange  kzg.SRS
   135  }
   136  
   137  func cacheKey(curveID ecc.ID, size uint64) string {
   138  	return fmt.Sprintf("kzgsrs-%s-%d", curveID.String(), size)
   139  }
   140  
   141  func extractCurveID(key string) (ecc.ID, error) {
   142  	matches := reCacheKey.FindStringSubmatch(key)
   143  
   144  	if len(matches) < 2 {
   145  		return ecc.UNKNOWN, fmt.Errorf("no curveID found in key")
   146  	}
   147  	return ecc.IDFromString(matches[1])
   148  }
   149  
   150  func newSRS(curveID ecc.ID, size uint64) (kzg.SRS, kzg.SRS, error) {
   151  
   152  	tau, err := rand.Int(rand.Reader, curveID.ScalarField())
   153  	if err != nil {
   154  		return nil, nil, err
   155  	}
   156  
   157  	var srs kzg.SRS
   158  
   159  	switch curveID {
   160  	case ecc.BN254:
   161  		srs, err = kzg_bn254.NewSRS(size, tau)
   162  	case ecc.BLS12_381:
   163  		srs, err = kzg_bls12381.NewSRS(size, tau)
   164  	case ecc.BLS12_377:
   165  		srs, err = kzg_bls12377.NewSRS(size, tau)
   166  	case ecc.BW6_761:
   167  		srs, err = kzg_bw6761.NewSRS(size, tau)
   168  	case ecc.BLS24_317:
   169  		srs, err = kzg_bls24317.NewSRS(size, tau)
   170  	case ecc.BLS24_315:
   171  		srs, err = kzg_bls24315.NewSRS(size, tau)
   172  	case ecc.BW6_633:
   173  		srs, err = kzg_bw6633.NewSRS(size, tau)
   174  	default:
   175  		panic("unrecognized R1CS curve type")
   176  	}
   177  
   178  	if err != nil {
   179  		return nil, nil, err
   180  	}
   181  
   182  	return srs, toLagrange(srs, tau), nil
   183  }
   184  
   185  func toLagrange(canonicalSRS kzg.SRS, tau *big.Int) kzg.SRS {
   186  
   187  	var lagrangeSRS kzg.SRS
   188  
   189  	switch srs := canonicalSRS.(type) {
   190  	case *kzg_bn254.SRS:
   191  		newSRS := &kzg_bn254.SRS{Vk: srs.Vk}
   192  		size := uint64(len(srs.Pk.G1)) - 3
   193  
   194  		// instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha
   195  		// since we know the randomness in test.
   196  		pAlpha := make([]fr_bn254.Element, size)
   197  		pAlpha[0].SetUint64(1)
   198  		pAlpha[1].SetBigInt(tau)
   199  		for i := 2; i < len(pAlpha); i++ {
   200  			pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1])
   201  		}
   202  		// do a fft on this.
   203  		d := fft_bn254.NewDomain(size)
   204  		d.FFTInverse(pAlpha, fft_bn254.DIF)
   205  		fft_bn254.BitReverse(pAlpha)
   206  
   207  		// bath scalar mul
   208  		_, _, g1gen, _ := bn254.Generators()
   209  		newSRS.Pk.G1 = bn254.BatchScalarMultiplicationG1(&g1gen, pAlpha)
   210  
   211  		lagrangeSRS = newSRS
   212  	case *kzg_bls12381.SRS:
   213  		newSRS := &kzg_bls12381.SRS{Vk: srs.Vk}
   214  		size := uint64(len(srs.Pk.G1)) - 3
   215  
   216  		// instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha
   217  		// since we know the randomness in test.
   218  		pAlpha := make([]fr_bls12381.Element, size)
   219  		pAlpha[0].SetUint64(1)
   220  		pAlpha[1].SetBigInt(tau)
   221  		for i := 2; i < len(pAlpha); i++ {
   222  			pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1])
   223  		}
   224  		// do a fft on this.
   225  		d := fft_bls12381.NewDomain(size)
   226  		d.FFTInverse(pAlpha, fft_bls12381.DIF)
   227  		fft_bls12381.BitReverse(pAlpha)
   228  
   229  		// bath scalar mul
   230  		_, _, g1gen, _ := bls12381.Generators()
   231  		newSRS.Pk.G1 = bls12381.BatchScalarMultiplicationG1(&g1gen, pAlpha)
   232  
   233  		lagrangeSRS = newSRS
   234  	case *kzg_bls12377.SRS:
   235  		newSRS := &kzg_bls12377.SRS{Vk: srs.Vk}
   236  		size := uint64(len(srs.Pk.G1)) - 3
   237  
   238  		// instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha
   239  		// since we know the randomness in test.
   240  		pAlpha := make([]fr_bls12377.Element, size)
   241  		pAlpha[0].SetUint64(1)
   242  		pAlpha[1].SetBigInt(tau)
   243  		for i := 2; i < len(pAlpha); i++ {
   244  			pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1])
   245  		}
   246  		// do a fft on this.
   247  		d := fft_bls12377.NewDomain(size)
   248  		d.FFTInverse(pAlpha, fft_bls12377.DIF)
   249  		fft_bls12377.BitReverse(pAlpha)
   250  
   251  		// bath scalar mul
   252  		_, _, g1gen, _ := bls12377.Generators()
   253  		newSRS.Pk.G1 = bls12377.BatchScalarMultiplicationG1(&g1gen, pAlpha)
   254  
   255  		lagrangeSRS = newSRS
   256  	case *kzg_bw6761.SRS:
   257  		newSRS := &kzg_bw6761.SRS{Vk: srs.Vk}
   258  		size := uint64(len(srs.Pk.G1)) - 3
   259  
   260  		// instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha
   261  		// since we know the randomness in test.
   262  		pAlpha := make([]fr_bw6761.Element, size)
   263  		pAlpha[0].SetUint64(1)
   264  		pAlpha[1].SetBigInt(tau)
   265  		for i := 2; i < len(pAlpha); i++ {
   266  			pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1])
   267  		}
   268  
   269  		// do a fft on this.
   270  		d := fft_bw6761.NewDomain(size)
   271  		d.FFTInverse(pAlpha, fft_bw6761.DIF)
   272  		fft_bw6761.BitReverse(pAlpha)
   273  
   274  		// bath scalar mul
   275  		_, _, g1gen, _ := bw6761.Generators()
   276  		newSRS.Pk.G1 = bw6761.BatchScalarMultiplicationG1(&g1gen, pAlpha)
   277  
   278  		lagrangeSRS = newSRS
   279  	case *kzg_bls24317.SRS:
   280  		newSRS := &kzg_bls24317.SRS{Vk: srs.Vk}
   281  		size := uint64(len(srs.Pk.G1)) - 3
   282  
   283  		// instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha
   284  		// since we know the randomness in test.
   285  		pAlpha := make([]fr_bls24317.Element, size)
   286  		pAlpha[0].SetUint64(1)
   287  		pAlpha[1].SetBigInt(tau)
   288  		for i := 2; i < len(pAlpha); i++ {
   289  			pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1])
   290  		}
   291  
   292  		// do a fft on this.
   293  		d := fft_bls24317.NewDomain(size)
   294  		d.FFTInverse(pAlpha, fft_bls24317.DIF)
   295  		fft_bls24317.BitReverse(pAlpha)
   296  
   297  		// bath scalar mul
   298  		_, _, g1gen, _ := bls24317.Generators()
   299  		newSRS.Pk.G1 = bls24317.BatchScalarMultiplicationG1(&g1gen, pAlpha)
   300  
   301  		lagrangeSRS = newSRS
   302  	case *kzg_bls24315.SRS:
   303  		newSRS := &kzg_bls24315.SRS{Vk: srs.Vk}
   304  		size := uint64(len(srs.Pk.G1)) - 3
   305  
   306  		// instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha
   307  		// since we know the randomness in test.
   308  		pAlpha := make([]fr_bls24315.Element, size)
   309  		pAlpha[0].SetUint64(1)
   310  		pAlpha[1].SetBigInt(tau)
   311  		for i := 2; i < len(pAlpha); i++ {
   312  			pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1])
   313  		}
   314  
   315  		// do a fft on this.
   316  		d := fft_bls24315.NewDomain(size)
   317  		d.FFTInverse(pAlpha, fft_bls24315.DIF)
   318  		fft_bls24315.BitReverse(pAlpha)
   319  
   320  		// bath scalar mul
   321  		_, _, g1gen, _ := bls24315.Generators()
   322  		newSRS.Pk.G1 = bls24315.BatchScalarMultiplicationG1(&g1gen, pAlpha)
   323  
   324  		lagrangeSRS = newSRS
   325  	case *kzg_bw6633.SRS:
   326  		newSRS := &kzg_bw6633.SRS{Vk: srs.Vk}
   327  		size := uint64(len(srs.Pk.G1)) - 3
   328  
   329  		// instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha
   330  		// since we know the randomness in test.
   331  		pAlpha := make([]fr_bw6633.Element, size)
   332  		pAlpha[0].SetUint64(1)
   333  		pAlpha[1].SetBigInt(tau)
   334  		for i := 2; i < len(pAlpha); i++ {
   335  			pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1])
   336  		}
   337  
   338  		// do a fft on this.
   339  		d := fft_bw6633.NewDomain(size)
   340  		d.FFTInverse(pAlpha, fft_bw6633.DIF)
   341  		fft_bw6633.BitReverse(pAlpha)
   342  
   343  		// bath scalar mul
   344  		_, _, g1gen, _ := bw6633.Generators()
   345  		newSRS.Pk.G1 = bw6633.BatchScalarMultiplicationG1(&g1gen, pAlpha)
   346  
   347  		lagrangeSRS = newSRS
   348  	default:
   349  		panic("unrecognized curve")
   350  	}
   351  
   352  	return lagrangeSRS
   353  }
   354  
   355  func fsRead(key string, cacheDir string) (cacheEntry, error) {
   356  	filePath := filepath.Join(cacheDir, key)
   357  
   358  	// if file does not exist, return false
   359  	if _, err := os.Stat(filePath); os.IsNotExist(err) {
   360  		return cacheEntry{}, fmt.Errorf("file %s does not exist", filePath)
   361  	}
   362  
   363  	// else open file and read the srs.
   364  	f, err := os.Open(filePath)
   365  	if err != nil {
   366  		return cacheEntry{}, err
   367  	}
   368  	defer f.Close()
   369  
   370  	r := bufio.NewReaderSize(f, 1<<20)
   371  
   372  	curveID, err := extractCurveID(key)
   373  	if err != nil {
   374  		return cacheEntry{}, err
   375  	}
   376  	cacheEntry := cacheEntry{
   377  		canonical: kzg.NewSRS(curveID),
   378  		lagrange:  kzg.NewSRS(curveID),
   379  	}
   380  	_, err = cacheEntry.canonical.UnsafeReadFrom(r)
   381  	if err != nil {
   382  		return cacheEntry, err
   383  	}
   384  	_, err = cacheEntry.lagrange.UnsafeReadFrom(r)
   385  	if err != nil {
   386  		return cacheEntry, err
   387  	}
   388  
   389  	return cacheEntry, nil
   390  }
   391  
   392  func fsWrite(key string, cacheDir string, canonical kzg.SRS, lagrange kzg.SRS) {
   393  	// if file exist, return.
   394  	filePath := filepath.Join(cacheDir, key)
   395  	if _, err := os.Stat(filePath); err == nil {
   396  		return
   397  	}
   398  
   399  	// else open file and write the srs.
   400  	f, err := os.Create(filePath)
   401  	if err != nil {
   402  		return
   403  	}
   404  	defer f.Close()
   405  
   406  	w := bufio.NewWriterSize(f, 1<<20)
   407  
   408  	if _, err = canonical.WriteRawTo(w); err != nil {
   409  		return
   410  	}
   411  
   412  	if _, err = lagrange.WriteRawTo(w); err != nil {
   413  		return
   414  	}
   415  
   416  	w.Flush()
   417  }