github.com/aloncn/graphics-go@v0.0.1/src/math/big/nat.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This file implements unsigned multi-precision integers (natural
     6  // numbers). They are the building blocks for the implementation
     7  // of signed integers, rationals, and floating-point numbers.
     8  
     9  package big
    10  
    11  import "math/rand"
    12  
    13  // An unsigned integer x of the form
    14  //
    15  //   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
    16  //
    17  // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
    18  // with the digits x[i] as the slice elements.
    19  //
    20  // A number is normalized if the slice contains no leading 0 digits.
    21  // During arithmetic operations, denormalized values may occur but are
    22  // always normalized before returning the final result. The normalized
    23  // representation of 0 is the empty or nil slice (length = 0).
    24  //
    25  type nat []Word
    26  
    27  var (
    28  	natOne = nat{1}
    29  	natTwo = nat{2}
    30  	natTen = nat{10}
    31  )
    32  
    33  func (z nat) clear() {
    34  	for i := range z {
    35  		z[i] = 0
    36  	}
    37  }
    38  
    39  func (z nat) norm() nat {
    40  	i := len(z)
    41  	for i > 0 && z[i-1] == 0 {
    42  		i--
    43  	}
    44  	return z[0:i]
    45  }
    46  
    47  func (z nat) make(n int) nat {
    48  	if n <= cap(z) {
    49  		return z[:n] // reuse z
    50  	}
    51  	// Choosing a good value for e has significant performance impact
    52  	// because it increases the chance that a value can be reused.
    53  	const e = 4 // extra capacity
    54  	return make(nat, n, n+e)
    55  }
    56  
    57  func (z nat) setWord(x Word) nat {
    58  	if x == 0 {
    59  		return z[:0]
    60  	}
    61  	z = z.make(1)
    62  	z[0] = x
    63  	return z
    64  }
    65  
    66  func (z nat) setUint64(x uint64) nat {
    67  	// single-digit values
    68  	if w := Word(x); uint64(w) == x {
    69  		return z.setWord(w)
    70  	}
    71  
    72  	// compute number of words n required to represent x
    73  	n := 0
    74  	for t := x; t > 0; t >>= _W {
    75  		n++
    76  	}
    77  
    78  	// split x into n words
    79  	z = z.make(n)
    80  	for i := range z {
    81  		z[i] = Word(x & _M)
    82  		x >>= _W
    83  	}
    84  
    85  	return z
    86  }
    87  
    88  func (z nat) set(x nat) nat {
    89  	z = z.make(len(x))
    90  	copy(z, x)
    91  	return z
    92  }
    93  
    94  func (z nat) add(x, y nat) nat {
    95  	m := len(x)
    96  	n := len(y)
    97  
    98  	switch {
    99  	case m < n:
   100  		return z.add(y, x)
   101  	case m == 0:
   102  		// n == 0 because m >= n; result is 0
   103  		return z[:0]
   104  	case n == 0:
   105  		// result is x
   106  		return z.set(x)
   107  	}
   108  	// m > 0
   109  
   110  	z = z.make(m + 1)
   111  	c := addVV(z[0:n], x, y)
   112  	if m > n {
   113  		c = addVW(z[n:m], x[n:], c)
   114  	}
   115  	z[m] = c
   116  
   117  	return z.norm()
   118  }
   119  
   120  func (z nat) sub(x, y nat) nat {
   121  	m := len(x)
   122  	n := len(y)
   123  
   124  	switch {
   125  	case m < n:
   126  		panic("underflow")
   127  	case m == 0:
   128  		// n == 0 because m >= n; result is 0
   129  		return z[:0]
   130  	case n == 0:
   131  		// result is x
   132  		return z.set(x)
   133  	}
   134  	// m > 0
   135  
   136  	z = z.make(m)
   137  	c := subVV(z[0:n], x, y)
   138  	if m > n {
   139  		c = subVW(z[n:], x[n:], c)
   140  	}
   141  	if c != 0 {
   142  		panic("underflow")
   143  	}
   144  
   145  	return z.norm()
   146  }
   147  
   148  func (x nat) cmp(y nat) (r int) {
   149  	m := len(x)
   150  	n := len(y)
   151  	if m != n || m == 0 {
   152  		switch {
   153  		case m < n:
   154  			r = -1
   155  		case m > n:
   156  			r = 1
   157  		}
   158  		return
   159  	}
   160  
   161  	i := m - 1
   162  	for i > 0 && x[i] == y[i] {
   163  		i--
   164  	}
   165  
   166  	switch {
   167  	case x[i] < y[i]:
   168  		r = -1
   169  	case x[i] > y[i]:
   170  		r = 1
   171  	}
   172  	return
   173  }
   174  
   175  func (z nat) mulAddWW(x nat, y, r Word) nat {
   176  	m := len(x)
   177  	if m == 0 || y == 0 {
   178  		return z.setWord(r) // result is r
   179  	}
   180  	// m > 0
   181  
   182  	z = z.make(m + 1)
   183  	z[m] = mulAddVWW(z[0:m], x, y, r)
   184  
   185  	return z.norm()
   186  }
   187  
   188  // basicMul multiplies x and y and leaves the result in z.
   189  // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   190  func basicMul(z, x, y nat) {
   191  	z[0 : len(x)+len(y)].clear() // initialize z
   192  	for i, d := range y {
   193  		if d != 0 {
   194  			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   195  		}
   196  	}
   197  }
   198  
   199  // montgomery computes z mod m = x*y*2**(-n*_W) mod m,
   200  // assuming k = -1/m mod 2**_W.
   201  // z is used for storing the result which is returned;
   202  // z must not alias x, y or m.
   203  // See Gueron, "Efficient Software Implementations of Modular Exponentiation".
   204  // https://eprint.iacr.org/2011/239.pdf
   205  // In the terminology of that paper, this is an "Almost Montgomery Multiplication":
   206  // x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
   207  // z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
   208  func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
   209  	// This code assumes x, y, m are all the same length, n.
   210  	// (required by addMulVVW and the for loop).
   211  	// It also assumes that x, y are already reduced mod m,
   212  	// or else the result will not be properly reduced.
   213  	if len(x) != n || len(y) != n || len(m) != n {
   214  		panic("math/big: mismatched montgomery number lengths")
   215  	}
   216  	z = z.make(n)
   217  	z.clear()
   218  	var c Word
   219  	for i := 0; i < n; i++ {
   220  		d := y[i]
   221  		c2 := addMulVVW(z, x, d)
   222  		t := z[0] * k
   223  		c3 := addMulVVW(z, m, t)
   224  		copy(z, z[1:])
   225  		cx := c + c2
   226  		cy := cx + c3
   227  		z[n-1] = cy
   228  		if cx < c2 || cy < c3 {
   229  			c = 1
   230  		} else {
   231  			c = 0
   232  		}
   233  	}
   234  	if c != 0 {
   235  		subVV(z, z, m)
   236  	}
   237  	return z
   238  }
   239  
   240  // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
   241  // Factored out for readability - do not use outside karatsuba.
   242  func karatsubaAdd(z, x nat, n int) {
   243  	if c := addVV(z[0:n], z, x); c != 0 {
   244  		addVW(z[n:n+n>>1], z[n:], c)
   245  	}
   246  }
   247  
   248  // Like karatsubaAdd, but does subtract.
   249  func karatsubaSub(z, x nat, n int) {
   250  	if c := subVV(z[0:n], z, x); c != 0 {
   251  		subVW(z[n:n+n>>1], z[n:], c)
   252  	}
   253  }
   254  
   255  // Operands that are shorter than karatsubaThreshold are multiplied using
   256  // "grade school" multiplication; for longer operands the Karatsuba algorithm
   257  // is used.
   258  var karatsubaThreshold int = 40 // computed by calibrate.go
   259  
   260  // karatsuba multiplies x and y and leaves the result in z.
   261  // Both x and y must have the same length n and n must be a
   262  // power of 2. The result vector z must have len(z) >= 6*n.
   263  // The (non-normalized) result is placed in z[0 : 2*n].
   264  func karatsuba(z, x, y nat) {
   265  	n := len(y)
   266  
   267  	// Switch to basic multiplication if numbers are odd or small.
   268  	// (n is always even if karatsubaThreshold is even, but be
   269  	// conservative)
   270  	if n&1 != 0 || n < karatsubaThreshold || n < 2 {
   271  		basicMul(z, x, y)
   272  		return
   273  	}
   274  	// n&1 == 0 && n >= karatsubaThreshold && n >= 2
   275  
   276  	// Karatsuba multiplication is based on the observation that
   277  	// for two numbers x and y with:
   278  	//
   279  	//   x = x1*b + x0
   280  	//   y = y1*b + y0
   281  	//
   282  	// the product x*y can be obtained with 3 products z2, z1, z0
   283  	// instead of 4:
   284  	//
   285  	//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
   286  	//       =    z2*b*b +              z1*b +    z0
   287  	//
   288  	// with:
   289  	//
   290  	//   xd = x1 - x0
   291  	//   yd = y0 - y1
   292  	//
   293  	//   z1 =      xd*yd                    + z2 + z0
   294  	//      = (x1-x0)*(y0 - y1)             + z2 + z0
   295  	//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
   296  	//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
   297  	//      = x1*y0                 + x0*y1
   298  
   299  	// split x, y into "digits"
   300  	n2 := n >> 1              // n2 >= 1
   301  	x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
   302  	y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
   303  
   304  	// z is used for the result and temporary storage:
   305  	//
   306  	//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
   307  	// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
   308  	//
   309  	// For each recursive call of karatsuba, an unused slice of
   310  	// z is passed in that has (at least) half the length of the
   311  	// caller's z.
   312  
   313  	// compute z0 and z2 with the result "in place" in z
   314  	karatsuba(z, x0, y0)     // z0 = x0*y0
   315  	karatsuba(z[n:], x1, y1) // z2 = x1*y1
   316  
   317  	// compute xd (or the negative value if underflow occurs)
   318  	s := 1 // sign of product xd*yd
   319  	xd := z[2*n : 2*n+n2]
   320  	if subVV(xd, x1, x0) != 0 { // x1-x0
   321  		s = -s
   322  		subVV(xd, x0, x1) // x0-x1
   323  	}
   324  
   325  	// compute yd (or the negative value if underflow occurs)
   326  	yd := z[2*n+n2 : 3*n]
   327  	if subVV(yd, y0, y1) != 0 { // y0-y1
   328  		s = -s
   329  		subVV(yd, y1, y0) // y1-y0
   330  	}
   331  
   332  	// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
   333  	// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
   334  	p := z[n*3:]
   335  	karatsuba(p, xd, yd)
   336  
   337  	// save original z2:z0
   338  	// (ok to use upper half of z since we're done recursing)
   339  	r := z[n*4:]
   340  	copy(r, z[:n*2])
   341  
   342  	// add up all partial products
   343  	//
   344  	//   2*n     n     0
   345  	// z = [ z2  | z0  ]
   346  	//   +    [ z0  ]
   347  	//   +    [ z2  ]
   348  	//   +    [  p  ]
   349  	//
   350  	karatsubaAdd(z[n2:], r, n)
   351  	karatsubaAdd(z[n2:], r[n:], n)
   352  	if s > 0 {
   353  		karatsubaAdd(z[n2:], p, n)
   354  	} else {
   355  		karatsubaSub(z[n2:], p, n)
   356  	}
   357  }
   358  
   359  // alias reports whether x and y share the same base array.
   360  func alias(x, y nat) bool {
   361  	return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
   362  }
   363  
   364  // addAt implements z += x<<(_W*i); z must be long enough.
   365  // (we don't use nat.add because we need z to stay the same
   366  // slice, and we don't need to normalize z after each addition)
   367  func addAt(z, x nat, i int) {
   368  	if n := len(x); n > 0 {
   369  		if c := addVV(z[i:i+n], z[i:], x); c != 0 {
   370  			j := i + n
   371  			if j < len(z) {
   372  				addVW(z[j:], z[j:], c)
   373  			}
   374  		}
   375  	}
   376  }
   377  
   378  func max(x, y int) int {
   379  	if x > y {
   380  		return x
   381  	}
   382  	return y
   383  }
   384  
   385  // karatsubaLen computes an approximation to the maximum k <= n such that
   386  // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
   387  // result is the largest number that can be divided repeatedly by 2 before
   388  // becoming about the value of karatsubaThreshold.
   389  func karatsubaLen(n int) int {
   390  	i := uint(0)
   391  	for n > karatsubaThreshold {
   392  		n >>= 1
   393  		i++
   394  	}
   395  	return n << i
   396  }
   397  
   398  func (z nat) mul(x, y nat) nat {
   399  	m := len(x)
   400  	n := len(y)
   401  
   402  	switch {
   403  	case m < n:
   404  		return z.mul(y, x)
   405  	case m == 0 || n == 0:
   406  		return z[:0]
   407  	case n == 1:
   408  		return z.mulAddWW(x, y[0], 0)
   409  	}
   410  	// m >= n > 1
   411  
   412  	// determine if z can be reused
   413  	if alias(z, x) || alias(z, y) {
   414  		z = nil // z is an alias for x or y - cannot reuse
   415  	}
   416  
   417  	// use basic multiplication if the numbers are small
   418  	if n < karatsubaThreshold {
   419  		z = z.make(m + n)
   420  		basicMul(z, x, y)
   421  		return z.norm()
   422  	}
   423  	// m >= n && n >= karatsubaThreshold && n >= 2
   424  
   425  	// determine Karatsuba length k such that
   426  	//
   427  	//   x = xh*b + x0  (0 <= x0 < b)
   428  	//   y = yh*b + y0  (0 <= y0 < b)
   429  	//   b = 1<<(_W*k)  ("base" of digits xi, yi)
   430  	//
   431  	k := karatsubaLen(n)
   432  	// k <= n
   433  
   434  	// multiply x0 and y0 via Karatsuba
   435  	x0 := x[0:k]              // x0 is not normalized
   436  	y0 := y[0:k]              // y0 is not normalized
   437  	z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
   438  	karatsuba(z, x0, y0)
   439  	z = z[0 : m+n]  // z has final length but may be incomplete
   440  	z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
   441  
   442  	// If xh != 0 or yh != 0, add the missing terms to z. For
   443  	//
   444  	//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
   445  	//   yh =                         y1*b (0 <= y1 < b)
   446  	//
   447  	// the missing terms are
   448  	//
   449  	//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
   450  	//
   451  	// since all the yi for i > 1 are 0 by choice of k: If any of them
   452  	// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
   453  	// be a larger valid threshold contradicting the assumption about k.
   454  	//
   455  	if k < n || m != n {
   456  		var t nat
   457  
   458  		// add x0*y1*b
   459  		x0 := x0.norm()
   460  		y1 := y[k:]       // y1 is normalized because y is
   461  		t = t.mul(x0, y1) // update t so we don't lose t's underlying array
   462  		addAt(z, t, k)
   463  
   464  		// add xi*y0<<i, xi*y1*b<<(i+k)
   465  		y0 := y0.norm()
   466  		for i := k; i < len(x); i += k {
   467  			xi := x[i:]
   468  			if len(xi) > k {
   469  				xi = xi[:k]
   470  			}
   471  			xi = xi.norm()
   472  			t = t.mul(xi, y0)
   473  			addAt(z, t, i)
   474  			t = t.mul(xi, y1)
   475  			addAt(z, t, i+k)
   476  		}
   477  	}
   478  
   479  	return z.norm()
   480  }
   481  
   482  // mulRange computes the product of all the unsigned integers in the
   483  // range [a, b] inclusively. If a > b (empty range), the result is 1.
   484  func (z nat) mulRange(a, b uint64) nat {
   485  	switch {
   486  	case a == 0:
   487  		// cut long ranges short (optimization)
   488  		return z.setUint64(0)
   489  	case a > b:
   490  		return z.setUint64(1)
   491  	case a == b:
   492  		return z.setUint64(a)
   493  	case a+1 == b:
   494  		return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
   495  	}
   496  	m := (a + b) / 2
   497  	return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
   498  }
   499  
   500  // q = (x-r)/y, with 0 <= r < y
   501  func (z nat) divW(x nat, y Word) (q nat, r Word) {
   502  	m := len(x)
   503  	switch {
   504  	case y == 0:
   505  		panic("division by zero")
   506  	case y == 1:
   507  		q = z.set(x) // result is x
   508  		return
   509  	case m == 0:
   510  		q = z[:0] // result is 0
   511  		return
   512  	}
   513  	// m > 0
   514  	z = z.make(m)
   515  	r = divWVW(z, 0, x, y)
   516  	q = z.norm()
   517  	return
   518  }
   519  
   520  func (z nat) div(z2, u, v nat) (q, r nat) {
   521  	if len(v) == 0 {
   522  		panic("division by zero")
   523  	}
   524  
   525  	if u.cmp(v) < 0 {
   526  		q = z[:0]
   527  		r = z2.set(u)
   528  		return
   529  	}
   530  
   531  	if len(v) == 1 {
   532  		var r2 Word
   533  		q, r2 = z.divW(u, v[0])
   534  		r = z2.setWord(r2)
   535  		return
   536  	}
   537  
   538  	q, r = z.divLarge(z2, u, v)
   539  	return
   540  }
   541  
   542  // q = (uIn-r)/v, with 0 <= r < y
   543  // Uses z as storage for q, and u as storage for r if possible.
   544  // See Knuth, Volume 2, section 4.3.1, Algorithm D.
   545  // Preconditions:
   546  //    len(v) >= 2
   547  //    len(uIn) >= len(v)
   548  func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
   549  	n := len(v)
   550  	m := len(uIn) - n
   551  
   552  	// determine if z can be reused
   553  	// TODO(gri) should find a better solution - this if statement
   554  	//           is very costly (see e.g. time pidigits -s -n 10000)
   555  	if alias(z, uIn) || alias(z, v) {
   556  		z = nil // z is an alias for uIn or v - cannot reuse
   557  	}
   558  	q = z.make(m + 1)
   559  
   560  	qhatv := make(nat, n+1)
   561  	if alias(u, uIn) || alias(u, v) {
   562  		u = nil // u is an alias for uIn or v - cannot reuse
   563  	}
   564  	u = u.make(len(uIn) + 1)
   565  	u.clear() // TODO(gri) no need to clear if we allocated a new u
   566  
   567  	// D1.
   568  	shift := nlz(v[n-1])
   569  	if shift > 0 {
   570  		// do not modify v, it may be used by another goroutine simultaneously
   571  		v1 := make(nat, n)
   572  		shlVU(v1, v, shift)
   573  		v = v1
   574  	}
   575  	u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
   576  
   577  	// D2.
   578  	for j := m; j >= 0; j-- {
   579  		// D3.
   580  		qhat := Word(_M)
   581  		if u[j+n] != v[n-1] {
   582  			var rhat Word
   583  			qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1])
   584  
   585  			// x1 | x2 = q̂v_{n-2}
   586  			x1, x2 := mulWW(qhat, v[n-2])
   587  			// test if q̂v_{n-2} > br̂ + u_{j+n-2}
   588  			for greaterThan(x1, x2, rhat, u[j+n-2]) {
   589  				qhat--
   590  				prevRhat := rhat
   591  				rhat += v[n-1]
   592  				// v[n-1] >= 0, so this tests for overflow.
   593  				if rhat < prevRhat {
   594  					break
   595  				}
   596  				x1, x2 = mulWW(qhat, v[n-2])
   597  			}
   598  		}
   599  
   600  		// D4.
   601  		qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
   602  
   603  		c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
   604  		if c != 0 {
   605  			c := addVV(u[j:j+n], u[j:], v)
   606  			u[j+n] += c
   607  			qhat--
   608  		}
   609  
   610  		q[j] = qhat
   611  	}
   612  
   613  	q = q.norm()
   614  	shrVU(u, u, shift)
   615  	r = u.norm()
   616  
   617  	return q, r
   618  }
   619  
   620  // Length of x in bits. x must be normalized.
   621  func (x nat) bitLen() int {
   622  	if i := len(x) - 1; i >= 0 {
   623  		return i*_W + bitLen(x[i])
   624  	}
   625  	return 0
   626  }
   627  
   628  const deBruijn32 = 0x077CB531
   629  
   630  var deBruijn32Lookup = []byte{
   631  	0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
   632  	31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
   633  }
   634  
   635  const deBruijn64 = 0x03f79d71b4ca8b09
   636  
   637  var deBruijn64Lookup = []byte{
   638  	0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
   639  	62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
   640  	63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
   641  	54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
   642  }
   643  
   644  // trailingZeroBits returns the number of consecutive least significant zero
   645  // bits of x.
   646  func trailingZeroBits(x Word) uint {
   647  	// x & -x leaves only the right-most bit set in the word. Let k be the
   648  	// index of that bit. Since only a single bit is set, the value is two
   649  	// to the power of k. Multiplying by a power of two is equivalent to
   650  	// left shifting, in this case by k bits.  The de Bruijn constant is
   651  	// such that all six bit, consecutive substrings are distinct.
   652  	// Therefore, if we have a left shifted version of this constant we can
   653  	// find by how many bits it was shifted by looking at which six bit
   654  	// substring ended up at the top of the word.
   655  	// (Knuth, volume 4, section 7.3.1)
   656  	switch _W {
   657  	case 32:
   658  		return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
   659  	case 64:
   660  		return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
   661  	default:
   662  		panic("unknown word size")
   663  	}
   664  }
   665  
   666  // trailingZeroBits returns the number of consecutive least significant zero
   667  // bits of x.
   668  func (x nat) trailingZeroBits() uint {
   669  	if len(x) == 0 {
   670  		return 0
   671  	}
   672  	var i uint
   673  	for x[i] == 0 {
   674  		i++
   675  	}
   676  	// x[i] != 0
   677  	return i*_W + trailingZeroBits(x[i])
   678  }
   679  
   680  // z = x << s
   681  func (z nat) shl(x nat, s uint) nat {
   682  	m := len(x)
   683  	if m == 0 {
   684  		return z[:0]
   685  	}
   686  	// m > 0
   687  
   688  	n := m + int(s/_W)
   689  	z = z.make(n + 1)
   690  	z[n] = shlVU(z[n-m:n], x, s%_W)
   691  	z[0 : n-m].clear()
   692  
   693  	return z.norm()
   694  }
   695  
   696  // z = x >> s
   697  func (z nat) shr(x nat, s uint) nat {
   698  	m := len(x)
   699  	n := m - int(s/_W)
   700  	if n <= 0 {
   701  		return z[:0]
   702  	}
   703  	// n > 0
   704  
   705  	z = z.make(n)
   706  	shrVU(z, x[m-n:], s%_W)
   707  
   708  	return z.norm()
   709  }
   710  
   711  func (z nat) setBit(x nat, i uint, b uint) nat {
   712  	j := int(i / _W)
   713  	m := Word(1) << (i % _W)
   714  	n := len(x)
   715  	switch b {
   716  	case 0:
   717  		z = z.make(n)
   718  		copy(z, x)
   719  		if j >= n {
   720  			// no need to grow
   721  			return z
   722  		}
   723  		z[j] &^= m
   724  		return z.norm()
   725  	case 1:
   726  		if j >= n {
   727  			z = z.make(j + 1)
   728  			z[n:].clear()
   729  		} else {
   730  			z = z.make(n)
   731  		}
   732  		copy(z, x)
   733  		z[j] |= m
   734  		// no need to normalize
   735  		return z
   736  	}
   737  	panic("set bit is not 0 or 1")
   738  }
   739  
   740  // bit returns the value of the i'th bit, with lsb == bit 0.
   741  func (x nat) bit(i uint) uint {
   742  	j := i / _W
   743  	if j >= uint(len(x)) {
   744  		return 0
   745  	}
   746  	// 0 <= j < len(x)
   747  	return uint(x[j] >> (i % _W) & 1)
   748  }
   749  
   750  // sticky returns 1 if there's a 1 bit within the
   751  // i least significant bits, otherwise it returns 0.
   752  func (x nat) sticky(i uint) uint {
   753  	j := i / _W
   754  	if j >= uint(len(x)) {
   755  		if len(x) == 0 {
   756  			return 0
   757  		}
   758  		return 1
   759  	}
   760  	// 0 <= j < len(x)
   761  	for _, x := range x[:j] {
   762  		if x != 0 {
   763  			return 1
   764  		}
   765  	}
   766  	if x[j]<<(_W-i%_W) != 0 {
   767  		return 1
   768  	}
   769  	return 0
   770  }
   771  
   772  func (z nat) and(x, y nat) nat {
   773  	m := len(x)
   774  	n := len(y)
   775  	if m > n {
   776  		m = n
   777  	}
   778  	// m <= n
   779  
   780  	z = z.make(m)
   781  	for i := 0; i < m; i++ {
   782  		z[i] = x[i] & y[i]
   783  	}
   784  
   785  	return z.norm()
   786  }
   787  
   788  func (z nat) andNot(x, y nat) nat {
   789  	m := len(x)
   790  	n := len(y)
   791  	if n > m {
   792  		n = m
   793  	}
   794  	// m >= n
   795  
   796  	z = z.make(m)
   797  	for i := 0; i < n; i++ {
   798  		z[i] = x[i] &^ y[i]
   799  	}
   800  	copy(z[n:m], x[n:m])
   801  
   802  	return z.norm()
   803  }
   804  
   805  func (z nat) or(x, y nat) nat {
   806  	m := len(x)
   807  	n := len(y)
   808  	s := x
   809  	if m < n {
   810  		n, m = m, n
   811  		s = y
   812  	}
   813  	// m >= n
   814  
   815  	z = z.make(m)
   816  	for i := 0; i < n; i++ {
   817  		z[i] = x[i] | y[i]
   818  	}
   819  	copy(z[n:m], s[n:m])
   820  
   821  	return z.norm()
   822  }
   823  
   824  func (z nat) xor(x, y nat) nat {
   825  	m := len(x)
   826  	n := len(y)
   827  	s := x
   828  	if m < n {
   829  		n, m = m, n
   830  		s = y
   831  	}
   832  	// m >= n
   833  
   834  	z = z.make(m)
   835  	for i := 0; i < n; i++ {
   836  		z[i] = x[i] ^ y[i]
   837  	}
   838  	copy(z[n:m], s[n:m])
   839  
   840  	return z.norm()
   841  }
   842  
   843  // greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
   844  func greaterThan(x1, x2, y1, y2 Word) bool {
   845  	return x1 > y1 || x1 == y1 && x2 > y2
   846  }
   847  
   848  // modW returns x % d.
   849  func (x nat) modW(d Word) (r Word) {
   850  	// TODO(agl): we don't actually need to store the q value.
   851  	var q nat
   852  	q = q.make(len(x))
   853  	return divWVW(q, 0, x, d)
   854  }
   855  
   856  // random creates a random integer in [0..limit), using the space in z if
   857  // possible. n is the bit length of limit.
   858  func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
   859  	if alias(z, limit) {
   860  		z = nil // z is an alias for limit - cannot reuse
   861  	}
   862  	z = z.make(len(limit))
   863  
   864  	bitLengthOfMSW := uint(n % _W)
   865  	if bitLengthOfMSW == 0 {
   866  		bitLengthOfMSW = _W
   867  	}
   868  	mask := Word((1 << bitLengthOfMSW) - 1)
   869  
   870  	for {
   871  		switch _W {
   872  		case 32:
   873  			for i := range z {
   874  				z[i] = Word(rand.Uint32())
   875  			}
   876  		case 64:
   877  			for i := range z {
   878  				z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
   879  			}
   880  		default:
   881  			panic("unknown word size")
   882  		}
   883  		z[len(limit)-1] &= mask
   884  		if z.cmp(limit) < 0 {
   885  			break
   886  		}
   887  	}
   888  
   889  	return z.norm()
   890  }
   891  
   892  // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
   893  // otherwise it sets z to x**y. The result is the value of z.
   894  func (z nat) expNN(x, y, m nat) nat {
   895  	if alias(z, x) || alias(z, y) {
   896  		// We cannot allow in-place modification of x or y.
   897  		z = nil
   898  	}
   899  
   900  	// x**y mod 1 == 0
   901  	if len(m) == 1 && m[0] == 1 {
   902  		return z.setWord(0)
   903  	}
   904  	// m == 0 || m > 1
   905  
   906  	// x**0 == 1
   907  	if len(y) == 0 {
   908  		return z.setWord(1)
   909  	}
   910  	// y > 0
   911  
   912  	// x**1 mod m == x mod m
   913  	if len(y) == 1 && y[0] == 1 && len(m) != 0 {
   914  		_, z = z.div(z, x, m)
   915  		return z
   916  	}
   917  	// y > 1
   918  
   919  	if len(m) != 0 {
   920  		// We likely end up being as long as the modulus.
   921  		z = z.make(len(m))
   922  	}
   923  	z = z.set(x)
   924  
   925  	// If the base is non-trivial and the exponent is large, we use
   926  	// 4-bit, windowed exponentiation. This involves precomputing 14 values
   927  	// (x^2...x^15) but then reduces the number of multiply-reduces by a
   928  	// third. Even for a 32-bit exponent, this reduces the number of
   929  	// operations. Uses Montgomery method for odd moduli.
   930  	if len(x) > 1 && len(y) > 1 && len(m) > 0 {
   931  		if m[0]&1 == 1 {
   932  			return z.expNNMontgomery(x, y, m)
   933  		}
   934  		return z.expNNWindowed(x, y, m)
   935  	}
   936  
   937  	v := y[len(y)-1] // v > 0 because y is normalized and y > 0
   938  	shift := nlz(v) + 1
   939  	v <<= shift
   940  	var q nat
   941  
   942  	const mask = 1 << (_W - 1)
   943  
   944  	// We walk through the bits of the exponent one by one. Each time we
   945  	// see a bit, we square, thus doubling the power. If the bit is a one,
   946  	// we also multiply by x, thus adding one to the power.
   947  
   948  	w := _W - int(shift)
   949  	// zz and r are used to avoid allocating in mul and div as
   950  	// otherwise the arguments would alias.
   951  	var zz, r nat
   952  	for j := 0; j < w; j++ {
   953  		zz = zz.mul(z, z)
   954  		zz, z = z, zz
   955  
   956  		if v&mask != 0 {
   957  			zz = zz.mul(z, x)
   958  			zz, z = z, zz
   959  		}
   960  
   961  		if len(m) != 0 {
   962  			zz, r = zz.div(r, z, m)
   963  			zz, r, q, z = q, z, zz, r
   964  		}
   965  
   966  		v <<= 1
   967  	}
   968  
   969  	for i := len(y) - 2; i >= 0; i-- {
   970  		v = y[i]
   971  
   972  		for j := 0; j < _W; j++ {
   973  			zz = zz.mul(z, z)
   974  			zz, z = z, zz
   975  
   976  			if v&mask != 0 {
   977  				zz = zz.mul(z, x)
   978  				zz, z = z, zz
   979  			}
   980  
   981  			if len(m) != 0 {
   982  				zz, r = zz.div(r, z, m)
   983  				zz, r, q, z = q, z, zz, r
   984  			}
   985  
   986  			v <<= 1
   987  		}
   988  	}
   989  
   990  	return z.norm()
   991  }
   992  
   993  // expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
   994  func (z nat) expNNWindowed(x, y, m nat) nat {
   995  	// zz and r are used to avoid allocating in mul and div as otherwise
   996  	// the arguments would alias.
   997  	var zz, r nat
   998  
   999  	const n = 4
  1000  	// powers[i] contains x^i.
  1001  	var powers [1 << n]nat
  1002  	powers[0] = natOne
  1003  	powers[1] = x
  1004  	for i := 2; i < 1<<n; i += 2 {
  1005  		p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
  1006  		*p = p.mul(*p2, *p2)
  1007  		zz, r = zz.div(r, *p, m)
  1008  		*p, r = r, *p
  1009  		*p1 = p1.mul(*p, x)
  1010  		zz, r = zz.div(r, *p1, m)
  1011  		*p1, r = r, *p1
  1012  	}
  1013  
  1014  	z = z.setWord(1)
  1015  
  1016  	for i := len(y) - 1; i >= 0; i-- {
  1017  		yi := y[i]
  1018  		for j := 0; j < _W; j += n {
  1019  			if i != len(y)-1 || j != 0 {
  1020  				// Unrolled loop for significant performance
  1021  				// gain.  Use go test -bench=".*" in crypto/rsa
  1022  				// to check performance before making changes.
  1023  				zz = zz.mul(z, z)
  1024  				zz, z = z, zz
  1025  				zz, r = zz.div(r, z, m)
  1026  				z, r = r, z
  1027  
  1028  				zz = zz.mul(z, z)
  1029  				zz, z = z, zz
  1030  				zz, r = zz.div(r, z, m)
  1031  				z, r = r, z
  1032  
  1033  				zz = zz.mul(z, z)
  1034  				zz, z = z, zz
  1035  				zz, r = zz.div(r, z, m)
  1036  				z, r = r, z
  1037  
  1038  				zz = zz.mul(z, z)
  1039  				zz, z = z, zz
  1040  				zz, r = zz.div(r, z, m)
  1041  				z, r = r, z
  1042  			}
  1043  
  1044  			zz = zz.mul(z, powers[yi>>(_W-n)])
  1045  			zz, z = z, zz
  1046  			zz, r = zz.div(r, z, m)
  1047  			z, r = r, z
  1048  
  1049  			yi <<= n
  1050  		}
  1051  	}
  1052  
  1053  	return z.norm()
  1054  }
  1055  
  1056  // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
  1057  // Uses Montgomery representation.
  1058  func (z nat) expNNMontgomery(x, y, m nat) nat {
  1059  	numWords := len(m)
  1060  
  1061  	// We want the lengths of x and m to be equal.
  1062  	// It is OK if x >= m as long as len(x) == len(m).
  1063  	if len(x) > numWords {
  1064  		_, x = nat(nil).div(nil, x, m)
  1065  		// Note: now len(x) <= numWords, not guaranteed ==.
  1066  	}
  1067  	if len(x) < numWords {
  1068  		rr := make(nat, numWords)
  1069  		copy(rr, x)
  1070  		x = rr
  1071  	}
  1072  
  1073  	// Ideally the precomputations would be performed outside, and reused
  1074  	// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
  1075  	// Iteration for Multiplicative Inverses Modulo Prime Powers".
  1076  	k0 := 2 - m[0]
  1077  	t := m[0] - 1
  1078  	for i := 1; i < _W; i <<= 1 {
  1079  		t *= t
  1080  		k0 *= (t + 1)
  1081  	}
  1082  	k0 = -k0
  1083  
  1084  	// RR = 2**(2*_W*len(m)) mod m
  1085  	RR := nat(nil).setWord(1)
  1086  	zz := nat(nil).shl(RR, uint(2*numWords*_W))
  1087  	_, RR = RR.div(RR, zz, m)
  1088  	if len(RR) < numWords {
  1089  		zz = zz.make(numWords)
  1090  		copy(zz, RR)
  1091  		RR = zz
  1092  	}
  1093  	// one = 1, with equal length to that of m
  1094  	one := make(nat, numWords)
  1095  	one[0] = 1
  1096  
  1097  	const n = 4
  1098  	// powers[i] contains x^i
  1099  	var powers [1 << n]nat
  1100  	powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
  1101  	powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
  1102  	for i := 2; i < 1<<n; i++ {
  1103  		powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
  1104  	}
  1105  
  1106  	// initialize z = 1 (Montgomery 1)
  1107  	z = z.make(numWords)
  1108  	copy(z, powers[0])
  1109  
  1110  	zz = zz.make(numWords)
  1111  
  1112  	// same windowed exponent, but with Montgomery multiplications
  1113  	for i := len(y) - 1; i >= 0; i-- {
  1114  		yi := y[i]
  1115  		for j := 0; j < _W; j += n {
  1116  			if i != len(y)-1 || j != 0 {
  1117  				zz = zz.montgomery(z, z, m, k0, numWords)
  1118  				z = z.montgomery(zz, zz, m, k0, numWords)
  1119  				zz = zz.montgomery(z, z, m, k0, numWords)
  1120  				z = z.montgomery(zz, zz, m, k0, numWords)
  1121  			}
  1122  			zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
  1123  			z, zz = zz, z
  1124  			yi <<= n
  1125  		}
  1126  	}
  1127  	// convert to regular number
  1128  	zz = zz.montgomery(z, one, m, k0, numWords)
  1129  
  1130  	// One last reduction, just in case.
  1131  	// See golang.org/issue/13907.
  1132  	if zz.cmp(m) >= 0 {
  1133  		// Common case is m has high bit set; in that case,
  1134  		// since zz is the same length as m, there can be just
  1135  		// one multiple of m to remove. Just subtract.
  1136  		// We think that the subtract should be sufficient in general,
  1137  		// so do that unconditionally, but double-check,
  1138  		// in case our beliefs are wrong.
  1139  		// The div is not expected to be reached.
  1140  		zz = zz.sub(zz, m)
  1141  		if zz.cmp(m) >= 0 {
  1142  			_, zz = nat(nil).div(nil, zz, m)
  1143  		}
  1144  	}
  1145  
  1146  	return zz.norm()
  1147  }
  1148  
  1149  // probablyPrime performs n Miller-Rabin tests to check whether x is prime.
  1150  // If x is prime, it returns true.
  1151  // If x is not prime, it returns false with probability at least 1 - ¼ⁿ.
  1152  //
  1153  // It is not suitable for judging primes that an adversary may have crafted
  1154  // to fool this test.
  1155  func (n nat) probablyPrime(reps int) bool {
  1156  	if len(n) == 0 {
  1157  		return false
  1158  	}
  1159  
  1160  	if len(n) == 1 {
  1161  		if n[0] < 2 {
  1162  			return false
  1163  		}
  1164  
  1165  		if n[0]%2 == 0 {
  1166  			return n[0] == 2
  1167  		}
  1168  
  1169  		// We have to exclude these cases because we reject all
  1170  		// multiples of these numbers below.
  1171  		switch n[0] {
  1172  		case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
  1173  			return true
  1174  		}
  1175  	}
  1176  
  1177  	if n[0]&1 == 0 {
  1178  		return false // n is even
  1179  	}
  1180  
  1181  	const primesProduct32 = 0xC0CFD797         // Π {p ∈ primes, 2 < p <= 29}
  1182  	const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53}
  1183  
  1184  	var r Word
  1185  	switch _W {
  1186  	case 32:
  1187  		r = n.modW(primesProduct32)
  1188  	case 64:
  1189  		r = n.modW(primesProduct64 & _M)
  1190  	default:
  1191  		panic("Unknown word size")
  1192  	}
  1193  
  1194  	if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 ||
  1195  		r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 {
  1196  		return false
  1197  	}
  1198  
  1199  	if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 ||
  1200  		r%43 == 0 || r%47 == 0 || r%53 == 0) {
  1201  		return false
  1202  	}
  1203  
  1204  	nm1 := nat(nil).sub(n, natOne)
  1205  	// determine q, k such that nm1 = q << k
  1206  	k := nm1.trailingZeroBits()
  1207  	q := nat(nil).shr(nm1, k)
  1208  
  1209  	nm3 := nat(nil).sub(nm1, natTwo)
  1210  	rand := rand.New(rand.NewSource(int64(n[0])))
  1211  
  1212  	var x, y, quotient nat
  1213  	nm3Len := nm3.bitLen()
  1214  
  1215  NextRandom:
  1216  	for i := 0; i < reps; i++ {
  1217  		x = x.random(rand, nm3, nm3Len)
  1218  		x = x.add(x, natTwo)
  1219  		y = y.expNN(x, q, n)
  1220  		if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
  1221  			continue
  1222  		}
  1223  		for j := uint(1); j < k; j++ {
  1224  			y = y.mul(y, y)
  1225  			quotient, y = quotient.div(y, y, n)
  1226  			if y.cmp(nm1) == 0 {
  1227  				continue NextRandom
  1228  			}
  1229  			if y.cmp(natOne) == 0 {
  1230  				return false
  1231  			}
  1232  		}
  1233  		return false
  1234  	}
  1235  
  1236  	return true
  1237  }
  1238  
  1239  // bytes writes the value of z into buf using big-endian encoding.
  1240  // len(buf) must be >= len(z)*_S. The value of z is encoded in the
  1241  // slice buf[i:]. The number i of unused bytes at the beginning of
  1242  // buf is returned as result.
  1243  func (z nat) bytes(buf []byte) (i int) {
  1244  	i = len(buf)
  1245  	for _, d := range z {
  1246  		for j := 0; j < _S; j++ {
  1247  			i--
  1248  			buf[i] = byte(d)
  1249  			d >>= 8
  1250  		}
  1251  	}
  1252  
  1253  	for i < len(buf) && buf[i] == 0 {
  1254  		i++
  1255  	}
  1256  
  1257  	return
  1258  }
  1259  
  1260  // setBytes interprets buf as the bytes of a big-endian unsigned
  1261  // integer, sets z to that value, and returns z.
  1262  func (z nat) setBytes(buf []byte) nat {
  1263  	z = z.make((len(buf) + _S - 1) / _S)
  1264  
  1265  	k := 0
  1266  	s := uint(0)
  1267  	var d Word
  1268  	for i := len(buf); i > 0; i-- {
  1269  		d |= Word(buf[i-1]) << s
  1270  		if s += 8; s == _S*8 {
  1271  			z[k] = d
  1272  			k++
  1273  			s = 0
  1274  			d = 0
  1275  		}
  1276  	}
  1277  	if k < len(z) {
  1278  		z[k] = d
  1279  	}
  1280  
  1281  	return z.norm()
  1282  }