gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/udp/endpoint.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io"
    21  	"math"
    22  	"time"
    23  
    24  	"gvisor.dev/gvisor/pkg/sync"
    25  	"gvisor.dev/gvisor/pkg/tcpip"
    26  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    27  	"gvisor.dev/gvisor/pkg/tcpip/header"
    28  	"gvisor.dev/gvisor/pkg/tcpip/ports"
    29  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    30  	"gvisor.dev/gvisor/pkg/tcpip/transport"
    31  	"gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
    32  	"gvisor.dev/gvisor/pkg/waiter"
    33  )
    34  
    35  // +stateify savable
    36  type udpPacket struct {
    37  	udpPacketEntry
    38  	netProto           tcpip.NetworkProtocolNumber
    39  	senderAddress      tcpip.FullAddress
    40  	destinationAddress tcpip.FullAddress
    41  	packetInfo         tcpip.IPPacketInfo
    42  	pkt                *stack.PacketBuffer
    43  	receivedAt         time.Time `state:".(int64)"`
    44  	// tosOrTClass stores either the Type of Service for IPv4 or the Traffic Class
    45  	// for IPv6.
    46  	tosOrTClass uint8
    47  	// ttlOrHopLimit stores either the TTL for IPv4 or the HopLimit for IPv6
    48  	ttlOrHopLimit uint8
    49  }
    50  
    51  // endpoint represents a UDP endpoint. This struct serves as the interface
    52  // between users of the endpoint and the protocol implementation; it is legal to
    53  // have concurrent goroutines make calls into the endpoint, they are properly
    54  // synchronized.
    55  //
    56  // It implements tcpip.Endpoint.
    57  //
    58  // +stateify savable
    59  type endpoint struct {
    60  	tcpip.DefaultSocketOptionsHandler
    61  
    62  	// The following fields are initialized at creation time and do not
    63  	// change throughout the lifetime of the endpoint.
    64  	stack       *stack.Stack `state:"manual"`
    65  	waiterQueue *waiter.Queue
    66  	uniqueID    uint64
    67  	net         network.Endpoint
    68  	stats       tcpip.TransportEndpointStats
    69  	ops         tcpip.SocketOptions
    70  
    71  	// The following fields are used to manage the receive queue, and are
    72  	// protected by rcvMu.
    73  	rcvMu      sync.Mutex `state:"nosave"`
    74  	rcvReady   bool
    75  	rcvList    udpPacketList
    76  	rcvBufSize int
    77  	rcvClosed  bool
    78  
    79  	lastErrorMu sync.Mutex `state:"nosave"`
    80  	lastError   tcpip.Error
    81  
    82  	// The following fields are protected by the mu mutex.
    83  	mu        sync.RWMutex `state:"nosave"`
    84  	portFlags ports.Flags
    85  
    86  	// Values used to reserve a port or register a transport endpoint.
    87  	// (which ever happens first).
    88  	boundBindToDevice tcpip.NICID
    89  	boundPortFlags    ports.Flags
    90  
    91  	readShutdown bool
    92  
    93  	// effectiveNetProtos contains the network protocols actually in use. In
    94  	// most cases it will only contain "netProto", but in cases like IPv6
    95  	// endpoints with v6only set to false, this could include multiple
    96  	// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
    97  	// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
    98  	// address).
    99  	effectiveNetProtos []tcpip.NetworkProtocolNumber
   100  
   101  	// frozen indicates if the packets should be delivered to the endpoint
   102  	// during restore.
   103  	frozen bool
   104  
   105  	localPort  uint16
   106  	remotePort uint16
   107  }
   108  
   109  func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
   110  	e := &endpoint{
   111  		stack:       s,
   112  		waiterQueue: waiterQueue,
   113  		uniqueID:    s.UniqueID(),
   114  	}
   115  	e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
   116  	e.ops.SetMulticastLoop(true)
   117  	e.ops.SetSendBufferSize(32*1024, false /* notify */)
   118  	e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
   119  	e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops, waiterQueue)
   120  
   121  	// Override with stack defaults.
   122  	var ss tcpip.SendBufferSizeOption
   123  	if err := s.Option(&ss); err == nil {
   124  		e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
   125  	}
   126  
   127  	var rs tcpip.ReceiveBufferSizeOption
   128  	if err := s.Option(&rs); err == nil {
   129  		e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
   130  	}
   131  
   132  	return e
   133  }
   134  
   135  // WakeupWriters implements tcpip.SocketOptionsHandler.
   136  func (e *endpoint) WakeupWriters() {
   137  	e.net.MaybeSignalWritable()
   138  }
   139  
   140  // UniqueID implements stack.TransportEndpoint.
   141  func (e *endpoint) UniqueID() uint64 {
   142  	return e.uniqueID
   143  }
   144  
   145  func (e *endpoint) LastError() tcpip.Error {
   146  	e.lastErrorMu.Lock()
   147  	defer e.lastErrorMu.Unlock()
   148  
   149  	err := e.lastError
   150  	e.lastError = nil
   151  	return err
   152  }
   153  
   154  // UpdateLastError implements tcpip.SocketOptionsHandler.
   155  func (e *endpoint) UpdateLastError(err tcpip.Error) {
   156  	e.lastErrorMu.Lock()
   157  	e.lastError = err
   158  	e.lastErrorMu.Unlock()
   159  }
   160  
   161  // Abort implements stack.TransportEndpoint.
   162  func (e *endpoint) Abort() {
   163  	e.Close()
   164  }
   165  
   166  // Close puts the endpoint in a closed state and frees all resources
   167  // associated with it.
   168  func (e *endpoint) Close() {
   169  	e.mu.Lock()
   170  
   171  	switch state := e.net.State(); state {
   172  	case transport.DatagramEndpointStateInitial:
   173  	case transport.DatagramEndpointStateClosed:
   174  		e.mu.Unlock()
   175  		return
   176  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   177  		id := e.net.Info().ID
   178  		id.LocalPort = e.localPort
   179  		id.RemotePort = e.remotePort
   180  		e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice)
   181  		portRes := ports.Reservation{
   182  			Networks:     e.effectiveNetProtos,
   183  			Transport:    ProtocolNumber,
   184  			Addr:         id.LocalAddress,
   185  			Port:         id.LocalPort,
   186  			Flags:        e.boundPortFlags,
   187  			BindToDevice: e.boundBindToDevice,
   188  			Dest:         tcpip.FullAddress{},
   189  		}
   190  		e.stack.ReleasePort(portRes)
   191  		e.boundBindToDevice = 0
   192  		e.boundPortFlags = ports.Flags{}
   193  	default:
   194  		panic(fmt.Sprintf("unhandled state = %s", state))
   195  	}
   196  
   197  	// Close the receive list and drain it.
   198  	e.rcvMu.Lock()
   199  	e.rcvClosed = true
   200  	e.rcvBufSize = 0
   201  	for !e.rcvList.Empty() {
   202  		p := e.rcvList.Front()
   203  		e.rcvList.Remove(p)
   204  		p.pkt.DecRef()
   205  	}
   206  	e.rcvMu.Unlock()
   207  
   208  	e.net.Shutdown()
   209  	e.net.Close()
   210  	e.readShutdown = true
   211  	e.mu.Unlock()
   212  
   213  	e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   214  }
   215  
   216  // ModerateRecvBuf implements tcpip.Endpoint.
   217  func (*endpoint) ModerateRecvBuf(int) {}
   218  
   219  // Read implements tcpip.Endpoint.
   220  func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
   221  	if err := e.LastError(); err != nil {
   222  		return tcpip.ReadResult{}, err
   223  	}
   224  
   225  	e.rcvMu.Lock()
   226  
   227  	if e.rcvList.Empty() {
   228  		var err tcpip.Error = &tcpip.ErrWouldBlock{}
   229  		if e.rcvClosed {
   230  			e.stats.ReadErrors.ReadClosed.Increment()
   231  			err = &tcpip.ErrClosedForReceive{}
   232  		}
   233  		e.rcvMu.Unlock()
   234  		return tcpip.ReadResult{}, err
   235  	}
   236  
   237  	p := e.rcvList.Front()
   238  	if !opts.Peek {
   239  		e.rcvList.Remove(p)
   240  		defer p.pkt.DecRef()
   241  		e.rcvBufSize -= p.pkt.Data().Size()
   242  	}
   243  	e.rcvMu.Unlock()
   244  
   245  	// Control Messages
   246  	// TODO(https://gvisor.dev/issue/7012): Share control message code with other
   247  	// network endpoints.
   248  	cm := tcpip.ReceivableControlMessages{
   249  		HasTimestamp: true,
   250  		Timestamp:    p.receivedAt,
   251  	}
   252  	switch p.netProto {
   253  	case header.IPv4ProtocolNumber:
   254  		if e.ops.GetReceiveTOS() {
   255  			cm.HasTOS = true
   256  			cm.TOS = p.tosOrTClass
   257  		}
   258  		if e.ops.GetReceiveTTL() {
   259  			cm.HasTTL = true
   260  			cm.TTL = p.ttlOrHopLimit
   261  		}
   262  		if e.ops.GetReceivePacketInfo() {
   263  			cm.HasIPPacketInfo = true
   264  			cm.PacketInfo = p.packetInfo
   265  		}
   266  	case header.IPv6ProtocolNumber:
   267  		if e.ops.GetReceiveTClass() {
   268  			cm.HasTClass = true
   269  			// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
   270  			cm.TClass = uint32(p.tosOrTClass)
   271  		}
   272  		if e.ops.GetReceiveHopLimit() {
   273  			cm.HasHopLimit = true
   274  			cm.HopLimit = p.ttlOrHopLimit
   275  		}
   276  		if e.ops.GetIPv6ReceivePacketInfo() {
   277  			cm.HasIPv6PacketInfo = true
   278  			cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
   279  				NIC:  p.packetInfo.NIC,
   280  				Addr: p.packetInfo.DestinationAddr,
   281  			}
   282  		}
   283  	default:
   284  		panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
   285  	}
   286  
   287  	if e.ops.GetReceiveOriginalDstAddress() {
   288  		cm.HasOriginalDstAddress = true
   289  		cm.OriginalDstAddress = p.destinationAddress
   290  	}
   291  
   292  	// Read Result
   293  	res := tcpip.ReadResult{
   294  		Total:           p.pkt.Data().Size(),
   295  		ControlMessages: cm,
   296  	}
   297  	if opts.NeedRemoteAddr {
   298  		res.RemoteAddr = p.senderAddress
   299  	}
   300  
   301  	n, err := p.pkt.Data().ReadTo(dst, opts.Peek)
   302  	if n == 0 && err != nil {
   303  		return res, &tcpip.ErrBadBuffer{}
   304  	}
   305  	res.Count = n
   306  	return res, nil
   307  }
   308  
   309  // prepareForWriteInner prepares the endpoint for sending data. In particular,
   310  // it binds it if it's still in the initial state. To do so, it must first
   311  // reacquire the mutex in exclusive mode.
   312  //
   313  // Returns true for retry if preparation should be retried.
   314  // +checklocksread:e.mu
   315  func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
   316  	switch e.net.State() {
   317  	case transport.DatagramEndpointStateInitial:
   318  	case transport.DatagramEndpointStateConnected:
   319  		return false, nil
   320  
   321  	case transport.DatagramEndpointStateBound:
   322  		if to == nil {
   323  			return false, &tcpip.ErrDestinationRequired{}
   324  		}
   325  		return false, nil
   326  	default:
   327  		return false, &tcpip.ErrInvalidEndpointState{}
   328  	}
   329  
   330  	e.mu.RUnlock()
   331  	e.mu.Lock()
   332  	defer e.mu.DowngradeLock()
   333  
   334  	// The state changed when we released the shared locked and re-acquired
   335  	// it in exclusive mode. Try again.
   336  	if e.net.State() != transport.DatagramEndpointStateInitial {
   337  		return true, nil
   338  	}
   339  
   340  	// The state is still 'initial', so try to bind the endpoint.
   341  	if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
   342  		return false, err
   343  	}
   344  
   345  	return true, nil
   346  }
   347  
   348  var _ tcpip.EndpointWithPreflight = (*endpoint)(nil)
   349  
   350  // Validates the passed WriteOptions and prepares the endpoint for writes
   351  // using those options. If the endpoint is unbound and the `To` address
   352  // is specified, binds the endpoint to that address.
   353  func (e *endpoint) Preflight(opts tcpip.WriteOptions) tcpip.Error {
   354  	var r bytes.Reader
   355  	udpInfo, err := e.prepareForWrite(&r, opts)
   356  	if err == nil {
   357  		udpInfo.ctx.Release()
   358  	}
   359  	return err
   360  }
   361  
   362  // Write writes data to the endpoint's peer. This method does not block
   363  // if the data cannot be written.
   364  func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   365  	n, err := e.write(p, opts)
   366  	switch err.(type) {
   367  	case nil:
   368  		e.stats.PacketsSent.Increment()
   369  	case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
   370  		e.stats.WriteErrors.InvalidArgs.Increment()
   371  	case *tcpip.ErrClosedForSend:
   372  		e.stats.WriteErrors.WriteClosed.Increment()
   373  	case *tcpip.ErrInvalidEndpointState:
   374  		e.stats.WriteErrors.InvalidEndpointState.Increment()
   375  	case *tcpip.ErrHostUnreachable, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
   376  		// Errors indicating any problem with IP routing of the packet.
   377  		e.stats.SendErrors.NoRoute.Increment()
   378  	default:
   379  		// For all other errors when writing to the network layer.
   380  		e.stats.SendErrors.SendToNetworkFailed.Increment()
   381  	}
   382  	return n, err
   383  }
   384  
   385  func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
   386  	e.mu.RLock()
   387  	defer e.mu.RUnlock()
   388  
   389  	// Prepare for write.
   390  	for {
   391  		retry, err := e.prepareForWriteInner(opts.To)
   392  		if err != nil {
   393  			return udpPacketInfo{}, err
   394  		}
   395  
   396  		if !retry {
   397  			break
   398  		}
   399  	}
   400  
   401  	dst, connected := e.net.GetRemoteAddress()
   402  	dst.Port = e.remotePort
   403  	if opts.To != nil {
   404  		if opts.To.Port == 0 {
   405  			// Port 0 is an invalid port to send to.
   406  			return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
   407  		}
   408  
   409  		dst = *opts.To
   410  	} else if !connected {
   411  		return udpPacketInfo{}, &tcpip.ErrDestinationRequired{}
   412  	}
   413  
   414  	ctx, err := e.net.AcquireContextForWrite(opts)
   415  	if err != nil {
   416  		return udpPacketInfo{}, err
   417  	}
   418  
   419  	if p.Len() > header.UDPMaximumPacketSize {
   420  		// Native linux behaviour differs for IPv4 and IPv6 packets; IPv4 packet
   421  		// errors aren't report to the error queue at all.
   422  		if ctx.PacketInfo().NetProto == header.IPv6ProtocolNumber {
   423  			so := e.SocketOptions()
   424  			if so.GetIPv6RecvError() {
   425  				so.QueueLocalErr(
   426  					&tcpip.ErrMessageTooLong{},
   427  					e.net.NetProto(),
   428  					uint32(p.Len()),
   429  					dst,
   430  					nil,
   431  				)
   432  			}
   433  		}
   434  		ctx.Release()
   435  		return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
   436  	}
   437  
   438  	return udpPacketInfo{
   439  		ctx:        ctx,
   440  		localPort:  e.localPort,
   441  		remotePort: dst.Port,
   442  	}, nil
   443  }
   444  
   445  func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   446  	// Do not hold lock when sending as loopback is synchronous and if the UDP
   447  	// datagram ends up generating an ICMP response then it can result in a
   448  	// deadlock where the ICMP response handling ends up acquiring this endpoint's
   449  	// mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a
   450  	// deadlock if another caller is trying to acquire e.mu in exclusive mode w/
   451  	// e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the
   452  	// lock can be eventually acquired.
   453  	//
   454  	// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
   455  	// locking is prohibited.
   456  
   457  	if err := e.LastError(); err != nil {
   458  		return 0, err
   459  	}
   460  
   461  	udpInfo, err := e.prepareForWrite(p, opts)
   462  	if err != nil {
   463  		return 0, err
   464  	}
   465  	defer udpInfo.ctx.Release()
   466  
   467  	dataSz := p.Len()
   468  	pktInfo := udpInfo.ctx.PacketInfo()
   469  	pkt := udpInfo.ctx.TryNewPacketBufferFromPayloader(header.UDPMinimumSize+int(pktInfo.MaxHeaderLength), p)
   470  	if pkt == nil {
   471  		return 0, &tcpip.ErrWouldBlock{}
   472  	}
   473  	defer pkt.DecRef()
   474  
   475  	// Initialize the UDP header.
   476  	udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
   477  	pkt.TransportProtocolNumber = ProtocolNumber
   478  
   479  	length := uint16(pkt.Size())
   480  	udp.Encode(&header.UDPFields{
   481  		SrcPort: udpInfo.localPort,
   482  		DstPort: udpInfo.remotePort,
   483  		Length:  length,
   484  	})
   485  
   486  	// Set the checksum field unless TX checksum offload is enabled.
   487  	// On IPv4, UDP checksum is optional, and a zero value indicates the
   488  	// transmitter skipped the checksum generation (RFC768).
   489  	// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
   490  	if pktInfo.RequiresTXTransportChecksum &&
   491  		(!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) {
   492  		xsum := udp.CalculateChecksum(checksum.Combine(
   493  			header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length),
   494  			pkt.Data().Checksum(),
   495  		))
   496  		// As per RFC 768 page 2,
   497  		//
   498  		//   Checksum is the 16-bit one's complement of the one's complement sum of
   499  		//   a pseudo header of information from the IP header, the UDP header, and
   500  		//   the data, padded with zero octets at the end (if necessary) to make a
   501  		//   multiple of two octets.
   502  		//
   503  		//	 The pseudo header conceptually prefixed to the UDP header contains the
   504  		//   source address, the destination address, the protocol, and the UDP
   505  		//   length. This information gives protection against misrouted datagrams.
   506  		//   This checksum procedure is the same as is used in TCP.
   507  		//
   508  		//   If the computed checksum is zero, it is transmitted as all ones (the
   509  		//   equivalent in one's complement arithmetic). An all zero transmitted
   510  		//   checksum value means that the transmitter generated no checksum (for
   511  		//   debugging or for higher level protocols that don't care).
   512  		//
   513  		// To avoid the zero value, we only calculate the one's complement of the
   514  		// one's complement sum if the sum is not all ones.
   515  		if xsum != math.MaxUint16 {
   516  			xsum = ^xsum
   517  		}
   518  		udp.SetChecksum(xsum)
   519  	}
   520  	if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
   521  		e.stack.Stats().UDP.PacketSendErrors.Increment()
   522  		return 0, err
   523  	}
   524  
   525  	// Track count of packets sent.
   526  	e.stack.Stats().UDP.PacketsSent.Increment()
   527  	return int64(dataSz), nil
   528  }
   529  
   530  // OnReuseAddressSet implements tcpip.SocketOptionsHandler.
   531  func (e *endpoint) OnReuseAddressSet(v bool) {
   532  	e.mu.Lock()
   533  	e.portFlags.MostRecent = v
   534  	e.mu.Unlock()
   535  }
   536  
   537  // OnReusePortSet implements tcpip.SocketOptionsHandler.
   538  func (e *endpoint) OnReusePortSet(v bool) {
   539  	e.mu.Lock()
   540  	e.portFlags.LoadBalanced = v
   541  	e.mu.Unlock()
   542  }
   543  
   544  // SetSockOptInt implements tcpip.Endpoint.
   545  func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
   546  	return e.net.SetSockOptInt(opt, v)
   547  }
   548  
   549  var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
   550  
   551  // HasNIC implements tcpip.SocketOptionsHandler.
   552  func (e *endpoint) HasNIC(id int32) bool {
   553  	return e.stack.HasNIC(tcpip.NICID(id))
   554  }
   555  
   556  // SetSockOpt implements tcpip.Endpoint.
   557  func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   558  	return e.net.SetSockOpt(opt)
   559  }
   560  
   561  // GetSockOptInt implements tcpip.Endpoint.
   562  func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
   563  	switch opt {
   564  	case tcpip.ReceiveQueueSizeOption:
   565  		v := 0
   566  		e.rcvMu.Lock()
   567  		if !e.rcvList.Empty() {
   568  			p := e.rcvList.Front()
   569  			v = p.pkt.Data().Size()
   570  		}
   571  		e.rcvMu.Unlock()
   572  		return v, nil
   573  
   574  	default:
   575  		return e.net.GetSockOptInt(opt)
   576  	}
   577  }
   578  
   579  // GetSockOpt implements tcpip.Endpoint.
   580  func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
   581  	return e.net.GetSockOpt(opt)
   582  }
   583  
   584  // udpPacketInfo holds information needed to send a UDP packet.
   585  type udpPacketInfo struct {
   586  	ctx        network.WriteContext
   587  	localPort  uint16
   588  	remotePort uint16
   589  }
   590  
   591  // Disconnect implements tcpip.Endpoint.
   592  func (e *endpoint) Disconnect() tcpip.Error {
   593  	e.mu.Lock()
   594  	defer e.mu.Unlock()
   595  
   596  	if e.net.State() != transport.DatagramEndpointStateConnected {
   597  		return nil
   598  	}
   599  	var (
   600  		id  stack.TransportEndpointID
   601  		btd tcpip.NICID
   602  	)
   603  
   604  	// We change this value below and we need the old value to unregister
   605  	// the endpoint.
   606  	boundPortFlags := e.boundPortFlags
   607  
   608  	// Exclude ephemerally bound endpoints.
   609  	info := e.net.Info()
   610  	info.ID.LocalPort = e.localPort
   611  	info.ID.RemotePort = e.remotePort
   612  	if e.net.WasBound() {
   613  		var err tcpip.Error
   614  		id = stack.TransportEndpointID{
   615  			LocalPort:    info.ID.LocalPort,
   616  			LocalAddress: info.ID.LocalAddress,
   617  		}
   618  		id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
   619  		if err != nil {
   620  			return err
   621  		}
   622  		boundPortFlags = e.boundPortFlags
   623  	} else {
   624  		if info.ID.LocalPort != 0 {
   625  			// Release the ephemeral port.
   626  			portRes := ports.Reservation{
   627  				Networks:     e.effectiveNetProtos,
   628  				Transport:    ProtocolNumber,
   629  				Addr:         info.ID.LocalAddress,
   630  				Port:         info.ID.LocalPort,
   631  				Flags:        boundPortFlags,
   632  				BindToDevice: e.boundBindToDevice,
   633  				Dest:         tcpip.FullAddress{},
   634  			}
   635  			e.stack.ReleasePort(portRes)
   636  			e.boundPortFlags = ports.Flags{}
   637  		}
   638  	}
   639  
   640  	e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice)
   641  	e.boundBindToDevice = btd
   642  	e.localPort = id.LocalPort
   643  	e.remotePort = id.RemotePort
   644  
   645  	e.net.Disconnect()
   646  
   647  	return nil
   648  }
   649  
   650  // Connect connects the endpoint to its peer. Specifying a NIC is optional.
   651  func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
   652  	e.mu.Lock()
   653  	defer e.mu.Unlock()
   654  
   655  	err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
   656  		nextID.LocalPort = e.localPort
   657  		nextID.RemotePort = addr.Port
   658  
   659  		// Even if we're connected, this endpoint can still be used to send
   660  		// packets on a different network protocol, so we register both even if
   661  		// v6only is set to false and this is an ipv6 endpoint.
   662  		netProtos := []tcpip.NetworkProtocolNumber{netProto}
   663  		if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) {
   664  			netProtos = []tcpip.NetworkProtocolNumber{
   665  				header.IPv4ProtocolNumber,
   666  				header.IPv6ProtocolNumber,
   667  			}
   668  		}
   669  
   670  		oldPortFlags := e.boundPortFlags
   671  
   672  		// Remove the old registration.
   673  		if e.localPort != 0 {
   674  			previousID.LocalPort = e.localPort
   675  			previousID.RemotePort = e.remotePort
   676  			e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice)
   677  		}
   678  
   679  		nextID, btd, err := e.registerWithStack(netProtos, nextID)
   680  		if err != nil {
   681  			return err
   682  		}
   683  
   684  		e.localPort = nextID.LocalPort
   685  		e.remotePort = nextID.RemotePort
   686  		e.boundBindToDevice = btd
   687  		e.effectiveNetProtos = netProtos
   688  		return nil
   689  	})
   690  	if err != nil {
   691  		return err
   692  	}
   693  
   694  	e.rcvMu.Lock()
   695  	e.rcvReady = true
   696  	e.rcvMu.Unlock()
   697  	return nil
   698  }
   699  
   700  // ConnectEndpoint is not supported.
   701  func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
   702  	return &tcpip.ErrInvalidEndpointState{}
   703  }
   704  
   705  // Shutdown closes the read and/or write end of the endpoint connection
   706  // to its peer.
   707  func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
   708  	e.mu.Lock()
   709  	defer e.mu.Unlock()
   710  
   711  	switch state := e.net.State(); state {
   712  	case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
   713  		return &tcpip.ErrNotConnected{}
   714  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   715  	default:
   716  		panic(fmt.Sprintf("unhandled state = %s", state))
   717  	}
   718  
   719  	if flags&tcpip.ShutdownWrite != 0 {
   720  		if err := e.net.Shutdown(); err != nil {
   721  			return err
   722  		}
   723  	}
   724  
   725  	if flags&tcpip.ShutdownRead != 0 {
   726  		e.readShutdown = true
   727  
   728  		e.rcvMu.Lock()
   729  		wasClosed := e.rcvClosed
   730  		e.rcvClosed = true
   731  		e.rcvMu.Unlock()
   732  
   733  		if !wasClosed {
   734  			e.waiterQueue.Notify(waiter.ReadableEvents)
   735  		}
   736  	}
   737  
   738  	if e.net.State() == transport.DatagramEndpointStateBound {
   739  		return &tcpip.ErrNotConnected{}
   740  	}
   741  	return nil
   742  }
   743  
   744  // Listen is not supported by UDP, it just fails.
   745  func (*endpoint) Listen(int) tcpip.Error {
   746  	return &tcpip.ErrNotSupported{}
   747  }
   748  
   749  // Accept is not supported by UDP, it just fails.
   750  func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
   751  	return nil, nil, &tcpip.ErrNotSupported{}
   752  }
   753  
   754  func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
   755  	bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
   756  	if e.localPort == 0 {
   757  		portRes := ports.Reservation{
   758  			Networks:     netProtos,
   759  			Transport:    ProtocolNumber,
   760  			Addr:         id.LocalAddress,
   761  			Port:         id.LocalPort,
   762  			Flags:        e.portFlags,
   763  			BindToDevice: bindToDevice,
   764  			Dest:         tcpip.FullAddress{},
   765  		}
   766  		port, err := e.stack.ReservePort(e.stack.SecureRNG(), portRes, nil /* testPort */)
   767  		if err != nil {
   768  			return id, bindToDevice, err
   769  		}
   770  		id.LocalPort = port
   771  	}
   772  	e.boundPortFlags = e.portFlags
   773  
   774  	err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
   775  	if err != nil {
   776  		portRes := ports.Reservation{
   777  			Networks:     netProtos,
   778  			Transport:    ProtocolNumber,
   779  			Addr:         id.LocalAddress,
   780  			Port:         id.LocalPort,
   781  			Flags:        e.boundPortFlags,
   782  			BindToDevice: bindToDevice,
   783  			Dest:         tcpip.FullAddress{},
   784  		}
   785  		e.stack.ReleasePort(portRes)
   786  		e.boundPortFlags = ports.Flags{}
   787  	}
   788  	return id, bindToDevice, err
   789  }
   790  
   791  func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
   792  	// Don't allow binding once endpoint is not in the initial state
   793  	// anymore.
   794  	if e.net.State() != transport.DatagramEndpointStateInitial {
   795  		return &tcpip.ErrInvalidEndpointState{}
   796  	}
   797  
   798  	err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
   799  		// Expand netProtos to include v4 and v6 if the caller is binding to a
   800  		// wildcard (empty) address, and this is an IPv6 endpoint with v6only
   801  		// set to false.
   802  		netProtos := []tcpip.NetworkProtocolNumber{boundNetProto}
   803  		if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == (tcpip.Address{}) && e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) {
   804  			netProtos = []tcpip.NetworkProtocolNumber{
   805  				header.IPv6ProtocolNumber,
   806  				header.IPv4ProtocolNumber,
   807  			}
   808  		}
   809  
   810  		id := stack.TransportEndpointID{
   811  			LocalPort:    addr.Port,
   812  			LocalAddress: boundAddr,
   813  		}
   814  		id, btd, err := e.registerWithStack(netProtos, id)
   815  		if err != nil {
   816  			return err
   817  		}
   818  
   819  		e.localPort = id.LocalPort
   820  		e.boundBindToDevice = btd
   821  		e.effectiveNetProtos = netProtos
   822  		return nil
   823  	})
   824  	if err != nil {
   825  		return err
   826  	}
   827  
   828  	e.rcvMu.Lock()
   829  	e.rcvReady = true
   830  	e.rcvMu.Unlock()
   831  	return nil
   832  }
   833  
   834  // Bind binds the endpoint to a specific local address and port.
   835  // Specifying a NIC is optional.
   836  func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
   837  	e.mu.Lock()
   838  	defer e.mu.Unlock()
   839  
   840  	err := e.bindLocked(addr)
   841  	if err != nil {
   842  		return err
   843  	}
   844  
   845  	return nil
   846  }
   847  
   848  // GetLocalAddress returns the address to which the endpoint is bound.
   849  func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   850  	e.mu.RLock()
   851  	defer e.mu.RUnlock()
   852  
   853  	addr := e.net.GetLocalAddress()
   854  	addr.Port = e.localPort
   855  	return addr, nil
   856  }
   857  
   858  // GetRemoteAddress returns the address to which the endpoint is connected.
   859  func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
   860  	e.mu.RLock()
   861  	defer e.mu.RUnlock()
   862  
   863  	addr, connected := e.net.GetRemoteAddress()
   864  	if !connected || e.remotePort == 0 {
   865  		return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
   866  	}
   867  
   868  	addr.Port = e.remotePort
   869  	return addr, nil
   870  }
   871  
   872  // Readiness returns the current readiness of the endpoint. For example, if
   873  // waiter.EventIn is set, the endpoint is immediately readable.
   874  func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
   875  	var result waiter.EventMask
   876  
   877  	if e.net.HasSendSpace() {
   878  		result |= waiter.WritableEvents & mask
   879  	}
   880  
   881  	// Determine if the endpoint is readable if requested.
   882  	if mask&waiter.ReadableEvents != 0 {
   883  		e.rcvMu.Lock()
   884  		if !e.rcvList.Empty() || e.rcvClosed {
   885  			result |= waiter.ReadableEvents
   886  		}
   887  		e.rcvMu.Unlock()
   888  	}
   889  
   890  	e.lastErrorMu.Lock()
   891  	hasError := e.lastError != nil
   892  	e.lastErrorMu.Unlock()
   893  	if hasError {
   894  		result |= waiter.EventErr
   895  	}
   896  	return result
   897  }
   898  
   899  // HandlePacket is called by the stack when new packets arrive to this transport
   900  // endpoint.
   901  func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
   902  	// Get the header then trim it from the view.
   903  	hdr := header.UDP(pkt.TransportHeader().Slice())
   904  	netHdr := pkt.Network()
   905  	lengthValid, csumValid := header.UDPValid(
   906  		hdr,
   907  		func() uint16 { return pkt.Data().Checksum() },
   908  		uint16(pkt.Data().Size()),
   909  		pkt.NetworkProtocolNumber,
   910  		netHdr.SourceAddress(),
   911  		netHdr.DestinationAddress(),
   912  		pkt.RXChecksumValidated)
   913  	if !lengthValid {
   914  		// Malformed packet.
   915  		e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
   916  		e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
   917  		return
   918  	}
   919  
   920  	if !csumValid {
   921  		e.stack.Stats().UDP.ChecksumErrors.Increment()
   922  		e.stats.ReceiveErrors.ChecksumErrors.Increment()
   923  		return
   924  	}
   925  
   926  	e.stack.Stats().UDP.PacketsReceived.Increment()
   927  	e.stats.PacketsReceived.Increment()
   928  
   929  	e.rcvMu.Lock()
   930  	// Drop the packet if our buffer is not ready to receive packets.
   931  	if !e.rcvReady || e.rcvClosed {
   932  		e.rcvMu.Unlock()
   933  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
   934  		e.stats.ReceiveErrors.ClosedReceiver.Increment()
   935  		return
   936  	}
   937  
   938  	rcvBufSize := e.ops.GetReceiveBufferSize()
   939  	// Drop the packet if our buffer is currently full.
   940  	if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
   941  		e.rcvMu.Unlock()
   942  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
   943  		e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
   944  		return
   945  	}
   946  
   947  	wasEmpty := e.rcvBufSize == 0
   948  
   949  	// Push new packet into receive list and increment the buffer size.
   950  	packet := &udpPacket{
   951  		netProto: pkt.NetworkProtocolNumber,
   952  		senderAddress: tcpip.FullAddress{
   953  			NIC:  pkt.NICID,
   954  			Addr: id.RemoteAddress,
   955  			Port: hdr.SourcePort(),
   956  		},
   957  		destinationAddress: tcpip.FullAddress{
   958  			NIC:  pkt.NICID,
   959  			Addr: id.LocalAddress,
   960  			Port: hdr.DestinationPort(),
   961  		},
   962  		pkt: pkt.IncRef(),
   963  	}
   964  	e.rcvList.PushBack(packet)
   965  	e.rcvBufSize += pkt.Data().Size()
   966  
   967  	// Save any useful information from the network header to the packet.
   968  	packet.tosOrTClass, _ = pkt.Network().TOS()
   969  	switch pkt.NetworkProtocolNumber {
   970  	case header.IPv4ProtocolNumber:
   971  		packet.ttlOrHopLimit = header.IPv4(pkt.NetworkHeader().Slice()).TTL()
   972  	case header.IPv6ProtocolNumber:
   973  		packet.ttlOrHopLimit = header.IPv6(pkt.NetworkHeader().Slice()).HopLimit()
   974  	}
   975  
   976  	// TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
   977  	// address. packetInfo.LocalAddr should hold a unicast address that can be
   978  	// used to respond to the incoming packet.
   979  	localAddr := pkt.Network().DestinationAddress()
   980  	packet.packetInfo.LocalAddr = localAddr
   981  	packet.packetInfo.DestinationAddr = localAddr
   982  	packet.packetInfo.NIC = pkt.NICID
   983  	packet.receivedAt = e.stack.Clock().Now()
   984  
   985  	e.rcvMu.Unlock()
   986  
   987  	// Notify any waiters that there's data to be read now.
   988  	if wasEmpty {
   989  		e.waiterQueue.Notify(waiter.ReadableEvents)
   990  	}
   991  }
   992  
   993  func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) {
   994  	// Update last error first.
   995  	e.lastErrorMu.Lock()
   996  	e.lastError = err
   997  	e.lastErrorMu.Unlock()
   998  
   999  	var recvErr bool
  1000  	switch pkt.NetworkProtocolNumber {
  1001  	case header.IPv4ProtocolNumber:
  1002  		recvErr = e.SocketOptions().GetIPv4RecvError()
  1003  	case header.IPv6ProtocolNumber:
  1004  		recvErr = e.SocketOptions().GetIPv6RecvError()
  1005  	default:
  1006  		panic(fmt.Sprintf("unhandled network protocol number = %d", pkt.NetworkProtocolNumber))
  1007  	}
  1008  
  1009  	if recvErr {
  1010  		// Linux passes the payload without the UDP header.
  1011  		payload := pkt.Data().AsRange().ToView()
  1012  		udp := header.UDP(payload.AsSlice())
  1013  		if len(udp) >= header.UDPMinimumSize {
  1014  			payload.TrimFront(header.UDPMinimumSize)
  1015  		}
  1016  
  1017  		id := e.net.Info().ID
  1018  		e.mu.RLock()
  1019  		e.SocketOptions().QueueErr(&tcpip.SockError{
  1020  			Err:     err,
  1021  			Cause:   transErr,
  1022  			Payload: payload,
  1023  			Dst: tcpip.FullAddress{
  1024  				NIC:  pkt.NICID,
  1025  				Addr: id.RemoteAddress,
  1026  				Port: e.remotePort,
  1027  			},
  1028  			Offender: tcpip.FullAddress{
  1029  				NIC:  pkt.NICID,
  1030  				Addr: id.LocalAddress,
  1031  				Port: e.localPort,
  1032  			},
  1033  			NetProto: pkt.NetworkProtocolNumber,
  1034  		})
  1035  		e.mu.RUnlock()
  1036  	}
  1037  
  1038  	// Notify of the error.
  1039  	e.waiterQueue.Notify(waiter.EventErr)
  1040  }
  1041  
  1042  // HandleError implements stack.TransportEndpoint.
  1043  func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) {
  1044  	// TODO(gvisor.dev/issues/5270): Handle all transport errors.
  1045  	switch transErr.Kind() {
  1046  	case stack.DestinationPortUnreachableTransportError:
  1047  		if e.net.State() == transport.DatagramEndpointStateConnected {
  1048  			e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt)
  1049  		}
  1050  	}
  1051  }
  1052  
  1053  // State implements tcpip.Endpoint.
  1054  func (e *endpoint) State() uint32 {
  1055  	return uint32(e.net.State())
  1056  }
  1057  
  1058  // Info returns a copy of the endpoint info.
  1059  func (e *endpoint) Info() tcpip.EndpointInfo {
  1060  	e.mu.RLock()
  1061  	defer e.mu.RUnlock()
  1062  	info := e.net.Info()
  1063  	info.ID.LocalPort = e.localPort
  1064  	info.ID.RemotePort = e.remotePort
  1065  	return &info
  1066  }
  1067  
  1068  // Stats returns a pointer to the endpoint stats.
  1069  func (e *endpoint) Stats() tcpip.EndpointStats {
  1070  	return &e.stats
  1071  }
  1072  
  1073  // Wait implements tcpip.Endpoint.
  1074  func (*endpoint) Wait() {}
  1075  
  1076  // SetOwner implements tcpip.Endpoint.
  1077  func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
  1078  	e.net.SetOwner(owner)
  1079  }
  1080  
  1081  // SocketOptions implements tcpip.Endpoint.
  1082  func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
  1083  	return &e.ops
  1084  }
  1085  
  1086  // freeze prevents any more packets from being delivered to the endpoint.
  1087  func (e *endpoint) freeze() {
  1088  	e.mu.Lock()
  1089  	e.frozen = true
  1090  	e.mu.Unlock()
  1091  }
  1092  
  1093  // thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
  1094  // new packets to be delivered again.
  1095  func (e *endpoint) thaw() {
  1096  	e.mu.Lock()
  1097  	e.frozen = false
  1098  	e.mu.Unlock()
  1099  }