github.com/consensys/gnark-crypto@v0.14.0/internal/generator/fft/template/bitreverse.go.tmpl (about)

     1  import (
     2  	"math/bits"
     3  	"runtime"
     4  	{{ template "import_fr" . }}
     5  )
     6  
     7  // BitReverse applies the bit-reversal permutation to v.
     8  // len(v) must be a power of 2
     9  func BitReverse(v []fr.Element) {
    10  	n := uint64(len(v))
    11  	if bits.OnesCount64(n) != 1 {
    12  		panic("len(a) must be a power of 2")
    13  	}
    14  
    15  	if runtime.GOARCH == "arm64" {
    16  		bitReverseNaive(v)
    17  	} else {
    18  		bitReverseCobra(v)
    19  	}
    20  }
    21  
    22  // bitReverseNaive applies the bit-reversal permutation to v.
    23  // len(v) must be a power of 2
    24  func bitReverseNaive(v []fr.Element) {
    25  	n := uint64(len(v))
    26  	nn := uint64(64 - bits.TrailingZeros64(n))
    27  
    28  	for i := uint64(0); i < n; i++ {
    29  		iRev := bits.Reverse64(i) >> nn
    30  		if iRev > i {
    31  			v[i], v[iRev] = v[iRev], v[i]
    32  		}
    33  	}
    34  }
    35  
    36  
    37  // bitReverseCobraInPlace applies the bit-reversal permutation to v.
    38  // len(v) must be a power of 2
    39  // This is derived from:
    40  //
    41  // - Towards an Optimal Bit-Reversal Permutation Program
    42  //   Larry Carter and Kang Su Gatlin, 1998
    43  //   https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf
    44  //
    45  // - Practically efficient methods for performing bit-reversed
    46  //   permutation in C++11 on the x86-64 architecture
    47  //   Knauth, Adas, Whitfield, Wang, Ickler, Conrad, Serang, 2017
    48  //   https://arxiv.org/pdf/1708.01873.pdf
    49  // 
    50  // - and more specifically, constantine implementation:
    51  //	 https://github.com/mratsim/constantine/blob/d51699248db04e29c7b1ad97e0bafa1499db00b5/constantine/math/polynomials/fft.nim#L205
    52  // 	 by Mamy Ratsimbazafy (@mratsim).
    53  //
    54  func bitReverseCobraInPlace(v []fr.Element) {
    55  	logN := uint64(bits.Len64(uint64(len(v))) - 1)
    56  	logTileSize := deriveLogTileSize(logN)
    57  	logBLen := logN - 2*logTileSize
    58  	bLen := uint64(1) << logBLen
    59  	bShift := logBLen + logTileSize
    60  	tileSize := uint64(1) << logTileSize
    61  
    62  	// rough idea;
    63  	// bit reversal permutation naive implementation may have some cache associativity issues,
    64  	// since we are accessing elements by strides of powers of 2.
    65  	// on large inputs, this is noticeable and can be improved by using a t buffer.
    66  	// idea is for t buffer to be small enough to fit in cache.
    67  	// in the first inner loop, we copy the elements of v into t in a bit-reversed order.
    68  	// in the subsequent inner loops, accesses have much better cache locality than the naive implementation.
    69  	// hence even if we apparently do more work (swaps / copies), we are faster.
    70  	// 
    71  	// on arm64 (and particularly on M1 macs), this is not noticeable, and the naive implementation is faster,
    72  	// in most cases.
    73  	// on x86 (and particularly on aws hpc6a) this is noticeable, and the t buffer implementation is faster (up to 3x).
    74  	//
    75  	// optimal choice for the tile size is cache dependent; in theory, we want the t buffer to fit in the L1 cache;
    76  	// in practice, a common size for L1 is 64kb, a field element is 32bytes or more.
    77  	// hence we can fit 2k elements in the L1 cache, which corresponds to a tile size of 2**5 with some margin for cache conflicts.
    78  	//
    79  	// for most sizes of interest, this tile size choice doesn't yield good results;
    80  	// we find that a tile size of 2**9 gives best results for input sizes from 2**21 up to 2**27+.
    81  	t := make([]fr.Element, tileSize*tileSize)
    82  
    83  
    84  	// see https://csaws.cs.technion.ac.il/~itai/Courses/Cache/bit.pdf
    85  	// for a detailed explanation of the algorithm.
    86  	for b := uint64(0); b < bLen; b++ {
    87  
    88  		for a := uint64(0); a < tileSize; a++ {
    89  			aRev :=( bits.Reverse64(a) >> (64 - logTileSize)) << logTileSize
    90  			for c := uint64(0); c < tileSize; c++ {
    91  				idx := (a << bShift) | (b << logTileSize) | c
    92  				t[aRev | c] = v[idx]
    93  			}
    94  		}
    95  
    96  		bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize
    97  
    98  		for c := uint64(0); c < tileSize; c++ {
    99  			cRev := ((bits.Reverse64(c) >> (64 - logTileSize))  << bShift) | bRev
   100  			for aRev := uint64(0); aRev < tileSize; aRev++ {
   101  				a := bits.Reverse64(aRev) >> (64 - logTileSize)
   102  				idx := (a << bShift) | (b << logTileSize) | c
   103  				idxRev := cRev | aRev
   104  				if idx < idxRev {
   105  					tIdx := (aRev << logTileSize) | c
   106  					v[idxRev], t[tIdx] = t[tIdx], v[idxRev]
   107  				}
   108  			}
   109  		}
   110  
   111  		for a := uint64(0); a < tileSize; a++ {
   112  			aRev := bits.Reverse64(a) >> (64 - logTileSize)
   113  			for c := uint64(0); c < tileSize; c++ {
   114  				cRev := (bits.Reverse64(c) >> (64 - logTileSize)) << bShift
   115  				idx := (a << bShift) | (b << logTileSize) | c
   116  				idxRev := cRev | bRev | aRev
   117  				if idx < idxRev {
   118  					tIdx := (aRev << logTileSize) | c
   119  					v[idx], t[tIdx] = t[tIdx], v[idx]
   120  				}
   121  			}
   122  		}
   123  	}
   124  }
   125  
   126  
   127  func bitReverseCobra(v []fr.Element) {
   128  	switch len(v) {
   129  	case 1 << 21:
   130  		bitReverseCobraInPlace_9_21(v)
   131  	case 1 << 22:
   132  		bitReverseCobraInPlace_9_22(v)
   133  	case 1 << 23:
   134  		bitReverseCobraInPlace_9_23(v)
   135  	case 1 << 24:
   136  		bitReverseCobraInPlace_9_24(v)
   137  	case 1 << 25:
   138  		bitReverseCobraInPlace_9_25(v)
   139  	case 1 << 26:
   140  		bitReverseCobraInPlace_9_26(v)
   141  	case 1 << 27:
   142  		bitReverseCobraInPlace_9_27(v)
   143  	default:
   144  		if len(v) > 1<<27 {
   145  			bitReverseCobraInPlace(v)
   146  		} else {
   147  			bitReverseNaive(v)
   148  		}
   149  	}
   150  }
   151  
   152  
   153  func deriveLogTileSize(logN uint64) uint64 {
   154  	q := uint64(9) // see bitReverseCobraInPlace for more details
   155  
   156  	for int(logN)-int(2*q) <= 0 {
   157  		q--
   158  	}
   159  
   160  	return q
   161  }
   162  
   163  
   164  {{bitReverseCobraInPlace 9 21}}
   165  {{bitReverseCobraInPlace 9 22}}
   166  {{bitReverseCobraInPlace 9 23}}
   167  {{bitReverseCobraInPlace 9 24}}
   168  {{bitReverseCobraInPlace 9 25}}
   169  {{bitReverseCobraInPlace 9 26}}
   170  {{bitReverseCobraInPlace 9 27}}
   171  
   172  
   173  {{define "bitReverseCobraInPlace logTileSize logN"}}
   174  
   175  // bitReverseCobraInPlace_{{.logTileSize}}_{{.logN}} applies the bit-reversal permutation to v.
   176  // len(v) must be 1 << {{.logN}}.
   177  // see bitReverseCobraInPlace for more details; this function is specialized for {{.logTileSize}},
   178  // as it declares the t buffer and various constants statically for performance.
   179  func bitReverseCobraInPlace_{{.logTileSize}}_{{.logN}}(v []fr.Element) {
   180  	const (
   181  		logTileSize = uint64({{.logTileSize}})
   182  		tileSize = uint64(1) << logTileSize
   183  		logN = {{.logN}}
   184  		logBLen = logN - 2*logTileSize
   185  		bShift = logBLen + logTileSize
   186  		bLen = uint64(1) << logBLen
   187  	)
   188  
   189  	var t [tileSize * tileSize]fr.Element
   190  	{{$k := sub 64  .logTileSize}}
   191  	{{$l := .logTileSize}}
   192  	{{$tileSize := shl 1 .logTileSize}}
   193  	
   194  	for b := uint64(0); b < bLen; b++ {
   195  		
   196  		for a := uint64(0); a < tileSize; a++ {
   197  			aRev := (bits.Reverse64(a) >> {{$k}}) << logTileSize
   198  			for c := uint64(0); c < tileSize; c++ {
   199  				idx := (a << bShift) | (b << logTileSize) | c
   200  				t[aRev | c] = v[idx]
   201  			}
   202  		}
   203  
   204  		bRev := (bits.Reverse64(b) >> (64 - logBLen)) << logTileSize
   205  
   206  		for c := uint64(0); c < tileSize; c++ {
   207  			cRev := ((bits.Reverse64(c) >> {{$k}}) << bShift) | bRev
   208  			for aRev := uint64(0); aRev < tileSize; aRev++ {
   209  				a := bits.Reverse64(aRev) >> {{$k}}
   210  				idx := (a << bShift) | (b << logTileSize) | c
   211  				idxRev := cRev | aRev
   212  				if idx < idxRev {
   213  					tIdx := (aRev << logTileSize) | c
   214  					v[idxRev], t[tIdx] = t[tIdx], v[idxRev]
   215  				}
   216  			}
   217  		}
   218  
   219  		for a := uint64(0); a < tileSize; a++ {
   220  			aRev := bits.Reverse64(a) >> {{$k}}
   221  			for c := uint64(0); c < tileSize; c++ {
   222  				cRev := (bits.Reverse64(c) >> {{$k}}) << bShift
   223  				idx := (a << bShift) | (b << logTileSize) | c
   224  				idxRev := cRev | bRev  | aRev
   225  				if idx < idxRev {
   226  					tIdx := (aRev << logTileSize) | c
   227  					v[idx], t[tIdx] = t[tIdx], v[idx]
   228  				}
   229  			}
   230  		}
   231  	}
   232  
   233  	
   234  }
   235  
   236  {{- end}}