github.com/ledgerwatch/erigon-lib@v1.0.0/pedersen_hash/big_int.inl (about)

     1  #include <iomanip>
     2  #include <ios>
     3  #include <limits>
     4  #include <sstream>
     5  #include <tuple>
     6  
     7  #include "math.h"
     8  
     9  namespace starkware {
    10  
    11  template <size_t N>
    12  BigInt<N> BigInt<N>::RandomBigInt(Prng* prng) {
    13    std::array<uint64_t, N> value{};
    14    for (size_t i = 0; i < N; ++i) {
    15      gsl::at(value, i) = prng->RandomUint64();
    16    }
    17    return BigInt(value);
    18  }
    19  
    20  template <size_t N>
    21  template <size_t K>
    22  constexpr BigInt<N>::BigInt(const BigInt<K>& v) noexcept : value_{} {
    23    static_assert(N > K, "trimming is not supported");
    24    for (size_t i = 0; i < K; ++i) {
    25      gsl::at(value_, i) = v[i];
    26    }
    27  
    28    for (size_t i = K; i < N; ++i) {
    29      gsl::at(value_, i) = 0;
    30    }
    31  }
    32  
    33  template <size_t N>
    34  constexpr std::pair<BigInt<N>, bool> BigInt<N>::Add(const BigInt& a, const BigInt& b) {
    35    bool carry{};
    36    BigInt r{0};
    37  
    38    for (size_t i = 0; i < N; ++i) {
    39      __uint128_t res = static_cast<__uint128_t>(a[i]) + b[i] + carry;
    40      carry = (res >> 64) != static_cast<__uint128_t>(0);
    41      r[i] = static_cast<uint64_t>(res);
    42    }
    43  
    44    return std::make_pair(r, carry);
    45  }
    46  
    47  template <size_t N>
    48  constexpr BigInt<2 * N> BigInt<N>::operator*(const BigInt<N>& other) const {
    49    constexpr auto kResSize = 2 * N;
    50    BigInt<kResSize> final_res = BigInt<kResSize>::Zero();
    51    // Multiply this by other using long multiplication algorithm.
    52    for (size_t i = 0; i < N; ++i) {
    53      uint64_t carry = static_cast<uint64_t>(0U);
    54      for (size_t j = 0; j < N; ++j) {
    55        // For M == UINT64_MAX, we have: a*b+c+d <= M*M + 2M = (M+1)^2 - 1 ==
    56        // UINT128_MAX. So we can do a multiplication and an addition without an
    57        // overflow.
    58        __uint128_t res = Umul128((*this)[j], other[i]) + final_res[i + j] + carry;
    59        carry = gsl::narrow_cast<uint64_t>(res >> 64);
    60        final_res[i + j] = gsl::narrow_cast<uint64_t>(res);
    61      }
    62      final_res[i + N] = static_cast<uint64_t>(carry);
    63    }
    64    return final_res;
    65  }
    66  
    67  template <size_t N>
    68  BigInt<N> BigInt<N>::MulMod(const BigInt& a, const BigInt& b, const BigInt& modulus) {
    69    const BigInt<2 * N> mul_res = a * b;
    70    const BigInt<2 * N> mul_res_mod = mul_res.Div(BigInt<2 * N>(modulus)).second;
    71  
    72    BigInt<N> res = Zero();
    73  
    74    // Trim mul_res_mod to the N lower limbs (this is possible since it must be smaller than modulus).
    75    for (size_t i = 0; i < N; ++i) {
    76      res[i] = mul_res_mod[i];
    77    }
    78  
    79    return res;
    80  }
    81  
    82  template <size_t N>
    83  BigInt<N> BigInt<N>::InvModPrime(const BigInt& prime) const {
    84    ASSERT(*this != BigInt::Zero(), "Inverse of 0 is not defined.");
    85    return GenericPow(
    86        *this, (prime - BigInt(2)).ToBoolVector(), BigInt::One(),
    87        [&prime](const BigInt& multiplier, BigInt* dst) { *dst = MulMod(*dst, multiplier, prime); });
    88  }
    89  
    90  template <size_t N>
    91  constexpr std::pair<BigInt<N>, bool> BigInt<N>::Sub(const BigInt& a, const BigInt& b) {
    92    bool carry{};
    93    BigInt r{};
    94  
    95    for (size_t i = 0; i < N; ++i) {
    96      __uint128_t res = static_cast<__uint128_t>(a[i]) - b[i] - carry;
    97      carry = (res >> 127) != static_cast<__uint128_t>(0);
    98      r[i] = static_cast<uint64_t>(res);
    99    }
   100  
   101    return std::make_pair(r, carry);
   102  }
   103  
   104  template <size_t N>
   105  constexpr bool BigInt<N>::operator<(const BigInt& b) const {
   106    return Sub(*this, b).second;
   107  }
   108  
   109  template <size_t N>
   110  std::pair<BigInt<N>, BigInt<N>> BigInt<N>::Div(const BigInt& divisor) const {
   111    // This is a simple long-division implementation. It is not very efficient and can be improved
   112    // if this function becomes a bottleneck.
   113    ASSERT(divisor != BigInt::Zero(), "Divisor must not be zero.");
   114  
   115    bool carry{};
   116    BigInt res{};
   117    BigInt shifted_divisor{}, tmp{};
   118    BigInt a = *this;
   119  
   120    while (a >= divisor) {
   121      tmp = divisor;
   122      int shift = -1;
   123      do {
   124        shifted_divisor = tmp;
   125        shift++;
   126        std::tie(tmp, carry) = Add(shifted_divisor, shifted_divisor);
   127      } while (!carry && tmp <= a);
   128  
   129      a = Sub(a, shifted_divisor).first;
   130      res[shift / 64] |= Pow2(shift % 64);
   131    }
   132  
   133    return {res, a};
   134  }
   135  
   136  template <size_t N>
   137  std::string BigInt<N>::ToString() const {
   138    std::ostringstream res;
   139    res << "0x";
   140    for (int i = N - 1; i >= 0; --i) {
   141      res << std::setfill('0') << std::setw(16) << std::hex << (*this)[i];
   142    }
   143    return res.str();
   144  }
   145  
   146  template <size_t N>
   147  std::vector<bool> BigInt<N>::ToBoolVector() const {
   148    std::vector<bool> res;
   149    for (uint64_t value : value_) {
   150      for (int i = 0; i < std::numeric_limits<uint64_t>::digits; ++i) {
   151        res.push_back((value & 1) != 0);
   152        value >>= 1;
   153      }
   154    }
   155    return res;
   156  }
   157  
   158  template <size_t N>
   159  constexpr bool BigInt<N>::operator==(const BigInt<N>& other) const {
   160    for (size_t i = 0; i < N; ++i) {
   161      if (gsl::at(value_, i) != gsl::at(other.value_, i)) {
   162        return false;
   163      }
   164    }
   165    return true;
   166  }
   167  
   168  template <size_t N>
   169  constexpr BigInt<N> BigInt<N>::ReduceIfNeeded(const BigInt<N>& x, const BigInt<N>& target) {
   170    ASSERT(target.NumLeadingZeros() > 0, "target must have at least one leading zero.");
   171    return (x >= target) ? x - target : x;
   172  }
   173  
   174  template <size_t N>
   175  constexpr BigInt<N> BigInt<N>::MontMul(
   176      const BigInt& x, const BigInt& y, const BigInt& modulus, uint64_t montgomery_mprime) {
   177    BigInt<N> res{};
   178    ASSERT(modulus.NumLeadingZeros() > 0, "We require at least one leading zero in the modulus");
   179    ASSERT(y < modulus, "y is supposed to be smaller then the modulus");
   180    ASSERT(x < modulus, "x is supposed to be smaller then the modulus.");
   181    for (size_t i = 0; i < N; ++i) {
   182      __uint128_t temp = Umul128(x[i], y[0]) + res[0];
   183      uint64_t u_i = gsl::narrow_cast<uint64_t>(temp) * montgomery_mprime;
   184      uint64_t carry1 = 0, carry2 = 0;
   185  
   186      for (size_t j = 0; j < N; ++j) {
   187        if (j != 0) {
   188          temp = Umul128(x[i], y[j]) + res[j];
   189        }
   190        uint64_t low = carry1 + gsl::narrow_cast<uint64_t>(temp);
   191        carry1 = gsl::narrow_cast<uint64_t>(temp >> 64) + static_cast<uint64_t>(low < carry1);
   192        temp = Umul128(modulus[j], u_i) + carry2;
   193        res[j] = low + gsl::narrow_cast<uint64_t>(temp);
   194        carry2 = gsl::narrow_cast<uint64_t>(temp >> 64) + static_cast<uint64_t>(res[j] < low);
   195      }
   196      for (size_t j = 0; j < N - 1; ++j) {
   197        res[j] = res[j + 1];
   198      }
   199      res[N - 1] = carry1 + carry2;
   200      ASSERT(res[N - 1] >= carry1, "There shouldn't be a carry here.");
   201    }
   202    return ReduceIfNeeded(res, modulus);
   203  }
   204  
   205  template <size_t N>
   206  constexpr size_t BigInt<N>::NumLeadingZeros() const {
   207    int i = value_.size() - 1;
   208    size_t res = 0;
   209  
   210    while (i >= 0 && (gsl::at(value_, i) == 0)) {
   211      i--;
   212      res += std::numeric_limits<uint64_t>::digits;
   213    }
   214  
   215    if (i >= 0) {
   216      res += __builtin_clzll(gsl::at(value_, i));
   217    }
   218  
   219    return res;
   220  }
   221  
   222  template <size_t N>
   223  std::ostream& operator<<(std::ostream& os, const BigInt<N>& bigint) {
   224    return os << bigint.ToString();
   225  }
   226  
   227  namespace bigint {
   228  namespace details {
   229  /*
   230    Converts an hex digit ASCII char to the corresponding int.
   231    Assumes the input is an hex digit.
   232  */
   233  inline constexpr uint64_t HexCharToUint64(char c) {
   234    if ('0' <= c && c <= '9') {
   235      return c - '0';
   236    }
   237  
   238    if ('A' <= c && c <= 'F') {
   239      return c - 'A' + 10;
   240    }
   241  
   242    // The function assumes that the input is an hex digit, so we can assume 'a'
   243    // <= c && c <= 'f' here.
   244    return c - 'a' + 10;
   245  }
   246  
   247  template <char... Chars>
   248  constexpr auto HexCharArrayToBigInt() {
   249    constexpr size_t kLen = sizeof...(Chars);
   250    constexpr std::array<char, kLen> kDigits{Chars...};
   251    static_assert(kDigits[0] == '0' && kDigits[1] == 'x', "Only hex input is currently supported");
   252  
   253    constexpr size_t kNibblesPerUint64 = 2 * sizeof(uint64_t);
   254    constexpr size_t kResLen = (kLen - 2 + kNibblesPerUint64 - 1) / (kNibblesPerUint64);
   255    std::array<uint64_t, kResLen> res{};
   256  
   257    for (size_t i = 0; i < kDigits.size() - 2; ++i) {
   258      const size_t limb = i / kNibblesPerUint64;
   259      const size_t nibble_offset = i % kNibblesPerUint64;
   260      const uint64_t nibble = HexCharToUint64(gsl::at(kDigits, kDigits.size() - i - 1));
   261  
   262      gsl::at(res, limb) |= nibble << (4 * nibble_offset);
   263    }
   264  
   265    return BigInt<res.size()>(res);
   266  }
   267  }  // namespace details
   268  }  // namespace bigint
   269  
   270  template <char... Chars>
   271  static constexpr auto operator"" _Z() {
   272    // This function is implemented as wrapper that calls the actual
   273    // implementation and stores it in a constexpr variable as we want to force
   274    // the evaluation to be done in compile time. We need to have the function
   275    // call because "constexpr auto kRes = BigInt<res.size()>(res);" won't work
   276    // unless res is constexpr.
   277  
   278    // Note that the compiler allows HEX and decimal literals but in any case
   279    // it enforces that Chars... contains only HEX (or decimal) characters.
   280    constexpr auto kRes = bigint::details::HexCharArrayToBigInt<Chars...>();
   281    return kRes;
   282  }
   283  
   284  }  // namespace starkware