github.com/ledgerwatch/erigon-lib@v1.0.0/pedersen_hash/big_int.inl (about) 1 #include <iomanip> 2 #include <ios> 3 #include <limits> 4 #include <sstream> 5 #include <tuple> 6 7 #include "math.h" 8 9 namespace starkware { 10 11 template <size_t N> 12 BigInt<N> BigInt<N>::RandomBigInt(Prng* prng) { 13 std::array<uint64_t, N> value{}; 14 for (size_t i = 0; i < N; ++i) { 15 gsl::at(value, i) = prng->RandomUint64(); 16 } 17 return BigInt(value); 18 } 19 20 template <size_t N> 21 template <size_t K> 22 constexpr BigInt<N>::BigInt(const BigInt<K>& v) noexcept : value_{} { 23 static_assert(N > K, "trimming is not supported"); 24 for (size_t i = 0; i < K; ++i) { 25 gsl::at(value_, i) = v[i]; 26 } 27 28 for (size_t i = K; i < N; ++i) { 29 gsl::at(value_, i) = 0; 30 } 31 } 32 33 template <size_t N> 34 constexpr std::pair<BigInt<N>, bool> BigInt<N>::Add(const BigInt& a, const BigInt& b) { 35 bool carry{}; 36 BigInt r{0}; 37 38 for (size_t i = 0; i < N; ++i) { 39 __uint128_t res = static_cast<__uint128_t>(a[i]) + b[i] + carry; 40 carry = (res >> 64) != static_cast<__uint128_t>(0); 41 r[i] = static_cast<uint64_t>(res); 42 } 43 44 return std::make_pair(r, carry); 45 } 46 47 template <size_t N> 48 constexpr BigInt<2 * N> BigInt<N>::operator*(const BigInt<N>& other) const { 49 constexpr auto kResSize = 2 * N; 50 BigInt<kResSize> final_res = BigInt<kResSize>::Zero(); 51 // Multiply this by other using long multiplication algorithm. 52 for (size_t i = 0; i < N; ++i) { 53 uint64_t carry = static_cast<uint64_t>(0U); 54 for (size_t j = 0; j < N; ++j) { 55 // For M == UINT64_MAX, we have: a*b+c+d <= M*M + 2M = (M+1)^2 - 1 == 56 // UINT128_MAX. So we can do a multiplication and an addition without an 57 // overflow. 58 __uint128_t res = Umul128((*this)[j], other[i]) + final_res[i + j] + carry; 59 carry = gsl::narrow_cast<uint64_t>(res >> 64); 60 final_res[i + j] = gsl::narrow_cast<uint64_t>(res); 61 } 62 final_res[i + N] = static_cast<uint64_t>(carry); 63 } 64 return final_res; 65 } 66 67 template <size_t N> 68 BigInt<N> BigInt<N>::MulMod(const BigInt& a, const BigInt& b, const BigInt& modulus) { 69 const BigInt<2 * N> mul_res = a * b; 70 const BigInt<2 * N> mul_res_mod = mul_res.Div(BigInt<2 * N>(modulus)).second; 71 72 BigInt<N> res = Zero(); 73 74 // Trim mul_res_mod to the N lower limbs (this is possible since it must be smaller than modulus). 75 for (size_t i = 0; i < N; ++i) { 76 res[i] = mul_res_mod[i]; 77 } 78 79 return res; 80 } 81 82 template <size_t N> 83 BigInt<N> BigInt<N>::InvModPrime(const BigInt& prime) const { 84 ASSERT(*this != BigInt::Zero(), "Inverse of 0 is not defined."); 85 return GenericPow( 86 *this, (prime - BigInt(2)).ToBoolVector(), BigInt::One(), 87 [&prime](const BigInt& multiplier, BigInt* dst) { *dst = MulMod(*dst, multiplier, prime); }); 88 } 89 90 template <size_t N> 91 constexpr std::pair<BigInt<N>, bool> BigInt<N>::Sub(const BigInt& a, const BigInt& b) { 92 bool carry{}; 93 BigInt r{}; 94 95 for (size_t i = 0; i < N; ++i) { 96 __uint128_t res = static_cast<__uint128_t>(a[i]) - b[i] - carry; 97 carry = (res >> 127) != static_cast<__uint128_t>(0); 98 r[i] = static_cast<uint64_t>(res); 99 } 100 101 return std::make_pair(r, carry); 102 } 103 104 template <size_t N> 105 constexpr bool BigInt<N>::operator<(const BigInt& b) const { 106 return Sub(*this, b).second; 107 } 108 109 template <size_t N> 110 std::pair<BigInt<N>, BigInt<N>> BigInt<N>::Div(const BigInt& divisor) const { 111 // This is a simple long-division implementation. It is not very efficient and can be improved 112 // if this function becomes a bottleneck. 113 ASSERT(divisor != BigInt::Zero(), "Divisor must not be zero."); 114 115 bool carry{}; 116 BigInt res{}; 117 BigInt shifted_divisor{}, tmp{}; 118 BigInt a = *this; 119 120 while (a >= divisor) { 121 tmp = divisor; 122 int shift = -1; 123 do { 124 shifted_divisor = tmp; 125 shift++; 126 std::tie(tmp, carry) = Add(shifted_divisor, shifted_divisor); 127 } while (!carry && tmp <= a); 128 129 a = Sub(a, shifted_divisor).first; 130 res[shift / 64] |= Pow2(shift % 64); 131 } 132 133 return {res, a}; 134 } 135 136 template <size_t N> 137 std::string BigInt<N>::ToString() const { 138 std::ostringstream res; 139 res << "0x"; 140 for (int i = N - 1; i >= 0; --i) { 141 res << std::setfill('0') << std::setw(16) << std::hex << (*this)[i]; 142 } 143 return res.str(); 144 } 145 146 template <size_t N> 147 std::vector<bool> BigInt<N>::ToBoolVector() const { 148 std::vector<bool> res; 149 for (uint64_t value : value_) { 150 for (int i = 0; i < std::numeric_limits<uint64_t>::digits; ++i) { 151 res.push_back((value & 1) != 0); 152 value >>= 1; 153 } 154 } 155 return res; 156 } 157 158 template <size_t N> 159 constexpr bool BigInt<N>::operator==(const BigInt<N>& other) const { 160 for (size_t i = 0; i < N; ++i) { 161 if (gsl::at(value_, i) != gsl::at(other.value_, i)) { 162 return false; 163 } 164 } 165 return true; 166 } 167 168 template <size_t N> 169 constexpr BigInt<N> BigInt<N>::ReduceIfNeeded(const BigInt<N>& x, const BigInt<N>& target) { 170 ASSERT(target.NumLeadingZeros() > 0, "target must have at least one leading zero."); 171 return (x >= target) ? x - target : x; 172 } 173 174 template <size_t N> 175 constexpr BigInt<N> BigInt<N>::MontMul( 176 const BigInt& x, const BigInt& y, const BigInt& modulus, uint64_t montgomery_mprime) { 177 BigInt<N> res{}; 178 ASSERT(modulus.NumLeadingZeros() > 0, "We require at least one leading zero in the modulus"); 179 ASSERT(y < modulus, "y is supposed to be smaller then the modulus"); 180 ASSERT(x < modulus, "x is supposed to be smaller then the modulus."); 181 for (size_t i = 0; i < N; ++i) { 182 __uint128_t temp = Umul128(x[i], y[0]) + res[0]; 183 uint64_t u_i = gsl::narrow_cast<uint64_t>(temp) * montgomery_mprime; 184 uint64_t carry1 = 0, carry2 = 0; 185 186 for (size_t j = 0; j < N; ++j) { 187 if (j != 0) { 188 temp = Umul128(x[i], y[j]) + res[j]; 189 } 190 uint64_t low = carry1 + gsl::narrow_cast<uint64_t>(temp); 191 carry1 = gsl::narrow_cast<uint64_t>(temp >> 64) + static_cast<uint64_t>(low < carry1); 192 temp = Umul128(modulus[j], u_i) + carry2; 193 res[j] = low + gsl::narrow_cast<uint64_t>(temp); 194 carry2 = gsl::narrow_cast<uint64_t>(temp >> 64) + static_cast<uint64_t>(res[j] < low); 195 } 196 for (size_t j = 0; j < N - 1; ++j) { 197 res[j] = res[j + 1]; 198 } 199 res[N - 1] = carry1 + carry2; 200 ASSERT(res[N - 1] >= carry1, "There shouldn't be a carry here."); 201 } 202 return ReduceIfNeeded(res, modulus); 203 } 204 205 template <size_t N> 206 constexpr size_t BigInt<N>::NumLeadingZeros() const { 207 int i = value_.size() - 1; 208 size_t res = 0; 209 210 while (i >= 0 && (gsl::at(value_, i) == 0)) { 211 i--; 212 res += std::numeric_limits<uint64_t>::digits; 213 } 214 215 if (i >= 0) { 216 res += __builtin_clzll(gsl::at(value_, i)); 217 } 218 219 return res; 220 } 221 222 template <size_t N> 223 std::ostream& operator<<(std::ostream& os, const BigInt<N>& bigint) { 224 return os << bigint.ToString(); 225 } 226 227 namespace bigint { 228 namespace details { 229 /* 230 Converts an hex digit ASCII char to the corresponding int. 231 Assumes the input is an hex digit. 232 */ 233 inline constexpr uint64_t HexCharToUint64(char c) { 234 if ('0' <= c && c <= '9') { 235 return c - '0'; 236 } 237 238 if ('A' <= c && c <= 'F') { 239 return c - 'A' + 10; 240 } 241 242 // The function assumes that the input is an hex digit, so we can assume 'a' 243 // <= c && c <= 'f' here. 244 return c - 'a' + 10; 245 } 246 247 template <char... Chars> 248 constexpr auto HexCharArrayToBigInt() { 249 constexpr size_t kLen = sizeof...(Chars); 250 constexpr std::array<char, kLen> kDigits{Chars...}; 251 static_assert(kDigits[0] == '0' && kDigits[1] == 'x', "Only hex input is currently supported"); 252 253 constexpr size_t kNibblesPerUint64 = 2 * sizeof(uint64_t); 254 constexpr size_t kResLen = (kLen - 2 + kNibblesPerUint64 - 1) / (kNibblesPerUint64); 255 std::array<uint64_t, kResLen> res{}; 256 257 for (size_t i = 0; i < kDigits.size() - 2; ++i) { 258 const size_t limb = i / kNibblesPerUint64; 259 const size_t nibble_offset = i % kNibblesPerUint64; 260 const uint64_t nibble = HexCharToUint64(gsl::at(kDigits, kDigits.size() - i - 1)); 261 262 gsl::at(res, limb) |= nibble << (4 * nibble_offset); 263 } 264 265 return BigInt<res.size()>(res); 266 } 267 } // namespace details 268 } // namespace bigint 269 270 template <char... Chars> 271 static constexpr auto operator"" _Z() { 272 // This function is implemented as wrapper that calls the actual 273 // implementation and stores it in a constexpr variable as we want to force 274 // the evaluation to be done in compile time. We need to have the function 275 // call because "constexpr auto kRes = BigInt<res.size()>(res);" won't work 276 // unless res is constexpr. 277 278 // Note that the compiler allows HEX and decimal literals but in any case 279 // it enforces that Chars... contains only HEX (or decimal) characters. 280 constexpr auto kRes = bigint::details::HexCharArrayToBigInt<Chars...>(); 281 return kRes; 282 } 283 284 } // namespace starkware