github.com/consensys/gnark-crypto@v0.14.0/internal/generator/fft/template/bitreverse.go.tmpl (about) 1 import ( 2 "math/bits" 3 "runtime" 4 {{ template "import_fr" . }} 5 ) 6 7 // BitReverse applies the bit-reversal permutation to v. 8 // len(v) must be a power of 2 9 func BitReverse(v []fr.Element) { 10 n := uint64(len(v)) 11 if bits.OnesCount64(n) != 1 { 12 panic("len(a) must be a power of 2") 13 } 14 15 if runtime.GOARCH == "arm64" { 16 bitReverseNaive(v) 17 } else { 18 bitReverseCobra(v) 19 } 20 } 21 22 // bitReverseNaive applies the bit-reversal permutation to v. 23 // len(v) must be a power of 2 24 func bitReverseNaive(v []fr.Element) { 25 n := uint64(len(v)) 26 nn := uint64(64 - bits.TrailingZeros64(n)) 27 28 for i := uint64(0); i < n; i++ { 29 iRev := bits.Reverse64(i) >> nn 30 if iRev > i { 31 v[i], v[iRev] = v[iRev], v[i] 32 } 33 } 34 } 35 36 37 // bitReverseCobraInPlace applies the bit-reversal permutation to v. 38 // len(v) must be a power of 2 39 // This is derived from: 40 // 41 // - Towards an Optimal Bit-Reversal Permutation Program 42 // Larry Carter and Kang Su Gatlin, 1998 43 // https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf 44 // 45 // - Practically efficient methods for performing bit-reversed 46 // permutation in C++11 on the x86-64 architecture 47 // Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017 48 // https://arxiv.org/pdf/1708.01873.pdf 49 // 50 // - and more specifically, constantine implementation: 51 // https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205 52 // by Mamy Ratsimbazafy (@mratsim). 53 // 54 func bitReverseCobraInPlace(v []fr.Element) { 55 logN := uint64(bits.Len64(uint64(len(v))) - 1) 56 logTileSize := deriveLogTileSize(logN) 57 logBLen := logN - 2*logTileSize 58 bLen := uint64(1) << logBLen 59 bShift := logBLen + logTileSize 60 tileSize := uint64(1) << logTileSize 61 62 // rough idea; 63 // bit reversal permutation naive implementation may have some cache associativity issues, 64 // since we are accessing elements by strides of powers of 2. 65 // on large inputs, this is noticeable and can be improved by using a t buffer. 66 // idea is for t buffer to be small enough to fit in cache. 67 // in the first inner loop, we copy the elements of v into t in a bit-reversed order. 68 // in the subsequent inner loops, accesses have much better cache locality than the naive implementation. 69 // hence even if we apparently do more work (swaps / copies), we are faster. 70 // 71 // on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster, 72 // in most cases. 73 // on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x). 74 // 75 // optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache; 76 // in practice, a common size for L1 is 64kb, a field element is 32bytes or more. 77 // hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts. 78 // 79 // for most sizes of interest, this tile size choice doesn't yield good results; 80 // we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+. 81 t := make([]fr.Element, tileSize*tileSize) 82 83 84 // see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf 85 // for a detailed explanation of the algorithm. 86 for b := uint64(0); b < bLen; b++ { 87 88 for a := uint64(0); a < tileSize; a++ { 89 aRev :=( bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize 90 for c := uint64(0); c < tileSize; c++ { 91 idx := (a << bShift) | (b << logTileSize) | c 92 t[aRev | c] = v[idx] 93 } 94 } 95 96 bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize 97 98 for c := uint64(0); c < tileSize; c++ { 99 cRev := ((bits.Reverse64(c) >> (64 - logTileSize)) << bShift) | bRev 100 for aRev := uint64(0); aRev < tileSize; aRev++ { 101 a := bits.Reverse64(aRev) >> (64 - logTileSize) 102 idx := (a << bShift) | (b << logTileSize) | c 103 idxRev := cRev | aRev 104 if idx < idxRev { 105 tIdx := (aRev << logTileSize) | c 106 v[idxRev], t[tIdx] = t[tIdx], v[idxRev] 107 } 108 } 109 } 110 111 for a := uint64(0); a < tileSize; a++ { 112 aRev := bits.Reverse64(a) >> (64 - logTileSize) 113 for c := uint64(0); c < tileSize; c++ { 114 cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift 115 idx := (a << bShift) | (b << logTileSize) | c 116 idxRev := cRev | bRev | aRev 117 if idx < idxRev { 118 tIdx := (aRev << logTileSize) | c 119 v[idx], t[tIdx] = t[tIdx], v[idx] 120 } 121 } 122 } 123 } 124 } 125 126 127 func bitReverseCobra(v []fr.Element) { 128 switch len(v) { 129 case 1 << 21: 130 bitReverseCobraInPlace_9_21(v) 131 case 1 << 22: 132 bitReverseCobraInPlace_9_22(v) 133 case 1 << 23: 134 bitReverseCobraInPlace_9_23(v) 135 case 1 << 24: 136 bitReverseCobraInPlace_9_24(v) 137 case 1 << 25: 138 bitReverseCobraInPlace_9_25(v) 139 case 1 << 26: 140 bitReverseCobraInPlace_9_26(v) 141 case 1 << 27: 142 bitReverseCobraInPlace_9_27(v) 143 default: 144 if len(v) > 1<<27 { 145 bitReverseCobraInPlace(v) 146 } else { 147 bitReverseNaive(v) 148 } 149 } 150 } 151 152 153 func deriveLogTileSize(logN uint64) uint64 { 154 q := uint64(9) // see bitReverseCobraInPlace for more details 155 156 for int(logN)-int(2*q) <= 0 { 157 q-- 158 } 159 160 return q 161 } 162 163 164 {{bitReverseCobraInPlace 9 21}} 165 {{bitReverseCobraInPlace 9 22}} 166 {{bitReverseCobraInPlace 9 23}} 167 {{bitReverseCobraInPlace 9 24}} 168 {{bitReverseCobraInPlace 9 25}} 169 {{bitReverseCobraInPlace 9 26}} 170 {{bitReverseCobraInPlace 9 27}} 171 172 173 {{define "bitReverseCobraInPlace logTileSize logN"}} 174 175 // bitReverseCobraInPlace_{{.logTileSize}}_{{.logN}} applies the bit-reversal permutation to v. 176 // len(v) must be 1 << {{.logN}}. 177 // see bitReverseCobraInPlace for more details; this function is specialized for {{.logTileSize}}, 178 // as it declares the t buffer and various constants statically for performance. 179 func bitReverseCobraInPlace_{{.logTileSize}}_{{.logN}}(v []fr.Element) { 180 const ( 181 logTileSize = uint64({{.logTileSize}}) 182 tileSize = uint64(1) << logTileSize 183 logN = {{.logN}} 184 logBLen = logN - 2*logTileSize 185 bShift = logBLen + logTileSize 186 bLen = uint64(1) << logBLen 187 ) 188 189 var t [tileSize * tileSize]fr.Element 190 {{$k := sub 64 .logTileSize}} 191 {{$l := .logTileSize}} 192 {{$tileSize := shl 1 .logTileSize}} 193 194 for b := uint64(0); b < bLen; b++ { 195 196 for a := uint64(0); a < tileSize; a++ { 197 aRev := (bits.Reverse64(a) >> {{$k}}) << logTileSize 198 for c := uint64(0); c < tileSize; c++ { 199 idx := (a << bShift) | (b << logTileSize) | c 200 t[aRev | c] = v[idx] 201 } 202 } 203 204 bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize 205 206 for c := uint64(0); c < tileSize; c++ { 207 cRev := ((bits.Reverse64(c) >> {{$k}}) << bShift) | bRev 208 for aRev := uint64(0); aRev < tileSize; aRev++ { 209 a := bits.Reverse64(aRev) >> {{$k}} 210 idx := (a << bShift) | (b << logTileSize) | c 211 idxRev := cRev | aRev 212 if idx < idxRev { 213 tIdx := (aRev << logTileSize) | c 214 v[idxRev], t[tIdx] = t[tIdx], v[idxRev] 215 } 216 } 217 } 218 219 for a := uint64(0); a < tileSize; a++ { 220 aRev := bits.Reverse64(a) >> {{$k}} 221 for c := uint64(0); c < tileSize; c++ { 222 cRev := (bits.Reverse64(c) >> {{$k}}) << bShift 223 idx := (a << bShift) | (b << logTileSize) | c 224 idxRev := cRev | bRev | aRev 225 if idx < idxRev { 226 tIdx := (aRev << logTileSize) | c 227 v[idx], t[tIdx] = t[tIdx], v[idx] 228 } 229 } 230 } 231 } 232 233 234 } 235 236 {{- end}}