gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/checksum/checksum_test.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package header provides the implementation of the encoding and decoding of
    16  // network protocol headers.
    17  package checksum
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"math/rand"
    23  	"testing"
    24  )
    25  
    26  func TestChecksumer(t *testing.T) {
    27  	testCases := []struct {
    28  		name string
    29  		data [][]byte
    30  		want uint16
    31  	}{
    32  		{
    33  			name: "empty",
    34  			want: 0,
    35  		},
    36  		{
    37  			name: "OneOddView",
    38  			data: [][]byte{
    39  				{1, 9, 0, 5, 4},
    40  			},
    41  			want: 1294,
    42  		},
    43  		{
    44  			name: "TwoOddViews",
    45  			data: [][]byte{
    46  				{1, 9, 0, 5, 4},
    47  				{4, 3, 7, 1, 2, 123},
    48  			},
    49  			want: 33819,
    50  		},
    51  		{
    52  			name: "OneEvenView",
    53  			data: [][]byte{
    54  				{1, 9, 0, 5},
    55  			},
    56  			want: 270,
    57  		},
    58  		{
    59  			name: "TwoEvenViews",
    60  			data: [][]byte{
    61  				[]byte{98, 1, 9, 0},
    62  				[]byte{9, 0, 5, 4},
    63  			},
    64  			want: 30981,
    65  		},
    66  		{
    67  			name: "ThreeViews",
    68  			data: [][]byte{
    69  				{77, 11, 33, 0, 55, 44},
    70  				{98, 1, 9, 0, 5, 4},
    71  				{4, 3, 7, 1, 2, 123, 99},
    72  			},
    73  			want: 34236,
    74  		},
    75  	}
    76  	for _, tc := range testCases {
    77  		t.Run(tc.name, func(t *testing.T) {
    78  			var all bytes.Buffer
    79  			var c Checksumer
    80  			for _, b := range tc.data {
    81  				c.Add(b)
    82  				// Append to the buffer. We will check the checksum as a whole later.
    83  				if _, err := all.Write(b); err != nil {
    84  					t.Fatalf("all.Write(b) = _, %s; want _, nil", err)
    85  				}
    86  			}
    87  			if got, want := c.Checksum(), tc.want; got != want {
    88  				t.Errorf("c.Checksum() = %d, want %d", got, want)
    89  			}
    90  			if got, want := Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want {
    91  				t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want)
    92  			}
    93  		})
    94  	}
    95  }
    96  
    97  func TestChecksum(t *testing.T) {
    98  	var bufSizes = []int{
    99  		0,
   100  		1,
   101  		2,
   102  		3,
   103  		4,
   104  		7,
   105  		8,
   106  		15,
   107  		16,
   108  		31,
   109  		32,
   110  		63,
   111  		64,
   112  		127,
   113  		128,
   114  		255,
   115  		256,
   116  		257,
   117  		1023,
   118  		1024,
   119  	}
   120  	type testCase struct {
   121  		buf     []byte
   122  		initial uint16
   123  	}
   124  	testCases := make([]testCase, 100000)
   125  	// Ensure same buffer generation for test consistency.
   126  	rnd := rand.New(rand.NewSource(42))
   127  	for i := range testCases {
   128  		testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
   129  		testCases[i].initial = uint16(rnd.Intn(65536))
   130  		rnd.Read(testCases[i].buf)
   131  	}
   132  
   133  	checkSumImpls := []struct {
   134  		fn   func([]byte, uint16) uint16
   135  		name string
   136  	}{
   137  		{old, "checksum_old"},
   138  		{Checksum, "checksum"},
   139  	}
   140  
   141  	for _, tc := range testCases {
   142  		t.Run(fmt.Sprintf("buf size %d", len(tc.buf)), func(t *testing.T) {
   143  			// Also test different offsets into the buffers. This
   144  			// tests the correctess of optimizations dealing with
   145  			// non-64-bit aligned numbers.
   146  			for offset := 0; offset < 8; offset++ {
   147  				t.Run(fmt.Sprintf("offset %d", offset), func(t *testing.T) {
   148  					if offset > len(tc.buf) {
   149  						t.Skip("offset is greater than buffer size")
   150  					}
   151  					buf := tc.buf[offset:]
   152  					for i := 0; i < len(checkSumImpls)-1; i++ {
   153  						first := checkSumImpls[i].fn(buf, tc.initial)
   154  						second := checkSumImpls[i+1].fn(buf, tc.initial)
   155  						if first != second {
   156  							t.Fatalf("for (buf = 0x%x, initial = 0x%x) checksum %q does not match %q: got: 0x%x and 0x%x", buf, tc.initial, checkSumImpls[i].name, checkSumImpls[i+1].name, first, second)
   157  						}
   158  					}
   159  				})
   160  			}
   161  		})
   162  	}
   163  }
   164  
   165  // TestIncrementalChecksum tests for breakages of Checksummer as described in
   166  // b/289284842.
   167  func TestIncrementalChecksum(t *testing.T) {
   168  	buf := []byte{
   169  		0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31,
   170  		0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c,
   171  		0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
   172  		0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52,
   173  		0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d,
   174  		0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63,
   175  	}
   176  
   177  	// Go through buf and check that checksum(buf[:end]) is equivalent to
   178  	// an incremental checksum of two chunks of buf[:end].
   179  	for end := 2; end <= len(buf); end++ {
   180  		for start := 1; start < end; start++ {
   181  			t.Run(fmt.Sprintf("end=%d start=%d", end, start), func(t *testing.T) {
   182  				var cs Checksumer
   183  				cs.Add(buf[:end])
   184  				csum := cs.Checksum()
   185  
   186  				cs = Checksumer{}
   187  				cs.Add(buf[:start])
   188  				cs.Add(buf[start:end])
   189  				csumIncremental := cs.Checksum()
   190  
   191  				if want := old(buf[:end], 0); csum != want {
   192  					t.Fatalf("checksum is wrong: got %x, expected %x", csum, want)
   193  				}
   194  				if csum != csumIncremental {
   195  					t.Errorf("checksums should be the same: %x %x", csum, csumIncremental)
   196  				}
   197  			})
   198  		}
   199  	}
   200  }
   201  
   202  func BenchmarkChecksum(b *testing.B) {
   203  	var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
   204  
   205  	checkSumImpls := []struct {
   206  		fn   func([]byte, uint16) uint16
   207  		name string
   208  	}{
   209  		{old, "checksum_old"},
   210  		{Checksum, "checksum"},
   211  	}
   212  
   213  	for _, csumImpl := range checkSumImpls {
   214  		// Ensure same buffer generation for test consistency.
   215  		rnd := rand.New(rand.NewSource(42))
   216  		for _, bufSz := range bufSizes {
   217  			b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
   218  				tc := struct {
   219  					buf     []byte
   220  					initial uint16
   221  					csum    uint16
   222  				}{
   223  					buf:     make([]byte, bufSz),
   224  					initial: uint16(rnd.Intn(65536)),
   225  				}
   226  				rnd.Read(tc.buf)
   227  				b.ResetTimer()
   228  				for i := 0; i < b.N; i++ {
   229  					tc.csum = csumImpl.fn(tc.buf, tc.initial)
   230  				}
   231  			})
   232  		}
   233  	}
   234  }
   235  
   236  // old calculates the checksum (as defined in RFC 1071) of the bytes in
   237  // the given byte array. This function uses a non-optimized implementation. Its
   238  // only retained for reference and to use as a benchmark/test. Most code should
   239  // use the header.Checksum function.
   240  //
   241  // The initial checksum must have been computed on an even number of bytes.
   242  func old(buf []byte, initial uint16) uint16 {
   243  	s, _ := oldCalculateChecksum(buf, false, uint32(initial))
   244  	return s
   245  }
   246  
   247  func oldCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
   248  	v := initial
   249  
   250  	if odd {
   251  		v += uint32(buf[0])
   252  		buf = buf[1:]
   253  	}
   254  
   255  	l := len(buf)
   256  	odd = l&1 != 0
   257  	if odd {
   258  		l--
   259  		v += uint32(buf[l]) << 8
   260  	}
   261  
   262  	for i := 0; i < l; i += 2 {
   263  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   264  	}
   265  
   266  	return Combine(uint16(v), uint16(v>>16)), odd
   267  }