github.com/grailbio/base@v0.0.11/simd/count_generic.go (about)

     1  // Copyright 2018 GRAIL, Inc.  All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build !amd64 appengine
     6  
     7  // This is derived from github.com/willf/bitset .
     8  
     9  package simd
    10  
    11  import "math/bits"
    12  
    13  // PopcntUnsafe returns the number of set bits in the given []byte, assuming
    14  // that any trailing bytes up to the next multiple of BytesPerWord are zeroed
    15  // out.
    16  func PopcntUnsafe(bytes []byte) int {
    17  	// Get the base-pointer for the slice, in a way that doesn't trigger a
    18  	// bounds-check and fail when length == 0.  (Yes, I found out during testing
    19  	// that the &bytes[0] idiom doesn't actually work in the length-0
    20  	// case...)
    21  	cnt := 0
    22  	for _, b := range bytes {
    23  		cnt += bits.OnesCount8(uint8(b))
    24  	}
    25  	return cnt
    26  }
    27  
    28  // Popcnt returns the number of set bits in the given []byte.
    29  //
    30  // Some effort has been made to make this run acceptably fast on relatively
    31  // short arrays, since I expect knowing how to do so to be helpful when working
    32  // with hundreds of millions of .bam reads with ~75 bytes of base data and ~150
    33  // bytes of quality data.  Interestingly, moving the leading-byte handling code
    34  // to assembly didn't make a difference.
    35  //
    36  // Some single-threaded benchmark results calling Popcnt 99999999 times on a
    37  // 14-byte unaligned array:
    38  //   C implementation: 0.219-0.232s
    39  //   This code: 0.606-0.620s
    40  //   C implementation using memcpy for trailing bytes: 0.964-0.983s
    41  // So Go's extra looping and function call overhead can almost triple runtime
    42  // in the short-array limit, but that's actually not as bad as the 4.5x
    43  // overhead of trusting memcpy to handle trailing bytes.
    44  func Popcnt(bytes []byte) int {
    45  	cnt := 0
    46  	for _, b := range bytes {
    47  		cnt += bits.OnesCount8(uint8(b))
    48  	}
    49  	return cnt
    50  }
    51  
    52  // We may want a PopcntW function in the future which operates on a []uintptr,
    53  // along with AndW, OrW, XorW, InvmaskW, etc.  This would amount to a
    54  // lower-overhead version of willf/bitset (which also uses []uintptr
    55  // internally).
    56  // The main thing I would want to benchmark before making that decision is
    57  // bitset.NextSetMany() vs. a loop of the form
    58  //   uidx_base := 0
    59  //   cur_bits := bitarr[0]
    60  //   for idx := 0; idx != nSetBit; idx++ {
    61  //     // see plink2_base.h BitIter1()
    62  //     if cur_bits == 0 {
    63  //       widx := uidx_base >> (3 + Log2BytesPerWord)
    64  //       for {
    65  //         widx++
    66  //         cur_bits = bitarr[widx]
    67  //         if cur_bits != 0 {
    68  //           break
    69  //         }
    70  //       }
    71  //       uidx_base = widx << (3 + Log2BytesPerWord)
    72  //     }
    73  //     uidx := uidx_base + bits.TrailingZeros(uint(cur_bits))
    74  //     cur_bits = cur_bits & (cur_bits - 1)
    75  //     // (do something with uidx, possibly very simple)
    76  //   }
    77  // (Note that there are *hundreds* of loops of this form in plink2.)
    78  // If bitset.NextSetMany() does not impose a large performance penalty, we may
    79  // just want to write a version of it which takes a []byte as input.
    80  // (update: https://go-review.googlesource.com/c/go/+/109716 suggests that
    81  // bitset.NextSetMany() is not good enough.)
    82  
    83  // todo: add ZeroTrailingBits, etc. once we need it
    84  
    85  // MaskThenCountByte counts the number of bytes in src[] satisfying
    86  //   src[pos] & mask == val.
    87  func MaskThenCountByte(src []byte, mask, val byte) int {
    88  	// This is especially useful for CG counting:
    89  	// - Count 'C'/'G' ASCII characters: mask = 0xfb (only capital) or 0xdb
    90  	//   (either capital or lowercase), val = 'C'
    91  	// - Count C/G bytes in .bam unpacked seq8 data, assuming '=' is not in
    92  	//   input: mask = 0x9, val = 0
    93  	// It can also be used to ignore capitalization when counting instances of a
    94  	// single letter.
    95  	cnt := 0
    96  	for _, srcByte := range src {
    97  		if (srcByte & mask) == val {
    98  			cnt++
    99  		}
   100  	}
   101  	return cnt
   102  }
   103  
   104  // Count2Bytes counts the number of bytes in src[] which are equal to either
   105  // val1 or val2.
   106  // (bytes.Count() should be good enough for a single byte.)
   107  func Count2Bytes(src []byte, val1, val2 byte) int {
   108  	cnt := 0
   109  	for _, srcByte := range src {
   110  		if (srcByte == val1) || (srcByte == val2) {
   111  			cnt++
   112  		}
   113  	}
   114  	return cnt
   115  }
   116  
   117  // Count3Bytes counts the number of bytes in src[] which are equal to val1,
   118  // val2, or val3.
   119  func Count3Bytes(src []byte, val1, val2, val3 byte) int {
   120  	cnt := 0
   121  	for _, srcByte := range src {
   122  		if (srcByte == val1) || (srcByte == val2) || (srcByte == val3) {
   123  			cnt++
   124  		}
   125  	}
   126  	return cnt
   127  }
   128  
   129  // CountNibblesInSet counts the number of nibbles in src[] which are in the
   130  // given set.  The set must be represented as table[x] == 1 when value x is in
   131  // the set, and table[x] == 0 when x isn't.
   132  //
   133  // WARNING: This function does not validate the table.  It may return a garbage
   134  // result on invalid input.  (However, it won't corrupt memory.)
   135  func CountNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int {
   136  	cnt := 0
   137  	for _, srcByte := range src {
   138  		cnt += int(tablePtr.Get(srcByte&15) + tablePtr.Get(srcByte>>4))
   139  	}
   140  	return cnt
   141  }
   142  
   143  // CountNibblesInTwoSets counts the number of bytes in src[] which are in the
   144  // given two sets, assuming all bytes are <16.  The sets must be represented as
   145  // table[x] == 1 when value x is in the set, and table[x] == 0 when x isn't.
   146  //
   147  // WARNING: This function does not validate the tables.  It may crash or return
   148  // garbage results on invalid input.  (However, it won't corrupt memory.)
   149  func CountNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) {
   150  	cnt1 := 0
   151  	cnt2 := 0
   152  	for _, srcByte := range src {
   153  		lowBits := srcByte & 15
   154  		highBits := srcByte >> 4
   155  		cnt1 += int(table1Ptr.Get(lowBits) + table1Ptr.Get(highBits))
   156  		cnt2 += int(table2Ptr.Get(lowBits) + table2Ptr.Get(highBits))
   157  	}
   158  	return cnt1, cnt2
   159  }
   160  
   161  // CountUnpackedNibblesInSet counts the number of bytes in src[] which are in
   162  // the given set, assuming all bytes are <16.  The set must be represented as
   163  // table[x] == 1 when value x is in the set, and table[x] == 0 when x isn't.
   164  //
   165  // WARNING: This function does not validate the table.  It may crash or return
   166  // a garbage result on invalid input.  (However, it won't corrupt memory.)
   167  func CountUnpackedNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int {
   168  	cnt := 0
   169  	for _, srcByte := range src {
   170  		cnt += int(tablePtr.Get(srcByte))
   171  	}
   172  	return cnt
   173  }
   174  
   175  // CountUnpackedNibblesInTwoSets counts the number of bytes in src[] which are
   176  // in the given two sets, assuming all bytes are <16.  The sets must be
   177  // represented as table[x] == 1 when value x is in the set, and table[x] == 0
   178  // when x isn't.
   179  //
   180  // WARNING: This function does not validate the tables.  It may crash or return
   181  // garbage results on invalid input.  (However, it won't corrupt memory.)
   182  func CountUnpackedNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) {
   183  	cnt1 := 0
   184  	cnt2 := 0
   185  	for _, srcByte := range src {
   186  		cnt1 += int(table1Ptr.Get(srcByte))
   187  		cnt2 += int(table2Ptr.Get(srcByte))
   188  	}
   189  	return cnt1, cnt2
   190  }
   191  
   192  // (could rename Popcnt to Accumulate1 for consistency...)
   193  
   194  // Accumulate8 returns the sum of the (unsigned) bytes in src[].
   195  func Accumulate8(src []byte) int {
   196  	cnt := 0
   197  	for _, srcByte := range src {
   198  		cnt += int(srcByte)
   199  	}
   200  	return cnt
   201  }
   202  
   203  // Accumulate8Greater returns the sum of all bytes in src[] greater than the
   204  // given value.
   205  func Accumulate8Greater(src []byte, val byte) int {
   206  	cnt := 0
   207  	for _, srcByte := range src {
   208  		if srcByte > val {
   209  			cnt += int(srcByte)
   210  		}
   211  	}
   212  	return cnt
   213  }