github.com/platonnetwork/platon-go@v0.7.6/cases/tool/win/bls_win/include/mcl/gmp_util.hpp (about) 1 #pragma once 2 /** 3 @file 4 @brief util function for gmp 5 @author MITSUNARI Shigeo(@herumi) 6 @license modified new BSD license 7 http://opensource.org/licenses/BSD-3-Clause 8 */ 9 #include <stdio.h> 10 #include <stdlib.h> 11 #include <assert.h> 12 #include <stdint.h> 13 #ifndef CYBOZU_DONT_USE_EXCEPTION 14 #include <cybozu/exception.hpp> 15 #endif 16 #include <mcl/randgen.hpp> 17 #ifdef _MSC_VER 18 #pragma warning(push) 19 #pragma warning(disable : 4616) 20 #pragma warning(disable : 4800) 21 #pragma warning(disable : 4244) 22 #pragma warning(disable : 4127) 23 #pragma warning(disable : 4512) 24 #pragma warning(disable : 4146) 25 #endif 26 #if defined(__EMSCRIPTEN__) || defined(__wasm__) 27 #define MCL_USE_VINT 28 #endif 29 #ifndef MCL_MAX_BIT_SIZE 30 #define MCL_MAX_BIT_SIZE 521 31 #endif 32 #ifdef MCL_USE_VINT 33 #include <mcl/vint.hpp> 34 typedef mcl::Vint mpz_class; 35 #else 36 #include <gmpxx.h> 37 #ifdef _MSC_VER 38 #pragma warning(pop) 39 #include <cybozu/link_mpir.hpp> 40 #endif 41 #endif 42 43 #ifndef MCL_SIZEOF_UNIT 44 #if defined(CYBOZU_OS_BIT) && (CYBOZU_OS_BIT == 32) 45 #define MCL_SIZEOF_UNIT 4 46 #else 47 #define MCL_SIZEOF_UNIT 8 48 #endif 49 #endif 50 51 namespace mcl { 52 53 namespace fp { 54 55 #if MCL_SIZEOF_UNIT == 8 56 typedef uint64_t Unit; 57 #else 58 typedef uint32_t Unit; 59 #endif 60 #define MCL_UNIT_BIT_SIZE (MCL_SIZEOF_UNIT * 8) 61 62 } // mcl::fp 63 64 namespace gmp { 65 66 typedef mpz_class ImplType; 67 68 // z = [buf[n-1]:..:buf[1]:buf[0]] 69 // eg. buf[] = {0x12345678, 0xaabbccdd}; => z = 0xaabbccdd12345678; 70 template<class T> 71 void setArray(bool *pb, mpz_class& z, const T *buf, size_t n) 72 { 73 #ifdef MCL_USE_VINT 74 z.setArray(pb, buf, n); 75 #else 76 mpz_import(z.get_mpz_t(), n, -1, sizeof(*buf), 0, 0, buf); 77 *pb = true; 78 #endif 79 } 80 /* 81 buf[0, size) = x 82 buf[size, maxSize) with zero 83 */ 84 template<class T, class U> 85 bool getArray_(T *buf, size_t maxSize, const U *x, int xn)//const mpz_srcptr x) 86 { 87 const size_t bufByteSize = sizeof(T) * maxSize; 88 if (xn < 0) return false; 89 size_t xByteSize = sizeof(*x) * xn; 90 if (xByteSize > bufByteSize) return false; 91 memcpy(buf, x, xByteSize); 92 memset((char*)buf + xByteSize, 0, bufByteSize - xByteSize); 93 return true; 94 } 95 template<class T> 96 void getArray(bool *pb, T *buf, size_t maxSize, const mpz_class& x) 97 { 98 #ifdef MCL_USE_VINT 99 *pb = getArray_(buf, maxSize, x.getUnit(), x.getUnitSize()); 100 #else 101 *pb = getArray_(buf, maxSize, x.get_mpz_t()->_mp_d, x.get_mpz_t()->_mp_size); 102 #endif 103 } 104 inline void set(mpz_class& z, uint64_t x) 105 { 106 bool b; 107 setArray(&b, z, &x, 1); 108 assert(b); 109 (void)b; 110 } 111 inline void setStr(bool *pb, mpz_class& z, const char *str, int base = 0) 112 { 113 #ifdef MCL_USE_VINT 114 z.setStr(pb, str, base); 115 #else 116 *pb = z.set_str(str, base) == 0; 117 #endif 118 } 119 120 /* 121 set buf with string terminated by '\0' 122 return strlen(buf) if success else 0 123 */ 124 inline size_t getStr(char *buf, size_t bufSize, const mpz_class& z, int base = 10) 125 { 126 #ifdef MCL_USE_VINT 127 return z.getStr(buf, bufSize, base); 128 #else 129 __gmp_alloc_cstring tmp(mpz_get_str(0, base, z.get_mpz_t())); 130 size_t n = strlen(tmp.str); 131 if (n + 1 > bufSize) return 0; 132 memcpy(buf, tmp.str, n + 1); 133 return n; 134 #endif 135 } 136 137 #ifndef CYBOZU_DONT_USE_STRING 138 inline void getStr(std::string& str, const mpz_class& z, int base = 10) 139 { 140 #ifdef MCL_USE_VINT 141 z.getStr(str, base); 142 #else 143 str = z.get_str(base); 144 #endif 145 } 146 inline std::string getStr(const mpz_class& z, int base = 10) 147 { 148 std::string s; 149 gmp::getStr(s, z, base); 150 return s; 151 } 152 #endif 153 154 inline void add(mpz_class& z, const mpz_class& x, const mpz_class& y) 155 { 156 #ifdef MCL_USE_VINT 157 Vint::add(z, x, y); 158 #else 159 mpz_add(z.get_mpz_t(), x.get_mpz_t(), y.get_mpz_t()); 160 #endif 161 } 162 #ifndef MCL_USE_VINT 163 inline void add(mpz_class& z, const mpz_class& x, unsigned int y) 164 { 165 mpz_add_ui(z.get_mpz_t(), x.get_mpz_t(), y); 166 } 167 inline void sub(mpz_class& z, const mpz_class& x, unsigned int y) 168 { 169 mpz_sub_ui(z.get_mpz_t(), x.get_mpz_t(), y); 170 } 171 inline void mul(mpz_class& z, const mpz_class& x, unsigned int y) 172 { 173 mpz_mul_ui(z.get_mpz_t(), x.get_mpz_t(), y); 174 } 175 inline void div(mpz_class& q, const mpz_class& x, unsigned int y) 176 { 177 mpz_div_ui(q.get_mpz_t(), x.get_mpz_t(), y); 178 } 179 inline void mod(mpz_class& r, const mpz_class& x, unsigned int m) 180 { 181 mpz_mod_ui(r.get_mpz_t(), x.get_mpz_t(), m); 182 } 183 inline int compare(const mpz_class& x, int y) 184 { 185 return mpz_cmp_si(x.get_mpz_t(), y); 186 } 187 #endif 188 inline void sub(mpz_class& z, const mpz_class& x, const mpz_class& y) 189 { 190 #ifdef MCL_USE_VINT 191 Vint::sub(z, x, y); 192 #else 193 mpz_sub(z.get_mpz_t(), x.get_mpz_t(), y.get_mpz_t()); 194 #endif 195 } 196 inline void mul(mpz_class& z, const mpz_class& x, const mpz_class& y) 197 { 198 #ifdef MCL_USE_VINT 199 Vint::mul(z, x, y); 200 #else 201 mpz_mul(z.get_mpz_t(), x.get_mpz_t(), y.get_mpz_t()); 202 #endif 203 } 204 inline void sqr(mpz_class& z, const mpz_class& x) 205 { 206 #ifdef MCL_USE_VINT 207 Vint::mul(z, x, x); 208 #else 209 mpz_mul(z.get_mpz_t(), x.get_mpz_t(), x.get_mpz_t()); 210 #endif 211 } 212 inline void divmod(mpz_class& q, mpz_class& r, const mpz_class& x, const mpz_class& y) 213 { 214 #ifdef MCL_USE_VINT 215 Vint::divMod(&q, r, x, y); 216 #else 217 mpz_divmod(q.get_mpz_t(), r.get_mpz_t(), x.get_mpz_t(), y.get_mpz_t()); 218 #endif 219 } 220 inline void div(mpz_class& q, const mpz_class& x, const mpz_class& y) 221 { 222 #ifdef MCL_USE_VINT 223 Vint::div(q, x, y); 224 #else 225 mpz_div(q.get_mpz_t(), x.get_mpz_t(), y.get_mpz_t()); 226 #endif 227 } 228 inline void mod(mpz_class& r, const mpz_class& x, const mpz_class& m) 229 { 230 #ifdef MCL_USE_VINT 231 Vint::mod(r, x, m); 232 #else 233 mpz_mod(r.get_mpz_t(), x.get_mpz_t(), m.get_mpz_t()); 234 #endif 235 } 236 inline void clear(mpz_class& z) 237 { 238 #ifdef MCL_USE_VINT 239 z.clear(); 240 #else 241 mpz_set_ui(z.get_mpz_t(), 0); 242 #endif 243 } 244 inline bool isZero(const mpz_class& z) 245 { 246 #ifdef MCL_USE_VINT 247 return z.isZero(); 248 #else 249 return mpz_sgn(z.get_mpz_t()) == 0; 250 #endif 251 } 252 inline bool isNegative(const mpz_class& z) 253 { 254 #ifdef MCL_USE_VINT 255 return z.isNegative(); 256 #else 257 return mpz_sgn(z.get_mpz_t()) < 0; 258 #endif 259 } 260 inline void neg(mpz_class& z, const mpz_class& x) 261 { 262 #ifdef MCL_USE_VINT 263 Vint::neg(z, x); 264 #else 265 mpz_neg(z.get_mpz_t(), x.get_mpz_t()); 266 #endif 267 } 268 inline int compare(const mpz_class& x, const mpz_class & y) 269 { 270 #ifdef MCL_USE_VINT 271 return Vint::compare(x, y); 272 #else 273 return mpz_cmp(x.get_mpz_t(), y.get_mpz_t()); 274 #endif 275 } 276 template<class T> 277 void addMod(mpz_class& z, const mpz_class& x, const T& y, const mpz_class& m) 278 { 279 add(z, x, y); 280 if (compare(z, m) >= 0) { 281 sub(z, z, m); 282 } 283 } 284 template<class T> 285 void subMod(mpz_class& z, const mpz_class& x, const T& y, const mpz_class& m) 286 { 287 sub(z, x, y); 288 if (!isNegative(z)) return; 289 add(z, z, m); 290 } 291 template<class T> 292 void mulMod(mpz_class& z, const mpz_class& x, const T& y, const mpz_class& m) 293 { 294 mul(z, x, y); 295 mod(z, z, m); 296 } 297 inline void sqrMod(mpz_class& z, const mpz_class& x, const mpz_class& m) 298 { 299 sqr(z, x); 300 mod(z, z, m); 301 } 302 // z = x^y (y >= 0) 303 inline void pow(mpz_class& z, const mpz_class& x, unsigned int y) 304 { 305 #ifdef MCL_USE_VINT 306 Vint::pow(z, x, y); 307 #else 308 mpz_pow_ui(z.get_mpz_t(), x.get_mpz_t(), y); 309 #endif 310 } 311 // z = x^y mod m (y >=0) 312 inline void powMod(mpz_class& z, const mpz_class& x, const mpz_class& y, const mpz_class& m) 313 { 314 #ifdef MCL_USE_VINT 315 Vint::powMod(z, x, y, m); 316 #else 317 mpz_powm(z.get_mpz_t(), x.get_mpz_t(), y.get_mpz_t(), m.get_mpz_t()); 318 #endif 319 } 320 // z = 1/x mod m 321 inline void invMod(mpz_class& z, const mpz_class& x, const mpz_class& m) 322 { 323 #ifdef MCL_USE_VINT 324 Vint::invMod(z, x, m); 325 #else 326 mpz_invert(z.get_mpz_t(), x.get_mpz_t(), m.get_mpz_t()); 327 #endif 328 } 329 // z = lcm(x, y) 330 inline void lcm(mpz_class& z, const mpz_class& x, const mpz_class& y) 331 { 332 #ifdef MCL_USE_VINT 333 Vint::lcm(z, x, y); 334 #else 335 mpz_lcm(z.get_mpz_t(), x.get_mpz_t(), y.get_mpz_t()); 336 #endif 337 } 338 inline mpz_class lcm(const mpz_class& x, const mpz_class& y) 339 { 340 mpz_class z; 341 lcm(z, x, y); 342 return z; 343 } 344 // z = gcd(x, y) 345 inline void gcd(mpz_class& z, const mpz_class& x, const mpz_class& y) 346 { 347 #ifdef MCL_USE_VINT 348 Vint::gcd(z, x, y); 349 #else 350 mpz_gcd(z.get_mpz_t(), x.get_mpz_t(), y.get_mpz_t()); 351 #endif 352 } 353 inline mpz_class gcd(const mpz_class& x, const mpz_class& y) 354 { 355 mpz_class z; 356 gcd(z, x, y); 357 return z; 358 } 359 /* 360 assume p : odd prime 361 return 1 if x^2 = a mod p for some x 362 return -1 if x^2 != a mod p for any x 363 */ 364 inline int legendre(const mpz_class& a, const mpz_class& p) 365 { 366 #ifdef MCL_USE_VINT 367 return Vint::jacobi(a, p); 368 #else 369 return mpz_legendre(a.get_mpz_t(), p.get_mpz_t()); 370 #endif 371 } 372 inline bool isPrime(bool *pb, const mpz_class& x) 373 { 374 #ifdef MCL_USE_VINT 375 return x.isPrime(pb, 32); 376 #else 377 *pb = true; 378 return mpz_probab_prime_p(x.get_mpz_t(), 32) != 0; 379 #endif 380 } 381 inline size_t getBitSize(const mpz_class& x) 382 { 383 #ifdef MCL_USE_VINT 384 return x.getBitSize(); 385 #else 386 return mpz_sizeinbase(x.get_mpz_t(), 2); 387 #endif 388 } 389 inline bool testBit(const mpz_class& x, size_t pos) 390 { 391 #ifdef MCL_USE_VINT 392 return x.testBit(pos); 393 #else 394 return mpz_tstbit(x.get_mpz_t(), pos) != 0; 395 #endif 396 } 397 inline void resetBit(mpz_class& x, size_t pos) 398 { 399 #ifdef MCL_USE_VINT 400 x.setBit(pos, false); 401 #else 402 mpz_clrbit(x.get_mpz_t(), pos); 403 #endif 404 } 405 inline void setBit(mpz_class& x, size_t pos, bool v = true) 406 { 407 #ifdef MCL_USE_VINT 408 x.setBit(pos, v); 409 #else 410 if (v) { 411 mpz_setbit(x.get_mpz_t(), pos); 412 } else { 413 resetBit(x, pos); 414 } 415 #endif 416 } 417 inline const fp::Unit *getUnit(const mpz_class& x) 418 { 419 #ifdef MCL_USE_VINT 420 return x.getUnit(); 421 #else 422 return reinterpret_cast<const fp::Unit*>(x.get_mpz_t()->_mp_d); 423 #endif 424 } 425 inline fp::Unit getUnit(const mpz_class& x, size_t i) 426 { 427 return getUnit(x)[i]; 428 } 429 inline size_t getUnitSize(const mpz_class& x) 430 { 431 #ifdef MCL_USE_VINT 432 return x.getUnitSize(); 433 #else 434 return std::abs(x.get_mpz_t()->_mp_size); 435 #endif 436 } 437 inline mpz_class abs(const mpz_class& x) 438 { 439 #ifdef MCL_USE_VINT 440 return Vint::abs(x); 441 #else 442 return ::abs(x); 443 #endif 444 } 445 446 inline void getRand(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()) 447 { 448 if (rg.isZero()) rg = fp::RandGen::get(); 449 assert(bitSize > 1); 450 const size_t rem = bitSize & 31; 451 const size_t n = (bitSize + 31) / 32; 452 uint32_t buf[128]; 453 assert(n <= CYBOZU_NUM_OF_ARRAY(buf)); 454 if (n > CYBOZU_NUM_OF_ARRAY(buf)) { 455 *pb = false; 456 return; 457 } 458 rg.read(pb, buf, n * sizeof(buf[0])); 459 if (!*pb) return; 460 uint32_t v = buf[n - 1]; 461 if (rem == 0) { 462 v |= 1U << 31; 463 } else { 464 v &= (1U << rem) - 1; 465 v |= 1U << (rem - 1); 466 } 467 buf[n - 1] = v; 468 setArray(pb, z, buf, n); 469 } 470 471 inline void getRandPrime(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false) 472 { 473 if (rg.isZero()) rg = fp::RandGen::get(); 474 assert(bitSize > 2); 475 for (;;) { 476 getRand(pb, z, bitSize, rg); 477 if (!*pb) return; 478 if (setSecondBit) { 479 z |= mpz_class(1) << (bitSize - 2); 480 } 481 if (mustBe3mod4) { 482 z |= 3; 483 } 484 bool ret = isPrime(pb, z); 485 if (!*pb) return; 486 if (ret) return; 487 } 488 } 489 inline mpz_class getQuadraticNonResidue(const mpz_class& p) 490 { 491 mpz_class g = 2; 492 while (legendre(g, p) > 0) { 493 ++g; 494 } 495 return g; 496 } 497 498 namespace impl { 499 500 template<class Vec> 501 void convertToBinary(Vec& v, const mpz_class& x) 502 { 503 const size_t len = gmp::getBitSize(x); 504 v.resize(len); 505 for (size_t i = 0; i < len; i++) { 506 v[i] = gmp::testBit(x, len - 1 - i) ? 1 : 0; 507 } 508 } 509 510 template<class Vec> 511 size_t getContinuousVal(const Vec& v, size_t pos, int val) 512 { 513 while (pos >= 2) { 514 if (v[pos] != val) break; 515 pos--; 516 } 517 return pos; 518 } 519 520 template<class Vec> 521 void convertToNAF(Vec& v, const Vec& in) 522 { 523 v.copy(in); 524 size_t pos = v.size() - 1; 525 for (;;) { 526 size_t p = getContinuousVal(v, pos, 0); 527 if (p == 1) return; 528 assert(v[p] == 1); 529 size_t q = getContinuousVal(v, p, 1); 530 if (q == 1) return; 531 assert(v[q] == 0); 532 if (p - q <= 1) { 533 pos = p - 1; 534 continue; 535 } 536 v[q] = 1; 537 for (size_t i = q + 1; i < p; i++) { 538 v[i] = 0; 539 } 540 v[p] = -1; 541 pos = q; 542 } 543 } 544 545 template<class Vec> 546 size_t getNumOfNonZeroElement(const Vec& v) 547 { 548 size_t w = 0; 549 for (size_t i = 0; i < v.size(); i++) { 550 if (v[i]) w++; 551 } 552 return w; 553 } 554 555 } // impl 556 557 /* 558 compute a repl of x which has smaller Hamming weights. 559 return true if naf is selected 560 */ 561 template<class Vec> 562 bool getNAF(Vec& v, const mpz_class& x) 563 { 564 Vec bin; 565 impl::convertToBinary(bin, x); 566 Vec naf; 567 impl::convertToNAF(naf, bin); 568 const size_t binW = impl::getNumOfNonZeroElement(bin); 569 const size_t nafW = impl::getNumOfNonZeroElement(naf); 570 if (nafW < binW) { 571 v.swap(naf); 572 return true; 573 } else { 574 v.swap(bin); 575 return false; 576 } 577 } 578 579 #ifndef CYBOZU_DONT_USE_EXCEPTION 580 inline void setStr(mpz_class& z, const std::string& str, int base = 0) 581 { 582 bool b; 583 setStr(&b, z, str.c_str(), base); 584 if (!b) throw cybozu::Exception("gmp:setStr"); 585 } 586 template<class T> 587 void setArray(mpz_class& z, const T *buf, size_t n) 588 { 589 bool b; 590 setArray(&b, z, buf, n); 591 if (!b) throw cybozu::Exception("gmp:setArray"); 592 } 593 template<class T> 594 void getArray(T *buf, size_t maxSize, const mpz_class& x) 595 { 596 bool b; 597 getArray(&b, buf, maxSize, x); 598 if (!b) throw cybozu::Exception("gmp:getArray"); 599 } 600 inline bool isPrime(const mpz_class& x) 601 { 602 bool b; 603 bool ret = isPrime(&b, x); 604 if (!b) throw cybozu::Exception("gmp:isPrime"); 605 return ret; 606 } 607 inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()) 608 { 609 bool b; 610 getRand(&b, z, bitSize, rg); 611 if (!b) throw cybozu::Exception("gmp:getRand"); 612 } 613 inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false) 614 { 615 bool b; 616 getRandPrime(&b, z, bitSize, rg, setSecondBit, mustBe3mod4); 617 if (!b) throw cybozu::Exception("gmp:getRandPrime"); 618 } 619 #endif 620 621 622 } // mcl::gmp 623 624 /* 625 Tonelli-Shanks 626 */ 627 class SquareRoot { 628 bool isPrecomputed_; 629 bool isPrime; 630 mpz_class p; 631 mpz_class g; 632 int r; 633 mpz_class q; // p - 1 = 2^r q 634 mpz_class s; // s = g^q 635 mpz_class q_add_1_div_2; 636 struct Tbl { 637 const char *p; 638 const char *g; 639 int r; 640 const char *q; 641 const char *s; 642 const char *q_add_1_div_2; 643 }; 644 bool setIfPrecomputed(const mpz_class& p_) 645 { 646 static const Tbl tbl[] = { 647 { // BN254.p 648 "2523648240000001ba344d80000000086121000000000013a700000000000013", 649 "2", 650 1, 651 "1291b24120000000dd1a26c0000000043090800000000009d380000000000009", 652 "2523648240000001ba344d80000000086121000000000013a700000000000012", 653 "948d920900000006e8d1360000000021848400000000004e9c0000000000005", 654 }, 655 { // BN254.r 656 "2523648240000001ba344d8000000007ff9f800000000010a10000000000000d", 657 "2", 658 2, 659 "948d920900000006e8d136000000001ffe7e000000000042840000000000003", 660 "9366c4800000000555150000000000122400000000000015", 661 "4a46c9048000000374689b000000000fff3f000000000021420000000000002", 662 }, 663 { // BLS12_381,p 664 "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", 665 "2", 666 1, 667 "d0088f51cbff34d258dd3db21a5d66bb23ba5c279c2895fb39869507b587b120f55ffff58a9ffffdcff7fffffffd555", 668 "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaaa", 669 "680447a8e5ff9a692c6e9ed90d2eb35d91dd2e13ce144afd9cc34a83dac3d8907aaffffac54ffffee7fbfffffffeaab", 670 }, 671 { // BLS12_381.r 672 "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001", 673 "5", 674 32, 675 "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff", 676 "212d79e5b416b6f0fd56dc8d168d6c0c4024ff270b3e0941b788f500b912f1f", 677 "39f6d3a994cebea4199cec0404d0ec02a9ded2017fff2dff80000000", 678 }, 679 }; 680 for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) { 681 mpz_class targetPrime; 682 bool b; 683 mcl::gmp::setStr(&b, targetPrime, tbl[i].p, 16); 684 if (!b) continue; 685 if (targetPrime != p_) continue; 686 isPrime = true; 687 p = p_; 688 mcl::gmp::setStr(&b, g, tbl[i].g, 16); 689 if (!b) continue; 690 r = tbl[i].r; 691 mcl::gmp::setStr(&b, q, tbl[i].q, 16); 692 if (!b) continue; 693 mcl::gmp::setStr(&b, s, tbl[i].s, 16); 694 if (!b) continue; 695 mcl::gmp::setStr(&b, q_add_1_div_2, tbl[i].q_add_1_div_2, 16); 696 if (!b) continue; 697 isPrecomputed_ = true; 698 return true; 699 } 700 return false; 701 } 702 public: 703 SquareRoot() { clear(); } 704 bool isPrecomputed() const { return isPrecomputed_; } 705 void clear() 706 { 707 isPrecomputed_ = false; 708 isPrime = false; 709 p = 0; 710 g = 0; 711 r = 0; 712 q = 0; 713 s = 0; 714 q_add_1_div_2 = 0; 715 } 716 #if !defined(CYBOZU_DONT_USE_USE_STRING) && !defined(CYBOZU_DONT_USE_EXCEPTION) 717 void dump() const 718 { 719 printf("\"%s\",\n", mcl::gmp::getStr(p, 16).c_str()); 720 printf("\"%s\",\n", mcl::gmp::getStr(g, 16).c_str()); 721 printf("%d,\n", r); 722 printf("\"%s\",\n", mcl::gmp::getStr(q, 16).c_str()); 723 printf("\"%s\",\n", mcl::gmp::getStr(s, 16).c_str()); 724 printf("\"%s\",\n", mcl::gmp::getStr(q_add_1_div_2, 16).c_str()); 725 } 726 #endif 727 void set(bool *pb, const mpz_class& _p, bool usePrecomputedTable = true) 728 { 729 if (usePrecomputedTable && setIfPrecomputed(_p)) { 730 *pb = true; 731 return; 732 } 733 p = _p; 734 if (p <= 2) { 735 *pb = false; 736 return; 737 } 738 isPrime = gmp::isPrime(pb, p); 739 if (!*pb) return; 740 if (!isPrime) { 741 *pb = false; 742 return; 743 } 744 g = gmp::getQuadraticNonResidue(p); 745 // p - 1 = 2^r q, q is odd 746 r = 0; 747 q = p - 1; 748 while ((q & 1) == 0) { 749 r++; 750 q /= 2; 751 } 752 gmp::powMod(s, g, q, p); 753 q_add_1_div_2 = (q + 1) / 2; 754 *pb = true; 755 } 756 /* 757 solve x^2 = a mod p 758 */ 759 bool get(mpz_class& x, const mpz_class& a) const 760 { 761 if (!isPrime) { 762 return false; 763 } 764 if (a == 0) { 765 x = 0; 766 return true; 767 } 768 if (gmp::legendre(a, p) < 0) return false; 769 if (r == 1) { 770 // (p + 1) / 4 = (q + 1) / 2 771 gmp::powMod(x, a, q_add_1_div_2, p); 772 return true; 773 } 774 mpz_class c = s, d; 775 int e = r; 776 gmp::powMod(d, a, q, p); 777 gmp::powMod(x, a, q_add_1_div_2, p); // destroy a if &x == &a 778 mpz_class dd; 779 mpz_class b; 780 while (d != 1) { 781 int i = 1; 782 dd = d * d; dd %= p; 783 while (dd != 1) { 784 dd *= dd; dd %= p; 785 i++; 786 } 787 b = 1; 788 b <<= e - i - 1; 789 gmp::powMod(b, c, b, p); 790 x *= b; x %= p; 791 c = b * b; c %= p; 792 d *= c; d %= p; 793 e = i; 794 } 795 return true; 796 } 797 /* 798 solve x^2 = a in Fp 799 */ 800 template<class Fp> 801 bool get(Fp& x, const Fp& a) const 802 { 803 assert(Fp::getOp().mp == p); 804 if (a == 0) { 805 x = 0; 806 return true; 807 } 808 { 809 bool b; 810 mpz_class aa; 811 a.getMpz(&b, aa); 812 assert(b); 813 if (gmp::legendre(aa, p) < 0) return false; 814 } 815 if (r == 1) { 816 // (p + 1) / 4 = (q + 1) / 2 817 Fp::pow(x, a, q_add_1_div_2); 818 return true; 819 } 820 Fp c, d; 821 { 822 bool b; 823 c.setMpz(&b, s); 824 assert(b); 825 } 826 int e = r; 827 Fp::pow(d, a, q); 828 Fp::pow(x, a, q_add_1_div_2); // destroy a if &x == &a 829 Fp dd; 830 Fp b; 831 while (!d.isOne()) { 832 int i = 1; 833 Fp::sqr(dd, d); 834 while (!dd.isOne()) { 835 dd *= dd; 836 i++; 837 } 838 b = 1; 839 // b <<= e - i - 1; 840 for (int j = 0; j < e - i - 1; j++) { 841 b += b; 842 } 843 Fp::pow(b, c, b); 844 x *= b; 845 Fp::sqr(c, b); 846 d *= c; 847 e = i; 848 } 849 return true; 850 } 851 bool operator==(const SquareRoot& rhs) const 852 { 853 return isPrime == rhs.isPrime && p == rhs.p && g == rhs.g && r == rhs.r 854 && q == rhs.q && s == rhs.s && q_add_1_div_2 == rhs.q_add_1_div_2; 855 } 856 bool operator!=(const SquareRoot& rhs) const { return !operator==(rhs); } 857 #ifndef CYBOZU_DONT_USE_EXCEPTION 858 void set(const mpz_class& _p) 859 { 860 bool b; 861 set(&b, _p); 862 if (!b) throw cybozu::Exception("gmp:SquareRoot:set"); 863 } 864 #endif 865 }; 866 867 /* 868 Barrett Reduction 869 for non GMP version 870 mod of GMP is faster than Modp 871 */ 872 struct Modp { 873 static const size_t unitBitSize = sizeof(mcl::fp::Unit) * 8; 874 mpz_class p_; 875 mpz_class u_; 876 mpz_class a_; 877 size_t pBitSize_; 878 size_t N_; 879 bool initU_; // Is u_ initialized? 880 Modp() 881 : pBitSize_(0) 882 , N_(0) 883 , initU_(false) 884 { 885 } 886 // x &= 1 << (unitBitSize * unitSize) 887 void shrinkSize(mpz_class &x, size_t unitSize) const 888 { 889 size_t u = gmp::getUnitSize(x); 890 if (u < unitSize) return; 891 bool b; 892 gmp::setArray(&b, x, gmp::getUnit(x), unitSize); 893 (void)b; 894 assert(b); 895 } 896 // p_ is set by p and compute (u_, a_) if possible 897 void init(const mpz_class& p) 898 { 899 p_ = p; 900 pBitSize_ = gmp::getBitSize(p); 901 N_ = (pBitSize_ + unitBitSize - 1) / unitBitSize; 902 initU_ = false; 903 #if 0 904 u_ = (mpz_class(1) << (unitBitSize * 2 * N_)) / p_; 905 #else 906 /* 907 1 << (unitBitSize * 2 * N_) may be overflow, 908 so use (1 << (unitBitSize * 2 * N_)) - 1 because u_ is same. 909 */ 910 uint8_t buf[48 * 2]; 911 const size_t byteSize = unitBitSize / 8 * 2 * N_; 912 if (byteSize > sizeof(buf)) return; 913 memset(buf, 0xff, byteSize); 914 bool b; 915 gmp::setArray(&b, u_, buf, byteSize); 916 if (!b) return; 917 #endif 918 u_ /= p_; 919 a_ = mpz_class(1) << (unitBitSize * (N_ + 1)); 920 initU_ = true; 921 } 922 void modp(mpz_class& r, const mpz_class& t) const 923 { 924 assert(p_ > 0); 925 const size_t tBitSize = gmp::getBitSize(t); 926 // use gmp::mod if init() fails or t is too large 927 if (tBitSize > unitBitSize * 2 * N_ || !initU_) { 928 gmp::mod(r, t, p_); 929 return; 930 } 931 if (tBitSize < pBitSize_) { 932 r = t; 933 return; 934 } 935 // mod is faster than modp if t is small 936 if (tBitSize <= unitBitSize * N_) { 937 gmp::mod(r, t, p_); 938 return; 939 } 940 mpz_class q; 941 q = t; 942 q >>= unitBitSize * (N_ - 1); 943 q *= u_; 944 q >>= unitBitSize * (N_ + 1); 945 q *= p_; 946 shrinkSize(q, N_ + 1); 947 r = t; 948 shrinkSize(r, N_ + 1); 949 r -= q; 950 if (r < 0) { 951 r += a_; 952 } 953 if (r >= p_) { 954 r -= p_; 955 } 956 } 957 }; 958 959 } // mcl