inet.af/netstack@v0.0.0-20220214151720-7585b01ddccf/tcpip/header/checksum.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package header provides the implementation of the encoding and decoding of 16 // network protocol headers. 17 package header 18 19 import ( 20 "encoding/binary" 21 "fmt" 22 23 "inet.af/netstack/tcpip" 24 "inet.af/netstack/tcpip/buffer" 25 ) 26 27 func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { 28 v := initial 29 30 if odd { 31 v += uint32(buf[0]) 32 buf = buf[1:] 33 } 34 35 l := len(buf) 36 odd = l&1 != 0 37 if odd { 38 l-- 39 v += uint32(buf[l]) << 8 40 } 41 42 for i := 0; i < l; i += 2 { 43 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 44 } 45 46 return ChecksumCombine(uint16(v), uint16(v>>16)), odd 47 } 48 49 func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { 50 v := initial 51 52 if odd { 53 v += uint32(buf[0]) 54 buf = buf[1:] 55 } 56 57 l := len(buf) 58 odd = l&1 != 0 59 if odd { 60 l-- 61 v += uint32(buf[l]) << 8 62 } 63 for (l - 64) >= 0 { 64 i := 0 65 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 66 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 67 v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) 68 v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) 69 v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) 70 v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) 71 v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) 72 v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) 73 i += 16 74 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 75 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 76 v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) 77 v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) 78 v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) 79 v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) 80 v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) 81 v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) 82 i += 16 83 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 84 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 85 v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) 86 v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) 87 v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) 88 v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) 89 v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) 90 v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) 91 i += 16 92 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 93 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 94 v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) 95 v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) 96 v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) 97 v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) 98 v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) 99 v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) 100 buf = buf[64:] 101 l = l - 64 102 } 103 if (l - 32) >= 0 { 104 i := 0 105 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 106 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 107 v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) 108 v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) 109 v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) 110 v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) 111 v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) 112 v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) 113 i += 16 114 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 115 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 116 v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) 117 v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) 118 v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) 119 v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) 120 v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) 121 v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) 122 buf = buf[32:] 123 l = l - 32 124 } 125 if (l - 16) >= 0 { 126 i := 0 127 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 128 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 129 v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) 130 v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) 131 v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) 132 v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) 133 v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) 134 v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) 135 buf = buf[16:] 136 l = l - 16 137 } 138 if (l - 8) >= 0 { 139 i := 0 140 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 141 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 142 v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) 143 v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) 144 buf = buf[8:] 145 l = l - 8 146 } 147 if (l - 4) >= 0 { 148 i := 0 149 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 150 v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) 151 buf = buf[4:] 152 l = l - 4 153 } 154 155 // At this point since l was even before we started unrolling 156 // there can be only two bytes left to add. 157 if l != 0 { 158 v += (uint32(buf[0]) << 8) + uint32(buf[1]) 159 } 160 161 return ChecksumCombine(uint16(v), uint16(v>>16)), odd 162 } 163 164 // ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in 165 // the given byte array. This function uses a non-optimized implementation. Its 166 // only retained for reference and to use as a benchmark/test. Most code should 167 // use the header.Checksum function. 168 // 169 // The initial checksum must have been computed on an even number of bytes. 170 func ChecksumOld(buf []byte, initial uint16) uint16 { 171 s, _ := calculateChecksum(buf, false, uint32(initial)) 172 return s 173 } 174 175 // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the 176 // given byte array. This function uses an optimized unrolled version of the 177 // checksum algorithm. 178 // 179 // The initial checksum must have been computed on an even number of bytes. 180 func Checksum(buf []byte, initial uint16) uint16 { 181 s, _ := unrolledCalculateChecksum(buf, false, uint32(initial)) 182 return s 183 } 184 185 // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in 186 // the given VectorizedView. 187 // 188 // The initial checksum must have been computed on an even number of bytes. 189 func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { 190 var c Checksumer 191 for _, v := range vv.Views() { 192 c.Add([]byte(v)) 193 } 194 return ChecksumCombine(initial, c.Checksum()) 195 } 196 197 // Checksumer calculates checksum defined in RFC 1071. 198 type Checksumer struct { 199 sum uint16 200 odd bool 201 } 202 203 // Add adds b to checksum. 204 func (c *Checksumer) Add(b []byte) { 205 if len(b) > 0 { 206 c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum)) 207 } 208 } 209 210 // Checksum returns the latest checksum value. 211 func (c *Checksumer) Checksum() uint16 { 212 return c.sum 213 } 214 215 // ChecksumCombine combines the two uint16 to form their checksum. This is done 216 // by adding them and the carry. 217 // 218 // Note that checksum a must have been computed on an even number of bytes. 219 func ChecksumCombine(a, b uint16) uint16 { 220 v := uint32(a) + uint32(b) 221 return uint16(v + v>>16) 222 } 223 224 // PseudoHeaderChecksum calculates the pseudo-header checksum for the given 225 // destination protocol and network address. Pseudo-headers are needed by 226 // transport layers when calculating their own checksum. 227 func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 { 228 xsum := Checksum([]byte(srcAddr), 0) 229 xsum = Checksum([]byte(dstAddr), xsum) 230 231 // Add the length portion of the checksum to the pseudo-checksum. 232 tmp := make([]byte, 2) 233 binary.BigEndian.PutUint16(tmp, totalLen) 234 xsum = Checksum(tmp, xsum) 235 236 return Checksum([]byte{0, uint8(protocol)}, xsum) 237 } 238 239 // checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated 240 // checksum. 241 // 242 // The value MUST begin at a 2-byte boundary in the original buffer. 243 func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { 244 // As per RFC 1071 page 4, 245 // (4) Incremental Update 246 // 247 // ... 248 // 249 // To update the checksum, simply add the differences of the 250 // sixteen bit integers that have been changed. To see why this 251 // works, observe that every 16-bit integer has an additive inverse 252 // and that addition is associative. From this it follows that 253 // given the original value m, the new value m', and the old 254 // checksum C, the new checksum C' is: 255 // 256 // C' = C + (-m) + m' = C + (m' - m) 257 return ChecksumCombine(xsum, ChecksumCombine(new, ^old)) 258 } 259 260 // checksumUpdate2ByteAlignedAddress updates an address in a calculated 261 // checksum. 262 // 263 // The addresses must have the same length and must contain an even number 264 // of bytes. The address MUST begin at a 2-byte boundary in the original buffer. 265 func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 { 266 const uint16Bytes = 2 267 268 if len(old) != len(new) { 269 panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new))) 270 } 271 272 if len(old)%uint16Bytes != 0 { 273 panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old))) 274 } 275 276 // As per RFC 1071 page 4, 277 // (4) Incremental Update 278 // 279 // ... 280 // 281 // To update the checksum, simply add the differences of the 282 // sixteen bit integers that have been changed. To see why this 283 // works, observe that every 16-bit integer has an additive inverse 284 // and that addition is associative. From this it follows that 285 // given the original value m, the new value m', and the old 286 // checksum C, the new checksum C' is: 287 // 288 // C' = C + (-m) + m' = C + (m' - m) 289 for len(old) != 0 { 290 // Convert the 2 byte sequences to uint16 values then apply the increment 291 // update. 292 xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1])) 293 old = old[uint16Bytes:] 294 new = new[uint16Bytes:] 295 } 296 297 return xsum 298 }