github.com/rakyll/go@v0.0.0-20170216000551-64c02460d703/src/math/big/nat.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // This file implements unsigned multi-precision integers (natural 6 // numbers). They are the building blocks for the implementation 7 // of signed integers, rationals, and floating-point numbers. 8 9 package big 10 11 import ( 12 "math/rand" 13 "sync" 14 ) 15 16 // An unsigned integer x of the form 17 // 18 // x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] 19 // 20 // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, 21 // with the digits x[i] as the slice elements. 22 // 23 // A number is normalized if the slice contains no leading 0 digits. 24 // During arithmetic operations, denormalized values may occur but are 25 // always normalized before returning the final result. The normalized 26 // representation of 0 is the empty or nil slice (length = 0). 27 // 28 type nat []Word 29 30 var ( 31 natOne = nat{1} 32 natTwo = nat{2} 33 natTen = nat{10} 34 ) 35 36 func (z nat) clear() { 37 for i := range z { 38 z[i] = 0 39 } 40 } 41 42 func (z nat) norm() nat { 43 i := len(z) 44 for i > 0 && z[i-1] == 0 { 45 i-- 46 } 47 return z[0:i] 48 } 49 50 func (z nat) make(n int) nat { 51 if n <= cap(z) { 52 return z[:n] // reuse z 53 } 54 // Choosing a good value for e has significant performance impact 55 // because it increases the chance that a value can be reused. 56 const e = 4 // extra capacity 57 return make(nat, n, n+e) 58 } 59 60 func (z nat) setWord(x Word) nat { 61 if x == 0 { 62 return z[:0] 63 } 64 z = z.make(1) 65 z[0] = x 66 return z 67 } 68 69 func (z nat) setUint64(x uint64) nat { 70 // single-digit values 71 if w := Word(x); uint64(w) == x { 72 return z.setWord(w) 73 } 74 75 // compute number of words n required to represent x 76 n := 0 77 for t := x; t > 0; t >>= _W { 78 n++ 79 } 80 81 // split x into n words 82 z = z.make(n) 83 for i := range z { 84 z[i] = Word(x & _M) 85 x >>= _W 86 } 87 88 return z 89 } 90 91 func (z nat) set(x nat) nat { 92 z = z.make(len(x)) 93 copy(z, x) 94 return z 95 } 96 97 func (z nat) add(x, y nat) nat { 98 m := len(x) 99 n := len(y) 100 101 switch { 102 case m < n: 103 return z.add(y, x) 104 case m == 0: 105 // n == 0 because m >= n; result is 0 106 return z[:0] 107 case n == 0: 108 // result is x 109 return z.set(x) 110 } 111 // m > 0 112 113 z = z.make(m + 1) 114 c := addVV(z[0:n], x, y) 115 if m > n { 116 c = addVW(z[n:m], x[n:], c) 117 } 118 z[m] = c 119 120 return z.norm() 121 } 122 123 func (z nat) sub(x, y nat) nat { 124 m := len(x) 125 n := len(y) 126 127 switch { 128 case m < n: 129 panic("underflow") 130 case m == 0: 131 // n == 0 because m >= n; result is 0 132 return z[:0] 133 case n == 0: 134 // result is x 135 return z.set(x) 136 } 137 // m > 0 138 139 z = z.make(m) 140 c := subVV(z[0:n], x, y) 141 if m > n { 142 c = subVW(z[n:], x[n:], c) 143 } 144 if c != 0 { 145 panic("underflow") 146 } 147 148 return z.norm() 149 } 150 151 func (x nat) cmp(y nat) (r int) { 152 m := len(x) 153 n := len(y) 154 if m != n || m == 0 { 155 switch { 156 case m < n: 157 r = -1 158 case m > n: 159 r = 1 160 } 161 return 162 } 163 164 i := m - 1 165 for i > 0 && x[i] == y[i] { 166 i-- 167 } 168 169 switch { 170 case x[i] < y[i]: 171 r = -1 172 case x[i] > y[i]: 173 r = 1 174 } 175 return 176 } 177 178 func (z nat) mulAddWW(x nat, y, r Word) nat { 179 m := len(x) 180 if m == 0 || y == 0 { 181 return z.setWord(r) // result is r 182 } 183 // m > 0 184 185 z = z.make(m + 1) 186 z[m] = mulAddVWW(z[0:m], x, y, r) 187 188 return z.norm() 189 } 190 191 // basicMul multiplies x and y and leaves the result in z. 192 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. 193 func basicMul(z, x, y nat) { 194 z[0 : len(x)+len(y)].clear() // initialize z 195 for i, d := range y { 196 if d != 0 { 197 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) 198 } 199 } 200 } 201 202 // montgomery computes z mod m = x*y*2**(-n*_W) mod m, 203 // assuming k = -1/m mod 2**_W. 204 // z is used for storing the result which is returned; 205 // z must not alias x, y or m. 206 // See Gueron, "Efficient Software Implementations of Modular Exponentiation". 207 // https://eprint.iacr.org/2011/239.pdf 208 // In the terminology of that paper, this is an "Almost Montgomery Multiplication": 209 // x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result 210 // z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m. 211 func (z nat) montgomery(x, y, m nat, k Word, n int) nat { 212 // This code assumes x, y, m are all the same length, n. 213 // (required by addMulVVW and the for loop). 214 // It also assumes that x, y are already reduced mod m, 215 // or else the result will not be properly reduced. 216 if len(x) != n || len(y) != n || len(m) != n { 217 panic("math/big: mismatched montgomery number lengths") 218 } 219 z = z.make(n) 220 z.clear() 221 var c Word 222 for i := 0; i < n; i++ { 223 d := y[i] 224 c2 := addMulVVW(z, x, d) 225 t := z[0] * k 226 c3 := addMulVVW(z, m, t) 227 copy(z, z[1:]) 228 cx := c + c2 229 cy := cx + c3 230 z[n-1] = cy 231 if cx < c2 || cy < c3 { 232 c = 1 233 } else { 234 c = 0 235 } 236 } 237 if c != 0 { 238 subVV(z, z, m) 239 } 240 return z 241 } 242 243 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. 244 // Factored out for readability - do not use outside karatsuba. 245 func karatsubaAdd(z, x nat, n int) { 246 if c := addVV(z[0:n], z, x); c != 0 { 247 addVW(z[n:n+n>>1], z[n:], c) 248 } 249 } 250 251 // Like karatsubaAdd, but does subtract. 252 func karatsubaSub(z, x nat, n int) { 253 if c := subVV(z[0:n], z, x); c != 0 { 254 subVW(z[n:n+n>>1], z[n:], c) 255 } 256 } 257 258 // Operands that are shorter than karatsubaThreshold are multiplied using 259 // "grade school" multiplication; for longer operands the Karatsuba algorithm 260 // is used. 261 var karatsubaThreshold int = 40 // computed by calibrate.go 262 263 // karatsuba multiplies x and y and leaves the result in z. 264 // Both x and y must have the same length n and n must be a 265 // power of 2. The result vector z must have len(z) >= 6*n. 266 // The (non-normalized) result is placed in z[0 : 2*n]. 267 func karatsuba(z, x, y nat) { 268 n := len(y) 269 270 // Switch to basic multiplication if numbers are odd or small. 271 // (n is always even if karatsubaThreshold is even, but be 272 // conservative) 273 if n&1 != 0 || n < karatsubaThreshold || n < 2 { 274 basicMul(z, x, y) 275 return 276 } 277 // n&1 == 0 && n >= karatsubaThreshold && n >= 2 278 279 // Karatsuba multiplication is based on the observation that 280 // for two numbers x and y with: 281 // 282 // x = x1*b + x0 283 // y = y1*b + y0 284 // 285 // the product x*y can be obtained with 3 products z2, z1, z0 286 // instead of 4: 287 // 288 // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 289 // = z2*b*b + z1*b + z0 290 // 291 // with: 292 // 293 // xd = x1 - x0 294 // yd = y0 - y1 295 // 296 // z1 = xd*yd + z2 + z0 297 // = (x1-x0)*(y0 - y1) + z2 + z0 298 // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 299 // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 300 // = x1*y0 + x0*y1 301 302 // split x, y into "digits" 303 n2 := n >> 1 // n2 >= 1 304 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 305 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 306 307 // z is used for the result and temporary storage: 308 // 309 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 310 // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] 311 // 312 // For each recursive call of karatsuba, an unused slice of 313 // z is passed in that has (at least) half the length of the 314 // caller's z. 315 316 // compute z0 and z2 with the result "in place" in z 317 karatsuba(z, x0, y0) // z0 = x0*y0 318 karatsuba(z[n:], x1, y1) // z2 = x1*y1 319 320 // compute xd (or the negative value if underflow occurs) 321 s := 1 // sign of product xd*yd 322 xd := z[2*n : 2*n+n2] 323 if subVV(xd, x1, x0) != 0 { // x1-x0 324 s = -s 325 subVV(xd, x0, x1) // x0-x1 326 } 327 328 // compute yd (or the negative value if underflow occurs) 329 yd := z[2*n+n2 : 3*n] 330 if subVV(yd, y0, y1) != 0 { // y0-y1 331 s = -s 332 subVV(yd, y1, y0) // y1-y0 333 } 334 335 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 336 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 337 p := z[n*3:] 338 karatsuba(p, xd, yd) 339 340 // save original z2:z0 341 // (ok to use upper half of z since we're done recursing) 342 r := z[n*4:] 343 copy(r, z[:n*2]) 344 345 // add up all partial products 346 // 347 // 2*n n 0 348 // z = [ z2 | z0 ] 349 // + [ z0 ] 350 // + [ z2 ] 351 // + [ p ] 352 // 353 karatsubaAdd(z[n2:], r, n) 354 karatsubaAdd(z[n2:], r[n:], n) 355 if s > 0 { 356 karatsubaAdd(z[n2:], p, n) 357 } else { 358 karatsubaSub(z[n2:], p, n) 359 } 360 } 361 362 // alias reports whether x and y share the same base array. 363 func alias(x, y nat) bool { 364 return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 365 } 366 367 // addAt implements z += x<<(_W*i); z must be long enough. 368 // (we don't use nat.add because we need z to stay the same 369 // slice, and we don't need to normalize z after each addition) 370 func addAt(z, x nat, i int) { 371 if n := len(x); n > 0 { 372 if c := addVV(z[i:i+n], z[i:], x); c != 0 { 373 j := i + n 374 if j < len(z) { 375 addVW(z[j:], z[j:], c) 376 } 377 } 378 } 379 } 380 381 func max(x, y int) int { 382 if x > y { 383 return x 384 } 385 return y 386 } 387 388 // karatsubaLen computes an approximation to the maximum k <= n such that 389 // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the 390 // result is the largest number that can be divided repeatedly by 2 before 391 // becoming about the value of karatsubaThreshold. 392 func karatsubaLen(n int) int { 393 i := uint(0) 394 for n > karatsubaThreshold { 395 n >>= 1 396 i++ 397 } 398 return n << i 399 } 400 401 func (z nat) mul(x, y nat) nat { 402 m := len(x) 403 n := len(y) 404 405 switch { 406 case m < n: 407 return z.mul(y, x) 408 case m == 0 || n == 0: 409 return z[:0] 410 case n == 1: 411 return z.mulAddWW(x, y[0], 0) 412 } 413 // m >= n > 1 414 415 // determine if z can be reused 416 if alias(z, x) || alias(z, y) { 417 z = nil // z is an alias for x or y - cannot reuse 418 } 419 420 // use basic multiplication if the numbers are small 421 if n < karatsubaThreshold { 422 z = z.make(m + n) 423 basicMul(z, x, y) 424 return z.norm() 425 } 426 // m >= n && n >= karatsubaThreshold && n >= 2 427 428 // determine Karatsuba length k such that 429 // 430 // x = xh*b + x0 (0 <= x0 < b) 431 // y = yh*b + y0 (0 <= y0 < b) 432 // b = 1<<(_W*k) ("base" of digits xi, yi) 433 // 434 k := karatsubaLen(n) 435 // k <= n 436 437 // multiply x0 and y0 via Karatsuba 438 x0 := x[0:k] // x0 is not normalized 439 y0 := y[0:k] // y0 is not normalized 440 z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y 441 karatsuba(z, x0, y0) 442 z = z[0 : m+n] // z has final length but may be incomplete 443 z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) 444 445 // If xh != 0 or yh != 0, add the missing terms to z. For 446 // 447 // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) 448 // yh = y1*b (0 <= y1 < b) 449 // 450 // the missing terms are 451 // 452 // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 453 // 454 // since all the yi for i > 1 are 0 by choice of k: If any of them 455 // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would 456 // be a larger valid threshold contradicting the assumption about k. 457 // 458 if k < n || m != n { 459 var t nat 460 461 // add x0*y1*b 462 x0 := x0.norm() 463 y1 := y[k:] // y1 is normalized because y is 464 t = t.mul(x0, y1) // update t so we don't lose t's underlying array 465 addAt(z, t, k) 466 467 // add xi*y0<<i, xi*y1*b<<(i+k) 468 y0 := y0.norm() 469 for i := k; i < len(x); i += k { 470 xi := x[i:] 471 if len(xi) > k { 472 xi = xi[:k] 473 } 474 xi = xi.norm() 475 t = t.mul(xi, y0) 476 addAt(z, t, i) 477 t = t.mul(xi, y1) 478 addAt(z, t, i+k) 479 } 480 } 481 482 return z.norm() 483 } 484 485 // mulRange computes the product of all the unsigned integers in the 486 // range [a, b] inclusively. If a > b (empty range), the result is 1. 487 func (z nat) mulRange(a, b uint64) nat { 488 switch { 489 case a == 0: 490 // cut long ranges short (optimization) 491 return z.setUint64(0) 492 case a > b: 493 return z.setUint64(1) 494 case a == b: 495 return z.setUint64(a) 496 case a+1 == b: 497 return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) 498 } 499 m := (a + b) / 2 500 return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) 501 } 502 503 // q = (x-r)/y, with 0 <= r < y 504 func (z nat) divW(x nat, y Word) (q nat, r Word) { 505 m := len(x) 506 switch { 507 case y == 0: 508 panic("division by zero") 509 case y == 1: 510 q = z.set(x) // result is x 511 return 512 case m == 0: 513 q = z[:0] // result is 0 514 return 515 } 516 // m > 0 517 z = z.make(m) 518 r = divWVW(z, 0, x, y) 519 q = z.norm() 520 return 521 } 522 523 func (z nat) div(z2, u, v nat) (q, r nat) { 524 if len(v) == 0 { 525 panic("division by zero") 526 } 527 528 if u.cmp(v) < 0 { 529 q = z[:0] 530 r = z2.set(u) 531 return 532 } 533 534 if len(v) == 1 { 535 var r2 Word 536 q, r2 = z.divW(u, v[0]) 537 r = z2.setWord(r2) 538 return 539 } 540 541 q, r = z.divLarge(z2, u, v) 542 return 543 } 544 545 // getNat returns a *nat of len n. The contents may not be zero. 546 // The pool holds *nat to avoid allocation when converting to interface{}. 547 func getNat(n int) *nat { 548 var z *nat 549 if v := natPool.Get(); v != nil { 550 z = v.(*nat) 551 } 552 if z == nil { 553 z = new(nat) 554 } 555 *z = z.make(n) 556 return z 557 } 558 559 func putNat(x *nat) { 560 natPool.Put(x) 561 } 562 563 var natPool sync.Pool 564 565 // q = (uIn-r)/v, with 0 <= r < y 566 // Uses z as storage for q, and u as storage for r if possible. 567 // See Knuth, Volume 2, section 4.3.1, Algorithm D. 568 // Preconditions: 569 // len(v) >= 2 570 // len(uIn) >= len(v) 571 func (z nat) divLarge(u, uIn, v nat) (q, r nat) { 572 n := len(v) 573 m := len(uIn) - n 574 575 // determine if z can be reused 576 // TODO(gri) should find a better solution - this if statement 577 // is very costly (see e.g. time pidigits -s -n 10000) 578 if alias(z, uIn) || alias(z, v) { 579 z = nil // z is an alias for uIn or v - cannot reuse 580 } 581 q = z.make(m + 1) 582 583 qhatvp := getNat(n + 1) 584 qhatv := *qhatvp 585 if alias(u, uIn) || alias(u, v) { 586 u = nil // u is an alias for uIn or v - cannot reuse 587 } 588 u = u.make(len(uIn) + 1) 589 u.clear() // TODO(gri) no need to clear if we allocated a new u 590 591 // D1. 592 var v1p *nat 593 shift := nlz(v[n-1]) 594 if shift > 0 { 595 // do not modify v, it may be used by another goroutine simultaneously 596 v1p = getNat(n) 597 v1 := *v1p 598 shlVU(v1, v, shift) 599 v = v1 600 } 601 u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift) 602 603 // D2. 604 vn1 := v[n-1] 605 for j := m; j >= 0; j-- { 606 // D3. 607 qhat := Word(_M) 608 if ujn := u[j+n]; ujn != vn1 { 609 var rhat Word 610 qhat, rhat = divWW(ujn, u[j+n-1], vn1) 611 612 // x1 | x2 = q̂v_{n-2} 613 vn2 := v[n-2] 614 x1, x2 := mulWW(qhat, vn2) 615 // test if q̂v_{n-2} > br̂ + u_{j+n-2} 616 ujn2 := u[j+n-2] 617 for greaterThan(x1, x2, rhat, ujn2) { 618 qhat-- 619 prevRhat := rhat 620 rhat += vn1 621 // v[n-1] >= 0, so this tests for overflow. 622 if rhat < prevRhat { 623 break 624 } 625 x1, x2 = mulWW(qhat, vn2) 626 } 627 } 628 629 // D4. 630 qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) 631 632 c := subVV(u[j:j+len(qhatv)], u[j:], qhatv) 633 if c != 0 { 634 c := addVV(u[j:j+n], u[j:], v) 635 u[j+n] += c 636 qhat-- 637 } 638 639 q[j] = qhat 640 } 641 if v1p != nil { 642 putNat(v1p) 643 } 644 putNat(qhatvp) 645 646 q = q.norm() 647 shrVU(u, u, shift) 648 r = u.norm() 649 650 return q, r 651 } 652 653 // Length of x in bits. x must be normalized. 654 func (x nat) bitLen() int { 655 if i := len(x) - 1; i >= 0 { 656 return i*_W + bitLen(x[i]) 657 } 658 return 0 659 } 660 661 const deBruijn32 = 0x077CB531 662 663 var deBruijn32Lookup = [...]byte{ 664 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 665 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9, 666 } 667 668 const deBruijn64 = 0x03f79d71b4ca8b09 669 670 var deBruijn64Lookup = [...]byte{ 671 0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4, 672 62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5, 673 63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11, 674 54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6, 675 } 676 677 // trailingZeroBits returns the number of consecutive least significant zero 678 // bits of x. 679 func trailingZeroBits(x Word) uint { 680 // x & -x leaves only the right-most bit set in the word. Let k be the 681 // index of that bit. Since only a single bit is set, the value is two 682 // to the power of k. Multiplying by a power of two is equivalent to 683 // left shifting, in this case by k bits. The de Bruijn constant is 684 // such that all six bit, consecutive substrings are distinct. 685 // Therefore, if we have a left shifted version of this constant we can 686 // find by how many bits it was shifted by looking at which six bit 687 // substring ended up at the top of the word. 688 // (Knuth, volume 4, section 7.3.1) 689 switch _W { 690 case 32: 691 return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27]) 692 case 64: 693 return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) 694 default: 695 panic("unknown word size") 696 } 697 } 698 699 // trailingZeroBits returns the number of consecutive least significant zero 700 // bits of x. 701 func (x nat) trailingZeroBits() uint { 702 if len(x) == 0 { 703 return 0 704 } 705 var i uint 706 for x[i] == 0 { 707 i++ 708 } 709 // x[i] != 0 710 return i*_W + trailingZeroBits(x[i]) 711 } 712 713 // z = x << s 714 func (z nat) shl(x nat, s uint) nat { 715 m := len(x) 716 if m == 0 { 717 return z[:0] 718 } 719 // m > 0 720 721 n := m + int(s/_W) 722 z = z.make(n + 1) 723 z[n] = shlVU(z[n-m:n], x, s%_W) 724 z[0 : n-m].clear() 725 726 return z.norm() 727 } 728 729 // z = x >> s 730 func (z nat) shr(x nat, s uint) nat { 731 m := len(x) 732 n := m - int(s/_W) 733 if n <= 0 { 734 return z[:0] 735 } 736 // n > 0 737 738 z = z.make(n) 739 shrVU(z, x[m-n:], s%_W) 740 741 return z.norm() 742 } 743 744 func (z nat) setBit(x nat, i uint, b uint) nat { 745 j := int(i / _W) 746 m := Word(1) << (i % _W) 747 n := len(x) 748 switch b { 749 case 0: 750 z = z.make(n) 751 copy(z, x) 752 if j >= n { 753 // no need to grow 754 return z 755 } 756 z[j] &^= m 757 return z.norm() 758 case 1: 759 if j >= n { 760 z = z.make(j + 1) 761 z[n:].clear() 762 } else { 763 z = z.make(n) 764 } 765 copy(z, x) 766 z[j] |= m 767 // no need to normalize 768 return z 769 } 770 panic("set bit is not 0 or 1") 771 } 772 773 // bit returns the value of the i'th bit, with lsb == bit 0. 774 func (x nat) bit(i uint) uint { 775 j := i / _W 776 if j >= uint(len(x)) { 777 return 0 778 } 779 // 0 <= j < len(x) 780 return uint(x[j] >> (i % _W) & 1) 781 } 782 783 // sticky returns 1 if there's a 1 bit within the 784 // i least significant bits, otherwise it returns 0. 785 func (x nat) sticky(i uint) uint { 786 j := i / _W 787 if j >= uint(len(x)) { 788 if len(x) == 0 { 789 return 0 790 } 791 return 1 792 } 793 // 0 <= j < len(x) 794 for _, x := range x[:j] { 795 if x != 0 { 796 return 1 797 } 798 } 799 if x[j]<<(_W-i%_W) != 0 { 800 return 1 801 } 802 return 0 803 } 804 805 func (z nat) and(x, y nat) nat { 806 m := len(x) 807 n := len(y) 808 if m > n { 809 m = n 810 } 811 // m <= n 812 813 z = z.make(m) 814 for i := 0; i < m; i++ { 815 z[i] = x[i] & y[i] 816 } 817 818 return z.norm() 819 } 820 821 func (z nat) andNot(x, y nat) nat { 822 m := len(x) 823 n := len(y) 824 if n > m { 825 n = m 826 } 827 // m >= n 828 829 z = z.make(m) 830 for i := 0; i < n; i++ { 831 z[i] = x[i] &^ y[i] 832 } 833 copy(z[n:m], x[n:m]) 834 835 return z.norm() 836 } 837 838 func (z nat) or(x, y nat) nat { 839 m := len(x) 840 n := len(y) 841 s := x 842 if m < n { 843 n, m = m, n 844 s = y 845 } 846 // m >= n 847 848 z = z.make(m) 849 for i := 0; i < n; i++ { 850 z[i] = x[i] | y[i] 851 } 852 copy(z[n:m], s[n:m]) 853 854 return z.norm() 855 } 856 857 func (z nat) xor(x, y nat) nat { 858 m := len(x) 859 n := len(y) 860 s := x 861 if m < n { 862 n, m = m, n 863 s = y 864 } 865 // m >= n 866 867 z = z.make(m) 868 for i := 0; i < n; i++ { 869 z[i] = x[i] ^ y[i] 870 } 871 copy(z[n:m], s[n:m]) 872 873 return z.norm() 874 } 875 876 // greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2) 877 func greaterThan(x1, x2, y1, y2 Word) bool { 878 return x1 > y1 || x1 == y1 && x2 > y2 879 } 880 881 // modW returns x % d. 882 func (x nat) modW(d Word) (r Word) { 883 // TODO(agl): we don't actually need to store the q value. 884 var q nat 885 q = q.make(len(x)) 886 return divWVW(q, 0, x, d) 887 } 888 889 // random creates a random integer in [0..limit), using the space in z if 890 // possible. n is the bit length of limit. 891 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 892 if alias(z, limit) { 893 z = nil // z is an alias for limit - cannot reuse 894 } 895 z = z.make(len(limit)) 896 897 bitLengthOfMSW := uint(n % _W) 898 if bitLengthOfMSW == 0 { 899 bitLengthOfMSW = _W 900 } 901 mask := Word((1 << bitLengthOfMSW) - 1) 902 903 for { 904 switch _W { 905 case 32: 906 for i := range z { 907 z[i] = Word(rand.Uint32()) 908 } 909 case 64: 910 for i := range z { 911 z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 912 } 913 default: 914 panic("unknown word size") 915 } 916 z[len(limit)-1] &= mask 917 if z.cmp(limit) < 0 { 918 break 919 } 920 } 921 922 return z.norm() 923 } 924 925 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; 926 // otherwise it sets z to x**y. The result is the value of z. 927 func (z nat) expNN(x, y, m nat) nat { 928 if alias(z, x) || alias(z, y) { 929 // We cannot allow in-place modification of x or y. 930 z = nil 931 } 932 933 // x**y mod 1 == 0 934 if len(m) == 1 && m[0] == 1 { 935 return z.setWord(0) 936 } 937 // m == 0 || m > 1 938 939 // x**0 == 1 940 if len(y) == 0 { 941 return z.setWord(1) 942 } 943 // y > 0 944 945 // x**1 mod m == x mod m 946 if len(y) == 1 && y[0] == 1 && len(m) != 0 { 947 _, z = z.div(z, x, m) 948 return z 949 } 950 // y > 1 951 952 if len(m) != 0 { 953 // We likely end up being as long as the modulus. 954 z = z.make(len(m)) 955 } 956 z = z.set(x) 957 958 // If the base is non-trivial and the exponent is large, we use 959 // 4-bit, windowed exponentiation. This involves precomputing 14 values 960 // (x^2...x^15) but then reduces the number of multiply-reduces by a 961 // third. Even for a 32-bit exponent, this reduces the number of 962 // operations. Uses Montgomery method for odd moduli. 963 if x.cmp(natOne) > 0 && len(y) > 1 && len(m) > 0 { 964 if m[0]&1 == 1 { 965 return z.expNNMontgomery(x, y, m) 966 } 967 return z.expNNWindowed(x, y, m) 968 } 969 970 v := y[len(y)-1] // v > 0 because y is normalized and y > 0 971 shift := nlz(v) + 1 972 v <<= shift 973 var q nat 974 975 const mask = 1 << (_W - 1) 976 977 // We walk through the bits of the exponent one by one. Each time we 978 // see a bit, we square, thus doubling the power. If the bit is a one, 979 // we also multiply by x, thus adding one to the power. 980 981 w := _W - int(shift) 982 // zz and r are used to avoid allocating in mul and div as 983 // otherwise the arguments would alias. 984 var zz, r nat 985 for j := 0; j < w; j++ { 986 zz = zz.mul(z, z) 987 zz, z = z, zz 988 989 if v&mask != 0 { 990 zz = zz.mul(z, x) 991 zz, z = z, zz 992 } 993 994 if len(m) != 0 { 995 zz, r = zz.div(r, z, m) 996 zz, r, q, z = q, z, zz, r 997 } 998 999 v <<= 1 1000 } 1001 1002 for i := len(y) - 2; i >= 0; i-- { 1003 v = y[i] 1004 1005 for j := 0; j < _W; j++ { 1006 zz = zz.mul(z, z) 1007 zz, z = z, zz 1008 1009 if v&mask != 0 { 1010 zz = zz.mul(z, x) 1011 zz, z = z, zz 1012 } 1013 1014 if len(m) != 0 { 1015 zz, r = zz.div(r, z, m) 1016 zz, r, q, z = q, z, zz, r 1017 } 1018 1019 v <<= 1 1020 } 1021 } 1022 1023 return z.norm() 1024 } 1025 1026 // expNNWindowed calculates x**y mod m using a fixed, 4-bit window. 1027 func (z nat) expNNWindowed(x, y, m nat) nat { 1028 // zz and r are used to avoid allocating in mul and div as otherwise 1029 // the arguments would alias. 1030 var zz, r nat 1031 1032 const n = 4 1033 // powers[i] contains x^i. 1034 var powers [1 << n]nat 1035 powers[0] = natOne 1036 powers[1] = x 1037 for i := 2; i < 1<<n; i += 2 { 1038 p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1] 1039 *p = p.mul(*p2, *p2) 1040 zz, r = zz.div(r, *p, m) 1041 *p, r = r, *p 1042 *p1 = p1.mul(*p, x) 1043 zz, r = zz.div(r, *p1, m) 1044 *p1, r = r, *p1 1045 } 1046 1047 z = z.setWord(1) 1048 1049 for i := len(y) - 1; i >= 0; i-- { 1050 yi := y[i] 1051 for j := 0; j < _W; j += n { 1052 if i != len(y)-1 || j != 0 { 1053 // Unrolled loop for significant performance 1054 // gain. Use go test -bench=".*" in crypto/rsa 1055 // to check performance before making changes. 1056 zz = zz.mul(z, z) 1057 zz, z = z, zz 1058 zz, r = zz.div(r, z, m) 1059 z, r = r, z 1060 1061 zz = zz.mul(z, z) 1062 zz, z = z, zz 1063 zz, r = zz.div(r, z, m) 1064 z, r = r, z 1065 1066 zz = zz.mul(z, z) 1067 zz, z = z, zz 1068 zz, r = zz.div(r, z, m) 1069 z, r = r, z 1070 1071 zz = zz.mul(z, z) 1072 zz, z = z, zz 1073 zz, r = zz.div(r, z, m) 1074 z, r = r, z 1075 } 1076 1077 zz = zz.mul(z, powers[yi>>(_W-n)]) 1078 zz, z = z, zz 1079 zz, r = zz.div(r, z, m) 1080 z, r = r, z 1081 1082 yi <<= n 1083 } 1084 } 1085 1086 return z.norm() 1087 } 1088 1089 // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. 1090 // Uses Montgomery representation. 1091 func (z nat) expNNMontgomery(x, y, m nat) nat { 1092 numWords := len(m) 1093 1094 // We want the lengths of x and m to be equal. 1095 // It is OK if x >= m as long as len(x) == len(m). 1096 if len(x) > numWords { 1097 _, x = nat(nil).div(nil, x, m) 1098 // Note: now len(x) <= numWords, not guaranteed ==. 1099 } 1100 if len(x) < numWords { 1101 rr := make(nat, numWords) 1102 copy(rr, x) 1103 x = rr 1104 } 1105 1106 // Ideally the precomputations would be performed outside, and reused 1107 // k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson 1108 // Iteration for Multiplicative Inverses Modulo Prime Powers". 1109 k0 := 2 - m[0] 1110 t := m[0] - 1 1111 for i := 1; i < _W; i <<= 1 { 1112 t *= t 1113 k0 *= (t + 1) 1114 } 1115 k0 = -k0 1116 1117 // RR = 2**(2*_W*len(m)) mod m 1118 RR := nat(nil).setWord(1) 1119 zz := nat(nil).shl(RR, uint(2*numWords*_W)) 1120 _, RR = RR.div(RR, zz, m) 1121 if len(RR) < numWords { 1122 zz = zz.make(numWords) 1123 copy(zz, RR) 1124 RR = zz 1125 } 1126 // one = 1, with equal length to that of m 1127 one := make(nat, numWords) 1128 one[0] = 1 1129 1130 const n = 4 1131 // powers[i] contains x^i 1132 var powers [1 << n]nat 1133 powers[0] = powers[0].montgomery(one, RR, m, k0, numWords) 1134 powers[1] = powers[1].montgomery(x, RR, m, k0, numWords) 1135 for i := 2; i < 1<<n; i++ { 1136 powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords) 1137 } 1138 1139 // initialize z = 1 (Montgomery 1) 1140 z = z.make(numWords) 1141 copy(z, powers[0]) 1142 1143 zz = zz.make(numWords) 1144 1145 // same windowed exponent, but with Montgomery multiplications 1146 for i := len(y) - 1; i >= 0; i-- { 1147 yi := y[i] 1148 for j := 0; j < _W; j += n { 1149 if i != len(y)-1 || j != 0 { 1150 zz = zz.montgomery(z, z, m, k0, numWords) 1151 z = z.montgomery(zz, zz, m, k0, numWords) 1152 zz = zz.montgomery(z, z, m, k0, numWords) 1153 z = z.montgomery(zz, zz, m, k0, numWords) 1154 } 1155 zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords) 1156 z, zz = zz, z 1157 yi <<= n 1158 } 1159 } 1160 // convert to regular number 1161 zz = zz.montgomery(z, one, m, k0, numWords) 1162 1163 // One last reduction, just in case. 1164 // See golang.org/issue/13907. 1165 if zz.cmp(m) >= 0 { 1166 // Common case is m has high bit set; in that case, 1167 // since zz is the same length as m, there can be just 1168 // one multiple of m to remove. Just subtract. 1169 // We think that the subtract should be sufficient in general, 1170 // so do that unconditionally, but double-check, 1171 // in case our beliefs are wrong. 1172 // The div is not expected to be reached. 1173 zz = zz.sub(zz, m) 1174 if zz.cmp(m) >= 0 { 1175 _, zz = nat(nil).div(nil, zz, m) 1176 } 1177 } 1178 1179 return zz.norm() 1180 } 1181 1182 // bytes writes the value of z into buf using big-endian encoding. 1183 // len(buf) must be >= len(z)*_S. The value of z is encoded in the 1184 // slice buf[i:]. The number i of unused bytes at the beginning of 1185 // buf is returned as result. 1186 func (z nat) bytes(buf []byte) (i int) { 1187 i = len(buf) 1188 for _, d := range z { 1189 for j := 0; j < _S; j++ { 1190 i-- 1191 buf[i] = byte(d) 1192 d >>= 8 1193 } 1194 } 1195 1196 for i < len(buf) && buf[i] == 0 { 1197 i++ 1198 } 1199 1200 return 1201 } 1202 1203 // setBytes interprets buf as the bytes of a big-endian unsigned 1204 // integer, sets z to that value, and returns z. 1205 func (z nat) setBytes(buf []byte) nat { 1206 z = z.make((len(buf) + _S - 1) / _S) 1207 1208 k := 0 1209 s := uint(0) 1210 var d Word 1211 for i := len(buf); i > 0; i-- { 1212 d |= Word(buf[i-1]) << s 1213 if s += 8; s == _S*8 { 1214 z[k] = d 1215 k++ 1216 s = 0 1217 d = 0 1218 } 1219 } 1220 if k < len(z) { 1221 z[k] = d 1222 } 1223 1224 return z.norm() 1225 } 1226 1227 // sqrt sets z = ⌊√x⌋ 1228 func (z nat) sqrt(x nat) nat { 1229 if x.cmp(natOne) <= 0 { 1230 return z.set(x) 1231 } 1232 if alias(z, x) { 1233 z = nil 1234 } 1235 1236 // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. 1237 // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). 1238 // https://members.loria.fr/PZimmermann/mca/pub226.html 1239 // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1; 1240 // otherwise it converges to the correct z and stays there. 1241 var z1, z2 nat 1242 z1 = z 1243 z1 = z1.setUint64(1) 1244 z1 = z1.shl(z1, uint(x.bitLen()/2+1)) // must be ≥ √x 1245 for n := 0; ; n++ { 1246 z2, _ = z2.div(nil, x, z1) 1247 z2 = z2.add(z2, z1) 1248 z2 = z2.shr(z2, 1) 1249 if z2.cmp(z1) >= 0 { 1250 // z1 is answer. 1251 // Figure out whether z1 or z2 is currently aliased to z by looking at loop count. 1252 if n&1 == 0 { 1253 return z1 1254 } 1255 return z.set(z1) 1256 } 1257 z1, z2 = z2, z1 1258 } 1259 }