github.com/grailbio/base@v0.0.11/simd/count_amd64.go (about)

     1  // Copyright 2021 GRAIL, Inc.  All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build amd64 && !appengine
     6  // +build amd64,!appengine
     7  
     8  // This is derived from github.com/willf/bitset .
     9  
    10  package simd
    11  
    12  import (
    13  	"math/bits"
    14  	"reflect"
    15  	"unsafe"
    16  )
    17  
    18  // *** the following function is defined in count_amd64.s
    19  
    20  //go:noescape
    21  func popcntWordArraySSE42Asm(bytes unsafe.Pointer, nWord int) int
    22  
    23  // Although mask and val are really byte parameters, actually declaring them as
    24  // bytes instead of ints in the function signature produces a *massive*
    25  // performance penalty.
    26  
    27  //go:noescape
    28  func maskThenCountByteSSE41Asm(src unsafe.Pointer, mask, val, nByte int) int
    29  
    30  //go:noescape
    31  func count2BytesSSE41Asm(src unsafe.Pointer, val1, val2, nByte int) int
    32  
    33  //go:noescape
    34  func count3BytesSSE41Asm(src unsafe.Pointer, val1, val2, val3, nByte int) int
    35  
    36  //go:noescape
    37  func countNibblesInSetSSE41Asm(src unsafe.Pointer, tablePtr *NibbleLookupTable, nByte int) int
    38  
    39  func countNibblesInTwoSetsSSE41Asm(cnt2Ptr *int, src unsafe.Pointer, table1Ptr, table2Ptr *NibbleLookupTable, nByte int) int
    40  
    41  //go:noescape
    42  func countUnpackedNibblesInSetSSE41Asm(src unsafe.Pointer, tablePtr *NibbleLookupTable, nByte int) int
    43  
    44  //go:noescape
    45  func countUnpackedNibblesInTwoSetsSSE41Asm(cnt2Ptr *int, src unsafe.Pointer, table1Ptr, table2Ptr *NibbleLookupTable, nByte int) int
    46  
    47  //go:noescape
    48  func accumulate8SSE41Asm(src unsafe.Pointer, nByte int) int
    49  
    50  //go:noescape
    51  func accumulate8GreaterSSE41Asm(src unsafe.Pointer, val, nByte int) int
    52  
    53  // *** end assembly function signature(s)
    54  
    55  // PopcntUnsafe returns the number of set bits in the given []byte, assuming
    56  // that any trailing bytes up to the next multiple of BytesPerWord are zeroed
    57  // out.
    58  func PopcntUnsafe(bytes []byte) int {
    59  	// Get the base-pointer for the slice, in a way that doesn't trigger a
    60  	// bounds-check and fail when length == 0.  (Yes, I found out during testing
    61  	// that the &bytes[0] idiom doesn't actually work in the length-0
    62  	// case...)
    63  	bytesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&bytes))
    64  
    65  	return popcntWordArraySSE42Asm(unsafe.Pointer(bytesHeader.Data), DivUpPow2(len(bytes), BytesPerWord, Log2BytesPerWord))
    66  }
    67  
    68  // Popcnt returns the number of set bits in the given []byte.
    69  //
    70  // Some effort has been made to make this run acceptably fast on relatively
    71  // short arrays, since I expect knowing how to do so to be helpful when working
    72  // with hundreds of millions of .bam reads with ~75 bytes of base data and ~150
    73  // bytes of quality data.  Interestingly, moving the leading-byte handling code
    74  // to assembly didn't make a difference.
    75  //
    76  // Some single-threaded benchmark results calling Popcnt 99999999 times on a
    77  // 14-byte unaligned array:
    78  //   C implementation: 0.219-0.232s
    79  //   This code: 0.606-0.620s
    80  //   C implementation using memcpy for trailing bytes: 0.964-0.983s
    81  // So Go's extra looping and function call overhead can almost triple runtime
    82  // in the short-array limit, but that's actually not as bad as the 4.5x
    83  // overhead of trusting memcpy to handle trailing bytes.
    84  func Popcnt(bytes []byte) int {
    85  	bytesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&bytes))
    86  	nByte := len(bytes)
    87  
    88  	bytearr := unsafe.Pointer(bytesHeader.Data)
    89  	tot := 0
    90  	nLeadingByte := nByte & (BytesPerWord - 1)
    91  	if nLeadingByte != 0 {
    92  		leadingWord := uint64(0)
    93  		if (nLeadingByte & 1) != 0 {
    94  			leadingWord = (uint64)(*(*byte)(bytearr))
    95  			bytearr = unsafe.Add(bytearr, 1)
    96  		}
    97  		if (nLeadingByte & 2) != 0 {
    98  			// Note that this does not keep the bytes in the original little-endian
    99  			// order, since that's irrelevant for popcount, and probably everything
   100  			// else we need to do in this package.  See ProperSubwordLoad() in
   101  			// plink2_base.h for code which does keep the bytes in order.
   102  			leadingWord <<= 16
   103  			leadingWord |= (uint64)(*(*uint16)(bytearr))
   104  			bytearr = unsafe.Add(bytearr, 2)
   105  		}
   106  		if (nLeadingByte & 4) != 0 {
   107  			leadingWord <<= 32
   108  			leadingWord |= (uint64)(*(*uint32)(bytearr))
   109  			bytearr = unsafe.Add(bytearr, 4)
   110  		}
   111  		tot = bits.OnesCount64(leadingWord)
   112  	}
   113  	tot += popcntWordArraySSE42Asm(bytearr, nByte>>Log2BytesPerWord)
   114  	return tot
   115  }
   116  
   117  // We may want a PopcntW function in the future which operates on a []uintptr,
   118  // along with AndW, OrW, XorW, InvmaskW, etc.  This would amount to a
   119  // lower-overhead version of willf/bitset (which also uses []uintptr
   120  // internally).
   121  // The main thing I would want to benchmark before making that decision is
   122  // bitset.NextSetMany() vs. a loop of the form
   123  //   uidx_base := 0
   124  //   cur_bits := bitarr[0]
   125  //   for idx := 0; idx != nSetBit; idx++ {
   126  //     // see plink2_base.h BitIter1()
   127  //     if cur_bits == 0 {
   128  //       widx := uidx_base >> (3 + Log2BytesPerWord)
   129  //       for {
   130  //         widx++
   131  //         cur_bits = bitarr[widx]
   132  //         if cur_bits != 0 {
   133  //           break
   134  //         }
   135  //       }
   136  //       uidx_base = widx << (3 + Log2BytesPerWord)
   137  //     }
   138  //     uidx := uidx_base + bits.TrailingZeros(uint(cur_bits))
   139  //     cur_bits = cur_bits & (cur_bits - 1)
   140  //     // (do something with uidx, possibly very simple)
   141  //   }
   142  // (Note that there are *hundreds* of loops of this form in plink2.)
   143  // If bitset.NextSetMany() does not impose a large performance penalty, we may
   144  // just want to write a version of it which takes a []byte as input.
   145  // (update: https://go-review.googlesource.com/c/go/+/109716 suggests that
   146  // bitset.NextSetMany() is not good enough.)
   147  
   148  // todo: add ZeroTrailingBits, etc. once we need it
   149  
   150  // MaskThenCountByte counts the number of bytes in src[] satisfying
   151  //   src[pos] & mask == val.
   152  func MaskThenCountByte(src []byte, mask, val byte) int {
   153  	// This is especially useful for CG counting:
   154  	// - Count 'C'/'G' ASCII characters: mask = 0xfb (only capital) or 0xdb
   155  	//   (either capital or lowercase), val = 'C'
   156  	// - Count C/G bytes in .bam unpacked seq8 data, assuming '=' is not in
   157  	//   input: mask = 0x9, val = 0
   158  	// It can also be used to ignore capitalization when counting instances of a
   159  	// single letter.
   160  	nByte := len(src)
   161  	if nByte < 16 {
   162  		cnt := 0
   163  		for _, srcByte := range src {
   164  			if (srcByte & mask) == val {
   165  				cnt++
   166  			}
   167  		}
   168  		return cnt
   169  	}
   170  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   171  	return maskThenCountByteSSE41Asm(unsafe.Pointer(srcHeader.Data), int(mask), int(val), nByte)
   172  }
   173  
   174  // Count2Bytes counts the number of bytes in src[] which are equal to either
   175  // val1 or val2.
   176  // (bytes.Count() should be good enough for a single byte.)
   177  func Count2Bytes(src []byte, val1, val2 byte) int {
   178  	nByte := len(src)
   179  	if nByte < 16 {
   180  		cnt := 0
   181  		for _, srcByte := range src {
   182  			if (srcByte == val1) || (srcByte == val2) {
   183  				cnt++
   184  			}
   185  		}
   186  		return cnt
   187  	}
   188  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   189  	return count2BytesSSE41Asm(unsafe.Pointer(srcHeader.Data), int(val1), int(val2), nByte)
   190  }
   191  
   192  // Count3Bytes counts the number of bytes in src[] which are equal to val1,
   193  // val2, or val3.
   194  func Count3Bytes(src []byte, val1, val2, val3 byte) int {
   195  	nByte := len(src)
   196  	if nByte < 16 {
   197  		cnt := 0
   198  		for _, srcByte := range src {
   199  			if (srcByte == val1) || (srcByte == val2) || (srcByte == val3) {
   200  				cnt++
   201  			}
   202  		}
   203  		return cnt
   204  	}
   205  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   206  	return count3BytesSSE41Asm(unsafe.Pointer(srcHeader.Data), int(val1), int(val2), int(val3), nByte)
   207  }
   208  
   209  // CountNibblesInSet counts the number of nibbles in src[] which are in the
   210  // given set.  The set must be represented as table[x] == 1 when value x is in
   211  // the set, and table[x] == 0 when x isn't.
   212  //
   213  // WARNING: This function does not validate the table.  It may return a garbage
   214  // result on invalid input.  (However, it won't corrupt memory.)
   215  func CountNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int {
   216  	nSrcByte := len(src)
   217  	if nSrcByte < 16 {
   218  		cnt := 0
   219  		for _, srcByte := range src {
   220  			cnt += int(tablePtr.Get(srcByte&15) + tablePtr.Get(srcByte>>4))
   221  		}
   222  		return cnt
   223  	}
   224  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   225  	return countNibblesInSetSSE41Asm(unsafe.Pointer(srcHeader.Data), tablePtr, nSrcByte)
   226  }
   227  
   228  // CountNibblesInTwoSets counts the number of bytes in src[] which are in the
   229  // given two sets, assuming all bytes are <16.  The sets must be represented as
   230  // table[x] == 1 when value x is in the set, and table[x] == 0 when x isn't.
   231  //
   232  // WARNING: This function does not validate the tables.  It may crash or return
   233  // garbage results on invalid input.  (However, it won't corrupt memory.)
   234  func CountNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) {
   235  	nSrcByte := len(src)
   236  	cnt2 := 0
   237  	if nSrcByte < 16 {
   238  		cnt1 := 0
   239  		for _, srcByte := range src {
   240  			lowBits := srcByte & 15
   241  			highBits := srcByte >> 4
   242  			cnt1 += int(table1Ptr.Get(lowBits) + table1Ptr.Get(highBits))
   243  			cnt2 += int(table2Ptr.Get(lowBits) + table2Ptr.Get(highBits))
   244  		}
   245  		return cnt1, cnt2
   246  	}
   247  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   248  	cnt1 := countNibblesInTwoSetsSSE41Asm(&cnt2, unsafe.Pointer(srcHeader.Data), table1Ptr, table2Ptr, nSrcByte)
   249  	return cnt1, cnt2
   250  }
   251  
   252  // CountUnpackedNibblesInSet counts the number of bytes in src[] which are in
   253  // the given set, assuming all bytes are <16.  The set must be represented as
   254  // table[x] == 1 when value x is in the set, and table[x] == 0 when x isn't.
   255  //
   256  // WARNING: This function does not validate the table.  It may crash or return
   257  // a garbage result on invalid input.  (However, it won't corrupt memory.)
   258  func CountUnpackedNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int {
   259  	nSrcByte := len(src)
   260  	if nSrcByte < 16 {
   261  		cnt := 0
   262  		for _, srcByte := range src {
   263  			cnt += int(tablePtr.Get(srcByte))
   264  		}
   265  		return cnt
   266  	}
   267  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   268  	return countUnpackedNibblesInSetSSE41Asm(unsafe.Pointer(srcHeader.Data), tablePtr, nSrcByte)
   269  }
   270  
   271  // CountUnpackedNibblesInTwoSets counts the number of bytes in src[] which are
   272  // in the given two sets, assuming all bytes are <16.  The sets must be
   273  // represented as table[x] == 1 when value x is in the set, and table[x] == 0
   274  // when x isn't.
   275  //
   276  // WARNING: This function does not validate the tables.  It may crash or return
   277  // garbage results on invalid input.  (However, it won't corrupt memory.)
   278  func CountUnpackedNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) {
   279  	// Building this out now so that biosimd.PackedSeqCountTwo is not a valid
   280  	// reason to stick to packed .bam seq[] representation.
   281  	nSrcByte := len(src)
   282  	cnt2 := 0
   283  	if nSrcByte < 16 {
   284  		cnt1 := 0
   285  		for _, srcByte := range src {
   286  			cnt1 += int(table1Ptr.Get(srcByte))
   287  			cnt2 += int(table2Ptr.Get(srcByte))
   288  		}
   289  		return cnt1, cnt2
   290  	}
   291  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   292  	cnt1 := countUnpackedNibblesInTwoSetsSSE41Asm(&cnt2, unsafe.Pointer(srcHeader.Data), table1Ptr, table2Ptr, nSrcByte)
   293  	return cnt1, cnt2
   294  }
   295  
   296  // (could rename Popcnt to Accumulate1 for consistency...)
   297  
   298  // Accumulate8 returns the sum of the (unsigned) bytes in src[].
   299  func Accumulate8(src []byte) int {
   300  	nSrcByte := len(src)
   301  	if nSrcByte < 16 {
   302  		cnt := 0
   303  		for _, srcByte := range src {
   304  			cnt += int(srcByte)
   305  		}
   306  		return cnt
   307  	}
   308  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   309  	return accumulate8SSE41Asm(unsafe.Pointer(srcHeader.Data), nSrcByte)
   310  }
   311  
   312  // Accumulate8Greater returns the sum of all bytes in src[] greater than the
   313  // given value.
   314  func Accumulate8Greater(src []byte, val byte) int {
   315  	nSrcByte := len(src)
   316  	if nSrcByte < 16 {
   317  		cnt := 0
   318  		for _, srcByte := range src {
   319  			if srcByte > val {
   320  				cnt += int(srcByte)
   321  			}
   322  		}
   323  		return cnt
   324  	}
   325  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   326  	return accumulate8GreaterSSE41Asm(unsafe.Pointer(srcHeader.Data), int(val), nSrcByte)
   327  }