github.com/rsc/tmp@v0.0.0-20240517235954-6deaab19748b/bootstrap/internal/gc/big/nat.go (about) 1 // Do not edit. Bootstrap copy of /Users/rsc/g/go/src/cmd/internal/gc/big/nat.go 2 3 // Copyright 2009 The Go Authors. All rights reserved. 4 // Use of this source code is governed by a BSD-style 5 // license that can be found in the LICENSE file. 6 7 // Package big implements multi-precision arithmetic (big numbers). 8 // The following numeric types are supported: 9 // 10 // Int signed integers 11 // Rat rational numbers 12 // Float floating-point numbers 13 // 14 // Methods are typically of the form: 15 // 16 // func (z *T) Unary(x *T) *T // z = op x 17 // func (z *T) Binary(x, y *T) *T // z = x op y 18 // func (x *T) M() T1 // v = x.M() 19 // 20 // with T one of Int, Rat, or Float. For unary and binary operations, the 21 // result is the receiver (usually named z in that case); if it is one of 22 // the operands x or y it may be overwritten (and its memory reused). 23 // To enable chaining of operations, the result is also returned. Methods 24 // returning a result other than *Int, *Rat, or *Float take an operand as 25 // the receiver (usually named x in that case). 26 // 27 package big 28 29 // This file contains operations on unsigned multi-precision integers. 30 // These are the building blocks for the operations on signed integers 31 // and rationals. 32 33 import "math/rand" 34 35 // An unsigned integer x of the form 36 // 37 // x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] 38 // 39 // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, 40 // with the digits x[i] as the slice elements. 41 // 42 // A number is normalized if the slice contains no leading 0 digits. 43 // During arithmetic operations, denormalized values may occur but are 44 // always normalized before returning the final result. The normalized 45 // representation of 0 is the empty or nil slice (length = 0). 46 // 47 type nat []Word 48 49 var ( 50 natOne = nat{1} 51 natTwo = nat{2} 52 natTen = nat{10} 53 ) 54 55 func (z nat) clear() { 56 for i := range z { 57 z[i] = 0 58 } 59 } 60 61 func (z nat) norm() nat { 62 i := len(z) 63 for i > 0 && z[i-1] == 0 { 64 i-- 65 } 66 return z[0:i] 67 } 68 69 func (z nat) make(n int) nat { 70 if n <= cap(z) { 71 return z[:n] // reuse z 72 } 73 // Choosing a good value for e has significant performance impact 74 // because it increases the chance that a value can be reused. 75 const e = 4 // extra capacity 76 return make(nat, n, n+e) 77 } 78 79 func (z nat) setWord(x Word) nat { 80 if x == 0 { 81 return z[:0] 82 } 83 z = z.make(1) 84 z[0] = x 85 return z 86 } 87 88 func (z nat) setUint64(x uint64) nat { 89 // single-digit values 90 if w := Word(x); uint64(w) == x { 91 return z.setWord(w) 92 } 93 94 // compute number of words n required to represent x 95 n := 0 96 for t := x; t > 0; t >>= _W { 97 n++ 98 } 99 100 // split x into n words 101 z = z.make(n) 102 for i := range z { 103 z[i] = Word(x & _M) 104 x >>= _W 105 } 106 107 return z 108 } 109 110 func (z nat) set(x nat) nat { 111 z = z.make(len(x)) 112 copy(z, x) 113 return z 114 } 115 116 func (z nat) add(x, y nat) nat { 117 m := len(x) 118 n := len(y) 119 120 switch { 121 case m < n: 122 return z.add(y, x) 123 case m == 0: 124 // n == 0 because m >= n; result is 0 125 return z[:0] 126 case n == 0: 127 // result is x 128 return z.set(x) 129 } 130 // m > 0 131 132 z = z.make(m + 1) 133 c := addVV(z[0:n], x, y) 134 if m > n { 135 c = addVW(z[n:m], x[n:], c) 136 } 137 z[m] = c 138 139 return z.norm() 140 } 141 142 func (z nat) sub(x, y nat) nat { 143 m := len(x) 144 n := len(y) 145 146 switch { 147 case m < n: 148 panic("underflow") 149 case m == 0: 150 // n == 0 because m >= n; result is 0 151 return z[:0] 152 case n == 0: 153 // result is x 154 return z.set(x) 155 } 156 // m > 0 157 158 z = z.make(m) 159 c := subVV(z[0:n], x, y) 160 if m > n { 161 c = subVW(z[n:], x[n:], c) 162 } 163 if c != 0 { 164 panic("underflow") 165 } 166 167 return z.norm() 168 } 169 170 func (x nat) cmp(y nat) (r int) { 171 m := len(x) 172 n := len(y) 173 if m != n || m == 0 { 174 switch { 175 case m < n: 176 r = -1 177 case m > n: 178 r = 1 179 } 180 return 181 } 182 183 i := m - 1 184 for i > 0 && x[i] == y[i] { 185 i-- 186 } 187 188 switch { 189 case x[i] < y[i]: 190 r = -1 191 case x[i] > y[i]: 192 r = 1 193 } 194 return 195 } 196 197 func (z nat) mulAddWW(x nat, y, r Word) nat { 198 m := len(x) 199 if m == 0 || y == 0 { 200 return z.setWord(r) // result is r 201 } 202 // m > 0 203 204 z = z.make(m + 1) 205 z[m] = mulAddVWW(z[0:m], x, y, r) 206 207 return z.norm() 208 } 209 210 // basicMul multiplies x and y and leaves the result in z. 211 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. 212 func basicMul(z, x, y nat) { 213 z[0 : len(x)+len(y)].clear() // initialize z 214 for i, d := range y { 215 if d != 0 { 216 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) 217 } 218 } 219 } 220 221 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. 222 // Factored out for readability - do not use outside karatsuba. 223 func karatsubaAdd(z, x nat, n int) { 224 if c := addVV(z[0:n], z, x); c != 0 { 225 addVW(z[n:n+n>>1], z[n:], c) 226 } 227 } 228 229 // Like karatsubaAdd, but does subtract. 230 func karatsubaSub(z, x nat, n int) { 231 if c := subVV(z[0:n], z, x); c != 0 { 232 subVW(z[n:n+n>>1], z[n:], c) 233 } 234 } 235 236 // Operands that are shorter than karatsubaThreshold are multiplied using 237 // "grade school" multiplication; for longer operands the Karatsuba algorithm 238 // is used. 239 var karatsubaThreshold int = 40 // computed by calibrate.go 240 241 // karatsuba multiplies x and y and leaves the result in z. 242 // Both x and y must have the same length n and n must be a 243 // power of 2. The result vector z must have len(z) >= 6*n. 244 // The (non-normalized) result is placed in z[0 : 2*n]. 245 func karatsuba(z, x, y nat) { 246 n := len(y) 247 248 // Switch to basic multiplication if numbers are odd or small. 249 // (n is always even if karatsubaThreshold is even, but be 250 // conservative) 251 if n&1 != 0 || n < karatsubaThreshold || n < 2 { 252 basicMul(z, x, y) 253 return 254 } 255 // n&1 == 0 && n >= karatsubaThreshold && n >= 2 256 257 // Karatsuba multiplication is based on the observation that 258 // for two numbers x and y with: 259 // 260 // x = x1*b + x0 261 // y = y1*b + y0 262 // 263 // the product x*y can be obtained with 3 products z2, z1, z0 264 // instead of 4: 265 // 266 // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 267 // = z2*b*b + z1*b + z0 268 // 269 // with: 270 // 271 // xd = x1 - x0 272 // yd = y0 - y1 273 // 274 // z1 = xd*yd + z2 + z0 275 // = (x1-x0)*(y0 - y1) + z2 + z0 276 // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 277 // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 278 // = x1*y0 + x0*y1 279 280 // split x, y into "digits" 281 n2 := n >> 1 // n2 >= 1 282 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 283 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 284 285 // z is used for the result and temporary storage: 286 // 287 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 288 // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] 289 // 290 // For each recursive call of karatsuba, an unused slice of 291 // z is passed in that has (at least) half the length of the 292 // caller's z. 293 294 // compute z0 and z2 with the result "in place" in z 295 karatsuba(z, x0, y0) // z0 = x0*y0 296 karatsuba(z[n:], x1, y1) // z2 = x1*y1 297 298 // compute xd (or the negative value if underflow occurs) 299 s := 1 // sign of product xd*yd 300 xd := z[2*n : 2*n+n2] 301 if subVV(xd, x1, x0) != 0 { // x1-x0 302 s = -s 303 subVV(xd, x0, x1) // x0-x1 304 } 305 306 // compute yd (or the negative value if underflow occurs) 307 yd := z[2*n+n2 : 3*n] 308 if subVV(yd, y0, y1) != 0 { // y0-y1 309 s = -s 310 subVV(yd, y1, y0) // y1-y0 311 } 312 313 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 314 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 315 p := z[n*3:] 316 karatsuba(p, xd, yd) 317 318 // save original z2:z0 319 // (ok to use upper half of z since we're done recursing) 320 r := z[n*4:] 321 copy(r, z[:n*2]) 322 323 // add up all partial products 324 // 325 // 2*n n 0 326 // z = [ z2 | z0 ] 327 // + [ z0 ] 328 // + [ z2 ] 329 // + [ p ] 330 // 331 karatsubaAdd(z[n2:], r, n) 332 karatsubaAdd(z[n2:], r[n:], n) 333 if s > 0 { 334 karatsubaAdd(z[n2:], p, n) 335 } else { 336 karatsubaSub(z[n2:], p, n) 337 } 338 } 339 340 // alias reports whether x and y share the same base array. 341 func alias(x, y nat) bool { 342 return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 343 } 344 345 // addAt implements z += x<<(_W*i); z must be long enough. 346 // (we don't use nat.add because we need z to stay the same 347 // slice, and we don't need to normalize z after each addition) 348 func addAt(z, x nat, i int) { 349 if n := len(x); n > 0 { 350 if c := addVV(z[i:i+n], z[i:], x); c != 0 { 351 j := i + n 352 if j < len(z) { 353 addVW(z[j:], z[j:], c) 354 } 355 } 356 } 357 } 358 359 func max(x, y int) int { 360 if x > y { 361 return x 362 } 363 return y 364 } 365 366 // karatsubaLen computes an approximation to the maximum k <= n such that 367 // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the 368 // result is the largest number that can be divided repeatedly by 2 before 369 // becoming about the value of karatsubaThreshold. 370 func karatsubaLen(n int) int { 371 i := uint(0) 372 for n > karatsubaThreshold { 373 n >>= 1 374 i++ 375 } 376 return n << i 377 } 378 379 func (z nat) mul(x, y nat) nat { 380 m := len(x) 381 n := len(y) 382 383 switch { 384 case m < n: 385 return z.mul(y, x) 386 case m == 0 || n == 0: 387 return z[:0] 388 case n == 1: 389 return z.mulAddWW(x, y[0], 0) 390 } 391 // m >= n > 1 392 393 // determine if z can be reused 394 if alias(z, x) || alias(z, y) { 395 z = nil // z is an alias for x or y - cannot reuse 396 } 397 398 // use basic multiplication if the numbers are small 399 if n < karatsubaThreshold { 400 z = z.make(m + n) 401 basicMul(z, x, y) 402 return z.norm() 403 } 404 // m >= n && n >= karatsubaThreshold && n >= 2 405 406 // determine Karatsuba length k such that 407 // 408 // x = xh*b + x0 (0 <= x0 < b) 409 // y = yh*b + y0 (0 <= y0 < b) 410 // b = 1<<(_W*k) ("base" of digits xi, yi) 411 // 412 k := karatsubaLen(n) 413 // k <= n 414 415 // multiply x0 and y0 via Karatsuba 416 x0 := x[0:k] // x0 is not normalized 417 y0 := y[0:k] // y0 is not normalized 418 z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y 419 karatsuba(z, x0, y0) 420 z = z[0 : m+n] // z has final length but may be incomplete 421 z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) 422 423 // If xh != 0 or yh != 0, add the missing terms to z. For 424 // 425 // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) 426 // yh = y1*b (0 <= y1 < b) 427 // 428 // the missing terms are 429 // 430 // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 431 // 432 // since all the yi for i > 1 are 0 by choice of k: If any of them 433 // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would 434 // be a larger valid threshold contradicting the assumption about k. 435 // 436 if k < n || m != n { 437 var t nat 438 439 // add x0*y1*b 440 x0 := x0.norm() 441 y1 := y[k:] // y1 is normalized because y is 442 t = t.mul(x0, y1) // update t so we don't lose t's underlying array 443 addAt(z, t, k) 444 445 // add xi*y0<<i, xi*y1*b<<(i+k) 446 y0 := y0.norm() 447 for i := k; i < len(x); i += k { 448 xi := x[i:] 449 if len(xi) > k { 450 xi = xi[:k] 451 } 452 xi = xi.norm() 453 t = t.mul(xi, y0) 454 addAt(z, t, i) 455 t = t.mul(xi, y1) 456 addAt(z, t, i+k) 457 } 458 } 459 460 return z.norm() 461 } 462 463 // mulRange computes the product of all the unsigned integers in the 464 // range [a, b] inclusively. If a > b (empty range), the result is 1. 465 func (z nat) mulRange(a, b uint64) nat { 466 switch { 467 case a == 0: 468 // cut long ranges short (optimization) 469 return z.setUint64(0) 470 case a > b: 471 return z.setUint64(1) 472 case a == b: 473 return z.setUint64(a) 474 case a+1 == b: 475 return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) 476 } 477 m := (a + b) / 2 478 return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) 479 } 480 481 // q = (x-r)/y, with 0 <= r < y 482 func (z nat) divW(x nat, y Word) (q nat, r Word) { 483 m := len(x) 484 switch { 485 case y == 0: 486 panic("division by zero") 487 case y == 1: 488 q = z.set(x) // result is x 489 return 490 case m == 0: 491 q = z[:0] // result is 0 492 return 493 } 494 // m > 0 495 z = z.make(m) 496 r = divWVW(z, 0, x, y) 497 q = z.norm() 498 return 499 } 500 501 func (z nat) div(z2, u, v nat) (q, r nat) { 502 if len(v) == 0 { 503 panic("division by zero") 504 } 505 506 if u.cmp(v) < 0 { 507 q = z[:0] 508 r = z2.set(u) 509 return 510 } 511 512 if len(v) == 1 { 513 var r2 Word 514 q, r2 = z.divW(u, v[0]) 515 r = z2.setWord(r2) 516 return 517 } 518 519 q, r = z.divLarge(z2, u, v) 520 return 521 } 522 523 // q = (uIn-r)/v, with 0 <= r < y 524 // Uses z as storage for q, and u as storage for r if possible. 525 // See Knuth, Volume 2, section 4.3.1, Algorithm D. 526 // Preconditions: 527 // len(v) >= 2 528 // len(uIn) >= len(v) 529 func (z nat) divLarge(u, uIn, v nat) (q, r nat) { 530 n := len(v) 531 m := len(uIn) - n 532 533 // determine if z can be reused 534 // TODO(gri) should find a better solution - this if statement 535 // is very costly (see e.g. time pidigits -s -n 10000) 536 if alias(z, uIn) || alias(z, v) { 537 z = nil // z is an alias for uIn or v - cannot reuse 538 } 539 q = z.make(m + 1) 540 541 qhatv := make(nat, n+1) 542 if alias(u, uIn) || alias(u, v) { 543 u = nil // u is an alias for uIn or v - cannot reuse 544 } 545 u = u.make(len(uIn) + 1) 546 u.clear() // TODO(gri) no need to clear if we allocated a new u 547 548 // D1. 549 shift := leadingZeros(v[n-1]) 550 if shift > 0 { 551 // do not modify v, it may be used by another goroutine simultaneously 552 v1 := make(nat, n) 553 shlVU(v1, v, shift) 554 v = v1 555 } 556 u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift) 557 558 // D2. 559 for j := m; j >= 0; j-- { 560 // D3. 561 qhat := Word(_M) 562 if u[j+n] != v[n-1] { 563 var rhat Word 564 qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1]) 565 566 // x1 | x2 = q̂v_{n-2} 567 x1, x2 := mulWW(qhat, v[n-2]) 568 // test if q̂v_{n-2} > br̂ + u_{j+n-2} 569 for greaterThan(x1, x2, rhat, u[j+n-2]) { 570 qhat-- 571 prevRhat := rhat 572 rhat += v[n-1] 573 // v[n-1] >= 0, so this tests for overflow. 574 if rhat < prevRhat { 575 break 576 } 577 x1, x2 = mulWW(qhat, v[n-2]) 578 } 579 } 580 581 // D4. 582 qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) 583 584 c := subVV(u[j:j+len(qhatv)], u[j:], qhatv) 585 if c != 0 { 586 c := addVV(u[j:j+n], u[j:], v) 587 u[j+n] += c 588 qhat-- 589 } 590 591 q[j] = qhat 592 } 593 594 q = q.norm() 595 shrVU(u, u, shift) 596 r = u.norm() 597 598 return q, r 599 } 600 601 // Length of x in bits. x must be normalized. 602 func (x nat) bitLen() int { 603 if i := len(x) - 1; i >= 0 { 604 return i*_W + bitLen(x[i]) 605 } 606 return 0 607 } 608 609 const deBruijn32 = 0x077CB531 610 611 var deBruijn32Lookup = []byte{ 612 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 613 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9, 614 } 615 616 const deBruijn64 = 0x03f79d71b4ca8b09 617 618 var deBruijn64Lookup = []byte{ 619 0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4, 620 62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5, 621 63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11, 622 54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6, 623 } 624 625 // trailingZeroBits returns the number of consecutive least significant zero 626 // bits of x. 627 func trailingZeroBits(x Word) uint { 628 // x & -x leaves only the right-most bit set in the word. Let k be the 629 // index of that bit. Since only a single bit is set, the value is two 630 // to the power of k. Multiplying by a power of two is equivalent to 631 // left shifting, in this case by k bits. The de Bruijn constant is 632 // such that all six bit, consecutive substrings are distinct. 633 // Therefore, if we have a left shifted version of this constant we can 634 // find by how many bits it was shifted by looking at which six bit 635 // substring ended up at the top of the word. 636 // (Knuth, volume 4, section 7.3.1) 637 switch _W { 638 case 32: 639 return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) 640 case 64: 641 return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) 642 default: 643 panic("unknown word size") 644 } 645 } 646 647 // trailingZeroBits returns the number of consecutive least significant zero 648 // bits of x. 649 func (x nat) trailingZeroBits() uint { 650 if len(x) == 0 { 651 return 0 652 } 653 var i uint 654 for x[i] == 0 { 655 i++ 656 } 657 // x[i] != 0 658 return i*_W + trailingZeroBits(x[i]) 659 } 660 661 // z = x << s 662 func (z nat) shl(x nat, s uint) nat { 663 m := len(x) 664 if m == 0 { 665 return z[:0] 666 } 667 // m > 0 668 669 n := m + int(s/_W) 670 z = z.make(n + 1) 671 z[n] = shlVU(z[n-m:n], x, s%_W) 672 z[0 : n-m].clear() 673 674 return z.norm() 675 } 676 677 // z = x >> s 678 func (z nat) shr(x nat, s uint) nat { 679 m := len(x) 680 n := m - int(s/_W) 681 if n <= 0 { 682 return z[:0] 683 } 684 // n > 0 685 686 z = z.make(n) 687 shrVU(z, x[m-n:], s%_W) 688 689 return z.norm() 690 } 691 692 func (z nat) setBit(x nat, i uint, b uint) nat { 693 j := int(i / _W) 694 m := Word(1) << (i % _W) 695 n := len(x) 696 switch b { 697 case 0: 698 z = z.make(n) 699 copy(z, x) 700 if j >= n { 701 // no need to grow 702 return z 703 } 704 z[j] &^= m 705 return z.norm() 706 case 1: 707 if j >= n { 708 z = z.make(j + 1) 709 z[n:].clear() 710 } else { 711 z = z.make(n) 712 } 713 copy(z, x) 714 z[j] |= m 715 // no need to normalize 716 return z 717 } 718 panic("set bit is not 0 or 1") 719 } 720 721 // bit returns the value of the i'th bit, with lsb == bit 0. 722 func (x nat) bit(i uint) uint { 723 j := i / _W 724 if j >= uint(len(x)) { 725 return 0 726 } 727 // 0 <= j < len(x) 728 return uint(x[j] >> (i % _W) & 1) 729 } 730 731 // sticky returns 1 if there's a 1 bit within the 732 // i least significant bits, otherwise it returns 0. 733 func (x nat) sticky(i uint) uint { 734 j := i / _W 735 if j >= uint(len(x)) { 736 if len(x) == 0 { 737 return 0 738 } 739 return 1 740 } 741 // 0 <= j < len(x) 742 for _, x := range x[:j] { 743 if x != 0 { 744 return 1 745 } 746 } 747 if x[j]<<(_W-i%_W) != 0 { 748 return 1 749 } 750 return 0 751 } 752 753 func (z nat) and(x, y nat) nat { 754 m := len(x) 755 n := len(y) 756 if m > n { 757 m = n 758 } 759 // m <= n 760 761 z = z.make(m) 762 for i := 0; i < m; i++ { 763 z[i] = x[i] & y[i] 764 } 765 766 return z.norm() 767 } 768 769 func (z nat) andNot(x, y nat) nat { 770 m := len(x) 771 n := len(y) 772 if n > m { 773 n = m 774 } 775 // m >= n 776 777 z = z.make(m) 778 for i := 0; i < n; i++ { 779 z[i] = x[i] &^ y[i] 780 } 781 copy(z[n:m], x[n:m]) 782 783 return z.norm() 784 } 785 786 func (z nat) or(x, y nat) nat { 787 m := len(x) 788 n := len(y) 789 s := x 790 if m < n { 791 n, m = m, n 792 s = y 793 } 794 // m >= n 795 796 z = z.make(m) 797 for i := 0; i < n; i++ { 798 z[i] = x[i] | y[i] 799 } 800 copy(z[n:m], s[n:m]) 801 802 return z.norm() 803 } 804 805 func (z nat) xor(x, y nat) nat { 806 m := len(x) 807 n := len(y) 808 s := x 809 if m < n { 810 n, m = m, n 811 s = y 812 } 813 // m >= n 814 815 z = z.make(m) 816 for i := 0; i < n; i++ { 817 z[i] = x[i] ^ y[i] 818 } 819 copy(z[n:m], s[n:m]) 820 821 return z.norm() 822 } 823 824 // greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2) 825 func greaterThan(x1, x2, y1, y2 Word) bool { 826 return x1 > y1 || x1 == y1 && x2 > y2 827 } 828 829 // modW returns x % d. 830 func (x nat) modW(d Word) (r Word) { 831 // TODO(agl): we don't actually need to store the q value. 832 var q nat 833 q = q.make(len(x)) 834 return divWVW(q, 0, x, d) 835 } 836 837 // random creates a random integer in [0..limit), using the space in z if 838 // possible. n is the bit length of limit. 839 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 840 if alias(z, limit) { 841 z = nil // z is an alias for limit - cannot reuse 842 } 843 z = z.make(len(limit)) 844 845 bitLengthOfMSW := uint(n % _W) 846 if bitLengthOfMSW == 0 { 847 bitLengthOfMSW = _W 848 } 849 mask := Word((1 << bitLengthOfMSW) - 1) 850 851 for { 852 switch _W { 853 case 32: 854 for i := range z { 855 z[i] = Word(rand.Uint32()) 856 } 857 case 64: 858 for i := range z { 859 z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 860 } 861 default: 862 panic("unknown word size") 863 } 864 z[len(limit)-1] &= mask 865 if z.cmp(limit) < 0 { 866 break 867 } 868 } 869 870 return z.norm() 871 } 872 873 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; 874 // otherwise it sets z to x**y. The result is the value of z. 875 func (z nat) expNN(x, y, m nat) nat { 876 if alias(z, x) || alias(z, y) { 877 // We cannot allow in-place modification of x or y. 878 z = nil 879 } 880 881 // x**y mod 1 == 0 882 if len(m) == 1 && m[0] == 1 { 883 return z.setWord(0) 884 } 885 // m == 0 || m > 1 886 887 // x**0 == 1 888 if len(y) == 0 { 889 return z.setWord(1) 890 } 891 // y > 0 892 893 if len(m) != 0 { 894 // We likely end up being as long as the modulus. 895 z = z.make(len(m)) 896 } 897 z = z.set(x) 898 899 // If the base is non-trivial and the exponent is large, we use 900 // 4-bit, windowed exponentiation. This involves precomputing 14 values 901 // (x^2...x^15) but then reduces the number of multiply-reduces by a 902 // third. Even for a 32-bit exponent, this reduces the number of 903 // operations. 904 if len(x) > 1 && len(y) > 1 && len(m) > 0 { 905 return z.expNNWindowed(x, y, m) 906 } 907 908 v := y[len(y)-1] // v > 0 because y is normalized and y > 0 909 shift := leadingZeros(v) + 1 910 v <<= shift 911 var q nat 912 913 const mask = 1 << (_W - 1) 914 915 // We walk through the bits of the exponent one by one. Each time we 916 // see a bit, we square, thus doubling the power. If the bit is a one, 917 // we also multiply by x, thus adding one to the power. 918 919 w := _W - int(shift) 920 // zz and r are used to avoid allocating in mul and div as 921 // otherwise the arguments would alias. 922 var zz, r nat 923 for j := 0; j < w; j++ { 924 zz = zz.mul(z, z) 925 zz, z = z, zz 926 927 if v&mask != 0 { 928 zz = zz.mul(z, x) 929 zz, z = z, zz 930 } 931 932 if len(m) != 0 { 933 zz, r = zz.div(r, z, m) 934 zz, r, q, z = q, z, zz, r 935 } 936 937 v <<= 1 938 } 939 940 for i := len(y) - 2; i >= 0; i-- { 941 v = y[i] 942 943 for j := 0; j < _W; j++ { 944 zz = zz.mul(z, z) 945 zz, z = z, zz 946 947 if v&mask != 0 { 948 zz = zz.mul(z, x) 949 zz, z = z, zz 950 } 951 952 if len(m) != 0 { 953 zz, r = zz.div(r, z, m) 954 zz, r, q, z = q, z, zz, r 955 } 956 957 v <<= 1 958 } 959 } 960 961 return z.norm() 962 } 963 964 // expNNWindowed calculates x**y mod m using a fixed, 4-bit window. 965 func (z nat) expNNWindowed(x, y, m nat) nat { 966 // zz and r are used to avoid allocating in mul and div as otherwise 967 // the arguments would alias. 968 var zz, r nat 969 970 const n = 4 971 // powers[i] contains x^i. 972 var powers [1 << n]nat 973 powers[0] = natOne 974 powers[1] = x 975 for i := 2; i < 1<<n; i += 2 { 976 p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1] 977 *p = p.mul(*p2, *p2) 978 zz, r = zz.div(r, *p, m) 979 *p, r = r, *p 980 *p1 = p1.mul(*p, x) 981 zz, r = zz.div(r, *p1, m) 982 *p1, r = r, *p1 983 } 984 985 z = z.setWord(1) 986 987 for i := len(y) - 1; i >= 0; i-- { 988 yi := y[i] 989 for j := 0; j < _W; j += n { 990 if i != len(y)-1 || j != 0 { 991 // Unrolled loop for significant performance 992 // gain. Use go test -bench=".*" in crypto/rsa 993 // to check performance before making changes. 994 zz = zz.mul(z, z) 995 zz, z = z, zz 996 zz, r = zz.div(r, z, m) 997 z, r = r, z 998 999 zz = zz.mul(z, z) 1000 zz, z = z, zz 1001 zz, r = zz.div(r, z, m) 1002 z, r = r, z 1003 1004 zz = zz.mul(z, z) 1005 zz, z = z, zz 1006 zz, r = zz.div(r, z, m) 1007 z, r = r, z 1008 1009 zz = zz.mul(z, z) 1010 zz, z = z, zz 1011 zz, r = zz.div(r, z, m) 1012 z, r = r, z 1013 } 1014 1015 zz = zz.mul(z, powers[yi>>(_W-n)]) 1016 zz, z = z, zz 1017 zz, r = zz.div(r, z, m) 1018 z, r = r, z 1019 1020 yi <<= n 1021 } 1022 } 1023 1024 return z.norm() 1025 } 1026 1027 // probablyPrime performs reps Miller-Rabin tests to check whether n is prime. 1028 // If it returns true, n is prime with probability 1 - 1/4^reps. 1029 // If it returns false, n is not prime. 1030 func (n nat) probablyPrime(reps int) bool { 1031 if len(n) == 0 { 1032 return false 1033 } 1034 1035 if len(n) == 1 { 1036 if n[0] < 2 { 1037 return false 1038 } 1039 1040 if n[0]%2 == 0 { 1041 return n[0] == 2 1042 } 1043 1044 // We have to exclude these cases because we reject all 1045 // multiples of these numbers below. 1046 switch n[0] { 1047 case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53: 1048 return true 1049 } 1050 } 1051 1052 if n[0]&1 == 0 { 1053 return false // n is even 1054 } 1055 1056 const primesProduct32 = 0xC0CFD797 // Π {p ∈ primes, 2 < p <= 29} 1057 const primesProduct64 = 0xE221F97C30E94E1D // Π {p ∈ primes, 2 < p <= 53} 1058 1059 var r Word 1060 switch _W { 1061 case 32: 1062 r = n.modW(primesProduct32) 1063 case 64: 1064 r = n.modW(primesProduct64 & _M) 1065 default: 1066 panic("Unknown word size") 1067 } 1068 1069 if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 || 1070 r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 { 1071 return false 1072 } 1073 1074 if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 || 1075 r%43 == 0 || r%47 == 0 || r%53 == 0) { 1076 return false 1077 } 1078 1079 nm1 := nat(nil).sub(n, natOne) 1080 // determine q, k such that nm1 = q << k 1081 k := nm1.trailingZeroBits() 1082 q := nat(nil).shr(nm1, k) 1083 1084 nm3 := nat(nil).sub(nm1, natTwo) 1085 rand := rand.New(rand.NewSource(int64(n[0]))) 1086 1087 var x, y, quotient nat 1088 nm3Len := nm3.bitLen() 1089 1090 NextRandom: 1091 for i := 0; i < reps; i++ { 1092 x = x.random(rand, nm3, nm3Len) 1093 x = x.add(x, natTwo) 1094 y = y.expNN(x, q, n) 1095 if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { 1096 continue 1097 } 1098 for j := uint(1); j < k; j++ { 1099 y = y.mul(y, y) 1100 quotient, y = quotient.div(y, y, n) 1101 if y.cmp(nm1) == 0 { 1102 continue NextRandom 1103 } 1104 if y.cmp(natOne) == 0 { 1105 return false 1106 } 1107 } 1108 return false 1109 } 1110 1111 return true 1112 } 1113 1114 // bytes writes the value of z into buf using big-endian encoding. 1115 // len(buf) must be >= len(z)*_S. The value of z is encoded in the 1116 // slice buf[i:]. The number i of unused bytes at the beginning of 1117 // buf is returned as result. 1118 func (z nat) bytes(buf []byte) (i int) { 1119 i = len(buf) 1120 for _, d := range z { 1121 for j := 0; j < _S; j++ { 1122 i-- 1123 buf[i] = byte(d) 1124 d >>= 8 1125 } 1126 } 1127 1128 for i < len(buf) && buf[i] == 0 { 1129 i++ 1130 } 1131 1132 return 1133 } 1134 1135 // setBytes interprets buf as the bytes of a big-endian unsigned 1136 // integer, sets z to that value, and returns z. 1137 func (z nat) setBytes(buf []byte) nat { 1138 z = z.make((len(buf) + _S - 1) / _S) 1139 1140 k := 0 1141 s := uint(0) 1142 var d Word 1143 for i := len(buf); i > 0; i-- { 1144 d |= Word(buf[i-1]) << s 1145 if s += 8; s == _S*8 { 1146 z[k] = d 1147 k++ 1148 s = 0 1149 d = 0 1150 } 1151 } 1152 if k < len(z) { 1153 z[k] = d 1154 } 1155 1156 return z.norm() 1157 }