github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/transport/udp/endpoint.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io"
    21  	"math"
    22  	"time"
    23  
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/buffer"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/checksum"
    28  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header"
    29  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/ports"
    30  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack"
    31  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/transport"
    32  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/transport/internal/network"
    33  	"github.com/nicocha30/gvisor-ligolo/pkg/waiter"
    34  )
    35  
    36  // +stateify savable
    37  type udpPacket struct {
    38  	udpPacketEntry
    39  	netProto           tcpip.NetworkProtocolNumber
    40  	senderAddress      tcpip.FullAddress
    41  	destinationAddress tcpip.FullAddress
    42  	packetInfo         tcpip.IPPacketInfo
    43  	pkt                stack.PacketBufferPtr
    44  	receivedAt         time.Time `state:".(int64)"`
    45  	// tosOrTClass stores either the Type of Service for IPv4 or the Traffic Class
    46  	// for IPv6.
    47  	tosOrTClass uint8
    48  	// ttlOrHopLimit stores either the TTL for IPv4 or the HopLimit for IPv6
    49  	ttlOrHopLimit uint8
    50  }
    51  
    52  // endpoint represents a UDP endpoint. This struct serves as the interface
    53  // between users of the endpoint and the protocol implementation; it is legal to
    54  // have concurrent goroutines make calls into the endpoint, they are properly
    55  // synchronized.
    56  //
    57  // It implements tcpip.Endpoint.
    58  //
    59  // +stateify savable
    60  type endpoint struct {
    61  	tcpip.DefaultSocketOptionsHandler
    62  
    63  	// The following fields are initialized at creation time and do not
    64  	// change throughout the lifetime of the endpoint.
    65  	stack       *stack.Stack `state:"manual"`
    66  	waiterQueue *waiter.Queue
    67  	uniqueID    uint64
    68  	net         network.Endpoint
    69  	stats       tcpip.TransportEndpointStats
    70  	ops         tcpip.SocketOptions
    71  
    72  	// The following fields are used to manage the receive queue, and are
    73  	// protected by rcvMu.
    74  	rcvMu      sync.Mutex `state:"nosave"`
    75  	rcvReady   bool
    76  	rcvList    udpPacketList
    77  	rcvBufSize int
    78  	rcvClosed  bool
    79  
    80  	lastErrorMu sync.Mutex `state:"nosave"`
    81  	lastError   tcpip.Error
    82  
    83  	// The following fields are protected by the mu mutex.
    84  	mu        sync.RWMutex `state:"nosave"`
    85  	portFlags ports.Flags
    86  
    87  	// Values used to reserve a port or register a transport endpoint.
    88  	// (which ever happens first).
    89  	boundBindToDevice tcpip.NICID
    90  	boundPortFlags    ports.Flags
    91  
    92  	readShutdown bool
    93  
    94  	// effectiveNetProtos contains the network protocols actually in use. In
    95  	// most cases it will only contain "netProto", but in cases like IPv6
    96  	// endpoints with v6only set to false, this could include multiple
    97  	// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
    98  	// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
    99  	// address).
   100  	effectiveNetProtos []tcpip.NetworkProtocolNumber
   101  
   102  	// frozen indicates if the packets should be delivered to the endpoint
   103  	// during restore.
   104  	frozen bool
   105  
   106  	localPort  uint16
   107  	remotePort uint16
   108  }
   109  
   110  func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
   111  	e := &endpoint{
   112  		stack:       s,
   113  		waiterQueue: waiterQueue,
   114  		uniqueID:    s.UniqueID(),
   115  	}
   116  	e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
   117  	e.ops.SetMulticastLoop(true)
   118  	e.ops.SetSendBufferSize(32*1024, false /* notify */)
   119  	e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
   120  	e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops, waiterQueue)
   121  
   122  	// Override with stack defaults.
   123  	var ss tcpip.SendBufferSizeOption
   124  	if err := s.Option(&ss); err == nil {
   125  		e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
   126  	}
   127  
   128  	var rs tcpip.ReceiveBufferSizeOption
   129  	if err := s.Option(&rs); err == nil {
   130  		e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
   131  	}
   132  
   133  	return e
   134  }
   135  
   136  // WakeupWriters implements tcpip.SocketOptionsHandler.
   137  func (e *endpoint) WakeupWriters() {
   138  	e.net.MaybeSignalWritable()
   139  }
   140  
   141  // UniqueID implements stack.TransportEndpoint.
   142  func (e *endpoint) UniqueID() uint64 {
   143  	return e.uniqueID
   144  }
   145  
   146  func (e *endpoint) LastError() tcpip.Error {
   147  	e.lastErrorMu.Lock()
   148  	defer e.lastErrorMu.Unlock()
   149  
   150  	err := e.lastError
   151  	e.lastError = nil
   152  	return err
   153  }
   154  
   155  // UpdateLastError implements tcpip.SocketOptionsHandler.
   156  func (e *endpoint) UpdateLastError(err tcpip.Error) {
   157  	e.lastErrorMu.Lock()
   158  	e.lastError = err
   159  	e.lastErrorMu.Unlock()
   160  }
   161  
   162  // Abort implements stack.TransportEndpoint.
   163  func (e *endpoint) Abort() {
   164  	e.Close()
   165  }
   166  
   167  // Close puts the endpoint in a closed state and frees all resources
   168  // associated with it.
   169  func (e *endpoint) Close() {
   170  	e.mu.Lock()
   171  
   172  	switch state := e.net.State(); state {
   173  	case transport.DatagramEndpointStateInitial:
   174  	case transport.DatagramEndpointStateClosed:
   175  		e.mu.Unlock()
   176  		return
   177  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   178  		id := e.net.Info().ID
   179  		id.LocalPort = e.localPort
   180  		id.RemotePort = e.remotePort
   181  		e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice)
   182  		portRes := ports.Reservation{
   183  			Networks:     e.effectiveNetProtos,
   184  			Transport:    ProtocolNumber,
   185  			Addr:         id.LocalAddress,
   186  			Port:         id.LocalPort,
   187  			Flags:        e.boundPortFlags,
   188  			BindToDevice: e.boundBindToDevice,
   189  			Dest:         tcpip.FullAddress{},
   190  		}
   191  		e.stack.ReleasePort(portRes)
   192  		e.boundBindToDevice = 0
   193  		e.boundPortFlags = ports.Flags{}
   194  	default:
   195  		panic(fmt.Sprintf("unhandled state = %s", state))
   196  	}
   197  
   198  	// Close the receive list and drain it.
   199  	e.rcvMu.Lock()
   200  	e.rcvClosed = true
   201  	e.rcvBufSize = 0
   202  	for !e.rcvList.Empty() {
   203  		p := e.rcvList.Front()
   204  		e.rcvList.Remove(p)
   205  		p.pkt.DecRef()
   206  	}
   207  	e.rcvMu.Unlock()
   208  
   209  	e.net.Shutdown()
   210  	e.net.Close()
   211  	e.readShutdown = true
   212  	e.mu.Unlock()
   213  
   214  	e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   215  }
   216  
   217  // ModerateRecvBuf implements tcpip.Endpoint.
   218  func (*endpoint) ModerateRecvBuf(int) {}
   219  
   220  // Read implements tcpip.Endpoint.
   221  func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
   222  	if err := e.LastError(); err != nil {
   223  		return tcpip.ReadResult{}, err
   224  	}
   225  
   226  	e.rcvMu.Lock()
   227  
   228  	if e.rcvList.Empty() {
   229  		var err tcpip.Error = &tcpip.ErrWouldBlock{}
   230  		if e.rcvClosed {
   231  			e.stats.ReadErrors.ReadClosed.Increment()
   232  			err = &tcpip.ErrClosedForReceive{}
   233  		}
   234  		e.rcvMu.Unlock()
   235  		return tcpip.ReadResult{}, err
   236  	}
   237  
   238  	p := e.rcvList.Front()
   239  	if !opts.Peek {
   240  		e.rcvList.Remove(p)
   241  		defer p.pkt.DecRef()
   242  		e.rcvBufSize -= p.pkt.Data().Size()
   243  	}
   244  	e.rcvMu.Unlock()
   245  
   246  	// Control Messages
   247  	// TODO(https://gvisor.dev/issue/7012): Share control message code with other
   248  	// network endpoints.
   249  	cm := tcpip.ReceivableControlMessages{
   250  		HasTimestamp: true,
   251  		Timestamp:    p.receivedAt,
   252  	}
   253  	switch p.netProto {
   254  	case header.IPv4ProtocolNumber:
   255  		if e.ops.GetReceiveTOS() {
   256  			cm.HasTOS = true
   257  			cm.TOS = p.tosOrTClass
   258  		}
   259  		if e.ops.GetReceiveTTL() {
   260  			cm.HasTTL = true
   261  			cm.TTL = p.ttlOrHopLimit
   262  		}
   263  		if e.ops.GetReceivePacketInfo() {
   264  			cm.HasIPPacketInfo = true
   265  			cm.PacketInfo = p.packetInfo
   266  		}
   267  	case header.IPv6ProtocolNumber:
   268  		if e.ops.GetReceiveTClass() {
   269  			cm.HasTClass = true
   270  			// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
   271  			cm.TClass = uint32(p.tosOrTClass)
   272  		}
   273  		if e.ops.GetReceiveHopLimit() {
   274  			cm.HasHopLimit = true
   275  			cm.HopLimit = p.ttlOrHopLimit
   276  		}
   277  		if e.ops.GetIPv6ReceivePacketInfo() {
   278  			cm.HasIPv6PacketInfo = true
   279  			cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
   280  				NIC:  p.packetInfo.NIC,
   281  				Addr: p.packetInfo.DestinationAddr,
   282  			}
   283  		}
   284  	default:
   285  		panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
   286  	}
   287  
   288  	if e.ops.GetReceiveOriginalDstAddress() {
   289  		cm.HasOriginalDstAddress = true
   290  		cm.OriginalDstAddress = p.destinationAddress
   291  	}
   292  
   293  	// Read Result
   294  	res := tcpip.ReadResult{
   295  		Total:           p.pkt.Data().Size(),
   296  		ControlMessages: cm,
   297  	}
   298  	if opts.NeedRemoteAddr {
   299  		res.RemoteAddr = p.senderAddress
   300  	}
   301  
   302  	n, err := p.pkt.Data().ReadTo(dst, opts.Peek)
   303  	if n == 0 && err != nil {
   304  		return res, &tcpip.ErrBadBuffer{}
   305  	}
   306  	res.Count = n
   307  	return res, nil
   308  }
   309  
   310  // prepareForWriteInner prepares the endpoint for sending data. In particular,
   311  // it binds it if it's still in the initial state. To do so, it must first
   312  // reacquire the mutex in exclusive mode.
   313  //
   314  // Returns true for retry if preparation should be retried.
   315  // +checklocksread:e.mu
   316  func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
   317  	switch e.net.State() {
   318  	case transport.DatagramEndpointStateInitial:
   319  	case transport.DatagramEndpointStateConnected:
   320  		return false, nil
   321  
   322  	case transport.DatagramEndpointStateBound:
   323  		if to == nil {
   324  			return false, &tcpip.ErrDestinationRequired{}
   325  		}
   326  		return false, nil
   327  	default:
   328  		return false, &tcpip.ErrInvalidEndpointState{}
   329  	}
   330  
   331  	e.mu.RUnlock()
   332  	e.mu.Lock()
   333  	defer e.mu.DowngradeLock()
   334  
   335  	// The state changed when we released the shared locked and re-acquired
   336  	// it in exclusive mode. Try again.
   337  	if e.net.State() != transport.DatagramEndpointStateInitial {
   338  		return true, nil
   339  	}
   340  
   341  	// The state is still 'initial', so try to bind the endpoint.
   342  	if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
   343  		return false, err
   344  	}
   345  
   346  	return true, nil
   347  }
   348  
   349  var _ tcpip.EndpointWithPreflight = (*endpoint)(nil)
   350  
   351  // Validates the passed WriteOptions and prepares the endpoint for writes
   352  // using those options. If the endpoint is unbound and the `To` address
   353  // is specified, binds the endpoint to that address.
   354  func (e *endpoint) Preflight(opts tcpip.WriteOptions) tcpip.Error {
   355  	var r bytes.Reader
   356  	udpInfo, err := e.prepareForWrite(&r, opts)
   357  	if err == nil {
   358  		udpInfo.ctx.Release()
   359  	}
   360  	return err
   361  }
   362  
   363  // Write writes data to the endpoint's peer. This method does not block
   364  // if the data cannot be written.
   365  func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   366  	n, err := e.write(p, opts)
   367  	switch err.(type) {
   368  	case nil:
   369  		e.stats.PacketsSent.Increment()
   370  	case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
   371  		e.stats.WriteErrors.InvalidArgs.Increment()
   372  	case *tcpip.ErrClosedForSend:
   373  		e.stats.WriteErrors.WriteClosed.Increment()
   374  	case *tcpip.ErrInvalidEndpointState:
   375  		e.stats.WriteErrors.InvalidEndpointState.Increment()
   376  	case *tcpip.ErrHostUnreachable, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
   377  		// Errors indicating any problem with IP routing of the packet.
   378  		e.stats.SendErrors.NoRoute.Increment()
   379  	default:
   380  		// For all other errors when writing to the network layer.
   381  		e.stats.SendErrors.SendToNetworkFailed.Increment()
   382  	}
   383  	return n, err
   384  }
   385  
   386  func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
   387  	e.mu.RLock()
   388  	defer e.mu.RUnlock()
   389  
   390  	// Prepare for write.
   391  	for {
   392  		retry, err := e.prepareForWriteInner(opts.To)
   393  		if err != nil {
   394  			return udpPacketInfo{}, err
   395  		}
   396  
   397  		if !retry {
   398  			break
   399  		}
   400  	}
   401  
   402  	dst, connected := e.net.GetRemoteAddress()
   403  	dst.Port = e.remotePort
   404  	if opts.To != nil {
   405  		if opts.To.Port == 0 {
   406  			// Port 0 is an invalid port to send to.
   407  			return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
   408  		}
   409  
   410  		dst = *opts.To
   411  	} else if !connected {
   412  		return udpPacketInfo{}, &tcpip.ErrDestinationRequired{}
   413  	}
   414  
   415  	ctx, err := e.net.AcquireContextForWrite(opts)
   416  	if err != nil {
   417  		return udpPacketInfo{}, err
   418  	}
   419  
   420  	if p.Len() > header.UDPMaximumPacketSize {
   421  		// Native linux behaviour differs for IPv4 and IPv6 packets; IPv4 packet
   422  		// errors aren't report to the error queue at all.
   423  		if ctx.PacketInfo().NetProto == header.IPv6ProtocolNumber {
   424  			so := e.SocketOptions()
   425  			if so.GetIPv6RecvError() {
   426  				so.QueueLocalErr(
   427  					&tcpip.ErrMessageTooLong{},
   428  					e.net.NetProto(),
   429  					uint32(p.Len()),
   430  					dst,
   431  					nil,
   432  				)
   433  			}
   434  		}
   435  		ctx.Release()
   436  		return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
   437  	}
   438  
   439  	var buf buffer.Buffer
   440  	if _, err := buf.WriteFromReader(p, int64(p.Len())); err != nil {
   441  		buf.Release()
   442  		ctx.Release()
   443  		return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
   444  	}
   445  
   446  	return udpPacketInfo{
   447  		ctx:        ctx,
   448  		data:       buf,
   449  		localPort:  e.localPort,
   450  		remotePort: dst.Port,
   451  	}, nil
   452  }
   453  
   454  func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   455  	// Do not hold lock when sending as loopback is synchronous and if the UDP
   456  	// datagram ends up generating an ICMP response then it can result in a
   457  	// deadlock where the ICMP response handling ends up acquiring this endpoint's
   458  	// mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a
   459  	// deadlock if another caller is trying to acquire e.mu in exclusive mode w/
   460  	// e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the
   461  	// lock can be eventually acquired.
   462  	//
   463  	// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
   464  	// locking is prohibited.
   465  
   466  	if err := e.LastError(); err != nil {
   467  		return 0, err
   468  	}
   469  
   470  	udpInfo, err := e.prepareForWrite(p, opts)
   471  	if err != nil {
   472  		return 0, err
   473  	}
   474  	defer udpInfo.ctx.Release()
   475  
   476  	dataSz := udpInfo.data.Size()
   477  	pktInfo := udpInfo.ctx.PacketInfo()
   478  	pkt := udpInfo.ctx.TryNewPacketBuffer(header.UDPMinimumSize+int(pktInfo.MaxHeaderLength), udpInfo.data)
   479  	if pkt.IsNil() {
   480  		return 0, &tcpip.ErrWouldBlock{}
   481  	}
   482  	defer pkt.DecRef()
   483  
   484  	// Initialize the UDP header.
   485  	udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
   486  	pkt.TransportProtocolNumber = ProtocolNumber
   487  
   488  	length := uint16(pkt.Size())
   489  	udp.Encode(&header.UDPFields{
   490  		SrcPort: udpInfo.localPort,
   491  		DstPort: udpInfo.remotePort,
   492  		Length:  length,
   493  	})
   494  
   495  	// Set the checksum field unless TX checksum offload is enabled.
   496  	// On IPv4, UDP checksum is optional, and a zero value indicates the
   497  	// transmitter skipped the checksum generation (RFC768).
   498  	// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
   499  	if pktInfo.RequiresTXTransportChecksum &&
   500  		(!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) {
   501  		xsum := udp.CalculateChecksum(checksum.Combine(
   502  			header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length),
   503  			pkt.Data().Checksum(),
   504  		))
   505  		// As per RFC 768 page 2,
   506  		//
   507  		//   Checksum is the 16-bit one's complement of the one's complement sum of
   508  		//   a pseudo header of information from the IP header, the UDP header, and
   509  		//   the data, padded with zero octets at the end (if necessary) to make a
   510  		//   multiple of two octets.
   511  		//
   512  		//	 The pseudo header conceptually prefixed to the UDP header contains the
   513  		//   source address, the destination address, the protocol, and the UDP
   514  		//   length. This information gives protection against misrouted datagrams.
   515  		//   This checksum procedure is the same as is used in TCP.
   516  		//
   517  		//   If the computed checksum is zero, it is transmitted as all ones (the
   518  		//   equivalent in one's complement arithmetic). An all zero transmitted
   519  		//   checksum value means that the transmitter generated no checksum (for
   520  		//   debugging or for higher level protocols that don't care).
   521  		//
   522  		// To avoid the zero value, we only calculate the one's complement of the
   523  		// one's complement sum if the sum is not all ones.
   524  		if xsum != math.MaxUint16 {
   525  			xsum = ^xsum
   526  		}
   527  		udp.SetChecksum(xsum)
   528  	}
   529  	if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
   530  		e.stack.Stats().UDP.PacketSendErrors.Increment()
   531  		return 0, err
   532  	}
   533  
   534  	// Track count of packets sent.
   535  	e.stack.Stats().UDP.PacketsSent.Increment()
   536  	return int64(dataSz), nil
   537  }
   538  
   539  // OnReuseAddressSet implements tcpip.SocketOptionsHandler.
   540  func (e *endpoint) OnReuseAddressSet(v bool) {
   541  	e.mu.Lock()
   542  	e.portFlags.MostRecent = v
   543  	e.mu.Unlock()
   544  }
   545  
   546  // OnReusePortSet implements tcpip.SocketOptionsHandler.
   547  func (e *endpoint) OnReusePortSet(v bool) {
   548  	e.mu.Lock()
   549  	e.portFlags.LoadBalanced = v
   550  	e.mu.Unlock()
   551  }
   552  
   553  // SetSockOptInt implements tcpip.Endpoint.
   554  func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
   555  	return e.net.SetSockOptInt(opt, v)
   556  }
   557  
   558  var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
   559  
   560  // HasNIC implements tcpip.SocketOptionsHandler.
   561  func (e *endpoint) HasNIC(id int32) bool {
   562  	return e.stack.HasNIC(tcpip.NICID(id))
   563  }
   564  
   565  // SetSockOpt implements tcpip.Endpoint.
   566  func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   567  	return e.net.SetSockOpt(opt)
   568  }
   569  
   570  // GetSockOptInt implements tcpip.Endpoint.
   571  func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
   572  	switch opt {
   573  	case tcpip.ReceiveQueueSizeOption:
   574  		v := 0
   575  		e.rcvMu.Lock()
   576  		if !e.rcvList.Empty() {
   577  			p := e.rcvList.Front()
   578  			v = p.pkt.Data().Size()
   579  		}
   580  		e.rcvMu.Unlock()
   581  		return v, nil
   582  
   583  	default:
   584  		return e.net.GetSockOptInt(opt)
   585  	}
   586  }
   587  
   588  // GetSockOpt implements tcpip.Endpoint.
   589  func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
   590  	return e.net.GetSockOpt(opt)
   591  }
   592  
   593  // udpPacketInfo holds information needed to send a UDP packet.
   594  type udpPacketInfo struct {
   595  	ctx        network.WriteContext
   596  	data       buffer.Buffer
   597  	localPort  uint16
   598  	remotePort uint16
   599  }
   600  
   601  // Disconnect implements tcpip.Endpoint.
   602  func (e *endpoint) Disconnect() tcpip.Error {
   603  	e.mu.Lock()
   604  	defer e.mu.Unlock()
   605  
   606  	if e.net.State() != transport.DatagramEndpointStateConnected {
   607  		return nil
   608  	}
   609  	var (
   610  		id  stack.TransportEndpointID
   611  		btd tcpip.NICID
   612  	)
   613  
   614  	// We change this value below and we need the old value to unregister
   615  	// the endpoint.
   616  	boundPortFlags := e.boundPortFlags
   617  
   618  	// Exclude ephemerally bound endpoints.
   619  	info := e.net.Info()
   620  	info.ID.LocalPort = e.localPort
   621  	info.ID.RemotePort = e.remotePort
   622  	if e.net.WasBound() {
   623  		var err tcpip.Error
   624  		id = stack.TransportEndpointID{
   625  			LocalPort:    info.ID.LocalPort,
   626  			LocalAddress: info.ID.LocalAddress,
   627  		}
   628  		id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
   629  		if err != nil {
   630  			return err
   631  		}
   632  		boundPortFlags = e.boundPortFlags
   633  	} else {
   634  		if info.ID.LocalPort != 0 {
   635  			// Release the ephemeral port.
   636  			portRes := ports.Reservation{
   637  				Networks:     e.effectiveNetProtos,
   638  				Transport:    ProtocolNumber,
   639  				Addr:         info.ID.LocalAddress,
   640  				Port:         info.ID.LocalPort,
   641  				Flags:        boundPortFlags,
   642  				BindToDevice: e.boundBindToDevice,
   643  				Dest:         tcpip.FullAddress{},
   644  			}
   645  			e.stack.ReleasePort(portRes)
   646  			e.boundPortFlags = ports.Flags{}
   647  		}
   648  	}
   649  
   650  	e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice)
   651  	e.boundBindToDevice = btd
   652  	e.localPort = id.LocalPort
   653  	e.remotePort = id.RemotePort
   654  
   655  	e.net.Disconnect()
   656  
   657  	return nil
   658  }
   659  
   660  // Connect connects the endpoint to its peer. Specifying a NIC is optional.
   661  func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
   662  	e.mu.Lock()
   663  	defer e.mu.Unlock()
   664  
   665  	err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
   666  		nextID.LocalPort = e.localPort
   667  		nextID.RemotePort = addr.Port
   668  
   669  		// Even if we're connected, this endpoint can still be used to send
   670  		// packets on a different network protocol, so we register both even if
   671  		// v6only is set to false and this is an ipv6 endpoint.
   672  		netProtos := []tcpip.NetworkProtocolNumber{netProto}
   673  		if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) {
   674  			netProtos = []tcpip.NetworkProtocolNumber{
   675  				header.IPv4ProtocolNumber,
   676  				header.IPv6ProtocolNumber,
   677  			}
   678  		}
   679  
   680  		oldPortFlags := e.boundPortFlags
   681  
   682  		nextID, btd, err := e.registerWithStack(netProtos, nextID)
   683  		if err != nil {
   684  			return err
   685  		}
   686  
   687  		// Remove the old registration.
   688  		if e.localPort != 0 {
   689  			previousID.LocalPort = e.localPort
   690  			previousID.RemotePort = e.remotePort
   691  			e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice)
   692  		}
   693  
   694  		e.localPort = nextID.LocalPort
   695  		e.remotePort = nextID.RemotePort
   696  		e.boundBindToDevice = btd
   697  		e.effectiveNetProtos = netProtos
   698  		return nil
   699  	})
   700  	if err != nil {
   701  		return err
   702  	}
   703  
   704  	e.rcvMu.Lock()
   705  	e.rcvReady = true
   706  	e.rcvMu.Unlock()
   707  	return nil
   708  }
   709  
   710  // ConnectEndpoint is not supported.
   711  func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
   712  	return &tcpip.ErrInvalidEndpointState{}
   713  }
   714  
   715  // Shutdown closes the read and/or write end of the endpoint connection
   716  // to its peer.
   717  func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
   718  	e.mu.Lock()
   719  	defer e.mu.Unlock()
   720  
   721  	switch state := e.net.State(); state {
   722  	case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
   723  		return &tcpip.ErrNotConnected{}
   724  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   725  	default:
   726  		panic(fmt.Sprintf("unhandled state = %s", state))
   727  	}
   728  
   729  	if flags&tcpip.ShutdownWrite != 0 {
   730  		if err := e.net.Shutdown(); err != nil {
   731  			return err
   732  		}
   733  	}
   734  
   735  	if flags&tcpip.ShutdownRead != 0 {
   736  		e.readShutdown = true
   737  
   738  		e.rcvMu.Lock()
   739  		wasClosed := e.rcvClosed
   740  		e.rcvClosed = true
   741  		e.rcvMu.Unlock()
   742  
   743  		if !wasClosed {
   744  			e.waiterQueue.Notify(waiter.ReadableEvents)
   745  		}
   746  	}
   747  
   748  	if e.net.State() == transport.DatagramEndpointStateBound {
   749  		return &tcpip.ErrNotConnected{}
   750  	}
   751  	return nil
   752  }
   753  
   754  // Listen is not supported by UDP, it just fails.
   755  func (*endpoint) Listen(int) tcpip.Error {
   756  	return &tcpip.ErrNotSupported{}
   757  }
   758  
   759  // Accept is not supported by UDP, it just fails.
   760  func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
   761  	return nil, nil, &tcpip.ErrNotSupported{}
   762  }
   763  
   764  func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
   765  	bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
   766  	if e.localPort == 0 {
   767  		portRes := ports.Reservation{
   768  			Networks:     netProtos,
   769  			Transport:    ProtocolNumber,
   770  			Addr:         id.LocalAddress,
   771  			Port:         id.LocalPort,
   772  			Flags:        e.portFlags,
   773  			BindToDevice: bindToDevice,
   774  			Dest:         tcpip.FullAddress{},
   775  		}
   776  		port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */)
   777  		if err != nil {
   778  			return id, bindToDevice, err
   779  		}
   780  		id.LocalPort = port
   781  	}
   782  	e.boundPortFlags = e.portFlags
   783  
   784  	err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
   785  	if err != nil {
   786  		portRes := ports.Reservation{
   787  			Networks:     netProtos,
   788  			Transport:    ProtocolNumber,
   789  			Addr:         id.LocalAddress,
   790  			Port:         id.LocalPort,
   791  			Flags:        e.boundPortFlags,
   792  			BindToDevice: bindToDevice,
   793  			Dest:         tcpip.FullAddress{},
   794  		}
   795  		e.stack.ReleasePort(portRes)
   796  		e.boundPortFlags = ports.Flags{}
   797  	}
   798  	return id, bindToDevice, err
   799  }
   800  
   801  func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
   802  	// Don't allow binding once endpoint is not in the initial state
   803  	// anymore.
   804  	if e.net.State() != transport.DatagramEndpointStateInitial {
   805  		return &tcpip.ErrInvalidEndpointState{}
   806  	}
   807  
   808  	err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
   809  		// Expand netProtos to include v4 and v6 if the caller is binding to a
   810  		// wildcard (empty) address, and this is an IPv6 endpoint with v6only
   811  		// set to false.
   812  		netProtos := []tcpip.NetworkProtocolNumber{boundNetProto}
   813  		if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == (tcpip.Address{}) && e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) {
   814  			netProtos = []tcpip.NetworkProtocolNumber{
   815  				header.IPv6ProtocolNumber,
   816  				header.IPv4ProtocolNumber,
   817  			}
   818  		}
   819  
   820  		id := stack.TransportEndpointID{
   821  			LocalPort:    addr.Port,
   822  			LocalAddress: boundAddr,
   823  		}
   824  		id, btd, err := e.registerWithStack(netProtos, id)
   825  		if err != nil {
   826  			return err
   827  		}
   828  
   829  		e.localPort = id.LocalPort
   830  		e.boundBindToDevice = btd
   831  		e.effectiveNetProtos = netProtos
   832  		return nil
   833  	})
   834  	if err != nil {
   835  		return err
   836  	}
   837  
   838  	e.rcvMu.Lock()
   839  	e.rcvReady = true
   840  	e.rcvMu.Unlock()
   841  	return nil
   842  }
   843  
   844  // Bind binds the endpoint to a specific local address and port.
   845  // Specifying a NIC is optional.
   846  func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
   847  	e.mu.Lock()
   848  	defer e.mu.Unlock()
   849  
   850  	err := e.bindLocked(addr)
   851  	if err != nil {
   852  		return err
   853  	}
   854  
   855  	return nil
   856  }
   857  
   858  // GetLocalAddress returns the address to which the endpoint is bound.
   859  func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   860  	e.mu.RLock()
   861  	defer e.mu.RUnlock()
   862  
   863  	addr := e.net.GetLocalAddress()
   864  	addr.Port = e.localPort
   865  	return addr, nil
   866  }
   867  
   868  // GetRemoteAddress returns the address to which the endpoint is connected.
   869  func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
   870  	e.mu.RLock()
   871  	defer e.mu.RUnlock()
   872  
   873  	addr, connected := e.net.GetRemoteAddress()
   874  	if !connected || e.remotePort == 0 {
   875  		return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
   876  	}
   877  
   878  	addr.Port = e.remotePort
   879  	return addr, nil
   880  }
   881  
   882  // Readiness returns the current readiness of the endpoint. For example, if
   883  // waiter.EventIn is set, the endpoint is immediately readable.
   884  func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
   885  	var result waiter.EventMask
   886  
   887  	if e.net.HasSendSpace() {
   888  		result |= waiter.WritableEvents & mask
   889  	}
   890  
   891  	// Determine if the endpoint is readable if requested.
   892  	if mask&waiter.ReadableEvents != 0 {
   893  		e.rcvMu.Lock()
   894  		if !e.rcvList.Empty() || e.rcvClosed {
   895  			result |= waiter.ReadableEvents
   896  		}
   897  		e.rcvMu.Unlock()
   898  	}
   899  
   900  	e.lastErrorMu.Lock()
   901  	hasError := e.lastError != nil
   902  	e.lastErrorMu.Unlock()
   903  	if hasError {
   904  		result |= waiter.EventErr
   905  	}
   906  	return result
   907  }
   908  
   909  // HandlePacket is called by the stack when new packets arrive to this transport
   910  // endpoint.
   911  func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) {
   912  	// Get the header then trim it from the view.
   913  	hdr := header.UDP(pkt.TransportHeader().Slice())
   914  	netHdr := pkt.Network()
   915  	lengthValid, csumValid := header.UDPValid(
   916  		hdr,
   917  		func() uint16 { return pkt.Data().Checksum() },
   918  		uint16(pkt.Data().Size()),
   919  		pkt.NetworkProtocolNumber,
   920  		netHdr.SourceAddress(),
   921  		netHdr.DestinationAddress(),
   922  		pkt.RXChecksumValidated)
   923  	if !lengthValid {
   924  		// Malformed packet.
   925  		e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
   926  		e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
   927  		return
   928  	}
   929  
   930  	if !csumValid {
   931  		e.stack.Stats().UDP.ChecksumErrors.Increment()
   932  		e.stats.ReceiveErrors.ChecksumErrors.Increment()
   933  		return
   934  	}
   935  
   936  	e.stack.Stats().UDP.PacketsReceived.Increment()
   937  	e.stats.PacketsReceived.Increment()
   938  
   939  	e.rcvMu.Lock()
   940  	// Drop the packet if our buffer is not ready to receive packets.
   941  	if !e.rcvReady || e.rcvClosed {
   942  		e.rcvMu.Unlock()
   943  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
   944  		e.stats.ReceiveErrors.ClosedReceiver.Increment()
   945  		return
   946  	}
   947  
   948  	rcvBufSize := e.ops.GetReceiveBufferSize()
   949  	// Drop the packet if our buffer is currently full.
   950  	if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
   951  		e.rcvMu.Unlock()
   952  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
   953  		e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
   954  		return
   955  	}
   956  
   957  	wasEmpty := e.rcvBufSize == 0
   958  
   959  	// Push new packet into receive list and increment the buffer size.
   960  	packet := &udpPacket{
   961  		netProto: pkt.NetworkProtocolNumber,
   962  		senderAddress: tcpip.FullAddress{
   963  			NIC:  pkt.NICID,
   964  			Addr: id.RemoteAddress,
   965  			Port: hdr.SourcePort(),
   966  		},
   967  		destinationAddress: tcpip.FullAddress{
   968  			NIC:  pkt.NICID,
   969  			Addr: id.LocalAddress,
   970  			Port: hdr.DestinationPort(),
   971  		},
   972  		pkt: pkt.IncRef(),
   973  	}
   974  	e.rcvList.PushBack(packet)
   975  	e.rcvBufSize += pkt.Data().Size()
   976  
   977  	// Save any useful information from the network header to the packet.
   978  	packet.tosOrTClass, _ = pkt.Network().TOS()
   979  	switch pkt.NetworkProtocolNumber {
   980  	case header.IPv4ProtocolNumber:
   981  		packet.ttlOrHopLimit = header.IPv4(pkt.NetworkHeader().Slice()).TTL()
   982  	case header.IPv6ProtocolNumber:
   983  		packet.ttlOrHopLimit = header.IPv6(pkt.NetworkHeader().Slice()).HopLimit()
   984  	}
   985  
   986  	// TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
   987  	// address. packetInfo.LocalAddr should hold a unicast address that can be
   988  	// used to respond to the incoming packet.
   989  	localAddr := pkt.Network().DestinationAddress()
   990  	packet.packetInfo.LocalAddr = localAddr
   991  	packet.packetInfo.DestinationAddr = localAddr
   992  	packet.packetInfo.NIC = pkt.NICID
   993  	packet.receivedAt = e.stack.Clock().Now()
   994  
   995  	e.rcvMu.Unlock()
   996  
   997  	// Notify any waiters that there's data to be read now.
   998  	if wasEmpty {
   999  		e.waiterQueue.Notify(waiter.ReadableEvents)
  1000  	}
  1001  }
  1002  
  1003  func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt stack.PacketBufferPtr) {
  1004  	// Update last error first.
  1005  	e.lastErrorMu.Lock()
  1006  	e.lastError = err
  1007  	e.lastErrorMu.Unlock()
  1008  
  1009  	var recvErr bool
  1010  	switch pkt.NetworkProtocolNumber {
  1011  	case header.IPv4ProtocolNumber:
  1012  		recvErr = e.SocketOptions().GetIPv4RecvError()
  1013  	case header.IPv6ProtocolNumber:
  1014  		recvErr = e.SocketOptions().GetIPv6RecvError()
  1015  	default:
  1016  		panic(fmt.Sprintf("unhandled network protocol number = %d", pkt.NetworkProtocolNumber))
  1017  	}
  1018  
  1019  	if recvErr {
  1020  		// Linux passes the payload without the UDP header.
  1021  		payload := pkt.Data().AsRange().ToView()
  1022  		udp := header.UDP(payload.AsSlice())
  1023  		if len(udp) >= header.UDPMinimumSize {
  1024  			payload.TrimFront(header.UDPMinimumSize)
  1025  		}
  1026  
  1027  		id := e.net.Info().ID
  1028  		e.SocketOptions().QueueErr(&tcpip.SockError{
  1029  			Err:     err,
  1030  			Cause:   transErr,
  1031  			Payload: payload,
  1032  			Dst: tcpip.FullAddress{
  1033  				NIC:  pkt.NICID,
  1034  				Addr: id.RemoteAddress,
  1035  				Port: e.remotePort,
  1036  			},
  1037  			Offender: tcpip.FullAddress{
  1038  				NIC:  pkt.NICID,
  1039  				Addr: id.LocalAddress,
  1040  				Port: e.localPort,
  1041  			},
  1042  			NetProto: pkt.NetworkProtocolNumber,
  1043  		})
  1044  	}
  1045  
  1046  	// Notify of the error.
  1047  	e.waiterQueue.Notify(waiter.EventErr)
  1048  }
  1049  
  1050  // HandleError implements stack.TransportEndpoint.
  1051  func (e *endpoint) HandleError(transErr stack.TransportError, pkt stack.PacketBufferPtr) {
  1052  	// TODO(gvisor.dev/issues/5270): Handle all transport errors.
  1053  	switch transErr.Kind() {
  1054  	case stack.DestinationPortUnreachableTransportError:
  1055  		if e.net.State() == transport.DatagramEndpointStateConnected {
  1056  			e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt)
  1057  		}
  1058  	}
  1059  }
  1060  
  1061  // State implements tcpip.Endpoint.
  1062  func (e *endpoint) State() uint32 {
  1063  	return uint32(e.net.State())
  1064  }
  1065  
  1066  // Info returns a copy of the endpoint info.
  1067  func (e *endpoint) Info() tcpip.EndpointInfo {
  1068  	e.mu.RLock()
  1069  	defer e.mu.RUnlock()
  1070  	info := e.net.Info()
  1071  	info.ID.LocalPort = e.localPort
  1072  	info.ID.RemotePort = e.remotePort
  1073  	return &info
  1074  }
  1075  
  1076  // Stats returns a pointer to the endpoint stats.
  1077  func (e *endpoint) Stats() tcpip.EndpointStats {
  1078  	return &e.stats
  1079  }
  1080  
  1081  // Wait implements tcpip.Endpoint.
  1082  func (*endpoint) Wait() {}
  1083  
  1084  // SetOwner implements tcpip.Endpoint.
  1085  func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
  1086  	e.net.SetOwner(owner)
  1087  }
  1088  
  1089  // SocketOptions implements tcpip.Endpoint.
  1090  func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
  1091  	return &e.ops
  1092  }
  1093  
  1094  // freeze prevents any more packets from being delivered to the endpoint.
  1095  func (e *endpoint) freeze() {
  1096  	e.mu.Lock()
  1097  	e.frozen = true
  1098  	e.mu.Unlock()
  1099  }
  1100  
  1101  // thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
  1102  // new packets to be delivered again.
  1103  func (e *endpoint) thaw() {
  1104  	e.mu.Lock()
  1105  	e.frozen = false
  1106  	e.mu.Unlock()
  1107  }