github.com/emmansun/gmsm@v0.29.1/internal/bigmod/nat.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package bigmod 6 7 import ( 8 "encoding/binary" 9 "errors" 10 "math/big" 11 "math/bits" 12 ) 13 14 const ( 15 // _W is the size in bits of our limbs. 16 _W = bits.UintSize 17 // _S is the size in bytes of our limbs. 18 _S = _W / 8 19 ) 20 21 // choice represents a constant-time boolean. The value of choice is always 22 // either 1 or 0. We use an int instead of bool in order to make decisions in 23 // constant time by turning it into a mask. 24 type choice uint 25 26 func not(c choice) choice { return 1 ^ c } 27 28 const yes = choice(1) 29 const no = choice(0) 30 31 // ctMask is all 1s if on is yes, and all 0s otherwise. 32 func ctMask(on choice) uint { return -uint(on) } 33 34 // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this 35 // function does not depend on its inputs. 36 func ctEq(x, y uint) choice { 37 // If x != y, then either x - y or y - x will generate a carry. 38 _, c1 := bits.Sub(x, y, 0) 39 _, c2 := bits.Sub(y, x, 0) 40 return not(choice(c1 | c2)) 41 } 42 43 // ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this 44 // function does not depend on its inputs. 45 func ctGeq(x, y uint) choice { 46 // If x < y, then x - y generates a carry. 47 _, carry := bits.Sub(x, y, 0) 48 return not(choice(carry)) 49 } 50 51 // Nat represents an arbitrary natural number 52 // 53 // Each Nat has an announced length, which is the number of limbs it has stored. 54 // Operations on this number are allowed to leak this length, but will not leak 55 // any information about the values contained in those limbs. 56 type Nat struct { 57 // limbs is little-endian in base 2^W with W = bits.UintSize. 58 limbs []uint 59 } 60 61 // preallocTarget is the size in bits of the numbers used to implement the most 62 // common and most performant RSA key size. It's also enough to cover some of 63 // the operations of key sizes up to 4096. 64 const preallocTarget = 2048 65 const preallocLimbs = (preallocTarget + _W - 1) / _W 66 67 // NewNat returns a new nat with a size of zero, just like new(Nat), but with 68 // the preallocated capacity to hold a number of up to preallocTarget bits. 69 // NewNat inlines, so the allocation can live on the stack. 70 func NewNat() *Nat { 71 limbs := make([]uint, 0, preallocLimbs) 72 return &Nat{limbs} 73 } 74 75 // expand expands x to n limbs, leaving its value unchanged. 76 func (x *Nat) expand(n int) *Nat { 77 if len(x.limbs) > n { 78 panic("bigmod: internal error: shrinking nat") 79 } 80 if cap(x.limbs) < n { 81 newLimbs := make([]uint, n) 82 copy(newLimbs, x.limbs) 83 x.limbs = newLimbs 84 return x 85 } 86 extraLimbs := x.limbs[len(x.limbs):n] 87 for i := range extraLimbs { 88 extraLimbs[i] = 0 89 } 90 x.limbs = x.limbs[:n] 91 return x 92 } 93 94 // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs). 95 func (x *Nat) reset(n int) *Nat { 96 if cap(x.limbs) < n { 97 x.limbs = make([]uint, n) 98 return x 99 } 100 for i := range x.limbs { 101 x.limbs[i] = 0 102 } 103 x.limbs = x.limbs[:n] 104 return x 105 } 106 107 // set assigns x = y, optionally resizing x to the appropriate size. 108 func (x *Nat) Set(y *Nat) *Nat { 109 x.reset(len(y.limbs)) 110 copy(x.limbs, y.limbs) 111 return x 112 } 113 114 // SetBig assigns x = n, optionally resizing n to the appropriate size. 115 // 116 // The announced length of x is set based on the actual bit size of the input, 117 // ignoring leading zeroes. 118 func (x *Nat) SetBig(n *big.Int) *Nat { 119 limbs := n.Bits() 120 x.reset(len(limbs)) 121 for i := range limbs { 122 x.limbs[i] = uint(limbs[i]) 123 } 124 return x 125 } 126 127 // Bytes returns x as a zero-extended big-endian byte slice. The size of the 128 // slice will match the size of m. 129 // 130 // x must have the same size as m and it must be reduced modulo m. 131 func (x *Nat) Bytes(m *Modulus) []byte { 132 i := m.Size() 133 bytes := make([]byte, i) 134 for _, limb := range x.limbs { 135 for j := 0; j < _S; j++ { 136 i-- 137 if i < 0 { 138 if limb == 0 { 139 break 140 } 141 panic("bigmod: modulus is smaller than nat") 142 } 143 bytes[i] = byte(limb) 144 limb >>= 8 145 } 146 } 147 return bytes 148 } 149 150 // SetBytes assigns x = b, where b is a slice of big-endian bytes. 151 // SetBytes returns an error if b >= m. 152 // 153 // The output will be resized to the size of m and overwritten. 154 func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { 155 if err := x.setBytes(b, m); err != nil { 156 return nil, err 157 } 158 if x.CmpGeq(m.nat) == yes { 159 return nil, errors.New("input overflows the modulus") 160 } 161 return x, nil 162 } 163 164 // SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. 165 // SetOverflowingBytes returns an error if b has a longer bit length than m, but 166 // reduces overflowing values up to 2^⌈log2(m)⌉ - 1. 167 // 168 // The output will be resized to the size of m and overwritten. 169 func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { 170 if err := x.setBytes(b, m); err != nil { 171 return nil, err 172 } 173 leading := _W - bitLen(x.limbs[len(x.limbs)-1]) 174 if leading < m.leading { 175 return nil, errors.New("input overflows the modulus size") 176 } 177 x.maybeSubtractModulus(no, m) 178 return x, nil 179 } 180 181 // bigEndianUint returns the contents of buf interpreted as a 182 // big-endian encoded uint value. 183 func bigEndianUint(buf []byte) uint { 184 if _W == 64 { 185 return uint(binary.BigEndian.Uint64(buf)) 186 } 187 return uint(binary.BigEndian.Uint32(buf)) 188 } 189 190 func (x *Nat) setBytes(b []byte, m *Modulus) error { 191 x.resetFor(m) 192 i, k := len(b), 0 193 for k < len(x.limbs) && i >= _S { 194 x.limbs[k] = bigEndianUint(b[i-_S : i]) 195 i -= _S 196 k++ 197 } 198 for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 { 199 x.limbs[k] |= uint(b[i-1]) << s 200 i-- 201 } 202 if i > 0 { 203 return errors.New("input overflows the modulus size") 204 } 205 return nil 206 } 207 208 // Equal returns 1 if x == y, and 0 otherwise. 209 // 210 // Both operands must have the same announced length. 211 func (x *Nat) Equal(y *Nat) choice { 212 // Eliminate bounds checks in the loop. 213 size := len(x.limbs) 214 xLimbs := x.limbs[:size] 215 yLimbs := y.limbs[:size] 216 217 equal := yes 218 for i := 0; i < size; i++ { 219 equal &= ctEq(xLimbs[i], yLimbs[i]) 220 } 221 return equal 222 } 223 224 // IsZero returns 1 if x == 0, and 0 otherwise. 225 func (x *Nat) IsZero() choice { 226 // Eliminate bounds checks in the loop. 227 size := len(x.limbs) 228 xLimbs := x.limbs[:size] 229 230 zero := yes 231 for i := 0; i < size; i++ { 232 zero &= ctEq(xLimbs[i], 0) 233 } 234 return zero 235 } 236 237 // CmpGeq returns 1 if x >= y, and 0 otherwise. 238 // 239 // Both operands must have the same announced length. 240 func (x *Nat) CmpGeq(y *Nat) choice { 241 // Eliminate bounds checks in the loop. 242 size := len(x.limbs) 243 xLimbs := x.limbs[:size] 244 yLimbs := y.limbs[:size] 245 246 var c uint 247 for i := 0; i < size; i++ { 248 _, c = bits.Sub(xLimbs[i], yLimbs[i], c) 249 } 250 // If there was a carry, then subtracting y underflowed, so 251 // x is not greater than or equal to y. 252 return not(choice(c)) 253 } 254 255 // assign sets x <- y if on == 1, and does nothing otherwise. 256 // 257 // Both operands must have the same announced length. 258 func (x *Nat) assign(on choice, y *Nat) *Nat { 259 // Eliminate bounds checks in the loop. 260 size := len(x.limbs) 261 xLimbs := x.limbs[:size] 262 yLimbs := y.limbs[:size] 263 264 mask := ctMask(on) 265 for i := 0; i < size; i++ { 266 xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i]) 267 } 268 return x 269 } 270 271 // add computes x += y and returns the carry. 272 // 273 // Both operands must have the same announced length. 274 func (x *Nat) add(y *Nat) (c uint) { 275 // Eliminate bounds checks in the loop. 276 size := len(x.limbs) 277 xLimbs := x.limbs[:size] 278 yLimbs := y.limbs[:size] 279 280 for i := 0; i < size; i++ { 281 xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c) 282 } 283 return 284 } 285 286 // sub computes x -= y. It returns the borrow of the subtraction. 287 // 288 // Both operands must have the same announced length. 289 func (x *Nat) sub(y *Nat) (c uint) { 290 // Eliminate bounds checks in the loop. 291 size := len(x.limbs) 292 xLimbs := x.limbs[:size] 293 yLimbs := y.limbs[:size] 294 295 for i := 0; i < size; i++ { 296 xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c) 297 } 298 return 299 } 300 301 // Modulus is used for modular arithmetic, precomputing relevant constants. 302 // 303 // Moduli are assumed to be odd numbers. Moduli can also leak the exact 304 // number of bits needed to store their value, and are stored without padding. 305 // 306 // Their actual value is still kept secret. 307 type Modulus struct { 308 // The underlying natural number for this modulus. 309 // 310 // This will be stored without any padding, and shouldn't alias with any 311 // other natural number being used. 312 nat *Nat 313 leading int // number of leading zeros in the modulus 314 m0inv uint // -nat.limbs[0]⁻¹ mod _W 315 rr *Nat // R*R for montgomeryRepresentation 316 } 317 318 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). 319 func rr(m *Modulus) *Nat { 320 rr := NewNat().ExpandFor(m) 321 n := uint(len(rr.limbs)) 322 mLen := uint(m.BitLen()) 323 logR := _W * n 324 325 // We start by computing R = 2^(_W * n) mod m. We can get pretty close, to 326 // 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce. 327 rr.limbs[n-1] = 1 << ((mLen - 1) % _W) 328 // Then we double until we reach 2^(_W * n). 329 for i := mLen - 1; i < logR; i++ { 330 rr.Add(rr, m) 331 } 332 333 // Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in 334 // the Montgomery domain, meaning we can use Montgomery multiplication now). 335 // We could do that by doubling _W * n times, or with a square-and-double 336 // chain log2(_W * n) long. Turns out the fastest thing is to start out with 337 // doublings, and switch to square-and-double once the exponent is large 338 // enough to justify the cost of the multiplications. 339 340 // The threshold is selected experimentally as a linear function of n. 341 threshold := n / 4 342 343 // We calculate how many of the most-significant bits of the exponent we can 344 // compute before crossing the threshold, and we do it with doublings. 345 i := bits.UintSize 346 for logR>>i <= threshold { 347 i-- 348 } 349 for k := uint(0); k < logR>>i; k++ { 350 rr.Add(rr, m) 351 } 352 353 // Then we process the remaining bits of the exponent with a 354 // square-and-double chain. 355 for i > 0 { 356 rr.montgomeryMul(rr, rr, m) 357 i-- 358 if logR>>i&1 != 0 { 359 rr.Add(rr, m) 360 } 361 } 362 363 return rr 364 } 365 366 // minusInverseModW computes -x⁻¹ mod _W with x odd. 367 // 368 // This operation is used to precompute a constant involved in Montgomery 369 // multiplication. 370 func minusInverseModW(x uint) uint { 371 // Every iteration of this loop doubles the least-significant bits of 372 // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, 373 // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough 374 // for 64 bits (and wastes only one iteration for 32 bits). 375 // 376 // See https://crypto.stackexchange.com/a/47496. 377 y := x 378 for i := 0; i < 5; i++ { 379 y = y * (2 - x*y) 380 } 381 return -y 382 } 383 384 // NewModulusFromBig creates a new Modulus from a [big.Int]. 385 // 386 // The Int must be odd. The number of significant bits (and nothing else) is 387 // leaked through timing side-channels. 388 func NewModulusFromBig(n *big.Int) (*Modulus, error) { 389 if b := n.Bits(); len(b) == 0 { 390 return nil, errors.New("modulus must be >= 0") 391 } else if b[0]&1 != 1 { 392 return nil, errors.New("modulus must be odd") 393 } 394 m := &Modulus{} 395 m.nat = NewNat().SetBig(n) 396 m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) 397 m.m0inv = minusInverseModW(m.nat.limbs[0]) 398 m.rr = rr(m) 399 return m, nil 400 } 401 402 // bitLen is a version of bits.Len that only leaks the bit length of n, but not 403 // its value. bits.Len and bits.LeadingZeros use a lookup table for the 404 // low-order bits on some architectures. 405 func bitLen(n uint) int { 406 var len int 407 // We assume, here and elsewhere, that comparison to zero is constant time 408 // with respect to different non-zero values. 409 for n != 0 { 410 len++ 411 n >>= 1 412 } 413 return len 414 } 415 416 // Size returns the size of m in bytes. 417 func (m *Modulus) Size() int { 418 return (m.BitLen() + 7) / 8 419 } 420 421 // BitLen returns the size of m in bits. 422 func (m *Modulus) BitLen() int { 423 return len(m.nat.limbs)*_W - int(m.leading) 424 } 425 426 // Nat returns m as a Nat. The return value must not be written to. 427 func (m *Modulus) Nat() *Nat { 428 return m.nat 429 } 430 431 // shiftIn calculates x = x << _W + y mod m. 432 // 433 // This assumes that x is already reduced mod m, and that y < 2^_W. 434 func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { 435 return x.shiftInNat(y, m.nat) 436 } 437 438 // shiftIn calculates x = x << _W + y mod m. 439 // 440 // This assumes that x is already reduced mod m, and that y < 2^_W. 441 func (x *Nat) shiftInNat(y uint, m *Nat) *Nat { 442 d := NewNat().reset(len(m.limbs)) 443 444 // Eliminate bounds checks in the loop. 445 size := len(m.limbs) 446 xLimbs := x.limbs[:size] 447 dLimbs := d.limbs[:size] 448 mLimbs := m.limbs[:size] 449 450 // Each iteration of this loop computes x = 2x + b mod m, where b is a bit 451 // from y. Effectively, it left-shifts x and adds y one bit at a time, 452 // reducing it every time. 453 // 454 // To do the reduction, each iteration computes both 2x + b and 2x + b - m. 455 // The next iteration (and finally the return line) will use either result 456 // based on whether 2x + b overflows m. 457 needSubtraction := no 458 for i := _W - 1; i >= 0; i-- { 459 carry := (y >> i) & 1 460 var borrow uint 461 mask := ctMask(needSubtraction) 462 for i := 0; i < size; i++ { 463 l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i])) 464 xLimbs[i], carry = bits.Add(l, l, carry) 465 dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow) 466 } 467 // Like in maybeSubtractModulus, we need the subtraction if either it 468 // didn't underflow (meaning 2x + b > m) or if computing 2x + b 469 // overflowed (meaning 2x + b > 2^_W*n > m). 470 needSubtraction = not(choice(borrow)) | choice(carry) 471 } 472 return x.assign(needSubtraction, d) 473 } 474 475 // Mod calculates out = x mod m. 476 // 477 // This works regardless how large the value of x is. 478 // 479 // The output will be resized to the size of m and overwritten. 480 func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { 481 return out.ModNat(x, m.nat) 482 } 483 484 // Mod calculates out = x mod m. 485 // 486 // This works regardless how large the value of x is. 487 // 488 // The output will be resized to the size of m and overwritten. 489 func (out *Nat) ModNat(x *Nat, m *Nat) *Nat { 490 out.reset(len(m.limbs)) 491 // Working our way from the most significant to the least significant limb, 492 // we can insert each limb at the least significant position, shifting all 493 // previous limbs left by _W. This way each limb will get shifted by the 494 // correct number of bits. We can insert at least N - 1 limbs without 495 // overflowing m. After that, we need to reduce every time we shift. 496 i := len(x.limbs) - 1 497 // For the first N - 1 limbs we can skip the actual shifting and position 498 // them at the shifted position, which starts at min(N - 2, i). 499 start := len(m.limbs) - 2 500 if i < start { 501 start = i 502 } 503 for j := start; j >= 0; j-- { 504 out.limbs[j] = x.limbs[i] 505 i-- 506 } 507 // We shift in the remaining limbs, reducing modulo m each time. 508 for i >= 0 { 509 out.shiftInNat(x.limbs[i], m) 510 i-- 511 } 512 return out 513 } 514 515 // ExpandFor ensures out has the right size to work with operations modulo m. 516 // 517 // The announced size of out must be smaller than or equal to that of m. 518 func (out *Nat) ExpandFor(m *Modulus) *Nat { 519 return out.expand(len(m.nat.limbs)) 520 } 521 522 // resetFor ensures out has the right size to work with operations modulo m. 523 // 524 // out is zeroed and may start at any size. 525 func (out *Nat) resetFor(m *Modulus) *Nat { 526 return out.reset(len(m.nat.limbs)) 527 } 528 529 // maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes. 530 // 531 // It can be used to reduce modulo m a value up to 2m - 1, which is a common 532 // range for results computed by higher level operations. 533 // 534 // always is usually a carry that indicates that the operation that produced x 535 // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. 536 // 537 // x and m operands must have the same announced length. 538 func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { 539 t := NewNat().Set(x) 540 underflow := t.sub(m.nat) 541 // We keep the result if x - m didn't underflow (meaning x >= m) 542 // or if always was set. 543 keep := not(choice(underflow)) | choice(always) 544 x.assign(keep, t) 545 } 546 547 // Sub computes x = x - y mod m. 548 // 549 // The length of both operands must be the same as the modulus. Both operands 550 // must already be reduced modulo m. 551 func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { 552 underflow := x.sub(y) 553 // If the subtraction underflowed, add m. 554 t := NewNat().Set(x) 555 t.add(m.nat) 556 x.assign(choice(underflow), t) 557 return x 558 } 559 560 // Add computes x = x + y mod m. 561 // 562 // The length of both operands must be the same as the modulus. Both operands 563 // must already be reduced modulo m. 564 func (x *Nat) Add(y *Nat, m *Modulus) *Nat { 565 overflow := x.add(y) 566 x.maybeSubtractModulus(choice(overflow), m) 567 return x 568 } 569 570 // montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and 571 // n = len(m.nat.limbs). 572 // 573 // Faster Montgomery multiplication replaces standard modular multiplication for 574 // numbers in this representation. 575 // 576 // This assumes that x is already reduced mod m. 577 func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat { 578 // A Montgomery multiplication (which computes a * b / R) by R * R works out 579 // to a multiplication by R, which takes the value out of the Montgomery domain. 580 return x.montgomeryMul(x, m.rr, m) 581 } 582 583 // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and 584 // n = len(m.nat.limbs). 585 // 586 // This assumes that x is already reduced mod m. 587 func (x *Nat) montgomeryReduction(m *Modulus) *Nat { 588 // By Montgomery multiplying with 1 not in Montgomery representation, we 589 // convert out back from Montgomery representation, because it works out to 590 // dividing by R. 591 one := NewNat().ExpandFor(m) 592 one.limbs[0] = 1 593 return x.montgomeryMul(x, one, m) 594 } 595 596 // montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and 597 // n = len(m.nat.limbs), also known as a Montgomery multiplication. 598 // 599 // All inputs should be the same length and already reduced modulo m. 600 // x will be resized to the size of m and overwritten. 601 func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { 602 n := len(m.nat.limbs) 603 mLimbs := m.nat.limbs[:n] 604 aLimbs := a.limbs[:n] 605 bLimbs := b.limbs[:n] 606 607 switch n { 608 default: 609 // Attempt to use a stack-allocated backing array. 610 T := make([]uint, 0, preallocLimbs*2) 611 if cap(T) < n*2 { 612 T = make([]uint, 0, n*2) 613 } 614 T = T[:n*2] 615 616 // This loop implements Word-by-Word Montgomery Multiplication, as 617 // described in Algorithm 4 (Fig. 3) of "Efficient Software 618 // Implementations of Modular Exponentiation" by Shay Gueron 619 // [https://eprint.iacr.org/2011/239.pdf]. 620 var c uint 621 for i := 0; i < n; i++ { 622 _ = T[n+i] // bounds check elimination hint 623 624 // Step 1 (T = a × b) is computed as a large pen-and-paper column 625 // multiplication of two numbers with n base-2^_W digits. If we just 626 // wanted to produce 2n-wide T, we would do 627 // 628 // for i := 0; i < n; i++ { 629 // d := bLimbs[i] 630 // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d) 631 // } 632 // 633 // where d is a digit of the multiplier, T[i:n+i] is the shifted 634 // position of the product of that digit, and T[n+i] is the final carry. 635 // Note that T[i] isn't modified after processing the i-th digit. 636 // 637 // Instead of running two loops, one for Step 1 and one for Steps 2–6, 638 // the result of Step 1 is computed during the next loop. This is 639 // possible because each iteration only uses T[i] in Step 2 and then 640 // discards it in Step 6. 641 d := bLimbs[i] 642 c1 := addMulVVW(T[i:n+i], aLimbs, d) 643 644 // Step 6 is replaced by shifting the virtual window we operate 645 // over: T of the algorithm is T[i:] for us. That means that T1 in 646 // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv. 647 Y := T[i] * m.m0inv 648 649 // Step 4 and 5 add Y × m to T, which as mentioned above is stored 650 // at T[i:]. The two carries (from a × d and Y × m) are added up in 651 // the next word T[n+i], and the carry bit from that addition is 652 // brought forward to the next iteration. 653 c2 := addMulVVW(T[i:n+i], mLimbs, Y) 654 T[n+i], c = bits.Add(c1, c2, c) 655 } 656 657 // Finally for Step 7 we copy the final T window into x, and subtract m 658 // if necessary (which as explained in maybeSubtractModulus can be the 659 // case both if x >= m, or if x overflowed). 660 // 661 // The paper suggests in Section 4 that we can do an "Almost Montgomery 662 // Multiplication" by subtracting only in the overflow case, but the 663 // cost is very similar since the constant time subtraction tells us if 664 // x >= m as a side effect, and taking care of the broken invariant is 665 // highly undesirable (see https://go.dev/issue/13907). 666 copy(x.reset(n).limbs, T[n:]) 667 x.maybeSubtractModulus(choice(c), m) 668 669 // The following specialized cases follow the exact same algorithm, but 670 // optimized for the sizes most used in RSA. addMulVVW is implemented in 671 // assembly with loop unrolling depending on the architecture and bounds 672 // checks are removed by the compiler thanks to the constant size. 673 case 256 / _W: // optimization for 256 bits nat 674 const n = 256 / _W // compiler hint 675 T := make([]uint, n*2) 676 var c uint 677 for i := 0; i < n; i++ { 678 d := bLimbs[i] 679 c1 := addMulVVW256(&T[i], &aLimbs[0], d) 680 Y := T[i] * m.m0inv 681 c2 := addMulVVW256(&T[i], &mLimbs[0], Y) 682 T[n+i], c = bits.Add(c1, c2, c) 683 } 684 copy(x.reset(n).limbs, T[n:]) 685 x.maybeSubtractModulus(choice(c), m) 686 687 case 1024 / _W: 688 const n = 1024 / _W // compiler hint 689 T := make([]uint, n*2) 690 var c uint 691 for i := 0; i < n; i++ { 692 d := bLimbs[i] 693 c1 := addMulVVW1024(&T[i], &aLimbs[0], d) 694 Y := T[i] * m.m0inv 695 c2 := addMulVVW1024(&T[i], &mLimbs[0], Y) 696 T[n+i], c = bits.Add(c1, c2, c) 697 } 698 copy(x.reset(n).limbs, T[n:]) 699 x.maybeSubtractModulus(choice(c), m) 700 701 case 1536 / _W: 702 const n = 1536 / _W // compiler hint 703 T := make([]uint, n*2) 704 var c uint 705 for i := 0; i < n; i++ { 706 d := bLimbs[i] 707 c1 := addMulVVW1536(&T[i], &aLimbs[0], d) 708 Y := T[i] * m.m0inv 709 c2 := addMulVVW1536(&T[i], &mLimbs[0], Y) 710 T[n+i], c = bits.Add(c1, c2, c) 711 } 712 copy(x.reset(n).limbs, T[n:]) 713 x.maybeSubtractModulus(choice(c), m) 714 715 case 2048 / _W: 716 const n = 2048 / _W // compiler hint 717 T := make([]uint, n*2) 718 var c uint 719 for i := 0; i < n; i++ { 720 d := bLimbs[i] 721 c1 := addMulVVW2048(&T[i], &aLimbs[0], d) 722 Y := T[i] * m.m0inv 723 c2 := addMulVVW2048(&T[i], &mLimbs[0], Y) 724 T[n+i], c = bits.Add(c1, c2, c) 725 } 726 copy(x.reset(n).limbs, T[n:]) 727 x.maybeSubtractModulus(choice(c), m) 728 } 729 730 return x 731 } 732 733 // addMulVVW multiplies the multi-word value x by the single-word value y, 734 // adding the result to the multi-word value z and returning the final carry. 735 // It can be thought of as one row of a pen-and-paper column multiplication. 736 func addMulVVW(z, x []uint, y uint) (carry uint) { 737 _ = x[len(z)-1] // bounds check elimination hint 738 for i := range z { 739 hi, lo := bits.Mul(x[i], y) 740 lo, c := bits.Add(lo, z[i], 0) 741 // We use bits.Add with zero to get an add-with-carry instruction that 742 // absorbs the carry from the previous bits.Add. 743 hi, _ = bits.Add(hi, 0, c) 744 lo, c = bits.Add(lo, carry, 0) 745 hi, _ = bits.Add(hi, 0, c) 746 carry = hi 747 z[i] = lo 748 } 749 return carry 750 } 751 752 // Mul calculates x = x * y mod m. 753 // 754 // The length of both operands must be the same as the modulus. Both operands 755 // must already be reduced modulo m. 756 func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { 757 // A Montgomery multiplication by a value out of the Montgomery domain 758 // takes the result out of Montgomery representation. 759 xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m 760 return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m 761 } 762 763 // Exp calculates out = x^e mod m. 764 // 765 // The exponent e is represented in big-endian order. The output will be resized 766 // to the size of m and overwritten. x must already be reduced modulo m. 767 func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { 768 // We use a 4 bit window. For our RSA workload, 4 bit windows are faster 769 // than 2 bit windows, but use an extra 12 nats worth of scratch space. 770 // Using bit sizes that don't divide 8 are more complex to implement, but 771 // are likely to be more efficient if necessary. 772 773 table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1) 774 // newNat calls are unrolled so they are allocated on the stack. 775 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 776 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 777 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 778 } 779 table[0].Set(x).montgomeryRepresentation(m) 780 for i := 1; i < len(table); i++ { 781 table[i].montgomeryMul(table[i-1], table[0], m) 782 } 783 784 out.resetFor(m) 785 out.limbs[0] = 1 786 out.montgomeryRepresentation(m) 787 tmp := NewNat().ExpandFor(m) 788 for _, b := range e { 789 for _, j := range []int{4, 0} { 790 // Square four times. Optimization note: this can be implemented 791 // more efficiently than with generic Montgomery multiplication. 792 out.montgomeryMul(out, out, m) 793 out.montgomeryMul(out, out, m) 794 out.montgomeryMul(out, out, m) 795 out.montgomeryMul(out, out, m) 796 797 // Select x^k in constant time from the table. 798 k := uint((b >> j) & 0b1111) 799 for i := range table { 800 tmp.assign(ctEq(k, uint(i+1)), table[i]) 801 } 802 803 // Multiply by x^k, discarding the result if k = 0. 804 tmp.montgomeryMul(out, tmp, m) 805 out.assign(not(ctEq(k, 0)), tmp) 806 } 807 } 808 809 return out.montgomeryReduction(m) 810 } 811 812 // ExpShortVarTime calculates out = x^e mod m. 813 // 814 // The output will be resized to the size of m and overwritten. x must already 815 // be reduced modulo m. This leaks the exponent through timing side-channels. 816 func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { 817 // For short exponents, precomputing a table and using a window like in Exp 818 // doesn't pay off. Instead, we do a simple conditional square-and-multiply 819 // chain, skipping the initial run of zeroes. 820 xR := NewNat().Set(x).montgomeryRepresentation(m) 821 out.Set(xR) 822 for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ { 823 out.montgomeryMul(out, out, m) 824 if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { 825 out.montgomeryMul(out, xR, m) 826 } 827 } 828 return out.montgomeryReduction(m) 829 }