github.com/gopacket/gopacket@v1.1.0/layers/tcpip.go (about)

     1  // Copyright 2012 Google, Inc. All rights reserved.
     2  // Copyright 2009-2011 Andreas Krennmair. All rights reserved.
     3  //
     4  // Use of this source code is governed by a BSD-style license
     5  // that can be found in the LICENSE file in the root of the source
     6  // tree.
     7  
     8  package layers
     9  
    10  import (
    11  	"errors"
    12  	"fmt"
    13  
    14  	"github.com/gopacket/gopacket"
    15  )
    16  
    17  // Checksum computation for TCP/UDP.
    18  type tcpipchecksum struct {
    19  	pseudoheader tcpipPseudoHeader
    20  }
    21  
    22  type tcpipPseudoHeader interface {
    23  	pseudoheaderChecksum() (uint32, error)
    24  }
    25  
    26  func (ip *IPv4) pseudoheaderChecksum() (csum uint32, err error) {
    27  	if err := ip.AddressTo4(); err != nil {
    28  		return 0, err
    29  	}
    30  	csum += (uint32(ip.SrcIP[0]) + uint32(ip.SrcIP[2])) << 8
    31  	csum += uint32(ip.SrcIP[1]) + uint32(ip.SrcIP[3])
    32  	csum += (uint32(ip.DstIP[0]) + uint32(ip.DstIP[2])) << 8
    33  	csum += uint32(ip.DstIP[1]) + uint32(ip.DstIP[3])
    34  	return csum, nil
    35  }
    36  
    37  func (ip *IPv6) pseudoheaderChecksum() (csum uint32, err error) {
    38  	if err := ip.AddressTo16(); err != nil {
    39  		return 0, err
    40  	}
    41  	for i := 0; i < 16; i += 2 {
    42  		csum += uint32(ip.SrcIP[i]) << 8
    43  		csum += uint32(ip.SrcIP[i+1])
    44  		csum += uint32(ip.DstIP[i]) << 8
    45  		csum += uint32(ip.DstIP[i+1])
    46  	}
    47  	return csum, nil
    48  }
    49  
    50  // Calculate the TCP/IP checksum defined in rfc1071.  The passed-in csum is any
    51  // initial checksum data that's already been computed.
    52  func tcpipChecksum(data []byte, csum uint32) uint16 {
    53  	// to handle odd lengths, we loop to length - 1, incrementing by 2, then
    54  	// handle the last byte specifically by checking against the original
    55  	// length.
    56  	length := len(data) - 1
    57  	for i := 0; i < length; i += 2 {
    58  		// For our test packet, doing this manually is about 25% faster
    59  		// (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16.
    60  		csum += uint32(data[i]) << 8
    61  		csum += uint32(data[i+1])
    62  	}
    63  	if len(data)%2 == 1 {
    64  		csum += uint32(data[length]) << 8
    65  	}
    66  	for csum > 0xffff {
    67  		csum = (csum >> 16) + (csum & 0xffff)
    68  	}
    69  	return ^uint16(csum)
    70  }
    71  
    72  // computeChecksum computes a TCP or UDP checksum.  headerAndPayload is the
    73  // serialized TCP or UDP header plus its payload, with the checksum zero'd
    74  // out. headerProtocol is the IP protocol number of the upper-layer header.
    75  func (c *tcpipchecksum) computeChecksum(headerAndPayload []byte, headerProtocol IPProtocol) (uint16, error) {
    76  	if c.pseudoheader == nil {
    77  		return 0, errors.New("TCP/IP layer 4 checksum cannot be computed without network layer... call SetNetworkLayerForChecksum to set which layer to use")
    78  	}
    79  	length := uint32(len(headerAndPayload))
    80  	csum, err := c.pseudoheader.pseudoheaderChecksum()
    81  	if err != nil {
    82  		return 0, err
    83  	}
    84  	csum += uint32(headerProtocol)
    85  	csum += length & 0xffff
    86  	csum += length >> 16
    87  	return tcpipChecksum(headerAndPayload, csum), nil
    88  }
    89  
    90  // SetNetworkLayerForChecksum tells this layer which network layer is wrapping it.
    91  // This is needed for computing the checksum when serializing, since TCP/IP transport
    92  // layer checksums depends on fields in the IPv4 or IPv6 layer that contains it.
    93  // The passed in layer must be an *IPv4 or *IPv6.
    94  func (i *tcpipchecksum) SetNetworkLayerForChecksum(l gopacket.NetworkLayer) error {
    95  	switch v := l.(type) {
    96  	case *IPv4:
    97  		i.pseudoheader = v
    98  	case *IPv6:
    99  		i.pseudoheader = v
   100  	default:
   101  		return fmt.Errorf("cannot use layer type %v for tcp checksum network layer", l.LayerType())
   102  	}
   103  	return nil
   104  }