github.com/gopacket/gopacket@v1.1.0/layers/tcpip.go (about) 1 // Copyright 2012 Google, Inc. All rights reserved. 2 // Copyright 2009-2011 Andreas Krennmair. All rights reserved. 3 // 4 // Use of this source code is governed by a BSD-style license 5 // that can be found in the LICENSE file in the root of the source 6 // tree. 7 8 package layers 9 10 import ( 11 "errors" 12 "fmt" 13 14 "github.com/gopacket/gopacket" 15 ) 16 17 // Checksum computation for TCP/UDP. 18 type tcpipchecksum struct { 19 pseudoheader tcpipPseudoHeader 20 } 21 22 type tcpipPseudoHeader interface { 23 pseudoheaderChecksum() (uint32, error) 24 } 25 26 func (ip *IPv4) pseudoheaderChecksum() (csum uint32, err error) { 27 if err := ip.AddressTo4(); err != nil { 28 return 0, err 29 } 30 csum += (uint32(ip.SrcIP[0]) + uint32(ip.SrcIP[2])) << 8 31 csum += uint32(ip.SrcIP[1]) + uint32(ip.SrcIP[3]) 32 csum += (uint32(ip.DstIP[0]) + uint32(ip.DstIP[2])) << 8 33 csum += uint32(ip.DstIP[1]) + uint32(ip.DstIP[3]) 34 return csum, nil 35 } 36 37 func (ip *IPv6) pseudoheaderChecksum() (csum uint32, err error) { 38 if err := ip.AddressTo16(); err != nil { 39 return 0, err 40 } 41 for i := 0; i < 16; i += 2 { 42 csum += uint32(ip.SrcIP[i]) << 8 43 csum += uint32(ip.SrcIP[i+1]) 44 csum += uint32(ip.DstIP[i]) << 8 45 csum += uint32(ip.DstIP[i+1]) 46 } 47 return csum, nil 48 } 49 50 // Calculate the TCP/IP checksum defined in rfc1071. The passed-in csum is any 51 // initial checksum data that's already been computed. 52 func tcpipChecksum(data []byte, csum uint32) uint16 { 53 // to handle odd lengths, we loop to length - 1, incrementing by 2, then 54 // handle the last byte specifically by checking against the original 55 // length. 56 length := len(data) - 1 57 for i := 0; i < length; i += 2 { 58 // For our test packet, doing this manually is about 25% faster 59 // (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16. 60 csum += uint32(data[i]) << 8 61 csum += uint32(data[i+1]) 62 } 63 if len(data)%2 == 1 { 64 csum += uint32(data[length]) << 8 65 } 66 for csum > 0xffff { 67 csum = (csum >> 16) + (csum & 0xffff) 68 } 69 return ^uint16(csum) 70 } 71 72 // computeChecksum computes a TCP or UDP checksum. headerAndPayload is the 73 // serialized TCP or UDP header plus its payload, with the checksum zero'd 74 // out. headerProtocol is the IP protocol number of the upper-layer header. 75 func (c *tcpipchecksum) computeChecksum(headerAndPayload []byte, headerProtocol IPProtocol) (uint16, error) { 76 if c.pseudoheader == nil { 77 return 0, errors.New("TCP/IP layer 4 checksum cannot be computed without network layer... call SetNetworkLayerForChecksum to set which layer to use") 78 } 79 length := uint32(len(headerAndPayload)) 80 csum, err := c.pseudoheader.pseudoheaderChecksum() 81 if err != nil { 82 return 0, err 83 } 84 csum += uint32(headerProtocol) 85 csum += length & 0xffff 86 csum += length >> 16 87 return tcpipChecksum(headerAndPayload, csum), nil 88 } 89 90 // SetNetworkLayerForChecksum tells this layer which network layer is wrapping it. 91 // This is needed for computing the checksum when serializing, since TCP/IP transport 92 // layer checksums depends on fields in the IPv4 or IPv6 layer that contains it. 93 // The passed in layer must be an *IPv4 or *IPv6. 94 func (i *tcpipchecksum) SetNetworkLayerForChecksum(l gopacket.NetworkLayer) error { 95 switch v := l.(type) { 96 case *IPv4: 97 i.pseudoheader = v 98 case *IPv6: 99 i.pseudoheader = v 100 default: 101 return fmt.Errorf("cannot use layer type %v for tcp checksum network layer", l.LayerType()) 102 } 103 return nil 104 }