github.com/grailbio/base@v0.0.11/simd/cmp_amd64.go (about)

     1  // Copyright 2021 GRAIL, Inc.  All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build amd64 && !appengine
     6  // +build amd64,!appengine
     7  
     8  package simd
     9  
    10  import (
    11  	"math/bits"
    12  	"reflect"
    13  	"unsafe"
    14  )
    15  
    16  //go:noescape
    17  func firstGreater8SSSE3Asm(arg unsafe.Pointer, val, startPos, endPos int) int
    18  
    19  //go:noescape
    20  func firstLeq8SSSE3Asm(arg unsafe.Pointer, val, startPos, endPos int) int
    21  
    22  // FirstUnequal8Unsafe scans arg1[startPos:] and arg2[startPos:] for the first
    23  // mismatching byte, returning its position if one is found, or the common
    24  // length if all bytes match (or startPos >= len).  This has essentially the
    25  // same speed as bytes.Compare().
    26  //
    27  // WARNING: This is a function designed to be used in inner loops, which makes
    28  // assumptions about length and capacity which aren't checked at runtime.  Use
    29  // the safe version of this function when that's a problem.
    30  // The second assumption is always satisfied when the last
    31  // potentially-size-increasing operation on arg1[] is {Re}makeUnsafe(),
    32  // ResizeUnsafe(), or XcapUnsafe(), and the same is true for arg2[].
    33  //
    34  // 1. len(arg1) == len(arg2).
    35  //
    36  // 2. Capacities are at least RoundUpPow2(len, bytesPerVec).
    37  func FirstUnequal8Unsafe(arg1, arg2 []byte, startPos int) int {
    38  	// Possible alternative interface for these functions: fill a bitarray, with
    39  	// set bits for each mismatching position.  Can return popcount.
    40  	endPos := len(arg1)
    41  	nByte := endPos - startPos
    42  	if nByte <= 0 {
    43  		return endPos
    44  	}
    45  	arg1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg1)).Data)
    46  	arg2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg2)).Data)
    47  	nWordMinus1 := (nByte - 1) >> Log2BytesPerWord
    48  	arg1Iter := unsafe.Add(arg1Data, startPos)
    49  	arg2Iter := unsafe.Add(arg2Data, startPos)
    50  	// Tried replacing this with simple (non-unrolled) vector-based loops very
    51  	// similar to the main runtime's go/src/internal/bytealg/compare_amd64.s, but
    52  	// they were actually worse than the safe function on the short-array
    53  	// benchmark.  Can eventually look at what clang/LLVM actually generates for
    54  	// plink2_string.cc FirstUnequal4()--I confirmed that the word-based loop is
    55  	// slower than the SSE2 vector-based loop there--but that can wait until AVX2
    56  	// support is added.
    57  	for widx := 0; widx < nWordMinus1; widx++ {
    58  		xorWord := (*((*uintptr)(arg1Iter))) ^ (*((*uintptr)(arg2Iter)))
    59  		if xorWord != 0 {
    60  			// Unfortunately, in a primarily-signed-int codebase, ">> 3" should
    61  			// generally be written over the more readable "/ 8", because the latter
    62  			// requires additional code to handle negative numerators.  In this
    63  			// specific case, I'd hope that the compiler is smart enough to prove
    64  			// that bits.TrailingZeros64() returns nonnegative values, and would then
    65  			// optimize "/ 8" appropriately, but it is better to not worry about the
    66  			// matter at all.
    67  			return startPos + (widx * BytesPerWord) + (bits.TrailingZeros64(uint64(xorWord)) >> 3)
    68  		}
    69  		arg1Iter = unsafe.Add(arg1Iter, BytesPerWord)
    70  		arg2Iter = unsafe.Add(arg2Iter, BytesPerWord)
    71  	}
    72  	xorWord := (*((*uintptr)(arg1Iter))) ^ (*((*uintptr)(arg2Iter)))
    73  	if xorWord == 0 {
    74  		return endPos
    75  	}
    76  	unequalPos := startPos + nWordMinus1*BytesPerWord + (bits.TrailingZeros64(uint64(xorWord)) >> 3)
    77  	if unequalPos > endPos {
    78  		return endPos
    79  	}
    80  	return unequalPos
    81  }
    82  
    83  // FirstUnequal8 scans arg1[startPos:] and arg2[startPos:] for the first
    84  // mismatching byte, returning its position if one is found, or the common
    85  // length if all bytes match (or startPos >= len).  It panics if the lengths
    86  // are not identical, or startPos is negative.
    87  //
    88  // This is essentially an extension of bytes.Compare().
    89  func FirstUnequal8(arg1, arg2 []byte, startPos int) int {
    90  	// This takes ~10% longer on the short-array benchmark.
    91  	endPos := len(arg1)
    92  	if endPos != len(arg2) || (startPos < 0) {
    93  		// The startPos < 0 check is kind of paranoid.  It's here because
    94  		// unsafe.Add(arg1Data, startPos) does not automatically error out on
    95  		// negative startPos, and it also doesn't hurt to protect against (endPos -
    96  		// startPos) integer overflow; but feel free to request its removal if you
    97  		// are using this function in a time-critical loop.
    98  		panic("FirstUnequal8() requires len(arg1) == len(arg2) and nonnegative startPos.")
    99  	}
   100  	nByte := endPos - startPos
   101  	if nByte < BytesPerWord {
   102  		for pos := startPos; pos < endPos; pos++ {
   103  			if arg1[pos] != arg2[pos] {
   104  				return pos
   105  			}
   106  		}
   107  		return endPos
   108  	}
   109  	arg1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg1)).Data)
   110  	arg2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg2)).Data)
   111  	nWordMinus1 := (nByte - 1) >> Log2BytesPerWord
   112  	arg1Iter := unsafe.Add(arg1Data, startPos)
   113  	arg2Iter := unsafe.Add(arg2Data, startPos)
   114  	for widx := 0; widx < nWordMinus1; widx++ {
   115  		xorWord := (*((*uintptr)(arg1Iter))) ^ (*((*uintptr)(arg2Iter)))
   116  		if xorWord != 0 {
   117  			return startPos + (widx * BytesPerWord) + (bits.TrailingZeros64(uint64(xorWord)) >> 3)
   118  		}
   119  		arg1Iter = unsafe.Add(arg1Iter, BytesPerWord)
   120  		arg2Iter = unsafe.Add(arg2Iter, BytesPerWord)
   121  	}
   122  	finalOffset := uintptr(endPos - BytesPerWord)
   123  	arg1FinalPtr := unsafe.Add(arg1Data, finalOffset)
   124  	arg2FinalPtr := unsafe.Add(arg2Data, finalOffset)
   125  	xorWord := (*((*uintptr)(arg1FinalPtr))) ^ (*((*uintptr)(arg2FinalPtr)))
   126  	if xorWord == 0 {
   127  		return endPos
   128  	}
   129  	return int(finalOffset) + (bits.TrailingZeros64(uint64(xorWord)) >> 3)
   130  }
   131  
   132  // FirstGreater8Unsafe scans arg[startPos:] for the first value larger than the
   133  // given constant, returning its position if one is found, or len(arg) if all
   134  // bytes are <= (or startPos >= len).
   135  //
   136  // This should only be used when greater values are usually present at ~5% or
   137  // lower frequency.  Above that, use a simple for loop.
   138  //
   139  // WARNING: This is a function designed to be used in inner loops, which makes
   140  // assumptions about length and capacity which aren't checked at runtime.  Use
   141  // the safe version of this function when that's a problem.
   142  // The second assumption is always satisfied when the last
   143  // potentially-size-increasing operation on arg[] is {Re}makeUnsafe(),
   144  // ResizeUnsafe(), or XcapUnsafe().
   145  //
   146  // 1. startPos is nonnegative.
   147  //
   148  // 2. cap(arg) >= RoundUpPow2(len, bytesPerVec).
   149  func FirstGreater8Unsafe(arg []byte, val byte, startPos int) int {
   150  	endPos := len(arg)
   151  	nByte := endPos - startPos
   152  	if nByte <= bytesPerVec {
   153  		// Main loop setup overhead is pretty high. Crossover point in benchmarks
   154  		// is in the 16-32 byte range (depending on sparsity).
   155  		for pos := startPos; pos < endPos; pos++ {
   156  			if arg[pos] > val {
   157  				return pos
   158  			}
   159  		}
   160  		return endPos
   161  	}
   162  	argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg))
   163  	return firstGreater8SSSE3Asm(unsafe.Pointer(argHeader.Data), int(val), startPos, endPos)
   164  }
   165  
   166  // FirstGreater8 scans arg[startPos:] for the first value larger than the given
   167  // constant, returning its position if one is found, or len(arg) if all bytes
   168  // are <= (or startPos >= len).
   169  //
   170  // This should only be used when greater values are usually present at ~5% or
   171  // lower frequency.  Above that, use a simple for loop.
   172  func FirstGreater8(arg []byte, val byte, startPos int) int {
   173  	if startPos < 0 {
   174  		panic("FirstGreater8() requires nonnegative startPos.")
   175  	}
   176  	endPos := len(arg)
   177  	nByte := endPos - startPos
   178  	if nByte <= bytesPerVec {
   179  		for pos := startPos; pos < endPos; pos++ {
   180  			if arg[pos] > val {
   181  				return pos
   182  			}
   183  		}
   184  		return endPos
   185  	}
   186  	argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg))
   187  	return firstGreater8SSSE3Asm(unsafe.Pointer(argHeader.Data), int(val), startPos, endPos)
   188  }
   189  
   190  // FirstLeq8Unsafe scans arg[startPos:] for the first value <= the given
   191  // constant, returning its position if one is found, or len(arg) if all bytes
   192  // are greater (or startPos >= len).
   193  //
   194  // This should only be used when <= values are usually present at ~5% or
   195  // lower frequency.  Above that, use a simple for loop.
   196  //
   197  // See warning for FirstGreater8Unsafe.
   198  func FirstLeq8Unsafe(arg []byte, val byte, startPos int) int {
   199  	endPos := len(arg)
   200  	nByte := endPos - startPos
   201  	if nByte <= bytesPerVec {
   202  		for pos := startPos; pos < endPos; pos++ {
   203  			if arg[pos] <= val {
   204  				return pos
   205  			}
   206  		}
   207  		return endPos
   208  	}
   209  	argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg))
   210  	return firstLeq8SSSE3Asm(unsafe.Pointer(argHeader.Data), int(val), startPos, endPos)
   211  }
   212  
   213  // FirstLeq8 scans arg[startPos:] for the first value <= the given constant,
   214  // returning its position if one is found, or len(arg) if all bytes are greater
   215  // (or startPos >= len).
   216  //
   217  // This should only be used when <= values are usually present at ~5% or lower
   218  // frequency.  Above that, use a simple for loop.
   219  func FirstLeq8(arg []byte, val byte, startPos int) int {
   220  	// This currently has practically no performance penalty relative to the
   221  	// Unsafe version, since the implementation is identical except for the
   222  	// startPos check.
   223  	if startPos < 0 {
   224  		panic("FirstLeq8() requires nonnegative startPos.")
   225  	}
   226  	endPos := len(arg)
   227  	nByte := endPos - startPos
   228  	if nByte <= bytesPerVec {
   229  		for pos := startPos; pos < endPos; pos++ {
   230  			if arg[pos] <= val {
   231  				return pos
   232  			}
   233  		}
   234  		return endPos
   235  	}
   236  	argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg))
   237  	return firstLeq8SSSE3Asm(unsafe.Pointer(argHeader.Data), int(val), startPos, endPos)
   238  }