github.com/grailbio/base@v0.0.11/simd/bitwise_amd64.go.tpl (about)

     1  // Copyright 2021 GRAIL, Inc.  All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build amd64,!appengine
     6  
     7  package PACKAGE
     8  
     9  import (
    10  	"reflect"
    11  	"unsafe"
    12  )
    13  
    14  // ZZUnsafeInplace sets main[pos] := main[pos] OPCHAR arg[pos] for every position
    15  // in main[].
    16  //
    17  // WARNING: This is a function designed to be used in inner loops, which makes
    18  // assumptions about length and capacity which aren't checked at runtime.  Use
    19  // the safe version of this function when that's a problem.
    20  // Assumptions #2-3 are always satisfied when the last
    21  // potentially-size-increasing operation on arg[] is {Re}makeUnsafe(),
    22  // ResizeUnsafe(), or XcapUnsafe(), and the same is true for main[].
    23  //
    24  // 1. len(arg) and len(main) must be equal.
    25  //
    26  // 2. Capacities are at least RoundUpPow2(len(main) + 1, bytesPerVec).
    27  //
    28  // 3. The caller does not care if a few bytes past the end of main[] are
    29  // changed.
    30  func ZZUnsafeInplace(main, arg []byte) {
    31  	mainLen := len(main)
    32  	argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data)
    33  	mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data)
    34  	argWordsIter := argData
    35  	mainWordsIter := mainData
    36  	if mainLen > 2*BytesPerWord {
    37  		nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord
    38  		for widx := 0; widx < nWordMinus2; widx++ {
    39  			mainWord := *((*uintptr)(mainWordsIter))
    40  			argWord := *((*uintptr)(argWordsIter))
    41  			*((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord
    42  			mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord)
    43  			argWordsIter = unsafe.Add(argWordsIter, BytesPerWord)
    44  		}
    45  	} else if mainLen <= BytesPerWord {
    46  		mainWord := *((*uintptr)(mainWordsIter))
    47  		argWord := *((*uintptr)(argWordsIter))
    48  		*((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord
    49  		return
    50  	}
    51  	// The last two read-and-writes to main[] usually overlap.  To avoid a
    52  	// store-to-load forwarding slowdown, we read both words before writing
    53  	// either.
    54  	// shuffleLookupOddInplaceSSSE3Asm() uses the same strategy.
    55  	mainWord1 := *((*uintptr)(mainWordsIter))
    56  	argWord1 := *((*uintptr)(argWordsIter))
    57  	finalOffset := uintptr(mainLen - BytesPerWord)
    58  	mainFinalWordPtr := unsafe.Add(mainData, finalOffset)
    59  	argFinalWordPtr := unsafe.Add(argData, finalOffset)
    60  	mainWord2 := *((*uintptr)(mainFinalWordPtr))
    61  	argWord2 := *((*uintptr)(argFinalWordPtr))
    62  	*((*uintptr)(mainWordsIter)) = mainWord1 OPCHAR argWord1
    63  	*((*uintptr)(mainFinalWordPtr)) = mainWord2 OPCHAR argWord2
    64  }
    65  
    66  // ZZInplace sets main[pos] := arg[pos] OPCHAR main[pos] for every position in
    67  // main[].  It panics if slice lengths don't match.
    68  func ZZInplace(main, arg []byte) {
    69  	// This takes ~6-8% longer than ZZUnsafeInplace on the short-array benchmark
    70  	// on my Mac.
    71  	mainLen := len(main)
    72  	if len(arg) != mainLen {
    73  		panic("ZZInplace() requires len(arg) == len(main).")
    74  	}
    75  	if mainLen < BytesPerWord {
    76  		// It's probably possible to do better here (e.g. when mainLen is in 4..7,
    77  		// operate on uint32s), but I won't worry about it unless/until that's
    78  		// actually a common case.
    79  		for pos, argByte := range arg {
    80  			main[pos] = main[pos] OPCHAR argByte
    81  		}
    82  		return
    83  	}
    84  	argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data)
    85  	mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data)
    86  	argWordsIter := argData
    87  	mainWordsIter := mainData
    88  	if mainLen > 2*BytesPerWord {
    89  		nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord
    90  		for widx := 0; widx < nWordMinus2; widx++ {
    91  			mainWord := *((*uintptr)(mainWordsIter))
    92  			argWord := *((*uintptr)(argWordsIter))
    93  			*((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord
    94  			mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord)
    95  			argWordsIter = unsafe.Add(argWordsIter, BytesPerWord)
    96  		}
    97  	}
    98  	mainWord1 := *((*uintptr)(mainWordsIter))
    99  	argWord1 := *((*uintptr)(argWordsIter))
   100  	finalOffset := uintptr(mainLen - BytesPerWord)
   101  	mainFinalWordPtr := unsafe.Add(mainData, finalOffset)
   102  	argFinalWordPtr := unsafe.Add(argData, finalOffset)
   103  	mainWord2 := *((*uintptr)(mainFinalWordPtr))
   104  	argWord2 := *((*uintptr)(argFinalWordPtr))
   105  	*((*uintptr)(mainWordsIter)) = mainWord1 OPCHAR argWord1
   106  	*((*uintptr)(mainFinalWordPtr)) = mainWord2 OPCHAR argWord2
   107  }
   108  
   109  // ZZUnsafe sets dst[pos] := src1[pos] OPCHAR src2[pos] for every position in dst.
   110  //
   111  // WARNING: This is a function designed to be used in inner loops, which makes
   112  // assumptions about length and capacity which aren't checked at runtime.  Use
   113  // the safe version of this function when that's a problem.
   114  // Assumptions #2-3 are always satisfied when the last
   115  // potentially-size-increasing operation on src1[] is {Re}makeUnsafe(),
   116  // ResizeUnsafe(), or XcapUnsafe(), and the same is true for src2[] and dst[].
   117  //
   118  // 1. len(src1), len(src2), and len(dst) must be equal.
   119  //
   120  // 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec).
   121  //
   122  // 3. The caller does not care if a few bytes past the end of dst[] are
   123  // changed.
   124  func ZZUnsafe(dst, src1, src2 []byte) {
   125  	src1Header := (*reflect.SliceHeader)(unsafe.Pointer(&src1))
   126  	src2Header := (*reflect.SliceHeader)(unsafe.Pointer(&src2))
   127  	dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst))
   128  	nWord := DivUpPow2(len(dst), BytesPerWord, Log2BytesPerWord)
   129  
   130  	src1Iter := unsafe.Pointer(src1Header.Data)
   131  	src2Iter := unsafe.Pointer(src2Header.Data)
   132  	dstIter := unsafe.Pointer(dstHeader.Data)
   133  	for widx := 0; widx < nWord; widx++ {
   134  		src1Word := *((*uintptr)(src1Iter))
   135  		src2Word := *((*uintptr)(src2Iter))
   136  		*((*uintptr)(dstIter)) = src1Word OPCHAR src2Word
   137  		src1Iter = unsafe.Add(src1Iter, BytesPerWord)
   138  		src2Iter = unsafe.Add(src2Iter, BytesPerWord)
   139  		dstIter = unsafe.Add(dstIter, BytesPerWord)
   140  	}
   141  }
   142  
   143  // ZZ sets dst[pos] := src1[pos] OPCHAR src2[pos] for every position in dst.  It
   144  // panics if slice lengths don't match.
   145  func ZZ(dst, src1, src2 []byte) {
   146  	dstLen := len(dst)
   147  	if (len(src1) != dstLen) || (len(src2) != dstLen) {
   148  		panic("ZZ() requires len(src1) == len(src2) == len(dst).")
   149  	}
   150  	if dstLen < BytesPerWord {
   151  		for pos, src1Byte := range src1 {
   152  			dst[pos] = src1Byte OPCHAR src2[pos]
   153  		}
   154  		return
   155  	}
   156  	src1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src1)).Data)
   157  	src2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src2)).Data)
   158  	dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data)
   159  	nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord
   160  
   161  	src1Iter := src1Data
   162  	src2Iter := src2Data
   163  	dstIter := dstData
   164  	for widx := 0; widx < nWordMinus1; widx++ {
   165  		src1Word := *((*uintptr)(src1Iter))
   166  		src2Word := *((*uintptr)(src2Iter))
   167  		*((*uintptr)(dstIter)) = src1Word OPCHAR src2Word
   168  		src1Iter = unsafe.Add(src1Iter, BytesPerWord)
   169  		src2Iter = unsafe.Add(src2Iter, BytesPerWord)
   170  		dstIter = unsafe.Add(dstIter, BytesPerWord)
   171  	}
   172  	// No store-forwarding problem here.
   173  	finalOffset := uintptr(dstLen - BytesPerWord)
   174  	src1Iter = unsafe.Add(src1Data, finalOffset)
   175  	src2Iter = unsafe.Add(src2Data, finalOffset)
   176  	dstIter = unsafe.Add(dstData, finalOffset)
   177  	src1Word := *((*uintptr)(src1Iter))
   178  	src2Word := *((*uintptr)(src2Iter))
   179  	*((*uintptr)(dstIter)) = src1Word OPCHAR src2Word
   180  }
   181  
   182  // ZZConst8UnsafeInplace sets main[pos] := main[pos] OPCHAR val for every position
   183  // in main[].
   184  //
   185  // WARNING: This is a function designed to be used in inner loops, which makes
   186  // assumptions about length and capacity which aren't checked at runtime.  Use
   187  // the safe version of this function when that's a problem.
   188  // These assumptions are always satisfied when the last
   189  // potentially-size-increasing operation on main[] is {Re}makeUnsafe(),
   190  // ResizeUnsafe(), or XcapUnsafe().
   191  //
   192  // 1. cap(main) is at least RoundUpPow2(len(main) + 1, bytesPerVec).
   193  //
   194  // 2. The caller does not care if a few bytes past the end of main[] are
   195  // changed.
   196  func ZZConst8UnsafeInplace(main []byte, val byte) {
   197  	mainLen := len(main)
   198  	argWord := 0x101010101010101 * uintptr(val)
   199  	mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data)
   200  	mainWordsIter := mainData
   201  	if mainLen > 2*BytesPerWord {
   202  		nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord
   203  		for widx := 0; widx < nWordMinus2; widx++ {
   204  			mainWord := *((*uintptr)(mainWordsIter))
   205  			*((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord
   206  			mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord)
   207  		}
   208  	} else if mainLen <= BytesPerWord {
   209  		mainWord := *((*uintptr)(mainWordsIter))
   210  		*((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord
   211  		return
   212  	}
   213  	mainWord1 := *((*uintptr)(mainWordsIter))
   214  	finalOffset := uintptr(mainLen - BytesPerWord)
   215  	mainFinalWordPtr := unsafe.Add(mainData, finalOffset)
   216  	mainWord2 := *((*uintptr)(mainFinalWordPtr))
   217  	*((*uintptr)(mainWordsIter)) = mainWord1 OPCHAR argWord
   218  	*((*uintptr)(mainFinalWordPtr)) = mainWord2 OPCHAR argWord
   219  }
   220  
   221  // ZZConst8Inplace sets main[pos] := main[pos] OPCHAR val for every position in
   222  // main[].
   223  func ZZConst8Inplace(main []byte, val byte) {
   224  	mainLen := len(main)
   225  	if mainLen < BytesPerWord {
   226  		for pos, mainByte := range main {
   227  			main[pos] = mainByte OPCHAR val
   228  		}
   229  		return
   230  	}
   231  	argWord := 0x101010101010101 * uintptr(val)
   232  	mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data)
   233  	mainWordsIter := mainData
   234  	if mainLen > 2*BytesPerWord {
   235  		nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord
   236  		for widx := 0; widx < nWordMinus2; widx++ {
   237  			mainWord := *((*uintptr)(mainWordsIter))
   238  			*((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord
   239  			mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord)
   240  		}
   241  	}
   242  	mainWord1 := *((*uintptr)(mainWordsIter))
   243  	finalOffset := uintptr(mainLen - BytesPerWord)
   244  	mainFinalWordPtr := unsafe.Add(mainData, finalOffset)
   245  	mainWord2 := *((*uintptr)(mainFinalWordPtr))
   246  	*((*uintptr)(mainWordsIter)) = mainWord1 OPCHAR argWord
   247  	*((*uintptr)(mainFinalWordPtr)) = mainWord2 OPCHAR argWord
   248  }
   249  
   250  // ZZConst8Unsafe sets dst[pos] := src[pos] OPCHAR val for every position in dst.
   251  //
   252  // WARNING: This is a function designed to be used in inner loops, which makes
   253  // assumptions about length and capacity which aren't checked at runtime.  Use
   254  // the safe version of this function when that's a problem.
   255  // Assumptions #2-3 are always satisfied when the last
   256  // potentially-size-increasing operation on src[] is {Re}makeUnsafe(),
   257  // ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[].
   258  //
   259  // 1. len(src) and len(dst) must be equal.
   260  //
   261  // 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec).
   262  //
   263  // 3. The caller does not care if a few bytes past the end of dst[] are
   264  // changed.
   265  func ZZConst8Unsafe(dst, src []byte, val byte) {
   266  	srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src))
   267  	dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst))
   268  	nWord := DivUpPow2(len(dst), BytesPerWord, Log2BytesPerWord)
   269  	argWord := 0x101010101010101 * uintptr(val)
   270  
   271  	srcIter := unsafe.Pointer(srcHeader.Data)
   272  	dstIter := unsafe.Pointer(dstHeader.Data)
   273  	for widx := 0; widx < nWord; widx++ {
   274  		srcWord := *((*uintptr)(srcIter))
   275  		*((*uintptr)(dstIter)) = srcWord OPCHAR argWord
   276  		srcIter = unsafe.Add(srcIter, BytesPerWord)
   277  		dstIter = unsafe.Add(dstIter, BytesPerWord)
   278  	}
   279  }
   280  
   281  // ZZConst8 sets dst[pos] := src[pos] OPCHAR val for every position in dst.  It
   282  // panics if slice lengths don't match.
   283  func ZZConst8(dst, src []byte, val byte) {
   284  	dstLen := len(dst)
   285  	if len(src) != dstLen {
   286  		panic("ZZConst8() requires len(src) == len(dst).")
   287  	}
   288  	if dstLen < BytesPerWord {
   289  		for pos, srcByte := range src {
   290  			dst[pos] = srcByte OPCHAR val
   291  		}
   292  		return
   293  	}
   294  	srcData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src)).Data)
   295  	dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data)
   296  	nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord
   297  	argWord := 0x101010101010101 * uintptr(val)
   298  
   299  	srcIter := unsafe.Pointer(srcData)
   300  	dstIter := unsafe.Pointer(dstData)
   301  	for widx := 0; widx < nWordMinus1; widx++ {
   302  		srcWord := *((*uintptr)(srcIter))
   303  		*((*uintptr)(dstIter)) = srcWord OPCHAR argWord
   304  		srcIter = unsafe.Add(srcIter, BytesPerWord)
   305  		dstIter = unsafe.Add(dstIter, BytesPerWord)
   306  	}
   307  	finalOffset := uintptr(dstLen - BytesPerWord)
   308  	srcIter = unsafe.Add(srcData, finalOffset)
   309  	dstIter = unsafe.Add(dstData, finalOffset)
   310  	srcWord := *((*uintptr)(srcIter))
   311  	*((*uintptr)(dstIter)) = srcWord OPCHAR argWord
   312  }