github.com/remyoudompheng/bigfft@v0.0.0-20230129092748-24d4a6f8daec/fermat.go (about)

     1  package bigfft
     2  
     3  import (
     4  	"math/big"
     5  )
     6  
     7  // Arithmetic modulo 2^n+1.
     8  
     9  // A fermat of length w+1 represents a number modulo 2^(w*_W) + 1. The last
    10  // word is zero or one. A number has at most two representatives satisfying the
    11  // 0-1 last word constraint.
    12  type fermat nat
    13  
    14  func (n fermat) String() string { return nat(n).String() }
    15  
    16  func (z fermat) norm() {
    17  	n := len(z) - 1
    18  	c := z[n]
    19  	if c == 0 {
    20  		return
    21  	}
    22  	if z[0] >= c {
    23  		z[n] = 0
    24  		z[0] -= c
    25  		return
    26  	}
    27  	// z[0] < z[n].
    28  	subVW(z, z, c) // Substract c
    29  	if c > 1 {
    30  		z[n] -= c - 1
    31  		c = 1
    32  	}
    33  	// Add back c.
    34  	if z[n] == 1 {
    35  		z[n] = 0
    36  		return
    37  	} else {
    38  		addVW(z, z, 1)
    39  	}
    40  }
    41  
    42  // Shift computes (x << k) mod (2^n+1).
    43  func (z fermat) Shift(x fermat, k int) {
    44  	if len(z) != len(x) {
    45  		panic("len(z) != len(x) in Shift")
    46  	}
    47  	n := len(x) - 1
    48  	// Shift by n*_W is taking the opposite.
    49  	k %= 2 * n * _W
    50  	if k < 0 {
    51  		k += 2 * n * _W
    52  	}
    53  	neg := false
    54  	if k >= n*_W {
    55  		k -= n * _W
    56  		neg = true
    57  	}
    58  
    59  	kw, kb := k/_W, k%_W
    60  
    61  	z[n] = 1 // Add (-1)
    62  	if !neg {
    63  		for i := 0; i < kw; i++ {
    64  			z[i] = 0
    65  		}
    66  		// Shift left by kw words.
    67  		// x = a·2^(n-k) + b
    68  		// x<<k = (b<<k) - a
    69  		copy(z[kw:], x[:n-kw])
    70  		b := subVV(z[:kw+1], z[:kw+1], x[n-kw:])
    71  		if z[kw+1] > 0 {
    72  			z[kw+1] -= b
    73  		} else {
    74  			subVW(z[kw+1:], z[kw+1:], b)
    75  		}
    76  	} else {
    77  		for i := kw + 1; i < n; i++ {
    78  			z[i] = 0
    79  		}
    80  		// Shift left and negate, by kw words.
    81  		copy(z[:kw+1], x[n-kw:n+1])            // z_low = x_high
    82  		b := subVV(z[kw:n], z[kw:n], x[:n-kw]) // z_high -= x_low
    83  		z[n] -= b
    84  	}
    85  	// Add back 1.
    86  	if z[n] > 0 {
    87  		z[n]--
    88  	} else if z[0] < ^big.Word(0) {
    89  		z[0]++
    90  	} else {
    91  		addVW(z, z, 1)
    92  	}
    93  	// Shift left by kb bits
    94  	shlVU(z, z, uint(kb))
    95  	z.norm()
    96  }
    97  
    98  // ShiftHalf shifts x by k/2 bits the left. Shifting by 1/2 bit
    99  // is multiplication by sqrt(2) mod 2^n+1 which is 2^(3n/4) - 2^(n/4).
   100  // A temporary buffer must be provided in tmp.
   101  func (z fermat) ShiftHalf(x fermat, k int, tmp fermat) {
   102  	n := len(z) - 1
   103  	if k%2 == 0 {
   104  		z.Shift(x, k/2)
   105  		return
   106  	}
   107  	u := (k - 1) / 2
   108  	a := u + (3*_W/4)*n
   109  	b := u + (_W/4)*n
   110  	z.Shift(x, a)
   111  	tmp.Shift(x, b)
   112  	z.Sub(z, tmp)
   113  }
   114  
   115  // Add computes addition mod 2^n+1.
   116  func (z fermat) Add(x, y fermat) fermat {
   117  	if len(z) != len(x) {
   118  		panic("Add: len(z) != len(x)")
   119  	}
   120  	addVV(z, x, y) // there cannot be a carry here.
   121  	z.norm()
   122  	return z
   123  }
   124  
   125  // Sub computes substraction mod 2^n+1.
   126  func (z fermat) Sub(x, y fermat) fermat {
   127  	if len(z) != len(x) {
   128  		panic("Add: len(z) != len(x)")
   129  	}
   130  	n := len(y) - 1
   131  	b := subVV(z[:n], x[:n], y[:n])
   132  	b += y[n]
   133  	// If b > 0, we need to subtract b<<n, which is the same as adding b.
   134  	z[n] = x[n]
   135  	if z[0] <= ^big.Word(0)-b {
   136  		z[0] += b
   137  	} else {
   138  		addVW(z, z, b)
   139  	}
   140  	z.norm()
   141  	return z
   142  }
   143  
   144  func (z fermat) Mul(x, y fermat) fermat {
   145  	if len(x) != len(y) {
   146  		panic("Mul: len(x) != len(y)")
   147  	}
   148  	n := len(x) - 1
   149  	if n < 30 {
   150  		z = z[:2*n+2]
   151  		basicMul(z, x, y)
   152  		z = z[:2*n+1]
   153  	} else {
   154  		var xi, yi, zi big.Int
   155  		xi.SetBits(x)
   156  		yi.SetBits(y)
   157  		zi.SetBits(z)
   158  		zb := zi.Mul(&xi, &yi).Bits()
   159  		if len(zb) <= n {
   160  			// Short product.
   161  			copy(z, zb)
   162  			for i := len(zb); i < len(z); i++ {
   163  				z[i] = 0
   164  			}
   165  			return z
   166  		}
   167  		z = zb
   168  	}
   169  	// len(z) is at most 2n+1.
   170  	if len(z) > 2*n+1 {
   171  		panic("len(z) > 2n+1")
   172  	}
   173  	// We now have
   174  	// z = z[:n] + 1<<(n*W) * z[n:2n+1]
   175  	// which normalizes to:
   176  	// z = z[:n] - z[n:2n] + z[2n]
   177  	c1 := big.Word(0)
   178  	if len(z) > 2*n {
   179  		c1 = addVW(z[:n], z[:n], z[2*n])
   180  	}
   181  	c2 := big.Word(0)
   182  	if len(z) >= 2*n {
   183  		c2 = subVV(z[:n], z[:n], z[n:2*n])
   184  	} else {
   185  		m := len(z) - n
   186  		c2 = subVV(z[:m], z[:m], z[n:])
   187  		c2 = subVW(z[m:n], z[m:n], c2)
   188  	}
   189  	// Restore carries.
   190  	// Substracting z[n] -= c2 is the same
   191  	// as z[0] += c2
   192  	z = z[:n+1]
   193  	z[n] = c1
   194  	c := addVW(z, z, c2)
   195  	if c != 0 {
   196  		panic("impossible")
   197  	}
   198  	z.norm()
   199  	return z
   200  }
   201  
   202  // copied from math/big
   203  //
   204  // basicMul multiplies x and y and leaves the result in z.
   205  // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   206  func basicMul(z, x, y fermat) {
   207  	// initialize z
   208  	for i := 0; i < len(z); i++ {
   209  		z[i] = 0
   210  	}
   211  	for i, d := range y {
   212  		if d != 0 {
   213  			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   214  		}
   215  	}
   216  }