gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/checksum/checksum_test.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package header provides the implementation of the encoding and decoding of 16 // network protocol headers. 17 package checksum 18 19 import ( 20 "bytes" 21 "fmt" 22 "math/rand" 23 "testing" 24 ) 25 26 func TestChecksumer(t *testing.T) { 27 testCases := []struct { 28 name string 29 data [][]byte 30 want uint16 31 }{ 32 { 33 name: "empty", 34 want: 0, 35 }, 36 { 37 name: "OneOddView", 38 data: [][]byte{ 39 {1, 9, 0, 5, 4}, 40 }, 41 want: 1294, 42 }, 43 { 44 name: "TwoOddViews", 45 data: [][]byte{ 46 {1, 9, 0, 5, 4}, 47 {4, 3, 7, 1, 2, 123}, 48 }, 49 want: 33819, 50 }, 51 { 52 name: "OneEvenView", 53 data: [][]byte{ 54 {1, 9, 0, 5}, 55 }, 56 want: 270, 57 }, 58 { 59 name: "TwoEvenViews", 60 data: [][]byte{ 61 []byte{98, 1, 9, 0}, 62 []byte{9, 0, 5, 4}, 63 }, 64 want: 30981, 65 }, 66 { 67 name: "ThreeViews", 68 data: [][]byte{ 69 {77, 11, 33, 0, 55, 44}, 70 {98, 1, 9, 0, 5, 4}, 71 {4, 3, 7, 1, 2, 123, 99}, 72 }, 73 want: 34236, 74 }, 75 } 76 for _, tc := range testCases { 77 t.Run(tc.name, func(t *testing.T) { 78 var all bytes.Buffer 79 var c Checksumer 80 for _, b := range tc.data { 81 c.Add(b) 82 // Append to the buffer. We will check the checksum as a whole later. 83 if _, err := all.Write(b); err != nil { 84 t.Fatalf("all.Write(b) = _, %s; want _, nil", err) 85 } 86 } 87 if got, want := c.Checksum(), tc.want; got != want { 88 t.Errorf("c.Checksum() = %d, want %d", got, want) 89 } 90 if got, want := Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want { 91 t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want) 92 } 93 }) 94 } 95 } 96 97 func TestChecksum(t *testing.T) { 98 var bufSizes = []int{ 99 0, 100 1, 101 2, 102 3, 103 4, 104 7, 105 8, 106 15, 107 16, 108 31, 109 32, 110 63, 111 64, 112 127, 113 128, 114 255, 115 256, 116 257, 117 1023, 118 1024, 119 } 120 type testCase struct { 121 buf []byte 122 initial uint16 123 } 124 testCases := make([]testCase, 100000) 125 // Ensure same buffer generation for test consistency. 126 rnd := rand.New(rand.NewSource(42)) 127 for i := range testCases { 128 testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) 129 testCases[i].initial = uint16(rnd.Intn(65536)) 130 rnd.Read(testCases[i].buf) 131 } 132 133 checkSumImpls := []struct { 134 fn func([]byte, uint16) uint16 135 name string 136 }{ 137 {old, "checksum_old"}, 138 {Checksum, "checksum"}, 139 } 140 141 for _, tc := range testCases { 142 t.Run(fmt.Sprintf("buf size %d", len(tc.buf)), func(t *testing.T) { 143 // Also test different offsets into the buffers. This 144 // tests the correctess of optimizations dealing with 145 // non-64-bit aligned numbers. 146 for offset := 0; offset < 8; offset++ { 147 t.Run(fmt.Sprintf("offset %d", offset), func(t *testing.T) { 148 if offset > len(tc.buf) { 149 t.Skip("offset is greater than buffer size") 150 } 151 buf := tc.buf[offset:] 152 for i := 0; i < len(checkSumImpls)-1; i++ { 153 first := checkSumImpls[i].fn(buf, tc.initial) 154 second := checkSumImpls[i+1].fn(buf, tc.initial) 155 if first != second { 156 t.Fatalf("for (buf = 0x%x, initial = 0x%x) checksum %q does not match %q: got: 0x%x and 0x%x", buf, tc.initial, checkSumImpls[i].name, checkSumImpls[i+1].name, first, second) 157 } 158 } 159 }) 160 } 161 }) 162 } 163 } 164 165 // TestIncrementalChecksum tests for breakages of Checksummer as described in 166 // b/289284842. 167 func TestIncrementalChecksum(t *testing.T) { 168 buf := []byte{ 169 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 170 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 171 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 172 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 173 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 174 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 175 } 176 177 // Go through buf and check that checksum(buf[:end]) is equivalent to 178 // an incremental checksum of two chunks of buf[:end]. 179 for end := 2; end <= len(buf); end++ { 180 for start := 1; start < end; start++ { 181 t.Run(fmt.Sprintf("end=%d start=%d", end, start), func(t *testing.T) { 182 var cs Checksumer 183 cs.Add(buf[:end]) 184 csum := cs.Checksum() 185 186 cs = Checksumer{} 187 cs.Add(buf[:start]) 188 cs.Add(buf[start:end]) 189 csumIncremental := cs.Checksum() 190 191 if want := old(buf[:end], 0); csum != want { 192 t.Fatalf("checksum is wrong: got %x, expected %x", csum, want) 193 } 194 if csum != csumIncremental { 195 t.Errorf("checksums should be the same: %x %x", csum, csumIncremental) 196 } 197 }) 198 } 199 } 200 } 201 202 func BenchmarkChecksum(b *testing.B) { 203 var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} 204 205 checkSumImpls := []struct { 206 fn func([]byte, uint16) uint16 207 name string 208 }{ 209 {old, "checksum_old"}, 210 {Checksum, "checksum"}, 211 } 212 213 for _, csumImpl := range checkSumImpls { 214 // Ensure same buffer generation for test consistency. 215 rnd := rand.New(rand.NewSource(42)) 216 for _, bufSz := range bufSizes { 217 b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { 218 tc := struct { 219 buf []byte 220 initial uint16 221 csum uint16 222 }{ 223 buf: make([]byte, bufSz), 224 initial: uint16(rnd.Intn(65536)), 225 } 226 rnd.Read(tc.buf) 227 b.ResetTimer() 228 for i := 0; i < b.N; i++ { 229 tc.csum = csumImpl.fn(tc.buf, tc.initial) 230 } 231 }) 232 } 233 } 234 } 235 236 // old calculates the checksum (as defined in RFC 1071) of the bytes in 237 // the given byte array. This function uses a non-optimized implementation. Its 238 // only retained for reference and to use as a benchmark/test. Most code should 239 // use the header.Checksum function. 240 // 241 // The initial checksum must have been computed on an even number of bytes. 242 func old(buf []byte, initial uint16) uint16 { 243 s, _ := oldCalculateChecksum(buf, false, uint32(initial)) 244 return s 245 } 246 247 func oldCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { 248 v := initial 249 250 if odd { 251 v += uint32(buf[0]) 252 buf = buf[1:] 253 } 254 255 l := len(buf) 256 odd = l&1 != 0 257 if odd { 258 l-- 259 v += uint32(buf[l]) << 8 260 } 261 262 for i := 0; i < l; i += 2 { 263 v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) 264 } 265 266 return Combine(uint16(v), uint16(v>>16)), odd 267 }