github.com/consensys/gnark-crypto@v0.14.0/field/generator/config/field_config.go (about)

     1  // Copyright 2020 ConsenSys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package config provides Golang code generation for efficient field arithmetic operations.
    16  package config
    17  
    18  import (
    19  	"errors"
    20  	"fmt"
    21  	"math"
    22  	"math/big"
    23  	"strconv"
    24  	"strings"
    25  
    26  	"github.com/consensys/bavard"
    27  	"github.com/consensys/gnark-crypto/field/generator/internal/addchain"
    28  )
    29  
    30  var (
    31  	errParseModulus = errors.New("can't parse modulus")
    32  )
    33  
    34  // FieldConfig precomputed values used in template for code generation of field element APIs
    35  type FieldConfig struct {
    36  	PackageName               string
    37  	ElementName               string
    38  	ModulusBig                *big.Int
    39  	Modulus                   string
    40  	ModulusHex                string
    41  	NbWords                   int
    42  	NbBits                    int
    43  	NbBytes                   int
    44  	NbWordsLastIndex          int
    45  	NbWordsIndexesNoZero      []int
    46  	NbWordsIndexesFull        []int
    47  	P20InversionCorrectiveFac []uint64
    48  	P20InversionNbIterations  int
    49  	UsingP20Inverse           bool
    50  	IsMSWSaturated            bool // indicates if the most significant word is 0xFFFFF...FFFF
    51  	Q                         []uint64
    52  	QInverse                  []uint64
    53  	QMinusOneHalvedP          []uint64 // ((q-1) / 2 ) + 1
    54  	ASM                       bool
    55  	RSquare                   []uint64
    56  	One, Thirteen             []uint64
    57  	LegendreExponent          string // big.Int to base16 string
    58  	NoCarry                   bool
    59  	NoCarrySquare             bool // used if NoCarry is set, but some op may overflow in square optimization
    60  	SqrtQ3Mod4                bool
    61  	SqrtAtkin                 bool
    62  	SqrtTonelliShanks         bool
    63  	SqrtE                     uint64
    64  	SqrtS                     []uint64
    65  	SqrtAtkinExponent         string   // big.Int to base16 string
    66  	SqrtSMinusOneOver2        string   // big.Int to base16 string
    67  	SqrtQ3Mod4Exponent        string   // big.Int to base16 string
    68  	SqrtG                     []uint64 // NonResidue ^  SqrtR (montgomery form)
    69  	NonResidue                big.Int  // (montgomery form)
    70  	LegendreExponentData      *addchain.AddChainData
    71  	SqrtAtkinExponentData     *addchain.AddChainData
    72  	SqrtSMinusOneOver2Data    *addchain.AddChainData
    73  	SqrtQ3Mod4ExponentData    *addchain.AddChainData
    74  	UseAddChain               bool
    75  }
    76  
    77  // NewFieldConfig returns a data structure with needed information to generate apis for field element
    78  //
    79  // See field/generator package
    80  func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) (*FieldConfig, error) {
    81  	// parse modulus
    82  	var bModulus big.Int
    83  	if _, ok := bModulus.SetString(modulus, 0); !ok {
    84  		return nil, errParseModulus
    85  	}
    86  
    87  	// field info
    88  	F := &FieldConfig{
    89  		PackageName: packageName,
    90  		ElementName: elementName,
    91  		Modulus:     bModulus.Text(10),
    92  		ModulusHex:  bModulus.Text(16),
    93  		ModulusBig:  new(big.Int).Set(&bModulus),
    94  		UseAddChain: useAddChain,
    95  	}
    96  	// pre compute field constants
    97  	F.NbBits = bModulus.BitLen()
    98  	F.NbWords = len(bModulus.Bits())
    99  	F.NbBytes = F.NbWords * 8 // (F.NbBits + 7) / 8
   100  
   101  	F.NbWordsLastIndex = F.NbWords - 1
   102  
   103  	// set q from big int repr
   104  	F.Q = toUint64Slice(&bModulus)
   105  	F.IsMSWSaturated = F.Q[len(F.Q)-1] == math.MaxUint64
   106  	_qHalved := big.NewInt(0)
   107  	bOne := new(big.Int).SetUint64(1)
   108  	_qHalved.Sub(&bModulus, bOne).Rsh(_qHalved, 1).Add(_qHalved, bOne)
   109  	F.QMinusOneHalvedP = toUint64Slice(_qHalved, F.NbWords)
   110  
   111  	//  setting qInverse
   112  	_r := big.NewInt(1)
   113  	_r.Lsh(_r, uint(F.NbWords)*64)
   114  	_rInv := big.NewInt(1)
   115  	_qInv := big.NewInt(0)
   116  	extendedEuclideanAlgo(_r, &bModulus, _rInv, _qInv)
   117  	_qInv.Mod(_qInv, _r)
   118  	F.QInverse = toUint64Slice(_qInv, F.NbWords)
   119  
   120  	// Pornin20 inversion correction factors
   121  	k := 32 // Optimized for 64 bit machines, still works for 32
   122  
   123  	p20InvInnerLoopNbIterations := 2*F.NbBits - 1
   124  	// if constant time inversion then p20InvInnerLoopNbIterations-- (among other changes)
   125  	F.P20InversionNbIterations = (p20InvInnerLoopNbIterations-1)/(k-1) + 1 // ⌈ (2 * field size - 1) / (k-1) ⌉
   126  	F.P20InversionNbIterations += F.P20InversionNbIterations % 2           // "round up" to a multiple of 2
   127  
   128  	kLimbs := k * F.NbWords
   129  	p20InversionCorrectiveFacPower := kLimbs*6 + F.P20InversionNbIterations*(kLimbs-k+1)
   130  	p20InversionCorrectiveFac := big.NewInt(1)
   131  	p20InversionCorrectiveFac.Lsh(p20InversionCorrectiveFac, uint(p20InversionCorrectiveFacPower))
   132  	p20InversionCorrectiveFac.Mod(p20InversionCorrectiveFac, &bModulus)
   133  	F.P20InversionCorrectiveFac = toUint64Slice(p20InversionCorrectiveFac, F.NbWords)
   134  
   135  	{
   136  		c := F.NbWords * 64
   137  		F.UsingP20Inverse = F.NbWords > 1 && F.NbBits < c
   138  	}
   139  
   140  	// rsquare
   141  	_rSquare := big.NewInt(2)
   142  	exponent := big.NewInt(int64(F.NbWords) * 64 * 2)
   143  	_rSquare.Exp(_rSquare, exponent, &bModulus)
   144  	F.RSquare = toUint64Slice(_rSquare, F.NbWords)
   145  
   146  	var one big.Int
   147  	one.SetUint64(1)
   148  	one.Lsh(&one, uint(F.NbWords)*64).Mod(&one, &bModulus)
   149  	F.One = toUint64Slice(&one, F.NbWords)
   150  
   151  	{
   152  		var n big.Int
   153  		n.SetUint64(13)
   154  		n.Lsh(&n, uint(F.NbWords)*64).Mod(&n, &bModulus)
   155  		F.Thirteen = toUint64Slice(&n, F.NbWords)
   156  	}
   157  
   158  	// indexes (template helpers)
   159  	F.NbWordsIndexesFull = make([]int, F.NbWords)
   160  	F.NbWordsIndexesNoZero = make([]int, F.NbWords-1)
   161  	for i := 0; i < F.NbWords; i++ {
   162  		F.NbWordsIndexesFull[i] = i
   163  		if i > 0 {
   164  			F.NbWordsIndexesNoZero[i-1] = i
   165  		}
   166  	}
   167  
   168  	// See https://hackmd.io/@gnark/modular_multiplication
   169  	// if the last word of the modulus is smaller or equal to B,
   170  	// we can simplify the montgomery multiplication
   171  	const B = (^uint64(0) >> 1) - 1
   172  	F.NoCarry = (F.Q[len(F.Q)-1] <= B) && F.NbWords <= 12
   173  	const BSquare = ^uint64(0) >> 2
   174  	F.NoCarrySquare = F.Q[len(F.Q)-1] <= BSquare
   175  
   176  	// Legendre exponent (p-1)/2
   177  	var legendreExponent big.Int
   178  	legendreExponent.SetUint64(1)
   179  	legendreExponent.Sub(&bModulus, &legendreExponent)
   180  	legendreExponent.Rsh(&legendreExponent, 1)
   181  	F.LegendreExponent = legendreExponent.Text(16)
   182  	if F.UseAddChain {
   183  		F.LegendreExponentData = addchain.GetAddChain(&legendreExponent)
   184  	}
   185  
   186  	// Sqrt pre computes
   187  	var qMod big.Int
   188  	qMod.SetUint64(4)
   189  	if qMod.Mod(&bModulus, &qMod).Cmp(new(big.Int).SetUint64(3)) == 0 {
   190  		// q ≡ 3 (mod 4)
   191  		// using  z ≡ ± x^((p+1)/4) (mod q)
   192  		F.SqrtQ3Mod4 = true
   193  		var sqrtExponent big.Int
   194  		sqrtExponent.SetUint64(1)
   195  		sqrtExponent.Add(&bModulus, &sqrtExponent)
   196  		sqrtExponent.Rsh(&sqrtExponent, 2)
   197  		F.SqrtQ3Mod4Exponent = sqrtExponent.Text(16)
   198  
   199  		// add chain stuff
   200  		if F.UseAddChain {
   201  			F.SqrtQ3Mod4ExponentData = addchain.GetAddChain(&sqrtExponent)
   202  		}
   203  
   204  	} else {
   205  		// q ≡ 1 (mod 4)
   206  		qMod.SetUint64(8)
   207  		if qMod.Mod(&bModulus, &qMod).Cmp(new(big.Int).SetUint64(5)) == 0 {
   208  			// q ≡ 5 (mod 8)
   209  			// use Atkin's algorithm
   210  			// see modSqrt5Mod8Prime in math/big/int.go
   211  			F.SqrtAtkin = true
   212  			e := new(big.Int).Rsh(&bModulus, 3) // e = (q - 5) / 8
   213  			F.SqrtAtkinExponent = e.Text(16)
   214  			if F.UseAddChain {
   215  				F.SqrtAtkinExponentData = addchain.GetAddChain(e)
   216  			}
   217  		} else {
   218  			// use Tonelli-Shanks
   219  			F.SqrtTonelliShanks = true
   220  
   221  			// Write q-1 =2ᵉ * s , s odd
   222  			var s big.Int
   223  			one.SetUint64(1)
   224  			s.Sub(&bModulus, &one)
   225  
   226  			e := s.TrailingZeroBits()
   227  			s.Rsh(&s, e)
   228  			F.SqrtE = uint64(e)
   229  			F.SqrtS = toUint64Slice(&s)
   230  
   231  			// find non residue
   232  			var nonResidue big.Int
   233  			nonResidue.SetInt64(2)
   234  			one.SetUint64(1)
   235  			for big.Jacobi(&nonResidue, &bModulus) != -1 {
   236  				nonResidue.Add(&nonResidue, &one)
   237  			}
   238  
   239  			// g = nonresidue ^ s
   240  			var g big.Int
   241  			g.Exp(&nonResidue, &s, &bModulus)
   242  			// store g in montgomery form
   243  			g.Lsh(&g, uint(F.NbWords)*64).Mod(&g, &bModulus)
   244  			F.SqrtG = toUint64Slice(&g, F.NbWords)
   245  
   246  			// store non residue in montgomery form
   247  			F.NonResidue = F.ToMont(nonResidue)
   248  
   249  			// (s+1) /2
   250  			s.Sub(&s, &one).Rsh(&s, 1)
   251  			F.SqrtSMinusOneOver2 = s.Text(16)
   252  
   253  			if F.UseAddChain {
   254  				F.SqrtSMinusOneOver2Data = addchain.GetAddChain(&s)
   255  			}
   256  		}
   257  	}
   258  
   259  	// note: to simplify output files generated, we generated ASM code only for
   260  	// moduli that meet the condition F.NoCarry
   261  	// asm code generation for moduli with more than 6 words can be optimized further
   262  	F.ASM = F.NoCarry && F.NbWords <= 12 && F.NbWords > 1
   263  
   264  	return F, nil
   265  }
   266  
   267  func toUint64Slice(b *big.Int, nbWords ...int) (s []uint64) {
   268  	if len(nbWords) > 0 && nbWords[0] > len(b.Bits()) {
   269  		s = make([]uint64, nbWords[0])
   270  	} else {
   271  		s = make([]uint64, len(b.Bits()))
   272  	}
   273  
   274  	for i, v := range b.Bits() {
   275  		s[i] = (uint64)(v)
   276  	}
   277  	return
   278  }
   279  
   280  // https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
   281  // r > q, modifies rinv and qinv such that rinv.r - qinv.q = 1
   282  func extendedEuclideanAlgo(r, q, rInv, qInv *big.Int) {
   283  	var s1, s2, t1, t2, qi, tmpMuls, riPlusOne, tmpMult, a, b big.Int
   284  	t1.SetUint64(1)
   285  	rInv.Set(big.NewInt(1))
   286  	qInv.Set(big.NewInt(0))
   287  	a.Set(r)
   288  	b.Set(q)
   289  
   290  	// r_i+1 = r_i-1 - q_i.r_i
   291  	// s_i+1 = s_i-1 - q_i.s_i
   292  	// t_i+1 = t_i-1 - q_i.s_i
   293  	for b.Sign() > 0 {
   294  		qi.Div(&a, &b)
   295  		riPlusOne.Mod(&a, &b)
   296  
   297  		tmpMuls.Mul(&s1, &qi)
   298  		tmpMult.Mul(&t1, &qi)
   299  
   300  		s2.Set(&s1)
   301  		t2.Set(&t1)
   302  
   303  		s1.Sub(rInv, &tmpMuls)
   304  		t1.Sub(qInv, &tmpMult)
   305  		rInv.Set(&s2)
   306  		qInv.Set(&t2)
   307  
   308  		a.Set(&b)
   309  		b.Set(&riPlusOne)
   310  	}
   311  	qInv.Neg(qInv)
   312  }
   313  
   314  // StringToMont takes an element written in string form, and returns it in Montgomery form
   315  // Useful for hard-coding in implementation field elements from standards documents
   316  func (f *FieldConfig) StringToMont(str string) big.Int {
   317  
   318  	var i big.Int
   319  	i.SetString(str, 0)
   320  	i = f.ToMont(i)
   321  
   322  	return i
   323  }
   324  
   325  func (f *FieldConfig) ToMont(nonMont big.Int) big.Int {
   326  	var mont big.Int
   327  	mont.Lsh(&nonMont, uint(f.NbWords)*64)
   328  	mont.Mod(&mont, f.ModulusBig)
   329  	return mont
   330  }
   331  
   332  func (f *FieldConfig) FromMont(nonMont *big.Int, mont *big.Int) *FieldConfig {
   333  
   334  	if f.NbWords == 0 {
   335  		nonMont.SetInt64(0)
   336  		return f
   337  	}
   338  	f.halve(nonMont, mont)
   339  	for i := 1; i < f.NbWords*64; i++ {
   340  		f.halve(nonMont, nonMont)
   341  	}
   342  
   343  	return f
   344  }
   345  
   346  func (f *FieldConfig) Exp(res *big.Int, x *big.Int, pow *big.Int) *FieldConfig {
   347  	res.SetInt64(1)
   348  
   349  	for i := pow.BitLen() - 1; ; {
   350  
   351  		if pow.Bit(i) == 1 {
   352  			res.Mul(res, x)
   353  		}
   354  
   355  		if i == 0 {
   356  			break
   357  		}
   358  		i--
   359  
   360  		res.Mul(res, res).Mod(res, f.ModulusBig)
   361  	}
   362  
   363  	res.Mod(res, f.ModulusBig)
   364  	return f
   365  }
   366  
   367  func (f *FieldConfig) halve(res *big.Int, x *big.Int) {
   368  	var z big.Int
   369  	if x.Bit(0) == 0 {
   370  		z.Set(x)
   371  	} else {
   372  		z.Add(x, f.ModulusBig)
   373  	}
   374  	res.Rsh(&z, 1)
   375  }
   376  
   377  func (f *FieldConfig) Mul(z *big.Int, x *big.Int, y *big.Int) *FieldConfig {
   378  	z.Mul(x, y).Mod(z, f.ModulusBig)
   379  	return f
   380  }
   381  
   382  func (f *FieldConfig) Add(z *big.Int, x *big.Int, y *big.Int) *FieldConfig {
   383  	z.Add(x, y).Mod(z, f.ModulusBig)
   384  	return f
   385  }
   386  
   387  func (f *FieldConfig) ToMontSlice(x []big.Int) []big.Int {
   388  	z := make(Element, len(x))
   389  	for i := 0; i < len(x); i++ {
   390  		z[i] = f.ToMont(x[i])
   391  	}
   392  	return z
   393  }
   394  
   395  // TODO: Spaghetti Alert: Okay to have codegen functions here?
   396  func CoordNameForExtensionDegree(degree uint8) string {
   397  	switch degree {
   398  	case 1:
   399  		return ""
   400  	case 2:
   401  		return "A"
   402  	case 6:
   403  		return "B"
   404  	case 12:
   405  		return "C"
   406  	}
   407  	panic(fmt.Sprint("unknown extension degree", degree))
   408  }
   409  
   410  func (f *FieldConfig) WriteElement(element Element) string {
   411  	var builder strings.Builder
   412  
   413  	builder.WriteString("{")
   414  	length := len(element)
   415  	var subElementNames string
   416  	if length > 1 {
   417  		builder.WriteString("\n")
   418  		subElementNames = CoordNameForExtensionDegree(uint8(length))
   419  	}
   420  	for i, e := range element {
   421  		if length > 1 {
   422  			builder.WriteString(subElementNames)
   423  			builder.WriteString(strconv.Itoa(i))
   424  			builder.WriteString(": fp.Element{")
   425  		}
   426  		mont := f.ToMont(e)
   427  		bavard.WriteBigIntAsUint64Slice(&builder, &mont)
   428  		if length > 1 {
   429  			builder.WriteString("},\n")
   430  		}
   431  	}
   432  	builder.WriteString("}")
   433  	return builder.String()
   434  }