github.com/aergoio/aergo@v1.3.1/libtool/src/gmp-6.1.2/mpn/generic/mul_fft.c (about)

     1  /* Schoenhage's fast multiplication modulo 2^N+1.
     2  
     3     Contributed by Paul Zimmermann.
     4  
     5     THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES.  IT IS ONLY
     6     SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES.  IN FACT, IT IS ALMOST
     7     GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
     8  
     9  Copyright 1998-2010, 2012, 2013 Free Software Foundation, Inc.
    10  
    11  This file is part of the GNU MP Library.
    12  
    13  The GNU MP Library is free software; you can redistribute it and/or modify
    14  it under the terms of either:
    15  
    16    * the GNU Lesser General Public License as published by the Free
    17      Software Foundation; either version 3 of the License, or (at your
    18      option) any later version.
    19  
    20  or
    21  
    22    * the GNU General Public License as published by the Free Software
    23      Foundation; either version 2 of the License, or (at your option) any
    24      later version.
    25  
    26  or both in parallel, as here.
    27  
    28  The GNU MP Library is distributed in the hope that it will be useful, but
    29  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
    30  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
    31  for more details.
    32  
    33  You should have received copies of the GNU General Public License and the
    34  GNU Lesser General Public License along with the GNU MP Library.  If not,
    35  see https://www.gnu.org/licenses/.  */
    36  
    37  
    38  /* References:
    39  
    40     Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
    41     Strassen, Computing 7, p. 281-292, 1971.
    42  
    43     Asymptotically fast algorithms for the numerical multiplication and division
    44     of polynomials with complex coefficients, by Arnold Schoenhage, Computer
    45     Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
    46  
    47     Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
    48     Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
    49  
    50     TODO:
    51  
    52     Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
    53     Zimmermann.
    54  
    55     It might be possible to avoid a small number of MPN_COPYs by using a
    56     rotating temporary or two.
    57  
    58     Cleanup and simplify the code!
    59  */
    60  
    61  #ifdef TRACE
    62  #undef TRACE
    63  #define TRACE(x) x
    64  #include <stdio.h>
    65  #else
    66  #define TRACE(x)
    67  #endif
    68  
    69  #include "gmp.h"
    70  #include "gmp-impl.h"
    71  
    72  #ifdef WANT_ADDSUB
    73  #include "generic/add_n_sub_n.c"
    74  #define HAVE_NATIVE_mpn_add_n_sub_n 1
    75  #endif
    76  
    77  static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
    78  				       mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
    79  				       mp_size_t, mp_size_t, int **, mp_ptr, int);
    80  static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
    81  				   mp_size_t, mp_size_t, mp_size_t, mp_ptr);
    82  
    83  
    84  /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
    85     We have sqr=0 if for a multiply, sqr=1 for a square.
    86     There are three generations of this code; we keep the old ones as long as
    87     some gmp-mparam.h is not updated.  */
    88  
    89  
    90  /*****************************************************************************/
    91  
    92  #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
    93  
    94  #ifndef FFT_TABLE3_SIZE		/* When tuning this is defined in gmp-impl.h */
    95  #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
    96  #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
    97  #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
    98  #else
    99  #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
   100  #endif
   101  #endif
   102  #endif
   103  
   104  #ifndef FFT_TABLE3_SIZE
   105  #define FFT_TABLE3_SIZE 200
   106  #endif
   107  
   108  FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
   109  {
   110    MUL_FFT_TABLE3,
   111    SQR_FFT_TABLE3
   112  };
   113  
   114  int
   115  mpn_fft_best_k (mp_size_t n, int sqr)
   116  {
   117    const struct fft_table_nk *fft_tab, *tab;
   118    mp_size_t tab_n, thres;
   119    int last_k;
   120  
   121    fft_tab = mpn_fft_table3[sqr];
   122    last_k = fft_tab->k;
   123    for (tab = fft_tab + 1; ; tab++)
   124      {
   125        tab_n = tab->n;
   126        thres = tab_n << last_k;
   127        if (n <= thres)
   128  	break;
   129        last_k = tab->k;
   130      }
   131    return last_k;
   132  }
   133  
   134  #define MPN_FFT_BEST_READY 1
   135  #endif
   136  
   137  /*****************************************************************************/
   138  
   139  #if ! defined (MPN_FFT_BEST_READY)
   140  FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
   141  {
   142    MUL_FFT_TABLE,
   143    SQR_FFT_TABLE
   144  };
   145  
   146  int
   147  mpn_fft_best_k (mp_size_t n, int sqr)
   148  {
   149    int i;
   150  
   151    for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
   152      if (n < mpn_fft_table[sqr][i])
   153        return i + FFT_FIRST_K;
   154  
   155    /* treat 4*last as one further entry */
   156    if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
   157      return i + FFT_FIRST_K;
   158    else
   159      return i + FFT_FIRST_K + 1;
   160  }
   161  #endif
   162  
   163  /*****************************************************************************/
   164  
   165  
   166  /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
   167     i.e. smallest multiple of 2^k >= pl.
   168  
   169     Don't declare static: needed by tuneup.
   170  */
   171  
   172  mp_size_t
   173  mpn_fft_next_size (mp_size_t pl, int k)
   174  {
   175    pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
   176    return pl << k;
   177  }
   178  
   179  
   180  /* Initialize l[i][j] with bitrev(j) */
   181  static void
   182  mpn_fft_initl (int **l, int k)
   183  {
   184    int i, j, K;
   185    int *li;
   186  
   187    l[0][0] = 0;
   188    for (i = 1, K = 1; i <= k; i++, K *= 2)
   189      {
   190        li = l[i];
   191        for (j = 0; j < K; j++)
   192  	{
   193  	  li[j] = 2 * l[i - 1][j];
   194  	  li[K + j] = 1 + li[j];
   195  	}
   196      }
   197  }
   198  
   199  
   200  /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
   201     Assumes a is semi-normalized, i.e. a[n] <= 1.
   202     r and a must have n+1 limbs, and not overlap.
   203  */
   204  static void
   205  mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
   206  {
   207    unsigned int sh;
   208    mp_size_t m;
   209    mp_limb_t cc, rd;
   210  
   211    sh = d % GMP_NUMB_BITS;
   212    m = d / GMP_NUMB_BITS;
   213  
   214    if (m >= n)			/* negate */
   215      {
   216        /* r[0..m-1]  <-- lshift(a[n-m]..a[n-1], sh)
   217  	 r[m..n-1]  <-- -lshift(a[0]..a[n-m-1],  sh) */
   218  
   219        m -= n;
   220        if (sh != 0)
   221  	{
   222  	  /* no out shift below since a[n] <= 1 */
   223  	  mpn_lshift (r, a + n - m, m + 1, sh);
   224  	  rd = r[m];
   225  	  cc = mpn_lshiftc (r + m, a, n - m, sh);
   226  	}
   227        else
   228  	{
   229  	  MPN_COPY (r, a + n - m, m);
   230  	  rd = a[n];
   231  	  mpn_com (r + m, a, n - m);
   232  	  cc = 0;
   233  	}
   234  
   235        /* add cc to r[0], and add rd to r[m] */
   236  
   237        /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
   238  
   239        r[n] = 0;
   240        /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
   241        cc++;
   242        mpn_incr_u (r, cc);
   243  
   244        rd++;
   245        /* rd might overflow when sh=GMP_NUMB_BITS-1 */
   246        cc = (rd == 0) ? 1 : rd;
   247        r = r + m + (rd == 0);
   248        mpn_incr_u (r, cc);
   249      }
   250    else
   251      {
   252        /* r[0..m-1]  <-- -lshift(a[n-m]..a[n-1], sh)
   253  	 r[m..n-1]  <-- lshift(a[0]..a[n-m-1],  sh)  */
   254        if (sh != 0)
   255  	{
   256  	  /* no out bits below since a[n] <= 1 */
   257  	  mpn_lshiftc (r, a + n - m, m + 1, sh);
   258  	  rd = ~r[m];
   259  	  /* {r, m+1} = {a+n-m, m+1} << sh */
   260  	  cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
   261  	}
   262        else
   263  	{
   264  	  /* r[m] is not used below, but we save a test for m=0 */
   265  	  mpn_com (r, a + n - m, m + 1);
   266  	  rd = a[n];
   267  	  MPN_COPY (r + m, a, n - m);
   268  	  cc = 0;
   269  	}
   270  
   271        /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
   272  
   273        /* if m=0 we just have r[0]=a[n] << sh */
   274        if (m != 0)
   275  	{
   276  	  /* now add 1 in r[0], subtract 1 in r[m] */
   277  	  if (cc-- == 0) /* then add 1 to r[0] */
   278  	    cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
   279  	  cc = mpn_sub_1 (r, r, m, cc) + 1;
   280  	  /* add 1 to cc instead of rd since rd might overflow */
   281  	}
   282  
   283        /* now subtract cc and rd from r[m..n] */
   284  
   285        r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc);
   286        r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd);
   287        if (r[n] & GMP_LIMB_HIGHBIT)
   288  	r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
   289      }
   290  }
   291  
   292  
   293  /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
   294     Assumes a and b are semi-normalized.
   295  */
   296  static inline void
   297  mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
   298  {
   299    mp_limb_t c, x;
   300  
   301    c = a[n] + b[n] + mpn_add_n (r, a, b, n);
   302    /* 0 <= c <= 3 */
   303  
   304  #if 1
   305    /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
   306       result is slower code, of course.  But the following outsmarts GCC.  */
   307    x = (c - 1) & -(c != 0);
   308    r[n] = c - x;
   309    MPN_DECR_U (r, n + 1, x);
   310  #endif
   311  #if 0
   312    if (c > 1)
   313      {
   314        r[n] = 1;                       /* r[n] - c = 1 */
   315        MPN_DECR_U (r, n + 1, c - 1);
   316      }
   317    else
   318      {
   319        r[n] = c;
   320      }
   321  #endif
   322  }
   323  
   324  /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
   325     Assumes a and b are semi-normalized.
   326  */
   327  static inline void
   328  mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
   329  {
   330    mp_limb_t c, x;
   331  
   332    c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
   333    /* -2 <= c <= 1 */
   334  
   335  #if 1
   336    /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch.  The
   337       result is slower code, of course.  But the following outsmarts GCC.  */
   338    x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
   339    r[n] = x + c;
   340    MPN_INCR_U (r, n + 1, x);
   341  #endif
   342  #if 0
   343    if ((c & GMP_LIMB_HIGHBIT) != 0)
   344      {
   345        r[n] = 0;
   346        MPN_INCR_U (r, n + 1, -c);
   347      }
   348    else
   349      {
   350        r[n] = c;
   351      }
   352  #endif
   353  }
   354  
   355  /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
   356  	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
   357     output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
   358  
   359  static void
   360  mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
   361  	     mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
   362  {
   363    if (K == 2)
   364      {
   365        mp_limb_t cy;
   366  #if HAVE_NATIVE_mpn_add_n_sub_n
   367        cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
   368  #else
   369        MPN_COPY (tp, Ap[0], n + 1);
   370        mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
   371        cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
   372  #endif
   373        if (Ap[0][n] > 1) /* can be 2 or 3 */
   374  	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
   375        if (cy) /* Ap[inc][n] can be -1 or -2 */
   376  	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
   377      }
   378    else
   379      {
   380        mp_size_t j, K2 = K >> 1;
   381        int *lk = *ll;
   382  
   383        mpn_fft_fft (Ap,     K2, ll-1, 2 * omega, n, inc * 2, tp);
   384        mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
   385        /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
   386  	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
   387        for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
   388  	{
   389  	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
   390  	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
   391  	  mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
   392  	  mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
   393  	  mpn_fft_add_modF (Ap[0],   Ap[0], tp, n);
   394  	}
   395      }
   396  }
   397  
   398  /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
   399  	  N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
   400     output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
   401     tp must have space for 2*(n+1) limbs.
   402  */
   403  
   404  
   405  /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
   406     by subtracting that modulus if necessary.
   407  
   408     If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
   409     borrow and the limbs must be zeroed out again.  This will occur very
   410     infrequently.  */
   411  
   412  static inline void
   413  mpn_fft_normalize (mp_ptr ap, mp_size_t n)
   414  {
   415    if (ap[n] != 0)
   416      {
   417        MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
   418        if (ap[n] == 0)
   419  	{
   420  	  /* This happens with very low probability; we have yet to trigger it,
   421  	     and thereby make sure this code is correct.  */
   422  	  MPN_ZERO (ap, n);
   423  	  ap[n] = 1;
   424  	}
   425        else
   426  	ap[n] = 0;
   427      }
   428  }
   429  
   430  /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
   431  static void
   432  mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
   433  {
   434    int i;
   435    int sqr = (ap == bp);
   436    TMP_DECL;
   437  
   438    TMP_MARK;
   439  
   440    if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
   441      {
   442        mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
   443        int k;
   444        int **fft_l, *tmp;
   445        mp_ptr *Ap, *Bp, A, B, T;
   446  
   447        k = mpn_fft_best_k (n, sqr);
   448        K2 = (mp_size_t) 1 << k;
   449        ASSERT_ALWAYS((n & (K2 - 1)) == 0);
   450        maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
   451        M2 = n * GMP_NUMB_BITS >> k;
   452        l = n >> k;
   453        Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
   454        /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
   455        nprime2 = Nprime2 / GMP_NUMB_BITS;
   456  
   457        /* we should ensure that nprime2 is a multiple of the next K */
   458        if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
   459  	{
   460  	  mp_size_t K3;
   461  	  for (;;)
   462  	    {
   463  	      K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
   464  	      if ((nprime2 & (K3 - 1)) == 0)
   465  		break;
   466  	      nprime2 = (nprime2 + K3 - 1) & -K3;
   467  	      Nprime2 = nprime2 * GMP_LIMB_BITS;
   468  	      /* warning: since nprime2 changed, K3 may change too! */
   469  	    }
   470  	}
   471        ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
   472  
   473        Mp2 = Nprime2 >> k;
   474  
   475        Ap = TMP_BALLOC_MP_PTRS (K2);
   476        Bp = TMP_BALLOC_MP_PTRS (K2);
   477        A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k);
   478        T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
   479        B = A + ((nprime2 + 1) << k);
   480        fft_l = TMP_BALLOC_TYPE (k + 1, int *);
   481        tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
   482        for (i = 0; i <= k; i++)
   483  	{
   484  	  fft_l[i] = tmp;
   485  	  tmp += (mp_size_t) 1 << i;
   486  	}
   487  
   488        mpn_fft_initl (fft_l, k);
   489  
   490        TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
   491  		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
   492        for (i = 0; i < K; i++, ap++, bp++)
   493  	{
   494  	  mp_limb_t cy;
   495  	  mpn_fft_normalize (*ap, n);
   496  	  if (!sqr)
   497  	    mpn_fft_normalize (*bp, n);
   498  
   499  	  mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
   500  	  if (!sqr)
   501  	    mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
   502  
   503  	  cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
   504  				     l, Mp2, fft_l, T, sqr);
   505  	  (*ap)[n] = cy;
   506  	}
   507      }
   508    else
   509      {
   510        mp_ptr a, b, tp, tpn;
   511        mp_limb_t cc;
   512        mp_size_t n2 = 2 * n;
   513        tp = TMP_BALLOC_LIMBS (n2);
   514        tpn = tp + n;
   515        TRACE (printf ("  mpn_mul_n %ld of %ld limbs\n", K, n));
   516        for (i = 0; i < K; i++)
   517  	{
   518  	  a = *ap++;
   519  	  b = *bp++;
   520  	  if (sqr)
   521  	    mpn_sqr (tp, a, n);
   522  	  else
   523  	    mpn_mul_n (tp, b, a, n);
   524  	  if (a[n] != 0)
   525  	    cc = mpn_add_n (tpn, tpn, b, n);
   526  	  else
   527  	    cc = 0;
   528  	  if (b[n] != 0)
   529  	    cc += mpn_add_n (tpn, tpn, a, n) + a[n];
   530  	  if (cc != 0)
   531  	    {
   532  	      /* FIXME: use MPN_INCR_U here, since carry is not expected.  */
   533  	      cc = mpn_add_1 (tp, tp, n2, cc);
   534  	      ASSERT (cc == 0);
   535  	    }
   536  	  a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
   537  	}
   538      }
   539    TMP_FREE;
   540  }
   541  
   542  
   543  /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
   544     output: K*A[0] K*A[K-1] ... K*A[1].
   545     Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
   546     This condition is also fulfilled at exit.
   547  */
   548  static void
   549  mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
   550  {
   551    if (K == 2)
   552      {
   553        mp_limb_t cy;
   554  #if HAVE_NATIVE_mpn_add_n_sub_n
   555        cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
   556  #else
   557        MPN_COPY (tp, Ap[0], n + 1);
   558        mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
   559        cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
   560  #endif
   561        if (Ap[0][n] > 1) /* can be 2 or 3 */
   562  	Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
   563        if (cy) /* Ap[1][n] can be -1 or -2 */
   564  	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
   565      }
   566    else
   567      {
   568        mp_size_t j, K2 = K >> 1;
   569  
   570        mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
   571        mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
   572        /* A[j]     <- A[j] + omega^j A[j+K/2]
   573  	 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
   574        for (j = 0; j < K2; j++, Ap++)
   575  	{
   576  	  /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
   577  	     Ap[0]  <- Ap[0] + Ap[K2] * 2^(j * omega) */
   578  	  mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
   579  	  mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
   580  	  mpn_fft_add_modF (Ap[0],  Ap[0], tp, n);
   581  	}
   582      }
   583  }
   584  
   585  
   586  /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
   587  static void
   588  mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
   589  {
   590    mp_bitcnt_t i;
   591  
   592    ASSERT (r != a);
   593    i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
   594    mpn_fft_mul_2exp_modF (r, a, i, n);
   595    /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
   596    /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
   597    mpn_fft_normalize (r, n);
   598  }
   599  
   600  
   601  /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
   602     Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
   603     then {rp,n}=0.
   604  */
   605  static mp_size_t
   606  mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
   607  {
   608    mp_size_t l, m, rpn;
   609    mp_limb_t cc;
   610  
   611    ASSERT ((n <= an) && (an <= 3 * n));
   612    m = an - 2 * n;
   613    if (m > 0)
   614      {
   615        l = n;
   616        /* add {ap, m} and {ap+2n, m} in {rp, m} */
   617        cc = mpn_add_n (rp, ap, ap + 2 * n, m);
   618        /* copy {ap+m, n-m} to {rp+m, n-m} */
   619        rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
   620      }
   621    else
   622      {
   623        l = an - n; /* l <= n */
   624        MPN_COPY (rp, ap, n);
   625        rpn = 0;
   626      }
   627  
   628    /* remains to subtract {ap+n, l} from {rp, n+1} */
   629    cc = mpn_sub_n (rp, rp, ap + n, l);
   630    rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
   631    if (rpn < 0) /* necessarily rpn = -1 */
   632      rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
   633    return rpn;
   634  }
   635  
   636  /* store in A[0..nprime] the first M bits from {n, nl},
   637     in A[nprime+1..] the following M bits, ...
   638     Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
   639     T must have space for at least (nprime + 1) limbs.
   640     We must have nl <= 2*K*l.
   641  */
   642  static void
   643  mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
   644  		       mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
   645  		       mp_ptr T)
   646  {
   647    mp_size_t i, j;
   648    mp_ptr tmp;
   649    mp_size_t Kl = K * l;
   650    TMP_DECL;
   651    TMP_MARK;
   652  
   653    if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
   654      {
   655        mp_size_t dif = nl - Kl;
   656        mp_limb_signed_t cy;
   657  
   658        tmp = TMP_BALLOC_LIMBS(Kl + 1);
   659  
   660        if (dif > Kl)
   661  	{
   662  	  int subp = 0;
   663  
   664  	  cy = mpn_sub_n (tmp, n, n + Kl, Kl);
   665  	  n += 2 * Kl;
   666  	  dif -= Kl;
   667  
   668  	  /* now dif > 0 */
   669  	  while (dif > Kl)
   670  	    {
   671  	      if (subp)
   672  		cy += mpn_sub_n (tmp, tmp, n, Kl);
   673  	      else
   674  		cy -= mpn_add_n (tmp, tmp, n, Kl);
   675  	      subp ^= 1;
   676  	      n += Kl;
   677  	      dif -= Kl;
   678  	    }
   679  	  /* now dif <= Kl */
   680  	  if (subp)
   681  	    cy += mpn_sub (tmp, tmp, Kl, n, dif);
   682  	  else
   683  	    cy -= mpn_add (tmp, tmp, Kl, n, dif);
   684  	  if (cy >= 0)
   685  	    cy = mpn_add_1 (tmp, tmp, Kl, cy);
   686  	  else
   687  	    cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
   688  	}
   689        else /* dif <= Kl, i.e. nl <= 2 * Kl */
   690  	{
   691  	  cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
   692  	  cy = mpn_add_1 (tmp, tmp, Kl, cy);
   693  	}
   694        tmp[Kl] = cy;
   695        nl = Kl + 1;
   696        n = tmp;
   697      }
   698    for (i = 0; i < K; i++)
   699      {
   700        Ap[i] = A;
   701        /* store the next M bits of n into A[0..nprime] */
   702        if (nl > 0) /* nl is the number of remaining limbs */
   703  	{
   704  	  j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
   705  	  nl -= j;
   706  	  MPN_COPY (T, n, j);
   707  	  MPN_ZERO (T + j, nprime + 1 - j);
   708  	  n += l;
   709  	  mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
   710  	}
   711        else
   712  	MPN_ZERO (A, nprime + 1);
   713        A += nprime + 1;
   714      }
   715    ASSERT_ALWAYS (nl == 0);
   716    TMP_FREE;
   717  }
   718  
   719  /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
   720     op is pl limbs, its high bit is returned.
   721     One must have pl = mpn_fft_next_size (pl, k).
   722     T must have space for 2 * (nprime + 1) limbs.
   723  */
   724  
   725  static mp_limb_t
   726  mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
   727  		      mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
   728  		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
   729  		      int **fft_l, mp_ptr T, int sqr)
   730  {
   731    mp_size_t K, i, pla, lo, sh, j;
   732    mp_ptr p;
   733    mp_limb_t cc;
   734  
   735    K = (mp_size_t) 1 << k;
   736  
   737    /* direct fft's */
   738    mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
   739    if (!sqr)
   740      mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
   741  
   742    /* term to term multiplications */
   743    mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
   744  
   745    /* inverse fft's */
   746    mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
   747  
   748    /* division of terms after inverse fft */
   749    Bp[0] = T + nprime + 1;
   750    mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
   751    for (i = 1; i < K; i++)
   752      {
   753        Bp[i] = Ap[i - 1];
   754        mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
   755      }
   756  
   757    /* addition of terms in result p */
   758    MPN_ZERO (T, nprime + 1);
   759    pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
   760    p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
   761    MPN_ZERO (p, pla);
   762    cc = 0; /* will accumulate the (signed) carry at p[pla] */
   763    for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
   764      {
   765        mp_ptr n = p + sh;
   766  
   767        j = (K - i) & (K - 1);
   768  
   769        if (mpn_add_n (n, n, Bp[j], nprime + 1))
   770  	cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
   771  			  pla - sh - nprime - 1, CNST_LIMB(1));
   772        T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
   773        if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
   774  	{ /* subtract 2^N'+1 */
   775  	  cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
   776  	  cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
   777  	}
   778      }
   779    if (cc == -CNST_LIMB(1))
   780      {
   781        if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
   782  	{
   783  	  /* p[pla-pl]...p[pla-1] are all zero */
   784  	  mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
   785  	  mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
   786  	}
   787      }
   788    else if (cc == 1)
   789      {
   790        if (pla >= 2 * pl)
   791  	{
   792  	  while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
   793  	    ;
   794  	}
   795        else
   796  	{
   797  	  cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
   798  	  ASSERT (cc == 0);
   799  	}
   800      }
   801    else
   802      ASSERT (cc == 0);
   803  
   804    /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
   805       < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
   806       < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
   807    return mpn_fft_norm_modF (op, pl, p, pla);
   808  }
   809  
   810  /* return the lcm of a and 2^k */
   811  static mp_bitcnt_t
   812  mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
   813  {
   814    mp_bitcnt_t l = k;
   815  
   816    while (a % 2 == 0 && k > 0)
   817      {
   818        a >>= 1;
   819        k --;
   820      }
   821    return a << l;
   822  }
   823  
   824  
   825  mp_limb_t
   826  mpn_mul_fft (mp_ptr op, mp_size_t pl,
   827  	     mp_srcptr n, mp_size_t nl,
   828  	     mp_srcptr m, mp_size_t ml,
   829  	     int k)
   830  {
   831    int i;
   832    mp_size_t K, maxLK;
   833    mp_size_t N, Nprime, nprime, M, Mp, l;
   834    mp_ptr *Ap, *Bp, A, T, B;
   835    int **fft_l, *tmp;
   836    int sqr = (n == m && nl == ml);
   837    mp_limb_t h;
   838    TMP_DECL;
   839  
   840    TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
   841    ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
   842  
   843    TMP_MARK;
   844    N = pl * GMP_NUMB_BITS;
   845    fft_l = TMP_BALLOC_TYPE (k + 1, int *);
   846    tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
   847    for (i = 0; i <= k; i++)
   848      {
   849        fft_l[i] = tmp;
   850        tmp += (mp_size_t) 1 << i;
   851      }
   852  
   853    mpn_fft_initl (fft_l, k);
   854    K = (mp_size_t) 1 << k;
   855    M = N >> k;	/* N = 2^k M */
   856    l = 1 + (M - 1) / GMP_NUMB_BITS;
   857    maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
   858  
   859    Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
   860    /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
   861    nprime = Nprime / GMP_NUMB_BITS;
   862    TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
   863  		 N, K, M, l, maxLK, Nprime, nprime));
   864    /* we should ensure that recursively, nprime is a multiple of the next K */
   865    if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
   866      {
   867        mp_size_t K2;
   868        for (;;)
   869  	{
   870  	  K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
   871  	  if ((nprime & (K2 - 1)) == 0)
   872  	    break;
   873  	  nprime = (nprime + K2 - 1) & -K2;
   874  	  Nprime = nprime * GMP_LIMB_BITS;
   875  	  /* warning: since nprime changed, K2 may change too! */
   876  	}
   877        TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
   878      }
   879    ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
   880  
   881    T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
   882    Mp = Nprime >> k;
   883  
   884    TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
   885  		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
   886  	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
   887  
   888    A = TMP_BALLOC_LIMBS (K * (nprime + 1));
   889    Ap = TMP_BALLOC_MP_PTRS (K);
   890    mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
   891    if (sqr)
   892      {
   893        mp_size_t pla;
   894        pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
   895        B = TMP_BALLOC_LIMBS (pla);
   896        Bp = TMP_BALLOC_MP_PTRS (K);
   897      }
   898    else
   899      {
   900        B = TMP_BALLOC_LIMBS (K * (nprime + 1));
   901        Bp = TMP_BALLOC_MP_PTRS (K);
   902        mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
   903      }
   904    h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
   905  
   906    TMP_FREE;
   907    return h;
   908  }
   909  
   910  #if WANT_OLD_FFT_FULL
   911  /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
   912  void
   913  mpn_mul_fft_full (mp_ptr op,
   914  		  mp_srcptr n, mp_size_t nl,
   915  		  mp_srcptr m, mp_size_t ml)
   916  {
   917    mp_ptr pad_op;
   918    mp_size_t pl, pl2, pl3, l;
   919    mp_size_t cc, c2, oldcc;
   920    int k2, k3;
   921    int sqr = (n == m && nl == ml);
   922  
   923    pl = nl + ml; /* total number of limbs of the result */
   924  
   925    /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
   926       We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
   927       pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
   928       and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
   929       (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
   930       We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
   931       which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
   932  
   933    /*  ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
   934  
   935    pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
   936    do
   937      {
   938        pl2++;
   939        k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
   940        pl2 = mpn_fft_next_size (pl2, k2);
   941        pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
   942  			    thus pl2 / 2 is exact */
   943        k3 = mpn_fft_best_k (pl3, sqr);
   944      }
   945    while (mpn_fft_next_size (pl3, k3) != pl3);
   946  
   947    TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
   948  		 nl, ml, pl2, pl3, k2));
   949  
   950    ASSERT_ALWAYS(pl3 <= pl);
   951    cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3);     /* mu */
   952    ASSERT(cc == 0);
   953    pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
   954    cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
   955    cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2);    /* lambda - low(mu) */
   956    /* 0 <= cc <= 1 */
   957    ASSERT(0 <= cc && cc <= 1);
   958    l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
   959    c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
   960    cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
   961    ASSERT(-1 <= cc && cc <= 1);
   962    if (cc < 0)
   963      cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
   964    ASSERT(0 <= cc && cc <= 1);
   965    /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
   966    oldcc = cc;
   967  #if HAVE_NATIVE_mpn_add_n_sub_n
   968    c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
   969    /* c2 & 1 is the borrow, c2 & 2 is the carry */
   970    cc += c2 >> 1; /* carry out from high <- low + high */
   971    c2 = c2 & 1; /* borrow out from low <- low - high */
   972  #else
   973    {
   974      mp_ptr tmp;
   975      TMP_DECL;
   976  
   977      TMP_MARK;
   978      tmp = TMP_BALLOC_LIMBS (l);
   979      MPN_COPY (tmp, pad_op, l);
   980      c2 = mpn_sub_n (pad_op,      pad_op, pad_op + l, l);
   981      cc += mpn_add_n (pad_op + l, tmp,    pad_op + l, l);
   982      TMP_FREE;
   983    }
   984  #endif
   985    c2 += oldcc;
   986    /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
   987       at pad_op + l, cc is the carry at pad_op + pl2 */
   988    /* 0 <= cc <= 2 */
   989    cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
   990    /* -1 <= cc <= 2 */
   991    if (cc > 0)
   992      cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
   993    /* now -1 <= cc <= 0 */
   994    if (cc < 0)
   995      cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
   996    /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
   997    if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
   998      cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
   999    /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
  1000       out below */
  1001    mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
  1002    if (cc) /* then cc=1 */
  1003      pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
  1004    /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
  1005       mod 2^(pl2*GMP_NUMB_BITS) + 1 */
  1006    c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
  1007    /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
  1008    MPN_COPY (op + pl3, pad_op, pl - pl3);
  1009    ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
  1010    __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
  1011    /* since the final result has at most pl limbs, no carry out below */
  1012    mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
  1013  }
  1014  #endif