github.com/dorkamotorka/go/src@v0.0.0-20230614113921-187095f0e316/math/big/nat.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // This file implements unsigned multi-precision integers (natural 6 // numbers). They are the building blocks for the implementation 7 // of signed integers, rationals, and floating-point numbers. 8 // 9 // Caution: This implementation relies on the function "alias" 10 // which assumes that (nat) slice capacities are never 11 // changed (no 3-operand slice expressions). If that 12 // changes, alias needs to be updated for correctness. 13 14 package big 15 16 import ( 17 "encoding/binary" 18 "math/bits" 19 "math/rand" 20 "sync" 21 ) 22 23 // An unsigned integer x of the form 24 // 25 // x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] 26 // 27 // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, 28 // with the digits x[i] as the slice elements. 29 // 30 // A number is normalized if the slice contains no leading 0 digits. 31 // During arithmetic operations, denormalized values may occur but are 32 // always normalized before returning the final result. The normalized 33 // representation of 0 is the empty or nil slice (length = 0). 34 type nat []Word 35 36 var ( 37 natOne = nat{1} 38 natTwo = nat{2} 39 natFive = nat{5} 40 natTen = nat{10} 41 ) 42 43 func (z nat) String() string { 44 return "0x" + string(z.itoa(false, 16)) 45 } 46 47 func (z nat) clear() { 48 for i := range z { 49 z[i] = 0 50 } 51 } 52 53 func (z nat) norm() nat { 54 i := len(z) 55 for i > 0 && z[i-1] == 0 { 56 i-- 57 } 58 return z[0:i] 59 } 60 61 func (z nat) make(n int) nat { 62 if n <= cap(z) { 63 return z[:n] // reuse z 64 } 65 if n == 1 { 66 // Most nats start small and stay that way; don't over-allocate. 67 return make(nat, 1) 68 } 69 // Choosing a good value for e has significant performance impact 70 // because it increases the chance that a value can be reused. 71 const e = 4 // extra capacity 72 return make(nat, n, n+e) 73 } 74 75 func (z nat) setWord(x Word) nat { 76 if x == 0 { 77 return z[:0] 78 } 79 z = z.make(1) 80 z[0] = x 81 return z 82 } 83 84 func (z nat) setUint64(x uint64) nat { 85 // single-word value 86 if w := Word(x); uint64(w) == x { 87 return z.setWord(w) 88 } 89 // 2-word value 90 z = z.make(2) 91 z[1] = Word(x >> 32) 92 z[0] = Word(x) 93 return z 94 } 95 96 func (z nat) set(x nat) nat { 97 z = z.make(len(x)) 98 copy(z, x) 99 return z 100 } 101 102 func (z nat) add(x, y nat) nat { 103 m := len(x) 104 n := len(y) 105 106 switch { 107 case m < n: 108 return z.add(y, x) 109 case m == 0: 110 // n == 0 because m >= n; result is 0 111 return z[:0] 112 case n == 0: 113 // result is x 114 return z.set(x) 115 } 116 // m > 0 117 118 z = z.make(m + 1) 119 c := addVV(z[0:n], x, y) 120 if m > n { 121 c = addVW(z[n:m], x[n:], c) 122 } 123 z[m] = c 124 125 return z.norm() 126 } 127 128 func (z nat) sub(x, y nat) nat { 129 m := len(x) 130 n := len(y) 131 132 switch { 133 case m < n: 134 panic("underflow") 135 case m == 0: 136 // n == 0 because m >= n; result is 0 137 return z[:0] 138 case n == 0: 139 // result is x 140 return z.set(x) 141 } 142 // m > 0 143 144 z = z.make(m) 145 c := subVV(z[0:n], x, y) 146 if m > n { 147 c = subVW(z[n:], x[n:], c) 148 } 149 if c != 0 { 150 panic("underflow") 151 } 152 153 return z.norm() 154 } 155 156 func (x nat) cmp(y nat) (r int) { 157 m := len(x) 158 n := len(y) 159 if m != n || m == 0 { 160 switch { 161 case m < n: 162 r = -1 163 case m > n: 164 r = 1 165 } 166 return 167 } 168 169 i := m - 1 170 for i > 0 && x[i] == y[i] { 171 i-- 172 } 173 174 switch { 175 case x[i] < y[i]: 176 r = -1 177 case x[i] > y[i]: 178 r = 1 179 } 180 return 181 } 182 183 func (z nat) mulAddWW(x nat, y, r Word) nat { 184 m := len(x) 185 if m == 0 || y == 0 { 186 return z.setWord(r) // result is r 187 } 188 // m > 0 189 190 z = z.make(m + 1) 191 z[m] = mulAddVWW(z[0:m], x, y, r) 192 193 return z.norm() 194 } 195 196 // basicMul multiplies x and y and leaves the result in z. 197 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. 198 func basicMul(z, x, y nat) { 199 z[0 : len(x)+len(y)].clear() // initialize z 200 for i, d := range y { 201 if d != 0 { 202 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) 203 } 204 } 205 } 206 207 // montgomery computes z mod m = x*y*2**(-n*_W) mod m, 208 // assuming k = -1/m mod 2**_W. 209 // z is used for storing the result which is returned; 210 // z must not alias x, y or m. 211 // See Gueron, "Efficient Software Implementations of Modular Exponentiation". 212 // https://eprint.iacr.org/2011/239.pdf 213 // In the terminology of that paper, this is an "Almost Montgomery Multiplication": 214 // x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result 215 // z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m. 216 func (z nat) montgomery(x, y, m nat, k Word, n int) nat { 217 // This code assumes x, y, m are all the same length, n. 218 // (required by addMulVVW and the for loop). 219 // It also assumes that x, y are already reduced mod m, 220 // or else the result will not be properly reduced. 221 if len(x) != n || len(y) != n || len(m) != n { 222 panic("math/big: mismatched montgomery number lengths") 223 } 224 z = z.make(n * 2) 225 z.clear() 226 var c Word 227 for i := 0; i < n; i++ { 228 d := y[i] 229 c2 := addMulVVW(z[i:n+i], x, d) 230 t := z[i] * k 231 c3 := addMulVVW(z[i:n+i], m, t) 232 cx := c + c2 233 cy := cx + c3 234 z[n+i] = cy 235 if cx < c2 || cy < c3 { 236 c = 1 237 } else { 238 c = 0 239 } 240 } 241 if c != 0 { 242 subVV(z[:n], z[n:], m) 243 } else { 244 copy(z[:n], z[n:]) 245 } 246 return z[:n] 247 } 248 249 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. 250 // Factored out for readability - do not use outside karatsuba. 251 func karatsubaAdd(z, x nat, n int) { 252 if c := addVV(z[0:n], z, x); c != 0 { 253 addVW(z[n:n+n>>1], z[n:], c) 254 } 255 } 256 257 // Like karatsubaAdd, but does subtract. 258 func karatsubaSub(z, x nat, n int) { 259 if c := subVV(z[0:n], z, x); c != 0 { 260 subVW(z[n:n+n>>1], z[n:], c) 261 } 262 } 263 264 // Operands that are shorter than karatsubaThreshold are multiplied using 265 // "grade school" multiplication; for longer operands the Karatsuba algorithm 266 // is used. 267 var karatsubaThreshold = 40 // computed by calibrate_test.go 268 269 // karatsuba multiplies x and y and leaves the result in z. 270 // Both x and y must have the same length n and n must be a 271 // power of 2. The result vector z must have len(z) >= 6*n. 272 // The (non-normalized) result is placed in z[0 : 2*n]. 273 func karatsuba(z, x, y nat) { 274 n := len(y) 275 276 // Switch to basic multiplication if numbers are odd or small. 277 // (n is always even if karatsubaThreshold is even, but be 278 // conservative) 279 if n&1 != 0 || n < karatsubaThreshold || n < 2 { 280 basicMul(z, x, y) 281 return 282 } 283 // n&1 == 0 && n >= karatsubaThreshold && n >= 2 284 285 // Karatsuba multiplication is based on the observation that 286 // for two numbers x and y with: 287 // 288 // x = x1*b + x0 289 // y = y1*b + y0 290 // 291 // the product x*y can be obtained with 3 products z2, z1, z0 292 // instead of 4: 293 // 294 // x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0 295 // = z2*b*b + z1*b + z0 296 // 297 // with: 298 // 299 // xd = x1 - x0 300 // yd = y0 - y1 301 // 302 // z1 = xd*yd + z2 + z0 303 // = (x1-x0)*(y0 - y1) + z2 + z0 304 // = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0 305 // = x1*y0 - z2 - z0 + x0*y1 + z2 + z0 306 // = x1*y0 + x0*y1 307 308 // split x, y into "digits" 309 n2 := n >> 1 // n2 >= 1 310 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 311 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 312 313 // z is used for the result and temporary storage: 314 // 315 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 316 // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] 317 // 318 // For each recursive call of karatsuba, an unused slice of 319 // z is passed in that has (at least) half the length of the 320 // caller's z. 321 322 // compute z0 and z2 with the result "in place" in z 323 karatsuba(z, x0, y0) // z0 = x0*y0 324 karatsuba(z[n:], x1, y1) // z2 = x1*y1 325 326 // compute xd (or the negative value if underflow occurs) 327 s := 1 // sign of product xd*yd 328 xd := z[2*n : 2*n+n2] 329 if subVV(xd, x1, x0) != 0 { // x1-x0 330 s = -s 331 subVV(xd, x0, x1) // x0-x1 332 } 333 334 // compute yd (or the negative value if underflow occurs) 335 yd := z[2*n+n2 : 3*n] 336 if subVV(yd, y0, y1) != 0 { // y0-y1 337 s = -s 338 subVV(yd, y1, y0) // y1-y0 339 } 340 341 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 342 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 343 p := z[n*3:] 344 karatsuba(p, xd, yd) 345 346 // save original z2:z0 347 // (ok to use upper half of z since we're done recurring) 348 r := z[n*4:] 349 copy(r, z[:n*2]) 350 351 // add up all partial products 352 // 353 // 2*n n 0 354 // z = [ z2 | z0 ] 355 // + [ z0 ] 356 // + [ z2 ] 357 // + [ p ] 358 // 359 karatsubaAdd(z[n2:], r, n) 360 karatsubaAdd(z[n2:], r[n:], n) 361 if s > 0 { 362 karatsubaAdd(z[n2:], p, n) 363 } else { 364 karatsubaSub(z[n2:], p, n) 365 } 366 } 367 368 // alias reports whether x and y share the same base array. 369 // 370 // Note: alias assumes that the capacity of underlying arrays 371 // is never changed for nat values; i.e. that there are 372 // no 3-operand slice expressions in this code (or worse, 373 // reflect-based operations to the same effect). 374 func alias(x, y nat) bool { 375 return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 376 } 377 378 // addAt implements z += x<<(_W*i); z must be long enough. 379 // (we don't use nat.add because we need z to stay the same 380 // slice, and we don't need to normalize z after each addition) 381 func addAt(z, x nat, i int) { 382 if n := len(x); n > 0 { 383 if c := addVV(z[i:i+n], z[i:], x); c != 0 { 384 j := i + n 385 if j < len(z) { 386 addVW(z[j:], z[j:], c) 387 } 388 } 389 } 390 } 391 392 func max(x, y int) int { 393 if x > y { 394 return x 395 } 396 return y 397 } 398 399 // karatsubaLen computes an approximation to the maximum k <= n such that 400 // k = p<<i for a number p <= threshold and an i >= 0. Thus, the 401 // result is the largest number that can be divided repeatedly by 2 before 402 // becoming about the value of threshold. 403 func karatsubaLen(n, threshold int) int { 404 i := uint(0) 405 for n > threshold { 406 n >>= 1 407 i++ 408 } 409 return n << i 410 } 411 412 func (z nat) mul(x, y nat) nat { 413 m := len(x) 414 n := len(y) 415 416 switch { 417 case m < n: 418 return z.mul(y, x) 419 case m == 0 || n == 0: 420 return z[:0] 421 case n == 1: 422 return z.mulAddWW(x, y[0], 0) 423 } 424 // m >= n > 1 425 426 // determine if z can be reused 427 if alias(z, x) || alias(z, y) { 428 z = nil // z is an alias for x or y - cannot reuse 429 } 430 431 // use basic multiplication if the numbers are small 432 if n < karatsubaThreshold { 433 z = z.make(m + n) 434 basicMul(z, x, y) 435 return z.norm() 436 } 437 // m >= n && n >= karatsubaThreshold && n >= 2 438 439 // determine Karatsuba length k such that 440 // 441 // x = xh*b + x0 (0 <= x0 < b) 442 // y = yh*b + y0 (0 <= y0 < b) 443 // b = 1<<(_W*k) ("base" of digits xi, yi) 444 // 445 k := karatsubaLen(n, karatsubaThreshold) 446 // k <= n 447 448 // multiply x0 and y0 via Karatsuba 449 x0 := x[0:k] // x0 is not normalized 450 y0 := y[0:k] // y0 is not normalized 451 z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y 452 karatsuba(z, x0, y0) 453 z = z[0 : m+n] // z has final length but may be incomplete 454 z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m) 455 456 // If xh != 0 or yh != 0, add the missing terms to z. For 457 // 458 // xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b) 459 // yh = y1*b (0 <= y1 < b) 460 // 461 // the missing terms are 462 // 463 // x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0 464 // 465 // since all the yi for i > 1 are 0 by choice of k: If any of them 466 // were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would 467 // be a larger valid threshold contradicting the assumption about k. 468 // 469 if k < n || m != n { 470 tp := getNat(3 * k) 471 t := *tp 472 473 // add x0*y1*b 474 x0 := x0.norm() 475 y1 := y[k:] // y1 is normalized because y is 476 t = t.mul(x0, y1) // update t so we don't lose t's underlying array 477 addAt(z, t, k) 478 479 // add xi*y0<<i, xi*y1*b<<(i+k) 480 y0 := y0.norm() 481 for i := k; i < len(x); i += k { 482 xi := x[i:] 483 if len(xi) > k { 484 xi = xi[:k] 485 } 486 xi = xi.norm() 487 t = t.mul(xi, y0) 488 addAt(z, t, i) 489 t = t.mul(xi, y1) 490 addAt(z, t, i+k) 491 } 492 493 putNat(tp) 494 } 495 496 return z.norm() 497 } 498 499 // basicSqr sets z = x*x and is asymptotically faster than basicMul 500 // by about a factor of 2, but slower for small arguments due to overhead. 501 // Requirements: len(x) > 0, len(z) == 2*len(x) 502 // The (non-normalized) result is placed in z. 503 func basicSqr(z, x nat) { 504 n := len(x) 505 tp := getNat(2 * n) 506 t := *tp // temporary variable to hold the products 507 t.clear() 508 z[1], z[0] = mulWW(x[0], x[0]) // the initial square 509 for i := 1; i < n; i++ { 510 d := x[i] 511 // z collects the squares x[i] * x[i] 512 z[2*i+1], z[2*i] = mulWW(d, d) 513 // t collects the products x[i] * x[j] where j < i 514 t[2*i] = addMulVVW(t[i:2*i], x[0:i], d) 515 } 516 t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products 517 addVV(z, z, t) // combine the result 518 putNat(tp) 519 } 520 521 // karatsubaSqr squares x and leaves the result in z. 522 // len(x) must be a power of 2 and len(z) >= 6*len(x). 523 // The (non-normalized) result is placed in z[0 : 2*len(x)]. 524 // 525 // The algorithm and the layout of z are the same as for karatsuba. 526 func karatsubaSqr(z, x nat) { 527 n := len(x) 528 529 if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 { 530 basicSqr(z[:2*n], x) 531 return 532 } 533 534 n2 := n >> 1 535 x1, x0 := x[n2:], x[0:n2] 536 537 karatsubaSqr(z, x0) 538 karatsubaSqr(z[n:], x1) 539 540 // s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0 541 xd := z[2*n : 2*n+n2] 542 if subVV(xd, x1, x0) != 0 { 543 subVV(xd, x0, x1) 544 } 545 546 p := z[n*3:] 547 karatsubaSqr(p, xd) 548 549 r := z[n*4:] 550 copy(r, z[:n*2]) 551 552 karatsubaAdd(z[n2:], r, n) 553 karatsubaAdd(z[n2:], r[n:], n) 554 karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0 555 } 556 557 // Operands that are shorter than basicSqrThreshold are squared using 558 // "grade school" multiplication; for operands longer than karatsubaSqrThreshold 559 // we use the Karatsuba algorithm optimized for x == y. 560 var basicSqrThreshold = 20 // computed by calibrate_test.go 561 var karatsubaSqrThreshold = 260 // computed by calibrate_test.go 562 563 // z = x*x 564 func (z nat) sqr(x nat) nat { 565 n := len(x) 566 switch { 567 case n == 0: 568 return z[:0] 569 case n == 1: 570 d := x[0] 571 z = z.make(2) 572 z[1], z[0] = mulWW(d, d) 573 return z.norm() 574 } 575 576 if alias(z, x) { 577 z = nil // z is an alias for x - cannot reuse 578 } 579 580 if n < basicSqrThreshold { 581 z = z.make(2 * n) 582 basicMul(z, x, x) 583 return z.norm() 584 } 585 if n < karatsubaSqrThreshold { 586 z = z.make(2 * n) 587 basicSqr(z, x) 588 return z.norm() 589 } 590 591 // Use Karatsuba multiplication optimized for x == y. 592 // The algorithm and layout of z are the same as for mul. 593 594 // z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2 595 596 k := karatsubaLen(n, karatsubaSqrThreshold) 597 598 x0 := x[0:k] 599 z = z.make(max(6*k, 2*n)) 600 karatsubaSqr(z, x0) // z = x0^2 601 z = z[0 : 2*n] 602 z[2*k:].clear() 603 604 if k < n { 605 tp := getNat(2 * k) 606 t := *tp 607 x0 := x0.norm() 608 x1 := x[k:] 609 t = t.mul(x0, x1) 610 addAt(z, t, k) 611 addAt(z, t, k) // z = 2*x1*x0*b + x0^2 612 t = t.sqr(x1) 613 addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2 614 putNat(tp) 615 } 616 617 return z.norm() 618 } 619 620 // mulRange computes the product of all the unsigned integers in the 621 // range [a, b] inclusively. If a > b (empty range), the result is 1. 622 func (z nat) mulRange(a, b uint64) nat { 623 switch { 624 case a == 0: 625 // cut long ranges short (optimization) 626 return z.setUint64(0) 627 case a > b: 628 return z.setUint64(1) 629 case a == b: 630 return z.setUint64(a) 631 case a+1 == b: 632 return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) 633 } 634 m := (a + b) / 2 635 return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) 636 } 637 638 // getNat returns a *nat of len n. The contents may not be zero. 639 // The pool holds *nat to avoid allocation when converting to interface{}. 640 func getNat(n int) *nat { 641 var z *nat 642 if v := natPool.Get(); v != nil { 643 z = v.(*nat) 644 } 645 if z == nil { 646 z = new(nat) 647 } 648 *z = z.make(n) 649 if n > 0 { 650 (*z)[0] = 0xfedcb // break code expecting zero 651 } 652 return z 653 } 654 655 func putNat(x *nat) { 656 natPool.Put(x) 657 } 658 659 var natPool sync.Pool 660 661 // bitLen returns the length of x in bits. 662 // Unlike most methods, it works even if x is not normalized. 663 func (x nat) bitLen() int { 664 // This function is used in cryptographic operations. It must not leak 665 // anything but the Int's sign and bit size through side-channels. Any 666 // changes must be reviewed by a security expert. 667 if i := len(x) - 1; i >= 0 { 668 // bits.Len uses a lookup table for the low-order bits on some 669 // architectures. Neutralize any input-dependent behavior by setting all 670 // bits after the first one bit. 671 top := uint(x[i]) 672 top |= top >> 1 673 top |= top >> 2 674 top |= top >> 4 675 top |= top >> 8 676 top |= top >> 16 677 top |= top >> 16 >> 16 // ">> 32" doesn't compile on 32-bit architectures 678 return i*_W + bits.Len(top) 679 } 680 return 0 681 } 682 683 // trailingZeroBits returns the number of consecutive least significant zero 684 // bits of x. 685 func (x nat) trailingZeroBits() uint { 686 if len(x) == 0 { 687 return 0 688 } 689 var i uint 690 for x[i] == 0 { 691 i++ 692 } 693 // x[i] != 0 694 return i*_W + uint(bits.TrailingZeros(uint(x[i]))) 695 } 696 697 // isPow2 returns i, true when x == 2**i and 0, false otherwise. 698 func (x nat) isPow2() (uint, bool) { 699 var i uint 700 for x[i] == 0 { 701 i++ 702 } 703 if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 { 704 return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true 705 } 706 return 0, false 707 } 708 709 func same(x, y nat) bool { 710 return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0] 711 } 712 713 // z = x << s 714 func (z nat) shl(x nat, s uint) nat { 715 if s == 0 { 716 if same(z, x) { 717 return z 718 } 719 if !alias(z, x) { 720 return z.set(x) 721 } 722 } 723 724 m := len(x) 725 if m == 0 { 726 return z[:0] 727 } 728 // m > 0 729 730 n := m + int(s/_W) 731 z = z.make(n + 1) 732 z[n] = shlVU(z[n-m:n], x, s%_W) 733 z[0 : n-m].clear() 734 735 return z.norm() 736 } 737 738 // z = x >> s 739 func (z nat) shr(x nat, s uint) nat { 740 if s == 0 { 741 if same(z, x) { 742 return z 743 } 744 if !alias(z, x) { 745 return z.set(x) 746 } 747 } 748 749 m := len(x) 750 n := m - int(s/_W) 751 if n <= 0 { 752 return z[:0] 753 } 754 // n > 0 755 756 z = z.make(n) 757 shrVU(z, x[m-n:], s%_W) 758 759 return z.norm() 760 } 761 762 func (z nat) setBit(x nat, i uint, b uint) nat { 763 j := int(i / _W) 764 m := Word(1) << (i % _W) 765 n := len(x) 766 switch b { 767 case 0: 768 z = z.make(n) 769 copy(z, x) 770 if j >= n { 771 // no need to grow 772 return z 773 } 774 z[j] &^= m 775 return z.norm() 776 case 1: 777 if j >= n { 778 z = z.make(j + 1) 779 z[n:].clear() 780 } else { 781 z = z.make(n) 782 } 783 copy(z, x) 784 z[j] |= m 785 // no need to normalize 786 return z 787 } 788 panic("set bit is not 0 or 1") 789 } 790 791 // bit returns the value of the i'th bit, with lsb == bit 0. 792 func (x nat) bit(i uint) uint { 793 j := i / _W 794 if j >= uint(len(x)) { 795 return 0 796 } 797 // 0 <= j < len(x) 798 return uint(x[j] >> (i % _W) & 1) 799 } 800 801 // sticky returns 1 if there's a 1 bit within the 802 // i least significant bits, otherwise it returns 0. 803 func (x nat) sticky(i uint) uint { 804 j := i / _W 805 if j >= uint(len(x)) { 806 if len(x) == 0 { 807 return 0 808 } 809 return 1 810 } 811 // 0 <= j < len(x) 812 for _, x := range x[:j] { 813 if x != 0 { 814 return 1 815 } 816 } 817 if x[j]<<(_W-i%_W) != 0 { 818 return 1 819 } 820 return 0 821 } 822 823 func (z nat) and(x, y nat) nat { 824 m := len(x) 825 n := len(y) 826 if m > n { 827 m = n 828 } 829 // m <= n 830 831 z = z.make(m) 832 for i := 0; i < m; i++ { 833 z[i] = x[i] & y[i] 834 } 835 836 return z.norm() 837 } 838 839 // trunc returns z = x mod 2ⁿ. 840 func (z nat) trunc(x nat, n uint) nat { 841 w := (n + _W - 1) / _W 842 if uint(len(x)) < w { 843 return z.set(x) 844 } 845 z = z.make(int(w)) 846 copy(z, x) 847 if n%_W != 0 { 848 z[len(z)-1] &= 1<<(n%_W) - 1 849 } 850 return z.norm() 851 } 852 853 func (z nat) andNot(x, y nat) nat { 854 m := len(x) 855 n := len(y) 856 if n > m { 857 n = m 858 } 859 // m >= n 860 861 z = z.make(m) 862 for i := 0; i < n; i++ { 863 z[i] = x[i] &^ y[i] 864 } 865 copy(z[n:m], x[n:m]) 866 867 return z.norm() 868 } 869 870 func (z nat) or(x, y nat) nat { 871 m := len(x) 872 n := len(y) 873 s := x 874 if m < n { 875 n, m = m, n 876 s = y 877 } 878 // m >= n 879 880 z = z.make(m) 881 for i := 0; i < n; i++ { 882 z[i] = x[i] | y[i] 883 } 884 copy(z[n:m], s[n:m]) 885 886 return z.norm() 887 } 888 889 func (z nat) xor(x, y nat) nat { 890 m := len(x) 891 n := len(y) 892 s := x 893 if m < n { 894 n, m = m, n 895 s = y 896 } 897 // m >= n 898 899 z = z.make(m) 900 for i := 0; i < n; i++ { 901 z[i] = x[i] ^ y[i] 902 } 903 copy(z[n:m], s[n:m]) 904 905 return z.norm() 906 } 907 908 // random creates a random integer in [0..limit), using the space in z if 909 // possible. n is the bit length of limit. 910 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 911 if alias(z, limit) { 912 z = nil // z is an alias for limit - cannot reuse 913 } 914 z = z.make(len(limit)) 915 916 bitLengthOfMSW := uint(n % _W) 917 if bitLengthOfMSW == 0 { 918 bitLengthOfMSW = _W 919 } 920 mask := Word((1 << bitLengthOfMSW) - 1) 921 922 for { 923 switch _W { 924 case 32: 925 for i := range z { 926 z[i] = Word(rand.Uint32()) 927 } 928 case 64: 929 for i := range z { 930 z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 931 } 932 default: 933 panic("unknown word size") 934 } 935 z[len(limit)-1] &= mask 936 if z.cmp(limit) < 0 { 937 break 938 } 939 } 940 941 return z.norm() 942 } 943 944 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; 945 // otherwise it sets z to x**y. The result is the value of z. 946 func (z nat) expNN(x, y, m nat, slow bool) nat { 947 if alias(z, x) || alias(z, y) { 948 // We cannot allow in-place modification of x or y. 949 z = nil 950 } 951 952 // x**y mod 1 == 0 953 if len(m) == 1 && m[0] == 1 { 954 return z.setWord(0) 955 } 956 // m == 0 || m > 1 957 958 // x**0 == 1 959 if len(y) == 0 { 960 return z.setWord(1) 961 } 962 // y > 0 963 964 // 0**y = 0 965 if len(x) == 0 { 966 return z.setWord(0) 967 } 968 // x > 0 969 970 // 1**y = 1 971 if len(x) == 1 && x[0] == 1 { 972 return z.setWord(1) 973 } 974 // x > 1 975 976 // x**1 == x 977 if len(y) == 1 && y[0] == 1 { 978 if len(m) != 0 { 979 return z.rem(x, m) 980 } 981 return z.set(x) 982 } 983 // y > 1 984 985 if len(m) != 0 { 986 // We likely end up being as long as the modulus. 987 z = z.make(len(m)) 988 989 // If the exponent is large, we use the Montgomery method for odd values, 990 // and a 4-bit, windowed exponentiation for powers of two, 991 // and a CRT-decomposed Montgomery method for the remaining values 992 // (even values times non-trivial odd values, which decompose into one 993 // instance of each of the first two cases). 994 if len(y) > 1 && !slow { 995 if m[0]&1 == 1 { 996 return z.expNNMontgomery(x, y, m) 997 } 998 if logM, ok := m.isPow2(); ok { 999 return z.expNNWindowed(x, y, logM) 1000 } 1001 return z.expNNMontgomeryEven(x, y, m) 1002 } 1003 } 1004 1005 z = z.set(x) 1006 v := y[len(y)-1] // v > 0 because y is normalized and y > 0 1007 shift := nlz(v) + 1 1008 v <<= shift 1009 var q nat 1010 1011 const mask = 1 << (_W - 1) 1012 1013 // We walk through the bits of the exponent one by one. Each time we 1014 // see a bit, we square, thus doubling the power. If the bit is a one, 1015 // we also multiply by x, thus adding one to the power. 1016 1017 w := _W - int(shift) 1018 // zz and r are used to avoid allocating in mul and div as 1019 // otherwise the arguments would alias. 1020 var zz, r nat 1021 for j := 0; j < w; j++ { 1022 zz = zz.sqr(z) 1023 zz, z = z, zz 1024 1025 if v&mask != 0 { 1026 zz = zz.mul(z, x) 1027 zz, z = z, zz 1028 } 1029 1030 if len(m) != 0 { 1031 zz, r = zz.div(r, z, m) 1032 zz, r, q, z = q, z, zz, r 1033 } 1034 1035 v <<= 1 1036 } 1037 1038 for i := len(y) - 2; i >= 0; i-- { 1039 v = y[i] 1040 1041 for j := 0; j < _W; j++ { 1042 zz = zz.sqr(z) 1043 zz, z = z, zz 1044 1045 if v&mask != 0 { 1046 zz = zz.mul(z, x) 1047 zz, z = z, zz 1048 } 1049 1050 if len(m) != 0 { 1051 zz, r = zz.div(r, z, m) 1052 zz, r, q, z = q, z, zz, r 1053 } 1054 1055 v <<= 1 1056 } 1057 } 1058 1059 return z.norm() 1060 } 1061 1062 // expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd. 1063 // It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2 1064 // and then uses the Chinese Remainder Theorem to combine the results. 1065 // The recursive call using m1 will use expNNWindowed, 1066 // while the recursive call using m2 will use expNNMontgomery. 1067 // For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”, 1068 // IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994. 1069 // http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf 1070 func (z nat) expNNMontgomeryEven(x, y, m nat) nat { 1071 // Split m = m₁ × m₂ where m₁ = 2ⁿ 1072 n := m.trailingZeroBits() 1073 m1 := nat(nil).shl(natOne, n) 1074 m2 := nat(nil).shr(m, n) 1075 1076 // We want z = x**y mod m. 1077 // z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1 1078 // z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2 1079 // (We are using the math/big convention for names here, 1080 // where the computation is z = x**y mod m, so its parts are z1 and z2. 1081 // The paper is computing x = a**e mod n; it refers to these as x2 and z1.) 1082 z1 := nat(nil).expNN(x, y, m1, false) 1083 z2 := nat(nil).expNN(x, y, m2, false) 1084 1085 // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper, 1086 // which uses only a single modInverse (and an easy one at that). 1087 // p = (z₁ - z₂) × m₂⁻¹ (mod m₁) 1088 // z = z₂ + p × m₂ 1089 // The final addition is in range because: 1090 // z = z₂ + p × m₂ 1091 // ≤ z₂ + (m₁-1) × m₂ 1092 // < m₂ + (m₁-1) × m₂ 1093 // = m₁ × m₂ 1094 // = m. 1095 z = z.set(z2) 1096 1097 // Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1. 1098 z1 = z1.subMod2N(z1, z2, n) 1099 1100 // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]). 1101 m2inv := nat(nil).modInverse(m2, m1) 1102 z2 = z2.mul(z1, m2inv) 1103 z2 = z2.trunc(z2, n) 1104 1105 // Reuse z1 for p * m2. 1106 z = z.add(z, z1.mul(z2, m2)) 1107 1108 return z 1109 } 1110 1111 // expNNWindowed calculates x**y mod m using a fixed, 4-bit window, 1112 // where m = 2**logM. 1113 func (z nat) expNNWindowed(x, y nat, logM uint) nat { 1114 if len(y) <= 1 { 1115 panic("big: misuse of expNNWindowed") 1116 } 1117 if x[0]&1 == 0 { 1118 // len(y) > 1, so y > logM. 1119 // x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM. 1120 return z.setWord(0) 1121 } 1122 if logM == 1 { 1123 return z.setWord(1) 1124 } 1125 1126 // zz is used to avoid allocating in mul as otherwise 1127 // the arguments would alias. 1128 w := int((logM + _W - 1) / _W) 1129 zzp := getNat(w) 1130 zz := *zzp 1131 1132 const n = 4 1133 // powers[i] contains x^i. 1134 var powers [1 << n]*nat 1135 for i := range powers { 1136 powers[i] = getNat(w) 1137 } 1138 *powers[0] = powers[0].set(natOne) 1139 *powers[1] = powers[1].trunc(x, logM) 1140 for i := 2; i < 1<<n; i += 2 { 1141 p2, p, p1 := powers[i/2], powers[i], powers[i+1] 1142 *p = p.sqr(*p2) 1143 *p = p.trunc(*p, logM) 1144 *p1 = p1.mul(*p, x) 1145 *p1 = p1.trunc(*p1, logM) 1146 } 1147 1148 // Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1, 1149 // so we can compute x**(y mod 2**(logM-1)) instead of x**y. 1150 // That is, we can throw away all but the bottom logM-1 bits of y. 1151 // Instead of allocating a new y, we start reading y at the right word 1152 // and truncate it appropriately at the start of the loop. 1153 i := len(y) - 1 1154 mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word. 1155 mmask := ^Word(0) 1156 if mbits := (logM - 1) & (_W - 1); mbits != 0 { 1157 mmask = (1 << mbits) - 1 1158 } 1159 if i > mtop { 1160 i = mtop 1161 } 1162 advance := false 1163 z = z.setWord(1) 1164 for ; i >= 0; i-- { 1165 yi := y[i] 1166 if i == mtop { 1167 yi &= mmask 1168 } 1169 for j := 0; j < _W; j += n { 1170 if advance { 1171 // Account for use of 4 bits in previous iteration. 1172 // Unrolled loop for significant performance 1173 // gain. Use go test -bench=".*" in crypto/rsa 1174 // to check performance before making changes. 1175 zz = zz.sqr(z) 1176 zz, z = z, zz 1177 z = z.trunc(z, logM) 1178 1179 zz = zz.sqr(z) 1180 zz, z = z, zz 1181 z = z.trunc(z, logM) 1182 1183 zz = zz.sqr(z) 1184 zz, z = z, zz 1185 z = z.trunc(z, logM) 1186 1187 zz = zz.sqr(z) 1188 zz, z = z, zz 1189 z = z.trunc(z, logM) 1190 } 1191 1192 zz = zz.mul(z, *powers[yi>>(_W-n)]) 1193 zz, z = z, zz 1194 z = z.trunc(z, logM) 1195 1196 yi <<= n 1197 advance = true 1198 } 1199 } 1200 1201 *zzp = zz 1202 putNat(zzp) 1203 for i := range powers { 1204 putNat(powers[i]) 1205 } 1206 1207 return z.norm() 1208 } 1209 1210 // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. 1211 // Uses Montgomery representation. 1212 func (z nat) expNNMontgomery(x, y, m nat) nat { 1213 numWords := len(m) 1214 1215 // We want the lengths of x and m to be equal. 1216 // It is OK if x >= m as long as len(x) == len(m). 1217 if len(x) > numWords { 1218 _, x = nat(nil).div(nil, x, m) 1219 // Note: now len(x) <= numWords, not guaranteed ==. 1220 } 1221 if len(x) < numWords { 1222 rr := make(nat, numWords) 1223 copy(rr, x) 1224 x = rr 1225 } 1226 1227 // Ideally the precomputations would be performed outside, and reused 1228 // k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson 1229 // Iteration for Multiplicative Inverses Modulo Prime Powers". 1230 k0 := 2 - m[0] 1231 t := m[0] - 1 1232 for i := 1; i < _W; i <<= 1 { 1233 t *= t 1234 k0 *= (t + 1) 1235 } 1236 k0 = -k0 1237 1238 // RR = 2**(2*_W*len(m)) mod m 1239 RR := nat(nil).setWord(1) 1240 zz := nat(nil).shl(RR, uint(2*numWords*_W)) 1241 _, RR = nat(nil).div(RR, zz, m) 1242 if len(RR) < numWords { 1243 zz = zz.make(numWords) 1244 copy(zz, RR) 1245 RR = zz 1246 } 1247 // one = 1, with equal length to that of m 1248 one := make(nat, numWords) 1249 one[0] = 1 1250 1251 const n = 4 1252 // powers[i] contains x^i 1253 var powers [1 << n]nat 1254 powers[0] = powers[0].montgomery(one, RR, m, k0, numWords) 1255 powers[1] = powers[1].montgomery(x, RR, m, k0, numWords) 1256 for i := 2; i < 1<<n; i++ { 1257 powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords) 1258 } 1259 1260 // initialize z = 1 (Montgomery 1) 1261 z = z.make(numWords) 1262 copy(z, powers[0]) 1263 1264 zz = zz.make(numWords) 1265 1266 // same windowed exponent, but with Montgomery multiplications 1267 for i := len(y) - 1; i >= 0; i-- { 1268 yi := y[i] 1269 for j := 0; j < _W; j += n { 1270 if i != len(y)-1 || j != 0 { 1271 zz = zz.montgomery(z, z, m, k0, numWords) 1272 z = z.montgomery(zz, zz, m, k0, numWords) 1273 zz = zz.montgomery(z, z, m, k0, numWords) 1274 z = z.montgomery(zz, zz, m, k0, numWords) 1275 } 1276 zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords) 1277 z, zz = zz, z 1278 yi <<= n 1279 } 1280 } 1281 // convert to regular number 1282 zz = zz.montgomery(z, one, m, k0, numWords) 1283 1284 // One last reduction, just in case. 1285 // See golang.org/issue/13907. 1286 if zz.cmp(m) >= 0 { 1287 // Common case is m has high bit set; in that case, 1288 // since zz is the same length as m, there can be just 1289 // one multiple of m to remove. Just subtract. 1290 // We think that the subtract should be sufficient in general, 1291 // so do that unconditionally, but double-check, 1292 // in case our beliefs are wrong. 1293 // The div is not expected to be reached. 1294 zz = zz.sub(zz, m) 1295 if zz.cmp(m) >= 0 { 1296 _, zz = nat(nil).div(nil, zz, m) 1297 } 1298 } 1299 1300 return zz.norm() 1301 } 1302 1303 // bytes writes the value of z into buf using big-endian encoding. 1304 // The value of z is encoded in the slice buf[i:]. If the value of z 1305 // cannot be represented in buf, bytes panics. The number i of unused 1306 // bytes at the beginning of buf is returned as result. 1307 func (z nat) bytes(buf []byte) (i int) { 1308 // This function is used in cryptographic operations. It must not leak 1309 // anything but the Int's sign and bit size through side-channels. Any 1310 // changes must be reviewed by a security expert. 1311 i = len(buf) 1312 for _, d := range z { 1313 for j := 0; j < _S; j++ { 1314 i-- 1315 if i >= 0 { 1316 buf[i] = byte(d) 1317 } else if byte(d) != 0 { 1318 panic("math/big: buffer too small to fit value") 1319 } 1320 d >>= 8 1321 } 1322 } 1323 1324 if i < 0 { 1325 i = 0 1326 } 1327 for i < len(buf) && buf[i] == 0 { 1328 i++ 1329 } 1330 1331 return 1332 } 1333 1334 // bigEndianWord returns the contents of buf interpreted as a big-endian encoded Word value. 1335 func bigEndianWord(buf []byte) Word { 1336 if _W == 64 { 1337 return Word(binary.BigEndian.Uint64(buf)) 1338 } 1339 return Word(binary.BigEndian.Uint32(buf)) 1340 } 1341 1342 // setBytes interprets buf as the bytes of a big-endian unsigned 1343 // integer, sets z to that value, and returns z. 1344 func (z nat) setBytes(buf []byte) nat { 1345 z = z.make((len(buf) + _S - 1) / _S) 1346 1347 i := len(buf) 1348 for k := 0; i >= _S; k++ { 1349 z[k] = bigEndianWord(buf[i-_S : i]) 1350 i -= _S 1351 } 1352 if i > 0 { 1353 var d Word 1354 for s := uint(0); i > 0; s += 8 { 1355 d |= Word(buf[i-1]) << s 1356 i-- 1357 } 1358 z[len(z)-1] = d 1359 } 1360 1361 return z.norm() 1362 } 1363 1364 // sqrt sets z = ⌊√x⌋ 1365 func (z nat) sqrt(x nat) nat { 1366 if x.cmp(natOne) <= 0 { 1367 return z.set(x) 1368 } 1369 if alias(z, x) { 1370 z = nil 1371 } 1372 1373 // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. 1374 // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). 1375 // https://members.loria.fr/PZimmermann/mca/pub226.html 1376 // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1; 1377 // otherwise it converges to the correct z and stays there. 1378 var z1, z2 nat 1379 z1 = z 1380 z1 = z1.setUint64(1) 1381 z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x 1382 for n := 0; ; n++ { 1383 z2, _ = z2.div(nil, x, z1) 1384 z2 = z2.add(z2, z1) 1385 z2 = z2.shr(z2, 1) 1386 if z2.cmp(z1) >= 0 { 1387 // z1 is answer. 1388 // Figure out whether z1 or z2 is currently aliased to z by looking at loop count. 1389 if n&1 == 0 { 1390 return z1 1391 } 1392 return z.set(z1) 1393 } 1394 z1, z2 = z2, z1 1395 } 1396 } 1397 1398 // subMod2N returns z = (x - y) mod 2ⁿ. 1399 func (z nat) subMod2N(x, y nat, n uint) nat { 1400 if uint(x.bitLen()) > n { 1401 if alias(z, x) { 1402 // ok to overwrite x in place 1403 x = x.trunc(x, n) 1404 } else { 1405 x = nat(nil).trunc(x, n) 1406 } 1407 } 1408 if uint(y.bitLen()) > n { 1409 if alias(z, y) { 1410 // ok to overwrite y in place 1411 y = y.trunc(y, n) 1412 } else { 1413 y = nat(nil).trunc(y, n) 1414 } 1415 } 1416 if x.cmp(y) >= 0 { 1417 return z.sub(x, y) 1418 } 1419 // x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x). 1420 z = z.sub(y, x) 1421 for uint(len(z))*_W < n { 1422 z = append(z, 0) 1423 } 1424 for i := range z { 1425 z[i] = ^z[i] 1426 } 1427 z = z.trunc(z, n) 1428 return z.add(z, natOne) 1429 }