github.com/consensys/gnark-crypto@v0.14.0/field/generator/internal/templates/element/inverse.go (about)

     1  package element
     2  
     3  const Inverse = `
     4  
     5  {{ define "addQ" }}
     6  if b != 0 {
     7  	// z[{{.NbWordsLastIndex}}] = -1
     8  	// negative: add q
     9  	const neg1 = 0xFFFFFFFFFFFFFFFF
    10  
    11  	var carry uint64
    12  	{{$lastIndex := sub .NbWords 1}}
    13  	{{- range $i :=  iterate 0 $lastIndex}}
    14  	z[{{$i}}], carry = bits.Add64(z[{{$i}}], q{{$i}}, {{- if eq $i 0}}0{{- else}}carry{{- end}})
    15  	{{- end}}
    16  	z[{{.NbWordsLastIndex}}], _ = bits.Add64(neg1, q{{$.NbWordsLastIndex}}, carry)
    17  }
    18  {{- end}}
    19  
    20  {{/* We use big.Int for Inverse for these type of moduli */}}
    21  {{if not $.UsingP20Inverse}}
    22  
    23  {{- if eq .NbWords 1}}
    24  // Inverse z = x⁻¹ (mod q) 
    25  //
    26  // if x == 0, sets and returns z = x 
    27  func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} {
    28  	// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography"
    29  	const q uint64 = q0
    30  	if x.IsZero() {
    31  		z.SetZero()
    32  		return z
    33  	}
    34  
    35  	var r,s,u,v uint64
    36  	u = q
    37  	s = {{index .RSquare 0}} // s = r²
    38  	r = 0
    39  	v = x[0]
    40  
    41  	var carry, borrow uint64
    42  
    43  	for  (u != 1) && (v != 1){
    44  		for v&1 == 0 {
    45  			v >>= 1
    46  			if s&1 == 0 {
    47  				s >>= 1
    48  			} else {
    49  				s, carry = bits.Add64(s, q, 0)
    50  				s >>= 1
    51  				if carry != 0 {
    52  					s |= (1 << 63)
    53  				}
    54  			}
    55  		} 
    56  		for u&1 == 0 {
    57  			u >>= 1
    58  			if r&1 == 0 {
    59  				r >>= 1
    60  			} else {
    61  				r, carry = bits.Add64(r, q, 0)
    62  				r >>= 1
    63  				if carry != 0 {
    64  					r |= (1 << 63)
    65  				}
    66  			}
    67  		} 
    68  		if v >= u  {
    69  			v -= u
    70  			s, borrow = bits.Sub64(s, r, 0)
    71  			if borrow == 1 {
    72  				s += q
    73  			}
    74  		} else {
    75  			u -= v
    76  			r, borrow = bits.Sub64(r, s, 0)
    77  			if borrow == 1 {
    78  				r += q
    79  			}
    80  		}
    81  	}
    82  
    83  	if u == 1 {
    84  		z[0] = r
    85  	} else {
    86  		z[0] = s
    87  	}
    88  	
    89  	return z
    90  }
    91  {{- else}}
    92  // Inverse z = x⁻¹ (mod q) 
    93  //
    94  // note: allocates a big.Int (math/big)
    95  func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} {
    96  	var _xNonMont big.Int
    97  	x.BigInt(&_xNonMont)
    98  	_xNonMont.ModInverse(&_xNonMont, Modulus())
    99  	z.SetBigInt(&_xNonMont)
   100  	return z
   101  }
   102  {{- end}}
   103  
   104  {{ else }}
   105  
   106  const (
   107  	k = 32 // word size / 2
   108  	signBitSelector = uint64(1) << 63
   109  	approxLowBitsN = k - 1
   110  	approxHighBitsN = k + 1
   111  )
   112  
   113  const (
   114  {{- range $i := .NbWordsIndexesFull}}
   115  inversionCorrectionFactorWord{{$i}} = {{index $.P20InversionCorrectiveFac $i}}
   116  {{- end}}
   117  invIterationsN = {{.P20InversionNbIterations}}
   118  )
   119  
   120  // Inverse z = x⁻¹ (mod q)
   121  //
   122  // if x == 0, sets and returns z = x
   123  func (z *{{.ElementName}}) Inverse(x *{{.ElementName}}) *{{.ElementName}} {
   124  	// Implements "Optimized Binary GCD for Modular Inversion"
   125  	// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf
   126  
   127  	a := *x
   128  	b := {{.ElementName}} {
   129  		{{- range $i := .NbWordsIndexesFull}}
   130  		q{{$i}},{{end}}
   131  	}	// b := q
   132  
   133  	u := {{.ElementName}}{1}
   134  
   135  	// Update factors: we get [u; v] ← [f₀ g₀; f₁ g₁] [u; v]
   136  	// cᵢ = fᵢ + 2³¹ - 1 + 2³² * (gᵢ + 2³¹ - 1)
   137   	var c0, c1 int64
   138  
   139  	// Saved update factors to reduce the number of field multiplications
   140   	var pf0, pf1, pg0, pg1 int64
   141  
   142  	var i uint
   143  
   144  	var v, s {{.ElementName}}
   145  
   146  	// Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations
   147  	// This also lets us get away with half as many updates to u,v
   148  	// To make this constant-time-ish, replace the condition with i < invIterationsN
   149  	for i = 0; i&1 == 1 || !a.IsZero(); i++ {
   150  		n := max(a.BitLen(), b.BitLen())
   151  		aApprox, bApprox := approximate(&a, n), approximate(&b, n)
   152  
   153  		// f₀, g₀, f₁, g₁ = 1, 0, 0, 1
   154   		c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1
   155  
   156  		for j := 0; j < approxLowBitsN; j++ {
   157  
   158  			// -2ʲ < f₀, f₁ ≤ 2ʲ
   159  			// |f₀| + |f₁| < 2ʲ⁺¹
   160  
   161  			if aApprox&1 == 0 {
   162  				aApprox /= 2
   163  			} else {
   164  				s, borrow := bits.Sub64(aApprox, bApprox, 0)
   165  				if borrow == 1 {
   166  					s = bApprox - aApprox
   167  					bApprox = aApprox
   168  					c0, c1 = c1, c0
   169  					// invariants unchanged
   170  				}
   171  
   172  				aApprox = s / 2
   173  				c0 = c0 - c1
   174  
   175  				// Now |f₀| < 2ʲ⁺¹ ≤ 2ʲ⁺¹ (only the weaker inequality is needed, strictly speaking)
   176                  // Started with f₀ > -2ʲ and f₁ ≤ 2ʲ, so f₀ - f₁ > -2ʲ⁺¹
   177                  // Invariants unchanged for f₁
   178  			}
   179  
   180  			c1 *= 2
   181  			// -2ʲ⁺¹ < f₁ ≤ 2ʲ⁺¹
   182              // So now |f₀| + |f₁| < 2ʲ⁺²
   183  		}
   184  
   185  		s = a
   186  
   187  		var g0 int64
   188  		// from this point on c0 aliases for f0
   189  		c0, g0 = updateFactorsDecompose(c0)
   190  		aHi := a.linearCombNonModular(&s, c0, &b, g0)
   191  		if aHi & signBitSelector != 0 {
   192  			// if aHi < 0
   193  			c0, g0 = -c0, -g0
   194  			aHi = negL(&a, aHi)
   195  		}
   196  		// right-shift a by k-1 bits
   197  
   198  		{{- range $i := .NbWordsIndexesFull}}
   199  			{{-  if eq $i $.NbWordsLastIndex}}
   200  				a[{{$i}}] = (a[{{$i}}] >> approxLowBitsN) | (aHi << approxHighBitsN)
   201  			{{-  else  }}
   202  				a[{{$i}}] = (a[{{$i}}] >> approxLowBitsN) | ((a[{{add $i 1}}]) << approxHighBitsN)
   203  			{{- end}}
   204  		{{- end}}
   205  
   206  		var f1 int64
   207  		// from this point on c1 aliases for g0
   208  		f1, c1 = updateFactorsDecompose(c1)
   209  		bHi := b.linearCombNonModular(&s, f1, &b, c1)
   210  		if bHi & signBitSelector != 0 {
   211  			// if bHi < 0
   212  			f1, c1 = -f1, -c1
   213  			bHi = negL(&b, bHi)
   214  		}
   215  		// right-shift b by k-1 bits
   216  
   217  		{{- range $i := .NbWordsIndexesFull}}
   218  			{{-  if eq $i $.NbWordsLastIndex}}
   219  				b[{{$i}}] = (b[{{$i}}] >> approxLowBitsN) | (bHi << approxHighBitsN)
   220  			{{-  else  }}
   221  				b[{{$i}}] = (b[{{$i}}] >> approxLowBitsN) | ((b[{{add $i 1}}]) << approxHighBitsN)
   222  			{{- end}}
   223  		{{- end}}
   224  
   225  		if i&1 == 1 {
   226  			// Combine current update factors with previously stored ones
   227  			// [F₀, G₀; F₁, G₁] ← [f₀, g₀; f₁, g₁] [pf₀, pg₀; pf₁, pg₁], with capital letters denoting new combined values
   228              // We get |F₀| = | f₀pf₀ + g₀pf₁ | ≤ |f₀pf₀| + |g₀pf₁| = |f₀| |pf₀| + |g₀| |pf₁| ≤ 2ᵏ⁻¹|pf₀| + 2ᵏ⁻¹|pf₁|
   229              // = 2ᵏ⁻¹ (|pf₀| + |pf₁|) < 2ᵏ⁻¹ 2ᵏ = 2²ᵏ⁻¹
   230              // So |F₀| < 2²ᵏ⁻¹ meaning it fits in a 2k-bit signed register
   231  
   232  			// c₀ aliases f₀, c₁ aliases g₁
   233  			c0, g0, f1, c1 = c0*pf0+g0*pf1,
   234  				c0*pg0+g0*pg1,
   235  				f1*pf0+c1*pf1,
   236  				f1*pg0+c1*pg1
   237  
   238  			s = u
   239  
   240  			// 0 ≤ u, v < 2²⁵⁵
   241              // |F₀|, |G₀| < 2⁶³
   242              u.linearComb(&u, c0, &v, g0)
   243              // |F₁|, |G₁| < 2⁶³
   244              v.linearComb(&s, f1, &v, c1)
   245  
   246  		} else {
   247  			// Save update factors
   248  			pf0, pg0, pf1, pg1 = c0, g0, f1, c1
   249  		}
   250  	}
   251  
   252  	// For every iteration that we miss, v is not being multiplied by 2ᵏ⁻²
   253  	const pSq uint64 = 1 << (2 * (k - 1))
   254  	a = {{.ElementName}}{pSq}
   255  	// If the function is constant-time ish, this loop will not run (no need to take it out explicitly)
   256  	for ; i < invIterationsN; i += 2 {
   257  		// could optimize further with mul by word routine or by pre-computing a table since with k=26,
   258  		// we would multiply by pSq up to 13times;
   259  		// on x86, the assembly routine outperforms generic code for mul by word
   260  		// on arm64, we may loose up to ~5% for 6 limbs
   261  		v.Mul(&v, &a)
   262  	}
   263  
   264  	u.Set(x) // for correctness check
   265  
   266  	z.Mul(&v, &{{.ElementName}}{
   267  		{{- range $i := .NbWordsIndexesFull }}
   268  		inversionCorrectionFactorWord{{$i}},
   269  		{{- end}}
   270  	})
   271  
   272  	// correctness check
   273      v.Mul(&u, z)
   274      if !v.IsOne() && !u.IsZero() {
   275              return z.inverseExp(u)
   276      }
   277  
   278  	return z
   279  }
   280  
   281  // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) 
   282  func (z *{{.ElementName}}) inverseExp(x {{.ElementName}}) *{{.ElementName}} {
   283  	// e == q-2
   284  	e := Modulus()
   285  	e.Sub(e, big.NewInt(2))
   286  
   287  	z.Set(&x)
   288  
   289  	for i := e.BitLen() - 2; i >= 0; i-- {
   290  		z.Square(z)
   291  		if e.Bit(i) == 1 {
   292  			z.Mul(z, &x)
   293  		}
   294  	}
   295  
   296  	return z
   297  }
   298  
   299  // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits
   300  // if x fits in a word as is, no approximation necessary
   301  func approximate(x *{{.ElementName}}, nBits int) uint64 {
   302  
   303  	if nBits <= 64 {
   304  		return x[0]
   305  	}
   306  
   307  	const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones
   308  	lo := mask & x[0]
   309  
   310  	hiWordIndex := (nBits - 1) / 64
   311  
   312  	hiWordBitsAvailable := nBits - hiWordIndex * 64
   313  	hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN)
   314  
   315  	mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1))
   316  	hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable)
   317  
   318  	mask_ = ^(1<<(approxLowBitsN + hiWordBitsUsed) - 1)
   319  	mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed
   320  
   321  	return lo | mid | hi
   322  }
   323  
   324  // linearComb z = xC * x + yC * y;
   325  // 0 ≤ x, y < 2{{supScr .NbBits}}
   326  // |xC|, |yC| < 2⁶³
   327  func (z *{{.ElementName}}) linearComb(x *{{.ElementName}}, xC int64, y *{{.ElementName}}, yC int64) {
   328  	{{- $elementCapacityNbBits := mul .NbWords 64}}
   329      // | (hi, z) | < 2 * 2⁶³ * 2{{supScr .NbBits}} = 2{{supScr (add 64 .NbBits)}}
   330  	// therefore | hi | < 2{{supScr (sub (add 64 .NbBits) $elementCapacityNbBits)}} ≤ 2⁶³
   331  	hi := z.linearCombNonModular(x, xC, y, yC)
   332  	z.montReduceSigned(z, hi)
   333  }
   334  
   335  // montReduceSigned z = (xHi * r + x) * r⁻¹ using the SOS algorithm
   336  // Requires |xHi| < 2⁶³. Most significant bit of xHi is the sign bit.
   337  func (z *{{.ElementName}}) montReduceSigned(x *{{.ElementName}}, xHi uint64) {
   338  	const signBitRemover = ^signBitSelector
   339  	mustNeg := xHi & signBitSelector != 0
   340  	// the SOS implementation requires that most significant bit is 0
   341  	// Let X be xHi*r + x
   342  	// If X is negative we would have initially stored it as 2⁶⁴ r + X (à la 2's complement)
   343  	xHi &= signBitRemover
   344  	// with this a negative X is now represented as 2⁶³ r + X
   345  
   346  	var t [2*Limbs - 1]uint64
   347  	var C uint64
   348  
   349  	m := x[0] * qInvNeg
   350  
   351  	C = madd0(m, q0, x[0])
   352  	{{- range $i := .NbWordsIndexesNoZero}}
   353  	C, t[{{$i}}] = madd2(m, q{{$i}}, x[{{$i}}], C)
   354  	{{- end}}
   355  
   356  	// m * qElement[{{.NbWordsLastIndex}}] ≤ (2⁶⁴ - 1) * (2⁶³ - 1) = 2¹²⁷ - 2⁶⁴ - 2⁶³ + 1
   357      // x[{{.NbWordsLastIndex}}] + C ≤ 2*(2⁶⁴ - 1) = 2⁶⁵ - 2
   358      // On LHS, (C, t[{{.NbWordsLastIndex}}]) ≤ 2¹²⁷ - 2⁶⁴ - 2⁶³ + 1 + 2⁶⁵ - 2 = 2¹²⁷ + 2⁶³ - 1
   359      // So on LHS, C ≤ 2⁶³
   360  	t[{{.NbWords}}] = xHi + C
   361  	// xHi + C < 2⁶³ + 2⁶³ = 2⁶⁴
   362  
   363  	{{/* $NbWordsIndexesNoZeroInnerLoop := .NbWordsIndexesNoZero*/}}// <standard SOS>
   364  	{{- range $i := iterate 1 $.NbWordsLastIndex}}
   365  	{
   366  		const i = {{$i}}
   367  		m = t[i] * qInvNeg
   368  
   369  		C = madd0(m, q0, t[i+0])
   370  
   371  		{{- range $j := $.NbWordsIndexesNoZero}}
   372  		C, t[i + {{$j}}] = madd2(m, q{{$j}}, t[i +  {{$j}}], C)
   373  		{{- end}}
   374  
   375  		t[i + Limbs] += C
   376  	}
   377  	{{- end}}
   378  	{
   379  		const i = {{.NbWordsLastIndex}}
   380  		m := t[i] * qInvNeg
   381  
   382  		C = madd0(m, q0, t[i+0])
   383  		{{- range $j := iterate 1 $.NbWordsLastIndex}}
   384  		C, z[{{sub $j 1}}] = madd2(m, q{{$j}}, t[i+{{$j}}], C)
   385  		{{- end}}
   386  		z[{{.NbWordsLastIndex}}], z[{{sub .NbWordsLastIndex 1}}] = madd2(m, q{{.NbWordsLastIndex}}, t[i+{{.NbWordsLastIndex}}], C)
   387  	}
   388      {{ template "reduce" . }}
   389  	// </standard SOS>
   390  
   391  	if mustNeg {
   392  		// We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead
   393  		var b uint64
   394  		z[0], b = bits.Sub64(z[0], signBitSelector, 0)
   395  
   396  		{{- range $i := .NbWordsIndexesNoZero}}
   397  		z[{{$i}}], b = bits.Sub64(z[{{$i}}], 0, b)
   398  		{{- end}}
   399  
   400  		// Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0
   401  		{{ template "addQ" .}}
   402  	}
   403  }
   404  
   405  const (
   406  	updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1)
   407  	updateFactorIdentityMatrixRow0 = 1
   408  	updateFactorIdentityMatrixRow1 = 1 << 32
   409  )
   410  
   411  func updateFactorsDecompose(c int64) (int64, int64) {
   412  	c += updateFactorsConversionBias
   413   	const low32BitsFilter int64 = 0xFFFFFFFF
   414   	f := c&low32BitsFilter - 0x7FFFFFFF
   415   	g := c>>32&low32BitsFilter - 0x7FFFFFFF
   416   	return f, g
   417  }
   418  
   419  {{ end }}
   420  
   421  `