github.com/rsc/go@v0.0.0-20150416155037-e040fd465409/src/math/big/nat.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package big implements multi-precision arithmetic (big numbers). 6 // The following numeric types are supported: 7 // 8 // Int signed integers 9 // Rat rational numbers 10 // Float floating-point numbers 11 // 12 // Methods are typically of the form: 13 // 14 // func (z *T) Unary(x *T) *T // z = op x 15 // func (z *T) Binary(x, y *T) *T // z = x op y 16 // func (x *T) M() T1 // v = x.M() 17 // 18 // with T one of Int, Rat, or Float. For unary and binary operations, the 19 // result is the receiver (usually named z in that case); if it is one of 20 // the operands x or y it may be overwritten (and its memory reused). 21 // To enable chaining of operations, the result is also returned. Methods 22 // returning a result other than *Int, *Rat, or *Float take an operand as 23 // the receiver (usually named x in that case). 24 // 25 package big 26 27 // This file contains operations on unsigned multi-precision integers. 28 // These are the building blocks for the operations on signed integers 29 // and rationals. 30 31 import "math/rand" 32 33 // An unsigned integer x of the form 34 // 35 // x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] 36 // 37 // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, 38 // with the digits x[i] as the slice elements. 39 // 40 // A number is normalized if the slice contains no leading 0 digits. 41 // During arithmetic operations, denormalized values may occur but are 42 // always normalized before returning the final result. The normalized 43 // representation of 0 is the empty or nil slice (length = 0). 44 // 45 type nat []Word 46 47 var ( 48 natOne = nat{1} 49 natTwo = nat{2} 50 natTen = nat{10} 51 ) 52 53 func (z nat) clear() { 54 for i := range z { 55 z[i] = 0 56 } 57 } 58 59 func (z nat) norm() nat { 60 i := len(z) 61 for i > 0 && z[i-1] == 0 { 62 i-- 63 } 64 return z[0:i] 65 } 66 67 func (z nat) make(n int) nat { 68 if n <= cap(z) { 69 return z[:n] // reuse z 70 } 71 // Choosing a good value for e has significant performance impact 72 // because it increases the chance that a value can be reused. 73 const e = 4 // extra capacity 74 return make(nat, n, n+e) 75 } 76 77 func (z nat) setWord(x Word) nat { 78 if x == 0 { 79 return z[:0] 80 } 81 z = z.make(1) 82 z[0] = x 83 return z 84 } 85 86 func (z nat) setUint64(x uint64) nat { 87 // single-digit values 88 if w := Word(x); uint64(w) == x { 89 return z.setWord(w) 90 } 91 92 // compute number of words n required to represent x 93 n := 0 94 for t := x; t > 0; t >>= _W { 95 n++ 96 } 97 98 // split x into n words 99 z = z.make(n) 100 for i := range z { 101 z[i] = Word(x & _M) 102 x >>= _W 103 } 104 105 return z 106 } 107 108 func (z nat) set(x nat) nat { 109 z = z.make(len(x)) 110 copy(z, x) 111 return z 112 } 113 114 func (z nat) add(x, y nat) nat { 115 m := len(x) 116 n := len(y) 117 118 switch { 119 case m < n: 120 return z.add(y, x) 121 case m == 0: 122 // n == 0 because m >= n; result is 0 123 return z[:0] 124 case n == 0: 125 // result is x 126 return z.set(x) 127 } 128 // m > 0 129 130 z = z.make(m + 1) 131 c := addVV(z[0:n], x, y) 132 if m > n { 133 c = addVW(z[n:m], x[n:], c) 134 } 135 z[m] = c 136 137 return z.norm() 138 } 139 140 func (z nat) sub(x, y nat) nat { 141 m := len(x) 142 n := len(y) 143 144 switch { 145 case m < n: 146 panic("underflow") 147 case m == 0: 148 // n == 0 because m >= n; result is 0 149 return z[:0] 150 case n == 0: 151 // result is x 152 return z.set(x) 153 } 154 // m > 0 155 156 z = z.make(m) 157 c := subVV(z[0:n], x, y) 158 if m > n { 159 c = subVW(z[n:], x[n:], c) 160 } 161 if c != 0 { 162 panic("underflow") 163 } 164 165 return z.norm() 166 } 167 168 func (x nat) cmp(y nat) (r int) { 169 m := len(x) 170 n := len(y) 171 if m != n || m == 0 { 172 switch { 173 case m < n: 174 r = -1 175 case m > n: 176 r = 1 177 } 178 return 179 } 180 181 i := m - 1 182 for i > 0 && x[i] == y[i] { 183 i-- 184 } 185 186 switch { 187 case x[i] < y[i]: 188 r = -1 189 case x[i] > y[i]: 190 r = 1 191 } 192 return 193 } 194 195 func (z nat) mulAddWW(x nat, y, r Word) nat { 196 m := len(x) 197 if m == 0 || y == 0 { 198 return z.setWord(r) // result is r 199 } 200 // m > 0 201 202 z = z.make(m + 1) 203 z[m] = mulAddVWW(z[0:m], x, y, r) 204 205 return z.norm() 206 } 207 208 // basicMul multiplies x and y and leaves the result in z. 209 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. 210 func basicMul(z, x, y nat) { 211 z[0 : len(x)+len(y)].clear() // initialize z 212 for i, d := range y { 213 if d != 0 { 214 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) 215 } 216 } 217 } 218 219 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. 220 // Factored out for readability - do not use outside karatsuba. 221 func karatsubaAdd(z, x nat, n int) { 222 if c := addVV(z[0:n], z, x); c != 0 { 223 addVW(z[n:n+n>>1], z[n:], c) 224 } 225 } 226 227 // Like karatsubaAdd, but does subtract. 228 func karatsubaSub(z, x nat, n int) { 229 if c := subVV(z[0:n], z, x); c != 0 { 230 subVW(z[n:n+n>>1], z[n:], c) 231 } 232 } 233 234 // Operands that are shorter than karatsubaThreshold are multiplied using 235 // "grade school" multiplication; for longer operands the Karatsuba algorithm 236 // is used. 237 var karatsubaThreshold int = 40 // computed by calibrate.go 238 239 // karatsuba multiplies x and y and leaves the result in z. 240 // Both x and y must have the same length n and n must be a 241 // power of 2. The result vector z must have len(z) >= 6*n. 242 // The (non-normalized) result is placed in z[0 : 2*n]. 243 func karatsuba(z, x, y nat) { 244 n := len(y) 245 246 // Switch to basic multiplication if numbers are odd or small. 247 // (n is always even if karatsubaThreshold is even, but be 248 // conservative) 249 if n&1 != 0 || n < karatsubaThreshold || n < 2 { 250 basicMul(z, x, y) 251 return 252 } 253 // n&1 == 0 && n >= karatsubaThreshold && n >= 2 254 255 // Karatsuba multiplication is based on the observation that 256 // for two numbers x and y with: 257 // 258 // x = x1*b + x0 259 // y = y1*b + y0 260 // 261 // the product x*y can be obtained with 3 products z2, z1, z0 262 // instead of 4: 263 // 264 // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 265 // = z2*b*b + z1*b + z0 266 // 267 // with: 268 // 269 // xd = x1 - x0 270 // yd = y0 - y1 271 // 272 // z1 = xd*yd + z2 + z0 273 // = (x1-x0)*(y0 - y1) + z2 + z0 274 // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 275 // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 276 // = x1*y0 + x0*y1 277 278 // split x, y into "digits" 279 n2 := n >> 1 // n2 >= 1 280 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 281 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 282 283 // z is used for the result and temporary storage: 284 // 285 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 286 // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] 287 // 288 // For each recursive call of karatsuba, an unused slice of 289 // z is passed in that has (at least) half the length of the 290 // caller's z. 291 292 // compute z0 and z2 with the result "in place" in z 293 karatsuba(z, x0, y0) // z0 = x0*y0 294 karatsuba(z[n:], x1, y1) // z2 = x1*y1 295 296 // compute xd (or the negative value if underflow occurs) 297 s := 1 // sign of product xd*yd 298 xd := z[2*n : 2*n+n2] 299 if subVV(xd, x1, x0) != 0 { // x1-x0 300 s = -s 301 subVV(xd, x0, x1) // x0-x1 302 } 303 304 // compute yd (or the negative value if underflow occurs) 305 yd := z[2*n+n2 : 3*n] 306 if subVV(yd, y0, y1) != 0 { // y0-y1 307 s = -s 308 subVV(yd, y1, y0) // y1-y0 309 } 310 311 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 312 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 313 p := z[n*3:] 314 karatsuba(p, xd, yd) 315 316 // save original z2:z0 317 // (ok to use upper half of z since we're done recursing) 318 r := z[n*4:] 319 copy(r, z[:n*2]) 320 321 // add up all partial products 322 // 323 // 2*n n 0 324 // z = [ z2 | z0 ] 325 // + [ z0 ] 326 // + [ z2 ] 327 // + [ p ] 328 // 329 karatsubaAdd(z[n2:], r, n) 330 karatsubaAdd(z[n2:], r[n:], n) 331 if s > 0 { 332 karatsubaAdd(z[n2:], p, n) 333 } else { 334 karatsubaSub(z[n2:], p, n) 335 } 336 } 337 338 // alias reports whether x and y share the same base array. 339 func alias(x, y nat) bool { 340 return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 341 } 342 343 // addAt implements z += x<<(_W*i); z must be long enough. 344 // (we don't use nat.add because we need z to stay the same 345 // slice, and we don't need to normalize z after each addition) 346 func addAt(z, x nat, i int) { 347 if n := len(x); n > 0 { 348 if c := addVV(z[i:i+n], z[i:], x); c != 0 { 349 j := i + n 350 if j < len(z) { 351 addVW(z[j:], z[j:], c) 352 } 353 } 354 } 355 } 356 357 func max(x, y int) int { 358 if x > y { 359 return x 360 } 361 return y 362 } 363 364 // karatsubaLen computes an approximation to the maximum k <= n such that 365 // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the 366 // result is the largest number that can be divided repeatedly by 2 before 367 // becoming about the value of karatsubaThreshold. 368 func karatsubaLen(n int) int { 369 i := uint(0) 370 for n > karatsubaThreshold { 371 n >>= 1 372 i++ 373 } 374 return n << i 375 } 376 377 func (z nat) mul(x, y nat) nat { 378 m := len(x) 379 n := len(y) 380 381 switch { 382 case m < n: 383 return z.mul(y, x) 384 case m == 0 || n == 0: 385 return z[:0] 386 case n == 1: 387 return z.mulAddWW(x, y[0], 0) 388 } 389 // m >= n > 1 390 391 // determine if z can be reused 392 if alias(z, x) || alias(z, y) { 393 z = nil // z is an alias for x or y - cannot reuse 394 } 395 396 // use basic multiplication if the numbers are small 397 if n < karatsubaThreshold { 398 z = z.make(m + n) 399 basicMul(z, x, y) 400 return z.norm() 401 } 402 // m >= n && n >= karatsubaThreshold && n >= 2 403 404 // determine Karatsuba length k such that 405 // 406 // x = xh*b + x0 (0 <= x0 < b) 407 // y = yh*b + y0 (0 <= y0 < b) 408 // b = 1<<(_W*k) ("base" of digits xi, yi) 409 // 410 k := karatsubaLen(n) 411 // k <= n 412 413 // multiply x0 and y0 via Karatsuba 414 x0 := x[0:k] // x0 is not normalized 415 y0 := y[0:k] // y0 is not normalized 416 z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y 417 karatsuba(z, x0, y0) 418 z = z[0 : m+n] // z has final length but may be incomplete 419 z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) 420 421 // If xh != 0 or yh != 0, add the missing terms to z. For 422 // 423 // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) 424 // yh = y1*b (0 <= y1 < b) 425 // 426 // the missing terms are 427 // 428 // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 429 // 430 // since all the yi for i > 1 are 0 by choice of k: If any of them 431 // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would 432 // be a larger valid threshold contradicting the assumption about k. 433 // 434 if k < n || m != n { 435 var t nat 436 437 // add x0*y1*b 438 x0 := x0.norm() 439 y1 := y[k:] // y1 is normalized because y is 440 t = t.mul(x0, y1) // update t so we don't lose t's underlying array 441 addAt(z, t, k) 442 443 // add xi*y0<<i, xi*y1*b<<(i+k) 444 y0 := y0.norm() 445 for i := k; i < len(x); i += k { 446 xi := x[i:] 447 if len(xi) > k { 448 xi = xi[:k] 449 } 450 xi = xi.norm() 451 t = t.mul(xi, y0) 452 addAt(z, t, i) 453 t = t.mul(xi, y1) 454 addAt(z, t, i+k) 455 } 456 } 457 458 return z.norm() 459 } 460 461 // mulRange computes the product of all the unsigned integers in the 462 // range [a, b] inclusively. If a > b (empty range), the result is 1. 463 func (z nat) mulRange(a, b uint64) nat { 464 switch { 465 case a == 0: 466 // cut long ranges short (optimization) 467 return z.setUint64(0) 468 case a > b: 469 return z.setUint64(1) 470 case a == b: 471 return z.setUint64(a) 472 case a+1 == b: 473 return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) 474 } 475 m := (a + b) / 2 476 return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) 477 } 478 479 // q = (x-r)/y, with 0 <= r < y 480 func (z nat) divW(x nat, y Word) (q nat, r Word) { 481 m := len(x) 482 switch { 483 case y == 0: 484 panic("division by zero") 485 case y == 1: 486 q = z.set(x) // result is x 487 return 488 case m == 0: 489 q = z[:0] // result is 0 490 return 491 } 492 // m > 0 493 z = z.make(m) 494 r = divWVW(z, 0, x, y) 495 q = z.norm() 496 return 497 } 498 499 func (z nat) div(z2, u, v nat) (q, r nat) { 500 if len(v) == 0 { 501 panic("division by zero") 502 } 503 504 if u.cmp(v) < 0 { 505 q = z[:0] 506 r = z2.set(u) 507 return 508 } 509 510 if len(v) == 1 { 511 var r2 Word 512 q, r2 = z.divW(u, v[0]) 513 r = z2.setWord(r2) 514 return 515 } 516 517 q, r = z.divLarge(z2, u, v) 518 return 519 } 520 521 // q = (uIn-r)/v, with 0 <= r < y 522 // Uses z as storage for q, and u as storage for r if possible. 523 // See Knuth, Volume 2, section 4.3.1, Algorithm D. 524 // Preconditions: 525 // len(v) >= 2 526 // len(uIn) >= len(v) 527 func (z nat) divLarge(u, uIn, v nat) (q, r nat) { 528 n := len(v) 529 m := len(uIn) - n 530 531 // determine if z can be reused 532 // TODO(gri) should find a better solution - this if statement 533 // is very costly (see e.g. time pidigits -s -n 10000) 534 if alias(z, uIn) || alias(z, v) { 535 z = nil // z is an alias for uIn or v - cannot reuse 536 } 537 q = z.make(m + 1) 538 539 qhatv := make(nat, n+1) 540 if alias(u, uIn) || alias(u, v) { 541 u = nil // u is an alias for uIn or v - cannot reuse 542 } 543 u = u.make(len(uIn) + 1) 544 u.clear() // TODO(gri) no need to clear if we allocated a new u 545 546 // D1. 547 shift := leadingZeros(v[n-1]) 548 if shift > 0 { 549 // do not modify v, it may be used by another goroutine simultaneously 550 v1 := make(nat, n) 551 shlVU(v1, v, shift) 552 v = v1 553 } 554 u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift) 555 556 // D2. 557 for j := m; j >= 0; j-- { 558 // D3. 559 qhat := Word(_M) 560 if u[j+n] != v[n-1] { 561 var rhat Word 562 qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1]) 563 564 // x1 | x2 = q̂v_{n-2} 565 x1, x2 := mulWW(qhat, v[n-2]) 566 // test if q̂v_{n-2} > br̂ + u_{j+n-2} 567 for greaterThan(x1, x2, rhat, u[j+n-2]) { 568 qhat-- 569 prevRhat := rhat 570 rhat += v[n-1] 571 // v[n-1] >= 0, so this tests for overflow. 572 if rhat < prevRhat { 573 break 574 } 575 x1, x2 = mulWW(qhat, v[n-2]) 576 } 577 } 578 579 // D4. 580 qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) 581 582 c := subVV(u[j:j+len(qhatv)], u[j:], qhatv) 583 if c != 0 { 584 c := addVV(u[j:j+n], u[j:], v) 585 u[j+n] += c 586 qhat-- 587 } 588 589 q[j] = qhat 590 } 591 592 q = q.norm() 593 shrVU(u, u, shift) 594 r = u.norm() 595 596 return q, r 597 } 598 599 // Length of x in bits. x must be normalized. 600 func (x nat) bitLen() int { 601 if i := len(x) - 1; i >= 0 { 602 return i*_W + bitLen(x[i]) 603 } 604 return 0 605 } 606 607 const deBruijn32 = 0x077CB531 608 609 var deBruijn32Lookup = []byte{ 610 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 611 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9, 612 } 613 614 const deBruijn64 = 0x03f79d71b4ca8b09 615 616 var deBruijn64Lookup = []byte{ 617 0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4, 618 62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5, 619 63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11, 620 54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6, 621 } 622 623 // trailingZeroBits returns the number of consecutive least significant zero 624 // bits of x. 625 func trailingZeroBits(x Word) uint { 626 // x & -x leaves only the right-most bit set in the word. Let k be the 627 // index of that bit. Since only a single bit is set, the value is two 628 // to the power of k. Multiplying by a power of two is equivalent to 629 // left shifting, in this case by k bits. The de Bruijn constant is 630 // such that all six bit, consecutive substrings are distinct. 631 // Therefore, if we have a left shifted version of this constant we can 632 // find by how many bits it was shifted by looking at which six bit 633 // substring ended up at the top of the word. 634 // (Knuth, volume 4, section 7.3.1) 635 switch _W { 636 case 32: 637 return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) 638 case 64: 639 return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) 640 default: 641 panic("unknown word size") 642 } 643 } 644 645 // trailingZeroBits returns the number of consecutive least significant zero 646 // bits of x. 647 func (x nat) trailingZeroBits() uint { 648 if len(x) == 0 { 649 return 0 650 } 651 var i uint 652 for x[i] == 0 { 653 i++ 654 } 655 // x[i] != 0 656 return i*_W + trailingZeroBits(x[i]) 657 } 658 659 // z = x << s 660 func (z nat) shl(x nat, s uint) nat { 661 m := len(x) 662 if m == 0 { 663 return z[:0] 664 } 665 // m > 0 666 667 n := m + int(s/_W) 668 z = z.make(n + 1) 669 z[n] = shlVU(z[n-m:n], x, s%_W) 670 z[0 : n-m].clear() 671 672 return z.norm() 673 } 674 675 // z = x >> s 676 func (z nat) shr(x nat, s uint) nat { 677 m := len(x) 678 n := m - int(s/_W) 679 if n <= 0 { 680 return z[:0] 681 } 682 // n > 0 683 684 z = z.make(n) 685 shrVU(z, x[m-n:], s%_W) 686 687 return z.norm() 688 } 689 690 func (z nat) setBit(x nat, i uint, b uint) nat { 691 j := int(i / _W) 692 m := Word(1) << (i % _W) 693 n := len(x) 694 switch b { 695 case 0: 696 z = z.make(n) 697 copy(z, x) 698 if j >= n { 699 // no need to grow 700 return z 701 } 702 z[j] &^= m 703 return z.norm() 704 case 1: 705 if j >= n { 706 z = z.make(j + 1) 707 z[n:].clear() 708 } else { 709 z = z.make(n) 710 } 711 copy(z, x) 712 z[j] |= m 713 // no need to normalize 714 return z 715 } 716 panic("set bit is not 0 or 1") 717 } 718 719 // bit returns the value of the i'th bit, with lsb == bit 0. 720 func (x nat) bit(i uint) uint { 721 j := i / _W 722 if j >= uint(len(x)) { 723 return 0 724 } 725 // 0 <= j < len(x) 726 return uint(x[j] >> (i % _W) & 1) 727 } 728 729 // sticky returns 1 if there's a 1 bit within the 730 // i least significant bits, otherwise it returns 0. 731 func (x nat) sticky(i uint) uint { 732 j := i / _W 733 if j >= uint(len(x)) { 734 if len(x) == 0 { 735 return 0 736 } 737 return 1 738 } 739 // 0 <= j < len(x) 740 for _, x := range x[:j] { 741 if x != 0 { 742 return 1 743 } 744 } 745 if x[j]<<(_W-i%_W) != 0 { 746 return 1 747 } 748 return 0 749 } 750 751 func (z nat) and(x, y nat) nat { 752 m := len(x) 753 n := len(y) 754 if m > n { 755 m = n 756 } 757 // m <= n 758 759 z = z.make(m) 760 for i := 0; i < m; i++ { 761 z[i] = x[i] & y[i] 762 } 763 764 return z.norm() 765 } 766 767 func (z nat) andNot(x, y nat) nat { 768 m := len(x) 769 n := len(y) 770 if n > m { 771 n = m 772 } 773 // m >= n 774 775 z = z.make(m) 776 for i := 0; i < n; i++ { 777 z[i] = x[i] &^ y[i] 778 } 779 copy(z[n:m], x[n:m]) 780 781 return z.norm() 782 } 783 784 func (z nat) or(x, y nat) nat { 785 m := len(x) 786 n := len(y) 787 s := x 788 if m < n { 789 n, m = m, n 790 s = y 791 } 792 // m >= n 793 794 z = z.make(m) 795 for i := 0; i < n; i++ { 796 z[i] = x[i] | y[i] 797 } 798 copy(z[n:m], s[n:m]) 799 800 return z.norm() 801 } 802 803 func (z nat) xor(x, y nat) nat { 804 m := len(x) 805 n := len(y) 806 s := x 807 if m < n { 808 n, m = m, n 809 s = y 810 } 811 // m >= n 812 813 z = z.make(m) 814 for i := 0; i < n; i++ { 815 z[i] = x[i] ^ y[i] 816 } 817 copy(z[n:m], s[n:m]) 818 819 return z.norm() 820 } 821 822 // greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2) 823 func greaterThan(x1, x2, y1, y2 Word) bool { 824 return x1 > y1 || x1 == y1 && x2 > y2 825 } 826 827 // modW returns x % d. 828 func (x nat) modW(d Word) (r Word) { 829 // TODO(agl): we don't actually need to store the q value. 830 var q nat 831 q = q.make(len(x)) 832 return divWVW(q, 0, x, d) 833 } 834 835 // random creates a random integer in [0..limit), using the space in z if 836 // possible. n is the bit length of limit. 837 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 838 if alias(z, limit) { 839 z = nil // z is an alias for limit - cannot reuse 840 } 841 z = z.make(len(limit)) 842 843 bitLengthOfMSW := uint(n % _W) 844 if bitLengthOfMSW == 0 { 845 bitLengthOfMSW = _W 846 } 847 mask := Word((1 << bitLengthOfMSW) - 1) 848 849 for { 850 switch _W { 851 case 32: 852 for i := range z { 853 z[i] = Word(rand.Uint32()) 854 } 855 case 64: 856 for i := range z { 857 z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 858 } 859 default: 860 panic("unknown word size") 861 } 862 z[len(limit)-1] &= mask 863 if z.cmp(limit) < 0 { 864 break 865 } 866 } 867 868 return z.norm() 869 } 870 871 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; 872 // otherwise it sets z to x**y. The result is the value of z. 873 func (z nat) expNN(x, y, m nat) nat { 874 if alias(z, x) || alias(z, y) { 875 // We cannot allow in-place modification of x or y. 876 z = nil 877 } 878 879 // x**y mod 1 == 0 880 if len(m) == 1 && m[0] == 1 { 881 return z.setWord(0) 882 } 883 // m == 0 || m > 1 884 885 // x**0 == 1 886 if len(y) == 0 { 887 return z.setWord(1) 888 } 889 // y > 0 890 891 // x**1 mod m == x mod m 892 if len(y) == 1 && y[0] == 1 && len(m) != 0 { 893 _, z = z.div(z, x, m) 894 return z 895 } 896 // y > 1 897 898 if len(m) != 0 { 899 // We likely end up being as long as the modulus. 900 z = z.make(len(m)) 901 } 902 z = z.set(x) 903 904 // If the base is non-trivial and the exponent is large, we use 905 // 4-bit, windowed exponentiation. This involves precomputing 14 values 906 // (x^2...x^15) but then reduces the number of multiply-reduces by a 907 // third. Even for a 32-bit exponent, this reduces the number of 908 // operations. 909 if len(x) > 1 && len(y) > 1 && len(m) > 0 { 910 return z.expNNWindowed(x, y, m) 911 } 912 913 v := y[len(y)-1] // v > 0 because y is normalized and y > 0 914 shift := leadingZeros(v) + 1 915 v <<= shift 916 var q nat 917 918 const mask = 1 << (_W - 1) 919 920 // We walk through the bits of the exponent one by one. Each time we 921 // see a bit, we square, thus doubling the power. If the bit is a one, 922 // we also multiply by x, thus adding one to the power. 923 924 w := _W - int(shift) 925 // zz and r are used to avoid allocating in mul and div as 926 // otherwise the arguments would alias. 927 var zz, r nat 928 for j := 0; j < w; j++ { 929 zz = zz.mul(z, z) 930 zz, z = z, zz 931 932 if v&mask != 0 { 933 zz = zz.mul(z, x) 934 zz, z = z, zz 935 } 936 937 if len(m) != 0 { 938 zz, r = zz.div(r, z, m) 939 zz, r, q, z = q, z, zz, r 940 } 941 942 v <<= 1 943 } 944 945 for i := len(y) - 2; i >= 0; i-- { 946 v = y[i] 947 948 for j := 0; j < _W; j++ { 949 zz = zz.mul(z, z) 950 zz, z = z, zz 951 952 if v&mask != 0 { 953 zz = zz.mul(z, x) 954 zz, z = z, zz 955 } 956 957 if len(m) != 0 { 958 zz, r = zz.div(r, z, m) 959 zz, r, q, z = q, z, zz, r 960 } 961 962 v <<= 1 963 } 964 } 965 966 return z.norm() 967 } 968 969 // expNNWindowed calculates x**y mod m using a fixed, 4-bit window. 970 func (z nat) expNNWindowed(x, y, m nat) nat { 971 // zz and r are used to avoid allocating in mul and div as otherwise 972 // the arguments would alias. 973 var zz, r nat 974 975 const n = 4 976 // powers[i] contains x^i. 977 var powers [1 << n]nat 978 powers[0] = natOne 979 powers[1] = x 980 for i := 2; i < 1<<n; i += 2 { 981 p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1] 982 *p = p.mul(*p2, *p2) 983 zz, r = zz.div(r, *p, m) 984 *p, r = r, *p 985 *p1 = p1.mul(*p, x) 986 zz, r = zz.div(r, *p1, m) 987 *p1, r = r, *p1 988 } 989 990 z = z.setWord(1) 991 992 for i := len(y) - 1; i >= 0; i-- { 993 yi := y[i] 994 for j := 0; j < _W; j += n { 995 if i != len(y)-1 || j != 0 { 996 // Unrolled loop for significant performance 997 // gain. Use go test -bench=".*" in crypto/rsa 998 // to check performance before making changes. 999 zz = zz.mul(z, z) 1000 zz, z = z, zz 1001 zz, r = zz.div(r, z, m) 1002 z, r = r, z 1003 1004 zz = zz.mul(z, z) 1005 zz, z = z, zz 1006 zz, r = zz.div(r, z, m) 1007 z, r = r, z 1008 1009 zz = zz.mul(z, z) 1010 zz, z = z, zz 1011 zz, r = zz.div(r, z, m) 1012 z, r = r, z 1013 1014 zz = zz.mul(z, z) 1015 zz, z = z, zz 1016 zz, r = zz.div(r, z, m) 1017 z, r = r, z 1018 } 1019 1020 zz = zz.mul(z, powers[yi>>(_W-n)]) 1021 zz, z = z, zz 1022 zz, r = zz.div(r, z, m) 1023 z, r = r, z 1024 1025 yi <<= n 1026 } 1027 } 1028 1029 return z.norm() 1030 } 1031 1032 // probablyPrime performs reps Miller-Rabin tests to check whether n is prime. 1033 // If it returns true, n is prime with probability 1 - 1/4^reps. 1034 // If it returns false, n is not prime. 1035 func (n nat) probablyPrime(reps int) bool { 1036 if len(n) == 0 { 1037 return false 1038 } 1039 1040 if len(n) == 1 { 1041 if n[0] < 2 { 1042 return false 1043 } 1044 1045 if n[0]%2 == 0 { 1046 return n[0] == 2 1047 } 1048 1049 // We have to exclude these cases because we reject all 1050 // multiples of these numbers below. 1051 switch n[0] { 1052 case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53: 1053 return true 1054 } 1055 } 1056 1057 if n[0]&1 == 0 { 1058 return false // n is even 1059 } 1060 1061 const primesProduct32 = 0xC0CFD797 // Π {p ∈ primes, 2 < p <= 29} 1062 const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53} 1063 1064 var r Word 1065 switch _W { 1066 case 32: 1067 r = n.modW(primesProduct32) 1068 case 64: 1069 r = n.modW(primesProduct64 & _M) 1070 default: 1071 panic("Unknown word size") 1072 } 1073 1074 if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 || 1075 r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 { 1076 return false 1077 } 1078 1079 if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 || 1080 r%43 == 0 || r%47 == 0 || r%53 == 0) { 1081 return false 1082 } 1083 1084 nm1 := nat(nil).sub(n, natOne) 1085 // determine q, k such that nm1 = q << k 1086 k := nm1.trailingZeroBits() 1087 q := nat(nil).shr(nm1, k) 1088 1089 nm3 := nat(nil).sub(nm1, natTwo) 1090 rand := rand.New(rand.NewSource(int64(n[0]))) 1091 1092 var x, y, quotient nat 1093 nm3Len := nm3.bitLen() 1094 1095 NextRandom: 1096 for i := 0; i < reps; i++ { 1097 x = x.random(rand, nm3, nm3Len) 1098 x = x.add(x, natTwo) 1099 y = y.expNN(x, q, n) 1100 if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { 1101 continue 1102 } 1103 for j := uint(1); j < k; j++ { 1104 y = y.mul(y, y) 1105 quotient, y = quotient.div(y, y, n) 1106 if y.cmp(nm1) == 0 { 1107 continue NextRandom 1108 } 1109 if y.cmp(natOne) == 0 { 1110 return false 1111 } 1112 } 1113 return false 1114 } 1115 1116 return true 1117 } 1118 1119 // bytes writes the value of z into buf using big-endian encoding. 1120 // len(buf) must be >= len(z)*_S. The value of z is encoded in the 1121 // slice buf[i:]. The number i of unused bytes at the beginning of 1122 // buf is returned as result. 1123 func (z nat) bytes(buf []byte) (i int) { 1124 i = len(buf) 1125 for _, d := range z { 1126 for j := 0; j < _S; j++ { 1127 i-- 1128 buf[i] = byte(d) 1129 d >>= 8 1130 } 1131 } 1132 1133 for i < len(buf) && buf[i] == 0 { 1134 i++ 1135 } 1136 1137 return 1138 } 1139 1140 // setBytes interprets buf as the bytes of a big-endian unsigned 1141 // integer, sets z to that value, and returns z. 1142 func (z nat) setBytes(buf []byte) nat { 1143 z = z.make((len(buf) + _S - 1) / _S) 1144 1145 k := 0 1146 s := uint(0) 1147 var d Word 1148 for i := len(buf); i > 0; i-- { 1149 d |= Word(buf[i-1]) << s 1150 if s += 8; s == _S*8 { 1151 z[k] = d 1152 k++ 1153 s = 0 1154 d = 0 1155 } 1156 } 1157 if k < len(z) { 1158 z[k] = d 1159 } 1160 1161 return z.norm() 1162 }