github.com/rsc/go@v0.0.0-20150416155037-e040fd465409/src/math/big/nat.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package big implements multi-precision arithmetic (big numbers).
     6  // The following numeric types are supported:
     7  //
     8  //   Int    signed integers
     9  //   Rat    rational numbers
    10  //   Float  floating-point numbers
    11  //
    12  // Methods are typically of the form:
    13  //
    14  //   func (z *T) Unary(x *T) *T        // z = op x
    15  //   func (z *T) Binary(x, y *T) *T    // z = x op y
    16  //   func (x *T) M() T1                // v = x.M()
    17  //
    18  // with T one of Int, Rat, or Float. For unary and binary operations, the
    19  // result is the receiver (usually named z in that case); if it is one of
    20  // the operands x or y it may be overwritten (and its memory reused).
    21  // To enable chaining of operations, the result is also returned. Methods
    22  // returning a result other than *Int, *Rat, or *Float take an operand as
    23  // the receiver (usually named x in that case).
    24  //
    25  package big
    26  
    27  // This file contains operations on unsigned multi-precision integers.
    28  // These are the building blocks for the operations on signed integers
    29  // and rationals.
    30  
    31  import "math/rand"
    32  
    33  // An unsigned integer x of the form
    34  //
    35  //   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
    36  //
    37  // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
    38  // with the digits x[i] as the slice elements.
    39  //
    40  // A number is normalized if the slice contains no leading 0 digits.
    41  // During arithmetic operations, denormalized values may occur but are
    42  // always normalized before returning the final result. The normalized
    43  // representation of 0 is the empty or nil slice (length = 0).
    44  //
    45  type nat []Word
    46  
    47  var (
    48  	natOne = nat{1}
    49  	natTwo = nat{2}
    50  	natTen = nat{10}
    51  )
    52  
    53  func (z nat) clear() {
    54  	for i := range z {
    55  		z[i] = 0
    56  	}
    57  }
    58  
    59  func (z nat) norm() nat {
    60  	i := len(z)
    61  	for i > 0 && z[i-1] == 0 {
    62  		i--
    63  	}
    64  	return z[0:i]
    65  }
    66  
    67  func (z nat) make(n int) nat {
    68  	if n <= cap(z) {
    69  		return z[:n] // reuse z
    70  	}
    71  	// Choosing a good value for e has significant performance impact
    72  	// because it increases the chance that a value can be reused.
    73  	const e = 4 // extra capacity
    74  	return make(nat, n, n+e)
    75  }
    76  
    77  func (z nat) setWord(x Word) nat {
    78  	if x == 0 {
    79  		return z[:0]
    80  	}
    81  	z = z.make(1)
    82  	z[0] = x
    83  	return z
    84  }
    85  
    86  func (z nat) setUint64(x uint64) nat {
    87  	// single-digit values
    88  	if w := Word(x); uint64(w) == x {
    89  		return z.setWord(w)
    90  	}
    91  
    92  	// compute number of words n required to represent x
    93  	n := 0
    94  	for t := x; t > 0; t >>= _W {
    95  		n++
    96  	}
    97  
    98  	// split x into n words
    99  	z = z.make(n)
   100  	for i := range z {
   101  		z[i] = Word(x & _M)
   102  		x >>= _W
   103  	}
   104  
   105  	return z
   106  }
   107  
   108  func (z nat) set(x nat) nat {
   109  	z = z.make(len(x))
   110  	copy(z, x)
   111  	return z
   112  }
   113  
   114  func (z nat) add(x, y nat) nat {
   115  	m := len(x)
   116  	n := len(y)
   117  
   118  	switch {
   119  	case m < n:
   120  		return z.add(y, x)
   121  	case m == 0:
   122  		// n == 0 because m >= n; result is 0
   123  		return z[:0]
   124  	case n == 0:
   125  		// result is x
   126  		return z.set(x)
   127  	}
   128  	// m > 0
   129  
   130  	z = z.make(m + 1)
   131  	c := addVV(z[0:n], x, y)
   132  	if m > n {
   133  		c = addVW(z[n:m], x[n:], c)
   134  	}
   135  	z[m] = c
   136  
   137  	return z.norm()
   138  }
   139  
   140  func (z nat) sub(x, y nat) nat {
   141  	m := len(x)
   142  	n := len(y)
   143  
   144  	switch {
   145  	case m < n:
   146  		panic("underflow")
   147  	case m == 0:
   148  		// n == 0 because m >= n; result is 0
   149  		return z[:0]
   150  	case n == 0:
   151  		// result is x
   152  		return z.set(x)
   153  	}
   154  	// m > 0
   155  
   156  	z = z.make(m)
   157  	c := subVV(z[0:n], x, y)
   158  	if m > n {
   159  		c = subVW(z[n:], x[n:], c)
   160  	}
   161  	if c != 0 {
   162  		panic("underflow")
   163  	}
   164  
   165  	return z.norm()
   166  }
   167  
   168  func (x nat) cmp(y nat) (r int) {
   169  	m := len(x)
   170  	n := len(y)
   171  	if m != n || m == 0 {
   172  		switch {
   173  		case m < n:
   174  			r = -1
   175  		case m > n:
   176  			r = 1
   177  		}
   178  		return
   179  	}
   180  
   181  	i := m - 1
   182  	for i > 0 && x[i] == y[i] {
   183  		i--
   184  	}
   185  
   186  	switch {
   187  	case x[i] < y[i]:
   188  		r = -1
   189  	case x[i] > y[i]:
   190  		r = 1
   191  	}
   192  	return
   193  }
   194  
   195  func (z nat) mulAddWW(x nat, y, r Word) nat {
   196  	m := len(x)
   197  	if m == 0 || y == 0 {
   198  		return z.setWord(r) // result is r
   199  	}
   200  	// m > 0
   201  
   202  	z = z.make(m + 1)
   203  	z[m] = mulAddVWW(z[0:m], x, y, r)
   204  
   205  	return z.norm()
   206  }
   207  
   208  // basicMul multiplies x and y and leaves the result in z.
   209  // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   210  func basicMul(z, x, y nat) {
   211  	z[0 : len(x)+len(y)].clear() // initialize z
   212  	for i, d := range y {
   213  		if d != 0 {
   214  			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   215  		}
   216  	}
   217  }
   218  
   219  // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
   220  // Factored out for readability - do not use outside karatsuba.
   221  func karatsubaAdd(z, x nat, n int) {
   222  	if c := addVV(z[0:n], z, x); c != 0 {
   223  		addVW(z[n:n+n>>1], z[n:], c)
   224  	}
   225  }
   226  
   227  // Like karatsubaAdd, but does subtract.
   228  func karatsubaSub(z, x nat, n int) {
   229  	if c := subVV(z[0:n], z, x); c != 0 {
   230  		subVW(z[n:n+n>>1], z[n:], c)
   231  	}
   232  }
   233  
   234  // Operands that are shorter than karatsubaThreshold are multiplied using
   235  // "grade school" multiplication; for longer operands the Karatsuba algorithm
   236  // is used.
   237  var karatsubaThreshold int = 40 // computed by calibrate.go
   238  
   239  // karatsuba multiplies x and y and leaves the result in z.
   240  // Both x and y must have the same length n and n must be a
   241  // power of 2. The result vector z must have len(z) >= 6*n.
   242  // The (non-normalized) result is placed in z[0 : 2*n].
   243  func karatsuba(z, x, y nat) {
   244  	n := len(y)
   245  
   246  	// Switch to basic multiplication if numbers are odd or small.
   247  	// (n is always even if karatsubaThreshold is even, but be
   248  	// conservative)
   249  	if n&1 != 0 || n < karatsubaThreshold || n < 2 {
   250  		basicMul(z, x, y)
   251  		return
   252  	}
   253  	// n&1 == 0 && n >= karatsubaThreshold && n >= 2
   254  
   255  	// Karatsuba multiplication is based on the observation that
   256  	// for two numbers x and y with:
   257  	//
   258  	//   x = x1*b + x0
   259  	//   y = y1*b + y0
   260  	//
   261  	// the product x*y can be obtained with 3 products z2, z1, z0
   262  	// instead of 4:
   263  	//
   264  	//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
   265  	//       =    z2*b*b +              z1*b +    z0
   266  	//
   267  	// with:
   268  	//
   269  	//   xd = x1 - x0
   270  	//   yd = y0 - y1
   271  	//
   272  	//   z1 =      xd*yd                    + z2 + z0
   273  	//      = (x1-x0)*(y0 - y1)             + z2 + z0
   274  	//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
   275  	//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
   276  	//      = x1*y0                 + x0*y1
   277  
   278  	// split x, y into "digits"
   279  	n2 := n >> 1              // n2 >= 1
   280  	x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
   281  	y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
   282  
   283  	// z is used for the result and temporary storage:
   284  	//
   285  	//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
   286  	// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
   287  	//
   288  	// For each recursive call of karatsuba, an unused slice of
   289  	// z is passed in that has (at least) half the length of the
   290  	// caller's z.
   291  
   292  	// compute z0 and z2 with the result "in place" in z
   293  	karatsuba(z, x0, y0)     // z0 = x0*y0
   294  	karatsuba(z[n:], x1, y1) // z2 = x1*y1
   295  
   296  	// compute xd (or the negative value if underflow occurs)
   297  	s := 1 // sign of product xd*yd
   298  	xd := z[2*n : 2*n+n2]
   299  	if subVV(xd, x1, x0) != 0 { // x1-x0
   300  		s = -s
   301  		subVV(xd, x0, x1) // x0-x1
   302  	}
   303  
   304  	// compute yd (or the negative value if underflow occurs)
   305  	yd := z[2*n+n2 : 3*n]
   306  	if subVV(yd, y0, y1) != 0 { // y0-y1
   307  		s = -s
   308  		subVV(yd, y1, y0) // y1-y0
   309  	}
   310  
   311  	// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
   312  	// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
   313  	p := z[n*3:]
   314  	karatsuba(p, xd, yd)
   315  
   316  	// save original z2:z0
   317  	// (ok to use upper half of z since we're done recursing)
   318  	r := z[n*4:]
   319  	copy(r, z[:n*2])
   320  
   321  	// add up all partial products
   322  	//
   323  	//   2*n     n     0
   324  	// z = [ z2  | z0  ]
   325  	//   +    [ z0  ]
   326  	//   +    [ z2  ]
   327  	//   +    [  p  ]
   328  	//
   329  	karatsubaAdd(z[n2:], r, n)
   330  	karatsubaAdd(z[n2:], r[n:], n)
   331  	if s > 0 {
   332  		karatsubaAdd(z[n2:], p, n)
   333  	} else {
   334  		karatsubaSub(z[n2:], p, n)
   335  	}
   336  }
   337  
   338  // alias reports whether x and y share the same base array.
   339  func alias(x, y nat) bool {
   340  	return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
   341  }
   342  
   343  // addAt implements z += x<<(_W*i); z must be long enough.
   344  // (we don't use nat.add because we need z to stay the same
   345  // slice, and we don't need to normalize z after each addition)
   346  func addAt(z, x nat, i int) {
   347  	if n := len(x); n > 0 {
   348  		if c := addVV(z[i:i+n], z[i:], x); c != 0 {
   349  			j := i + n
   350  			if j < len(z) {
   351  				addVW(z[j:], z[j:], c)
   352  			}
   353  		}
   354  	}
   355  }
   356  
   357  func max(x, y int) int {
   358  	if x > y {
   359  		return x
   360  	}
   361  	return y
   362  }
   363  
   364  // karatsubaLen computes an approximation to the maximum k <= n such that
   365  // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
   366  // result is the largest number that can be divided repeatedly by 2 before
   367  // becoming about the value of karatsubaThreshold.
   368  func karatsubaLen(n int) int {
   369  	i := uint(0)
   370  	for n > karatsubaThreshold {
   371  		n >>= 1
   372  		i++
   373  	}
   374  	return n << i
   375  }
   376  
   377  func (z nat) mul(x, y nat) nat {
   378  	m := len(x)
   379  	n := len(y)
   380  
   381  	switch {
   382  	case m < n:
   383  		return z.mul(y, x)
   384  	case m == 0 || n == 0:
   385  		return z[:0]
   386  	case n == 1:
   387  		return z.mulAddWW(x, y[0], 0)
   388  	}
   389  	// m >= n > 1
   390  
   391  	// determine if z can be reused
   392  	if alias(z, x) || alias(z, y) {
   393  		z = nil // z is an alias for x or y - cannot reuse
   394  	}
   395  
   396  	// use basic multiplication if the numbers are small
   397  	if n < karatsubaThreshold {
   398  		z = z.make(m + n)
   399  		basicMul(z, x, y)
   400  		return z.norm()
   401  	}
   402  	// m >= n && n >= karatsubaThreshold && n >= 2
   403  
   404  	// determine Karatsuba length k such that
   405  	//
   406  	//   x = xh*b + x0  (0 <= x0 < b)
   407  	//   y = yh*b + y0  (0 <= y0 < b)
   408  	//   b = 1<<(_W*k)  ("base" of digits xi, yi)
   409  	//
   410  	k := karatsubaLen(n)
   411  	// k <= n
   412  
   413  	// multiply x0 and y0 via Karatsuba
   414  	x0 := x[0:k]              // x0 is not normalized
   415  	y0 := y[0:k]              // y0 is not normalized
   416  	z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
   417  	karatsuba(z, x0, y0)
   418  	z = z[0 : m+n]  // z has final length but may be incomplete
   419  	z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
   420  
   421  	// If xh != 0 or yh != 0, add the missing terms to z. For
   422  	//
   423  	//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
   424  	//   yh =                         y1*b (0 <= y1 < b)
   425  	//
   426  	// the missing terms are
   427  	//
   428  	//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
   429  	//
   430  	// since all the yi for i > 1 are 0 by choice of k: If any of them
   431  	// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
   432  	// be a larger valid threshold contradicting the assumption about k.
   433  	//
   434  	if k < n || m != n {
   435  		var t nat
   436  
   437  		// add x0*y1*b
   438  		x0 := x0.norm()
   439  		y1 := y[k:]       // y1 is normalized because y is
   440  		t = t.mul(x0, y1) // update t so we don't lose t's underlying array
   441  		addAt(z, t, k)
   442  
   443  		// add xi*y0<<i, xi*y1*b<<(i+k)
   444  		y0 := y0.norm()
   445  		for i := k; i < len(x); i += k {
   446  			xi := x[i:]
   447  			if len(xi) > k {
   448  				xi = xi[:k]
   449  			}
   450  			xi = xi.norm()
   451  			t = t.mul(xi, y0)
   452  			addAt(z, t, i)
   453  			t = t.mul(xi, y1)
   454  			addAt(z, t, i+k)
   455  		}
   456  	}
   457  
   458  	return z.norm()
   459  }
   460  
   461  // mulRange computes the product of all the unsigned integers in the
   462  // range [a, b] inclusively. If a > b (empty range), the result is 1.
   463  func (z nat) mulRange(a, b uint64) nat {
   464  	switch {
   465  	case a == 0:
   466  		// cut long ranges short (optimization)
   467  		return z.setUint64(0)
   468  	case a > b:
   469  		return z.setUint64(1)
   470  	case a == b:
   471  		return z.setUint64(a)
   472  	case a+1 == b:
   473  		return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
   474  	}
   475  	m := (a + b) / 2
   476  	return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
   477  }
   478  
   479  // q = (x-r)/y, with 0 <= r < y
   480  func (z nat) divW(x nat, y Word) (q nat, r Word) {
   481  	m := len(x)
   482  	switch {
   483  	case y == 0:
   484  		panic("division by zero")
   485  	case y == 1:
   486  		q = z.set(x) // result is x
   487  		return
   488  	case m == 0:
   489  		q = z[:0] // result is 0
   490  		return
   491  	}
   492  	// m > 0
   493  	z = z.make(m)
   494  	r = divWVW(z, 0, x, y)
   495  	q = z.norm()
   496  	return
   497  }
   498  
   499  func (z nat) div(z2, u, v nat) (q, r nat) {
   500  	if len(v) == 0 {
   501  		panic("division by zero")
   502  	}
   503  
   504  	if u.cmp(v) < 0 {
   505  		q = z[:0]
   506  		r = z2.set(u)
   507  		return
   508  	}
   509  
   510  	if len(v) == 1 {
   511  		var r2 Word
   512  		q, r2 = z.divW(u, v[0])
   513  		r = z2.setWord(r2)
   514  		return
   515  	}
   516  
   517  	q, r = z.divLarge(z2, u, v)
   518  	return
   519  }
   520  
   521  // q = (uIn-r)/v, with 0 <= r < y
   522  // Uses z as storage for q, and u as storage for r if possible.
   523  // See Knuth, Volume 2, section 4.3.1, Algorithm D.
   524  // Preconditions:
   525  //    len(v) >= 2
   526  //    len(uIn) >= len(v)
   527  func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
   528  	n := len(v)
   529  	m := len(uIn) - n
   530  
   531  	// determine if z can be reused
   532  	// TODO(gri) should find a better solution - this if statement
   533  	//           is very costly (see e.g. time pidigits -s -n 10000)
   534  	if alias(z, uIn) || alias(z, v) {
   535  		z = nil // z is an alias for uIn or v - cannot reuse
   536  	}
   537  	q = z.make(m + 1)
   538  
   539  	qhatv := make(nat, n+1)
   540  	if alias(u, uIn) || alias(u, v) {
   541  		u = nil // u is an alias for uIn or v - cannot reuse
   542  	}
   543  	u = u.make(len(uIn) + 1)
   544  	u.clear() // TODO(gri) no need to clear if we allocated a new u
   545  
   546  	// D1.
   547  	shift := leadingZeros(v[n-1])
   548  	if shift > 0 {
   549  		// do not modify v, it may be used by another goroutine simultaneously
   550  		v1 := make(nat, n)
   551  		shlVU(v1, v, shift)
   552  		v = v1
   553  	}
   554  	u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
   555  
   556  	// D2.
   557  	for j := m; j >= 0; j-- {
   558  		// D3.
   559  		qhat := Word(_M)
   560  		if u[j+n] != v[n-1] {
   561  			var rhat Word
   562  			qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1])
   563  
   564  			// x1 | x2 = q̂v_{n-2}
   565  			x1, x2 := mulWW(qhat, v[n-2])
   566  			// test if q̂v_{n-2} > br̂ + u_{j+n-2}
   567  			for greaterThan(x1, x2, rhat, u[j+n-2]) {
   568  				qhat--
   569  				prevRhat := rhat
   570  				rhat += v[n-1]
   571  				// v[n-1] >= 0, so this tests for overflow.
   572  				if rhat < prevRhat {
   573  					break
   574  				}
   575  				x1, x2 = mulWW(qhat, v[n-2])
   576  			}
   577  		}
   578  
   579  		// D4.
   580  		qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
   581  
   582  		c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
   583  		if c != 0 {
   584  			c := addVV(u[j:j+n], u[j:], v)
   585  			u[j+n] += c
   586  			qhat--
   587  		}
   588  
   589  		q[j] = qhat
   590  	}
   591  
   592  	q = q.norm()
   593  	shrVU(u, u, shift)
   594  	r = u.norm()
   595  
   596  	return q, r
   597  }
   598  
   599  // Length of x in bits. x must be normalized.
   600  func (x nat) bitLen() int {
   601  	if i := len(x) - 1; i >= 0 {
   602  		return i*_W + bitLen(x[i])
   603  	}
   604  	return 0
   605  }
   606  
   607  const deBruijn32 = 0x077CB531
   608  
   609  var deBruijn32Lookup = []byte{
   610  	0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
   611  	31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
   612  }
   613  
   614  const deBruijn64 = 0x03f79d71b4ca8b09
   615  
   616  var deBruijn64Lookup = []byte{
   617  	0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
   618  	62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
   619  	63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
   620  	54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
   621  }
   622  
   623  // trailingZeroBits returns the number of consecutive least significant zero
   624  // bits of x.
   625  func trailingZeroBits(x Word) uint {
   626  	// x & -x leaves only the right-most bit set in the word. Let k be the
   627  	// index of that bit. Since only a single bit is set, the value is two
   628  	// to the power of k. Multiplying by a power of two is equivalent to
   629  	// left shifting, in this case by k bits.  The de Bruijn constant is
   630  	// such that all six bit, consecutive substrings are distinct.
   631  	// Therefore, if we have a left shifted version of this constant we can
   632  	// find by how many bits it was shifted by looking at which six bit
   633  	// substring ended up at the top of the word.
   634  	// (Knuth, volume 4, section 7.3.1)
   635  	switch _W {
   636  	case 32:
   637  		return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
   638  	case 64:
   639  		return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
   640  	default:
   641  		panic("unknown word size")
   642  	}
   643  }
   644  
   645  // trailingZeroBits returns the number of consecutive least significant zero
   646  // bits of x.
   647  func (x nat) trailingZeroBits() uint {
   648  	if len(x) == 0 {
   649  		return 0
   650  	}
   651  	var i uint
   652  	for x[i] == 0 {
   653  		i++
   654  	}
   655  	// x[i] != 0
   656  	return i*_W + trailingZeroBits(x[i])
   657  }
   658  
   659  // z = x << s
   660  func (z nat) shl(x nat, s uint) nat {
   661  	m := len(x)
   662  	if m == 0 {
   663  		return z[:0]
   664  	}
   665  	// m > 0
   666  
   667  	n := m + int(s/_W)
   668  	z = z.make(n + 1)
   669  	z[n] = shlVU(z[n-m:n], x, s%_W)
   670  	z[0 : n-m].clear()
   671  
   672  	return z.norm()
   673  }
   674  
   675  // z = x >> s
   676  func (z nat) shr(x nat, s uint) nat {
   677  	m := len(x)
   678  	n := m - int(s/_W)
   679  	if n <= 0 {
   680  		return z[:0]
   681  	}
   682  	// n > 0
   683  
   684  	z = z.make(n)
   685  	shrVU(z, x[m-n:], s%_W)
   686  
   687  	return z.norm()
   688  }
   689  
   690  func (z nat) setBit(x nat, i uint, b uint) nat {
   691  	j := int(i / _W)
   692  	m := Word(1) << (i % _W)
   693  	n := len(x)
   694  	switch b {
   695  	case 0:
   696  		z = z.make(n)
   697  		copy(z, x)
   698  		if j >= n {
   699  			// no need to grow
   700  			return z
   701  		}
   702  		z[j] &^= m
   703  		return z.norm()
   704  	case 1:
   705  		if j >= n {
   706  			z = z.make(j + 1)
   707  			z[n:].clear()
   708  		} else {
   709  			z = z.make(n)
   710  		}
   711  		copy(z, x)
   712  		z[j] |= m
   713  		// no need to normalize
   714  		return z
   715  	}
   716  	panic("set bit is not 0 or 1")
   717  }
   718  
   719  // bit returns the value of the i'th bit, with lsb == bit 0.
   720  func (x nat) bit(i uint) uint {
   721  	j := i / _W
   722  	if j >= uint(len(x)) {
   723  		return 0
   724  	}
   725  	// 0 <= j < len(x)
   726  	return uint(x[j] >> (i % _W) & 1)
   727  }
   728  
   729  // sticky returns 1 if there's a 1 bit within the
   730  // i least significant bits, otherwise it returns 0.
   731  func (x nat) sticky(i uint) uint {
   732  	j := i / _W
   733  	if j >= uint(len(x)) {
   734  		if len(x) == 0 {
   735  			return 0
   736  		}
   737  		return 1
   738  	}
   739  	// 0 <= j < len(x)
   740  	for _, x := range x[:j] {
   741  		if x != 0 {
   742  			return 1
   743  		}
   744  	}
   745  	if x[j]<<(_W-i%_W) != 0 {
   746  		return 1
   747  	}
   748  	return 0
   749  }
   750  
   751  func (z nat) and(x, y nat) nat {
   752  	m := len(x)
   753  	n := len(y)
   754  	if m > n {
   755  		m = n
   756  	}
   757  	// m <= n
   758  
   759  	z = z.make(m)
   760  	for i := 0; i < m; i++ {
   761  		z[i] = x[i] & y[i]
   762  	}
   763  
   764  	return z.norm()
   765  }
   766  
   767  func (z nat) andNot(x, y nat) nat {
   768  	m := len(x)
   769  	n := len(y)
   770  	if n > m {
   771  		n = m
   772  	}
   773  	// m >= n
   774  
   775  	z = z.make(m)
   776  	for i := 0; i < n; i++ {
   777  		z[i] = x[i] &^ y[i]
   778  	}
   779  	copy(z[n:m], x[n:m])
   780  
   781  	return z.norm()
   782  }
   783  
   784  func (z nat) or(x, y nat) nat {
   785  	m := len(x)
   786  	n := len(y)
   787  	s := x
   788  	if m < n {
   789  		n, m = m, n
   790  		s = y
   791  	}
   792  	// m >= n
   793  
   794  	z = z.make(m)
   795  	for i := 0; i < n; i++ {
   796  		z[i] = x[i] | y[i]
   797  	}
   798  	copy(z[n:m], s[n:m])
   799  
   800  	return z.norm()
   801  }
   802  
   803  func (z nat) xor(x, y nat) nat {
   804  	m := len(x)
   805  	n := len(y)
   806  	s := x
   807  	if m < n {
   808  		n, m = m, n
   809  		s = y
   810  	}
   811  	// m >= n
   812  
   813  	z = z.make(m)
   814  	for i := 0; i < n; i++ {
   815  		z[i] = x[i] ^ y[i]
   816  	}
   817  	copy(z[n:m], s[n:m])
   818  
   819  	return z.norm()
   820  }
   821  
   822  // greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
   823  func greaterThan(x1, x2, y1, y2 Word) bool {
   824  	return x1 > y1 || x1 == y1 && x2 > y2
   825  }
   826  
   827  // modW returns x % d.
   828  func (x nat) modW(d Word) (r Word) {
   829  	// TODO(agl): we don't actually need to store the q value.
   830  	var q nat
   831  	q = q.make(len(x))
   832  	return divWVW(q, 0, x, d)
   833  }
   834  
   835  // random creates a random integer in [0..limit), using the space in z if
   836  // possible. n is the bit length of limit.
   837  func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
   838  	if alias(z, limit) {
   839  		z = nil // z is an alias for limit - cannot reuse
   840  	}
   841  	z = z.make(len(limit))
   842  
   843  	bitLengthOfMSW := uint(n % _W)
   844  	if bitLengthOfMSW == 0 {
   845  		bitLengthOfMSW = _W
   846  	}
   847  	mask := Word((1 << bitLengthOfMSW) - 1)
   848  
   849  	for {
   850  		switch _W {
   851  		case 32:
   852  			for i := range z {
   853  				z[i] = Word(rand.Uint32())
   854  			}
   855  		case 64:
   856  			for i := range z {
   857  				z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
   858  			}
   859  		default:
   860  			panic("unknown word size")
   861  		}
   862  		z[len(limit)-1] &= mask
   863  		if z.cmp(limit) < 0 {
   864  			break
   865  		}
   866  	}
   867  
   868  	return z.norm()
   869  }
   870  
   871  // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
   872  // otherwise it sets z to x**y. The result is the value of z.
   873  func (z nat) expNN(x, y, m nat) nat {
   874  	if alias(z, x) || alias(z, y) {
   875  		// We cannot allow in-place modification of x or y.
   876  		z = nil
   877  	}
   878  
   879  	// x**y mod 1 == 0
   880  	if len(m) == 1 && m[0] == 1 {
   881  		return z.setWord(0)
   882  	}
   883  	// m == 0 || m > 1
   884  
   885  	// x**0 == 1
   886  	if len(y) == 0 {
   887  		return z.setWord(1)
   888  	}
   889  	// y > 0
   890  
   891  	// x**1 mod m == x mod m
   892  	if len(y) == 1 && y[0] == 1 && len(m) != 0 {
   893  		_, z = z.div(z, x, m)
   894  		return z
   895  	}
   896  	// y > 1
   897  
   898  	if len(m) != 0 {
   899  		// We likely end up being as long as the modulus.
   900  		z = z.make(len(m))
   901  	}
   902  	z = z.set(x)
   903  
   904  	// If the base is non-trivial and the exponent is large, we use
   905  	// 4-bit, windowed exponentiation. This involves precomputing 14 values
   906  	// (x^2...x^15) but then reduces the number of multiply-reduces by a
   907  	// third. Even for a 32-bit exponent, this reduces the number of
   908  	// operations.
   909  	if len(x) > 1 && len(y) > 1 && len(m) > 0 {
   910  		return z.expNNWindowed(x, y, m)
   911  	}
   912  
   913  	v := y[len(y)-1] // v > 0 because y is normalized and y > 0
   914  	shift := leadingZeros(v) + 1
   915  	v <<= shift
   916  	var q nat
   917  
   918  	const mask = 1 << (_W - 1)
   919  
   920  	// We walk through the bits of the exponent one by one. Each time we
   921  	// see a bit, we square, thus doubling the power. If the bit is a one,
   922  	// we also multiply by x, thus adding one to the power.
   923  
   924  	w := _W - int(shift)
   925  	// zz and r are used to avoid allocating in mul and div as
   926  	// otherwise the arguments would alias.
   927  	var zz, r nat
   928  	for j := 0; j < w; j++ {
   929  		zz = zz.mul(z, z)
   930  		zz, z = z, zz
   931  
   932  		if v&mask != 0 {
   933  			zz = zz.mul(z, x)
   934  			zz, z = z, zz
   935  		}
   936  
   937  		if len(m) != 0 {
   938  			zz, r = zz.div(r, z, m)
   939  			zz, r, q, z = q, z, zz, r
   940  		}
   941  
   942  		v <<= 1
   943  	}
   944  
   945  	for i := len(y) - 2; i >= 0; i-- {
   946  		v = y[i]
   947  
   948  		for j := 0; j < _W; j++ {
   949  			zz = zz.mul(z, z)
   950  			zz, z = z, zz
   951  
   952  			if v&mask != 0 {
   953  				zz = zz.mul(z, x)
   954  				zz, z = z, zz
   955  			}
   956  
   957  			if len(m) != 0 {
   958  				zz, r = zz.div(r, z, m)
   959  				zz, r, q, z = q, z, zz, r
   960  			}
   961  
   962  			v <<= 1
   963  		}
   964  	}
   965  
   966  	return z.norm()
   967  }
   968  
   969  // expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
   970  func (z nat) expNNWindowed(x, y, m nat) nat {
   971  	// zz and r are used to avoid allocating in mul and div as otherwise
   972  	// the arguments would alias.
   973  	var zz, r nat
   974  
   975  	const n = 4
   976  	// powers[i] contains x^i.
   977  	var powers [1 << n]nat
   978  	powers[0] = natOne
   979  	powers[1] = x
   980  	for i := 2; i < 1<<n; i += 2 {
   981  		p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
   982  		*p = p.mul(*p2, *p2)
   983  		zz, r = zz.div(r, *p, m)
   984  		*p, r = r, *p
   985  		*p1 = p1.mul(*p, x)
   986  		zz, r = zz.div(r, *p1, m)
   987  		*p1, r = r, *p1
   988  	}
   989  
   990  	z = z.setWord(1)
   991  
   992  	for i := len(y) - 1; i >= 0; i-- {
   993  		yi := y[i]
   994  		for j := 0; j < _W; j += n {
   995  			if i != len(y)-1 || j != 0 {
   996  				// Unrolled loop for significant performance
   997  				// gain.  Use go test -bench=".*" in crypto/rsa
   998  				// to check performance before making changes.
   999  				zz = zz.mul(z, z)
  1000  				zz, z = z, zz
  1001  				zz, r = zz.div(r, z, m)
  1002  				z, r = r, z
  1003  
  1004  				zz = zz.mul(z, z)
  1005  				zz, z = z, zz
  1006  				zz, r = zz.div(r, z, m)
  1007  				z, r = r, z
  1008  
  1009  				zz = zz.mul(z, z)
  1010  				zz, z = z, zz
  1011  				zz, r = zz.div(r, z, m)
  1012  				z, r = r, z
  1013  
  1014  				zz = zz.mul(z, z)
  1015  				zz, z = z, zz
  1016  				zz, r = zz.div(r, z, m)
  1017  				z, r = r, z
  1018  			}
  1019  
  1020  			zz = zz.mul(z, powers[yi>>(_W-n)])
  1021  			zz, z = z, zz
  1022  			zz, r = zz.div(r, z, m)
  1023  			z, r = r, z
  1024  
  1025  			yi <<= n
  1026  		}
  1027  	}
  1028  
  1029  	return z.norm()
  1030  }
  1031  
  1032  // probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
  1033  // If it returns true, n is prime with probability 1 - 1/4^reps.
  1034  // If it returns false, n is not prime.
  1035  func (n nat) probablyPrime(reps int) bool {
  1036  	if len(n) == 0 {
  1037  		return false
  1038  	}
  1039  
  1040  	if len(n) == 1 {
  1041  		if n[0] < 2 {
  1042  			return false
  1043  		}
  1044  
  1045  		if n[0]%2 == 0 {
  1046  			return n[0] == 2
  1047  		}
  1048  
  1049  		// We have to exclude these cases because we reject all
  1050  		// multiples of these numbers below.
  1051  		switch n[0] {
  1052  		case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
  1053  			return true
  1054  		}
  1055  	}
  1056  
  1057  	if n[0]&1 == 0 {
  1058  		return false // n is even
  1059  	}
  1060  
  1061  	const primesProduct32 = 0xC0CFD797         // Π {p ∈ primes, 2 < p <= 29}
  1062  	const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53}
  1063  
  1064  	var r Word
  1065  	switch _W {
  1066  	case 32:
  1067  		r = n.modW(primesProduct32)
  1068  	case 64:
  1069  		r = n.modW(primesProduct64 & _M)
  1070  	default:
  1071  		panic("Unknown word size")
  1072  	}
  1073  
  1074  	if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 ||
  1075  		r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 {
  1076  		return false
  1077  	}
  1078  
  1079  	if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 ||
  1080  		r%43 == 0 || r%47 == 0 || r%53 == 0) {
  1081  		return false
  1082  	}
  1083  
  1084  	nm1 := nat(nil).sub(n, natOne)
  1085  	// determine q, k such that nm1 = q << k
  1086  	k := nm1.trailingZeroBits()
  1087  	q := nat(nil).shr(nm1, k)
  1088  
  1089  	nm3 := nat(nil).sub(nm1, natTwo)
  1090  	rand := rand.New(rand.NewSource(int64(n[0])))
  1091  
  1092  	var x, y, quotient nat
  1093  	nm3Len := nm3.bitLen()
  1094  
  1095  NextRandom:
  1096  	for i := 0; i < reps; i++ {
  1097  		x = x.random(rand, nm3, nm3Len)
  1098  		x = x.add(x, natTwo)
  1099  		y = y.expNN(x, q, n)
  1100  		if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
  1101  			continue
  1102  		}
  1103  		for j := uint(1); j < k; j++ {
  1104  			y = y.mul(y, y)
  1105  			quotient, y = quotient.div(y, y, n)
  1106  			if y.cmp(nm1) == 0 {
  1107  				continue NextRandom
  1108  			}
  1109  			if y.cmp(natOne) == 0 {
  1110  				return false
  1111  			}
  1112  		}
  1113  		return false
  1114  	}
  1115  
  1116  	return true
  1117  }
  1118  
  1119  // bytes writes the value of z into buf using big-endian encoding.
  1120  // len(buf) must be >= len(z)*_S. The value of z is encoded in the
  1121  // slice buf[i:]. The number i of unused bytes at the beginning of
  1122  // buf is returned as result.
  1123  func (z nat) bytes(buf []byte) (i int) {
  1124  	i = len(buf)
  1125  	for _, d := range z {
  1126  		for j := 0; j < _S; j++ {
  1127  			i--
  1128  			buf[i] = byte(d)
  1129  			d >>= 8
  1130  		}
  1131  	}
  1132  
  1133  	for i < len(buf) && buf[i] == 0 {
  1134  		i++
  1135  	}
  1136  
  1137  	return
  1138  }
  1139  
  1140  // setBytes interprets buf as the bytes of a big-endian unsigned
  1141  // integer, sets z to that value, and returns z.
  1142  func (z nat) setBytes(buf []byte) nat {
  1143  	z = z.make((len(buf) + _S - 1) / _S)
  1144  
  1145  	k := 0
  1146  	s := uint(0)
  1147  	var d Word
  1148  	for i := len(buf); i > 0; i-- {
  1149  		d |= Word(buf[i-1]) << s
  1150  		if s += 8; s == _S*8 {
  1151  			z[k] = d
  1152  			k++
  1153  			s = 0
  1154  			d = 0
  1155  		}
  1156  	}
  1157  	if k < len(z) {
  1158  		z[k] = d
  1159  	}
  1160  
  1161  	return z.norm()
  1162  }