gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/header/checksum_test.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package header provides the implementation of the encoding and decoding of 16 // network protocol headers. 17 package header_test 18 19 import ( 20 "bytes" 21 "fmt" 22 "math/rand" 23 "sync" 24 "testing" 25 26 "gvisor.dev/gvisor/pkg/buffer" 27 "gvisor.dev/gvisor/pkg/tcpip" 28 "gvisor.dev/gvisor/pkg/tcpip/checksum" 29 "gvisor.dev/gvisor/pkg/tcpip/header" 30 ) 31 32 func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { 33 // icmpChecksum should not do any modifications of the header to 34 // calculate its checksum. Let's call it from a few go-routines and the 35 // race detector will trigger a warning if there are any concurrent 36 // read/write accesses. 37 38 const concurrency = 5 39 start := make(chan int) 40 ready := make(chan bool, concurrency) 41 var wg sync.WaitGroup 42 wg.Add(concurrency) 43 defer wg.Wait() 44 45 for i := 0; i < concurrency; i++ { 46 go func() { 47 defer wg.Done() 48 49 ready <- true 50 <-start 51 52 if got := headerChecksum(); want != got { 53 t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) 54 } 55 if got := icmpChecksum(); want != got { 56 t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) 57 } 58 }() 59 } 60 for i := 0; i < concurrency; i++ { 61 <-ready 62 } 63 close(start) 64 } 65 66 // TODO(b/239732156): Replace magic constants with names corresponding to what 67 // they represent ICMP. 68 func TestICMPv4Checksum(t *testing.T) { 69 rnd := rand.New(rand.NewSource(42)) 70 71 h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) 72 if _, err := rnd.Read(h); err != nil { 73 t.Fatalf("rnd.Read failed: %v", err) 74 } 75 h.SetChecksum(0) 76 77 buf := make([]byte, 13) 78 if _, err := rnd.Read(buf); err != nil { 79 t.Fatalf("rnd.Read failed: %v", err) 80 } 81 b := buffer.MakeWithData(buf[:5]) 82 b.Append(buffer.NewViewWithData(buf[5:])) 83 84 want := checksum.Checksum(b.Flatten(), 0) 85 want = ^checksum.Checksum(h, want) 86 h.SetChecksum(want) 87 88 testICMPChecksum(t, h.Checksum, func() uint16 { 89 return header.ICMPv4Checksum(h, b.Checksum(0)) 90 }, want, fmt.Sprintf("header: {% x} data {% x}", h, b.Flatten())) 91 } 92 93 func TestICMPv4ChecksumUpdate(t *testing.T) { 94 const icmpIdent = 0 95 96 data := make([]byte, header.ICMPv4MinimumSize) 97 h := header.ICMPv4(data) 98 h.SetType(header.ICMPv4EchoReply) 99 h.SetCode(header.ICMPv4UnusedCode) 100 h.SetIdent(icmpIdent) 101 h.SetChecksum(^checksum.Checksum(data, 0)) 102 103 updated := header.ICMPv4(bytes.Clone(data)) 104 // Perform an incremental checksum update where we aren't actually changing the ID. 105 updated.SetIdentWithChecksumUpdate(icmpIdent) 106 if updated.Checksum() != h.Checksum() { 107 t.Errorf("got updated.Checksum() = %x, want = %x", updated.Checksum(), h.Checksum()) 108 } 109 } 110 111 func TestICMPv6Checksum(t *testing.T) { 112 rnd := rand.New(rand.NewSource(42)) 113 114 h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) 115 if _, err := rnd.Read(h); err != nil { 116 t.Fatalf("rnd.Read failed: %v", err) 117 } 118 h.SetChecksum(0) 119 120 buf := make([]byte, 13) 121 if _, err := rnd.Read(buf); err != nil { 122 t.Fatalf("rnd.Read failed: %v", err) 123 } 124 b := buffer.MakeWithData(buf[:7]) 125 b.Append(buffer.NewViewWithData(buf[7:10])) 126 b.Append(buffer.NewViewWithData(buf[10:])) 127 128 dst := header.IPv6Loopback 129 src := header.IPv6Loopback 130 131 want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+int(b.Size()))) 132 want = checksum.Checksum(b.Flatten(), want) 133 want = ^checksum.Checksum(h, want) 134 h.SetChecksum(want) 135 136 testICMPChecksum(t, h.Checksum, func() uint16 { 137 return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 138 Header: h, 139 Src: src, 140 Dst: dst, 141 PayloadCsum: b.Checksum(0), 142 PayloadLen: int(b.Size()), 143 }) 144 }, want, fmt.Sprintf("header: {% x} data {% x}", h, b.Flatten())) 145 } 146 147 func randomAddress(size int) tcpip.Address { 148 s := make([]byte, size) 149 for i := 0; i < size; i++ { 150 s[i] = byte(rand.Uint32()) 151 } 152 return tcpip.AddrFromSlice(s) 153 } 154 155 func TestChecksummableNetworkUpdateAddress(t *testing.T) { 156 tests := []struct { 157 name string 158 update func(header.IPv4, tcpip.Address) 159 }{ 160 { 161 name: "SetSourceAddressWithChecksumUpdate", 162 update: header.IPv4.SetSourceAddressWithChecksumUpdate, 163 }, 164 { 165 name: "SetDestinationAddressWithChecksumUpdate", 166 update: header.IPv4.SetDestinationAddressWithChecksumUpdate, 167 }, 168 } 169 170 for _, test := range tests { 171 t.Run(test.name, func(t *testing.T) { 172 for i := 0; i < 1000; i++ { 173 var origBytes [header.IPv4MinimumSize]byte 174 header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ 175 TOS: 1, 176 TotalLength: header.IPv4MinimumSize, 177 ID: 2, 178 Flags: 3, 179 FragmentOffset: 4, 180 TTL: 5, 181 Protocol: 6, 182 Checksum: 0, 183 SrcAddr: randomAddress(header.IPv4AddressSize), 184 DstAddr: randomAddress(header.IPv4AddressSize), 185 }) 186 187 addr := randomAddress(header.IPv4AddressSize) 188 189 bytesCopy := origBytes 190 h := header.IPv4(bytesCopy[:]) 191 origXSum := h.CalculateChecksum() 192 h.SetChecksum(^origXSum) 193 194 test.update(h, addr) 195 got := ^h.Checksum() 196 h.SetChecksum(0) 197 want := h.CalculateChecksum() 198 if got != want { 199 t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) 200 } 201 } 202 }) 203 } 204 } 205 206 func TestChecksummableTransportUpdatePort(t *testing.T) { 207 // The fields in the pseudo header is not tested here so we just use 0. 208 const pseudoHeaderXSum = 0 209 210 tests := []struct { 211 name string 212 transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) 213 proto tcpip.TransportProtocolNumber 214 }{ 215 { 216 name: "TCP", 217 transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { 218 h := header.TCP(make([]byte, header.TCPMinimumSize)) 219 h.Encode(&header.TCPFields{ 220 SrcPort: src, 221 DstPort: dst, 222 SeqNum: 1, 223 AckNum: 2, 224 DataOffset: header.TCPMinimumSize, 225 Flags: 3, 226 WindowSize: 4, 227 Checksum: 0, 228 UrgentPointer: 5, 229 }) 230 h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) 231 return h, h.CalculateChecksum 232 }, 233 proto: header.TCPProtocolNumber, 234 }, 235 { 236 name: "UDP", 237 transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { 238 h := header.UDP(make([]byte, header.UDPMinimumSize)) 239 h.Encode(&header.UDPFields{ 240 SrcPort: src, 241 DstPort: dst, 242 Length: 0, 243 Checksum: 0, 244 }) 245 h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) 246 return h, h.CalculateChecksum 247 }, 248 proto: header.UDPProtocolNumber, 249 }, 250 } 251 252 for i := 0; i < 1000; i++ { 253 origSrcPort := uint16(rand.Uint32()) 254 origDstPort := uint16(rand.Uint32()) 255 newPort := uint16(rand.Uint32()) 256 257 t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(t *testing.T) { 258 for _, test := range tests { 259 t.Run(test.name, func(t *testing.T) { 260 for _, subTest := range []struct { 261 name string 262 update func(header.ChecksummableTransport) 263 }{ 264 { 265 name: "Source port", 266 update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, 267 }, 268 { 269 name: "Destination port", 270 update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, 271 }, 272 } { 273 t.Run(subTest.name, func(t *testing.T) { 274 h, calcXSum := test.transportHdr(origSrcPort, origDstPort) 275 subTest.update(h) 276 // TCP and UDP hold the 1s complement of the fully calculated 277 // checksum. 278 got := ^h.Checksum() 279 h.SetChecksum(0) 280 281 if want := calcXSum(pseudoHeaderXSum); got != want { 282 h, _ := test.transportHdr(origSrcPort, origDstPort) 283 t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) 284 } 285 }) 286 } 287 }) 288 } 289 }) 290 } 291 } 292 293 func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { 294 const addressSize = 16 295 296 tests := []struct { 297 name string 298 transportHdr func() header.ChecksummableTransport 299 proto tcpip.TransportProtocolNumber 300 }{ 301 { 302 name: "TCP", 303 transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, 304 proto: header.TCPProtocolNumber, 305 }, 306 { 307 name: "UDP", 308 transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, 309 proto: header.UDPProtocolNumber, 310 }, 311 } 312 313 for i := 0; i < 1000; i++ { 314 permanent := randomAddress(addressSize) 315 old := randomAddress(addressSize) 316 new := randomAddress(addressSize) 317 318 t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { 319 for _, test := range tests { 320 t.Run(test.name, func(t *testing.T) { 321 for _, fullChecksum := range []bool{true, false} { 322 t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { 323 initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) 324 if fullChecksum { 325 // TCP and UDP hold the 1s complement of the fully calculated 326 // checksum. 327 initialXSum = ^initialXSum 328 } 329 330 h := test.transportHdr() 331 h.SetChecksum(initialXSum) 332 h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) 333 334 got := h.Checksum() 335 if fullChecksum { 336 got = ^got 337 } 338 if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { 339 t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) 340 } 341 }) 342 } 343 }) 344 } 345 }) 346 } 347 }