inet.af/netstack@v0.0.0-20220214151720-7585b01ddccf/tcpip/header/checksum.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package header provides the implementation of the encoding and decoding of
    16  // network protocol headers.
    17  package header
    18  
    19  import (
    20  	"encoding/binary"
    21  	"fmt"
    22  
    23  	"inet.af/netstack/tcpip"
    24  	"inet.af/netstack/tcpip/buffer"
    25  )
    26  
    27  func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
    28  	v := initial
    29  
    30  	if odd {
    31  		v += uint32(buf[0])
    32  		buf = buf[1:]
    33  	}
    34  
    35  	l := len(buf)
    36  	odd = l&1 != 0
    37  	if odd {
    38  		l--
    39  		v += uint32(buf[l]) << 8
    40  	}
    41  
    42  	for i := 0; i < l; i += 2 {
    43  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    44  	}
    45  
    46  	return ChecksumCombine(uint16(v), uint16(v>>16)), odd
    47  }
    48  
    49  func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
    50  	v := initial
    51  
    52  	if odd {
    53  		v += uint32(buf[0])
    54  		buf = buf[1:]
    55  	}
    56  
    57  	l := len(buf)
    58  	odd = l&1 != 0
    59  	if odd {
    60  		l--
    61  		v += uint32(buf[l]) << 8
    62  	}
    63  	for (l - 64) >= 0 {
    64  		i := 0
    65  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    66  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    67  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    68  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    69  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    70  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    71  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    72  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
    73  		i += 16
    74  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    75  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    76  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    77  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    78  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    79  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    80  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    81  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
    82  		i += 16
    83  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    84  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    85  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    86  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    87  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    88  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    89  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    90  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
    91  		i += 16
    92  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    93  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    94  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    95  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    96  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    97  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    98  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    99  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
   100  		buf = buf[64:]
   101  		l = l - 64
   102  	}
   103  	if (l - 32) >= 0 {
   104  		i := 0
   105  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   106  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   107  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
   108  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
   109  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
   110  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
   111  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
   112  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
   113  		i += 16
   114  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   115  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   116  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
   117  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
   118  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
   119  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
   120  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
   121  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
   122  		buf = buf[32:]
   123  		l = l - 32
   124  	}
   125  	if (l - 16) >= 0 {
   126  		i := 0
   127  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   128  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   129  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
   130  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
   131  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
   132  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
   133  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
   134  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
   135  		buf = buf[16:]
   136  		l = l - 16
   137  	}
   138  	if (l - 8) >= 0 {
   139  		i := 0
   140  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   141  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   142  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
   143  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
   144  		buf = buf[8:]
   145  		l = l - 8
   146  	}
   147  	if (l - 4) >= 0 {
   148  		i := 0
   149  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   150  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   151  		buf = buf[4:]
   152  		l = l - 4
   153  	}
   154  
   155  	// At this point since l was even before we started unrolling
   156  	// there can be only two bytes left to add.
   157  	if l != 0 {
   158  		v += (uint32(buf[0]) << 8) + uint32(buf[1])
   159  	}
   160  
   161  	return ChecksumCombine(uint16(v), uint16(v>>16)), odd
   162  }
   163  
   164  // ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in
   165  // the given byte array. This function uses a non-optimized implementation. Its
   166  // only retained for reference and to use as a benchmark/test. Most code should
   167  // use the header.Checksum function.
   168  //
   169  // The initial checksum must have been computed on an even number of bytes.
   170  func ChecksumOld(buf []byte, initial uint16) uint16 {
   171  	s, _ := calculateChecksum(buf, false, uint32(initial))
   172  	return s
   173  }
   174  
   175  // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
   176  // given byte array. This function uses an optimized unrolled version of the
   177  // checksum algorithm.
   178  //
   179  // The initial checksum must have been computed on an even number of bytes.
   180  func Checksum(buf []byte, initial uint16) uint16 {
   181  	s, _ := unrolledCalculateChecksum(buf, false, uint32(initial))
   182  	return s
   183  }
   184  
   185  // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
   186  // the given VectorizedView.
   187  //
   188  // The initial checksum must have been computed on an even number of bytes.
   189  func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
   190  	var c Checksumer
   191  	for _, v := range vv.Views() {
   192  		c.Add([]byte(v))
   193  	}
   194  	return ChecksumCombine(initial, c.Checksum())
   195  }
   196  
   197  // Checksumer calculates checksum defined in RFC 1071.
   198  type Checksumer struct {
   199  	sum uint16
   200  	odd bool
   201  }
   202  
   203  // Add adds b to checksum.
   204  func (c *Checksumer) Add(b []byte) {
   205  	if len(b) > 0 {
   206  		c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum))
   207  	}
   208  }
   209  
   210  // Checksum returns the latest checksum value.
   211  func (c *Checksumer) Checksum() uint16 {
   212  	return c.sum
   213  }
   214  
   215  // ChecksumCombine combines the two uint16 to form their checksum. This is done
   216  // by adding them and the carry.
   217  //
   218  // Note that checksum a must have been computed on an even number of bytes.
   219  func ChecksumCombine(a, b uint16) uint16 {
   220  	v := uint32(a) + uint32(b)
   221  	return uint16(v + v>>16)
   222  }
   223  
   224  // PseudoHeaderChecksum calculates the pseudo-header checksum for the given
   225  // destination protocol and network address. Pseudo-headers are needed by
   226  // transport layers when calculating their own checksum.
   227  func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 {
   228  	xsum := Checksum([]byte(srcAddr), 0)
   229  	xsum = Checksum([]byte(dstAddr), xsum)
   230  
   231  	// Add the length portion of the checksum to the pseudo-checksum.
   232  	tmp := make([]byte, 2)
   233  	binary.BigEndian.PutUint16(tmp, totalLen)
   234  	xsum = Checksum(tmp, xsum)
   235  
   236  	return Checksum([]byte{0, uint8(protocol)}, xsum)
   237  }
   238  
   239  // checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
   240  // checksum.
   241  //
   242  // The value MUST begin at a 2-byte boundary in the original buffer.
   243  func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
   244  	// As per RFC 1071 page 4,
   245  	//	(4)  Incremental Update
   246  	//
   247  	//        ...
   248  	//
   249  	//        To update the checksum, simply add the differences of the
   250  	//        sixteen bit integers that have been changed.  To see why this
   251  	//        works, observe that every 16-bit integer has an additive inverse
   252  	//        and that addition is associative.  From this it follows that
   253  	//        given the original value m, the new value m', and the old
   254  	//        checksum C, the new checksum C' is:
   255  	//
   256  	//                C' = C + (-m) + m' = C + (m' - m)
   257  	return ChecksumCombine(xsum, ChecksumCombine(new, ^old))
   258  }
   259  
   260  // checksumUpdate2ByteAlignedAddress updates an address in a calculated
   261  // checksum.
   262  //
   263  // The addresses must have the same length and must contain an even number
   264  // of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
   265  func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
   266  	const uint16Bytes = 2
   267  
   268  	if len(old) != len(new) {
   269  		panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new)))
   270  	}
   271  
   272  	if len(old)%uint16Bytes != 0 {
   273  		panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old)))
   274  	}
   275  
   276  	// As per RFC 1071 page 4,
   277  	//	(4)  Incremental Update
   278  	//
   279  	//        ...
   280  	//
   281  	//        To update the checksum, simply add the differences of the
   282  	//        sixteen bit integers that have been changed.  To see why this
   283  	//        works, observe that every 16-bit integer has an additive inverse
   284  	//        and that addition is associative.  From this it follows that
   285  	//        given the original value m, the new value m', and the old
   286  	//        checksum C, the new checksum C' is:
   287  	//
   288  	//                C' = C + (-m) + m' = C + (m' - m)
   289  	for len(old) != 0 {
   290  		// Convert the 2 byte sequences to uint16 values then apply the increment
   291  		// update.
   292  		xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1]))
   293  		old = old[uint16Bytes:]
   294  		new = new[uint16Bytes:]
   295  	}
   296  
   297  	return xsum
   298  }