github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/header/checksum_test.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package header provides the implementation of the encoding and decoding of 16 // network protocol headers. 17 package header_test 18 19 import ( 20 "bytes" 21 "fmt" 22 "math/rand" 23 "sync" 24 "testing" 25 26 "github.com/SagerNet/gvisor/pkg/tcpip" 27 "github.com/SagerNet/gvisor/pkg/tcpip/buffer" 28 "github.com/SagerNet/gvisor/pkg/tcpip/header" 29 ) 30 31 func TestChecksumer(t *testing.T) { 32 testCases := []struct { 33 name string 34 data [][]byte 35 want uint16 36 }{ 37 { 38 name: "empty", 39 want: 0, 40 }, 41 { 42 name: "OneOddView", 43 data: [][]byte{ 44 []byte{1, 9, 0, 5, 4}, 45 }, 46 want: 1294, 47 }, 48 { 49 name: "TwoOddViews", 50 data: [][]byte{ 51 []byte{1, 9, 0, 5, 4}, 52 []byte{4, 3, 7, 1, 2, 123}, 53 }, 54 want: 33819, 55 }, 56 { 57 name: "OneEvenView", 58 data: [][]byte{ 59 []byte{1, 9, 0, 5}, 60 }, 61 want: 270, 62 }, 63 { 64 name: "TwoEvenViews", 65 data: [][]byte{ 66 buffer.NewViewFromBytes([]byte{98, 1, 9, 0}), 67 buffer.NewViewFromBytes([]byte{9, 0, 5, 4}), 68 }, 69 want: 30981, 70 }, 71 { 72 name: "ThreeViews", 73 data: [][]byte{ 74 []byte{77, 11, 33, 0, 55, 44}, 75 []byte{98, 1, 9, 0, 5, 4}, 76 []byte{4, 3, 7, 1, 2, 123, 99}, 77 }, 78 want: 34236, 79 }, 80 } 81 for _, tc := range testCases { 82 t.Run(tc.name, func(t *testing.T) { 83 var all bytes.Buffer 84 var c header.Checksumer 85 for _, b := range tc.data { 86 c.Add(b) 87 // Append to the buffer. We will check the checksum as a whole later. 88 if _, err := all.Write(b); err != nil { 89 t.Fatalf("all.Write(b) = _, %s; want _, nil", err) 90 } 91 } 92 if got, want := c.Checksum(), tc.want; got != want { 93 t.Errorf("c.Checksum() = %d, want %d", got, want) 94 } 95 if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want { 96 t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want) 97 } 98 }) 99 } 100 } 101 102 func TestChecksum(t *testing.T) { 103 var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} 104 type testCase struct { 105 buf []byte 106 initial uint16 107 csumOrig uint16 108 csumNew uint16 109 } 110 testCases := make([]testCase, 100000) 111 // Ensure same buffer generation for test consistency. 112 rnd := rand.New(rand.NewSource(42)) 113 for i := range testCases { 114 testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) 115 testCases[i].initial = uint16(rnd.Intn(65536)) 116 rnd.Read(testCases[i].buf) 117 } 118 119 for i := range testCases { 120 testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) 121 testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) 122 if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { 123 t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) 124 } 125 } 126 } 127 128 func BenchmarkChecksum(b *testing.B) { 129 var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} 130 131 checkSumImpls := []struct { 132 fn func([]byte, uint16) uint16 133 name string 134 }{ 135 {header.ChecksumOld, fmt.Sprintf("checksum_old")}, 136 {header.Checksum, fmt.Sprintf("checksum")}, 137 } 138 139 for _, csumImpl := range checkSumImpls { 140 // Ensure same buffer generation for test consistency. 141 rnd := rand.New(rand.NewSource(42)) 142 for _, bufSz := range bufSizes { 143 b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { 144 tc := struct { 145 buf []byte 146 initial uint16 147 csum uint16 148 }{ 149 buf: make([]byte, bufSz), 150 initial: uint16(rnd.Intn(65536)), 151 } 152 rnd.Read(tc.buf) 153 b.ResetTimer() 154 for i := 0; i < b.N; i++ { 155 tc.csum = csumImpl.fn(tc.buf, tc.initial) 156 } 157 }) 158 } 159 } 160 } 161 162 func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { 163 // icmpChecksum should not do any modifications of the header to 164 // calculate its checksum. Let's call it from a few go-routines and the 165 // race detector will trigger a warning if there are any concurrent 166 // read/write accesses. 167 168 const concurrency = 5 169 start := make(chan int) 170 ready := make(chan bool, concurrency) 171 var wg sync.WaitGroup 172 wg.Add(concurrency) 173 defer wg.Wait() 174 175 for i := 0; i < concurrency; i++ { 176 go func() { 177 defer wg.Done() 178 179 ready <- true 180 <-start 181 182 if got := headerChecksum(); want != got { 183 t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) 184 } 185 if got := icmpChecksum(); want != got { 186 t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) 187 } 188 }() 189 } 190 for i := 0; i < concurrency; i++ { 191 <-ready 192 } 193 close(start) 194 } 195 196 func TestICMPv4Checksum(t *testing.T) { 197 rnd := rand.New(rand.NewSource(42)) 198 199 h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) 200 if _, err := rnd.Read(h); err != nil { 201 t.Fatalf("rnd.Read failed: %v", err) 202 } 203 h.SetChecksum(0) 204 205 buf := make([]byte, 13) 206 if _, err := rnd.Read(buf); err != nil { 207 t.Fatalf("rnd.Read failed: %v", err) 208 } 209 vv := buffer.NewVectorisedView(len(buf), []buffer.View{ 210 buffer.NewViewFromBytes(buf[:5]), 211 buffer.NewViewFromBytes(buf[5:]), 212 }) 213 214 want := header.Checksum(vv.ToView(), 0) 215 want = ^header.Checksum(h, want) 216 h.SetChecksum(want) 217 218 testICMPChecksum(t, h.Checksum, func() uint16 { 219 return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0)) 220 }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) 221 } 222 223 func TestICMPv6Checksum(t *testing.T) { 224 rnd := rand.New(rand.NewSource(42)) 225 226 h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) 227 if _, err := rnd.Read(h); err != nil { 228 t.Fatalf("rnd.Read failed: %v", err) 229 } 230 h.SetChecksum(0) 231 232 buf := make([]byte, 13) 233 if _, err := rnd.Read(buf); err != nil { 234 t.Fatalf("rnd.Read failed: %v", err) 235 } 236 vv := buffer.NewVectorisedView(len(buf), []buffer.View{ 237 buffer.NewViewFromBytes(buf[:7]), 238 buffer.NewViewFromBytes(buf[7:10]), 239 buffer.NewViewFromBytes(buf[10:]), 240 }) 241 242 dst := header.IPv6Loopback 243 src := header.IPv6Loopback 244 245 want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) 246 want = header.Checksum(vv.ToView(), want) 247 want = ^header.Checksum(h, want) 248 h.SetChecksum(want) 249 250 testICMPChecksum(t, h.Checksum, func() uint16 { 251 return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 252 Header: h, 253 Src: src, 254 Dst: dst, 255 PayloadCsum: header.ChecksumVV(vv, 0), 256 PayloadLen: vv.Size(), 257 }) 258 }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) 259 } 260 261 func randomAddress(size int) tcpip.Address { 262 s := make([]byte, size) 263 for i := 0; i < size; i++ { 264 s[i] = byte(rand.Uint32()) 265 } 266 return tcpip.Address(s) 267 } 268 269 func TestChecksummableNetworkUpdateAddress(t *testing.T) { 270 tests := []struct { 271 name string 272 update func(header.IPv4, tcpip.Address) 273 }{ 274 { 275 name: "SetSourceAddressWithChecksumUpdate", 276 update: header.IPv4.SetSourceAddressWithChecksumUpdate, 277 }, 278 { 279 name: "SetDestinationAddressWithChecksumUpdate", 280 update: header.IPv4.SetDestinationAddressWithChecksumUpdate, 281 }, 282 } 283 284 for _, test := range tests { 285 t.Run(test.name, func(t *testing.T) { 286 for i := 0; i < 1000; i++ { 287 var origBytes [header.IPv4MinimumSize]byte 288 header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ 289 TOS: 1, 290 TotalLength: header.IPv4MinimumSize, 291 ID: 2, 292 Flags: 3, 293 FragmentOffset: 4, 294 TTL: 5, 295 Protocol: 6, 296 Checksum: 0, 297 SrcAddr: randomAddress(header.IPv4AddressSize), 298 DstAddr: randomAddress(header.IPv4AddressSize), 299 }) 300 301 addr := randomAddress(header.IPv4AddressSize) 302 303 bytesCopy := origBytes 304 h := header.IPv4(bytesCopy[:]) 305 origXSum := h.CalculateChecksum() 306 h.SetChecksum(^origXSum) 307 308 test.update(h, addr) 309 got := ^h.Checksum() 310 h.SetChecksum(0) 311 want := h.CalculateChecksum() 312 if got != want { 313 t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) 314 } 315 } 316 }) 317 } 318 } 319 320 func TestChecksummableTransportUpdatePort(t *testing.T) { 321 // The fields in the pseudo header is not tested here so we just use 0. 322 const pseudoHeaderXSum = 0 323 324 tests := []struct { 325 name string 326 transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) 327 proto tcpip.TransportProtocolNumber 328 }{ 329 { 330 name: "TCP", 331 transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { 332 h := header.TCP(make([]byte, header.TCPMinimumSize)) 333 h.Encode(&header.TCPFields{ 334 SrcPort: src, 335 DstPort: dst, 336 SeqNum: 1, 337 AckNum: 2, 338 DataOffset: header.TCPMinimumSize, 339 Flags: 3, 340 WindowSize: 4, 341 Checksum: 0, 342 UrgentPointer: 5, 343 }) 344 h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) 345 return h, h.CalculateChecksum 346 }, 347 proto: header.TCPProtocolNumber, 348 }, 349 { 350 name: "UDP", 351 transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { 352 h := header.UDP(make([]byte, header.UDPMinimumSize)) 353 h.Encode(&header.UDPFields{ 354 SrcPort: src, 355 DstPort: dst, 356 Length: 0, 357 Checksum: 0, 358 }) 359 h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) 360 return h, h.CalculateChecksum 361 }, 362 proto: header.UDPProtocolNumber, 363 }, 364 } 365 366 for i := 0; i < 1000; i++ { 367 origSrcPort := uint16(rand.Uint32()) 368 origDstPort := uint16(rand.Uint32()) 369 newPort := uint16(rand.Uint32()) 370 371 t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) { 372 for _, test := range tests { 373 t.Run(test.name, func(t *testing.T) { 374 for _, subTest := range []struct { 375 name string 376 update func(header.ChecksummableTransport) 377 }{ 378 { 379 name: "Source port", 380 update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, 381 }, 382 { 383 name: "Destination port", 384 update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, 385 }, 386 } { 387 t.Run(subTest.name, func(t *testing.T) { 388 h, calcXSum := test.transportHdr(origSrcPort, origDstPort) 389 subTest.update(h) 390 // TCP and UDP hold the 1s complement of the fully calculated 391 // checksum. 392 got := ^h.Checksum() 393 h.SetChecksum(0) 394 395 if want := calcXSum(pseudoHeaderXSum); got != want { 396 h, _ := test.transportHdr(origSrcPort, origDstPort) 397 t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) 398 } 399 }) 400 } 401 }) 402 } 403 }) 404 } 405 } 406 407 func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { 408 const addressSize = 6 409 410 tests := []struct { 411 name string 412 transportHdr func() header.ChecksummableTransport 413 proto tcpip.TransportProtocolNumber 414 }{ 415 { 416 name: "TCP", 417 transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, 418 proto: header.TCPProtocolNumber, 419 }, 420 { 421 name: "UDP", 422 transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, 423 proto: header.UDPProtocolNumber, 424 }, 425 } 426 427 for i := 0; i < 1000; i++ { 428 permanent := randomAddress(addressSize) 429 old := randomAddress(addressSize) 430 new := randomAddress(addressSize) 431 432 t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { 433 for _, test := range tests { 434 t.Run(test.name, func(t *testing.T) { 435 for _, fullChecksum := range []bool{true, false} { 436 t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { 437 initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) 438 if fullChecksum { 439 // TCP and UDP hold the 1s complement of the fully calculated 440 // checksum. 441 initialXSum = ^initialXSum 442 } 443 444 h := test.transportHdr() 445 h.SetChecksum(initialXSum) 446 h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) 447 448 got := h.Checksum() 449 if fullChecksum { 450 got = ^got 451 } 452 if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { 453 t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) 454 } 455 }) 456 } 457 }) 458 } 459 }) 460 } 461 }