github.com/aergoio/aergo@v1.3.1/libtool/src/gmp-6.1.2/mpn/generic/mul_fft.c (about) 1 /* Schoenhage's fast multiplication modulo 2^N+1. 2 3 Contributed by Paul Zimmermann. 4 5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY 6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST 7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. 8 9 Copyright 1998-2010, 2012, 2013 Free Software Foundation, Inc. 10 11 This file is part of the GNU MP Library. 12 13 The GNU MP Library is free software; you can redistribute it and/or modify 14 it under the terms of either: 15 16 * the GNU Lesser General Public License as published by the Free 17 Software Foundation; either version 3 of the License, or (at your 18 option) any later version. 19 20 or 21 22 * the GNU General Public License as published by the Free Software 23 Foundation; either version 2 of the License, or (at your option) any 24 later version. 25 26 or both in parallel, as here. 27 28 The GNU MP Library is distributed in the hope that it will be useful, but 29 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 30 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 31 for more details. 32 33 You should have received copies of the GNU General Public License and the 34 GNU Lesser General Public License along with the GNU MP Library. If not, 35 see https://www.gnu.org/licenses/. */ 36 37 38 /* References: 39 40 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker 41 Strassen, Computing 7, p. 281-292, 1971. 42 43 Asymptotically fast algorithms for the numerical multiplication and division 44 of polynomials with complex coefficients, by Arnold Schoenhage, Computer 45 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982. 46 47 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold 48 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986. 49 50 TODO: 51 52 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and 53 Zimmermann. 54 55 It might be possible to avoid a small number of MPN_COPYs by using a 56 rotating temporary or two. 57 58 Cleanup and simplify the code! 59 */ 60 61 #ifdef TRACE 62 #undef TRACE 63 #define TRACE(x) x 64 #include <stdio.h> 65 #else 66 #define TRACE(x) 67 #endif 68 69 #include "gmp.h" 70 #include "gmp-impl.h" 71 72 #ifdef WANT_ADDSUB 73 #include "generic/add_n_sub_n.c" 74 #define HAVE_NATIVE_mpn_add_n_sub_n 1 75 #endif 76 77 static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *, 78 mp_ptr *, mp_ptr, mp_ptr, mp_size_t, 79 mp_size_t, mp_size_t, int **, mp_ptr, int); 80 static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr, 81 mp_size_t, mp_size_t, mp_size_t, mp_ptr); 82 83 84 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n. 85 We have sqr=0 if for a multiply, sqr=1 for a square. 86 There are three generations of this code; we keep the old ones as long as 87 some gmp-mparam.h is not updated. */ 88 89 90 /*****************************************************************************/ 91 92 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3)) 93 94 #ifndef FFT_TABLE3_SIZE /* When tuning this is defined in gmp-impl.h */ 95 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE) 96 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE 97 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE 98 #else 99 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE 100 #endif 101 #endif 102 #endif 103 104 #ifndef FFT_TABLE3_SIZE 105 #define FFT_TABLE3_SIZE 200 106 #endif 107 108 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] = 109 { 110 MUL_FFT_TABLE3, 111 SQR_FFT_TABLE3 112 }; 113 114 int 115 mpn_fft_best_k (mp_size_t n, int sqr) 116 { 117 const struct fft_table_nk *fft_tab, *tab; 118 mp_size_t tab_n, thres; 119 int last_k; 120 121 fft_tab = mpn_fft_table3[sqr]; 122 last_k = fft_tab->k; 123 for (tab = fft_tab + 1; ; tab++) 124 { 125 tab_n = tab->n; 126 thres = tab_n << last_k; 127 if (n <= thres) 128 break; 129 last_k = tab->k; 130 } 131 return last_k; 132 } 133 134 #define MPN_FFT_BEST_READY 1 135 #endif 136 137 /*****************************************************************************/ 138 139 #if ! defined (MPN_FFT_BEST_READY) 140 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = 141 { 142 MUL_FFT_TABLE, 143 SQR_FFT_TABLE 144 }; 145 146 int 147 mpn_fft_best_k (mp_size_t n, int sqr) 148 { 149 int i; 150 151 for (i = 0; mpn_fft_table[sqr][i] != 0; i++) 152 if (n < mpn_fft_table[sqr][i]) 153 return i + FFT_FIRST_K; 154 155 /* treat 4*last as one further entry */ 156 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1]) 157 return i + FFT_FIRST_K; 158 else 159 return i + FFT_FIRST_K + 1; 160 } 161 #endif 162 163 /*****************************************************************************/ 164 165 166 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k, 167 i.e. smallest multiple of 2^k >= pl. 168 169 Don't declare static: needed by tuneup. 170 */ 171 172 mp_size_t 173 mpn_fft_next_size (mp_size_t pl, int k) 174 { 175 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */ 176 return pl << k; 177 } 178 179 180 /* Initialize l[i][j] with bitrev(j) */ 181 static void 182 mpn_fft_initl (int **l, int k) 183 { 184 int i, j, K; 185 int *li; 186 187 l[0][0] = 0; 188 for (i = 1, K = 1; i <= k; i++, K *= 2) 189 { 190 li = l[i]; 191 for (j = 0; j < K; j++) 192 { 193 li[j] = 2 * l[i - 1][j]; 194 li[K + j] = 1 + li[j]; 195 } 196 } 197 } 198 199 200 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1} 201 Assumes a is semi-normalized, i.e. a[n] <= 1. 202 r and a must have n+1 limbs, and not overlap. 203 */ 204 static void 205 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n) 206 { 207 unsigned int sh; 208 mp_size_t m; 209 mp_limb_t cc, rd; 210 211 sh = d % GMP_NUMB_BITS; 212 m = d / GMP_NUMB_BITS; 213 214 if (m >= n) /* negate */ 215 { 216 /* r[0..m-1] <-- lshift(a[n-m]..a[n-1], sh) 217 r[m..n-1] <-- -lshift(a[0]..a[n-m-1], sh) */ 218 219 m -= n; 220 if (sh != 0) 221 { 222 /* no out shift below since a[n] <= 1 */ 223 mpn_lshift (r, a + n - m, m + 1, sh); 224 rd = r[m]; 225 cc = mpn_lshiftc (r + m, a, n - m, sh); 226 } 227 else 228 { 229 MPN_COPY (r, a + n - m, m); 230 rd = a[n]; 231 mpn_com (r + m, a, n - m); 232 cc = 0; 233 } 234 235 /* add cc to r[0], and add rd to r[m] */ 236 237 /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */ 238 239 r[n] = 0; 240 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */ 241 cc++; 242 mpn_incr_u (r, cc); 243 244 rd++; 245 /* rd might overflow when sh=GMP_NUMB_BITS-1 */ 246 cc = (rd == 0) ? 1 : rd; 247 r = r + m + (rd == 0); 248 mpn_incr_u (r, cc); 249 } 250 else 251 { 252 /* r[0..m-1] <-- -lshift(a[n-m]..a[n-1], sh) 253 r[m..n-1] <-- lshift(a[0]..a[n-m-1], sh) */ 254 if (sh != 0) 255 { 256 /* no out bits below since a[n] <= 1 */ 257 mpn_lshiftc (r, a + n - m, m + 1, sh); 258 rd = ~r[m]; 259 /* {r, m+1} = {a+n-m, m+1} << sh */ 260 cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */ 261 } 262 else 263 { 264 /* r[m] is not used below, but we save a test for m=0 */ 265 mpn_com (r, a + n - m, m + 1); 266 rd = a[n]; 267 MPN_COPY (r + m, a, n - m); 268 cc = 0; 269 } 270 271 /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */ 272 273 /* if m=0 we just have r[0]=a[n] << sh */ 274 if (m != 0) 275 { 276 /* now add 1 in r[0], subtract 1 in r[m] */ 277 if (cc-- == 0) /* then add 1 to r[0] */ 278 cc = mpn_add_1 (r, r, n, CNST_LIMB(1)); 279 cc = mpn_sub_1 (r, r, m, cc) + 1; 280 /* add 1 to cc instead of rd since rd might overflow */ 281 } 282 283 /* now subtract cc and rd from r[m..n] */ 284 285 r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc); 286 r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd); 287 if (r[n] & GMP_LIMB_HIGHBIT) 288 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1)); 289 } 290 } 291 292 293 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1. 294 Assumes a and b are semi-normalized. 295 */ 296 static inline void 297 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n) 298 { 299 mp_limb_t c, x; 300 301 c = a[n] + b[n] + mpn_add_n (r, a, b, n); 302 /* 0 <= c <= 3 */ 303 304 #if 1 305 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 306 result is slower code, of course. But the following outsmarts GCC. */ 307 x = (c - 1) & -(c != 0); 308 r[n] = c - x; 309 MPN_DECR_U (r, n + 1, x); 310 #endif 311 #if 0 312 if (c > 1) 313 { 314 r[n] = 1; /* r[n] - c = 1 */ 315 MPN_DECR_U (r, n + 1, c - 1); 316 } 317 else 318 { 319 r[n] = c; 320 } 321 #endif 322 } 323 324 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1. 325 Assumes a and b are semi-normalized. 326 */ 327 static inline void 328 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n) 329 { 330 mp_limb_t c, x; 331 332 c = a[n] - b[n] - mpn_sub_n (r, a, b, n); 333 /* -2 <= c <= 1 */ 334 335 #if 1 336 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The 337 result is slower code, of course. But the following outsmarts GCC. */ 338 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0); 339 r[n] = x + c; 340 MPN_INCR_U (r, n + 1, x); 341 #endif 342 #if 0 343 if ((c & GMP_LIMB_HIGHBIT) != 0) 344 { 345 r[n] = 0; 346 MPN_INCR_U (r, n + 1, -c); 347 } 348 else 349 { 350 r[n] = c; 351 } 352 #endif 353 } 354 355 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 356 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 357 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */ 358 359 static void 360 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll, 361 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp) 362 { 363 if (K == 2) 364 { 365 mp_limb_t cy; 366 #if HAVE_NATIVE_mpn_add_n_sub_n 367 cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1; 368 #else 369 MPN_COPY (tp, Ap[0], n + 1); 370 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1); 371 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1); 372 #endif 373 if (Ap[0][n] > 1) /* can be 2 or 3 */ 374 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 375 if (cy) /* Ap[inc][n] can be -1 or -2 */ 376 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1); 377 } 378 else 379 { 380 mp_size_t j, K2 = K >> 1; 381 int *lk = *ll; 382 383 mpn_fft_fft (Ap, K2, ll-1, 2 * omega, n, inc * 2, tp); 384 mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp); 385 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc] 386 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */ 387 for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc) 388 { 389 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega) 390 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */ 391 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n); 392 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n); 393 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 394 } 395 } 396 } 397 398 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where 399 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1 400 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 401 tp must have space for 2*(n+1) limbs. 402 */ 403 404 405 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1, 406 by subtracting that modulus if necessary. 407 408 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a 409 borrow and the limbs must be zeroed out again. This will occur very 410 infrequently. */ 411 412 static inline void 413 mpn_fft_normalize (mp_ptr ap, mp_size_t n) 414 { 415 if (ap[n] != 0) 416 { 417 MPN_DECR_U (ap, n + 1, CNST_LIMB(1)); 418 if (ap[n] == 0) 419 { 420 /* This happens with very low probability; we have yet to trigger it, 421 and thereby make sure this code is correct. */ 422 MPN_ZERO (ap, n); 423 ap[n] = 1; 424 } 425 else 426 ap[n] = 0; 427 } 428 } 429 430 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */ 431 static void 432 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K) 433 { 434 int i; 435 int sqr = (ap == bp); 436 TMP_DECL; 437 438 TMP_MARK; 439 440 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 441 { 442 mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2; 443 int k; 444 int **fft_l, *tmp; 445 mp_ptr *Ap, *Bp, A, B, T; 446 447 k = mpn_fft_best_k (n, sqr); 448 K2 = (mp_size_t) 1 << k; 449 ASSERT_ALWAYS((n & (K2 - 1)) == 0); 450 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS; 451 M2 = n * GMP_NUMB_BITS >> k; 452 l = n >> k; 453 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK; 454 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/ 455 nprime2 = Nprime2 / GMP_NUMB_BITS; 456 457 /* we should ensure that nprime2 is a multiple of the next K */ 458 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 459 { 460 mp_size_t K3; 461 for (;;) 462 { 463 K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr); 464 if ((nprime2 & (K3 - 1)) == 0) 465 break; 466 nprime2 = (nprime2 + K3 - 1) & -K3; 467 Nprime2 = nprime2 * GMP_LIMB_BITS; 468 /* warning: since nprime2 changed, K3 may change too! */ 469 } 470 } 471 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */ 472 473 Mp2 = Nprime2 >> k; 474 475 Ap = TMP_BALLOC_MP_PTRS (K2); 476 Bp = TMP_BALLOC_MP_PTRS (K2); 477 A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k); 478 T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1)); 479 B = A + ((nprime2 + 1) << k); 480 fft_l = TMP_BALLOC_TYPE (k + 1, int *); 481 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int); 482 for (i = 0; i <= k; i++) 483 { 484 fft_l[i] = tmp; 485 tmp += (mp_size_t) 1 << i; 486 } 487 488 mpn_fft_initl (fft_l, k); 489 490 TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n, 491 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2)); 492 for (i = 0; i < K; i++, ap++, bp++) 493 { 494 mp_limb_t cy; 495 mpn_fft_normalize (*ap, n); 496 if (!sqr) 497 mpn_fft_normalize (*bp, n); 498 499 mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T); 500 if (!sqr) 501 mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T); 502 503 cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2, 504 l, Mp2, fft_l, T, sqr); 505 (*ap)[n] = cy; 506 } 507 } 508 else 509 { 510 mp_ptr a, b, tp, tpn; 511 mp_limb_t cc; 512 mp_size_t n2 = 2 * n; 513 tp = TMP_BALLOC_LIMBS (n2); 514 tpn = tp + n; 515 TRACE (printf (" mpn_mul_n %ld of %ld limbs\n", K, n)); 516 for (i = 0; i < K; i++) 517 { 518 a = *ap++; 519 b = *bp++; 520 if (sqr) 521 mpn_sqr (tp, a, n); 522 else 523 mpn_mul_n (tp, b, a, n); 524 if (a[n] != 0) 525 cc = mpn_add_n (tpn, tpn, b, n); 526 else 527 cc = 0; 528 if (b[n] != 0) 529 cc += mpn_add_n (tpn, tpn, a, n) + a[n]; 530 if (cc != 0) 531 { 532 /* FIXME: use MPN_INCR_U here, since carry is not expected. */ 533 cc = mpn_add_1 (tp, tp, n2, cc); 534 ASSERT (cc == 0); 535 } 536 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1)); 537 } 538 } 539 TMP_FREE; 540 } 541 542 543 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]] 544 output: K*A[0] K*A[K-1] ... K*A[1]. 545 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1. 546 This condition is also fulfilled at exit. 547 */ 548 static void 549 mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp) 550 { 551 if (K == 2) 552 { 553 mp_limb_t cy; 554 #if HAVE_NATIVE_mpn_add_n_sub_n 555 cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1; 556 #else 557 MPN_COPY (tp, Ap[0], n + 1); 558 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1); 559 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1); 560 #endif 561 if (Ap[0][n] > 1) /* can be 2 or 3 */ 562 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1); 563 if (cy) /* Ap[1][n] can be -1 or -2 */ 564 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1); 565 } 566 else 567 { 568 mp_size_t j, K2 = K >> 1; 569 570 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp); 571 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp); 572 /* A[j] <- A[j] + omega^j A[j+K/2] 573 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */ 574 for (j = 0; j < K2; j++, Ap++) 575 { 576 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega) 577 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */ 578 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n); 579 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n); 580 mpn_fft_add_modF (Ap[0], Ap[0], tp, n); 581 } 582 } 583 } 584 585 586 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */ 587 static void 588 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n) 589 { 590 mp_bitcnt_t i; 591 592 ASSERT (r != a); 593 i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k; 594 mpn_fft_mul_2exp_modF (r, a, i, n); 595 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */ 596 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */ 597 mpn_fft_normalize (r, n); 598 } 599 600 601 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n. 602 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1, 603 then {rp,n}=0. 604 */ 605 static mp_size_t 606 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an) 607 { 608 mp_size_t l, m, rpn; 609 mp_limb_t cc; 610 611 ASSERT ((n <= an) && (an <= 3 * n)); 612 m = an - 2 * n; 613 if (m > 0) 614 { 615 l = n; 616 /* add {ap, m} and {ap+2n, m} in {rp, m} */ 617 cc = mpn_add_n (rp, ap, ap + 2 * n, m); 618 /* copy {ap+m, n-m} to {rp+m, n-m} */ 619 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc); 620 } 621 else 622 { 623 l = an - n; /* l <= n */ 624 MPN_COPY (rp, ap, n); 625 rpn = 0; 626 } 627 628 /* remains to subtract {ap+n, l} from {rp, n+1} */ 629 cc = mpn_sub_n (rp, rp, ap + n, l); 630 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc); 631 if (rpn < 0) /* necessarily rpn = -1 */ 632 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1)); 633 return rpn; 634 } 635 636 /* store in A[0..nprime] the first M bits from {n, nl}, 637 in A[nprime+1..] the following M bits, ... 638 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS). 639 T must have space for at least (nprime + 1) limbs. 640 We must have nl <= 2*K*l. 641 */ 642 static void 643 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime, 644 mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp, 645 mp_ptr T) 646 { 647 mp_size_t i, j; 648 mp_ptr tmp; 649 mp_size_t Kl = K * l; 650 TMP_DECL; 651 TMP_MARK; 652 653 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */ 654 { 655 mp_size_t dif = nl - Kl; 656 mp_limb_signed_t cy; 657 658 tmp = TMP_BALLOC_LIMBS(Kl + 1); 659 660 if (dif > Kl) 661 { 662 int subp = 0; 663 664 cy = mpn_sub_n (tmp, n, n + Kl, Kl); 665 n += 2 * Kl; 666 dif -= Kl; 667 668 /* now dif > 0 */ 669 while (dif > Kl) 670 { 671 if (subp) 672 cy += mpn_sub_n (tmp, tmp, n, Kl); 673 else 674 cy -= mpn_add_n (tmp, tmp, n, Kl); 675 subp ^= 1; 676 n += Kl; 677 dif -= Kl; 678 } 679 /* now dif <= Kl */ 680 if (subp) 681 cy += mpn_sub (tmp, tmp, Kl, n, dif); 682 else 683 cy -= mpn_add (tmp, tmp, Kl, n, dif); 684 if (cy >= 0) 685 cy = mpn_add_1 (tmp, tmp, Kl, cy); 686 else 687 cy = mpn_sub_1 (tmp, tmp, Kl, -cy); 688 } 689 else /* dif <= Kl, i.e. nl <= 2 * Kl */ 690 { 691 cy = mpn_sub (tmp, n, Kl, n + Kl, dif); 692 cy = mpn_add_1 (tmp, tmp, Kl, cy); 693 } 694 tmp[Kl] = cy; 695 nl = Kl + 1; 696 n = tmp; 697 } 698 for (i = 0; i < K; i++) 699 { 700 Ap[i] = A; 701 /* store the next M bits of n into A[0..nprime] */ 702 if (nl > 0) /* nl is the number of remaining limbs */ 703 { 704 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */ 705 nl -= j; 706 MPN_COPY (T, n, j); 707 MPN_ZERO (T + j, nprime + 1 - j); 708 n += l; 709 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime); 710 } 711 else 712 MPN_ZERO (A, nprime + 1); 713 A += nprime + 1; 714 } 715 ASSERT_ALWAYS (nl == 0); 716 TMP_FREE; 717 } 718 719 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS 720 op is pl limbs, its high bit is returned. 721 One must have pl = mpn_fft_next_size (pl, k). 722 T must have space for 2 * (nprime + 1) limbs. 723 */ 724 725 static mp_limb_t 726 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k, 727 mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B, 728 mp_size_t nprime, mp_size_t l, mp_size_t Mp, 729 int **fft_l, mp_ptr T, int sqr) 730 { 731 mp_size_t K, i, pla, lo, sh, j; 732 mp_ptr p; 733 mp_limb_t cc; 734 735 K = (mp_size_t) 1 << k; 736 737 /* direct fft's */ 738 mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T); 739 if (!sqr) 740 mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T); 741 742 /* term to term multiplications */ 743 mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K); 744 745 /* inverse fft's */ 746 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T); 747 748 /* division of terms after inverse fft */ 749 Bp[0] = T + nprime + 1; 750 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime); 751 for (i = 1; i < K; i++) 752 { 753 Bp[i] = Ap[i - 1]; 754 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime); 755 } 756 757 /* addition of terms in result p */ 758 MPN_ZERO (T, nprime + 1); 759 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 760 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */ 761 MPN_ZERO (p, pla); 762 cc = 0; /* will accumulate the (signed) carry at p[pla] */ 763 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l) 764 { 765 mp_ptr n = p + sh; 766 767 j = (K - i) & (K - 1); 768 769 if (mpn_add_n (n, n, Bp[j], nprime + 1)) 770 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1, 771 pla - sh - nprime - 1, CNST_LIMB(1)); 772 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */ 773 if (mpn_cmp (Bp[j], T, nprime + 1) > 0) 774 { /* subtract 2^N'+1 */ 775 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1)); 776 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1)); 777 } 778 } 779 if (cc == -CNST_LIMB(1)) 780 { 781 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1)))) 782 { 783 /* p[pla-pl]...p[pla-1] are all zero */ 784 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1)); 785 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1)); 786 } 787 } 788 else if (cc == 1) 789 { 790 if (pla >= 2 * pl) 791 { 792 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc))) 793 ; 794 } 795 else 796 { 797 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc); 798 ASSERT (cc == 0); 799 } 800 } 801 else 802 ASSERT (cc == 0); 803 804 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ] 805 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ] 806 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */ 807 return mpn_fft_norm_modF (op, pl, p, pla); 808 } 809 810 /* return the lcm of a and 2^k */ 811 static mp_bitcnt_t 812 mpn_mul_fft_lcm (mp_bitcnt_t a, int k) 813 { 814 mp_bitcnt_t l = k; 815 816 while (a % 2 == 0 && k > 0) 817 { 818 a >>= 1; 819 k --; 820 } 821 return a << l; 822 } 823 824 825 mp_limb_t 826 mpn_mul_fft (mp_ptr op, mp_size_t pl, 827 mp_srcptr n, mp_size_t nl, 828 mp_srcptr m, mp_size_t ml, 829 int k) 830 { 831 int i; 832 mp_size_t K, maxLK; 833 mp_size_t N, Nprime, nprime, M, Mp, l; 834 mp_ptr *Ap, *Bp, A, T, B; 835 int **fft_l, *tmp; 836 int sqr = (n == m && nl == ml); 837 mp_limb_t h; 838 TMP_DECL; 839 840 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k)); 841 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl); 842 843 TMP_MARK; 844 N = pl * GMP_NUMB_BITS; 845 fft_l = TMP_BALLOC_TYPE (k + 1, int *); 846 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int); 847 for (i = 0; i <= k; i++) 848 { 849 fft_l[i] = tmp; 850 tmp += (mp_size_t) 1 << i; 851 } 852 853 mpn_fft_initl (fft_l, k); 854 K = (mp_size_t) 1 << k; 855 M = N >> k; /* N = 2^k M */ 856 l = 1 + (M - 1) / GMP_NUMB_BITS; 857 maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */ 858 859 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK; 860 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */ 861 nprime = Nprime / GMP_NUMB_BITS; 862 TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n", 863 N, K, M, l, maxLK, Nprime, nprime)); 864 /* we should ensure that recursively, nprime is a multiple of the next K */ 865 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) 866 { 867 mp_size_t K2; 868 for (;;) 869 { 870 K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr); 871 if ((nprime & (K2 - 1)) == 0) 872 break; 873 nprime = (nprime + K2 - 1) & -K2; 874 Nprime = nprime * GMP_LIMB_BITS; 875 /* warning: since nprime changed, K2 may change too! */ 876 } 877 TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime)); 878 } 879 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */ 880 881 T = TMP_BALLOC_LIMBS (2 * (nprime + 1)); 882 Mp = Nprime >> k; 883 884 TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n", 885 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K); 886 printf (" temp space %ld\n", 2 * K * (nprime + 1))); 887 888 A = TMP_BALLOC_LIMBS (K * (nprime + 1)); 889 Ap = TMP_BALLOC_MP_PTRS (K); 890 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T); 891 if (sqr) 892 { 893 mp_size_t pla; 894 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ 895 B = TMP_BALLOC_LIMBS (pla); 896 Bp = TMP_BALLOC_MP_PTRS (K); 897 } 898 else 899 { 900 B = TMP_BALLOC_LIMBS (K * (nprime + 1)); 901 Bp = TMP_BALLOC_MP_PTRS (K); 902 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T); 903 } 904 h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr); 905 906 TMP_FREE; 907 return h; 908 } 909 910 #if WANT_OLD_FFT_FULL 911 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */ 912 void 913 mpn_mul_fft_full (mp_ptr op, 914 mp_srcptr n, mp_size_t nl, 915 mp_srcptr m, mp_size_t ml) 916 { 917 mp_ptr pad_op; 918 mp_size_t pl, pl2, pl3, l; 919 mp_size_t cc, c2, oldcc; 920 int k2, k3; 921 int sqr = (n == m && nl == ml); 922 923 pl = nl + ml; /* total number of limbs of the result */ 924 925 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1. 926 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and 927 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2, 928 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) = 929 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j. 930 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1), 931 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */ 932 933 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */ 934 935 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */ 936 do 937 { 938 pl2++; 939 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */ 940 pl2 = mpn_fft_next_size (pl2, k2); 941 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4, 942 thus pl2 / 2 is exact */ 943 k3 = mpn_fft_best_k (pl3, sqr); 944 } 945 while (mpn_fft_next_size (pl3, k3) != pl3); 946 947 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n", 948 nl, ml, pl2, pl3, k2)); 949 950 ASSERT_ALWAYS(pl3 <= pl); 951 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */ 952 ASSERT(cc == 0); 953 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2); 954 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */ 955 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */ 956 /* 0 <= cc <= 1 */ 957 ASSERT(0 <= cc && cc <= 1); 958 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */ 959 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l); 960 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc; 961 ASSERT(-1 <= cc && cc <= 1); 962 if (cc < 0) 963 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 964 ASSERT(0 <= cc && cc <= 1); 965 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */ 966 oldcc = cc; 967 #if HAVE_NATIVE_mpn_add_n_sub_n 968 c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l); 969 /* c2 & 1 is the borrow, c2 & 2 is the carry */ 970 cc += c2 >> 1; /* carry out from high <- low + high */ 971 c2 = c2 & 1; /* borrow out from low <- low - high */ 972 #else 973 { 974 mp_ptr tmp; 975 TMP_DECL; 976 977 TMP_MARK; 978 tmp = TMP_BALLOC_LIMBS (l); 979 MPN_COPY (tmp, pad_op, l); 980 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l); 981 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l); 982 TMP_FREE; 983 } 984 #endif 985 c2 += oldcc; 986 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow 987 at pad_op + l, cc is the carry at pad_op + pl2 */ 988 /* 0 <= cc <= 2 */ 989 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2); 990 /* -1 <= cc <= 2 */ 991 if (cc > 0) 992 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc); 993 /* now -1 <= cc <= 0 */ 994 if (cc < 0) 995 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc); 996 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */ 997 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */ 998 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1)); 999 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry 1000 out below */ 1001 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */ 1002 if (cc) /* then cc=1 */ 1003 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1); 1004 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS)) 1005 mod 2^(pl2*GMP_NUMB_BITS) + 1 */ 1006 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */ 1007 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */ 1008 MPN_COPY (op + pl3, pad_op, pl - pl3); 1009 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl); 1010 __GMP_FREE_FUNC_LIMBS (pad_op, pl2); 1011 /* since the final result has at most pl limbs, no carry out below */ 1012 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2); 1013 } 1014 #endif