github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/tcpip/checksum/checksum_unsafe.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package checksum 16 17 import ( 18 "encoding/binary" 19 "math/bits" 20 "unsafe" 21 ) 22 23 // Note: odd indicates whether initial is a partial checksum over an odd number 24 // of bytes. 25 func calculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) { 26 // Use a larger-than-uint16 accumulator to benefit from parallel summation 27 // as described in RFC 1071 1.2.C. 28 acc := uint64(initial) 29 30 // Handle an odd number of previously-summed bytes, and get the return 31 // value for odd. 32 if odd { 33 acc += uint64(buf[0]) 34 buf = buf[1:] 35 } 36 odd = len(buf)&1 != 0 37 38 // Aligning &buf[0] below is much simpler if len(buf) >= 8; special-case 39 // smaller bufs. 40 if len(buf) < 8 { 41 if len(buf) >= 4 { 42 acc += (uint64(buf[0]) << 8) + uint64(buf[1]) 43 acc += (uint64(buf[2]) << 8) + uint64(buf[3]) 44 buf = buf[4:] 45 } 46 if len(buf) >= 2 { 47 acc += (uint64(buf[0]) << 8) + uint64(buf[1]) 48 buf = buf[2:] 49 } 50 if len(buf) >= 1 { 51 acc += uint64(buf[0]) << 8 52 // buf = buf[1:] is skipped because it's unused and nogo will 53 // complain. 54 } 55 return reduce(acc), odd 56 } 57 58 // On little-endian architectures, multi-byte loads from buf will load 59 // bytes in the wrong order. Rather than byte-swap after each load (slow), 60 // we byte-swap the accumulator before summing any bytes and byte-swap it 61 // back before returning, which still produces the correct result as 62 // described in RFC 1071 1.2.B "Byte Order Independence". 63 // 64 // acc is at most a uint16 + a uint8, so its upper 32 bits must be 0s. We 65 // preserve this property by byte-swapping only the lower 32 bits of acc, 66 // so that additions to acc performed during alignment can't overflow. 67 acc = uint64(bswapIfLittleEndian32(uint32(acc))) 68 69 // Align &buf[0] to an 8-byte boundary. 70 bswapped := false 71 if sliceAddr(buf)&1 != 0 { 72 // Compute the rest of the partial checksum with bytes swapped, and 73 // swap back before returning; see the last paragraph of 74 // RFC 1071 1.2.B. 75 acc = uint64(bits.ReverseBytes32(uint32(acc))) 76 bswapped = true 77 // No `<< 8` here due to the byte swap we just did. 78 acc += uint64(bswapIfLittleEndian16(uint16(buf[0]))) 79 buf = buf[1:] 80 } 81 if sliceAddr(buf)&2 != 0 { 82 acc += uint64(*(*uint16)(unsafe.Pointer(&buf[0]))) 83 buf = buf[2:] 84 } 85 if sliceAddr(buf)&4 != 0 { 86 acc += uint64(*(*uint32)(unsafe.Pointer(&buf[0]))) 87 buf = buf[4:] 88 } 89 90 // Sum 64 bytes at a time. Beyond this point, additions to acc may 91 // overflow, so we have to handle carrying. 92 for len(buf) >= 64 { 93 var carry uint64 94 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0) 95 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry) 96 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry) 97 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry) 98 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[32])), carry) 99 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[40])), carry) 100 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[48])), carry) 101 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[56])), carry) 102 acc, _ = bits.Add64(acc, 0, carry) 103 buf = buf[64:] 104 } 105 106 // Sum the remaining 0-63 bytes. 107 if len(buf) >= 32 { 108 var carry uint64 109 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0) 110 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry) 111 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry) 112 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry) 113 acc, _ = bits.Add64(acc, 0, carry) 114 buf = buf[32:] 115 } 116 if len(buf) >= 16 { 117 var carry uint64 118 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0) 119 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry) 120 acc, _ = bits.Add64(acc, 0, carry) 121 buf = buf[16:] 122 } 123 if len(buf) >= 8 { 124 var carry uint64 125 acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0) 126 acc, _ = bits.Add64(acc, 0, carry) 127 buf = buf[8:] 128 } 129 if len(buf) >= 4 { 130 var carry uint64 131 acc, carry = bits.Add64(acc, uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), 0) 132 acc, _ = bits.Add64(acc, 0, carry) 133 buf = buf[4:] 134 } 135 if len(buf) >= 2 { 136 var carry uint64 137 acc, carry = bits.Add64(acc, uint64(*(*uint16)(unsafe.Pointer(&buf[0]))), 0) 138 acc, _ = bits.Add64(acc, 0, carry) 139 buf = buf[2:] 140 } 141 if len(buf) >= 1 { 142 // bswapIfBigEndian16(buf[0]) == bswapIfLittleEndian16(buf[0]<<8). 143 var carry uint64 144 acc, carry = bits.Add64(acc, uint64(bswapIfBigEndian16(uint16(buf[0]))), 0) 145 acc, _ = bits.Add64(acc, 0, carry) 146 // buf = buf[1:] is skipped because it's unused and nogo will complain. 147 } 148 149 // Reduce the checksum to 16 bits and undo byte swaps before returning. 150 acc16 := bswapIfLittleEndian16(reduce(acc)) 151 if bswapped { 152 acc16 = bits.ReverseBytes16(acc16) 153 } 154 return acc16, odd 155 } 156 157 func reduce(acc uint64) uint16 { 158 // Ideally we would do: 159 // return uint16(acc>>48) +' uint16(acc>>32) +' uint16(acc>>16) +' uint16(acc) 160 // for more instruction-level parallelism; however, there is no 161 // bits.Add16(). 162 acc = (acc >> 32) + (acc & 0xffff_ffff) // at most 0x1_ffff_fffe 163 acc32 := uint32(acc>>32 + acc) // at most 0xffff_ffff 164 acc32 = (acc32 >> 16) + (acc32 & 0xffff) // at most 0x1_fffe 165 return uint16(acc32>>16 + acc32) // at most 0xffff 166 } 167 168 func bswapIfLittleEndian32(val uint32) uint32 { 169 return binary.BigEndian.Uint32((*[4]byte)(unsafe.Pointer(&val))[:]) 170 } 171 172 func bswapIfLittleEndian16(val uint16) uint16 { 173 return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:]) 174 } 175 176 func bswapIfBigEndian16(val uint16) uint16 { 177 return binary.LittleEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:]) 178 } 179 180 func sliceAddr(buf []byte) uintptr { 181 return uintptr(unsafe.Pointer(unsafe.SliceData(buf))) 182 }