gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/header/checksum_test.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package header provides the implementation of the encoding and decoding of
    16  // network protocol headers.
    17  package header_test
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"math/rand"
    23  	"sync"
    24  	"testing"
    25  
    26  	"gvisor.dev/gvisor/pkg/buffer"
    27  	"gvisor.dev/gvisor/pkg/tcpip"
    28  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    29  	"gvisor.dev/gvisor/pkg/tcpip/header"
    30  )
    31  
    32  func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) {
    33  	// icmpChecksum should not do any modifications of the header to
    34  	// calculate its checksum. Let's call it from a few go-routines and the
    35  	// race detector will trigger a warning if there are any concurrent
    36  	// read/write accesses.
    37  
    38  	const concurrency = 5
    39  	start := make(chan int)
    40  	ready := make(chan bool, concurrency)
    41  	var wg sync.WaitGroup
    42  	wg.Add(concurrency)
    43  	defer wg.Wait()
    44  
    45  	for i := 0; i < concurrency; i++ {
    46  		go func() {
    47  			defer wg.Done()
    48  
    49  			ready <- true
    50  			<-start
    51  
    52  			if got := headerChecksum(); want != got {
    53  				t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
    54  			}
    55  			if got := icmpChecksum(); want != got {
    56  				t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
    57  			}
    58  		}()
    59  	}
    60  	for i := 0; i < concurrency; i++ {
    61  		<-ready
    62  	}
    63  	close(start)
    64  }
    65  
    66  // TODO(b/239732156): Replace magic constants with names corresponding to what
    67  // they represent ICMP.
    68  func TestICMPv4Checksum(t *testing.T) {
    69  	rnd := rand.New(rand.NewSource(42))
    70  
    71  	h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize))
    72  	if _, err := rnd.Read(h); err != nil {
    73  		t.Fatalf("rnd.Read failed: %v", err)
    74  	}
    75  	h.SetChecksum(0)
    76  
    77  	buf := make([]byte, 13)
    78  	if _, err := rnd.Read(buf); err != nil {
    79  		t.Fatalf("rnd.Read failed: %v", err)
    80  	}
    81  	b := buffer.MakeWithData(buf[:5])
    82  	b.Append(buffer.NewViewWithData(buf[5:]))
    83  
    84  	want := checksum.Checksum(b.Flatten(), 0)
    85  	want = ^checksum.Checksum(h, want)
    86  	h.SetChecksum(want)
    87  
    88  	testICMPChecksum(t, h.Checksum, func() uint16 {
    89  		return header.ICMPv4Checksum(h, b.Checksum(0))
    90  	}, want, fmt.Sprintf("header: {% x} data {% x}", h, b.Flatten()))
    91  }
    92  
    93  func TestICMPv4ChecksumUpdate(t *testing.T) {
    94  	const icmpIdent = 0
    95  
    96  	data := make([]byte, header.ICMPv4MinimumSize)
    97  	h := header.ICMPv4(data)
    98  	h.SetType(header.ICMPv4EchoReply)
    99  	h.SetCode(header.ICMPv4UnusedCode)
   100  	h.SetIdent(icmpIdent)
   101  	h.SetChecksum(^checksum.Checksum(data, 0))
   102  
   103  	updated := header.ICMPv4(bytes.Clone(data))
   104  	// Perform an incremental checksum update where we aren't actually changing the ID.
   105  	updated.SetIdentWithChecksumUpdate(icmpIdent)
   106  	if updated.Checksum() != h.Checksum() {
   107  		t.Errorf("got updated.Checksum() = %x, want = %x", updated.Checksum(), h.Checksum())
   108  	}
   109  }
   110  
   111  func TestICMPv6Checksum(t *testing.T) {
   112  	rnd := rand.New(rand.NewSource(42))
   113  
   114  	h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize))
   115  	if _, err := rnd.Read(h); err != nil {
   116  		t.Fatalf("rnd.Read failed: %v", err)
   117  	}
   118  	h.SetChecksum(0)
   119  
   120  	buf := make([]byte, 13)
   121  	if _, err := rnd.Read(buf); err != nil {
   122  		t.Fatalf("rnd.Read failed: %v", err)
   123  	}
   124  	b := buffer.MakeWithData(buf[:7])
   125  	b.Append(buffer.NewViewWithData(buf[7:10]))
   126  	b.Append(buffer.NewViewWithData(buf[10:]))
   127  
   128  	dst := header.IPv6Loopback
   129  	src := header.IPv6Loopback
   130  
   131  	want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+int(b.Size())))
   132  	want = checksum.Checksum(b.Flatten(), want)
   133  	want = ^checksum.Checksum(h, want)
   134  	h.SetChecksum(want)
   135  
   136  	testICMPChecksum(t, h.Checksum, func() uint16 {
   137  		return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   138  			Header:      h,
   139  			Src:         src,
   140  			Dst:         dst,
   141  			PayloadCsum: b.Checksum(0),
   142  			PayloadLen:  int(b.Size()),
   143  		})
   144  	}, want, fmt.Sprintf("header: {% x} data {% x}", h, b.Flatten()))
   145  }
   146  
   147  func randomAddress(size int) tcpip.Address {
   148  	s := make([]byte, size)
   149  	for i := 0; i < size; i++ {
   150  		s[i] = byte(rand.Uint32())
   151  	}
   152  	return tcpip.AddrFromSlice(s)
   153  }
   154  
   155  func TestChecksummableNetworkUpdateAddress(t *testing.T) {
   156  	tests := []struct {
   157  		name   string
   158  		update func(header.IPv4, tcpip.Address)
   159  	}{
   160  		{
   161  			name:   "SetSourceAddressWithChecksumUpdate",
   162  			update: header.IPv4.SetSourceAddressWithChecksumUpdate,
   163  		},
   164  		{
   165  			name:   "SetDestinationAddressWithChecksumUpdate",
   166  			update: header.IPv4.SetDestinationAddressWithChecksumUpdate,
   167  		},
   168  	}
   169  
   170  	for _, test := range tests {
   171  		t.Run(test.name, func(t *testing.T) {
   172  			for i := 0; i < 1000; i++ {
   173  				var origBytes [header.IPv4MinimumSize]byte
   174  				header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{
   175  					TOS:            1,
   176  					TotalLength:    header.IPv4MinimumSize,
   177  					ID:             2,
   178  					Flags:          3,
   179  					FragmentOffset: 4,
   180  					TTL:            5,
   181  					Protocol:       6,
   182  					Checksum:       0,
   183  					SrcAddr:        randomAddress(header.IPv4AddressSize),
   184  					DstAddr:        randomAddress(header.IPv4AddressSize),
   185  				})
   186  
   187  				addr := randomAddress(header.IPv4AddressSize)
   188  
   189  				bytesCopy := origBytes
   190  				h := header.IPv4(bytesCopy[:])
   191  				origXSum := h.CalculateChecksum()
   192  				h.SetChecksum(^origXSum)
   193  
   194  				test.update(h, addr)
   195  				got := ^h.Checksum()
   196  				h.SetChecksum(0)
   197  				want := h.CalculateChecksum()
   198  				if got != want {
   199  					t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr)
   200  				}
   201  			}
   202  		})
   203  	}
   204  }
   205  
   206  func TestChecksummableTransportUpdatePort(t *testing.T) {
   207  	// The fields in the pseudo header is not tested here so we just use 0.
   208  	const pseudoHeaderXSum = 0
   209  
   210  	tests := []struct {
   211  		name         string
   212  		transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16)
   213  		proto        tcpip.TransportProtocolNumber
   214  	}{
   215  		{
   216  			name: "TCP",
   217  			transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
   218  				h := header.TCP(make([]byte, header.TCPMinimumSize))
   219  				h.Encode(&header.TCPFields{
   220  					SrcPort:       src,
   221  					DstPort:       dst,
   222  					SeqNum:        1,
   223  					AckNum:        2,
   224  					DataOffset:    header.TCPMinimumSize,
   225  					Flags:         3,
   226  					WindowSize:    4,
   227  					Checksum:      0,
   228  					UrgentPointer: 5,
   229  				})
   230  				h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
   231  				return h, h.CalculateChecksum
   232  			},
   233  			proto: header.TCPProtocolNumber,
   234  		},
   235  		{
   236  			name: "UDP",
   237  			transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
   238  				h := header.UDP(make([]byte, header.UDPMinimumSize))
   239  				h.Encode(&header.UDPFields{
   240  					SrcPort:  src,
   241  					DstPort:  dst,
   242  					Length:   0,
   243  					Checksum: 0,
   244  				})
   245  				h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
   246  				return h, h.CalculateChecksum
   247  			},
   248  			proto: header.UDPProtocolNumber,
   249  		},
   250  	}
   251  
   252  	for i := 0; i < 1000; i++ {
   253  		origSrcPort := uint16(rand.Uint32())
   254  		origDstPort := uint16(rand.Uint32())
   255  		newPort := uint16(rand.Uint32())
   256  
   257  		t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(t *testing.T) {
   258  			for _, test := range tests {
   259  				t.Run(test.name, func(t *testing.T) {
   260  					for _, subTest := range []struct {
   261  						name   string
   262  						update func(header.ChecksummableTransport)
   263  					}{
   264  						{
   265  							name:   "Source port",
   266  							update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) },
   267  						},
   268  						{
   269  							name:   "Destination port",
   270  							update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) },
   271  						},
   272  					} {
   273  						t.Run(subTest.name, func(t *testing.T) {
   274  							h, calcXSum := test.transportHdr(origSrcPort, origDstPort)
   275  							subTest.update(h)
   276  							// TCP and UDP hold the 1s complement of the fully calculated
   277  							// checksum.
   278  							got := ^h.Checksum()
   279  							h.SetChecksum(0)
   280  
   281  							if want := calcXSum(pseudoHeaderXSum); got != want {
   282  								h, _ := test.transportHdr(origSrcPort, origDstPort)
   283  								t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort)
   284  							}
   285  						})
   286  					}
   287  				})
   288  			}
   289  		})
   290  	}
   291  }
   292  
   293  func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) {
   294  	const addressSize = 16
   295  
   296  	tests := []struct {
   297  		name         string
   298  		transportHdr func() header.ChecksummableTransport
   299  		proto        tcpip.TransportProtocolNumber
   300  	}{
   301  		{
   302  			name:         "TCP",
   303  			transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) },
   304  			proto:        header.TCPProtocolNumber,
   305  		},
   306  		{
   307  			name:         "UDP",
   308  			transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) },
   309  			proto:        header.UDPProtocolNumber,
   310  		},
   311  	}
   312  
   313  	for i := 0; i < 1000; i++ {
   314  		permanent := randomAddress(addressSize)
   315  		old := randomAddress(addressSize)
   316  		new := randomAddress(addressSize)
   317  
   318  		t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) {
   319  			for _, test := range tests {
   320  				t.Run(test.name, func(t *testing.T) {
   321  					for _, fullChecksum := range []bool{true, false} {
   322  						t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) {
   323  							initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0)
   324  							if fullChecksum {
   325  								// TCP and UDP hold the 1s complement of the fully calculated
   326  								// checksum.
   327  								initialXSum = ^initialXSum
   328  							}
   329  
   330  							h := test.transportHdr()
   331  							h.SetChecksum(initialXSum)
   332  							h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum)
   333  
   334  							got := h.Checksum()
   335  							if fullChecksum {
   336  								got = ^got
   337  							}
   338  							if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want {
   339  								t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h)
   340  							}
   341  						})
   342  					}
   343  				})
   344  			}
   345  		})
   346  	}
   347  }