github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/header/checksum_test.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package header provides the implementation of the encoding and decoding of
    16  // network protocol headers.
    17  package header_test
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"math/rand"
    23  	"sync"
    24  	"testing"
    25  
    26  	"github.com/SagerNet/gvisor/pkg/tcpip"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    29  )
    30  
    31  func TestChecksumer(t *testing.T) {
    32  	testCases := []struct {
    33  		name string
    34  		data [][]byte
    35  		want uint16
    36  	}{
    37  		{
    38  			name: "empty",
    39  			want: 0,
    40  		},
    41  		{
    42  			name: "OneOddView",
    43  			data: [][]byte{
    44  				[]byte{1, 9, 0, 5, 4},
    45  			},
    46  			want: 1294,
    47  		},
    48  		{
    49  			name: "TwoOddViews",
    50  			data: [][]byte{
    51  				[]byte{1, 9, 0, 5, 4},
    52  				[]byte{4, 3, 7, 1, 2, 123},
    53  			},
    54  			want: 33819,
    55  		},
    56  		{
    57  			name: "OneEvenView",
    58  			data: [][]byte{
    59  				[]byte{1, 9, 0, 5},
    60  			},
    61  			want: 270,
    62  		},
    63  		{
    64  			name: "TwoEvenViews",
    65  			data: [][]byte{
    66  				buffer.NewViewFromBytes([]byte{98, 1, 9, 0}),
    67  				buffer.NewViewFromBytes([]byte{9, 0, 5, 4}),
    68  			},
    69  			want: 30981,
    70  		},
    71  		{
    72  			name: "ThreeViews",
    73  			data: [][]byte{
    74  				[]byte{77, 11, 33, 0, 55, 44},
    75  				[]byte{98, 1, 9, 0, 5, 4},
    76  				[]byte{4, 3, 7, 1, 2, 123, 99},
    77  			},
    78  			want: 34236,
    79  		},
    80  	}
    81  	for _, tc := range testCases {
    82  		t.Run(tc.name, func(t *testing.T) {
    83  			var all bytes.Buffer
    84  			var c header.Checksumer
    85  			for _, b := range tc.data {
    86  				c.Add(b)
    87  				// Append to the buffer. We will check the checksum as a whole later.
    88  				if _, err := all.Write(b); err != nil {
    89  					t.Fatalf("all.Write(b) = _, %s; want _, nil", err)
    90  				}
    91  			}
    92  			if got, want := c.Checksum(), tc.want; got != want {
    93  				t.Errorf("c.Checksum() = %d, want %d", got, want)
    94  			}
    95  			if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want {
    96  				t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want)
    97  			}
    98  		})
    99  	}
   100  }
   101  
   102  func TestChecksum(t *testing.T) {
   103  	var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024}
   104  	type testCase struct {
   105  		buf      []byte
   106  		initial  uint16
   107  		csumOrig uint16
   108  		csumNew  uint16
   109  	}
   110  	testCases := make([]testCase, 100000)
   111  	// Ensure same buffer generation for test consistency.
   112  	rnd := rand.New(rand.NewSource(42))
   113  	for i := range testCases {
   114  		testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
   115  		testCases[i].initial = uint16(rnd.Intn(65536))
   116  		rnd.Read(testCases[i].buf)
   117  	}
   118  
   119  	for i := range testCases {
   120  		testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
   121  		testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
   122  		if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
   123  			t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
   124  		}
   125  	}
   126  }
   127  
   128  func BenchmarkChecksum(b *testing.B) {
   129  	var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
   130  
   131  	checkSumImpls := []struct {
   132  		fn   func([]byte, uint16) uint16
   133  		name string
   134  	}{
   135  		{header.ChecksumOld, fmt.Sprintf("checksum_old")},
   136  		{header.Checksum, fmt.Sprintf("checksum")},
   137  	}
   138  
   139  	for _, csumImpl := range checkSumImpls {
   140  		// Ensure same buffer generation for test consistency.
   141  		rnd := rand.New(rand.NewSource(42))
   142  		for _, bufSz := range bufSizes {
   143  			b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
   144  				tc := struct {
   145  					buf     []byte
   146  					initial uint16
   147  					csum    uint16
   148  				}{
   149  					buf:     make([]byte, bufSz),
   150  					initial: uint16(rnd.Intn(65536)),
   151  				}
   152  				rnd.Read(tc.buf)
   153  				b.ResetTimer()
   154  				for i := 0; i < b.N; i++ {
   155  					tc.csum = csumImpl.fn(tc.buf, tc.initial)
   156  				}
   157  			})
   158  		}
   159  	}
   160  }
   161  
   162  func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) {
   163  	// icmpChecksum should not do any modifications of the header to
   164  	// calculate its checksum. Let's call it from a few go-routines and the
   165  	// race detector will trigger a warning if there are any concurrent
   166  	// read/write accesses.
   167  
   168  	const concurrency = 5
   169  	start := make(chan int)
   170  	ready := make(chan bool, concurrency)
   171  	var wg sync.WaitGroup
   172  	wg.Add(concurrency)
   173  	defer wg.Wait()
   174  
   175  	for i := 0; i < concurrency; i++ {
   176  		go func() {
   177  			defer wg.Done()
   178  
   179  			ready <- true
   180  			<-start
   181  
   182  			if got := headerChecksum(); want != got {
   183  				t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
   184  			}
   185  			if got := icmpChecksum(); want != got {
   186  				t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
   187  			}
   188  		}()
   189  	}
   190  	for i := 0; i < concurrency; i++ {
   191  		<-ready
   192  	}
   193  	close(start)
   194  }
   195  
   196  func TestICMPv4Checksum(t *testing.T) {
   197  	rnd := rand.New(rand.NewSource(42))
   198  
   199  	h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize))
   200  	if _, err := rnd.Read(h); err != nil {
   201  		t.Fatalf("rnd.Read failed: %v", err)
   202  	}
   203  	h.SetChecksum(0)
   204  
   205  	buf := make([]byte, 13)
   206  	if _, err := rnd.Read(buf); err != nil {
   207  		t.Fatalf("rnd.Read failed: %v", err)
   208  	}
   209  	vv := buffer.NewVectorisedView(len(buf), []buffer.View{
   210  		buffer.NewViewFromBytes(buf[:5]),
   211  		buffer.NewViewFromBytes(buf[5:]),
   212  	})
   213  
   214  	want := header.Checksum(vv.ToView(), 0)
   215  	want = ^header.Checksum(h, want)
   216  	h.SetChecksum(want)
   217  
   218  	testICMPChecksum(t, h.Checksum, func() uint16 {
   219  		return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0))
   220  	}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
   221  }
   222  
   223  func TestICMPv6Checksum(t *testing.T) {
   224  	rnd := rand.New(rand.NewSource(42))
   225  
   226  	h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize))
   227  	if _, err := rnd.Read(h); err != nil {
   228  		t.Fatalf("rnd.Read failed: %v", err)
   229  	}
   230  	h.SetChecksum(0)
   231  
   232  	buf := make([]byte, 13)
   233  	if _, err := rnd.Read(buf); err != nil {
   234  		t.Fatalf("rnd.Read failed: %v", err)
   235  	}
   236  	vv := buffer.NewVectorisedView(len(buf), []buffer.View{
   237  		buffer.NewViewFromBytes(buf[:7]),
   238  		buffer.NewViewFromBytes(buf[7:10]),
   239  		buffer.NewViewFromBytes(buf[10:]),
   240  	})
   241  
   242  	dst := header.IPv6Loopback
   243  	src := header.IPv6Loopback
   244  
   245  	want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
   246  	want = header.Checksum(vv.ToView(), want)
   247  	want = ^header.Checksum(h, want)
   248  	h.SetChecksum(want)
   249  
   250  	testICMPChecksum(t, h.Checksum, func() uint16 {
   251  		return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   252  			Header:      h,
   253  			Src:         src,
   254  			Dst:         dst,
   255  			PayloadCsum: header.ChecksumVV(vv, 0),
   256  			PayloadLen:  vv.Size(),
   257  		})
   258  	}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
   259  }
   260  
   261  func randomAddress(size int) tcpip.Address {
   262  	s := make([]byte, size)
   263  	for i := 0; i < size; i++ {
   264  		s[i] = byte(rand.Uint32())
   265  	}
   266  	return tcpip.Address(s)
   267  }
   268  
   269  func TestChecksummableNetworkUpdateAddress(t *testing.T) {
   270  	tests := []struct {
   271  		name   string
   272  		update func(header.IPv4, tcpip.Address)
   273  	}{
   274  		{
   275  			name:   "SetSourceAddressWithChecksumUpdate",
   276  			update: header.IPv4.SetSourceAddressWithChecksumUpdate,
   277  		},
   278  		{
   279  			name:   "SetDestinationAddressWithChecksumUpdate",
   280  			update: header.IPv4.SetDestinationAddressWithChecksumUpdate,
   281  		},
   282  	}
   283  
   284  	for _, test := range tests {
   285  		t.Run(test.name, func(t *testing.T) {
   286  			for i := 0; i < 1000; i++ {
   287  				var origBytes [header.IPv4MinimumSize]byte
   288  				header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{
   289  					TOS:            1,
   290  					TotalLength:    header.IPv4MinimumSize,
   291  					ID:             2,
   292  					Flags:          3,
   293  					FragmentOffset: 4,
   294  					TTL:            5,
   295  					Protocol:       6,
   296  					Checksum:       0,
   297  					SrcAddr:        randomAddress(header.IPv4AddressSize),
   298  					DstAddr:        randomAddress(header.IPv4AddressSize),
   299  				})
   300  
   301  				addr := randomAddress(header.IPv4AddressSize)
   302  
   303  				bytesCopy := origBytes
   304  				h := header.IPv4(bytesCopy[:])
   305  				origXSum := h.CalculateChecksum()
   306  				h.SetChecksum(^origXSum)
   307  
   308  				test.update(h, addr)
   309  				got := ^h.Checksum()
   310  				h.SetChecksum(0)
   311  				want := h.CalculateChecksum()
   312  				if got != want {
   313  					t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr)
   314  				}
   315  			}
   316  		})
   317  	}
   318  }
   319  
   320  func TestChecksummableTransportUpdatePort(t *testing.T) {
   321  	// The fields in the pseudo header is not tested here so we just use 0.
   322  	const pseudoHeaderXSum = 0
   323  
   324  	tests := []struct {
   325  		name         string
   326  		transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16)
   327  		proto        tcpip.TransportProtocolNumber
   328  	}{
   329  		{
   330  			name: "TCP",
   331  			transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
   332  				h := header.TCP(make([]byte, header.TCPMinimumSize))
   333  				h.Encode(&header.TCPFields{
   334  					SrcPort:       src,
   335  					DstPort:       dst,
   336  					SeqNum:        1,
   337  					AckNum:        2,
   338  					DataOffset:    header.TCPMinimumSize,
   339  					Flags:         3,
   340  					WindowSize:    4,
   341  					Checksum:      0,
   342  					UrgentPointer: 5,
   343  				})
   344  				h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
   345  				return h, h.CalculateChecksum
   346  			},
   347  			proto: header.TCPProtocolNumber,
   348  		},
   349  		{
   350  			name: "UDP",
   351  			transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
   352  				h := header.UDP(make([]byte, header.UDPMinimumSize))
   353  				h.Encode(&header.UDPFields{
   354  					SrcPort:  src,
   355  					DstPort:  dst,
   356  					Length:   0,
   357  					Checksum: 0,
   358  				})
   359  				h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
   360  				return h, h.CalculateChecksum
   361  			},
   362  			proto: header.UDPProtocolNumber,
   363  		},
   364  	}
   365  
   366  	for i := 0; i < 1000; i++ {
   367  		origSrcPort := uint16(rand.Uint32())
   368  		origDstPort := uint16(rand.Uint32())
   369  		newPort := uint16(rand.Uint32())
   370  
   371  		t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) {
   372  			for _, test := range tests {
   373  				t.Run(test.name, func(t *testing.T) {
   374  					for _, subTest := range []struct {
   375  						name   string
   376  						update func(header.ChecksummableTransport)
   377  					}{
   378  						{
   379  							name:   "Source port",
   380  							update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) },
   381  						},
   382  						{
   383  							name:   "Destination port",
   384  							update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) },
   385  						},
   386  					} {
   387  						t.Run(subTest.name, func(t *testing.T) {
   388  							h, calcXSum := test.transportHdr(origSrcPort, origDstPort)
   389  							subTest.update(h)
   390  							// TCP and UDP hold the 1s complement of the fully calculated
   391  							// checksum.
   392  							got := ^h.Checksum()
   393  							h.SetChecksum(0)
   394  
   395  							if want := calcXSum(pseudoHeaderXSum); got != want {
   396  								h, _ := test.transportHdr(origSrcPort, origDstPort)
   397  								t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort)
   398  							}
   399  						})
   400  					}
   401  				})
   402  			}
   403  		})
   404  	}
   405  }
   406  
   407  func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) {
   408  	const addressSize = 6
   409  
   410  	tests := []struct {
   411  		name         string
   412  		transportHdr func() header.ChecksummableTransport
   413  		proto        tcpip.TransportProtocolNumber
   414  	}{
   415  		{
   416  			name:         "TCP",
   417  			transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) },
   418  			proto:        header.TCPProtocolNumber,
   419  		},
   420  		{
   421  			name:         "UDP",
   422  			transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) },
   423  			proto:        header.UDPProtocolNumber,
   424  		},
   425  	}
   426  
   427  	for i := 0; i < 1000; i++ {
   428  		permanent := randomAddress(addressSize)
   429  		old := randomAddress(addressSize)
   430  		new := randomAddress(addressSize)
   431  
   432  		t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) {
   433  			for _, test := range tests {
   434  				t.Run(test.name, func(t *testing.T) {
   435  					for _, fullChecksum := range []bool{true, false} {
   436  						t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) {
   437  							initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0)
   438  							if fullChecksum {
   439  								// TCP and UDP hold the 1s complement of the fully calculated
   440  								// checksum.
   441  								initialXSum = ^initialXSum
   442  							}
   443  
   444  							h := test.transportHdr()
   445  							h.SetChecksum(initialXSum)
   446  							h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum)
   447  
   448  							got := h.Checksum()
   449  							if fullChecksum {
   450  								got = ^got
   451  							}
   452  							if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want {
   453  								t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h)
   454  							}
   455  						})
   456  					}
   457  				})
   458  			}
   459  		})
   460  	}
   461  }