github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/crypto/internal/bigmod/nat.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigmod
     6  
     7  import (
     8  	"errors"
     9  	"math/big"
    10  	"math/bits"
    11  )
    12  
    13  const (
    14  	// _W is the number of bits we use for our limbs.
    15  	_W = bits.UintSize - 1
    16  	// _MASK selects _W bits from a full machine word.
    17  	_MASK = (1 << _W) - 1
    18  )
    19  
    20  // choice represents a constant-time boolean. The value of choice is always
    21  // either 1 or 0. We use an int instead of bool in order to make decisions in
    22  // constant time by turning it into a mask.
    23  type choice uint
    24  
    25  func not(c choice) choice { return 1 ^ c }
    26  
    27  const yes = choice(1)
    28  const no = choice(0)
    29  
    30  // ctSelect returns x if on == 1, and y if on == 0. The execution time of this
    31  // function does not depend on its inputs. If on is any value besides 1 or 0,
    32  // the result is undefined.
    33  func ctSelect(on choice, x, y uint) uint {
    34  	// When on == 1, mask is 0b111..., otherwise mask is 0b000...
    35  	mask := -uint(on)
    36  	// When mask is all zeros, we just have y, otherwise, y cancels with itself.
    37  	return y ^ (mask & (y ^ x))
    38  }
    39  
    40  // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
    41  // function does not depend on its inputs.
    42  func ctEq(x, y uint) choice {
    43  	// If x != y, then either x - y or y - x will generate a carry.
    44  	_, c1 := bits.Sub(x, y, 0)
    45  	_, c2 := bits.Sub(y, x, 0)
    46  	return not(choice(c1 | c2))
    47  }
    48  
    49  // ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
    50  // function does not depend on its inputs.
    51  func ctGeq(x, y uint) choice {
    52  	// If x < y, then x - y generates a carry.
    53  	_, carry := bits.Sub(x, y, 0)
    54  	return not(choice(carry))
    55  }
    56  
    57  // Nat represents an arbitrary natural number
    58  //
    59  // Each Nat has an announced length, which is the number of limbs it has stored.
    60  // Operations on this number are allowed to leak this length, but will not leak
    61  // any information about the values contained in those limbs.
    62  type Nat struct {
    63  	// limbs is a little-endian representation in base 2^W with
    64  	// W = bits.UintSize - 1. The top bit is always unset between operations.
    65  	//
    66  	// The top bit is left unset to optimize Montgomery multiplication, in the
    67  	// inner loop of exponentiation. Using fully saturated limbs would leave us
    68  	// working with 129-bit numbers on 64-bit platforms, wasting a lot of space,
    69  	// and thus time.
    70  	limbs []uint
    71  }
    72  
    73  // preallocTarget is the size in bits of the numbers used to implement the most
    74  // common and most performant RSA key size. It's also enough to cover some of
    75  // the operations of key sizes up to 4096.
    76  const preallocTarget = 2048
    77  const preallocLimbs = (preallocTarget + _W - 1) / _W
    78  
    79  // NewNat returns a new nat with a size of zero, just like new(Nat), but with
    80  // the preallocated capacity to hold a number of up to preallocTarget bits.
    81  // NewNat inlines, so the allocation can live on the stack.
    82  func NewNat() *Nat {
    83  	limbs := make([]uint, 0, preallocLimbs)
    84  	return &Nat{limbs}
    85  }
    86  
    87  // expand expands x to n limbs, leaving its value unchanged.
    88  func (x *Nat) expand(n int) *Nat {
    89  	if len(x.limbs) > n {
    90  		panic("bigmod: internal error: shrinking nat")
    91  	}
    92  	if cap(x.limbs) < n {
    93  		newLimbs := make([]uint, n)
    94  		copy(newLimbs, x.limbs)
    95  		x.limbs = newLimbs
    96  		return x
    97  	}
    98  	extraLimbs := x.limbs[len(x.limbs):n]
    99  	for i := range extraLimbs {
   100  		extraLimbs[i] = 0
   101  	}
   102  	x.limbs = x.limbs[:n]
   103  	return x
   104  }
   105  
   106  // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
   107  func (x *Nat) reset(n int) *Nat {
   108  	if cap(x.limbs) < n {
   109  		x.limbs = make([]uint, n)
   110  		return x
   111  	}
   112  	for i := range x.limbs {
   113  		x.limbs[i] = 0
   114  	}
   115  	x.limbs = x.limbs[:n]
   116  	return x
   117  }
   118  
   119  // set assigns x = y, optionally resizing x to the appropriate size.
   120  func (x *Nat) set(y *Nat) *Nat {
   121  	x.reset(len(y.limbs))
   122  	copy(x.limbs, y.limbs)
   123  	return x
   124  }
   125  
   126  // setBig assigns x = n, optionally resizing n to the appropriate size.
   127  //
   128  // The announced length of x is set based on the actual bit size of the input,
   129  // ignoring leading zeroes.
   130  func (x *Nat) setBig(n *big.Int) *Nat {
   131  	requiredLimbs := (n.BitLen() + _W - 1) / _W
   132  	x.reset(requiredLimbs)
   133  
   134  	outI := 0
   135  	shift := 0
   136  	limbs := n.Bits()
   137  	for i := range limbs {
   138  		xi := uint(limbs[i])
   139  		x.limbs[outI] |= (xi << shift) & _MASK
   140  		outI++
   141  		if outI == requiredLimbs {
   142  			return x
   143  		}
   144  		x.limbs[outI] = xi >> (_W - shift)
   145  		shift++ // this assumes bits.UintSize - _W = 1
   146  		if shift == _W {
   147  			shift = 0
   148  			outI++
   149  		}
   150  	}
   151  	return x
   152  }
   153  
   154  // Bytes returns x as a zero-extended big-endian byte slice. The size of the
   155  // slice will match the size of m.
   156  //
   157  // x must have the same size as m and it must be reduced modulo m.
   158  func (x *Nat) Bytes(m *Modulus) []byte {
   159  	bytes := make([]byte, m.Size())
   160  	shift := 0
   161  	outI := len(bytes) - 1
   162  	for _, limb := range x.limbs {
   163  		remainingBits := _W
   164  		for remainingBits >= 8 {
   165  			bytes[outI] |= byte(limb) << shift
   166  			consumed := 8 - shift
   167  			limb >>= consumed
   168  			remainingBits -= consumed
   169  			shift = 0
   170  			outI--
   171  			if outI < 0 {
   172  				return bytes
   173  			}
   174  		}
   175  		bytes[outI] = byte(limb)
   176  		shift = remainingBits
   177  	}
   178  	return bytes
   179  }
   180  
   181  // SetBytes assigns x = b, where b is a slice of big-endian bytes.
   182  // SetBytes returns an error if b >= m.
   183  //
   184  // The output will be resized to the size of m and overwritten.
   185  func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
   186  	if err := x.setBytes(b, m); err != nil {
   187  		return nil, err
   188  	}
   189  	if x.cmpGeq(m.nat) == yes {
   190  		return nil, errors.New("input overflows the modulus")
   191  	}
   192  	return x, nil
   193  }
   194  
   195  // SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. SetOverflowingBytes
   196  // returns an error if b has a longer bit length than m, but reduces overflowing
   197  // values up to 2^⌈log2(m)⌉ - 1.
   198  //
   199  // The output will be resized to the size of m and overwritten.
   200  func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
   201  	if err := x.setBytes(b, m); err != nil {
   202  		return nil, err
   203  	}
   204  	leading := _W - bitLen(x.limbs[len(x.limbs)-1])
   205  	if leading < m.leading {
   206  		return nil, errors.New("input overflows the modulus")
   207  	}
   208  	x.sub(x.cmpGeq(m.nat), m.nat)
   209  	return x, nil
   210  }
   211  
   212  func (x *Nat) setBytes(b []byte, m *Modulus) error {
   213  	outI := 0
   214  	shift := 0
   215  	x.resetFor(m)
   216  	for i := len(b) - 1; i >= 0; i-- {
   217  		bi := b[i]
   218  		x.limbs[outI] |= uint(bi) << shift
   219  		shift += 8
   220  		if shift >= _W {
   221  			shift -= _W
   222  			x.limbs[outI] &= _MASK
   223  			overflow := bi >> (8 - shift)
   224  			outI++
   225  			if outI >= len(x.limbs) {
   226  				if overflow > 0 || i > 0 {
   227  					return errors.New("input overflows the modulus")
   228  				}
   229  				break
   230  			}
   231  			x.limbs[outI] = uint(overflow)
   232  		}
   233  	}
   234  	return nil
   235  }
   236  
   237  // Equal returns 1 if x == y, and 0 otherwise.
   238  //
   239  // Both operands must have the same announced length.
   240  func (x *Nat) Equal(y *Nat) choice {
   241  	// Eliminate bounds checks in the loop.
   242  	size := len(x.limbs)
   243  	xLimbs := x.limbs[:size]
   244  	yLimbs := y.limbs[:size]
   245  
   246  	equal := yes
   247  	for i := 0; i < size; i++ {
   248  		equal &= ctEq(xLimbs[i], yLimbs[i])
   249  	}
   250  	return equal
   251  }
   252  
   253  // IsZero returns 1 if x == 0, and 0 otherwise.
   254  func (x *Nat) IsZero() choice {
   255  	// Eliminate bounds checks in the loop.
   256  	size := len(x.limbs)
   257  	xLimbs := x.limbs[:size]
   258  
   259  	zero := yes
   260  	for i := 0; i < size; i++ {
   261  		zero &= ctEq(xLimbs[i], 0)
   262  	}
   263  	return zero
   264  }
   265  
   266  // cmpGeq returns 1 if x >= y, and 0 otherwise.
   267  //
   268  // Both operands must have the same announced length.
   269  func (x *Nat) cmpGeq(y *Nat) choice {
   270  	// Eliminate bounds checks in the loop.
   271  	size := len(x.limbs)
   272  	xLimbs := x.limbs[:size]
   273  	yLimbs := y.limbs[:size]
   274  
   275  	var c uint
   276  	for i := 0; i < size; i++ {
   277  		c = (xLimbs[i] - yLimbs[i] - c) >> _W
   278  	}
   279  	// If there was a carry, then subtracting y underflowed, so
   280  	// x is not greater than or equal to y.
   281  	return not(choice(c))
   282  }
   283  
   284  // assign sets x <- y if on == 1, and does nothing otherwise.
   285  //
   286  // Both operands must have the same announced length.
   287  func (x *Nat) assign(on choice, y *Nat) *Nat {
   288  	// Eliminate bounds checks in the loop.
   289  	size := len(x.limbs)
   290  	xLimbs := x.limbs[:size]
   291  	yLimbs := y.limbs[:size]
   292  
   293  	for i := 0; i < size; i++ {
   294  		xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i])
   295  	}
   296  	return x
   297  }
   298  
   299  // add computes x += y if on == 1, and does nothing otherwise. It returns the
   300  // carry of the addition regardless of on.
   301  //
   302  // Both operands must have the same announced length.
   303  func (x *Nat) add(on choice, y *Nat) (c uint) {
   304  	// Eliminate bounds checks in the loop.
   305  	size := len(x.limbs)
   306  	xLimbs := x.limbs[:size]
   307  	yLimbs := y.limbs[:size]
   308  
   309  	for i := 0; i < size; i++ {
   310  		res := xLimbs[i] + yLimbs[i] + c
   311  		xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
   312  		c = res >> _W
   313  	}
   314  	return
   315  }
   316  
   317  // sub computes x -= y if on == 1, and does nothing otherwise. It returns the
   318  // borrow of the subtraction regardless of on.
   319  //
   320  // Both operands must have the same announced length.
   321  func (x *Nat) sub(on choice, y *Nat) (c uint) {
   322  	// Eliminate bounds checks in the loop.
   323  	size := len(x.limbs)
   324  	xLimbs := x.limbs[:size]
   325  	yLimbs := y.limbs[:size]
   326  
   327  	for i := 0; i < size; i++ {
   328  		res := xLimbs[i] - yLimbs[i] - c
   329  		xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
   330  		c = res >> _W
   331  	}
   332  	return
   333  }
   334  
   335  // Modulus is used for modular arithmetic, precomputing relevant constants.
   336  //
   337  // Moduli are assumed to be odd numbers. Moduli can also leak the exact
   338  // number of bits needed to store their value, and are stored without padding.
   339  //
   340  // Their actual value is still kept secret.
   341  type Modulus struct {
   342  	// The underlying natural number for this modulus.
   343  	//
   344  	// This will be stored without any padding, and shouldn't alias with any
   345  	// other natural number being used.
   346  	nat     *Nat
   347  	leading int  // number of leading zeros in the modulus
   348  	m0inv   uint // -nat.limbs[0]⁻¹ mod _W
   349  	rr      *Nat // R*R for montgomeryRepresentation
   350  }
   351  
   352  // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
   353  func rr(m *Modulus) *Nat {
   354  	rr := NewNat().ExpandFor(m)
   355  	// R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
   356  	// most significant limb to 1. We then get to R*R by shifting left by _W
   357  	// n + 1 times.
   358  	n := len(rr.limbs)
   359  	rr.limbs[n-1] = 1
   360  	for i := n - 1; i < 2*n; i++ {
   361  		rr.shiftIn(0, m) // x = x * 2^_W mod m
   362  	}
   363  	return rr
   364  }
   365  
   366  // minusInverseModW computes -x⁻¹ mod _W with x odd.
   367  //
   368  // This operation is used to precompute a constant involved in Montgomery
   369  // multiplication.
   370  func minusInverseModW(x uint) uint {
   371  	// Every iteration of this loop doubles the least-significant bits of
   372  	// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
   373  	// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
   374  	// for 61 bits (and wastes only one iteration for 31 bits).
   375  	//
   376  	// See https://crypto.stackexchange.com/a/47496.
   377  	y := x
   378  	for i := 0; i < 5; i++ {
   379  		y = y * (2 - x*y)
   380  	}
   381  	return (1 << _W) - (y & _MASK)
   382  }
   383  
   384  // NewModulusFromBig creates a new Modulus from a [big.Int].
   385  //
   386  // The Int must be odd. The number of significant bits must be leakable.
   387  func NewModulusFromBig(n *big.Int) *Modulus {
   388  	m := &Modulus{}
   389  	m.nat = NewNat().setBig(n)
   390  	m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
   391  	m.m0inv = minusInverseModW(m.nat.limbs[0])
   392  	m.rr = rr(m)
   393  	return m
   394  }
   395  
   396  // bitLen is a version of bits.Len that only leaks the bit length of n, but not
   397  // its value. bits.Len and bits.LeadingZeros use a lookup table for the
   398  // low-order bits on some architectures.
   399  func bitLen(n uint) int {
   400  	var len int
   401  	// We assume, here and elsewhere, that comparison to zero is constant time
   402  	// with respect to different non-zero values.
   403  	for n != 0 {
   404  		len++
   405  		n >>= 1
   406  	}
   407  	return len
   408  }
   409  
   410  // Size returns the size of m in bytes.
   411  func (m *Modulus) Size() int {
   412  	return (m.BitLen() + 7) / 8
   413  }
   414  
   415  // BitLen returns the size of m in bits.
   416  func (m *Modulus) BitLen() int {
   417  	return len(m.nat.limbs)*_W - int(m.leading)
   418  }
   419  
   420  // Nat returns m as a Nat. The return value must not be written to.
   421  func (m *Modulus) Nat() *Nat {
   422  	return m.nat
   423  }
   424  
   425  // shiftIn calculates x = x << _W + y mod m.
   426  //
   427  // This assumes that x is already reduced mod m, and that y < 2^_W.
   428  func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
   429  	d := NewNat().resetFor(m)
   430  
   431  	// Eliminate bounds checks in the loop.
   432  	size := len(m.nat.limbs)
   433  	xLimbs := x.limbs[:size]
   434  	dLimbs := d.limbs[:size]
   435  	mLimbs := m.nat.limbs[:size]
   436  
   437  	// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
   438  	// from y. Effectively, it left-shifts x and adds y one bit at a time,
   439  	// reducing it every time.
   440  	//
   441  	// To do the reduction, each iteration computes both 2x + b and 2x + b - m.
   442  	// The next iteration (and finally the return line) will use either result
   443  	// based on whether the subtraction underflowed.
   444  	needSubtraction := no
   445  	for i := _W - 1; i >= 0; i-- {
   446  		carry := (y >> i) & 1
   447  		var borrow uint
   448  		for i := 0; i < size; i++ {
   449  			l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i])
   450  
   451  			res := l<<1 + carry
   452  			xLimbs[i] = res & _MASK
   453  			carry = res >> _W
   454  
   455  			res = xLimbs[i] - mLimbs[i] - borrow
   456  			dLimbs[i] = res & _MASK
   457  			borrow = res >> _W
   458  		}
   459  		// See Add for how carry (aka overflow), borrow (aka underflow), and
   460  		// needSubtraction relate.
   461  		needSubtraction = ctEq(carry, borrow)
   462  	}
   463  	return x.assign(needSubtraction, d)
   464  }
   465  
   466  // Mod calculates out = x mod m.
   467  //
   468  // This works regardless how large the value of x is.
   469  //
   470  // The output will be resized to the size of m and overwritten.
   471  func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
   472  	out.resetFor(m)
   473  	// Working our way from the most significant to the least significant limb,
   474  	// we can insert each limb at the least significant position, shifting all
   475  	// previous limbs left by _W. This way each limb will get shifted by the
   476  	// correct number of bits. We can insert at least N - 1 limbs without
   477  	// overflowing m. After that, we need to reduce every time we shift.
   478  	i := len(x.limbs) - 1
   479  	// For the first N - 1 limbs we can skip the actual shifting and position
   480  	// them at the shifted position, which starts at min(N - 2, i).
   481  	start := len(m.nat.limbs) - 2
   482  	if i < start {
   483  		start = i
   484  	}
   485  	for j := start; j >= 0; j-- {
   486  		out.limbs[j] = x.limbs[i]
   487  		i--
   488  	}
   489  	// We shift in the remaining limbs, reducing modulo m each time.
   490  	for i >= 0 {
   491  		out.shiftIn(x.limbs[i], m)
   492  		i--
   493  	}
   494  	return out
   495  }
   496  
   497  // ExpandFor ensures out has the right size to work with operations modulo m.
   498  //
   499  // The announced size of out must be smaller than or equal to that of m.
   500  func (out *Nat) ExpandFor(m *Modulus) *Nat {
   501  	return out.expand(len(m.nat.limbs))
   502  }
   503  
   504  // resetFor ensures out has the right size to work with operations modulo m.
   505  //
   506  // out is zeroed and may start at any size.
   507  func (out *Nat) resetFor(m *Modulus) *Nat {
   508  	return out.reset(len(m.nat.limbs))
   509  }
   510  
   511  // Sub computes x = x - y mod m.
   512  //
   513  // The length of both operands must be the same as the modulus. Both operands
   514  // must already be reduced modulo m.
   515  func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
   516  	underflow := x.sub(yes, y)
   517  	// If the subtraction underflowed, add m.
   518  	x.add(choice(underflow), m.nat)
   519  	return x
   520  }
   521  
   522  // Add computes x = x + y mod m.
   523  //
   524  // The length of both operands must be the same as the modulus. Both operands
   525  // must already be reduced modulo m.
   526  func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
   527  	overflow := x.add(yes, y)
   528  	underflow := not(x.cmpGeq(m.nat)) // x < m
   529  
   530  	// Three cases are possible:
   531  	//
   532  	//   - overflow = 0, underflow = 0
   533  	//
   534  	// In this case, addition fits in our limbs, but we can still subtract away
   535  	// m without an underflow, so we need to perform the subtraction to reduce
   536  	// our result.
   537  	//
   538  	//   - overflow = 0, underflow = 1
   539  	//
   540  	// The addition fits in our limbs, but we can't subtract m without
   541  	// underflowing. The result is already reduced.
   542  	//
   543  	//   - overflow = 1, underflow = 1
   544  	//
   545  	// The addition does not fit in our limbs, and the subtraction's borrow
   546  	// would cancel out with the addition's carry. We need to subtract m to
   547  	// reduce our result.
   548  	//
   549  	// The overflow = 1, underflow = 0 case is not possible, because y is at
   550  	// most m - 1, and if adding m - 1 overflows, then subtracting m must
   551  	// necessarily underflow.
   552  	needSubtraction := ctEq(overflow, uint(underflow))
   553  
   554  	x.sub(needSubtraction, m.nat)
   555  	return x
   556  }
   557  
   558  // montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and
   559  // n = len(m.nat.limbs).
   560  //
   561  // Faster Montgomery multiplication replaces standard modular multiplication for
   562  // numbers in this representation.
   563  //
   564  // This assumes that x is already reduced mod m.
   565  func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
   566  	// A Montgomery multiplication (which computes a * b / R) by R * R works out
   567  	// to a multiplication by R, which takes the value out of the Montgomery domain.
   568  	return x.montgomeryMul(NewNat().set(x), m.rr, m)
   569  }
   570  
   571  // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
   572  // n = len(m.nat.limbs).
   573  //
   574  // This assumes that x is already reduced mod m.
   575  func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
   576  	// By Montgomery multiplying with 1 not in Montgomery representation, we
   577  	// convert out back from Montgomery representation, because it works out to
   578  	// dividing by R.
   579  	t0 := NewNat().set(x)
   580  	t1 := NewNat().ExpandFor(m)
   581  	t1.limbs[0] = 1
   582  	return x.montgomeryMul(t0, t1, m)
   583  }
   584  
   585  // montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and
   586  // n = len(m.nat.limbs), using the Montgomery Multiplication technique.
   587  //
   588  // All inputs should be the same length, not aliasing d, and already
   589  // reduced modulo m. d will be resized to the size of m and overwritten.
   590  func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
   591  	d.resetFor(m)
   592  	if len(a.limbs) != len(m.nat.limbs) || len(b.limbs) != len(m.nat.limbs) {
   593  		panic("bigmod: invalid montgomeryMul input")
   594  	}
   595  
   596  	// See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
   597  	// for a description of the algorithm implemented mostly in montgomeryLoop.
   598  	// See Add for how overflow, underflow, and needSubtraction relate.
   599  	overflow := montgomeryLoop(d.limbs, a.limbs, b.limbs, m.nat.limbs, m.m0inv)
   600  	underflow := not(d.cmpGeq(m.nat)) // d < m
   601  	needSubtraction := ctEq(overflow, uint(underflow))
   602  	d.sub(needSubtraction, m.nat)
   603  
   604  	return d
   605  }
   606  
   607  func montgomeryLoopGeneric(d, a, b, m []uint, m0inv uint) (overflow uint) {
   608  	// Eliminate bounds checks in the loop.
   609  	size := len(d)
   610  	a = a[:size]
   611  	b = b[:size]
   612  	m = m[:size]
   613  
   614  	for _, ai := range a {
   615  		// This is an unrolled iteration of the loop below with j = 0.
   616  		hi, lo := bits.Mul(ai, b[0])
   617  		z_lo, c := bits.Add(d[0], lo, 0)
   618  		f := (z_lo * m0inv) & _MASK // (d[0] + a[i] * b[0]) * m0inv
   619  		z_hi, _ := bits.Add(0, hi, c)
   620  		hi, lo = bits.Mul(f, m[0])
   621  		z_lo, c = bits.Add(z_lo, lo, 0)
   622  		z_hi, _ = bits.Add(z_hi, hi, c)
   623  		carry := z_hi<<1 | z_lo>>_W
   624  
   625  		for j := 1; j < size; j++ {
   626  			// z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W
   627  			hi, lo := bits.Mul(ai, b[j])
   628  			z_lo, c := bits.Add(d[j], lo, 0)
   629  			z_hi, _ := bits.Add(0, hi, c)
   630  			hi, lo = bits.Mul(f, m[j])
   631  			z_lo, c = bits.Add(z_lo, lo, 0)
   632  			z_hi, _ = bits.Add(z_hi, hi, c)
   633  			z_lo, c = bits.Add(z_lo, carry, 0)
   634  			z_hi, _ = bits.Add(z_hi, 0, c)
   635  			d[j-1] = z_lo & _MASK
   636  			carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2
   637  		}
   638  
   639  		z := overflow + carry // z <= 2^(W+1) - 1
   640  		d[size-1] = z & _MASK
   641  		overflow = z >> _W // overflow <= 1
   642  	}
   643  	return
   644  }
   645  
   646  // Mul calculates x *= y mod m.
   647  //
   648  // x and y must already be reduced modulo m, they must share its announced
   649  // length, and they may not alias.
   650  func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
   651  	// A Montgomery multiplication by a value out of the Montgomery domain
   652  	// takes the result out of Montgomery representation.
   653  	xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
   654  	return x.montgomeryMul(xR, y, m)                  // x = xR * y / R mod m
   655  }
   656  
   657  // Exp calculates out = x^e mod m.
   658  //
   659  // The exponent e is represented in big-endian order. The output will be resized
   660  // to the size of m and overwritten. x must already be reduced modulo m.
   661  func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
   662  	// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
   663  	// than 2 bit windows, but use an extra 12 nats worth of scratch space.
   664  	// Using bit sizes that don't divide 8 are more complex to implement.
   665  
   666  	table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
   667  		// newNat calls are unrolled so they are allocated on the stack.
   668  		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
   669  		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
   670  		NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
   671  	}
   672  	table[0].set(x).montgomeryRepresentation(m)
   673  	for i := 1; i < len(table); i++ {
   674  		table[i].montgomeryMul(table[i-1], table[0], m)
   675  	}
   676  
   677  	out.resetFor(m)
   678  	out.limbs[0] = 1
   679  	out.montgomeryRepresentation(m)
   680  	t0 := NewNat().ExpandFor(m)
   681  	t1 := NewNat().ExpandFor(m)
   682  	for _, b := range e {
   683  		for _, j := range []int{4, 0} {
   684  			// Square four times.
   685  			t1.montgomeryMul(out, out, m)
   686  			out.montgomeryMul(t1, t1, m)
   687  			t1.montgomeryMul(out, out, m)
   688  			out.montgomeryMul(t1, t1, m)
   689  
   690  			// Select x^k in constant time from the table.
   691  			k := uint((b >> j) & 0b1111)
   692  			for i := range table {
   693  				t0.assign(ctEq(k, uint(i+1)), table[i])
   694  			}
   695  
   696  			// Multiply by x^k, discarding the result if k = 0.
   697  			t1.montgomeryMul(out, t0, m)
   698  			out.assign(not(ctEq(k, 0)), t1)
   699  		}
   700  	}
   701  
   702  	return out.montgomeryReduction(m)
   703  }