github.com/grailbio/base@v0.0.11/simd/simd_generic.go (about)

     1  // Copyright 2018 GRAIL, Inc.  All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build !amd64 || appengine
     6  
     7  package simd
     8  
     9  // amd64 compile-time constants.
    10  
    11  // BytesPerWord is the number of bytes in a machine word.
    12  // We don't use unsafe.Sizeof(uintptr(1)) since there are advantages to having
    13  // this as an untyped constant, and there's essentially no drawback since this
    14  // is an _amd64-specific file.
    15  const BytesPerWord = 8
    16  
    17  // Log2BytesPerWord is log2(BytesPerWord).  This is relevant for manual
    18  // bit-shifting when we know that's a safe way to divide and the compiler does
    19  // not (e.g. dividend is of signed int type).
    20  const Log2BytesPerWord = uint(3)
    21  
    22  // BitsPerWord is the number of bits in a machine word.
    23  const BitsPerWord = BytesPerWord * 8
    24  
    25  // This must be at least <maximum supported vector size> / 16.
    26  const nibbleLookupDup = 1
    27  
    28  // NibbleLookupTable represents a parallel-byte-substitution operation f, where
    29  // every byte b in a byte-slice is replaced with
    30  //   f(b) := shuffle[0][b & 15] for b <= 127, and
    31  //   f(b) := 0 for b > 127.
    32  // (The second part is usually irrelevant in practice, but must be defined this
    33  // way to allow _mm_shuffle_epi8()/_mm256_shuffle_epi8()/_mm512_shuffle_epi8()
    34  // to be used to implement the operation efficiently.)
    35  // It's named NibbleLookupTable rather than ByteLookupTable since only the
    36  // bottom nibble of each byte can be used for table lookup.
    37  // It potentially stores multiple adjacent copies of the lookup table since
    38  // that speeds up the AVX2 and AVX-512 use cases (the table can be loaded with
    39  // a single _mm256_loadu_si256 operation, instead of e.g. _mm_loadu_si128
    40  // followed by _mm256_set_m128i with the same argument twice), and the typical
    41  // use case involves initializing very few tables and using them many, many
    42  // times.
    43  type NibbleLookupTable struct {
    44  	shuffle [nibbleLookupDup][16]byte
    45  }
    46  
    47  func (t *NibbleLookupTable) Get(b byte) byte {
    48  	return t.shuffle[0][b]
    49  }
    50  
    51  // const minPageSize = 4096  may be relevant for safe functions soon.
    52  
    53  // These could be compile-time constants for now, but not after AVX2
    54  // autodetection is added.
    55  
    56  // bytesPerVec is the size of the maximum-width vector that may be used.  It is
    57  // currently always 16, but it will be set to larger values at runtime in the
    58  // future when AVX2/AVX-512/etc. is detected.
    59  var bytesPerVec int
    60  
    61  // log2BytesPerVec supports efficient division by bytesPerVec.
    62  var log2BytesPerVec uint
    63  
    64  func init() {
    65  	bytesPerVec = 16
    66  	log2BytesPerVec = 4
    67  }
    68  
    69  // BytesPerVec is an accessor for the bytesPerVec package variable.
    70  func BytesPerVec() int {
    71  	return bytesPerVec
    72  }
    73  
    74  // RoundUpPow2 returns val rounded up to a multiple of alignment, assuming
    75  // alignment is a power of 2.
    76  func RoundUpPow2(val, alignment int) int {
    77  	return (val + alignment - 1) & (^(alignment - 1))
    78  }
    79  
    80  // DivUpPow2 efficiently divides a number by a power-of-2 divisor.  (This works
    81  // for negative dividends since the language specifies arithmetic right-shifts
    82  // of signed numbers.  I'm pretty sure this doesn't have a performance
    83  // penalty.)
    84  func DivUpPow2(dividend, divisor int, log2Divisor uint) int {
    85  	return (dividend + divisor - 1) >> log2Divisor
    86  }
    87  
    88  // MakeUnsafe returns a byte slice of the given length which is guaranteed to
    89  // have enough capacity for all Unsafe functions in this package to work.  (It
    90  // is not itself an unsafe function: allocated memory is zero-initialized.)
    91  // Note that Unsafe functions occasionally have other caveats: e.g.
    92  // PopcntUnsafe also requires relevant bytes past the end of the slice to be
    93  // zeroed out.
    94  func MakeUnsafe(len int) []byte {
    95  	// Although no planned function requires more than
    96  	// RoundUpPow2(len+1, bytesPerVec) capacity, it is necessary to add
    97  	// bytesPerVec instead to make subslicing safe.
    98  	return make([]byte, len, len+bytesPerVec)
    99  }
   100  
   101  // RemakeUnsafe reuses the given buffer if it has sufficient capacity;
   102  // otherwise it does the same thing as MakeUnsafe.  It does NOT preserve
   103  // existing contents of buf[]; use ResizeUnsafe() for that.
   104  func RemakeUnsafe(bufptr *[]byte, len int) {
   105  	minCap := len + bytesPerVec
   106  	// This is likely to be called in an inner loop processing variable-size
   107  	// inputs, so mild exponential growth is appropriate.
   108  	*bufptr = make([]byte, len, RoundUpPow2(minCap+(minCap/8), bytesPerVec))
   109  }
   110  
   111  // ResizeUnsafe changes the length of buf and ensures it has enough extra
   112  // capacity to be passed to this package's Unsafe functions.  Existing buf[]
   113  // contents are preserved (with possible truncation), though when length is
   114  // increased, new bytes might not be zero-initialized.
   115  func ResizeUnsafe(bufptr *[]byte, len int) {
   116  	minCap := len + bytesPerVec
   117  	dst := make([]byte, len, RoundUpPow2(minCap+(minCap/8), bytesPerVec))
   118  	copy(dst, *bufptr)
   119  	*bufptr = dst
   120  }
   121  
   122  // XcapUnsafe is shorthand for ResizeUnsafe's most common use case (no length
   123  // change, just want to ensure sufficient capacity).
   124  func XcapUnsafe(bufptr *[]byte) {
   125  	// mid-stack inlining isn't yet working as I write this, but it should be
   126  	// available soon enough:
   127  	//   https://github.com/golang/go/issues/19348
   128  	ResizeUnsafe(bufptr, len(*bufptr))
   129  }
   130  
   131  // Memset8Unsafe sets all values of dst[] to the given byte.  (This is intended
   132  // for val != 0.  It is better to use a range-for loop for val == 0 since the
   133  // compiler has a hardcoded optimization for that case; see
   134  // https://github.com/golang/go/issues/5373 .)
   135  //
   136  // WARNING: This is a function designed to be used in inner loops, which
   137  // assumes without checking that capacity is at least RoundUpPow2(len(dst),
   138  // bytesPerVec).  It also assumes that the caller does not care if a few bytes
   139  // past the end of dst[] are changed.  Use the safe version of this function if
   140  // any of these properties are problematic.
   141  // These assumptions are always satisfied when the last
   142  // potentially-size-increasing operation on dst[] is {Re}makeUnsafe(),
   143  // ResizeUnsafe(), or XcapUnsafe().
   144  func Memset8Unsafe(dst []byte, val byte) {
   145  	for pos := range dst {
   146  		dst[pos] = val
   147  	}
   148  }
   149  
   150  // Memset8 sets all values of dst[] to the given byte.  (This is intended for
   151  // val != 0.  It is better to use a range-for loop for val == 0 since the
   152  // compiler has a hardcoded optimization for that case.)
   153  func Memset8(dst []byte, val byte) {
   154  	for pos := range dst {
   155  		dst[pos] = val
   156  	}
   157  }
   158  
   159  // MakeNibbleLookupTable generates a NibbleLookupTable from a [16]byte.
   160  func MakeNibbleLookupTable(table [16]byte) (t NibbleLookupTable) {
   161  	for i := range t.shuffle {
   162  		t.shuffle[i] = table
   163  	}
   164  	return
   165  }
   166  
   167  // UnpackedNibbleLookupUnsafeInplace replaces the bytes in main[] as follows:
   168  //   if value < 128, set to table[value & 15]
   169  //   otherwise, set to 0
   170  //
   171  // WARNING: This is a function designed to be used in inner loops, which makes
   172  // assumptions about capacity which aren't checked at runtime.  Use the safe
   173  // version of this function when that's a problem.
   174  // These assumptions are always satisfied when the last
   175  // potentially-size-increasing operation on main[] is {Re}makeUnsafe(),
   176  // ResizeUnsafe(), or XcapUnsafe().
   177  //
   178  // 1. cap(main) must be at least RoundUpPow2(len(main) + 1, bytesPerVec).
   179  //
   180  // 2. The caller does not care if a few bytes past the end of main[] are
   181  // changed.
   182  func UnpackedNibbleLookupUnsafeInplace(main []byte, tablePtr *NibbleLookupTable) {
   183  	for pos, curByte := range main {
   184  		if curByte < 128 {
   185  			curByte = tablePtr.shuffle[0][curByte&15]
   186  		} else {
   187  			curByte = 0
   188  		}
   189  		main[pos] = curByte
   190  	}
   191  }
   192  
   193  // UnpackedNibbleLookupInplace replaces the bytes in main[] as follows:
   194  //   if value < 128, set to table[value & 15]
   195  //   otherwise, set to 0
   196  func UnpackedNibbleLookupInplace(main []byte, tablePtr *NibbleLookupTable) {
   197  	for pos, curByte := range main {
   198  		if curByte < 128 {
   199  			curByte = tablePtr.shuffle[0][curByte&15]
   200  		} else {
   201  			curByte = 0
   202  		}
   203  		main[pos] = curByte
   204  	}
   205  }
   206  
   207  // UnpackedNibbleLookupUnsafe sets the bytes in dst[] as follows:
   208  //   if src[pos] < 128, set dst[pos] := table[src[pos] & 15]
   209  //   otherwise, set dst[pos] := 0
   210  //
   211  // WARNING: This is a function designed to be used in inner loops, which makes
   212  // assumptions about length and capacity which aren't checked at runtime.  Use
   213  // the safe version of this function when that's a problem.
   214  // Assumptions #2-3 are always satisfied when the last
   215  // potentially-size-increasing operation on src[] is {Re}makeUnsafe(),
   216  // ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[].
   217  //
   218  // 1. len(src) and len(dst) are equal.
   219  //
   220  // 2. Capacities are at least RoundUpPow2(len(src) + 1, bytesPerVec).
   221  //
   222  // 3. The caller does not care if a few bytes past the end of dst[] are
   223  // changed.
   224  func UnpackedNibbleLookupUnsafe(dst, src []byte, tablePtr *NibbleLookupTable) {
   225  	for pos, curByte := range src {
   226  		if curByte < 128 {
   227  			curByte = tablePtr.shuffle[0][curByte&15]
   228  		} else {
   229  			curByte = 0
   230  		}
   231  		dst[pos] = curByte
   232  	}
   233  }
   234  
   235  // UnpackedNibbleLookup sets the bytes in dst[] as follows:
   236  //   if src[pos] < 128, set dst[pos] := table[src[pos] & 15]
   237  //   otherwise, set dst[pos] := 0
   238  // It panics if len(src) != len(dst).
   239  func UnpackedNibbleLookup(dst, src []byte, tablePtr *NibbleLookupTable) {
   240  	if len(dst) != len(src) {
   241  		panic("UnpackedNibbleLookup() requires len(src) == len(dst).")
   242  	}
   243  	for pos, curByte := range src {
   244  		if curByte < 128 {
   245  			curByte = tablePtr.shuffle[0][curByte&15]
   246  		} else {
   247  			curByte = 0
   248  		}
   249  		dst[pos] = curByte
   250  	}
   251  }
   252  
   253  // UnpackedNibbleLookupS is a variant of UnpackedNibbleLookup() that takes
   254  // string src.
   255  func UnpackedNibbleLookupS(dst []byte, src string, tablePtr *NibbleLookupTable) {
   256  	srcLen := len(src)
   257  	if len(dst) != srcLen {
   258  		panic("UnpackedNibbleLookupS() requires len(src) == len(dst).")
   259  	}
   260  	for pos := range src {
   261  		curByte := src[pos]
   262  		if curByte < 128 {
   263  			curByte = tablePtr.Get(curByte & 15)
   264  		} else {
   265  			curByte = 0
   266  		}
   267  		dst[pos] = curByte
   268  	}
   269  	return
   270  }
   271  
   272  // PackedNibbleLookupUnsafe sets the bytes in dst[] as follows:
   273  //   if pos is even, dst[pos] := table[src[pos / 2] & 15]
   274  //   if pos is odd, dst[pos] := table[src[pos / 2] >> 4]
   275  //
   276  // WARNING: This is a function designed to be used in inner loops, which makes
   277  // assumptions about length and capacity which aren't checked at runtime.  Use
   278  // the safe version of this function when that's a problem.
   279  // Assumptions #2-#3 are always satisfied when the last
   280  // potentially-size-increasing operation on src[] is {Re}makeUnsafe(),
   281  // ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[].
   282  //
   283  // 1. len(src) == (len(dst) + 1) / 2.
   284  //
   285  // 2. Capacity of src is at least RoundUpPow2(len(src) + 1, bytesPerVec), and
   286  // the same is true for dst.
   287  //
   288  // 3. The caller does not care if a few bytes past the end of dst[] are
   289  // changed.
   290  func PackedNibbleLookupUnsafe(dst, src []byte, tablePtr *NibbleLookupTable) {
   291  	dstLen := len(dst)
   292  	nSrcFullByte := dstLen >> 1
   293  	srcOdd := dstLen & 1
   294  	for srcPos := 0; srcPos < nSrcFullByte; srcPos++ {
   295  		srcByte := src[srcPos]
   296  		dst[2*srcPos] = tablePtr.shuffle[0][srcByte&15]
   297  		dst[2*srcPos+1] = tablePtr.shuffle[0][srcByte>>4]
   298  	}
   299  	if srcOdd == 1 {
   300  		srcByte := src[nSrcFullByte]
   301  		dst[2*nSrcFullByte] = tablePtr.shuffle[0][srcByte&15]
   302  	}
   303  }
   304  
   305  // PackedNibbleLookup sets the bytes in dst[] as follows:
   306  //   if pos is even, dst[pos] := table[src[pos / 2] & 15]
   307  //   if pos is odd, dst[pos] := table[src[pos / 2] >> 4]
   308  // It panics if len(src) != (len(dst) + 1) / 2.
   309  //
   310  // Nothing bad happens if len(dst) is odd and some high bits in the last src[]
   311  // byte are set, though it's generally good practice to ensure that case
   312  // doesn't come up.
   313  func PackedNibbleLookup(dst, src []byte, tablePtr *NibbleLookupTable) {
   314  	dstLen := len(dst)
   315  	nSrcFullByte := dstLen >> 1
   316  	srcOdd := dstLen & 1
   317  	if len(src) != nSrcFullByte+srcOdd {
   318  		panic("PackedNibbleLookup() requires len(src) == (len(dst) + 1) / 2.")
   319  	}
   320  	for srcPos := 0; srcPos < nSrcFullByte; srcPos++ {
   321  		srcByte := src[srcPos]
   322  		dst[2*srcPos] = tablePtr.shuffle[0][srcByte&15]
   323  		dst[2*srcPos+1] = tablePtr.shuffle[0][srcByte>>4]
   324  	}
   325  	if srcOdd == 1 {
   326  		srcByte := src[nSrcFullByte]
   327  		dst[2*nSrcFullByte] = tablePtr.shuffle[0][srcByte&15]
   328  	}
   329  }
   330  
   331  // Interleave8Unsafe sets the bytes in dst[] as follows:
   332  //   if pos is even, dst[pos] := even[pos/2]
   333  //   if pos is odd, dst[pos] := odd[pos/2]
   334  //
   335  // WARNING: This is a function designed to be used in inner loops, which makes
   336  // assumptions about length and capacity which aren't checked at runtime.  Use
   337  // the safe version of this function when that's a problem.
   338  // Assumptions #2-3 are always satisfied when the last
   339  // potentially-size-increasing operation on dst[] is {Re}makeUnsafe(),
   340  // ResizeUnsafe(), or XcapUnsafe(), and the same is true for even[] and odd[].
   341  //
   342  // 1. len(even) = (len(dst) + 1) / 2, and len(odd) = len(dst) / 2.
   343  //
   344  // 2. cap(dst) >= RoundUpPow2(len(dst) + 1, bytesPerVec),
   345  // cap(even) >= RoundUpPow2(len(even) + 1, bytesPerVec), and
   346  // cap(odd) >= RoundUpPow2(len(odd) + 1, bytesPerVec).
   347  //
   348  // 3. The caller does not care if a few bytes past the end of dst[] are
   349  // changed.
   350  func Interleave8Unsafe(dst, even, odd []byte) {
   351  	dstLen := len(dst)
   352  	evenLen := (dstLen + 1) >> 1
   353  	oddLen := dstLen >> 1
   354  	for idx, oddByte := range odd {
   355  		dst[2*idx] = even[idx]
   356  		dst[2*idx+1] = oddByte
   357  	}
   358  	if oddLen != evenLen {
   359  		dst[oddLen*2] = even[oddLen]
   360  	}
   361  }
   362  
   363  // Interleave8 sets the bytes in dst[] as follows:
   364  //   if pos is even, dst[pos] := even[pos/2]
   365  //   if pos is odd, dst[pos] := odd[pos/2]
   366  // It panics if ((len(dst) + 1) / 2) != len(even), or (len(dst) / 2) !=
   367  // len(odd).
   368  func Interleave8(dst, even, odd []byte) {
   369  	// This is ~6-20% slower than the unsafe function on the short-array
   370  	// benchmark.
   371  	dstLen := len(dst)
   372  	evenLen := (dstLen + 1) >> 1
   373  	oddLen := dstLen >> 1
   374  	if (len(even) != evenLen) || (len(odd) != oddLen) {
   375  		panic("Interleave8() requires len(even) == len(dst) + 1) / 2, and len(odd) == len(dst) / 2.")
   376  	}
   377  	for idx, oddByte := range odd {
   378  		dst[2*idx] = even[idx]
   379  		dst[2*idx+1] = oddByte
   380  	}
   381  	if oddLen != evenLen {
   382  		dst[oddLen*2] = even[oddLen]
   383  	}
   384  }
   385  
   386  // Reverse8Inplace reverses the bytes in main[].  (There is no unsafe version
   387  // of this function.)
   388  func Reverse8Inplace(main []byte) {
   389  	nByte := len(main)
   390  	nByteDiv2 := nByte >> 1
   391  	for idx, invIdx := 0, nByte-1; idx != nByteDiv2; idx, invIdx = idx+1, invIdx-1 {
   392  		main[idx], main[invIdx] = main[invIdx], main[idx]
   393  	}
   394  }
   395  
   396  // Reverse8Unsafe sets dst[pos] := src[len(src) - 1 - pos] for every position
   397  // in src.
   398  //
   399  // WARNING: This does not verify len(dst) == len(src); call the safe version of
   400  // this function if you want that.
   401  func Reverse8Unsafe(dst, src []byte) {
   402  	nByte := len(src)
   403  	nByteMinus1 := nByte - 1
   404  	for idx := 0; idx != nByte; idx++ {
   405  		dst[nByteMinus1-idx] = src[idx]
   406  	}
   407  }
   408  
   409  // Reverse8 sets dst[pos] := src[len(src) - 1 - pos] for every position in src.
   410  // It panics if len(src) != len(dst).
   411  func Reverse8(dst, src []byte) {
   412  	nByte := len(src)
   413  	if nByte != len(dst) {
   414  		panic("Reverse8() requires len(src) == len(dst).")
   415  	}
   416  	nByteMinus1 := nByte - 1
   417  	for idx := 0; idx != nByte; idx++ {
   418  		dst[nByteMinus1-idx] = src[idx]
   419  	}
   420  }
   421  
   422  // BitFromEveryByte fills dst[] with a bitarray containing every 8th bit from
   423  // src[], starting with bitIdx, where bitIdx is in [0,7].  If len(src) is not
   424  // divisible by 8, extra bits in the last filled byte of dst are set to zero.
   425  // For example, if src[] is
   426  //   0x1f 0x33 0x0d 0x00 0x51 0xcc 0x34 0x59 0x44
   427  // and bitIdx is 2, bit 2 from every byte is
   428  //      1    0    1    0    0    1    1    0    1
   429  // so dst[] is filled with
   430  //   0x65 0x01.
   431  //
   432  // - It panics if len(dst) < (len(src) + 7) / 8, or bitIdx isn't in [0,7].
   433  // - If dst is larger than necessary, the extra bytes are not changed.
   434  func BitFromEveryByte(dst, src []byte, bitIdx int) {
   435  	requiredDstLen := (len(src) + 7) >> 3
   436  	if (len(dst) < requiredDstLen) || (uint(bitIdx) > 7) {
   437  		panic("BitFromEveryByte requires len(dst) >= (len(src) + 7) / 8 and 0 <= bitIdx < 8.")
   438  	}
   439  	dst = dst[:requiredDstLen]
   440  	for i := range dst {
   441  		dst[i] = 0
   442  	}
   443  	for i, b := range src {
   444  		dst[i>>3] |= ((b >> uint32(bitIdx)) & 1) << uint32(i&7)
   445  	}
   446  }