github.com/grailbio/base@v0.0.11/simd/cmp_amd64.go (about) 1 // Copyright 2021 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 //go:build amd64 && !appengine 6 // +build amd64,!appengine 7 8 package simd 9 10 import ( 11 "math/bits" 12 "reflect" 13 "unsafe" 14 ) 15 16 //go:noescape 17 func firstGreater8SSSE3Asm(arg unsafe.Pointer, val, startPos, endPos int) int 18 19 //go:noescape 20 func firstLeq8SSSE3Asm(arg unsafe.Pointer, val, startPos, endPos int) int 21 22 // FirstUnequal8Unsafe scans arg1[startPos:] and arg2[startPos:] for the first 23 // mismatching byte, returning its position if one is found, or the common 24 // length if all bytes match (or startPos >= len). This has essentially the 25 // same speed as bytes.Compare(). 26 // 27 // WARNING: This is a function designed to be used in inner loops, which makes 28 // assumptions about length and capacity which aren't checked at runtime. Use 29 // the safe version of this function when that's a problem. 30 // The second assumption is always satisfied when the last 31 // potentially-size-increasing operation on arg1[] is {Re}makeUnsafe(), 32 // ResizeUnsafe(), or XcapUnsafe(), and the same is true for arg2[]. 33 // 34 // 1. len(arg1) == len(arg2). 35 // 36 // 2. Capacities are at least RoundUpPow2(len, bytesPerVec). 37 func FirstUnequal8Unsafe(arg1, arg2 []byte, startPos int) int { 38 // Possible alternative interface for these functions: fill a bitarray, with 39 // set bits for each mismatching position. Can return popcount. 40 endPos := len(arg1) 41 nByte := endPos - startPos 42 if nByte <= 0 { 43 return endPos 44 } 45 arg1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg1)).Data) 46 arg2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg2)).Data) 47 nWordMinus1 := (nByte - 1) >> Log2BytesPerWord 48 arg1Iter := unsafe.Add(arg1Data, startPos) 49 arg2Iter := unsafe.Add(arg2Data, startPos) 50 // Tried replacing this with simple (non-unrolled) vector-based loops very 51 // similar to the main runtime's go/src/internal/bytealg/compare_amd64.s, but 52 // they were actually worse than the safe function on the short-array 53 // benchmark. Can eventually look at what clang/LLVM actually generates for 54 // plink2_string.cc FirstUnequal4()--I confirmed that the word-based loop is 55 // slower than the SSE2 vector-based loop there--but that can wait until AVX2 56 // support is added. 57 for widx := 0; widx < nWordMinus1; widx++ { 58 xorWord := (*((*uintptr)(arg1Iter))) ^ (*((*uintptr)(arg2Iter))) 59 if xorWord != 0 { 60 // Unfortunately, in a primarily-signed-int codebase, ">> 3" should 61 // generally be written over the more readable "/ 8", because the latter 62 // requires additional code to handle negative numerators. In this 63 // specific case, I'd hope that the compiler is smart enough to prove 64 // that bits.TrailingZeros64() returns nonnegative values, and would then 65 // optimize "/ 8" appropriately, but it is better to not worry about the 66 // matter at all. 67 return startPos + (widx * BytesPerWord) + (bits.TrailingZeros64(uint64(xorWord)) >> 3) 68 } 69 arg1Iter = unsafe.Add(arg1Iter, BytesPerWord) 70 arg2Iter = unsafe.Add(arg2Iter, BytesPerWord) 71 } 72 xorWord := (*((*uintptr)(arg1Iter))) ^ (*((*uintptr)(arg2Iter))) 73 if xorWord == 0 { 74 return endPos 75 } 76 unequalPos := startPos + nWordMinus1*BytesPerWord + (bits.TrailingZeros64(uint64(xorWord)) >> 3) 77 if unequalPos > endPos { 78 return endPos 79 } 80 return unequalPos 81 } 82 83 // FirstUnequal8 scans arg1[startPos:] and arg2[startPos:] for the first 84 // mismatching byte, returning its position if one is found, or the common 85 // length if all bytes match (or startPos >= len). It panics if the lengths 86 // are not identical, or startPos is negative. 87 // 88 // This is essentially an extension of bytes.Compare(). 89 func FirstUnequal8(arg1, arg2 []byte, startPos int) int { 90 // This takes ~10% longer on the short-array benchmark. 91 endPos := len(arg1) 92 if endPos != len(arg2) || (startPos < 0) { 93 // The startPos < 0 check is kind of paranoid. It's here because 94 // unsafe.Add(arg1Data, startPos) does not automatically error out on 95 // negative startPos, and it also doesn't hurt to protect against (endPos - 96 // startPos) integer overflow; but feel free to request its removal if you 97 // are using this function in a time-critical loop. 98 panic("FirstUnequal8() requires len(arg1) == len(arg2) and nonnegative startPos.") 99 } 100 nByte := endPos - startPos 101 if nByte < BytesPerWord { 102 for pos := startPos; pos < endPos; pos++ { 103 if arg1[pos] != arg2[pos] { 104 return pos 105 } 106 } 107 return endPos 108 } 109 arg1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg1)).Data) 110 arg2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg2)).Data) 111 nWordMinus1 := (nByte - 1) >> Log2BytesPerWord 112 arg1Iter := unsafe.Add(arg1Data, startPos) 113 arg2Iter := unsafe.Add(arg2Data, startPos) 114 for widx := 0; widx < nWordMinus1; widx++ { 115 xorWord := (*((*uintptr)(arg1Iter))) ^ (*((*uintptr)(arg2Iter))) 116 if xorWord != 0 { 117 return startPos + (widx * BytesPerWord) + (bits.TrailingZeros64(uint64(xorWord)) >> 3) 118 } 119 arg1Iter = unsafe.Add(arg1Iter, BytesPerWord) 120 arg2Iter = unsafe.Add(arg2Iter, BytesPerWord) 121 } 122 finalOffset := uintptr(endPos - BytesPerWord) 123 arg1FinalPtr := unsafe.Add(arg1Data, finalOffset) 124 arg2FinalPtr := unsafe.Add(arg2Data, finalOffset) 125 xorWord := (*((*uintptr)(arg1FinalPtr))) ^ (*((*uintptr)(arg2FinalPtr))) 126 if xorWord == 0 { 127 return endPos 128 } 129 return int(finalOffset) + (bits.TrailingZeros64(uint64(xorWord)) >> 3) 130 } 131 132 // FirstGreater8Unsafe scans arg[startPos:] for the first value larger than the 133 // given constant, returning its position if one is found, or len(arg) if all 134 // bytes are <= (or startPos >= len). 135 // 136 // This should only be used when greater values are usually present at ~5% or 137 // lower frequency. Above that, use a simple for loop. 138 // 139 // WARNING: This is a function designed to be used in inner loops, which makes 140 // assumptions about length and capacity which aren't checked at runtime. Use 141 // the safe version of this function when that's a problem. 142 // The second assumption is always satisfied when the last 143 // potentially-size-increasing operation on arg[] is {Re}makeUnsafe(), 144 // ResizeUnsafe(), or XcapUnsafe(). 145 // 146 // 1. startPos is nonnegative. 147 // 148 // 2. cap(arg) >= RoundUpPow2(len, bytesPerVec). 149 func FirstGreater8Unsafe(arg []byte, val byte, startPos int) int { 150 endPos := len(arg) 151 nByte := endPos - startPos 152 if nByte <= bytesPerVec { 153 // Main loop setup overhead is pretty high. Crossover point in benchmarks 154 // is in the 16-32 byte range (depending on sparsity). 155 for pos := startPos; pos < endPos; pos++ { 156 if arg[pos] > val { 157 return pos 158 } 159 } 160 return endPos 161 } 162 argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) 163 return firstGreater8SSSE3Asm(unsafe.Pointer(argHeader.Data), int(val), startPos, endPos) 164 } 165 166 // FirstGreater8 scans arg[startPos:] for the first value larger than the given 167 // constant, returning its position if one is found, or len(arg) if all bytes 168 // are <= (or startPos >= len). 169 // 170 // This should only be used when greater values are usually present at ~5% or 171 // lower frequency. Above that, use a simple for loop. 172 func FirstGreater8(arg []byte, val byte, startPos int) int { 173 if startPos < 0 { 174 panic("FirstGreater8() requires nonnegative startPos.") 175 } 176 endPos := len(arg) 177 nByte := endPos - startPos 178 if nByte <= bytesPerVec { 179 for pos := startPos; pos < endPos; pos++ { 180 if arg[pos] > val { 181 return pos 182 } 183 } 184 return endPos 185 } 186 argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) 187 return firstGreater8SSSE3Asm(unsafe.Pointer(argHeader.Data), int(val), startPos, endPos) 188 } 189 190 // FirstLeq8Unsafe scans arg[startPos:] for the first value <= the given 191 // constant, returning its position if one is found, or len(arg) if all bytes 192 // are greater (or startPos >= len). 193 // 194 // This should only be used when <= values are usually present at ~5% or 195 // lower frequency. Above that, use a simple for loop. 196 // 197 // See warning for FirstGreater8Unsafe. 198 func FirstLeq8Unsafe(arg []byte, val byte, startPos int) int { 199 endPos := len(arg) 200 nByte := endPos - startPos 201 if nByte <= bytesPerVec { 202 for pos := startPos; pos < endPos; pos++ { 203 if arg[pos] <= val { 204 return pos 205 } 206 } 207 return endPos 208 } 209 argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) 210 return firstLeq8SSSE3Asm(unsafe.Pointer(argHeader.Data), int(val), startPos, endPos) 211 } 212 213 // FirstLeq8 scans arg[startPos:] for the first value <= the given constant, 214 // returning its position if one is found, or len(arg) if all bytes are greater 215 // (or startPos >= len). 216 // 217 // This should only be used when <= values are usually present at ~5% or lower 218 // frequency. Above that, use a simple for loop. 219 func FirstLeq8(arg []byte, val byte, startPos int) int { 220 // This currently has practically no performance penalty relative to the 221 // Unsafe version, since the implementation is identical except for the 222 // startPos check. 223 if startPos < 0 { 224 panic("FirstLeq8() requires nonnegative startPos.") 225 } 226 endPos := len(arg) 227 nByte := endPos - startPos 228 if nByte <= bytesPerVec { 229 for pos := startPos; pos < endPos; pos++ { 230 if arg[pos] <= val { 231 return pos 232 } 233 } 234 return endPos 235 } 236 argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) 237 return firstLeq8SSSE3Asm(unsafe.Pointer(argHeader.Data), int(val), startPos, endPos) 238 }