github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/tcpip/checksum/checksum_unsafe.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package checksum
    16  
    17  import (
    18  	"encoding/binary"
    19  	"math/bits"
    20  	"unsafe"
    21  )
    22  
    23  // Note: odd indicates whether initial is a partial checksum over an odd number
    24  // of bytes.
    25  func calculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) {
    26  	// Use a larger-than-uint16 accumulator to benefit from parallel summation
    27  	// as described in RFC 1071 1.2.C.
    28  	acc := uint64(initial)
    29  
    30  	// Handle an odd number of previously-summed bytes, and get the return
    31  	// value for odd.
    32  	if odd {
    33  		acc += uint64(buf[0])
    34  		buf = buf[1:]
    35  	}
    36  	odd = len(buf)&1 != 0
    37  
    38  	// Aligning &buf[0] below is much simpler if len(buf) >= 8; special-case
    39  	// smaller bufs.
    40  	if len(buf) < 8 {
    41  		if len(buf) >= 4 {
    42  			acc += (uint64(buf[0]) << 8) + uint64(buf[1])
    43  			acc += (uint64(buf[2]) << 8) + uint64(buf[3])
    44  			buf = buf[4:]
    45  		}
    46  		if len(buf) >= 2 {
    47  			acc += (uint64(buf[0]) << 8) + uint64(buf[1])
    48  			buf = buf[2:]
    49  		}
    50  		if len(buf) >= 1 {
    51  			acc += uint64(buf[0]) << 8
    52  			// buf = buf[1:] is skipped because it's unused and nogo will
    53  			// complain.
    54  		}
    55  		return reduce(acc), odd
    56  	}
    57  
    58  	// On little-endian architectures, multi-byte loads from buf will load
    59  	// bytes in the wrong order. Rather than byte-swap after each load (slow),
    60  	// we byte-swap the accumulator before summing any bytes and byte-swap it
    61  	// back before returning, which still produces the correct result as
    62  	// described in RFC 1071 1.2.B "Byte Order Independence".
    63  	//
    64  	// acc is at most a uint16 + a uint8, so its upper 32 bits must be 0s. We
    65  	// preserve this property by byte-swapping only the lower 32 bits of acc,
    66  	// so that additions to acc performed during alignment can't overflow.
    67  	acc = uint64(bswapIfLittleEndian32(uint32(acc)))
    68  
    69  	// Align &buf[0] to an 8-byte boundary.
    70  	bswapped := false
    71  	if sliceAddr(buf)&1 != 0 {
    72  		// Compute the rest of the partial checksum with bytes swapped, and
    73  		// swap back before returning; see the last paragraph of
    74  		// RFC 1071 1.2.B.
    75  		acc = uint64(bits.ReverseBytes32(uint32(acc)))
    76  		bswapped = true
    77  		// No `<< 8` here due to the byte swap we just did.
    78  		acc += uint64(bswapIfLittleEndian16(uint16(buf[0])))
    79  		buf = buf[1:]
    80  	}
    81  	if sliceAddr(buf)&2 != 0 {
    82  		acc += uint64(*(*uint16)(unsafe.Pointer(&buf[0])))
    83  		buf = buf[2:]
    84  	}
    85  	if sliceAddr(buf)&4 != 0 {
    86  		acc += uint64(*(*uint32)(unsafe.Pointer(&buf[0])))
    87  		buf = buf[4:]
    88  	}
    89  
    90  	// Sum 64 bytes at a time. Beyond this point, additions to acc may
    91  	// overflow, so we have to handle carrying.
    92  	for len(buf) >= 64 {
    93  		var carry uint64
    94  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
    95  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
    96  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
    97  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
    98  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[32])), carry)
    99  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[40])), carry)
   100  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[48])), carry)
   101  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[56])), carry)
   102  		acc, _ = bits.Add64(acc, 0, carry)
   103  		buf = buf[64:]
   104  	}
   105  
   106  	// Sum the remaining 0-63 bytes.
   107  	if len(buf) >= 32 {
   108  		var carry uint64
   109  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
   110  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
   111  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
   112  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
   113  		acc, _ = bits.Add64(acc, 0, carry)
   114  		buf = buf[32:]
   115  	}
   116  	if len(buf) >= 16 {
   117  		var carry uint64
   118  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
   119  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
   120  		acc, _ = bits.Add64(acc, 0, carry)
   121  		buf = buf[16:]
   122  	}
   123  	if len(buf) >= 8 {
   124  		var carry uint64
   125  		acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
   126  		acc, _ = bits.Add64(acc, 0, carry)
   127  		buf = buf[8:]
   128  	}
   129  	if len(buf) >= 4 {
   130  		var carry uint64
   131  		acc, carry = bits.Add64(acc, uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), 0)
   132  		acc, _ = bits.Add64(acc, 0, carry)
   133  		buf = buf[4:]
   134  	}
   135  	if len(buf) >= 2 {
   136  		var carry uint64
   137  		acc, carry = bits.Add64(acc, uint64(*(*uint16)(unsafe.Pointer(&buf[0]))), 0)
   138  		acc, _ = bits.Add64(acc, 0, carry)
   139  		buf = buf[2:]
   140  	}
   141  	if len(buf) >= 1 {
   142  		// bswapIfBigEndian16(buf[0]) == bswapIfLittleEndian16(buf[0]<<8).
   143  		var carry uint64
   144  		acc, carry = bits.Add64(acc, uint64(bswapIfBigEndian16(uint16(buf[0]))), 0)
   145  		acc, _ = bits.Add64(acc, 0, carry)
   146  		// buf = buf[1:] is skipped because it's unused and nogo will complain.
   147  	}
   148  
   149  	// Reduce the checksum to 16 bits and undo byte swaps before returning.
   150  	acc16 := bswapIfLittleEndian16(reduce(acc))
   151  	if bswapped {
   152  		acc16 = bits.ReverseBytes16(acc16)
   153  	}
   154  	return acc16, odd
   155  }
   156  
   157  func reduce(acc uint64) uint16 {
   158  	// Ideally we would do:
   159  	//   return uint16(acc>>48) +' uint16(acc>>32) +' uint16(acc>>16) +' uint16(acc)
   160  	// for more instruction-level parallelism; however, there is no
   161  	// bits.Add16().
   162  	acc = (acc >> 32) + (acc & 0xffff_ffff)  // at most 0x1_ffff_fffe
   163  	acc32 := uint32(acc>>32 + acc)           // at most 0xffff_ffff
   164  	acc32 = (acc32 >> 16) + (acc32 & 0xffff) // at most 0x1_fffe
   165  	return uint16(acc32>>16 + acc32)         // at most 0xffff
   166  }
   167  
   168  func bswapIfLittleEndian32(val uint32) uint32 {
   169  	return binary.BigEndian.Uint32((*[4]byte)(unsafe.Pointer(&val))[:])
   170  }
   171  
   172  func bswapIfLittleEndian16(val uint16) uint16 {
   173  	return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
   174  }
   175  
   176  func bswapIfBigEndian16(val uint16) uint16 {
   177  	return binary.LittleEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
   178  }
   179  
   180  func sliceAddr(buf []byte) uintptr {
   181  	return uintptr(unsafe.Pointer(unsafe.SliceData(buf)))
   182  }