github.com/ice-blockchain/go/src@v0.0.0-20240403114104-1564d284e521/crypto/internal/bigmod/nat.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package bigmod 6 7 import ( 8 "encoding/binary" 9 "errors" 10 "math/big" 11 "math/bits" 12 ) 13 14 const ( 15 // _W is the size in bits of our limbs. 16 _W = bits.UintSize 17 // _S is the size in bytes of our limbs. 18 _S = _W / 8 19 ) 20 21 // choice represents a constant-time boolean. The value of choice is always 22 // either 1 or 0. We use an int instead of bool in order to make decisions in 23 // constant time by turning it into a mask. 24 type choice uint 25 26 func not(c choice) choice { return 1 ^ c } 27 28 const yes = choice(1) 29 const no = choice(0) 30 31 // ctMask is all 1s if on is yes, and all 0s otherwise. 32 func ctMask(on choice) uint { return -uint(on) } 33 34 // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this 35 // function does not depend on its inputs. 36 func ctEq(x, y uint) choice { 37 // If x != y, then either x - y or y - x will generate a carry. 38 _, c1 := bits.Sub(x, y, 0) 39 _, c2 := bits.Sub(y, x, 0) 40 return not(choice(c1 | c2)) 41 } 42 43 // Nat represents an arbitrary natural number 44 // 45 // Each Nat has an announced length, which is the number of limbs it has stored. 46 // Operations on this number are allowed to leak this length, but will not leak 47 // any information about the values contained in those limbs. 48 type Nat struct { 49 // limbs is little-endian in base 2^W with W = bits.UintSize. 50 limbs []uint 51 } 52 53 // preallocTarget is the size in bits of the numbers used to implement the most 54 // common and most performant RSA key size. It's also enough to cover some of 55 // the operations of key sizes up to 4096. 56 const preallocTarget = 2048 57 const preallocLimbs = (preallocTarget + _W - 1) / _W 58 59 // NewNat returns a new nat with a size of zero, just like new(Nat), but with 60 // the preallocated capacity to hold a number of up to preallocTarget bits. 61 // NewNat inlines, so the allocation can live on the stack. 62 func NewNat() *Nat { 63 limbs := make([]uint, 0, preallocLimbs) 64 return &Nat{limbs} 65 } 66 67 // expand expands x to n limbs, leaving its value unchanged. 68 func (x *Nat) expand(n int) *Nat { 69 if len(x.limbs) > n { 70 panic("bigmod: internal error: shrinking nat") 71 } 72 if cap(x.limbs) < n { 73 newLimbs := make([]uint, n) 74 copy(newLimbs, x.limbs) 75 x.limbs = newLimbs 76 return x 77 } 78 extraLimbs := x.limbs[len(x.limbs):n] 79 for i := range extraLimbs { 80 extraLimbs[i] = 0 81 } 82 x.limbs = x.limbs[:n] 83 return x 84 } 85 86 // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs). 87 func (x *Nat) reset(n int) *Nat { 88 if cap(x.limbs) < n { 89 x.limbs = make([]uint, n) 90 return x 91 } 92 for i := range x.limbs { 93 x.limbs[i] = 0 94 } 95 x.limbs = x.limbs[:n] 96 return x 97 } 98 99 // set assigns x = y, optionally resizing x to the appropriate size. 100 func (x *Nat) set(y *Nat) *Nat { 101 x.reset(len(y.limbs)) 102 copy(x.limbs, y.limbs) 103 return x 104 } 105 106 // setBig assigns x = n, optionally resizing n to the appropriate size. 107 // 108 // The announced length of x is set based on the actual bit size of the input, 109 // ignoring leading zeroes. 110 func (x *Nat) setBig(n *big.Int) *Nat { 111 limbs := n.Bits() 112 x.reset(len(limbs)) 113 for i := range limbs { 114 x.limbs[i] = uint(limbs[i]) 115 } 116 return x 117 } 118 119 // Bytes returns x as a zero-extended big-endian byte slice. The size of the 120 // slice will match the size of m. 121 // 122 // x must have the same size as m and it must be reduced modulo m. 123 func (x *Nat) Bytes(m *Modulus) []byte { 124 i := m.Size() 125 bytes := make([]byte, i) 126 for _, limb := range x.limbs { 127 for j := 0; j < _S; j++ { 128 i-- 129 if i < 0 { 130 if limb == 0 { 131 break 132 } 133 panic("bigmod: modulus is smaller than nat") 134 } 135 bytes[i] = byte(limb) 136 limb >>= 8 137 } 138 } 139 return bytes 140 } 141 142 // SetBytes assigns x = b, where b is a slice of big-endian bytes. 143 // SetBytes returns an error if b >= m. 144 // 145 // The output will be resized to the size of m and overwritten. 146 func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { 147 if err := x.setBytes(b, m); err != nil { 148 return nil, err 149 } 150 if x.cmpGeq(m.nat) == yes { 151 return nil, errors.New("input overflows the modulus") 152 } 153 return x, nil 154 } 155 156 // SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. 157 // SetOverflowingBytes returns an error if b has a longer bit length than m, but 158 // reduces overflowing values up to 2^⌈log2(m)⌉ - 1. 159 // 160 // The output will be resized to the size of m and overwritten. 161 func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { 162 if err := x.setBytes(b, m); err != nil { 163 return nil, err 164 } 165 leading := _W - bitLen(x.limbs[len(x.limbs)-1]) 166 if leading < m.leading { 167 return nil, errors.New("input overflows the modulus size") 168 } 169 x.maybeSubtractModulus(no, m) 170 return x, nil 171 } 172 173 // bigEndianUint returns the contents of buf interpreted as a 174 // big-endian encoded uint value. 175 func bigEndianUint(buf []byte) uint { 176 if _W == 64 { 177 return uint(binary.BigEndian.Uint64(buf)) 178 } 179 return uint(binary.BigEndian.Uint32(buf)) 180 } 181 182 func (x *Nat) setBytes(b []byte, m *Modulus) error { 183 x.resetFor(m) 184 i, k := len(b), 0 185 for k < len(x.limbs) && i >= _S { 186 x.limbs[k] = bigEndianUint(b[i-_S : i]) 187 i -= _S 188 k++ 189 } 190 for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 { 191 x.limbs[k] |= uint(b[i-1]) << s 192 i-- 193 } 194 if i > 0 { 195 return errors.New("input overflows the modulus size") 196 } 197 return nil 198 } 199 200 // Equal returns 1 if x == y, and 0 otherwise. 201 // 202 // Both operands must have the same announced length. 203 func (x *Nat) Equal(y *Nat) choice { 204 // Eliminate bounds checks in the loop. 205 size := len(x.limbs) 206 xLimbs := x.limbs[:size] 207 yLimbs := y.limbs[:size] 208 209 equal := yes 210 for i := 0; i < size; i++ { 211 equal &= ctEq(xLimbs[i], yLimbs[i]) 212 } 213 return equal 214 } 215 216 // IsZero returns 1 if x == 0, and 0 otherwise. 217 func (x *Nat) IsZero() choice { 218 // Eliminate bounds checks in the loop. 219 size := len(x.limbs) 220 xLimbs := x.limbs[:size] 221 222 zero := yes 223 for i := 0; i < size; i++ { 224 zero &= ctEq(xLimbs[i], 0) 225 } 226 return zero 227 } 228 229 // cmpGeq returns 1 if x >= y, and 0 otherwise. 230 // 231 // Both operands must have the same announced length. 232 func (x *Nat) cmpGeq(y *Nat) choice { 233 // Eliminate bounds checks in the loop. 234 size := len(x.limbs) 235 xLimbs := x.limbs[:size] 236 yLimbs := y.limbs[:size] 237 238 var c uint 239 for i := 0; i < size; i++ { 240 _, c = bits.Sub(xLimbs[i], yLimbs[i], c) 241 } 242 // If there was a carry, then subtracting y underflowed, so 243 // x is not greater than or equal to y. 244 return not(choice(c)) 245 } 246 247 // assign sets x <- y if on == 1, and does nothing otherwise. 248 // 249 // Both operands must have the same announced length. 250 func (x *Nat) assign(on choice, y *Nat) *Nat { 251 // Eliminate bounds checks in the loop. 252 size := len(x.limbs) 253 xLimbs := x.limbs[:size] 254 yLimbs := y.limbs[:size] 255 256 mask := ctMask(on) 257 for i := 0; i < size; i++ { 258 xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i]) 259 } 260 return x 261 } 262 263 // add computes x += y and returns the carry. 264 // 265 // Both operands must have the same announced length. 266 func (x *Nat) add(y *Nat) (c uint) { 267 // Eliminate bounds checks in the loop. 268 size := len(x.limbs) 269 xLimbs := x.limbs[:size] 270 yLimbs := y.limbs[:size] 271 272 for i := 0; i < size; i++ { 273 xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c) 274 } 275 return 276 } 277 278 // sub computes x -= y. It returns the borrow of the subtraction. 279 // 280 // Both operands must have the same announced length. 281 func (x *Nat) sub(y *Nat) (c uint) { 282 // Eliminate bounds checks in the loop. 283 size := len(x.limbs) 284 xLimbs := x.limbs[:size] 285 yLimbs := y.limbs[:size] 286 287 for i := 0; i < size; i++ { 288 xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c) 289 } 290 return 291 } 292 293 // Modulus is used for modular arithmetic, precomputing relevant constants. 294 // 295 // Moduli are assumed to be odd numbers. Moduli can also leak the exact 296 // number of bits needed to store their value, and are stored without padding. 297 // 298 // Their actual value is still kept secret. 299 type Modulus struct { 300 // The underlying natural number for this modulus. 301 // 302 // This will be stored without any padding, and shouldn't alias with any 303 // other natural number being used. 304 nat *Nat 305 leading int // number of leading zeros in the modulus 306 m0inv uint // -nat.limbs[0]⁻¹ mod _W 307 rr *Nat // R*R for montgomeryRepresentation 308 } 309 310 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). 311 func rr(m *Modulus) *Nat { 312 rr := NewNat().ExpandFor(m) 313 n := uint(len(rr.limbs)) 314 mLen := uint(m.BitLen()) 315 logR := _W * n 316 317 // We start by computing R = 2^(_W * n) mod m. We can get pretty close, to 318 // 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce. 319 rr.limbs[n-1] = 1 << ((mLen - 1) % _W) 320 // Then we double until we reach 2^(_W * n). 321 for i := mLen - 1; i < logR; i++ { 322 rr.Add(rr, m) 323 } 324 325 // Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in 326 // the Montgomery domain, meaning we can use Montgomery multiplication now). 327 // We could do that by doubling _W * n times, or with a square-and-double 328 // chain log2(_W * n) long. Turns out the fastest thing is to start out with 329 // doublings, and switch to square-and-double once the exponent is large 330 // enough to justify the cost of the multiplications. 331 332 // The threshold is selected experimentally as a linear function of n. 333 threshold := n / 4 334 335 // We calculate how many of the most-significant bits of the exponent we can 336 // compute before crossing the threshold, and we do it with doublings. 337 i := bits.UintSize 338 for logR>>i <= threshold { 339 i-- 340 } 341 for k := uint(0); k < logR>>i; k++ { 342 rr.Add(rr, m) 343 } 344 345 // Then we process the remaining bits of the exponent with a 346 // square-and-double chain. 347 for i > 0 { 348 rr.montgomeryMul(rr, rr, m) 349 i-- 350 if logR>>i&1 != 0 { 351 rr.Add(rr, m) 352 } 353 } 354 355 return rr 356 } 357 358 // minusInverseModW computes -x⁻¹ mod _W with x odd. 359 // 360 // This operation is used to precompute a constant involved in Montgomery 361 // multiplication. 362 func minusInverseModW(x uint) uint { 363 // Every iteration of this loop doubles the least-significant bits of 364 // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, 365 // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough 366 // for 64 bits (and wastes only one iteration for 32 bits). 367 // 368 // See https://crypto.stackexchange.com/a/47496. 369 y := x 370 for i := 0; i < 5; i++ { 371 y = y * (2 - x*y) 372 } 373 return -y 374 } 375 376 // NewModulusFromBig creates a new Modulus from a [big.Int]. 377 // 378 // The Int must be odd. The number of significant bits (and nothing else) is 379 // leaked through timing side-channels. 380 func NewModulusFromBig(n *big.Int) (*Modulus, error) { 381 if b := n.Bits(); len(b) == 0 { 382 return nil, errors.New("modulus must be >= 0") 383 } else if b[0]&1 != 1 { 384 return nil, errors.New("modulus must be odd") 385 } 386 m := &Modulus{} 387 m.nat = NewNat().setBig(n) 388 m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) 389 m.m0inv = minusInverseModW(m.nat.limbs[0]) 390 m.rr = rr(m) 391 return m, nil 392 } 393 394 // bitLen is a version of bits.Len that only leaks the bit length of n, but not 395 // its value. bits.Len and bits.LeadingZeros use a lookup table for the 396 // low-order bits on some architectures. 397 func bitLen(n uint) int { 398 var len int 399 // We assume, here and elsewhere, that comparison to zero is constant time 400 // with respect to different non-zero values. 401 for n != 0 { 402 len++ 403 n >>= 1 404 } 405 return len 406 } 407 408 // Size returns the size of m in bytes. 409 func (m *Modulus) Size() int { 410 return (m.BitLen() + 7) / 8 411 } 412 413 // BitLen returns the size of m in bits. 414 func (m *Modulus) BitLen() int { 415 return len(m.nat.limbs)*_W - int(m.leading) 416 } 417 418 // Nat returns m as a Nat. The return value must not be written to. 419 func (m *Modulus) Nat() *Nat { 420 return m.nat 421 } 422 423 // shiftIn calculates x = x << _W + y mod m. 424 // 425 // This assumes that x is already reduced mod m. 426 func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { 427 d := NewNat().resetFor(m) 428 429 // Eliminate bounds checks in the loop. 430 size := len(m.nat.limbs) 431 xLimbs := x.limbs[:size] 432 dLimbs := d.limbs[:size] 433 mLimbs := m.nat.limbs[:size] 434 435 // Each iteration of this loop computes x = 2x + b mod m, where b is a bit 436 // from y. Effectively, it left-shifts x and adds y one bit at a time, 437 // reducing it every time. 438 // 439 // To do the reduction, each iteration computes both 2x + b and 2x + b - m. 440 // The next iteration (and finally the return line) will use either result 441 // based on whether 2x + b overflows m. 442 needSubtraction := no 443 for i := _W - 1; i >= 0; i-- { 444 carry := (y >> i) & 1 445 var borrow uint 446 mask := ctMask(needSubtraction) 447 for i := 0; i < size; i++ { 448 l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i])) 449 xLimbs[i], carry = bits.Add(l, l, carry) 450 dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow) 451 } 452 // Like in maybeSubtractModulus, we need the subtraction if either it 453 // didn't underflow (meaning 2x + b > m) or if computing 2x + b 454 // overflowed (meaning 2x + b > 2^_W*n > m). 455 needSubtraction = not(choice(borrow)) | choice(carry) 456 } 457 return x.assign(needSubtraction, d) 458 } 459 460 // Mod calculates out = x mod m. 461 // 462 // This works regardless how large the value of x is. 463 // 464 // The output will be resized to the size of m and overwritten. 465 func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { 466 out.resetFor(m) 467 // Working our way from the most significant to the least significant limb, 468 // we can insert each limb at the least significant position, shifting all 469 // previous limbs left by _W. This way each limb will get shifted by the 470 // correct number of bits. We can insert at least N - 1 limbs without 471 // overflowing m. After that, we need to reduce every time we shift. 472 i := len(x.limbs) - 1 473 // For the first N - 1 limbs we can skip the actual shifting and position 474 // them at the shifted position, which starts at min(N - 2, i). 475 start := len(m.nat.limbs) - 2 476 if i < start { 477 start = i 478 } 479 for j := start; j >= 0; j-- { 480 out.limbs[j] = x.limbs[i] 481 i-- 482 } 483 // We shift in the remaining limbs, reducing modulo m each time. 484 for i >= 0 { 485 out.shiftIn(x.limbs[i], m) 486 i-- 487 } 488 return out 489 } 490 491 // ExpandFor ensures x has the right size to work with operations modulo m. 492 // 493 // The announced size of x must be smaller than or equal to that of m. 494 func (x *Nat) ExpandFor(m *Modulus) *Nat { 495 return x.expand(len(m.nat.limbs)) 496 } 497 498 // resetFor ensures out has the right size to work with operations modulo m. 499 // 500 // out is zeroed and may start at any size. 501 func (out *Nat) resetFor(m *Modulus) *Nat { 502 return out.reset(len(m.nat.limbs)) 503 } 504 505 // maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes. 506 // 507 // It can be used to reduce modulo m a value up to 2m - 1, which is a common 508 // range for results computed by higher level operations. 509 // 510 // always is usually a carry that indicates that the operation that produced x 511 // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. 512 // 513 // x and m operands must have the same announced length. 514 func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { 515 t := NewNat().set(x) 516 underflow := t.sub(m.nat) 517 // We keep the result if x - m didn't underflow (meaning x >= m) 518 // or if always was set. 519 keep := not(choice(underflow)) | choice(always) 520 x.assign(keep, t) 521 } 522 523 // Sub computes x = x - y mod m. 524 // 525 // The length of both operands must be the same as the modulus. Both operands 526 // must already be reduced modulo m. 527 func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { 528 underflow := x.sub(y) 529 // If the subtraction underflowed, add m. 530 t := NewNat().set(x) 531 t.add(m.nat) 532 x.assign(choice(underflow), t) 533 return x 534 } 535 536 // Add computes x = x + y mod m. 537 // 538 // The length of both operands must be the same as the modulus. Both operands 539 // must already be reduced modulo m. 540 func (x *Nat) Add(y *Nat, m *Modulus) *Nat { 541 overflow := x.add(y) 542 x.maybeSubtractModulus(choice(overflow), m) 543 return x 544 } 545 546 // montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and 547 // n = len(m.nat.limbs). 548 // 549 // Faster Montgomery multiplication replaces standard modular multiplication for 550 // numbers in this representation. 551 // 552 // This assumes that x is already reduced mod m. 553 func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat { 554 // A Montgomery multiplication (which computes a * b / R) by R * R works out 555 // to a multiplication by R, which takes the value out of the Montgomery domain. 556 return x.montgomeryMul(x, m.rr, m) 557 } 558 559 // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and 560 // n = len(m.nat.limbs). 561 // 562 // This assumes that x is already reduced mod m. 563 func (x *Nat) montgomeryReduction(m *Modulus) *Nat { 564 // By Montgomery multiplying with 1 not in Montgomery representation, we 565 // convert out back from Montgomery representation, because it works out to 566 // dividing by R. 567 one := NewNat().ExpandFor(m) 568 one.limbs[0] = 1 569 return x.montgomeryMul(x, one, m) 570 } 571 572 // montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and 573 // n = len(m.nat.limbs), also known as a Montgomery multiplication. 574 // 575 // All inputs should be the same length and already reduced modulo m. 576 // x will be resized to the size of m and overwritten. 577 func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { 578 n := len(m.nat.limbs) 579 mLimbs := m.nat.limbs[:n] 580 aLimbs := a.limbs[:n] 581 bLimbs := b.limbs[:n] 582 583 switch n { 584 default: 585 // Attempt to use a stack-allocated backing array. 586 T := make([]uint, 0, preallocLimbs*2) 587 if cap(T) < n*2 { 588 T = make([]uint, 0, n*2) 589 } 590 T = T[:n*2] 591 592 // This loop implements Word-by-Word Montgomery Multiplication, as 593 // described in Algorithm 4 (Fig. 3) of "Efficient Software 594 // Implementations of Modular Exponentiation" by Shay Gueron 595 // [https://eprint.iacr.org/2011/239.pdf]. 596 var c uint 597 for i := 0; i < n; i++ { 598 _ = T[n+i] // bounds check elimination hint 599 600 // Step 1 (T = a × b) is computed as a large pen-and-paper column 601 // multiplication of two numbers with n base-2^_W digits. If we just 602 // wanted to produce 2n-wide T, we would do 603 // 604 // for i := 0; i < n; i++ { 605 // d := bLimbs[i] 606 // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d) 607 // } 608 // 609 // where d is a digit of the multiplier, T[i:n+i] is the shifted 610 // position of the product of that digit, and T[n+i] is the final carry. 611 // Note that T[i] isn't modified after processing the i-th digit. 612 // 613 // Instead of running two loops, one for Step 1 and one for Steps 2–6, 614 // the result of Step 1 is computed during the next loop. This is 615 // possible because each iteration only uses T[i] in Step 2 and then 616 // discards it in Step 6. 617 d := bLimbs[i] 618 c1 := addMulVVW(T[i:n+i], aLimbs, d) 619 620 // Step 6 is replaced by shifting the virtual window we operate 621 // over: T of the algorithm is T[i:] for us. That means that T1 in 622 // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv. 623 Y := T[i] * m.m0inv 624 625 // Step 4 and 5 add Y × m to T, which as mentioned above is stored 626 // at T[i:]. The two carries (from a × d and Y × m) are added up in 627 // the next word T[n+i], and the carry bit from that addition is 628 // brought forward to the next iteration. 629 c2 := addMulVVW(T[i:n+i], mLimbs, Y) 630 T[n+i], c = bits.Add(c1, c2, c) 631 } 632 633 // Finally for Step 7 we copy the final T window into x, and subtract m 634 // if necessary (which as explained in maybeSubtractModulus can be the 635 // case both if x >= m, or if x overflowed). 636 // 637 // The paper suggests in Section 4 that we can do an "Almost Montgomery 638 // Multiplication" by subtracting only in the overflow case, but the 639 // cost is very similar since the constant time subtraction tells us if 640 // x >= m as a side effect, and taking care of the broken invariant is 641 // highly undesirable (see https://go.dev/issue/13907). 642 copy(x.reset(n).limbs, T[n:]) 643 x.maybeSubtractModulus(choice(c), m) 644 645 // The following specialized cases follow the exact same algorithm, but 646 // optimized for the sizes most used in RSA. addMulVVW is implemented in 647 // assembly with loop unrolling depending on the architecture and bounds 648 // checks are removed by the compiler thanks to the constant size. 649 case 1024 / _W: 650 const n = 1024 / _W // compiler hint 651 T := make([]uint, n*2) 652 var c uint 653 for i := 0; i < n; i++ { 654 d := bLimbs[i] 655 c1 := addMulVVW1024(&T[i], &aLimbs[0], d) 656 Y := T[i] * m.m0inv 657 c2 := addMulVVW1024(&T[i], &mLimbs[0], Y) 658 T[n+i], c = bits.Add(c1, c2, c) 659 } 660 copy(x.reset(n).limbs, T[n:]) 661 x.maybeSubtractModulus(choice(c), m) 662 663 case 1536 / _W: 664 const n = 1536 / _W // compiler hint 665 T := make([]uint, n*2) 666 var c uint 667 for i := 0; i < n; i++ { 668 d := bLimbs[i] 669 c1 := addMulVVW1536(&T[i], &aLimbs[0], d) 670 Y := T[i] * m.m0inv 671 c2 := addMulVVW1536(&T[i], &mLimbs[0], Y) 672 T[n+i], c = bits.Add(c1, c2, c) 673 } 674 copy(x.reset(n).limbs, T[n:]) 675 x.maybeSubtractModulus(choice(c), m) 676 677 case 2048 / _W: 678 const n = 2048 / _W // compiler hint 679 T := make([]uint, n*2) 680 var c uint 681 for i := 0; i < n; i++ { 682 d := bLimbs[i] 683 c1 := addMulVVW2048(&T[i], &aLimbs[0], d) 684 Y := T[i] * m.m0inv 685 c2 := addMulVVW2048(&T[i], &mLimbs[0], Y) 686 T[n+i], c = bits.Add(c1, c2, c) 687 } 688 copy(x.reset(n).limbs, T[n:]) 689 x.maybeSubtractModulus(choice(c), m) 690 } 691 692 return x 693 } 694 695 // addMulVVW multiplies the multi-word value x by the single-word value y, 696 // adding the result to the multi-word value z and returning the final carry. 697 // It can be thought of as one row of a pen-and-paper column multiplication. 698 func addMulVVW(z, x []uint, y uint) (carry uint) { 699 _ = x[len(z)-1] // bounds check elimination hint 700 for i := range z { 701 hi, lo := bits.Mul(x[i], y) 702 lo, c := bits.Add(lo, z[i], 0) 703 // We use bits.Add with zero to get an add-with-carry instruction that 704 // absorbs the carry from the previous bits.Add. 705 hi, _ = bits.Add(hi, 0, c) 706 lo, c = bits.Add(lo, carry, 0) 707 hi, _ = bits.Add(hi, 0, c) 708 carry = hi 709 z[i] = lo 710 } 711 return carry 712 } 713 714 // Mul calculates x = x * y mod m. 715 // 716 // The length of both operands must be the same as the modulus. Both operands 717 // must already be reduced modulo m. 718 func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { 719 // A Montgomery multiplication by a value out of the Montgomery domain 720 // takes the result out of Montgomery representation. 721 xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m 722 return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m 723 } 724 725 // Exp calculates out = x^e mod m. 726 // 727 // The exponent e is represented in big-endian order. The output will be resized 728 // to the size of m and overwritten. x must already be reduced modulo m. 729 func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { 730 // We use a 4 bit window. For our RSA workload, 4 bit windows are faster 731 // than 2 bit windows, but use an extra 12 nats worth of scratch space. 732 // Using bit sizes that don't divide 8 are more complex to implement, but 733 // are likely to be more efficient if necessary. 734 735 table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1) 736 // newNat calls are unrolled so they are allocated on the stack. 737 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 738 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 739 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 740 } 741 table[0].set(x).montgomeryRepresentation(m) 742 for i := 1; i < len(table); i++ { 743 table[i].montgomeryMul(table[i-1], table[0], m) 744 } 745 746 out.resetFor(m) 747 out.limbs[0] = 1 748 out.montgomeryRepresentation(m) 749 tmp := NewNat().ExpandFor(m) 750 for _, b := range e { 751 for _, j := range []int{4, 0} { 752 // Square four times. Optimization note: this can be implemented 753 // more efficiently than with generic Montgomery multiplication. 754 out.montgomeryMul(out, out, m) 755 out.montgomeryMul(out, out, m) 756 out.montgomeryMul(out, out, m) 757 out.montgomeryMul(out, out, m) 758 759 // Select x^k in constant time from the table. 760 k := uint((b >> j) & 0b1111) 761 for i := range table { 762 tmp.assign(ctEq(k, uint(i+1)), table[i]) 763 } 764 765 // Multiply by x^k, discarding the result if k = 0. 766 tmp.montgomeryMul(out, tmp, m) 767 out.assign(not(ctEq(k, 0)), tmp) 768 } 769 } 770 771 return out.montgomeryReduction(m) 772 } 773 774 // ExpShortVarTime calculates out = x^e mod m. 775 // 776 // The output will be resized to the size of m and overwritten. x must already 777 // be reduced modulo m. This leaks the exponent through timing side-channels. 778 func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { 779 // For short exponents, precomputing a table and using a window like in Exp 780 // doesn't pay off. Instead, we do a simple conditional square-and-multiply 781 // chain, skipping the initial run of zeroes. 782 xR := NewNat().set(x).montgomeryRepresentation(m) 783 out.set(xR) 784 for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ { 785 out.montgomeryMul(out, out, m) 786 if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { 787 out.montgomeryMul(out, xR, m) 788 } 789 } 790 return out.montgomeryReduction(m) 791 }