github.com/rsc/go@v0.0.0-20150416155037-e040fd465409/src/cmd/internal/gc/big/nat.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package big implements multi-precision arithmetic (big numbers). 6 // The following numeric types are supported: 7 // 8 // Int signed integers 9 // Rat rational numbers 10 // Float floating-point numbers 11 // 12 // Methods are typically of the form: 13 // 14 // func (z *T) Unary(x *T) *T // z = op x 15 // func (z *T) Binary(x, y *T) *T // z = x op y 16 // func (x *T) M() T1 // v = x.M() 17 // 18 // with T one of Int, Rat, or Float. For unary and binary operations, the 19 // result is the receiver (usually named z in that case); if it is one of 20 // the operands x or y it may be overwritten (and its memory reused). 21 // To enable chaining of operations, the result is also returned. Methods 22 // returning a result other than *Int, *Rat, or *Float take an operand as 23 // the receiver (usually named x in that case). 24 // 25 package big 26 27 // This file contains operations on unsigned multi-precision integers. 28 // These are the building blocks for the operations on signed integers 29 // and rationals. 30 31 import "math/rand" 32 33 // An unsigned integer x of the form 34 // 35 // x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] 36 // 37 // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, 38 // with the digits x[i] as the slice elements. 39 // 40 // A number is normalized if the slice contains no leading 0 digits. 41 // During arithmetic operations, denormalized values may occur but are 42 // always normalized before returning the final result. The normalized 43 // representation of 0 is the empty or nil slice (length = 0). 44 // 45 type nat []Word 46 47 var ( 48 natOne = nat{1} 49 natTwo = nat{2} 50 natTen = nat{10} 51 ) 52 53 func (z nat) clear() { 54 for i := range z { 55 z[i] = 0 56 } 57 } 58 59 func (z nat) norm() nat { 60 i := len(z) 61 for i > 0 && z[i-1] == 0 { 62 i-- 63 } 64 return z[0:i] 65 } 66 67 func (z nat) make(n int) nat { 68 if n <= cap(z) { 69 return z[:n] // reuse z 70 } 71 // Choosing a good value for e has significant performance impact 72 // because it increases the chance that a value can be reused. 73 const e = 4 // extra capacity 74 return make(nat, n, n+e) 75 } 76 77 func (z nat) setWord(x Word) nat { 78 if x == 0 { 79 return z[:0] 80 } 81 z = z.make(1) 82 z[0] = x 83 return z 84 } 85 86 func (z nat) setUint64(x uint64) nat { 87 // single-digit values 88 if w := Word(x); uint64(w) == x { 89 return z.setWord(w) 90 } 91 92 // compute number of words n required to represent x 93 n := 0 94 for t := x; t > 0; t >>= _W { 95 n++ 96 } 97 98 // split x into n words 99 z = z.make(n) 100 for i := range z { 101 z[i] = Word(x & _M) 102 x >>= _W 103 } 104 105 return z 106 } 107 108 func (z nat) set(x nat) nat { 109 z = z.make(len(x)) 110 copy(z, x) 111 return z 112 } 113 114 func (z nat) add(x, y nat) nat { 115 m := len(x) 116 n := len(y) 117 118 switch { 119 case m < n: 120 return z.add(y, x) 121 case m == 0: 122 // n == 0 because m >= n; result is 0 123 return z[:0] 124 case n == 0: 125 // result is x 126 return z.set(x) 127 } 128 // m > 0 129 130 z = z.make(m + 1) 131 c := addVV(z[0:n], x, y) 132 if m > n { 133 c = addVW(z[n:m], x[n:], c) 134 } 135 z[m] = c 136 137 return z.norm() 138 } 139 140 func (z nat) sub(x, y nat) nat { 141 m := len(x) 142 n := len(y) 143 144 switch { 145 case m < n: 146 panic("underflow") 147 case m == 0: 148 // n == 0 because m >= n; result is 0 149 return z[:0] 150 case n == 0: 151 // result is x 152 return z.set(x) 153 } 154 // m > 0 155 156 z = z.make(m) 157 c := subVV(z[0:n], x, y) 158 if m > n { 159 c = subVW(z[n:], x[n:], c) 160 } 161 if c != 0 { 162 panic("underflow") 163 } 164 165 return z.norm() 166 } 167 168 func (x nat) cmp(y nat) (r int) { 169 m := len(x) 170 n := len(y) 171 if m != n || m == 0 { 172 switch { 173 case m < n: 174 r = -1 175 case m > n: 176 r = 1 177 } 178 return 179 } 180 181 i := m - 1 182 for i > 0 && x[i] == y[i] { 183 i-- 184 } 185 186 switch { 187 case x[i] < y[i]: 188 r = -1 189 case x[i] > y[i]: 190 r = 1 191 } 192 return 193 } 194 195 func (z nat) mulAddWW(x nat, y, r Word) nat { 196 m := len(x) 197 if m == 0 || y == 0 { 198 return z.setWord(r) // result is r 199 } 200 // m > 0 201 202 z = z.make(m + 1) 203 z[m] = mulAddVWW(z[0:m], x, y, r) 204 205 return z.norm() 206 } 207 208 // basicMul multiplies x and y and leaves the result in z. 209 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. 210 func basicMul(z, x, y nat) { 211 z[0 : len(x)+len(y)].clear() // initialize z 212 for i, d := range y { 213 if d != 0 { 214 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) 215 } 216 } 217 } 218 219 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. 220 // Factored out for readability - do not use outside karatsuba. 221 func karatsubaAdd(z, x nat, n int) { 222 if c := addVV(z[0:n], z, x); c != 0 { 223 addVW(z[n:n+n>>1], z[n:], c) 224 } 225 } 226 227 // Like karatsubaAdd, but does subtract. 228 func karatsubaSub(z, x nat, n int) { 229 if c := subVV(z[0:n], z, x); c != 0 { 230 subVW(z[n:n+n>>1], z[n:], c) 231 } 232 } 233 234 // Operands that are shorter than karatsubaThreshold are multiplied using 235 // "grade school" multiplication; for longer operands the Karatsuba algorithm 236 // is used. 237 var karatsubaThreshold int = 40 // computed by calibrate.go 238 239 // karatsuba multiplies x and y and leaves the result in z. 240 // Both x and y must have the same length n and n must be a 241 // power of 2. The result vector z must have len(z) >= 6*n. 242 // The (non-normalized) result is placed in z[0 : 2*n]. 243 func karatsuba(z, x, y nat) { 244 n := len(y) 245 246 // Switch to basic multiplication if numbers are odd or small. 247 // (n is always even if karatsubaThreshold is even, but be 248 // conservative) 249 if n&1 != 0 || n < karatsubaThreshold || n < 2 { 250 basicMul(z, x, y) 251 return 252 } 253 // n&1 == 0 && n >= karatsubaThreshold && n >= 2 254 255 // Karatsuba multiplication is based on the observation that 256 // for two numbers x and y with: 257 // 258 // x = x1*b + x0 259 // y = y1*b + y0 260 // 261 // the product x*y can be obtained with 3 products z2, z1, z0 262 // instead of 4: 263 // 264 // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 265 // = z2*b*b + z1*b + z0 266 // 267 // with: 268 // 269 // xd = x1 - x0 270 // yd = y0 - y1 271 // 272 // z1 = xd*yd + z2 + z0 273 // = (x1-x0)*(y0 - y1) + z2 + z0 274 // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 275 // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 276 // = x1*y0 + x0*y1 277 278 // split x, y into "digits" 279 n2 := n >> 1 // n2 >= 1 280 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 281 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 282 283 // z is used for the result and temporary storage: 284 // 285 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 286 // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] 287 // 288 // For each recursive call of karatsuba, an unused slice of 289 // z is passed in that has (at least) half the length of the 290 // caller's z. 291 292 // compute z0 and z2 with the result "in place" in z 293 karatsuba(z, x0, y0) // z0 = x0*y0 294 karatsuba(z[n:], x1, y1) // z2 = x1*y1 295 296 // compute xd (or the negative value if underflow occurs) 297 s := 1 // sign of product xd*yd 298 xd := z[2*n : 2*n+n2] 299 if subVV(xd, x1, x0) != 0 { // x1-x0 300 s = -s 301 subVV(xd, x0, x1) // x0-x1 302 } 303 304 // compute yd (or the negative value if underflow occurs) 305 yd := z[2*n+n2 : 3*n] 306 if subVV(yd, y0, y1) != 0 { // y0-y1 307 s = -s 308 subVV(yd, y1, y0) // y1-y0 309 } 310 311 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 312 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 313 p := z[n*3:] 314 karatsuba(p, xd, yd) 315 316 // save original z2:z0 317 // (ok to use upper half of z since we're done recursing) 318 r := z[n*4:] 319 copy(r, z[:n*2]) 320 321 // add up all partial products 322 // 323 // 2*n n 0 324 // z = [ z2 | z0 ] 325 // + [ z0 ] 326 // + [ z2 ] 327 // + [ p ] 328 // 329 karatsubaAdd(z[n2:], r, n) 330 karatsubaAdd(z[n2:], r[n:], n) 331 if s > 0 { 332 karatsubaAdd(z[n2:], p, n) 333 } else { 334 karatsubaSub(z[n2:], p, n) 335 } 336 } 337 338 // alias reports whether x and y share the same base array. 339 func alias(x, y nat) bool { 340 return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 341 } 342 343 // addAt implements z += x<<(_W*i); z must be long enough. 344 // (we don't use nat.add because we need z to stay the same 345 // slice, and we don't need to normalize z after each addition) 346 func addAt(z, x nat, i int) { 347 if n := len(x); n > 0 { 348 if c := addVV(z[i:i+n], z[i:], x); c != 0 { 349 j := i + n 350 if j < len(z) { 351 addVW(z[j:], z[j:], c) 352 } 353 } 354 } 355 } 356 357 func max(x, y int) int { 358 if x > y { 359 return x 360 } 361 return y 362 } 363 364 // karatsubaLen computes an approximation to the maximum k <= n such that 365 // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the 366 // result is the largest number that can be divided repeatedly by 2 before 367 // becoming about the value of karatsubaThreshold. 368 func karatsubaLen(n int) int { 369 i := uint(0) 370 for n > karatsubaThreshold { 371 n >>= 1 372 i++ 373 } 374 return n << i 375 } 376 377 func (z nat) mul(x, y nat) nat { 378 m := len(x) 379 n := len(y) 380 381 switch { 382 case m < n: 383 return z.mul(y, x) 384 case m == 0 || n == 0: 385 return z[:0] 386 case n == 1: 387 return z.mulAddWW(x, y[0], 0) 388 } 389 // m >= n > 1 390 391 // determine if z can be reused 392 if alias(z, x) || alias(z, y) { 393 z = nil // z is an alias for x or y - cannot reuse 394 } 395 396 // use basic multiplication if the numbers are small 397 if n < karatsubaThreshold { 398 z = z.make(m + n) 399 basicMul(z, x, y) 400 return z.norm() 401 } 402 // m >= n && n >= karatsubaThreshold && n >= 2 403 404 // determine Karatsuba length k such that 405 // 406 // x = xh*b + x0 (0 <= x0 < b) 407 // y = yh*b + y0 (0 <= y0 < b) 408 // b = 1<<(_W*k) ("base" of digits xi, yi) 409 // 410 k := karatsubaLen(n) 411 // k <= n 412 413 // multiply x0 and y0 via Karatsuba 414 x0 := x[0:k] // x0 is not normalized 415 y0 := y[0:k] // y0 is not normalized 416 z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y 417 karatsuba(z, x0, y0) 418 z = z[0 : m+n] // z has final length but may be incomplete 419 z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) 420 421 // If xh != 0 or yh != 0, add the missing terms to z. For 422 // 423 // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) 424 // yh = y1*b (0 <= y1 < b) 425 // 426 // the missing terms are 427 // 428 // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 429 // 430 // since all the yi for i > 1 are 0 by choice of k: If any of them 431 // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would 432 // be a larger valid threshold contradicting the assumption about k. 433 // 434 if k < n || m != n { 435 var t nat 436 437 // add x0*y1*b 438 x0 := x0.norm() 439 y1 := y[k:] // y1 is normalized because y is 440 t = t.mul(x0, y1) // update t so we don't lose t's underlying array 441 addAt(z, t, k) 442 443 // add xi*y0<<i, xi*y1*b<<(i+k) 444 y0 := y0.norm() 445 for i := k; i < len(x); i += k { 446 xi := x[i:] 447 if len(xi) > k { 448 xi = xi[:k] 449 } 450 xi = xi.norm() 451 t = t.mul(xi, y0) 452 addAt(z, t, i) 453 t = t.mul(xi, y1) 454 addAt(z, t, i+k) 455 } 456 } 457 458 return z.norm() 459 } 460 461 // mulRange computes the product of all the unsigned integers in the 462 // range [a, b] inclusively. If a > b (empty range), the result is 1. 463 func (z nat) mulRange(a, b uint64) nat { 464 switch { 465 case a == 0: 466 // cut long ranges short (optimization) 467 return z.setUint64(0) 468 case a > b: 469 return z.setUint64(1) 470 case a == b: 471 return z.setUint64(a) 472 case a+1 == b: 473 return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) 474 } 475 m := (a + b) / 2 476 return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) 477 } 478 479 // q = (x-r)/y, with 0 <= r < y 480 func (z nat) divW(x nat, y Word) (q nat, r Word) { 481 m := len(x) 482 switch { 483 case y == 0: 484 panic("division by zero") 485 case y == 1: 486 q = z.set(x) // result is x 487 return 488 case m == 0: 489 q = z[:0] // result is 0 490 return 491 } 492 // m > 0 493 z = z.make(m) 494 r = divWVW(z, 0, x, y) 495 q = z.norm() 496 return 497 } 498 499 func (z nat) div(z2, u, v nat) (q, r nat) { 500 if len(v) == 0 { 501 panic("division by zero") 502 } 503 504 if u.cmp(v) < 0 { 505 q = z[:0] 506 r = z2.set(u) 507 return 508 } 509 510 if len(v) == 1 { 511 var r2 Word 512 q, r2 = z.divW(u, v[0]) 513 r = z2.setWord(r2) 514 return 515 } 516 517 q, r = z.divLarge(z2, u, v) 518 return 519 } 520 521 // q = (uIn-r)/v, with 0 <= r < y 522 // Uses z as storage for q, and u as storage for r if possible. 523 // See Knuth, Volume 2, section 4.3.1, Algorithm D. 524 // Preconditions: 525 // len(v) >= 2 526 // len(uIn) >= len(v) 527 func (z nat) divLarge(u, uIn, v nat) (q, r nat) { 528 n := len(v) 529 m := len(uIn) - n 530 531 // determine if z can be reused 532 // TODO(gri) should find a better solution - this if statement 533 // is very costly (see e.g. time pidigits -s -n 10000) 534 if alias(z, uIn) || alias(z, v) { 535 z = nil // z is an alias for uIn or v - cannot reuse 536 } 537 q = z.make(m + 1) 538 539 qhatv := make(nat, n+1) 540 if alias(u, uIn) || alias(u, v) { 541 u = nil // u is an alias for uIn or v - cannot reuse 542 } 543 u = u.make(len(uIn) + 1) 544 u.clear() // TODO(gri) no need to clear if we allocated a new u 545 546 // D1. 547 shift := leadingZeros(v[n-1]) 548 if shift > 0 { 549 // do not modify v, it may be used by another goroutine simultaneously 550 v1 := make(nat, n) 551 shlVU(v1, v, shift) 552 v = v1 553 } 554 u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift) 555 556 // D2. 557 for j := m; j >= 0; j-- { 558 // D3. 559 qhat := Word(_M) 560 if u[j+n] != v[n-1] { 561 var rhat Word 562 qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1]) 563 564 // x1 | x2 = q̂v_{n-2} 565 x1, x2 := mulWW(qhat, v[n-2]) 566 // test if q̂v_{n-2} > br̂ + u_{j+n-2} 567 for greaterThan(x1, x2, rhat, u[j+n-2]) { 568 qhat-- 569 prevRhat := rhat 570 rhat += v[n-1] 571 // v[n-1] >= 0, so this tests for overflow. 572 if rhat < prevRhat { 573 break 574 } 575 x1, x2 = mulWW(qhat, v[n-2]) 576 } 577 } 578 579 // D4. 580 qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) 581 582 c := subVV(u[j:j+len(qhatv)], u[j:], qhatv) 583 if c != 0 { 584 c := addVV(u[j:j+n], u[j:], v) 585 u[j+n] += c 586 qhat-- 587 } 588 589 q[j] = qhat 590 } 591 592 q = q.norm() 593 shrVU(u, u, shift) 594 r = u.norm() 595 596 return q, r 597 } 598 599 // Length of x in bits. x must be normalized. 600 func (x nat) bitLen() int { 601 if i := len(x) - 1; i >= 0 { 602 return i*_W + bitLen(x[i]) 603 } 604 return 0 605 } 606 607 const deBruijn32 = 0x077CB531 608 609 var deBruijn32Lookup = []byte{ 610 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 611 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9, 612 } 613 614 const deBruijn64 = 0x03f79d71b4ca8b09 615 616 var deBruijn64Lookup = []byte{ 617 0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4, 618 62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5, 619 63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11, 620 54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6, 621 } 622 623 // trailingZeroBits returns the number of consecutive least significant zero 624 // bits of x. 625 func trailingZeroBits(x Word) uint { 626 // x & -x leaves only the right-most bit set in the word. Let k be the 627 // index of that bit. Since only a single bit is set, the value is two 628 // to the power of k. Multiplying by a power of two is equivalent to 629 // left shifting, in this case by k bits. The de Bruijn constant is 630 // such that all six bit, consecutive substrings are distinct. 631 // Therefore, if we have a left shifted version of this constant we can 632 // find by how many bits it was shifted by looking at which six bit 633 // substring ended up at the top of the word. 634 // (Knuth, volume 4, section 7.3.1) 635 switch _W { 636 case 32: 637 return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) 638 case 64: 639 return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) 640 default: 641 panic("unknown word size") 642 } 643 } 644 645 // trailingZeroBits returns the number of consecutive least significant zero 646 // bits of x. 647 func (x nat) trailingZeroBits() uint { 648 if len(x) == 0 { 649 return 0 650 } 651 var i uint 652 for x[i] == 0 { 653 i++ 654 } 655 // x[i] != 0 656 return i*_W + trailingZeroBits(x[i]) 657 } 658 659 // z = x << s 660 func (z nat) shl(x nat, s uint) nat { 661 m := len(x) 662 if m == 0 { 663 return z[:0] 664 } 665 // m > 0 666 667 n := m + int(s/_W) 668 z = z.make(n + 1) 669 z[n] = shlVU(z[n-m:n], x, s%_W) 670 z[0 : n-m].clear() 671 672 return z.norm() 673 } 674 675 // z = x >> s 676 func (z nat) shr(x nat, s uint) nat { 677 m := len(x) 678 n := m - int(s/_W) 679 if n <= 0 { 680 return z[:0] 681 } 682 // n > 0 683 684 z = z.make(n) 685 shrVU(z, x[m-n:], s%_W) 686 687 return z.norm() 688 } 689 690 func (z nat) setBit(x nat, i uint, b uint) nat { 691 j := int(i / _W) 692 m := Word(1) << (i % _W) 693 n := len(x) 694 switch b { 695 case 0: 696 z = z.make(n) 697 copy(z, x) 698 if j >= n { 699 // no need to grow 700 return z 701 } 702 z[j] &^= m 703 return z.norm() 704 case 1: 705 if j >= n { 706 z = z.make(j + 1) 707 z[n:].clear() 708 } else { 709 z = z.make(n) 710 } 711 copy(z, x) 712 z[j] |= m 713 // no need to normalize 714 return z 715 } 716 panic("set bit is not 0 or 1") 717 } 718 719 // bit returns the value of the i'th bit, with lsb == bit 0. 720 func (x nat) bit(i uint) uint { 721 j := i / _W 722 if j >= uint(len(x)) { 723 return 0 724 } 725 // 0 <= j < len(x) 726 return uint(x[j] >> (i % _W) & 1) 727 } 728 729 // sticky returns 1 if there's a 1 bit within the 730 // i least significant bits, otherwise it returns 0. 731 func (x nat) sticky(i uint) uint { 732 j := i / _W 733 if j >= uint(len(x)) { 734 if len(x) == 0 { 735 return 0 736 } 737 return 1 738 } 739 // 0 <= j < len(x) 740 for _, x := range x[:j] { 741 if x != 0 { 742 return 1 743 } 744 } 745 if x[j]<<(_W-i%_W) != 0 { 746 return 1 747 } 748 return 0 749 } 750 751 func (z nat) and(x, y nat) nat { 752 m := len(x) 753 n := len(y) 754 if m > n { 755 m = n 756 } 757 // m <= n 758 759 z = z.make(m) 760 for i := 0; i < m; i++ { 761 z[i] = x[i] & y[i] 762 } 763 764 return z.norm() 765 } 766 767 func (z nat) andNot(x, y nat) nat { 768 m := len(x) 769 n := len(y) 770 if n > m { 771 n = m 772 } 773 // m >= n 774 775 z = z.make(m) 776 for i := 0; i < n; i++ { 777 z[i] = x[i] &^ y[i] 778 } 779 copy(z[n:m], x[n:m]) 780 781 return z.norm() 782 } 783 784 func (z nat) or(x, y nat) nat { 785 m := len(x) 786 n := len(y) 787 s := x 788 if m < n { 789 n, m = m, n 790 s = y 791 } 792 // m >= n 793 794 z = z.make(m) 795 for i := 0; i < n; i++ { 796 z[i] = x[i] | y[i] 797 } 798 copy(z[n:m], s[n:m]) 799 800 return z.norm() 801 } 802 803 func (z nat) xor(x, y nat) nat { 804 m := len(x) 805 n := len(y) 806 s := x 807 if m < n { 808 n, m = m, n 809 s = y 810 } 811 // m >= n 812 813 z = z.make(m) 814 for i := 0; i < n; i++ { 815 z[i] = x[i] ^ y[i] 816 } 817 copy(z[n:m], s[n:m]) 818 819 return z.norm() 820 } 821 822 // greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2) 823 func greaterThan(x1, x2, y1, y2 Word) bool { 824 return x1 > y1 || x1 == y1 && x2 > y2 825 } 826 827 // modW returns x % d. 828 func (x nat) modW(d Word) (r Word) { 829 // TODO(agl): we don't actually need to store the q value. 830 var q nat 831 q = q.make(len(x)) 832 return divWVW(q, 0, x, d) 833 } 834 835 // random creates a random integer in [0..limit), using the space in z if 836 // possible. n is the bit length of limit. 837 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 838 if alias(z, limit) { 839 z = nil // z is an alias for limit - cannot reuse 840 } 841 z = z.make(len(limit)) 842 843 bitLengthOfMSW := uint(n % _W) 844 if bitLengthOfMSW == 0 { 845 bitLengthOfMSW = _W 846 } 847 mask := Word((1 << bitLengthOfMSW) - 1) 848 849 for { 850 switch _W { 851 case 32: 852 for i := range z { 853 z[i] = Word(rand.Uint32()) 854 } 855 case 64: 856 for i := range z { 857 z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 858 } 859 default: 860 panic("unknown word size") 861 } 862 z[len(limit)-1] &= mask 863 if z.cmp(limit) < 0 { 864 break 865 } 866 } 867 868 return z.norm() 869 } 870 871 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; 872 // otherwise it sets z to x**y. The result is the value of z. 873 func (z nat) expNN(x, y, m nat) nat { 874 if alias(z, x) || alias(z, y) { 875 // We cannot allow in-place modification of x or y. 876 z = nil 877 } 878 879 // x**y mod 1 == 0 880 if len(m) == 1 && m[0] == 1 { 881 return z.setWord(0) 882 } 883 // m == 0 || m > 1 884 885 // x**0 == 1 886 if len(y) == 0 { 887 return z.setWord(1) 888 } 889 // y > 0 890 891 if len(m) != 0 { 892 // We likely end up being as long as the modulus. 893 z = z.make(len(m)) 894 } 895 z = z.set(x) 896 897 // If the base is non-trivial and the exponent is large, we use 898 // 4-bit, windowed exponentiation. This involves precomputing 14 values 899 // (x^2...x^15) but then reduces the number of multiply-reduces by a 900 // third. Even for a 32-bit exponent, this reduces the number of 901 // operations. 902 if len(x) > 1 && len(y) > 1 && len(m) > 0 { 903 return z.expNNWindowed(x, y, m) 904 } 905 906 v := y[len(y)-1] // v > 0 because y is normalized and y > 0 907 shift := leadingZeros(v) + 1 908 v <<= shift 909 var q nat 910 911 const mask = 1 << (_W - 1) 912 913 // We walk through the bits of the exponent one by one. Each time we 914 // see a bit, we square, thus doubling the power. If the bit is a one, 915 // we also multiply by x, thus adding one to the power. 916 917 w := _W - int(shift) 918 // zz and r are used to avoid allocating in mul and div as 919 // otherwise the arguments would alias. 920 var zz, r nat 921 for j := 0; j < w; j++ { 922 zz = zz.mul(z, z) 923 zz, z = z, zz 924 925 if v&mask != 0 { 926 zz = zz.mul(z, x) 927 zz, z = z, zz 928 } 929 930 if len(m) != 0 { 931 zz, r = zz.div(r, z, m) 932 zz, r, q, z = q, z, zz, r 933 } 934 935 v <<= 1 936 } 937 938 for i := len(y) - 2; i >= 0; i-- { 939 v = y[i] 940 941 for j := 0; j < _W; j++ { 942 zz = zz.mul(z, z) 943 zz, z = z, zz 944 945 if v&mask != 0 { 946 zz = zz.mul(z, x) 947 zz, z = z, zz 948 } 949 950 if len(m) != 0 { 951 zz, r = zz.div(r, z, m) 952 zz, r, q, z = q, z, zz, r 953 } 954 955 v <<= 1 956 } 957 } 958 959 return z.norm() 960 } 961 962 // expNNWindowed calculates x**y mod m using a fixed, 4-bit window. 963 func (z nat) expNNWindowed(x, y, m nat) nat { 964 // zz and r are used to avoid allocating in mul and div as otherwise 965 // the arguments would alias. 966 var zz, r nat 967 968 const n = 4 969 // powers[i] contains x^i. 970 var powers [1 << n]nat 971 powers[0] = natOne 972 powers[1] = x 973 for i := 2; i < 1<<n; i += 2 { 974 p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1] 975 *p = p.mul(*p2, *p2) 976 zz, r = zz.div(r, *p, m) 977 *p, r = r, *p 978 *p1 = p1.mul(*p, x) 979 zz, r = zz.div(r, *p1, m) 980 *p1, r = r, *p1 981 } 982 983 z = z.setWord(1) 984 985 for i := len(y) - 1; i >= 0; i-- { 986 yi := y[i] 987 for j := 0; j < _W; j += n { 988 if i != len(y)-1 || j != 0 { 989 // Unrolled loop for significant performance 990 // gain. Use go test -bench=".*" in crypto/rsa 991 // to check performance before making changes. 992 zz = zz.mul(z, z) 993 zz, z = z, zz 994 zz, r = zz.div(r, z, m) 995 z, r = r, z 996 997 zz = zz.mul(z, z) 998 zz, z = z, zz 999 zz, r = zz.div(r, z, m) 1000 z, r = r, z 1001 1002 zz = zz.mul(z, z) 1003 zz, z = z, zz 1004 zz, r = zz.div(r, z, m) 1005 z, r = r, z 1006 1007 zz = zz.mul(z, z) 1008 zz, z = z, zz 1009 zz, r = zz.div(r, z, m) 1010 z, r = r, z 1011 } 1012 1013 zz = zz.mul(z, powers[yi>>(_W-n)]) 1014 zz, z = z, zz 1015 zz, r = zz.div(r, z, m) 1016 z, r = r, z 1017 1018 yi <<= n 1019 } 1020 } 1021 1022 return z.norm() 1023 } 1024 1025 // probablyPrime performs reps Miller-Rabin tests to check whether n is prime. 1026 // If it returns true, n is prime with probability 1 - 1/4^reps. 1027 // If it returns false, n is not prime. 1028 func (n nat) probablyPrime(reps int) bool { 1029 if len(n) == 0 { 1030 return false 1031 } 1032 1033 if len(n) == 1 { 1034 if n[0] < 2 { 1035 return false 1036 } 1037 1038 if n[0]%2 == 0 { 1039 return n[0] == 2 1040 } 1041 1042 // We have to exclude these cases because we reject all 1043 // multiples of these numbers below. 1044 switch n[0] { 1045 case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53: 1046 return true 1047 } 1048 } 1049 1050 if n[0]&1 == 0 { 1051 return false // n is even 1052 } 1053 1054 const primesProduct32 = 0xC0CFD797 // Π {p ∈ primes, 2 < p <= 29} 1055 const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53} 1056 1057 var r Word 1058 switch _W { 1059 case 32: 1060 r = n.modW(primesProduct32) 1061 case 64: 1062 r = n.modW(primesProduct64 & _M) 1063 default: 1064 panic("Unknown word size") 1065 } 1066 1067 if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 || 1068 r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 { 1069 return false 1070 } 1071 1072 if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 || 1073 r%43 == 0 || r%47 == 0 || r%53 == 0) { 1074 return false 1075 } 1076 1077 nm1 := nat(nil).sub(n, natOne) 1078 // determine q, k such that nm1 = q << k 1079 k := nm1.trailingZeroBits() 1080 q := nat(nil).shr(nm1, k) 1081 1082 nm3 := nat(nil).sub(nm1, natTwo) 1083 rand := rand.New(rand.NewSource(int64(n[0]))) 1084 1085 var x, y, quotient nat 1086 nm3Len := nm3.bitLen() 1087 1088 NextRandom: 1089 for i := 0; i < reps; i++ { 1090 x = x.random(rand, nm3, nm3Len) 1091 x = x.add(x, natTwo) 1092 y = y.expNN(x, q, n) 1093 if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { 1094 continue 1095 } 1096 for j := uint(1); j < k; j++ { 1097 y = y.mul(y, y) 1098 quotient, y = quotient.div(y, y, n) 1099 if y.cmp(nm1) == 0 { 1100 continue NextRandom 1101 } 1102 if y.cmp(natOne) == 0 { 1103 return false 1104 } 1105 } 1106 return false 1107 } 1108 1109 return true 1110 } 1111 1112 // bytes writes the value of z into buf using big-endian encoding. 1113 // len(buf) must be >= len(z)*_S. The value of z is encoded in the 1114 // slice buf[i:]. The number i of unused bytes at the beginning of 1115 // buf is returned as result. 1116 func (z nat) bytes(buf []byte) (i int) { 1117 i = len(buf) 1118 for _, d := range z { 1119 for j := 0; j < _S; j++ { 1120 i-- 1121 buf[i] = byte(d) 1122 d >>= 8 1123 } 1124 } 1125 1126 for i < len(buf) && buf[i] == 0 { 1127 i++ 1128 } 1129 1130 return 1131 } 1132 1133 // setBytes interprets buf as the bytes of a big-endian unsigned 1134 // integer, sets z to that value, and returns z. 1135 func (z nat) setBytes(buf []byte) nat { 1136 z = z.make((len(buf) + _S - 1) / _S) 1137 1138 k := 0 1139 s := uint(0) 1140 var d Word 1141 for i := len(buf); i > 0; i-- { 1142 d |= Word(buf[i-1]) << s 1143 if s += 8; s == _S*8 { 1144 z[k] = d 1145 k++ 1146 s = 0 1147 d = 0 1148 } 1149 } 1150 if k < len(z) { 1151 z[k] = d 1152 } 1153 1154 return z.norm() 1155 }