github.com/MerlinKodo/gvisor@v0.0.0-20231110090155-957f62ecf90e/pkg/tcpip/header/udp.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package header
    16  
    17  import (
    18  	"encoding/binary"
    19  	"math"
    20  
    21  	"github.com/MerlinKodo/gvisor/pkg/tcpip"
    22  	"github.com/MerlinKodo/gvisor/pkg/tcpip/checksum"
    23  )
    24  
    25  const (
    26  	udpSrcPort  = 0
    27  	udpDstPort  = 2
    28  	udpLength   = 4
    29  	udpChecksum = 6
    30  )
    31  
    32  const (
    33  	// UDPMaximumPacketSize is the largest possible UDP packet.
    34  	UDPMaximumPacketSize = 0xffff
    35  )
    36  
    37  // UDPFields contains the fields of a UDP packet. It is used to describe the
    38  // fields of a packet that needs to be encoded.
    39  type UDPFields struct {
    40  	// SrcPort is the "source port" field of a UDP packet.
    41  	SrcPort uint16
    42  
    43  	// DstPort is the "destination port" field of a UDP packet.
    44  	DstPort uint16
    45  
    46  	// Length is the "length" field of a UDP packet.
    47  	Length uint16
    48  
    49  	// Checksum is the "checksum" field of a UDP packet.
    50  	Checksum uint16
    51  }
    52  
    53  // UDP represents a UDP header stored in a byte array.
    54  type UDP []byte
    55  
    56  const (
    57  	// UDPMinimumSize is the minimum size of a valid UDP packet.
    58  	UDPMinimumSize = 8
    59  
    60  	// UDPMaximumSize is the maximum size of a valid UDP packet. The length field
    61  	// in the UDP header is 16 bits as per RFC 768.
    62  	UDPMaximumSize = math.MaxUint16
    63  
    64  	// UDPProtocolNumber is UDP's transport protocol number.
    65  	UDPProtocolNumber tcpip.TransportProtocolNumber = 17
    66  )
    67  
    68  // SourcePort returns the "source port" field of the UDP header.
    69  func (b UDP) SourcePort() uint16 {
    70  	return binary.BigEndian.Uint16(b[udpSrcPort:])
    71  }
    72  
    73  // DestinationPort returns the "destination port" field of the UDP header.
    74  func (b UDP) DestinationPort() uint16 {
    75  	return binary.BigEndian.Uint16(b[udpDstPort:])
    76  }
    77  
    78  // Length returns the "length" field of the UDP header.
    79  func (b UDP) Length() uint16 {
    80  	return binary.BigEndian.Uint16(b[udpLength:])
    81  }
    82  
    83  // Payload returns the data contained in the UDP datagram.
    84  func (b UDP) Payload() []byte {
    85  	return b[UDPMinimumSize:]
    86  }
    87  
    88  // Checksum returns the "checksum" field of the UDP header.
    89  func (b UDP) Checksum() uint16 {
    90  	return binary.BigEndian.Uint16(b[udpChecksum:])
    91  }
    92  
    93  // SetSourcePort sets the "source port" field of the UDP header.
    94  func (b UDP) SetSourcePort(port uint16) {
    95  	binary.BigEndian.PutUint16(b[udpSrcPort:], port)
    96  }
    97  
    98  // SetDestinationPort sets the "destination port" field of the UDP header.
    99  func (b UDP) SetDestinationPort(port uint16) {
   100  	binary.BigEndian.PutUint16(b[udpDstPort:], port)
   101  }
   102  
   103  // SetChecksum sets the "checksum" field of the UDP header.
   104  func (b UDP) SetChecksum(xsum uint16) {
   105  	checksum.Put(b[udpChecksum:], xsum)
   106  }
   107  
   108  // SetLength sets the "length" field of the UDP header.
   109  func (b UDP) SetLength(length uint16) {
   110  	binary.BigEndian.PutUint16(b[udpLength:], length)
   111  }
   112  
   113  // CalculateChecksum calculates the checksum of the UDP packet, given the
   114  // checksum of the network-layer pseudo-header and the checksum of the payload.
   115  func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
   116  	// Calculate the rest of the checksum.
   117  	return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
   118  }
   119  
   120  // IsChecksumValid returns true iff the UDP header's checksum is valid.
   121  func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
   122  	xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length())
   123  	xsum = checksum.Combine(xsum, payloadChecksum)
   124  	return b.CalculateChecksum(xsum) == 0xffff
   125  }
   126  
   127  // Encode encodes all the fields of the UDP header.
   128  func (b UDP) Encode(u *UDPFields) {
   129  	b.SetSourcePort(u.SrcPort)
   130  	b.SetDestinationPort(u.DstPort)
   131  	b.SetLength(u.Length)
   132  	b.SetChecksum(u.Checksum)
   133  }
   134  
   135  // SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
   136  func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
   137  	old := b.SourcePort()
   138  	b.SetSourcePort(new)
   139  	b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
   140  }
   141  
   142  // SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
   143  func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
   144  	old := b.DestinationPort()
   145  	b.SetDestinationPort(new)
   146  	b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
   147  }
   148  
   149  // UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
   150  func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
   151  	xsum := b.Checksum()
   152  	if fullChecksum {
   153  		xsum = ^xsum
   154  	}
   155  
   156  	xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
   157  	if fullChecksum {
   158  		xsum = ^xsum
   159  	}
   160  
   161  	b.SetChecksum(xsum)
   162  }
   163  
   164  // UDPValid returns true if the pkt has a valid UDP header. It checks whether:
   165  //   - The length field is too small.
   166  //   - The length field is too large.
   167  //   - The checksum is invalid.
   168  //
   169  // UDPValid corresponds to net/netfilter/nf_conntrack_proto_udp.c:udp_error.
   170  func UDPValid(hdr UDP, payloadChecksum func() uint16, payloadSize uint16, netProto tcpip.NetworkProtocolNumber, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (lengthValid, csumValid bool) {
   171  	if length := hdr.Length(); length > payloadSize+UDPMinimumSize || length < UDPMinimumSize {
   172  		return false, false
   173  	}
   174  
   175  	if skipChecksumValidation {
   176  		return true, true
   177  	}
   178  
   179  	// On IPv4, UDP checksum is optional, and a zero value means the transmitter
   180  	// omitted the checksum generation, as per RFC 768:
   181  	//
   182  	//   An all zero transmitted checksum value means that the transmitter
   183  	//   generated  no checksum  (for debugging or for higher level protocols that
   184  	//   don't care).
   185  	//
   186  	// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
   187  	//
   188  	//   Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
   189  	//   checksum is not optional.
   190  	if netProto == IPv4ProtocolNumber && hdr.Checksum() == 0 {
   191  		return true, true
   192  	}
   193  
   194  	return true, hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum())
   195  }