github.com/lightlus/netstack@v1.2.0/tcpip/header/checksum.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package header provides the implementation of the encoding and decoding of
    16  // network protocol headers.
    17  package header
    18  
    19  import (
    20  	"encoding/binary"
    21  
    22  	"github.com/lightlus/netstack/tcpip"
    23  	"github.com/lightlus/netstack/tcpip/buffer"
    24  )
    25  
    26  func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
    27  	v := initial
    28  
    29  	if odd {
    30  		v += uint32(buf[0])
    31  		buf = buf[1:]
    32  	}
    33  
    34  	l := len(buf)
    35  	odd = l&1 != 0
    36  	if odd {
    37  		l--
    38  		v += uint32(buf[l]) << 8
    39  	}
    40  
    41  	for i := 0; i < l; i += 2 {
    42  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    43  	}
    44  
    45  	return ChecksumCombine(uint16(v), uint16(v>>16)), odd
    46  }
    47  
    48  // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
    49  // given byte array.
    50  //
    51  // The initial checksum must have been computed on an even number of bytes.
    52  func Checksum(buf []byte, initial uint16) uint16 {
    53  	s, _ := calculateChecksum(buf, false, uint32(initial))
    54  	return s
    55  }
    56  
    57  // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
    58  // the given VectorizedView.
    59  //
    60  // The initial checksum must have been computed on an even number of bytes.
    61  func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
    62  	return ChecksumVVWithOffset(vv, initial, 0, vv.Size())
    63  }
    64  
    65  // ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the
    66  // bytes in the given VectorizedView.
    67  //
    68  // The initial checksum must have been computed on an even number of bytes.
    69  func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 {
    70  	odd := false
    71  	sum := initial
    72  	for _, v := range vv.Views() {
    73  		if len(v) == 0 {
    74  			continue
    75  		}
    76  
    77  		if off >= len(v) {
    78  			off -= len(v)
    79  			continue
    80  		}
    81  		v = v[off:]
    82  
    83  		l := len(v)
    84  		if l > size {
    85  			l = size
    86  		}
    87  		v = v[:l]
    88  
    89  		sum, odd = calculateChecksum(v, odd, uint32(sum))
    90  
    91  		size -= len(v)
    92  		if size == 0 {
    93  			break
    94  		}
    95  		off = 0
    96  	}
    97  	return sum
    98  }
    99  
   100  // ChecksumCombine combines the two uint16 to form their checksum. This is done
   101  // by adding them and the carry.
   102  //
   103  // Note that checksum a must have been computed on an even number of bytes.
   104  func ChecksumCombine(a, b uint16) uint16 {
   105  	v := uint32(a) + uint32(b)
   106  	return uint16(v + v>>16)
   107  }
   108  
   109  // PseudoHeaderChecksum calculates the pseudo-header checksum for the given
   110  // destination protocol and network address. Pseudo-headers are needed by
   111  // transport layers when calculating their own checksum.
   112  func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 {
   113  	xsum := Checksum([]byte(srcAddr), 0)
   114  	xsum = Checksum([]byte(dstAddr), xsum)
   115  
   116  	// Add the length portion of the checksum to the pseudo-checksum.
   117  	tmp := make([]byte, 2)
   118  	binary.BigEndian.PutUint16(tmp, totalLen)
   119  	xsum = Checksum(tmp, xsum)
   120  
   121  	return Checksum([]byte{0, uint8(protocol)}, xsum)
   122  }