github.com/flyinox/gosm@v0.0.0-20171117061539-16768cb62077/src/math/big/nat.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // This file implements unsigned multi-precision integers (natural 6 // numbers). They are the building blocks for the implementation 7 // of signed integers, rationals, and floating-point numbers. 8 9 package big 10 11 import ( 12 "math/bits" 13 "math/rand" 14 "sync" 15 ) 16 17 // An unsigned integer x of the form 18 // 19 // x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] 20 // 21 // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, 22 // with the digits x[i] as the slice elements. 23 // 24 // A number is normalized if the slice contains no leading 0 digits. 25 // During arithmetic operations, denormalized values may occur but are 26 // always normalized before returning the final result. The normalized 27 // representation of 0 is the empty or nil slice (length = 0). 28 // 29 type nat []Word 30 31 var ( 32 natOne = nat{1} 33 natTwo = nat{2} 34 natTen = nat{10} 35 ) 36 37 func (z nat) clear() { 38 for i := range z { 39 z[i] = 0 40 } 41 } 42 43 func (z nat) norm() nat { 44 i := len(z) 45 for i > 0 && z[i-1] == 0 { 46 i-- 47 } 48 return z[0:i] 49 } 50 51 func (z nat) make(n int) nat { 52 if n <= cap(z) { 53 return z[:n] // reuse z 54 } 55 // Choosing a good value for e has significant performance impact 56 // because it increases the chance that a value can be reused. 57 const e = 4 // extra capacity 58 return make(nat, n, n+e) 59 } 60 61 func (z nat) setWord(x Word) nat { 62 if x == 0 { 63 return z[:0] 64 } 65 z = z.make(1) 66 z[0] = x 67 return z 68 } 69 70 func (z nat) setUint64(x uint64) nat { 71 // single-word value 72 if w := Word(x); uint64(w) == x { 73 return z.setWord(w) 74 } 75 // 2-word value 76 z = z.make(2) 77 z[1] = Word(x >> 32) 78 z[0] = Word(x) 79 return z 80 } 81 82 func (z nat) set(x nat) nat { 83 z = z.make(len(x)) 84 copy(z, x) 85 return z 86 } 87 88 func (z nat) add(x, y nat) nat { 89 m := len(x) 90 n := len(y) 91 92 switch { 93 case m < n: 94 return z.add(y, x) 95 case m == 0: 96 // n == 0 because m >= n; result is 0 97 return z[:0] 98 case n == 0: 99 // result is x 100 return z.set(x) 101 } 102 // m > 0 103 104 z = z.make(m + 1) 105 c := addVV(z[0:n], x, y) 106 if m > n { 107 c = addVW(z[n:m], x[n:], c) 108 } 109 z[m] = c 110 111 return z.norm() 112 } 113 114 func (z nat) sub(x, y nat) nat { 115 m := len(x) 116 n := len(y) 117 118 switch { 119 case m < n: 120 panic("underflow") 121 case m == 0: 122 // n == 0 because m >= n; result is 0 123 return z[:0] 124 case n == 0: 125 // result is x 126 return z.set(x) 127 } 128 // m > 0 129 130 z = z.make(m) 131 c := subVV(z[0:n], x, y) 132 if m > n { 133 c = subVW(z[n:], x[n:], c) 134 } 135 if c != 0 { 136 panic("underflow") 137 } 138 139 return z.norm() 140 } 141 142 func (x nat) cmp(y nat) (r int) { 143 m := len(x) 144 n := len(y) 145 if m != n || m == 0 { 146 switch { 147 case m < n: 148 r = -1 149 case m > n: 150 r = 1 151 } 152 return 153 } 154 155 i := m - 1 156 for i > 0 && x[i] == y[i] { 157 i-- 158 } 159 160 switch { 161 case x[i] < y[i]: 162 r = -1 163 case x[i] > y[i]: 164 r = 1 165 } 166 return 167 } 168 169 func (z nat) mulAddWW(x nat, y, r Word) nat { 170 m := len(x) 171 if m == 0 || y == 0 { 172 return z.setWord(r) // result is r 173 } 174 // m > 0 175 176 z = z.make(m + 1) 177 z[m] = mulAddVWW(z[0:m], x, y, r) 178 179 return z.norm() 180 } 181 182 // basicMul multiplies x and y and leaves the result in z. 183 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. 184 func basicMul(z, x, y nat) { 185 z[0 : len(x)+len(y)].clear() // initialize z 186 for i, d := range y { 187 if d != 0 { 188 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) 189 } 190 } 191 } 192 193 // montgomery computes z mod m = x*y*2**(-n*_W) mod m, 194 // assuming k = -1/m mod 2**_W. 195 // z is used for storing the result which is returned; 196 // z must not alias x, y or m. 197 // See Gueron, "Efficient Software Implementations of Modular Exponentiation". 198 // https://eprint.iacr.org/2011/239.pdf 199 // In the terminology of that paper, this is an "Almost Montgomery Multiplication": 200 // x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result 201 // z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m. 202 func (z nat) montgomery(x, y, m nat, k Word, n int) nat { 203 // This code assumes x, y, m are all the same length, n. 204 // (required by addMulVVW and the for loop). 205 // It also assumes that x, y are already reduced mod m, 206 // or else the result will not be properly reduced. 207 if len(x) != n || len(y) != n || len(m) != n { 208 panic("math/big: mismatched montgomery number lengths") 209 } 210 z = z.make(n) 211 z.clear() 212 var c Word 213 for i := 0; i < n; i++ { 214 d := y[i] 215 c2 := addMulVVW(z, x, d) 216 t := z[0] * k 217 c3 := addMulVVW(z, m, t) 218 copy(z, z[1:]) 219 cx := c + c2 220 cy := cx + c3 221 z[n-1] = cy 222 if cx < c2 || cy < c3 { 223 c = 1 224 } else { 225 c = 0 226 } 227 } 228 if c != 0 { 229 subVV(z, z, m) 230 } 231 return z 232 } 233 234 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. 235 // Factored out for readability - do not use outside karatsuba. 236 func karatsubaAdd(z, x nat, n int) { 237 if c := addVV(z[0:n], z, x); c != 0 { 238 addVW(z[n:n+n>>1], z[n:], c) 239 } 240 } 241 242 // Like karatsubaAdd, but does subtract. 243 func karatsubaSub(z, x nat, n int) { 244 if c := subVV(z[0:n], z, x); c != 0 { 245 subVW(z[n:n+n>>1], z[n:], c) 246 } 247 } 248 249 // Operands that are shorter than karatsubaThreshold are multiplied using 250 // "grade school" multiplication; for longer operands the Karatsuba algorithm 251 // is used. 252 var karatsubaThreshold int = 40 // computed by calibrate.go 253 254 // karatsuba multiplies x and y and leaves the result in z. 255 // Both x and y must have the same length n and n must be a 256 // power of 2. The result vector z must have len(z) >= 6*n. 257 // The (non-normalized) result is placed in z[0 : 2*n]. 258 func karatsuba(z, x, y nat) { 259 n := len(y) 260 261 // Switch to basic multiplication if numbers are odd or small. 262 // (n is always even if karatsubaThreshold is even, but be 263 // conservative) 264 if n&1 != 0 || n < karatsubaThreshold || n < 2 { 265 basicMul(z, x, y) 266 return 267 } 268 // n&1 == 0 && n >= karatsubaThreshold && n >= 2 269 270 // Karatsuba multiplication is based on the observation that 271 // for two numbers x and y with: 272 // 273 // x = x1*b + x0 274 // y = y1*b + y0 275 // 276 // the product x*y can be obtained with 3 products z2, z1, z0 277 // instead of 4: 278 // 279 // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 280 // = z2*b*b + z1*b + z0 281 // 282 // with: 283 // 284 // xd = x1 - x0 285 // yd = y0 - y1 286 // 287 // z1 = xd*yd + z2 + z0 288 // = (x1-x0)*(y0 - y1) + z2 + z0 289 // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 290 // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 291 // = x1*y0 + x0*y1 292 293 // split x, y into "digits" 294 n2 := n >> 1 // n2 >= 1 295 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 296 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 297 298 // z is used for the result and temporary storage: 299 // 300 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 301 // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] 302 // 303 // For each recursive call of karatsuba, an unused slice of 304 // z is passed in that has (at least) half the length of the 305 // caller's z. 306 307 // compute z0 and z2 with the result "in place" in z 308 karatsuba(z, x0, y0) // z0 = x0*y0 309 karatsuba(z[n:], x1, y1) // z2 = x1*y1 310 311 // compute xd (or the negative value if underflow occurs) 312 s := 1 // sign of product xd*yd 313 xd := z[2*n : 2*n+n2] 314 if subVV(xd, x1, x0) != 0 { // x1-x0 315 s = -s 316 subVV(xd, x0, x1) // x0-x1 317 } 318 319 // compute yd (or the negative value if underflow occurs) 320 yd := z[2*n+n2 : 3*n] 321 if subVV(yd, y0, y1) != 0 { // y0-y1 322 s = -s 323 subVV(yd, y1, y0) // y1-y0 324 } 325 326 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 327 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 328 p := z[n*3:] 329 karatsuba(p, xd, yd) 330 331 // save original z2:z0 332 // (ok to use upper half of z since we're done recursing) 333 r := z[n*4:] 334 copy(r, z[:n*2]) 335 336 // add up all partial products 337 // 338 // 2*n n 0 339 // z = [ z2 | z0 ] 340 // + [ z0 ] 341 // + [ z2 ] 342 // + [ p ] 343 // 344 karatsubaAdd(z[n2:], r, n) 345 karatsubaAdd(z[n2:], r[n:], n) 346 if s > 0 { 347 karatsubaAdd(z[n2:], p, n) 348 } else { 349 karatsubaSub(z[n2:], p, n) 350 } 351 } 352 353 // alias reports whether x and y share the same base array. 354 func alias(x, y nat) bool { 355 return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 356 } 357 358 // addAt implements z += x<<(_W*i); z must be long enough. 359 // (we don't use nat.add because we need z to stay the same 360 // slice, and we don't need to normalize z after each addition) 361 func addAt(z, x nat, i int) { 362 if n := len(x); n > 0 { 363 if c := addVV(z[i:i+n], z[i:], x); c != 0 { 364 j := i + n 365 if j < len(z) { 366 addVW(z[j:], z[j:], c) 367 } 368 } 369 } 370 } 371 372 func max(x, y int) int { 373 if x > y { 374 return x 375 } 376 return y 377 } 378 379 // karatsubaLen computes an approximation to the maximum k <= n such that 380 // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the 381 // result is the largest number that can be divided repeatedly by 2 before 382 // becoming about the value of karatsubaThreshold. 383 func karatsubaLen(n int) int { 384 i := uint(0) 385 for n > karatsubaThreshold { 386 n >>= 1 387 i++ 388 } 389 return n << i 390 } 391 392 func (z nat) mul(x, y nat) nat { 393 m := len(x) 394 n := len(y) 395 396 switch { 397 case m < n: 398 return z.mul(y, x) 399 case m == 0 || n == 0: 400 return z[:0] 401 case n == 1: 402 return z.mulAddWW(x, y[0], 0) 403 } 404 // m >= n > 1 405 406 // determine if z can be reused 407 if alias(z, x) || alias(z, y) { 408 z = nil // z is an alias for x or y - cannot reuse 409 } 410 411 // use basic multiplication if the numbers are small 412 if n < karatsubaThreshold { 413 z = z.make(m + n) 414 basicMul(z, x, y) 415 return z.norm() 416 } 417 // m >= n && n >= karatsubaThreshold && n >= 2 418 419 // determine Karatsuba length k such that 420 // 421 // x = xh*b + x0 (0 <= x0 < b) 422 // y = yh*b + y0 (0 <= y0 < b) 423 // b = 1<<(_W*k) ("base" of digits xi, yi) 424 // 425 k := karatsubaLen(n) 426 // k <= n 427 428 // multiply x0 and y0 via Karatsuba 429 x0 := x[0:k] // x0 is not normalized 430 y0 := y[0:k] // y0 is not normalized 431 z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y 432 karatsuba(z, x0, y0) 433 z = z[0 : m+n] // z has final length but may be incomplete 434 z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) 435 436 // If xh != 0 or yh != 0, add the missing terms to z. For 437 // 438 // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) 439 // yh = y1*b (0 <= y1 < b) 440 // 441 // the missing terms are 442 // 443 // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 444 // 445 // since all the yi for i > 1 are 0 by choice of k: If any of them 446 // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would 447 // be a larger valid threshold contradicting the assumption about k. 448 // 449 if k < n || m != n { 450 var t nat 451 452 // add x0*y1*b 453 x0 := x0.norm() 454 y1 := y[k:] // y1 is normalized because y is 455 t = t.mul(x0, y1) // update t so we don't lose t's underlying array 456 addAt(z, t, k) 457 458 // add xi*y0<<i, xi*y1*b<<(i+k) 459 y0 := y0.norm() 460 for i := k; i < len(x); i += k { 461 xi := x[i:] 462 if len(xi) > k { 463 xi = xi[:k] 464 } 465 xi = xi.norm() 466 t = t.mul(xi, y0) 467 addAt(z, t, i) 468 t = t.mul(xi, y1) 469 addAt(z, t, i+k) 470 } 471 } 472 473 return z.norm() 474 } 475 476 // mulRange computes the product of all the unsigned integers in the 477 // range [a, b] inclusively. If a > b (empty range), the result is 1. 478 func (z nat) mulRange(a, b uint64) nat { 479 switch { 480 case a == 0: 481 // cut long ranges short (optimization) 482 return z.setUint64(0) 483 case a > b: 484 return z.setUint64(1) 485 case a == b: 486 return z.setUint64(a) 487 case a+1 == b: 488 return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) 489 } 490 m := (a + b) / 2 491 return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) 492 } 493 494 // q = (x-r)/y, with 0 <= r < y 495 func (z nat) divW(x nat, y Word) (q nat, r Word) { 496 m := len(x) 497 switch { 498 case y == 0: 499 panic("division by zero") 500 case y == 1: 501 q = z.set(x) // result is x 502 return 503 case m == 0: 504 q = z[:0] // result is 0 505 return 506 } 507 // m > 0 508 z = z.make(m) 509 r = divWVW(z, 0, x, y) 510 q = z.norm() 511 return 512 } 513 514 func (z nat) div(z2, u, v nat) (q, r nat) { 515 if len(v) == 0 { 516 panic("division by zero") 517 } 518 519 if u.cmp(v) < 0 { 520 q = z[:0] 521 r = z2.set(u) 522 return 523 } 524 525 if len(v) == 1 { 526 var r2 Word 527 q, r2 = z.divW(u, v[0]) 528 r = z2.setWord(r2) 529 return 530 } 531 532 q, r = z.divLarge(z2, u, v) 533 return 534 } 535 536 // getNat returns a *nat of len n. The contents may not be zero. 537 // The pool holds *nat to avoid allocation when converting to interface{}. 538 func getNat(n int) *nat { 539 var z *nat 540 if v := natPool.Get(); v != nil { 541 z = v.(*nat) 542 } 543 if z == nil { 544 z = new(nat) 545 } 546 *z = z.make(n) 547 return z 548 } 549 550 func putNat(x *nat) { 551 natPool.Put(x) 552 } 553 554 var natPool sync.Pool 555 556 // q = (uIn-r)/v, with 0 <= r < y 557 // Uses z as storage for q, and u as storage for r if possible. 558 // See Knuth, Volume 2, section 4.3.1, Algorithm D. 559 // Preconditions: 560 // len(v) >= 2 561 // len(uIn) >= len(v) 562 func (z nat) divLarge(u, uIn, v nat) (q, r nat) { 563 n := len(v) 564 m := len(uIn) - n 565 566 // determine if z can be reused 567 // TODO(gri) should find a better solution - this if statement 568 // is very costly (see e.g. time pidigits -s -n 10000) 569 if alias(z, uIn) || alias(z, v) { 570 z = nil // z is an alias for uIn or v - cannot reuse 571 } 572 q = z.make(m + 1) 573 574 qhatvp := getNat(n + 1) 575 qhatv := *qhatvp 576 if alias(u, uIn) || alias(u, v) { 577 u = nil // u is an alias for uIn or v - cannot reuse 578 } 579 u = u.make(len(uIn) + 1) 580 u.clear() // TODO(gri) no need to clear if we allocated a new u 581 582 // D1. 583 var v1p *nat 584 shift := nlz(v[n-1]) 585 if shift > 0 { 586 // do not modify v, it may be used by another goroutine simultaneously 587 v1p = getNat(n) 588 v1 := *v1p 589 shlVU(v1, v, shift) 590 v = v1 591 } 592 u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift) 593 594 // D2. 595 vn1 := v[n-1] 596 for j := m; j >= 0; j-- { 597 // D3. 598 qhat := Word(_M) 599 if ujn := u[j+n]; ujn != vn1 { 600 var rhat Word 601 qhat, rhat = divWW(ujn, u[j+n-1], vn1) 602 603 // x1 | x2 = q̂v_{n-2} 604 vn2 := v[n-2] 605 x1, x2 := mulWW(qhat, vn2) 606 // test if q̂v_{n-2} > br̂ + u_{j+n-2} 607 ujn2 := u[j+n-2] 608 for greaterThan(x1, x2, rhat, ujn2) { 609 qhat-- 610 prevRhat := rhat 611 rhat += vn1 612 // v[n-1] >= 0, so this tests for overflow. 613 if rhat < prevRhat { 614 break 615 } 616 x1, x2 = mulWW(qhat, vn2) 617 } 618 } 619 620 // D4. 621 qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) 622 623 c := subVV(u[j:j+len(qhatv)], u[j:], qhatv) 624 if c != 0 { 625 c := addVV(u[j:j+n], u[j:], v) 626 u[j+n] += c 627 qhat-- 628 } 629 630 q[j] = qhat 631 } 632 if v1p != nil { 633 putNat(v1p) 634 } 635 putNat(qhatvp) 636 637 q = q.norm() 638 shrVU(u, u, shift) 639 r = u.norm() 640 641 return q, r 642 } 643 644 // Length of x in bits. x must be normalized. 645 func (x nat) bitLen() int { 646 if i := len(x) - 1; i >= 0 { 647 return i*_W + bits.Len(uint(x[i])) 648 } 649 return 0 650 } 651 652 // trailingZeroBits returns the number of consecutive least significant zero 653 // bits of x. 654 func (x nat) trailingZeroBits() uint { 655 if len(x) == 0 { 656 return 0 657 } 658 var i uint 659 for x[i] == 0 { 660 i++ 661 } 662 // x[i] != 0 663 return i*_W + uint(bits.TrailingZeros(uint(x[i]))) 664 } 665 666 // z = x << s 667 func (z nat) shl(x nat, s uint) nat { 668 m := len(x) 669 if m == 0 { 670 return z[:0] 671 } 672 // m > 0 673 674 n := m + int(s/_W) 675 z = z.make(n + 1) 676 z[n] = shlVU(z[n-m:n], x, s%_W) 677 z[0 : n-m].clear() 678 679 return z.norm() 680 } 681 682 // z = x >> s 683 func (z nat) shr(x nat, s uint) nat { 684 m := len(x) 685 n := m - int(s/_W) 686 if n <= 0 { 687 return z[:0] 688 } 689 // n > 0 690 691 z = z.make(n) 692 shrVU(z, x[m-n:], s%_W) 693 694 return z.norm() 695 } 696 697 func (z nat) setBit(x nat, i uint, b uint) nat { 698 j := int(i / _W) 699 m := Word(1) << (i % _W) 700 n := len(x) 701 switch b { 702 case 0: 703 z = z.make(n) 704 copy(z, x) 705 if j >= n { 706 // no need to grow 707 return z 708 } 709 z[j] &^= m 710 return z.norm() 711 case 1: 712 if j >= n { 713 z = z.make(j + 1) 714 z[n:].clear() 715 } else { 716 z = z.make(n) 717 } 718 copy(z, x) 719 z[j] |= m 720 // no need to normalize 721 return z 722 } 723 panic("set bit is not 0 or 1") 724 } 725 726 // bit returns the value of the i'th bit, with lsb == bit 0. 727 func (x nat) bit(i uint) uint { 728 j := i / _W 729 if j >= uint(len(x)) { 730 return 0 731 } 732 // 0 <= j < len(x) 733 return uint(x[j] >> (i % _W) & 1) 734 } 735 736 // sticky returns 1 if there's a 1 bit within the 737 // i least significant bits, otherwise it returns 0. 738 func (x nat) sticky(i uint) uint { 739 j := i / _W 740 if j >= uint(len(x)) { 741 if len(x) == 0 { 742 return 0 743 } 744 return 1 745 } 746 // 0 <= j < len(x) 747 for _, x := range x[:j] { 748 if x != 0 { 749 return 1 750 } 751 } 752 if x[j]<<(_W-i%_W) != 0 { 753 return 1 754 } 755 return 0 756 } 757 758 func (z nat) and(x, y nat) nat { 759 m := len(x) 760 n := len(y) 761 if m > n { 762 m = n 763 } 764 // m <= n 765 766 z = z.make(m) 767 for i := 0; i < m; i++ { 768 z[i] = x[i] & y[i] 769 } 770 771 return z.norm() 772 } 773 774 func (z nat) andNot(x, y nat) nat { 775 m := len(x) 776 n := len(y) 777 if n > m { 778 n = m 779 } 780 // m >= n 781 782 z = z.make(m) 783 for i := 0; i < n; i++ { 784 z[i] = x[i] &^ y[i] 785 } 786 copy(z[n:m], x[n:m]) 787 788 return z.norm() 789 } 790 791 func (z nat) or(x, y nat) nat { 792 m := len(x) 793 n := len(y) 794 s := x 795 if m < n { 796 n, m = m, n 797 s = y 798 } 799 // m >= n 800 801 z = z.make(m) 802 for i := 0; i < n; i++ { 803 z[i] = x[i] | y[i] 804 } 805 copy(z[n:m], s[n:m]) 806 807 return z.norm() 808 } 809 810 func (z nat) xor(x, y nat) nat { 811 m := len(x) 812 n := len(y) 813 s := x 814 if m < n { 815 n, m = m, n 816 s = y 817 } 818 // m >= n 819 820 z = z.make(m) 821 for i := 0; i < n; i++ { 822 z[i] = x[i] ^ y[i] 823 } 824 copy(z[n:m], s[n:m]) 825 826 return z.norm() 827 } 828 829 // greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2) 830 func greaterThan(x1, x2, y1, y2 Word) bool { 831 return x1 > y1 || x1 == y1 && x2 > y2 832 } 833 834 // modW returns x % d. 835 func (x nat) modW(d Word) (r Word) { 836 // TODO(agl): we don't actually need to store the q value. 837 var q nat 838 q = q.make(len(x)) 839 return divWVW(q, 0, x, d) 840 } 841 842 // random creates a random integer in [0..limit), using the space in z if 843 // possible. n is the bit length of limit. 844 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 845 if alias(z, limit) { 846 z = nil // z is an alias for limit - cannot reuse 847 } 848 z = z.make(len(limit)) 849 850 bitLengthOfMSW := uint(n % _W) 851 if bitLengthOfMSW == 0 { 852 bitLengthOfMSW = _W 853 } 854 mask := Word((1 << bitLengthOfMSW) - 1) 855 856 for { 857 switch _W { 858 case 32: 859 for i := range z { 860 z[i] = Word(rand.Uint32()) 861 } 862 case 64: 863 for i := range z { 864 z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 865 } 866 default: 867 panic("unknown word size") 868 } 869 z[len(limit)-1] &= mask 870 if z.cmp(limit) < 0 { 871 break 872 } 873 } 874 875 return z.norm() 876 } 877 878 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; 879 // otherwise it sets z to x**y. The result is the value of z. 880 func (z nat) expNN(x, y, m nat) nat { 881 if alias(z, x) || alias(z, y) { 882 // We cannot allow in-place modification of x or y. 883 z = nil 884 } 885 886 // x**y mod 1 == 0 887 if len(m) == 1 && m[0] == 1 { 888 return z.setWord(0) 889 } 890 // m == 0 || m > 1 891 892 // x**0 == 1 893 if len(y) == 0 { 894 return z.setWord(1) 895 } 896 // y > 0 897 898 // x**1 mod m == x mod m 899 if len(y) == 1 && y[0] == 1 && len(m) != 0 { 900 _, z = z.div(z, x, m) 901 return z 902 } 903 // y > 1 904 905 if len(m) != 0 { 906 // We likely end up being as long as the modulus. 907 z = z.make(len(m)) 908 } 909 z = z.set(x) 910 911 // If the base is non-trivial and the exponent is large, we use 912 // 4-bit, windowed exponentiation. This involves precomputing 14 values 913 // (x^2...x^15) but then reduces the number of multiply-reduces by a 914 // third. Even for a 32-bit exponent, this reduces the number of 915 // operations. Uses Montgomery method for odd moduli. 916 if x.cmp(natOne) > 0 && len(y) > 1 && len(m) > 0 { 917 if m[0]&1 == 1 { 918 return z.expNNMontgomery(x, y, m) 919 } 920 return z.expNNWindowed(x, y, m) 921 } 922 923 v := y[len(y)-1] // v > 0 because y is normalized and y > 0 924 shift := nlz(v) + 1 925 v <<= shift 926 var q nat 927 928 const mask = 1 << (_W - 1) 929 930 // We walk through the bits of the exponent one by one. Each time we 931 // see a bit, we square, thus doubling the power. If the bit is a one, 932 // we also multiply by x, thus adding one to the power. 933 934 w := _W - int(shift) 935 // zz and r are used to avoid allocating in mul and div as 936 // otherwise the arguments would alias. 937 var zz, r nat 938 for j := 0; j < w; j++ { 939 zz = zz.mul(z, z) 940 zz, z = z, zz 941 942 if v&mask != 0 { 943 zz = zz.mul(z, x) 944 zz, z = z, zz 945 } 946 947 if len(m) != 0 { 948 zz, r = zz.div(r, z, m) 949 zz, r, q, z = q, z, zz, r 950 } 951 952 v <<= 1 953 } 954 955 for i := len(y) - 2; i >= 0; i-- { 956 v = y[i] 957 958 for j := 0; j < _W; j++ { 959 zz = zz.mul(z, z) 960 zz, z = z, zz 961 962 if v&mask != 0 { 963 zz = zz.mul(z, x) 964 zz, z = z, zz 965 } 966 967 if len(m) != 0 { 968 zz, r = zz.div(r, z, m) 969 zz, r, q, z = q, z, zz, r 970 } 971 972 v <<= 1 973 } 974 } 975 976 return z.norm() 977 } 978 979 // expNNWindowed calculates x**y mod m using a fixed, 4-bit window. 980 func (z nat) expNNWindowed(x, y, m nat) nat { 981 // zz and r are used to avoid allocating in mul and div as otherwise 982 // the arguments would alias. 983 var zz, r nat 984 985 const n = 4 986 // powers[i] contains x^i. 987 var powers [1 << n]nat 988 powers[0] = natOne 989 powers[1] = x 990 for i := 2; i < 1<<n; i += 2 { 991 p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1] 992 *p = p.mul(*p2, *p2) 993 zz, r = zz.div(r, *p, m) 994 *p, r = r, *p 995 *p1 = p1.mul(*p, x) 996 zz, r = zz.div(r, *p1, m) 997 *p1, r = r, *p1 998 } 999 1000 z = z.setWord(1) 1001 1002 for i := len(y) - 1; i >= 0; i-- { 1003 yi := y[i] 1004 for j := 0; j < _W; j += n { 1005 if i != len(y)-1 || j != 0 { 1006 // Unrolled loop for significant performance 1007 // gain. Use go test -bench=".*" in crypto/rsa 1008 // to check performance before making changes. 1009 zz = zz.mul(z, z) 1010 zz, z = z, zz 1011 zz, r = zz.div(r, z, m) 1012 z, r = r, z 1013 1014 zz = zz.mul(z, z) 1015 zz, z = z, zz 1016 zz, r = zz.div(r, z, m) 1017 z, r = r, z 1018 1019 zz = zz.mul(z, z) 1020 zz, z = z, zz 1021 zz, r = zz.div(r, z, m) 1022 z, r = r, z 1023 1024 zz = zz.mul(z, z) 1025 zz, z = z, zz 1026 zz, r = zz.div(r, z, m) 1027 z, r = r, z 1028 } 1029 1030 zz = zz.mul(z, powers[yi>>(_W-n)]) 1031 zz, z = z, zz 1032 zz, r = zz.div(r, z, m) 1033 z, r = r, z 1034 1035 yi <<= n 1036 } 1037 } 1038 1039 return z.norm() 1040 } 1041 1042 // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. 1043 // Uses Montgomery representation. 1044 func (z nat) expNNMontgomery(x, y, m nat) nat { 1045 numWords := len(m) 1046 1047 // We want the lengths of x and m to be equal. 1048 // It is OK if x >= m as long as len(x) == len(m). 1049 if len(x) > numWords { 1050 _, x = nat(nil).div(nil, x, m) 1051 // Note: now len(x) <= numWords, not guaranteed ==. 1052 } 1053 if len(x) < numWords { 1054 rr := make(nat, numWords) 1055 copy(rr, x) 1056 x = rr 1057 } 1058 1059 // Ideally the precomputations would be performed outside, and reused 1060 // k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson 1061 // Iteration for Multiplicative Inverses Modulo Prime Powers". 1062 k0 := 2 - m[0] 1063 t := m[0] - 1 1064 for i := 1; i < _W; i <<= 1 { 1065 t *= t 1066 k0 *= (t + 1) 1067 } 1068 k0 = -k0 1069 1070 // RR = 2**(2*_W*len(m)) mod m 1071 RR := nat(nil).setWord(1) 1072 zz := nat(nil).shl(RR, uint(2*numWords*_W)) 1073 _, RR = RR.div(RR, zz, m) 1074 if len(RR) < numWords { 1075 zz = zz.make(numWords) 1076 copy(zz, RR) 1077 RR = zz 1078 } 1079 // one = 1, with equal length to that of m 1080 one := make(nat, numWords) 1081 one[0] = 1 1082 1083 const n = 4 1084 // powers[i] contains x^i 1085 var powers [1 << n]nat 1086 powers[0] = powers[0].montgomery(one, RR, m, k0, numWords) 1087 powers[1] = powers[1].montgomery(x, RR, m, k0, numWords) 1088 for i := 2; i < 1<<n; i++ { 1089 powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords) 1090 } 1091 1092 // initialize z = 1 (Montgomery 1) 1093 z = z.make(numWords) 1094 copy(z, powers[0]) 1095 1096 zz = zz.make(numWords) 1097 1098 // same windowed exponent, but with Montgomery multiplications 1099 for i := len(y) - 1; i >= 0; i-- { 1100 yi := y[i] 1101 for j := 0; j < _W; j += n { 1102 if i != len(y)-1 || j != 0 { 1103 zz = zz.montgomery(z, z, m, k0, numWords) 1104 z = z.montgomery(zz, zz, m, k0, numWords) 1105 zz = zz.montgomery(z, z, m, k0, numWords) 1106 z = z.montgomery(zz, zz, m, k0, numWords) 1107 } 1108 zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords) 1109 z, zz = zz, z 1110 yi <<= n 1111 } 1112 } 1113 // convert to regular number 1114 zz = zz.montgomery(z, one, m, k0, numWords) 1115 1116 // One last reduction, just in case. 1117 // See golang.org/issue/13907. 1118 if zz.cmp(m) >= 0 { 1119 // Common case is m has high bit set; in that case, 1120 // since zz is the same length as m, there can be just 1121 // one multiple of m to remove. Just subtract. 1122 // We think that the subtract should be sufficient in general, 1123 // so do that unconditionally, but double-check, 1124 // in case our beliefs are wrong. 1125 // The div is not expected to be reached. 1126 zz = zz.sub(zz, m) 1127 if zz.cmp(m) >= 0 { 1128 _, zz = nat(nil).div(nil, zz, m) 1129 } 1130 } 1131 1132 return zz.norm() 1133 } 1134 1135 // bytes writes the value of z into buf using big-endian encoding. 1136 // len(buf) must be >= len(z)*_S. The value of z is encoded in the 1137 // slice buf[i:]. The number i of unused bytes at the beginning of 1138 // buf is returned as result. 1139 func (z nat) bytes(buf []byte) (i int) { 1140 i = len(buf) 1141 for _, d := range z { 1142 for j := 0; j < _S; j++ { 1143 i-- 1144 buf[i] = byte(d) 1145 d >>= 8 1146 } 1147 } 1148 1149 for i < len(buf) && buf[i] == 0 { 1150 i++ 1151 } 1152 1153 return 1154 } 1155 1156 // setBytes interprets buf as the bytes of a big-endian unsigned 1157 // integer, sets z to that value, and returns z. 1158 func (z nat) setBytes(buf []byte) nat { 1159 z = z.make((len(buf) + _S - 1) / _S) 1160 1161 k := 0 1162 s := uint(0) 1163 var d Word 1164 for i := len(buf); i > 0; i-- { 1165 d |= Word(buf[i-1]) << s 1166 if s += 8; s == _S*8 { 1167 z[k] = d 1168 k++ 1169 s = 0 1170 d = 0 1171 } 1172 } 1173 if k < len(z) { 1174 z[k] = d 1175 } 1176 1177 return z.norm() 1178 } 1179 1180 // sqrt sets z = ⌊√x⌋ 1181 func (z nat) sqrt(x nat) nat { 1182 if x.cmp(natOne) <= 0 { 1183 return z.set(x) 1184 } 1185 if alias(z, x) { 1186 z = nil 1187 } 1188 1189 // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. 1190 // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). 1191 // https://members.loria.fr/PZimmermann/mca/pub226.html 1192 // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1; 1193 // otherwise it converges to the correct z and stays there. 1194 var z1, z2 nat 1195 z1 = z 1196 z1 = z1.setUint64(1) 1197 z1 = z1.shl(z1, uint(x.bitLen()/2+1)) // must be ≥ √x 1198 for n := 0; ; n++ { 1199 z2, _ = z2.div(nil, x, z1) 1200 z2 = z2.add(z2, z1) 1201 z2 = z2.shr(z2, 1) 1202 if z2.cmp(z1) >= 0 { 1203 // z1 is answer. 1204 // Figure out whether z1 or z2 is currently aliased to z by looking at loop count. 1205 if n&1 == 0 { 1206 return z1 1207 } 1208 return z.set(z1) 1209 } 1210 z1, z2 = z2, z1 1211 } 1212 }