github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/header/udp.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package header
    16  
    17  import (
    18  	"encoding/binary"
    19  	"math"
    20  
    21  	"github.com/SagerNet/gvisor/pkg/tcpip"
    22  )
    23  
    24  const (
    25  	udpSrcPort  = 0
    26  	udpDstPort  = 2
    27  	udpLength   = 4
    28  	udpChecksum = 6
    29  )
    30  
    31  const (
    32  	// UDPMaximumPacketSize is the largest possible UDP packet.
    33  	UDPMaximumPacketSize = 0xffff
    34  )
    35  
    36  // UDPFields contains the fields of a UDP packet. It is used to describe the
    37  // fields of a packet that needs to be encoded.
    38  type UDPFields struct {
    39  	// SrcPort is the "source port" field of a UDP packet.
    40  	SrcPort uint16
    41  
    42  	// DstPort is the "destination port" field of a UDP packet.
    43  	DstPort uint16
    44  
    45  	// Length is the "length" field of a UDP packet.
    46  	Length uint16
    47  
    48  	// Checksum is the "checksum" field of a UDP packet.
    49  	Checksum uint16
    50  }
    51  
    52  // UDP represents a UDP header stored in a byte array.
    53  type UDP []byte
    54  
    55  const (
    56  	// UDPMinimumSize is the minimum size of a valid UDP packet.
    57  	UDPMinimumSize = 8
    58  
    59  	// UDPMaximumSize is the maximum size of a valid UDP packet. The length field
    60  	// in the UDP header is 16 bits as per RFC 768.
    61  	UDPMaximumSize = math.MaxUint16
    62  
    63  	// UDPProtocolNumber is UDP's transport protocol number.
    64  	UDPProtocolNumber tcpip.TransportProtocolNumber = 17
    65  )
    66  
    67  // SourcePort returns the "source port" field of the UDP header.
    68  func (b UDP) SourcePort() uint16 {
    69  	return binary.BigEndian.Uint16(b[udpSrcPort:])
    70  }
    71  
    72  // DestinationPort returns the "destination port" field of the UDP header.
    73  func (b UDP) DestinationPort() uint16 {
    74  	return binary.BigEndian.Uint16(b[udpDstPort:])
    75  }
    76  
    77  // Length returns the "length" field of the UDP header.
    78  func (b UDP) Length() uint16 {
    79  	return binary.BigEndian.Uint16(b[udpLength:])
    80  }
    81  
    82  // Payload returns the data contained in the UDP datagram.
    83  func (b UDP) Payload() []byte {
    84  	return b[UDPMinimumSize:]
    85  }
    86  
    87  // Checksum returns the "checksum" field of the UDP header.
    88  func (b UDP) Checksum() uint16 {
    89  	return binary.BigEndian.Uint16(b[udpChecksum:])
    90  }
    91  
    92  // SetSourcePort sets the "source port" field of the UDP header.
    93  func (b UDP) SetSourcePort(port uint16) {
    94  	binary.BigEndian.PutUint16(b[udpSrcPort:], port)
    95  }
    96  
    97  // SetDestinationPort sets the "destination port" field of the UDP header.
    98  func (b UDP) SetDestinationPort(port uint16) {
    99  	binary.BigEndian.PutUint16(b[udpDstPort:], port)
   100  }
   101  
   102  // SetChecksum sets the "checksum" field of the UDP header.
   103  func (b UDP) SetChecksum(checksum uint16) {
   104  	binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
   105  }
   106  
   107  // SetLength sets the "length" field of the UDP header.
   108  func (b UDP) SetLength(length uint16) {
   109  	binary.BigEndian.PutUint16(b[udpLength:], length)
   110  }
   111  
   112  // CalculateChecksum calculates the checksum of the UDP packet, given the
   113  // checksum of the network-layer pseudo-header and the checksum of the payload.
   114  func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
   115  	// Calculate the rest of the checksum.
   116  	return Checksum(b[:UDPMinimumSize], partialChecksum)
   117  }
   118  
   119  // IsChecksumValid returns true iff the UDP header's checksum is valid.
   120  func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
   121  	xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length())
   122  	xsum = ChecksumCombine(xsum, payloadChecksum)
   123  	return b.CalculateChecksum(xsum) == 0xffff
   124  }
   125  
   126  // Encode encodes all the fields of the UDP header.
   127  func (b UDP) Encode(u *UDPFields) {
   128  	binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
   129  	binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
   130  	binary.BigEndian.PutUint16(b[udpLength:], u.Length)
   131  	binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
   132  }
   133  
   134  // SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
   135  func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
   136  	old := b.SourcePort()
   137  	b.SetSourcePort(new)
   138  	b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
   139  }
   140  
   141  // SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
   142  func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
   143  	old := b.DestinationPort()
   144  	b.SetDestinationPort(new)
   145  	b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
   146  }
   147  
   148  // UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
   149  func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
   150  	xsum := b.Checksum()
   151  	if fullChecksum {
   152  		xsum = ^xsum
   153  	}
   154  
   155  	xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
   156  	if fullChecksum {
   157  		xsum = ^xsum
   158  	}
   159  
   160  	b.SetChecksum(xsum)
   161  }