github.com/remyoudompheng/bigfft@v0.0.0-20230129092748-24d4a6f8daec/fermat.go (about) 1 package bigfft 2 3 import ( 4 "math/big" 5 ) 6 7 // Arithmetic modulo 2^n+1. 8 9 // A fermat of length w+1 represents a number modulo 2^(w*_W) + 1. The last 10 // word is zero or one. A number has at most two representatives satisfying the 11 // 0-1 last word constraint. 12 type fermat nat 13 14 func (n fermat) String() string { return nat(n).String() } 15 16 func (z fermat) norm() { 17 n := len(z) - 1 18 c := z[n] 19 if c == 0 { 20 return 21 } 22 if z[0] >= c { 23 z[n] = 0 24 z[0] -= c 25 return 26 } 27 // z[0] < z[n]. 28 subVW(z, z, c) // Substract c 29 if c > 1 { 30 z[n] -= c - 1 31 c = 1 32 } 33 // Add back c. 34 if z[n] == 1 { 35 z[n] = 0 36 return 37 } else { 38 addVW(z, z, 1) 39 } 40 } 41 42 // Shift computes (x << k) mod (2^n+1). 43 func (z fermat) Shift(x fermat, k int) { 44 if len(z) != len(x) { 45 panic("len(z) != len(x) in Shift") 46 } 47 n := len(x) - 1 48 // Shift by n*_W is taking the opposite. 49 k %= 2 * n * _W 50 if k < 0 { 51 k += 2 * n * _W 52 } 53 neg := false 54 if k >= n*_W { 55 k -= n * _W 56 neg = true 57 } 58 59 kw, kb := k/_W, k%_W 60 61 z[n] = 1 // Add (-1) 62 if !neg { 63 for i := 0; i < kw; i++ { 64 z[i] = 0 65 } 66 // Shift left by kw words. 67 // x = a·2^(n-k) + b 68 // x<<k = (b<<k) - a 69 copy(z[kw:], x[:n-kw]) 70 b := subVV(z[:kw+1], z[:kw+1], x[n-kw:]) 71 if z[kw+1] > 0 { 72 z[kw+1] -= b 73 } else { 74 subVW(z[kw+1:], z[kw+1:], b) 75 } 76 } else { 77 for i := kw + 1; i < n; i++ { 78 z[i] = 0 79 } 80 // Shift left and negate, by kw words. 81 copy(z[:kw+1], x[n-kw:n+1]) // z_low = x_high 82 b := subVV(z[kw:n], z[kw:n], x[:n-kw]) // z_high -= x_low 83 z[n] -= b 84 } 85 // Add back 1. 86 if z[n] > 0 { 87 z[n]-- 88 } else if z[0] < ^big.Word(0) { 89 z[0]++ 90 } else { 91 addVW(z, z, 1) 92 } 93 // Shift left by kb bits 94 shlVU(z, z, uint(kb)) 95 z.norm() 96 } 97 98 // ShiftHalf shifts x by k/2 bits the left. Shifting by 1/2 bit 99 // is multiplication by sqrt(2) mod 2^n+1 which is 2^(3n/4) - 2^(n/4). 100 // A temporary buffer must be provided in tmp. 101 func (z fermat) ShiftHalf(x fermat, k int, tmp fermat) { 102 n := len(z) - 1 103 if k%2 == 0 { 104 z.Shift(x, k/2) 105 return 106 } 107 u := (k - 1) / 2 108 a := u + (3*_W/4)*n 109 b := u + (_W/4)*n 110 z.Shift(x, a) 111 tmp.Shift(x, b) 112 z.Sub(z, tmp) 113 } 114 115 // Add computes addition mod 2^n+1. 116 func (z fermat) Add(x, y fermat) fermat { 117 if len(z) != len(x) { 118 panic("Add: len(z) != len(x)") 119 } 120 addVV(z, x, y) // there cannot be a carry here. 121 z.norm() 122 return z 123 } 124 125 // Sub computes substraction mod 2^n+1. 126 func (z fermat) Sub(x, y fermat) fermat { 127 if len(z) != len(x) { 128 panic("Add: len(z) != len(x)") 129 } 130 n := len(y) - 1 131 b := subVV(z[:n], x[:n], y[:n]) 132 b += y[n] 133 // If b > 0, we need to subtract b<<n, which is the same as adding b. 134 z[n] = x[n] 135 if z[0] <= ^big.Word(0)-b { 136 z[0] += b 137 } else { 138 addVW(z, z, b) 139 } 140 z.norm() 141 return z 142 } 143 144 func (z fermat) Mul(x, y fermat) fermat { 145 if len(x) != len(y) { 146 panic("Mul: len(x) != len(y)") 147 } 148 n := len(x) - 1 149 if n < 30 { 150 z = z[:2*n+2] 151 basicMul(z, x, y) 152 z = z[:2*n+1] 153 } else { 154 var xi, yi, zi big.Int 155 xi.SetBits(x) 156 yi.SetBits(y) 157 zi.SetBits(z) 158 zb := zi.Mul(&xi, &yi).Bits() 159 if len(zb) <= n { 160 // Short product. 161 copy(z, zb) 162 for i := len(zb); i < len(z); i++ { 163 z[i] = 0 164 } 165 return z 166 } 167 z = zb 168 } 169 // len(z) is at most 2n+1. 170 if len(z) > 2*n+1 { 171 panic("len(z) > 2n+1") 172 } 173 // We now have 174 // z = z[:n] + 1<<(n*W) * z[n:2n+1] 175 // which normalizes to: 176 // z = z[:n] - z[n:2n] + z[2n] 177 c1 := big.Word(0) 178 if len(z) > 2*n { 179 c1 = addVW(z[:n], z[:n], z[2*n]) 180 } 181 c2 := big.Word(0) 182 if len(z) >= 2*n { 183 c2 = subVV(z[:n], z[:n], z[n:2*n]) 184 } else { 185 m := len(z) - n 186 c2 = subVV(z[:m], z[:m], z[n:]) 187 c2 = subVW(z[m:n], z[m:n], c2) 188 } 189 // Restore carries. 190 // Substracting z[n] -= c2 is the same 191 // as z[0] += c2 192 z = z[:n+1] 193 z[n] = c1 194 c := addVW(z, z, c2) 195 if c != 0 { 196 panic("impossible") 197 } 198 z.norm() 199 return z 200 } 201 202 // copied from math/big 203 // 204 // basicMul multiplies x and y and leaves the result in z. 205 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. 206 func basicMul(z, x, y fermat) { 207 // initialize z 208 for i := 0; i < len(z); i++ { 209 z[i] = 0 210 } 211 for i, d := range y { 212 if d != 0 { 213 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d) 214 } 215 } 216 }