github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/simd/count_amd64.go (about) 1 // Copyright 2021 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 //go:build amd64 && !appengine 6 // +build amd64,!appengine 7 8 // This is derived from github.com/willf/bitset . 9 10 package simd 11 12 import ( 13 "math/bits" 14 "reflect" 15 "unsafe" 16 ) 17 18 // *** the following function is defined in count_amd64.s 19 20 //go:noescape 21 func popcntWordArraySSE42Asm(bytes unsafe.Pointer, nWord int) int 22 23 // Although mask and val are really byte parameters, actually declaring them as 24 // bytes instead of ints in the function signature produces a *massive* 25 // performance penalty. 26 27 //go:noescape 28 func maskThenCountByteSSE41Asm(src unsafe.Pointer, mask, val, nByte int) int 29 30 //go:noescape 31 func count2BytesSSE41Asm(src unsafe.Pointer, val1, val2, nByte int) int 32 33 //go:noescape 34 func count3BytesSSE41Asm(src unsafe.Pointer, val1, val2, val3, nByte int) int 35 36 //go:noescape 37 func countNibblesInSetSSE41Asm(src unsafe.Pointer, tablePtr *NibbleLookupTable, nByte int) int 38 39 func countNibblesInTwoSetsSSE41Asm(cnt2Ptr *int, src unsafe.Pointer, table1Ptr, table2Ptr *NibbleLookupTable, nByte int) int 40 41 //go:noescape 42 func countUnpackedNibblesInSetSSE41Asm(src unsafe.Pointer, tablePtr *NibbleLookupTable, nByte int) int 43 44 //go:noescape 45 func countUnpackedNibblesInTwoSetsSSE41Asm(cnt2Ptr *int, src unsafe.Pointer, table1Ptr, table2Ptr *NibbleLookupTable, nByte int) int 46 47 //go:noescape 48 func accumulate8SSE41Asm(src unsafe.Pointer, nByte int) int 49 50 //go:noescape 51 func accumulate8GreaterSSE41Asm(src unsafe.Pointer, val, nByte int) int 52 53 // *** end assembly function signature(s) 54 55 // PopcntUnsafe returns the number of set bits in the given []byte, assuming 56 // that any trailing bytes up to the next multiple of BytesPerWord are zeroed 57 // out. 58 func PopcntUnsafe(bytes []byte) int { 59 // Get the base-pointer for the slice, in a way that doesn't trigger a 60 // bounds-check and fail when length == 0. (Yes, I found out during testing 61 // that the &bytes[0] idiom doesn't actually work in the length-0 62 // case...) 63 bytesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&bytes)) 64 65 return popcntWordArraySSE42Asm(unsafe.Pointer(bytesHeader.Data), DivUpPow2(len(bytes), BytesPerWord, Log2BytesPerWord)) 66 } 67 68 // Popcnt returns the number of set bits in the given []byte. 69 // 70 // Some effort has been made to make this run acceptably fast on relatively 71 // short arrays, since I expect knowing how to do so to be helpful when working 72 // with hundreds of millions of .bam reads with ~75 bytes of base data and ~150 73 // bytes of quality data. Interestingly, moving the leading-byte handling code 74 // to assembly didn't make a difference. 75 // 76 // Some single-threaded benchmark results calling Popcnt 99999999 times on a 77 // 14-byte unaligned array: 78 // C implementation: 0.219-0.232s 79 // This code: 0.606-0.620s 80 // C implementation using memcpy for trailing bytes: 0.964-0.983s 81 // So Go's extra looping and function call overhead can almost triple runtime 82 // in the short-array limit, but that's actually not as bad as the 4.5x 83 // overhead of trusting memcpy to handle trailing bytes. 84 func Popcnt(bytes []byte) int { 85 bytesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&bytes)) 86 nByte := len(bytes) 87 88 bytearr := unsafe.Pointer(bytesHeader.Data) 89 tot := 0 90 nLeadingByte := nByte & (BytesPerWord - 1) 91 if nLeadingByte != 0 { 92 leadingWord := uint64(0) 93 if (nLeadingByte & 1) != 0 { 94 leadingWord = (uint64)(*(*byte)(bytearr)) 95 bytearr = unsafe.Add(bytearr, 1) 96 } 97 if (nLeadingByte & 2) != 0 { 98 // Note that this does not keep the bytes in the original little-endian 99 // order, since that's irrelevant for popcount, and probably everything 100 // else we need to do in this package. See ProperSubwordLoad() in 101 // plink2_base.h for code which does keep the bytes in order. 102 leadingWord <<= 16 103 leadingWord |= (uint64)(*(*uint16)(bytearr)) 104 bytearr = unsafe.Add(bytearr, 2) 105 } 106 if (nLeadingByte & 4) != 0 { 107 leadingWord <<= 32 108 leadingWord |= (uint64)(*(*uint32)(bytearr)) 109 bytearr = unsafe.Add(bytearr, 4) 110 } 111 tot = bits.OnesCount64(leadingWord) 112 } 113 tot += popcntWordArraySSE42Asm(bytearr, nByte>>Log2BytesPerWord) 114 return tot 115 } 116 117 // We may want a PopcntW function in the future which operates on a []uintptr, 118 // along with AndW, OrW, XorW, InvmaskW, etc. This would amount to a 119 // lower-overhead version of willf/bitset (which also uses []uintptr 120 // internally). 121 // The main thing I would want to benchmark before making that decision is 122 // bitset.NextSetMany() vs. a loop of the form 123 // uidx_base := 0 124 // cur_bits := bitarr[0] 125 // for idx := 0; idx != nSetBit; idx++ { 126 // // see plink2_base.h BitIter1() 127 // if cur_bits == 0 { 128 // widx := uidx_base >> (3 + Log2BytesPerWord) 129 // for { 130 // widx++ 131 // cur_bits = bitarr[widx] 132 // if cur_bits != 0 { 133 // break 134 // } 135 // } 136 // uidx_base = widx << (3 + Log2BytesPerWord) 137 // } 138 // uidx := uidx_base + bits.TrailingZeros(uint(cur_bits)) 139 // cur_bits = cur_bits & (cur_bits - 1) 140 // // (do something with uidx, possibly very simple) 141 // } 142 // (Note that there are *hundreds* of loops of this form in plink2.) 143 // If bitset.NextSetMany() does not impose a large performance penalty, we may 144 // just want to write a version of it which takes a []byte as input. 145 // (update: https://go-review.googlesource.com/c/go/+/109716 suggests that 146 // bitset.NextSetMany() is not good enough.) 147 148 // todo: add ZeroTrailingBits, etc. once we need it 149 150 // MaskThenCountByte counts the number of bytes in src[] satisfying 151 // src[pos] & mask == val. 152 func MaskThenCountByte(src []byte, mask, val byte) int { 153 // This is especially useful for CG counting: 154 // - Count 'C'/'G' ASCII characters: mask = 0xfb (only capital) or 0xdb 155 // (either capital or lowercase), val = 'C' 156 // - Count C/G bytes in .bam unpacked seq8 data, assuming '=' is not in 157 // input: mask = 0x9, val = 0 158 // It can also be used to ignore capitalization when counting instances of a 159 // single letter. 160 nByte := len(src) 161 if nByte < 16 { 162 cnt := 0 163 for _, srcByte := range src { 164 if (srcByte & mask) == val { 165 cnt++ 166 } 167 } 168 return cnt 169 } 170 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 171 return maskThenCountByteSSE41Asm(unsafe.Pointer(srcHeader.Data), int(mask), int(val), nByte) 172 } 173 174 // Count2Bytes counts the number of bytes in src[] which are equal to either 175 // val1 or val2. 176 // (bytes.Count() should be good enough for a single byte.) 177 func Count2Bytes(src []byte, val1, val2 byte) int { 178 nByte := len(src) 179 if nByte < 16 { 180 cnt := 0 181 for _, srcByte := range src { 182 if (srcByte == val1) || (srcByte == val2) { 183 cnt++ 184 } 185 } 186 return cnt 187 } 188 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 189 return count2BytesSSE41Asm(unsafe.Pointer(srcHeader.Data), int(val1), int(val2), nByte) 190 } 191 192 // Count3Bytes counts the number of bytes in src[] which are equal to val1, 193 // val2, or val3. 194 func Count3Bytes(src []byte, val1, val2, val3 byte) int { 195 nByte := len(src) 196 if nByte < 16 { 197 cnt := 0 198 for _, srcByte := range src { 199 if (srcByte == val1) || (srcByte == val2) || (srcByte == val3) { 200 cnt++ 201 } 202 } 203 return cnt 204 } 205 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 206 return count3BytesSSE41Asm(unsafe.Pointer(srcHeader.Data), int(val1), int(val2), int(val3), nByte) 207 } 208 209 // CountNibblesInSet counts the number of nibbles in src[] which are in the 210 // given set. The set must be represented as table[x] == 1 when value x is in 211 // the set, and table[x] == 0 when x isn't. 212 // 213 // WARNING: This function does not validate the table. It may return a garbage 214 // result on invalid input. (However, it won't corrupt memory.) 215 func CountNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int { 216 nSrcByte := len(src) 217 if nSrcByte < 16 { 218 cnt := 0 219 for _, srcByte := range src { 220 cnt += int(tablePtr.Get(srcByte&15) + tablePtr.Get(srcByte>>4)) 221 } 222 return cnt 223 } 224 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 225 return countNibblesInSetSSE41Asm(unsafe.Pointer(srcHeader.Data), tablePtr, nSrcByte) 226 } 227 228 // CountNibblesInTwoSets counts the number of bytes in src[] which are in the 229 // given two sets, assuming all bytes are <16. The sets must be represented as 230 // table[x] == 1 when value x is in the set, and table[x] == 0 when x isn't. 231 // 232 // WARNING: This function does not validate the tables. It may crash or return 233 // garbage results on invalid input. (However, it won't corrupt memory.) 234 func CountNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) { 235 nSrcByte := len(src) 236 cnt2 := 0 237 if nSrcByte < 16 { 238 cnt1 := 0 239 for _, srcByte := range src { 240 lowBits := srcByte & 15 241 highBits := srcByte >> 4 242 cnt1 += int(table1Ptr.Get(lowBits) + table1Ptr.Get(highBits)) 243 cnt2 += int(table2Ptr.Get(lowBits) + table2Ptr.Get(highBits)) 244 } 245 return cnt1, cnt2 246 } 247 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 248 cnt1 := countNibblesInTwoSetsSSE41Asm(&cnt2, unsafe.Pointer(srcHeader.Data), table1Ptr, table2Ptr, nSrcByte) 249 return cnt1, cnt2 250 } 251 252 // CountUnpackedNibblesInSet counts the number of bytes in src[] which are in 253 // the given set, assuming all bytes are <16. The set must be represented as 254 // table[x] == 1 when value x is in the set, and table[x] == 0 when x isn't. 255 // 256 // WARNING: This function does not validate the table. It may crash or return 257 // a garbage result on invalid input. (However, it won't corrupt memory.) 258 func CountUnpackedNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int { 259 nSrcByte := len(src) 260 if nSrcByte < 16 { 261 cnt := 0 262 for _, srcByte := range src { 263 cnt += int(tablePtr.Get(srcByte)) 264 } 265 return cnt 266 } 267 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 268 return countUnpackedNibblesInSetSSE41Asm(unsafe.Pointer(srcHeader.Data), tablePtr, nSrcByte) 269 } 270 271 // CountUnpackedNibblesInTwoSets counts the number of bytes in src[] which are 272 // in the given two sets, assuming all bytes are <16. The sets must be 273 // represented as table[x] == 1 when value x is in the set, and table[x] == 0 274 // when x isn't. 275 // 276 // WARNING: This function does not validate the tables. It may crash or return 277 // garbage results on invalid input. (However, it won't corrupt memory.) 278 func CountUnpackedNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) { 279 // Building this out now so that biosimd.PackedSeqCountTwo is not a valid 280 // reason to stick to packed .bam seq[] representation. 281 nSrcByte := len(src) 282 cnt2 := 0 283 if nSrcByte < 16 { 284 cnt1 := 0 285 for _, srcByte := range src { 286 cnt1 += int(table1Ptr.Get(srcByte)) 287 cnt2 += int(table2Ptr.Get(srcByte)) 288 } 289 return cnt1, cnt2 290 } 291 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 292 cnt1 := countUnpackedNibblesInTwoSetsSSE41Asm(&cnt2, unsafe.Pointer(srcHeader.Data), table1Ptr, table2Ptr, nSrcByte) 293 return cnt1, cnt2 294 } 295 296 // (could rename Popcnt to Accumulate1 for consistency...) 297 298 // Accumulate8 returns the sum of the (unsigned) bytes in src[]. 299 func Accumulate8(src []byte) int { 300 nSrcByte := len(src) 301 if nSrcByte < 16 { 302 cnt := 0 303 for _, srcByte := range src { 304 cnt += int(srcByte) 305 } 306 return cnt 307 } 308 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 309 return accumulate8SSE41Asm(unsafe.Pointer(srcHeader.Data), nSrcByte) 310 } 311 312 // Accumulate8Greater returns the sum of all bytes in src[] greater than the 313 // given value. 314 func Accumulate8Greater(src []byte, val byte) int { 315 nSrcByte := len(src) 316 if nSrcByte < 16 { 317 cnt := 0 318 for _, srcByte := range src { 319 if srcByte > val { 320 cnt += int(srcByte) 321 } 322 } 323 return cnt 324 } 325 srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) 326 return accumulate8GreaterSSE41Asm(unsafe.Pointer(srcHeader.Data), int(val), nSrcByte) 327 }