github.com/emmansun/gmsm@v0.29.1/internal/bigmod/nat.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigmod
     6  
     7  import (
     8  	"encoding/binary"
     9  	"errors"
    10  	"math/big"
    11  	"math/bits"
    12  )
    13  
    14  const (
    15  	// _W is the size in bits of our limbs.
    16  	_W = bits.UintSize
    17  	// _S is the size in bytes of our limbs.
    18  	_S = _W / 8
    19  )
    20  
    21  // choice represents a constant-time boolean. The value of choice is always
    22  // either 1 or 0. We use an int instead of bool in order to make decisions in
    23  // constant time by turning it into a mask.
    24  type choice uint
    25  
    26  func not(c choice) choice { return 1 ^ c }
    27  
    28  const yes = choice(1)
    29  const no = choice(0)
    30  
    31  // ctMask is all 1s if on is yes, and all 0s otherwise.
    32  func ctMask(on choice) uint { return -uint(on) }
    33  
    34  // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
    35  // function does not depend on its inputs.
    36  func ctEq(x, y uint) choice {
    37  	// If x != y, then either x - y or y - x will generate a carry.
    38  	_, c1 := bits.Sub(x, y, 0)
    39  	_, c2 := bits.Sub(y, x, 0)
    40  	return not(choice(c1 | c2))
    41  }
    42  
    43  // ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
    44  // function does not depend on its inputs.
    45  func ctGeq(x, y uint) choice {
    46  	// If x < y, then x - y generates a carry.
    47  	_, carry := bits.Sub(x, y, 0)
    48  	return not(choice(carry))
    49  }
    50  
    51  // Nat represents an arbitrary natural number
    52  //
    53  // Each Nat has an announced length, which is the number of limbs it has stored.
    54  // Operations on this number are allowed to leak this length, but will not leak
    55  // any information about the values contained in those limbs.
    56  type Nat struct {
    57  	// limbs is little-endian in base 2^W with W = bits.UintSize.
    58  	limbs []uint
    59  }
    60  
    61  // preallocTarget is the size in bits of the numbers used to implement the most
    62  // common and most performant RSA key size. It's also enough to cover some of
    63  // the operations of key sizes up to 4096.
    64  const preallocTarget = 2048
    65  const preallocLimbs = (preallocTarget + _W - 1) / _W
    66  
    67  // NewNat returns a new nat with a size of zero, just like new(Nat), but with
    68  // the preallocated capacity to hold a number of up to preallocTarget bits.
    69  // NewNat inlines, so the allocation can live on the stack.
    70  func NewNat() *Nat {
    71  	limbs := make([]uint, 0, preallocLimbs)
    72  	return &Nat{limbs}
    73  }
    74  
    75  // expand expands x to n limbs, leaving its value unchanged.
    76  func (x *Nat) expand(n int) *Nat {
    77  	if len(x.limbs) > n {
    78  		panic("bigmod: internal error: shrinking nat")
    79  	}
    80  	if cap(x.limbs) < n {
    81  		newLimbs := make([]uint, n)
    82  		copy(newLimbs, x.limbs)
    83  		x.limbs = newLimbs
    84  		return x
    85  	}
    86  	extraLimbs := x.limbs[len(x.limbs):n]
    87  	for i := range extraLimbs {
    88  		extraLimbs[i] = 0
    89  	}
    90  	x.limbs = x.limbs[:n]
    91  	return x
    92  }
    93  
    94  // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
    95  func (x *Nat) reset(n int) *Nat {
    96  	if cap(x.limbs) < n {
    97  		x.limbs = make([]uint, n)
    98  		return x
    99  	}
   100  	for i := range x.limbs {
   101  		x.limbs[i] = 0
   102  	}
   103  	x.limbs = x.limbs[:n]
   104  	return x
   105  }
   106  
   107  // set assigns x = y, optionally resizing x to the appropriate size.
   108  func (x *Nat) Set(y *Nat) *Nat {
   109  	x.reset(len(y.limbs))
   110  	copy(x.limbs, y.limbs)
   111  	return x
   112  }
   113  
   114  // SetBig assigns x = n, optionally resizing n to the appropriate size.
   115  //
   116  // The announced length of x is set based on the actual bit size of the input,
   117  // ignoring leading zeroes.
   118  func (x *Nat) SetBig(n *big.Int) *Nat {
   119  	limbs := n.Bits()
   120  	x.reset(len(limbs))
   121  	for i := range limbs {
   122  		x.limbs[i] = uint(limbs[i])
   123  	}
   124  	return x
   125  }
   126  
   127  // Bytes returns x as a zero-extended big-endian byte slice. The size of the
   128  // slice will match the size of m.
   129  //
   130  // x must have the same size as m and it must be reduced modulo m.
   131  func (x *Nat) Bytes(m *Modulus) []byte {
   132  	i := m.Size()
   133  	bytes := make([]byte, i)
   134  	for _, limb := range x.limbs {
   135  		for j := 0; j < _S; j++ {
   136  			i--
   137  			if i < 0 {
   138  				if limb == 0 {
   139  					break
   140  				}
   141  				panic("bigmod: modulus is smaller than nat")
   142  			}
   143  			bytes[i] = byte(limb)
   144  			limb >>= 8
   145  		}
   146  	}
   147  	return bytes
   148  }
   149  
   150  // SetBytes assigns x = b, where b is a slice of big-endian bytes.
   151  // SetBytes returns an error if b >= m.
   152  //
   153  // The output will be resized to the size of m and overwritten.
   154  func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
   155  	if err := x.setBytes(b, m); err != nil {
   156  		return nil, err
   157  	}
   158  	if x.CmpGeq(m.nat) == yes {
   159  		return nil, errors.New("input overflows the modulus")
   160  	}
   161  	return x, nil
   162  }
   163  
   164  // SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.
   165  // SetOverflowingBytes returns an error if b has a longer bit length than m, but
   166  // reduces overflowing values up to 2^⌈log2(m)⌉ - 1.
   167  //
   168  // The output will be resized to the size of m and overwritten.
   169  func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
   170  	if err := x.setBytes(b, m); err != nil {
   171  		return nil, err
   172  	}
   173  	leading := _W - bitLen(x.limbs[len(x.limbs)-1])
   174  	if leading < m.leading {
   175  		return nil, errors.New("input overflows the modulus size")
   176  	}
   177  	x.maybeSubtractModulus(no, m)
   178  	return x, nil
   179  }
   180  
   181  // bigEndianUint returns the contents of buf interpreted as a
   182  // big-endian encoded uint value.
   183  func bigEndianUint(buf []byte) uint {
   184  	if _W == 64 {
   185  		return uint(binary.BigEndian.Uint64(buf))
   186  	}
   187  	return uint(binary.BigEndian.Uint32(buf))
   188  }
   189  
   190  func (x *Nat) setBytes(b []byte, m *Modulus) error {
   191  	x.resetFor(m)
   192  	i, k := len(b), 0
   193  	for k < len(x.limbs) && i >= _S {
   194  		x.limbs[k] = bigEndianUint(b[i-_S : i])
   195  		i -= _S
   196  		k++
   197  	}
   198  	for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 {
   199  		x.limbs[k] |= uint(b[i-1]) << s
   200  		i--
   201  	}
   202  	if i > 0 {
   203  		return errors.New("input overflows the modulus size")
   204  	}
   205  	return nil
   206  }
   207  
   208  // Equal returns 1 if x == y, and 0 otherwise.
   209  //
   210  // Both operands must have the same announced length.
   211  func (x *Nat) Equal(y *Nat) choice {
   212  	// Eliminate bounds checks in the loop.
   213  	size := len(x.limbs)
   214  	xLimbs := x.limbs[:size]
   215  	yLimbs := y.limbs[:size]
   216  
   217  	equal := yes
   218  	for i := 0; i < size; i++ {
   219  		equal &= ctEq(xLimbs[i], yLimbs[i])
   220  	}
   221  	return equal
   222  }
   223  
   224  // IsZero returns 1 if x == 0, and 0 otherwise.
   225  func (x *Nat) IsZero() choice {
   226  	// Eliminate bounds checks in the loop.
   227  	size := len(x.limbs)
   228  	xLimbs := x.limbs[:size]
   229  
   230  	zero := yes
   231  	for i := 0; i < size; i++ {
   232  		zero &= ctEq(xLimbs[i], 0)
   233  	}
   234  	return zero
   235  }
   236  
   237  // CmpGeq returns 1 if x >= y, and 0 otherwise.
   238  //
   239  // Both operands must have the same announced length.
   240  func (x *Nat) CmpGeq(y *Nat) choice {
   241  	// Eliminate bounds checks in the loop.
   242  	size := len(x.limbs)
   243  	xLimbs := x.limbs[:size]
   244  	yLimbs := y.limbs[:size]
   245  
   246  	var c uint
   247  	for i := 0; i < size; i++ {
   248  		_, c = bits.Sub(xLimbs[i], yLimbs[i], c)
   249  	}
   250  	// If there was a carry, then subtracting y underflowed, so
   251  	// x is not greater than or equal to y.
   252  	return not(choice(c))
   253  }
   254  
   255  // assign sets x <- y if on == 1, and does nothing otherwise.
   256  //
   257  // Both operands must have the same announced length.
   258  func (x *Nat) assign(on choice, y *Nat) *Nat {
   259  	// Eliminate bounds checks in the loop.
   260  	size := len(x.limbs)
   261  	xLimbs := x.limbs[:size]
   262  	yLimbs := y.limbs[:size]
   263  
   264  	mask := ctMask(on)
   265  	for i := 0; i < size; i++ {
   266  		xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i])
   267  	}
   268  	return x
   269  }
   270  
   271  // add computes x += y and returns the carry.
   272  //
   273  // Both operands must have the same announced length.
   274  func (x *Nat) add(y *Nat) (c uint) {
   275  	// Eliminate bounds checks in the loop.
   276  	size := len(x.limbs)
   277  	xLimbs := x.limbs[:size]
   278  	yLimbs := y.limbs[:size]
   279  
   280  	for i := 0; i < size; i++ {
   281  		xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c)
   282  	}
   283  	return
   284  }
   285  
   286  // sub computes x -= y. It returns the borrow of the subtraction.
   287  //
   288  // Both operands must have the same announced length.
   289  func (x *Nat) sub(y *Nat) (c uint) {
   290  	// Eliminate bounds checks in the loop.
   291  	size := len(x.limbs)
   292  	xLimbs := x.limbs[:size]
   293  	yLimbs := y.limbs[:size]
   294  
   295  	for i := 0; i < size; i++ {
   296  		xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c)
   297  	}
   298  	return
   299  }
   300  
   301  // Modulus is used for modular arithmetic, precomputing relevant constants.
   302  //
   303  // Moduli are assumed to be odd numbers. Moduli can also leak the exact
   304  // number of bits needed to store their value, and are stored without padding.
   305  //
   306  // Their actual value is still kept secret.
   307  type Modulus struct {
   308  	// The underlying natural number for this modulus.
   309  	//
   310  	// This will be stored without any padding, and shouldn't alias with any
   311  	// other natural number being used.
   312  	nat     *Nat
   313  	leading int  // number of leading zeros in the modulus
   314  	m0inv   uint // -nat.limbs[0]⁻¹ mod _W
   315  	rr      *Nat // R*R for montgomeryRepresentation
   316  }
   317  
   318  // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
   319  func rr(m *Modulus) *Nat {
   320  	rr := NewNat().ExpandFor(m)
   321  	n := uint(len(rr.limbs))
   322  	mLen := uint(m.BitLen())
   323  	logR := _W * n
   324  
   325  	// We start by computing R = 2^(_W * n) mod m. We can get pretty close, to
   326  	// 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce.
   327  	rr.limbs[n-1] = 1 << ((mLen - 1) % _W)
   328  	// Then we double until we reach 2^(_W * n).
   329  	for i := mLen - 1; i < logR; i++ {
   330  		rr.Add(rr, m)
   331  	}
   332  
   333  	// Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in
   334  	// the Montgomery domain, meaning we can use Montgomery multiplication now).
   335  	// We could do that by doubling _W * n times, or with a square-and-double
   336  	// chain log2(_W * n) long. Turns out the fastest thing is to start out with
   337  	// doublings, and switch to square-and-double once the exponent is large
   338  	// enough to justify the cost of the multiplications.
   339  
   340  	// The threshold is selected experimentally as a linear function of n.
   341  	threshold := n / 4
   342  
   343  	// We calculate how many of the most-significant bits of the exponent we can
   344  	// compute before crossing the threshold, and we do it with doublings.
   345  	i := bits.UintSize
   346  	for logR>>i <= threshold {
   347  		i--
   348  	}
   349  	for k := uint(0); k < logR>>i; k++ {
   350  		rr.Add(rr, m)
   351  	}
   352  
   353  	// Then we process the remaining bits of the exponent with a
   354  	// square-and-double chain.
   355  	for i > 0 {
   356  		rr.montgomeryMul(rr, rr, m)
   357  		i--
   358  		if logR>>i&1 != 0 {
   359  			rr.Add(rr, m)
   360  		}
   361  	}
   362  
   363  	return rr
   364  }
   365  
   366  // minusInverseModW computes -x⁻¹ mod _W with x odd.
   367  //
   368  // This operation is used to precompute a constant involved in Montgomery
   369  // multiplication.
   370  func minusInverseModW(x uint) uint {
   371  	// Every iteration of this loop doubles the least-significant bits of
   372  	// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
   373  	// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
   374  	// for 64 bits (and wastes only one iteration for 32 bits).
   375  	//
   376  	// See https://crypto.stackexchange.com/a/47496.
   377  	y := x
   378  	for i := 0; i < 5; i++ {
   379  		y = y * (2 - x*y)
   380  	}
   381  	return -y
   382  }
   383  
   384  // NewModulusFromBig creates a new Modulus from a [big.Int].
   385  //
   386  // The Int must be odd. The number of significant bits (and nothing else) is
   387  // leaked through timing side-channels.
   388  func NewModulusFromBig(n *big.Int) (*Modulus, error) {
   389  	if b := n.Bits(); len(b) == 0 {
   390  		return nil, errors.New("modulus must be >= 0")
   391  	} else if b[0]&1 != 1 {
   392  		return nil, errors.New("modulus must be odd")
   393  	}
   394  	m := &Modulus{}
   395  	m.nat = NewNat().SetBig(n)
   396  	m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
   397  	m.m0inv = minusInverseModW(m.nat.limbs[0])
   398  	m.rr = rr(m)
   399  	return m, nil
   400  }
   401  
   402  // bitLen is a version of bits.Len that only leaks the bit length of n, but not
   403  // its value. bits.Len and bits.LeadingZeros use a lookup table for the
   404  // low-order bits on some architectures.
   405  func bitLen(n uint) int {
   406  	var len int
   407  	// We assume, here and elsewhere, that comparison to zero is constant time
   408  	// with respect to different non-zero values.
   409  	for n != 0 {
   410  		len++
   411  		n >>= 1
   412  	}
   413  	return len
   414  }
   415  
   416  // Size returns the size of m in bytes.
   417  func (m *Modulus) Size() int {
   418  	return (m.BitLen() + 7) / 8
   419  }
   420  
   421  // BitLen returns the size of m in bits.
   422  func (m *Modulus) BitLen() int {
   423  	return len(m.nat.limbs)*_W - int(m.leading)
   424  }
   425  
   426  // Nat returns m as a Nat. The return value must not be written to.
   427  func (m *Modulus) Nat() *Nat {
   428  	return m.nat
   429  }
   430  
   431  // shiftIn calculates x = x << _W + y mod m.
   432  //
   433  // This assumes that x is already reduced mod m, and that y < 2^_W.
   434  func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
   435  	return x.shiftInNat(y, m.nat)
   436  }
   437  
   438  // shiftIn calculates x = x << _W + y mod m.
   439  //
   440  // This assumes that x is already reduced mod m, and that y < 2^_W.
   441  func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
   442  	d := NewNat().reset(len(m.limbs))
   443  
   444  	// Eliminate bounds checks in the loop.
   445  	size := len(m.limbs)
   446  	xLimbs := x.limbs[:size]
   447  	dLimbs := d.limbs[:size]
   448  	mLimbs := m.limbs[:size]
   449  
   450  	// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
   451  	// from y. Effectively, it left-shifts x and adds y one bit at a time,
   452  	// reducing it every time.
   453  	//
   454  	// To do the reduction, each iteration computes both 2x + b and 2x + b - m.
   455  	// The next iteration (and finally the return line) will use either result
   456  	// based on whether 2x + b overflows m.
   457  	needSubtraction := no
   458  	for i := _W - 1; i >= 0; i-- {
   459  		carry := (y >> i) & 1
   460  		var borrow uint
   461  		mask := ctMask(needSubtraction)
   462  		for i := 0; i < size; i++ {
   463  			l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i]))
   464  			xLimbs[i], carry = bits.Add(l, l, carry)
   465  			dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow)
   466  		}
   467  		// Like in maybeSubtractModulus, we need the subtraction if either it
   468  		// didn't underflow (meaning 2x + b > m) or if computing 2x + b
   469  		// overflowed (meaning 2x + b > 2^_W*n > m).
   470  		needSubtraction = not(choice(borrow)) | choice(carry)
   471  	}
   472  	return x.assign(needSubtraction, d)
   473  }
   474  
   475  // Mod calculates out = x mod m.
   476  //
   477  // This works regardless how large the value of x is.
   478  //
   479  // The output will be resized to the size of m and overwritten.
   480  func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
   481  	return out.ModNat(x, m.nat)
   482  }
   483  
   484  // Mod calculates out = x mod m.
   485  //
   486  // This works regardless how large the value of x is.
   487  //
   488  // The output will be resized to the size of m and overwritten.
   489  func (out *Nat) ModNat(x *Nat, m *Nat) *Nat {
   490  	out.reset(len(m.limbs))
   491  	// Working our way from the most significant to the least significant limb,
   492  	// we can insert each limb at the least significant position, shifting all
   493  	// previous limbs left by _W. This way each limb will get shifted by the
   494  	// correct number of bits. We can insert at least N - 1 limbs without
   495  	// overflowing m. After that, we need to reduce every time we shift.
   496  	i := len(x.limbs) - 1
   497  	// For the first N - 1 limbs we can skip the actual shifting and position
   498  	// them at the shifted position, which starts at min(N - 2, i).
   499  	start := len(m.limbs) - 2
   500  	if i < start {
   501  		start = i
   502  	}
   503  	for j := start; j >= 0; j-- {
   504  		out.limbs[j] = x.limbs[i]
   505  		i--
   506  	}
   507  	// We shift in the remaining limbs, reducing modulo m each time.
   508  	for i >= 0 {
   509  		out.shiftInNat(x.limbs[i], m)
   510  		i--
   511  	}
   512  	return out
   513  }
   514  
   515  // ExpandFor ensures out has the right size to work with operations modulo m.
   516  //
   517  // The announced size of out must be smaller than or equal to that of m.
   518  func (out *Nat) ExpandFor(m *Modulus) *Nat {
   519  	return out.expand(len(m.nat.limbs))
   520  }
   521  
   522  // resetFor ensures out has the right size to work with operations modulo m.
   523  //
   524  // out is zeroed and may start at any size.
   525  func (out *Nat) resetFor(m *Modulus) *Nat {
   526  	return out.reset(len(m.nat.limbs))
   527  }
   528  
   529  // maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.
   530  //
   531  // It can be used to reduce modulo m a value up to 2m - 1, which is a common
   532  // range for results computed by higher level operations.
   533  //
   534  // always is usually a carry that indicates that the operation that produced x
   535  // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
   536  //
   537  // x and m operands must have the same announced length.
   538  func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
   539  	t := NewNat().Set(x)
   540  	underflow := t.sub(m.nat)
   541  	// We keep the result if x - m didn't underflow (meaning x >= m)
   542  	// or if always was set.
   543  	keep := not(choice(underflow)) | choice(always)
   544  	x.assign(keep, t)
   545  }
   546  
   547  // Sub computes x = x - y mod m.
   548  //
   549  // The length of both operands must be the same as the modulus. Both operands
   550  // must already be reduced modulo m.
   551  func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
   552  	underflow := x.sub(y)
   553  	// If the subtraction underflowed, add m.
   554  	t := NewNat().Set(x)
   555  	t.add(m.nat)
   556  	x.assign(choice(underflow), t)
   557  	return x
   558  }
   559  
   560  // Add computes x = x + y mod m.
   561  //
   562  // The length of both operands must be the same as the modulus. Both operands
   563  // must already be reduced modulo m.
   564  func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
   565  	overflow := x.add(y)
   566  	x.maybeSubtractModulus(choice(overflow), m)
   567  	return x
   568  }
   569  
   570  // montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and
   571  // n = len(m.nat.limbs).
   572  //
   573  // Faster Montgomery multiplication replaces standard modular multiplication for
   574  // numbers in this representation.
   575  //
   576  // This assumes that x is already reduced mod m.
   577  func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
   578  	// A Montgomery multiplication (which computes a * b / R) by R * R works out
   579  	// to a multiplication by R, which takes the value out of the Montgomery domain.
   580  	return x.montgomeryMul(x, m.rr, m)
   581  }
   582  
   583  // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
   584  // n = len(m.nat.limbs).
   585  //
   586  // This assumes that x is already reduced mod m.
   587  func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
   588  	// By Montgomery multiplying with 1 not in Montgomery representation, we
   589  	// convert out back from Montgomery representation, because it works out to
   590  	// dividing by R.
   591  	one := NewNat().ExpandFor(m)
   592  	one.limbs[0] = 1
   593  	return x.montgomeryMul(x, one, m)
   594  }
   595  
   596  // montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and
   597  // n = len(m.nat.limbs), also known as a Montgomery multiplication.
   598  //
   599  // All inputs should be the same length and already reduced modulo m.
   600  // x will be resized to the size of m and overwritten.
   601  func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
   602  	n := len(m.nat.limbs)
   603  	mLimbs := m.nat.limbs[:n]
   604  	aLimbs := a.limbs[:n]
   605  	bLimbs := b.limbs[:n]
   606  
   607  	switch n {
   608  	default:
   609  		// Attempt to use a stack-allocated backing array.
   610  		T := make([]uint, 0, preallocLimbs*2)
   611  		if cap(T) < n*2 {
   612  			T = make([]uint, 0, n*2)
   613  		}
   614  		T = T[:n*2]
   615  
   616  		// This loop implements Word-by-Word Montgomery Multiplication, as
   617  		// described in Algorithm 4 (Fig. 3) of "Efficient Software
   618  		// Implementations of Modular Exponentiation" by Shay Gueron
   619  		// [https://eprint.iacr.org/2011/239.pdf].
   620  		var c uint
   621  		for i := 0; i < n; i++ {
   622  			_ = T[n+i] // bounds check elimination hint
   623  
   624  			// Step 1 (T = a × b) is computed as a large pen-and-paper column
   625  			// multiplication of two numbers with n base-2^_W digits. If we just
   626  			// wanted to produce 2n-wide T, we would do
   627  			//
   628  			//   for i := 0; i < n; i++ {
   629  			//       d := bLimbs[i]
   630  			//       T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)
   631  			//   }
   632  			//
   633  			// where d is a digit of the multiplier, T[i:n+i] is the shifted
   634  			// position of the product of that digit, and T[n+i] is the final carry.
   635  			// Note that T[i] isn't modified after processing the i-th digit.
   636  			//
   637  			// Instead of running two loops, one for Step 1 and one for Steps 2–6,
   638  			// the result of Step 1 is computed during the next loop. This is
   639  			// possible because each iteration only uses T[i] in Step 2 and then
   640  			// discards it in Step 6.
   641  			d := bLimbs[i]
   642  			c1 := addMulVVW(T[i:n+i], aLimbs, d)
   643  
   644  			// Step 6 is replaced by shifting the virtual window we operate
   645  			// over: T of the algorithm is T[i:] for us. That means that T1 in
   646  			// Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.
   647  			Y := T[i] * m.m0inv
   648  
   649  			// Step 4 and 5 add Y × m to T, which as mentioned above is stored
   650  			// at T[i:]. The two carries (from a × d and Y × m) are added up in
   651  			// the next word T[n+i], and the carry bit from that addition is
   652  			// brought forward to the next iteration.
   653  			c2 := addMulVVW(T[i:n+i], mLimbs, Y)
   654  			T[n+i], c = bits.Add(c1, c2, c)
   655  		}
   656  
   657  		// Finally for Step 7 we copy the final T window into x, and subtract m
   658  		// if necessary (which as explained in maybeSubtractModulus can be the
   659  		// case both if x >= m, or if x overflowed).
   660  		//
   661  		// The paper suggests in Section 4 that we can do an "Almost Montgomery
   662  		// Multiplication" by subtracting only in the overflow case, but the
   663  		// cost is very similar since the constant time subtraction tells us if
   664  		// x >= m as a side effect, and taking care of the broken invariant is
   665  		// highly undesirable (see https://go.dev/issue/13907).
   666  		copy(x.reset(n).limbs, T[n:])
   667  		x.maybeSubtractModulus(choice(c), m)
   668  
   669  	// The following specialized cases follow the exact same algorithm, but
   670  	// optimized for the sizes most used in RSA. addMulVVW is implemented in
   671  	// assembly with loop unrolling depending on the architecture and bounds
   672  	// checks are removed by the compiler thanks to the constant size.
   673  	case 256 / _W: // optimization for 256 bits nat
   674  		const n = 256 / _W // compiler hint
   675  		T := make([]uint, n*2)
   676  		var c uint
   677  		for i := 0; i < n; i++ {
   678  			d := bLimbs[i]
   679  			c1 := addMulVVW256(&T[i], &aLimbs[0], d)
   680  			Y := T[i] * m.m0inv
   681  			c2 := addMulVVW256(&T[i], &mLimbs[0], Y)
   682  			T[n+i], c = bits.Add(c1, c2, c)
   683  		}
   684  		copy(x.reset(n).limbs, T[n:])
   685  		x.maybeSubtractModulus(choice(c), m)
   686  		
   687  	case 1024 / _W:
   688  		const n = 1024 / _W // compiler hint
   689  		T := make([]uint, n*2)
   690  		var c uint
   691  		for i := 0; i < n; i++ {
   692  			d := bLimbs[i]
   693  			c1 := addMulVVW1024(&T[i], &aLimbs[0], d)
   694  			Y := T[i] * m.m0inv
   695  			c2 := addMulVVW1024(&T[i], &mLimbs[0], Y)
   696  			T[n+i], c = bits.Add(c1, c2, c)
   697  		}
   698  		copy(x.reset(n).limbs, T[n:])
   699  		x.maybeSubtractModulus(choice(c), m)
   700  
   701  	case 1536 / _W:
   702  		const n = 1536 / _W // compiler hint
   703  		T := make([]uint, n*2)
   704  		var c uint
   705  		for i := 0; i < n; i++ {
   706  			d := bLimbs[i]
   707  			c1 := addMulVVW1536(&T[i], &aLimbs[0], d)
   708  			Y := T[i] * m.m0inv
   709  			c2 := addMulVVW1536(&T[i], &mLimbs[0], Y)
   710  			T[n+i], c = bits.Add(c1, c2, c)
   711  		}
   712  		copy(x.reset(n).limbs, T[n:])
   713  		x.maybeSubtractModulus(choice(c), m)
   714  
   715  	case 2048 / _W:
   716  		const n = 2048 / _W // compiler hint
   717  		T := make([]uint, n*2)
   718  		var c uint
   719  		for i := 0; i < n; i++ {
   720  			d := bLimbs[i]
   721  			c1 := addMulVVW2048(&T[i], &aLimbs[0], d)
   722  			Y := T[i] * m.m0inv
   723  			c2 := addMulVVW2048(&T[i], &mLimbs[0], Y)
   724  			T[n+i], c = bits.Add(c1, c2, c)
   725  		}
   726  		copy(x.reset(n).limbs, T[n:])
   727  		x.maybeSubtractModulus(choice(c), m)
   728  	}
   729  
   730  	return x
   731  }
   732  
   733  // addMulVVW multiplies the multi-word value x by the single-word value y,
   734  // adding the result to the multi-word value z and returning the final carry.
   735  // It can be thought of as one row of a pen-and-paper column multiplication.
   736  func addMulVVW(z, x []uint, y uint) (carry uint) {
   737  	_ = x[len(z)-1] // bounds check elimination hint
   738  	for i := range z {
   739  		hi, lo := bits.Mul(x[i], y)
   740  		lo, c := bits.Add(lo, z[i], 0)
   741  		// We use bits.Add with zero to get an add-with-carry instruction that
   742  		// absorbs the carry from the previous bits.Add.
   743  		hi, _ = bits.Add(hi, 0, c)
   744  		lo, c = bits.Add(lo, carry, 0)
   745  		hi, _ = bits.Add(hi, 0, c)
   746  		carry = hi
   747  		z[i] = lo
   748  	}
   749  	return carry
   750  }
   751  
   752  // Mul calculates x = x * y mod m.
   753  //
   754  // The length of both operands must be the same as the modulus. Both operands
   755  // must already be reduced modulo m.
   756  func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
   757  	// A Montgomery multiplication by a value out of the Montgomery domain
   758  	// takes the result out of Montgomery representation.
   759  	xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
   760  	return x.montgomeryMul(xR, y, m)                  // x = xR * y / R mod m
   761  }
   762  
   763  // Exp calculates out = x^e mod m.
   764  //
   765  // The exponent e is represented in big-endian order. The output will be resized
   766  // to the size of m and overwritten. x must already be reduced modulo m.
   767  func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
   768  	// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
   769  	// than 2 bit windows, but use an extra 12 nats worth of scratch space.
   770  	// Using bit sizes that don't divide 8 are more complex to implement, but
   771  	// are likely to be more efficient if necessary.
   772  
   773  	table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
   774  		// newNat calls are unrolled so they are allocated on the stack.
   775  		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
   776  		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
   777  		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
   778  	}
   779  	table[0].Set(x).montgomeryRepresentation(m)
   780  	for i := 1; i < len(table); i++ {
   781  		table[i].montgomeryMul(table[i-1], table[0], m)
   782  	}
   783  
   784  	out.resetFor(m)
   785  	out.limbs[0] = 1
   786  	out.montgomeryRepresentation(m)
   787  	tmp := NewNat().ExpandFor(m)
   788  	for _, b := range e {
   789  		for _, j := range []int{4, 0} {
   790  			// Square four times. Optimization note: this can be implemented
   791  			// more efficiently than with generic Montgomery multiplication.
   792  			out.montgomeryMul(out, out, m)
   793  			out.montgomeryMul(out, out, m)
   794  			out.montgomeryMul(out, out, m)
   795  			out.montgomeryMul(out, out, m)
   796  
   797  			// Select x^k in constant time from the table.
   798  			k := uint((b >> j) & 0b1111)
   799  			for i := range table {
   800  				tmp.assign(ctEq(k, uint(i+1)), table[i])
   801  			}
   802  
   803  			// Multiply by x^k, discarding the result if k = 0.
   804  			tmp.montgomeryMul(out, tmp, m)
   805  			out.assign(not(ctEq(k, 0)), tmp)
   806  		}
   807  	}
   808  
   809  	return out.montgomeryReduction(m)
   810  }
   811  
   812  // ExpShortVarTime calculates out = x^e mod m.
   813  //
   814  // The output will be resized to the size of m and overwritten. x must already
   815  // be reduced modulo m. This leaks the exponent through timing side-channels.
   816  func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
   817  	// For short exponents, precomputing a table and using a window like in Exp
   818  	// doesn't pay off. Instead, we do a simple conditional square-and-multiply
   819  	// chain, skipping the initial run of zeroes.
   820  	xR := NewNat().Set(x).montgomeryRepresentation(m)
   821  	out.Set(xR)
   822  	for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ {
   823  		out.montgomeryMul(out, out, m)
   824  		if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
   825  			out.montgomeryMul(out, xR, m)
   826  		}
   827  	}
   828  	return out.montgomeryReduction(m)
   829  }