inet.af/netstack@v0.0.0-20220214151720-7585b01ddccf/tcpip/transport/udp/endpoint.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"time"
    21  
    22  	"inet.af/netstack/sync"
    23  	"inet.af/netstack/tcpip"
    24  	"inet.af/netstack/tcpip/buffer"
    25  	"inet.af/netstack/tcpip/header"
    26  	"inet.af/netstack/tcpip/ports"
    27  	"inet.af/netstack/tcpip/stack"
    28  	"inet.af/netstack/tcpip/transport"
    29  	"inet.af/netstack/tcpip/transport/internal/network"
    30  	"inet.af/netstack/waiter"
    31  )
    32  
    33  // +stateify savable
    34  type udpPacket struct {
    35  	udpPacketEntry
    36  	netProto           tcpip.NetworkProtocolNumber
    37  	senderAddress      tcpip.FullAddress
    38  	destinationAddress tcpip.FullAddress
    39  	packetInfo         tcpip.IPPacketInfo
    40  	data               buffer.VectorisedView `state:".(buffer.VectorisedView)"`
    41  	receivedAt         time.Time             `state:".(int64)"`
    42  	// tos stores either the receiveTOS or receiveTClass value.
    43  	tos uint8
    44  }
    45  
    46  // endpoint represents a UDP endpoint. This struct serves as the interface
    47  // between users of the endpoint and the protocol implementation; it is legal to
    48  // have concurrent goroutines make calls into the endpoint, they are properly
    49  // synchronized.
    50  //
    51  // It implements tcpip.Endpoint.
    52  //
    53  // +stateify savable
    54  type endpoint struct {
    55  	tcpip.DefaultSocketOptionsHandler
    56  
    57  	// The following fields are initialized at creation time and do not
    58  	// change throughout the lifetime of the endpoint.
    59  	stack       *stack.Stack `state:"manual"`
    60  	waiterQueue *waiter.Queue
    61  	uniqueID    uint64
    62  	net         network.Endpoint
    63  	stats       tcpip.TransportEndpointStats
    64  	ops         tcpip.SocketOptions
    65  
    66  	// The following fields are used to manage the receive queue, and are
    67  	// protected by rcvMu.
    68  	rcvMu      sync.Mutex `state:"nosave"`
    69  	rcvReady   bool
    70  	rcvList    udpPacketList
    71  	rcvBufSize int
    72  	rcvClosed  bool
    73  
    74  	lastErrorMu sync.Mutex `state:"nosave"`
    75  	lastError   tcpip.Error
    76  
    77  	// The following fields are protected by the mu mutex.
    78  	mu        sync.RWMutex `state:"nosave"`
    79  	portFlags ports.Flags
    80  
    81  	// Values used to reserve a port or register a transport endpoint.
    82  	// (which ever happens first).
    83  	boundBindToDevice tcpip.NICID
    84  	boundPortFlags    ports.Flags
    85  
    86  	readShutdown bool
    87  
    88  	// effectiveNetProtos contains the network protocols actually in use. In
    89  	// most cases it will only contain "netProto", but in cases like IPv6
    90  	// endpoints with v6only set to false, this could include multiple
    91  	// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
    92  	// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
    93  	// address).
    94  	effectiveNetProtos []tcpip.NetworkProtocolNumber
    95  
    96  	// frozen indicates if the packets should be delivered to the endpoint
    97  	// during restore.
    98  	frozen bool
    99  
   100  	localPort  uint16
   101  	remotePort uint16
   102  }
   103  
   104  func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
   105  	e := &endpoint{
   106  		stack:       s,
   107  		waiterQueue: waiterQueue,
   108  		uniqueID:    s.UniqueID(),
   109  	}
   110  	e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
   111  	e.ops.SetMulticastLoop(true)
   112  	e.ops.SetSendBufferSize(32*1024, false /* notify */)
   113  	e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
   114  	e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops)
   115  
   116  	// Override with stack defaults.
   117  	var ss tcpip.SendBufferSizeOption
   118  	if err := s.Option(&ss); err == nil {
   119  		e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
   120  	}
   121  
   122  	var rs tcpip.ReceiveBufferSizeOption
   123  	if err := s.Option(&rs); err == nil {
   124  		e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
   125  	}
   126  
   127  	return e
   128  }
   129  
   130  // UniqueID implements stack.TransportEndpoint.
   131  func (e *endpoint) UniqueID() uint64 {
   132  	return e.uniqueID
   133  }
   134  
   135  func (e *endpoint) LastError() tcpip.Error {
   136  	e.lastErrorMu.Lock()
   137  	defer e.lastErrorMu.Unlock()
   138  
   139  	err := e.lastError
   140  	e.lastError = nil
   141  	return err
   142  }
   143  
   144  // UpdateLastError implements tcpip.SocketOptionsHandler.
   145  func (e *endpoint) UpdateLastError(err tcpip.Error) {
   146  	e.lastErrorMu.Lock()
   147  	e.lastError = err
   148  	e.lastErrorMu.Unlock()
   149  }
   150  
   151  // Abort implements stack.TransportEndpoint.
   152  func (e *endpoint) Abort() {
   153  	e.Close()
   154  }
   155  
   156  // Close puts the endpoint in a closed state and frees all resources
   157  // associated with it.
   158  func (e *endpoint) Close() {
   159  	e.mu.Lock()
   160  
   161  	switch state := e.net.State(); state {
   162  	case transport.DatagramEndpointStateInitial:
   163  	case transport.DatagramEndpointStateClosed:
   164  		e.mu.Unlock()
   165  		return
   166  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   167  		id := e.net.Info().ID
   168  		id.LocalPort = e.localPort
   169  		id.RemotePort = e.remotePort
   170  		e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice)
   171  		portRes := ports.Reservation{
   172  			Networks:     e.effectiveNetProtos,
   173  			Transport:    ProtocolNumber,
   174  			Addr:         id.LocalAddress,
   175  			Port:         id.LocalPort,
   176  			Flags:        e.boundPortFlags,
   177  			BindToDevice: e.boundBindToDevice,
   178  			Dest:         tcpip.FullAddress{},
   179  		}
   180  		e.stack.ReleasePort(portRes)
   181  		e.boundBindToDevice = 0
   182  		e.boundPortFlags = ports.Flags{}
   183  	default:
   184  		panic(fmt.Sprintf("unhandled state = %s", state))
   185  	}
   186  
   187  	// Close the receive list and drain it.
   188  	e.rcvMu.Lock()
   189  	e.rcvClosed = true
   190  	e.rcvBufSize = 0
   191  	for !e.rcvList.Empty() {
   192  		p := e.rcvList.Front()
   193  		e.rcvList.Remove(p)
   194  	}
   195  	e.rcvMu.Unlock()
   196  
   197  	e.net.Shutdown()
   198  	e.net.Close()
   199  	e.readShutdown = true
   200  	e.mu.Unlock()
   201  
   202  	e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   203  }
   204  
   205  // ModerateRecvBuf implements tcpip.Endpoint.
   206  func (*endpoint) ModerateRecvBuf(int) {}
   207  
   208  // Read implements tcpip.Endpoint.
   209  func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
   210  	if err := e.LastError(); err != nil {
   211  		return tcpip.ReadResult{}, err
   212  	}
   213  
   214  	e.rcvMu.Lock()
   215  
   216  	if e.rcvList.Empty() {
   217  		var err tcpip.Error = &tcpip.ErrWouldBlock{}
   218  		if e.rcvClosed {
   219  			e.stats.ReadErrors.ReadClosed.Increment()
   220  			err = &tcpip.ErrClosedForReceive{}
   221  		}
   222  		e.rcvMu.Unlock()
   223  		return tcpip.ReadResult{}, err
   224  	}
   225  
   226  	p := e.rcvList.Front()
   227  	if !opts.Peek {
   228  		e.rcvList.Remove(p)
   229  		e.rcvBufSize -= p.data.Size()
   230  	}
   231  	e.rcvMu.Unlock()
   232  
   233  	// Control Messages
   234  	cm := tcpip.ControlMessages{
   235  		HasTimestamp: true,
   236  		Timestamp:    p.receivedAt,
   237  	}
   238  
   239  	switch p.netProto {
   240  	case header.IPv4ProtocolNumber:
   241  		if e.ops.GetReceiveTOS() {
   242  			cm.HasTOS = true
   243  			cm.TOS = p.tos
   244  		}
   245  
   246  		if e.ops.GetReceivePacketInfo() {
   247  			cm.HasIPPacketInfo = true
   248  			cm.PacketInfo = p.packetInfo
   249  		}
   250  	case header.IPv6ProtocolNumber:
   251  		if e.ops.GetReceiveTClass() {
   252  			cm.HasTClass = true
   253  			// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
   254  			cm.TClass = uint32(p.tos)
   255  		}
   256  
   257  		if e.ops.GetIPv6ReceivePacketInfo() {
   258  			cm.HasIPv6PacketInfo = true
   259  			cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
   260  				NIC:  p.packetInfo.NIC,
   261  				Addr: p.packetInfo.DestinationAddr,
   262  			}
   263  		}
   264  	default:
   265  		panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
   266  	}
   267  
   268  	if e.ops.GetReceiveOriginalDstAddress() {
   269  		cm.HasOriginalDstAddress = true
   270  		cm.OriginalDstAddress = p.destinationAddress
   271  	}
   272  
   273  	// Read Result
   274  	res := tcpip.ReadResult{
   275  		Total:           p.data.Size(),
   276  		ControlMessages: cm,
   277  	}
   278  	if opts.NeedRemoteAddr {
   279  		res.RemoteAddr = p.senderAddress
   280  	}
   281  
   282  	n, err := p.data.ReadTo(dst, opts.Peek)
   283  	if n == 0 && err != nil {
   284  		return res, &tcpip.ErrBadBuffer{}
   285  	}
   286  	res.Count = n
   287  	return res, nil
   288  }
   289  
   290  // prepareForWriteInner prepares the endpoint for sending data. In particular,
   291  // it binds it if it's still in the initial state. To do so, it must first
   292  // reacquire the mutex in exclusive mode.
   293  //
   294  // Returns true for retry if preparation should be retried.
   295  // +checklocksread:e.mu
   296  func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
   297  	switch e.net.State() {
   298  	case transport.DatagramEndpointStateInitial:
   299  	case transport.DatagramEndpointStateConnected:
   300  		return false, nil
   301  
   302  	case transport.DatagramEndpointStateBound:
   303  		if to == nil {
   304  			return false, &tcpip.ErrDestinationRequired{}
   305  		}
   306  		return false, nil
   307  	default:
   308  		return false, &tcpip.ErrInvalidEndpointState{}
   309  	}
   310  
   311  	e.mu.RUnlock()
   312  	e.mu.Lock()
   313  	defer e.mu.DowngradeLock()
   314  
   315  	// The state changed when we released the shared locked and re-acquired
   316  	// it in exclusive mode. Try again.
   317  	if e.net.State() != transport.DatagramEndpointStateInitial {
   318  		return true, nil
   319  	}
   320  
   321  	// The state is still 'initial', so try to bind the endpoint.
   322  	if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
   323  		return false, err
   324  	}
   325  
   326  	return true, nil
   327  }
   328  
   329  // Write writes data to the endpoint's peer. This method does not block
   330  // if the data cannot be written.
   331  func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   332  	n, err := e.write(p, opts)
   333  	switch err.(type) {
   334  	case nil:
   335  		e.stats.PacketsSent.Increment()
   336  	case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
   337  		e.stats.WriteErrors.InvalidArgs.Increment()
   338  	case *tcpip.ErrClosedForSend:
   339  		e.stats.WriteErrors.WriteClosed.Increment()
   340  	case *tcpip.ErrInvalidEndpointState:
   341  		e.stats.WriteErrors.InvalidEndpointState.Increment()
   342  	case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
   343  		// Errors indicating any problem with IP routing of the packet.
   344  		e.stats.SendErrors.NoRoute.Increment()
   345  	default:
   346  		// For all other errors when writing to the network layer.
   347  		e.stats.SendErrors.SendToNetworkFailed.Increment()
   348  	}
   349  	return n, err
   350  }
   351  
   352  func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
   353  	e.mu.RLock()
   354  	defer e.mu.RUnlock()
   355  
   356  	// Prepare for write.
   357  	for {
   358  		retry, err := e.prepareForWriteInner(opts.To)
   359  		if err != nil {
   360  			return udpPacketInfo{}, err
   361  		}
   362  
   363  		if !retry {
   364  			break
   365  		}
   366  	}
   367  
   368  	dst, connected := e.net.GetRemoteAddress()
   369  	dst.Port = e.remotePort
   370  	if opts.To != nil {
   371  		if opts.To.Port == 0 {
   372  			// Port 0 is an invalid port to send to.
   373  			return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
   374  		}
   375  
   376  		dst = *opts.To
   377  	} else if !connected {
   378  		return udpPacketInfo{}, &tcpip.ErrDestinationRequired{}
   379  	}
   380  
   381  	ctx, err := e.net.AcquireContextForWrite(opts)
   382  	if err != nil {
   383  		return udpPacketInfo{}, err
   384  	}
   385  
   386  	// TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
   387  	v := make([]byte, p.Len())
   388  	if _, err := io.ReadFull(p, v); err != nil {
   389  		ctx.Release()
   390  		return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
   391  	}
   392  	if len(v) > header.UDPMaximumPacketSize {
   393  		// Payload can't possibly fit in a packet.
   394  		so := e.SocketOptions()
   395  		if so.GetRecvError() {
   396  			so.QueueLocalErr(
   397  				&tcpip.ErrMessageTooLong{},
   398  				e.net.NetProto(),
   399  				header.UDPMaximumPacketSize,
   400  				dst,
   401  				v,
   402  			)
   403  		}
   404  		ctx.Release()
   405  		return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
   406  	}
   407  
   408  	return udpPacketInfo{
   409  		ctx:        ctx,
   410  		data:       v,
   411  		localPort:  e.localPort,
   412  		remotePort: dst.Port,
   413  	}, nil
   414  }
   415  
   416  func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   417  	// Do not hold lock when sending as loopback is synchronous and if the UDP
   418  	// datagram ends up generating an ICMP response then it can result in a
   419  	// deadlock where the ICMP response handling ends up acquiring this endpoint's
   420  	// mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a
   421  	// deadlock if another caller is trying to acquire e.mu in exclusive mode w/
   422  	// e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the
   423  	// lock can be eventually acquired.
   424  	//
   425  	// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
   426  	// locking is prohibited.
   427  
   428  	if err := e.LastError(); err != nil {
   429  		return 0, err
   430  	}
   431  
   432  	udpInfo, err := e.prepareForWrite(p, opts)
   433  	if err != nil {
   434  		return 0, err
   435  	}
   436  	defer udpInfo.ctx.Release()
   437  
   438  	pktInfo := udpInfo.ctx.PacketInfo()
   439  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   440  		ReserveHeaderBytes: header.UDPMinimumSize + int(pktInfo.MaxHeaderLength),
   441  		Data:               udpInfo.data.ToVectorisedView(),
   442  	})
   443  	defer pkt.DecRef()
   444  
   445  	// Initialize the UDP header.
   446  	udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
   447  	pkt.TransportProtocolNumber = ProtocolNumber
   448  
   449  	length := uint16(pkt.Size())
   450  	udp.Encode(&header.UDPFields{
   451  		SrcPort: udpInfo.localPort,
   452  		DstPort: udpInfo.remotePort,
   453  		Length:  length,
   454  	})
   455  
   456  	// Set the checksum field unless TX checksum offload is enabled.
   457  	// On IPv4, UDP checksum is optional, and a zero value indicates the
   458  	// transmitter skipped the checksum generation (RFC768).
   459  	// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
   460  	if pktInfo.RequiresTXTransportChecksum &&
   461  		(!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) {
   462  		udp.SetChecksum(^udp.CalculateChecksum(header.ChecksumCombine(
   463  			header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length),
   464  			pkt.Data().AsRange().Checksum(),
   465  		)))
   466  	}
   467  	if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
   468  		e.stack.Stats().UDP.PacketSendErrors.Increment()
   469  		return 0, err
   470  	}
   471  
   472  	// Track count of packets sent.
   473  	e.stack.Stats().UDP.PacketsSent.Increment()
   474  	return int64(len(udpInfo.data)), nil
   475  }
   476  
   477  // OnReuseAddressSet implements tcpip.SocketOptionsHandler.
   478  func (e *endpoint) OnReuseAddressSet(v bool) {
   479  	e.mu.Lock()
   480  	e.portFlags.MostRecent = v
   481  	e.mu.Unlock()
   482  }
   483  
   484  // OnReusePortSet implements tcpip.SocketOptionsHandler.
   485  func (e *endpoint) OnReusePortSet(v bool) {
   486  	e.mu.Lock()
   487  	e.portFlags.LoadBalanced = v
   488  	e.mu.Unlock()
   489  }
   490  
   491  // SetSockOptInt implements tcpip.Endpoint.
   492  func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
   493  	return e.net.SetSockOptInt(opt, v)
   494  }
   495  
   496  var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
   497  
   498  // HasNIC implements tcpip.SocketOptionsHandler.
   499  func (e *endpoint) HasNIC(id int32) bool {
   500  	return e.stack.HasNIC(tcpip.NICID(id))
   501  }
   502  
   503  // SetSockOpt implements tcpip.Endpoint.
   504  func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   505  	return e.net.SetSockOpt(opt)
   506  }
   507  
   508  // GetSockOptInt implements tcpip.Endpoint.
   509  func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
   510  	switch opt {
   511  	case tcpip.ReceiveQueueSizeOption:
   512  		v := 0
   513  		e.rcvMu.Lock()
   514  		if !e.rcvList.Empty() {
   515  			p := e.rcvList.Front()
   516  			v = p.data.Size()
   517  		}
   518  		e.rcvMu.Unlock()
   519  		return v, nil
   520  
   521  	default:
   522  		return e.net.GetSockOptInt(opt)
   523  	}
   524  }
   525  
   526  // GetSockOpt implements tcpip.Endpoint.
   527  func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
   528  	return e.net.GetSockOpt(opt)
   529  }
   530  
   531  // udpPacketInfo holds information needed to send a UDP packet.
   532  type udpPacketInfo struct {
   533  	ctx        network.WriteContext
   534  	data       buffer.View
   535  	localPort  uint16
   536  	remotePort uint16
   537  }
   538  
   539  // Disconnect implements tcpip.Endpoint.
   540  func (e *endpoint) Disconnect() tcpip.Error {
   541  	e.mu.Lock()
   542  	defer e.mu.Unlock()
   543  
   544  	if e.net.State() != transport.DatagramEndpointStateConnected {
   545  		return nil
   546  	}
   547  	var (
   548  		id  stack.TransportEndpointID
   549  		btd tcpip.NICID
   550  	)
   551  
   552  	// We change this value below and we need the old value to unregister
   553  	// the endpoint.
   554  	boundPortFlags := e.boundPortFlags
   555  
   556  	// Exclude ephemerally bound endpoints.
   557  	info := e.net.Info()
   558  	info.ID.LocalPort = e.localPort
   559  	info.ID.RemotePort = e.remotePort
   560  	if e.net.WasBound() {
   561  		var err tcpip.Error
   562  		id = stack.TransportEndpointID{
   563  			LocalPort:    info.ID.LocalPort,
   564  			LocalAddress: info.ID.LocalAddress,
   565  		}
   566  		id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
   567  		if err != nil {
   568  			return err
   569  		}
   570  		boundPortFlags = e.boundPortFlags
   571  	} else {
   572  		if info.ID.LocalPort != 0 {
   573  			// Release the ephemeral port.
   574  			portRes := ports.Reservation{
   575  				Networks:     e.effectiveNetProtos,
   576  				Transport:    ProtocolNumber,
   577  				Addr:         info.ID.LocalAddress,
   578  				Port:         info.ID.LocalPort,
   579  				Flags:        boundPortFlags,
   580  				BindToDevice: e.boundBindToDevice,
   581  				Dest:         tcpip.FullAddress{},
   582  			}
   583  			e.stack.ReleasePort(portRes)
   584  			e.boundPortFlags = ports.Flags{}
   585  		}
   586  	}
   587  
   588  	e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice)
   589  	e.boundBindToDevice = btd
   590  	e.localPort = id.LocalPort
   591  	e.remotePort = id.RemotePort
   592  
   593  	e.net.Disconnect()
   594  
   595  	return nil
   596  }
   597  
   598  // Connect connects the endpoint to its peer. Specifying a NIC is optional.
   599  func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
   600  	e.mu.Lock()
   601  	defer e.mu.Unlock()
   602  
   603  	err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
   604  		nextID.LocalPort = e.localPort
   605  		nextID.RemotePort = addr.Port
   606  
   607  		// Even if we're connected, this endpoint can still be used to send
   608  		// packets on a different network protocol, so we register both even if
   609  		// v6only is set to false and this is an ipv6 endpoint.
   610  		netProtos := []tcpip.NetworkProtocolNumber{netProto}
   611  		if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() {
   612  			netProtos = []tcpip.NetworkProtocolNumber{
   613  				header.IPv4ProtocolNumber,
   614  				header.IPv6ProtocolNumber,
   615  			}
   616  		}
   617  
   618  		oldPortFlags := e.boundPortFlags
   619  
   620  		nextID, btd, err := e.registerWithStack(netProtos, nextID)
   621  		if err != nil {
   622  			return err
   623  		}
   624  
   625  		// Remove the old registration.
   626  		if e.localPort != 0 {
   627  			previousID.LocalPort = e.localPort
   628  			previousID.RemotePort = e.remotePort
   629  			e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice)
   630  		}
   631  
   632  		e.localPort = nextID.LocalPort
   633  		e.remotePort = nextID.RemotePort
   634  		e.boundBindToDevice = btd
   635  		e.effectiveNetProtos = netProtos
   636  		return nil
   637  	})
   638  	if err != nil {
   639  		return err
   640  	}
   641  
   642  	e.rcvMu.Lock()
   643  	e.rcvReady = true
   644  	e.rcvMu.Unlock()
   645  	return nil
   646  }
   647  
   648  // ConnectEndpoint is not supported.
   649  func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
   650  	return &tcpip.ErrInvalidEndpointState{}
   651  }
   652  
   653  // Shutdown closes the read and/or write end of the endpoint connection
   654  // to its peer.
   655  func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
   656  	e.mu.Lock()
   657  	defer e.mu.Unlock()
   658  
   659  	switch state := e.net.State(); state {
   660  	case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
   661  		return &tcpip.ErrNotConnected{}
   662  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   663  	default:
   664  		panic(fmt.Sprintf("unhandled state = %s", state))
   665  	}
   666  
   667  	if flags&tcpip.ShutdownWrite != 0 {
   668  		if err := e.net.Shutdown(); err != nil {
   669  			return err
   670  		}
   671  	}
   672  
   673  	if flags&tcpip.ShutdownRead != 0 {
   674  		e.readShutdown = true
   675  
   676  		e.rcvMu.Lock()
   677  		wasClosed := e.rcvClosed
   678  		e.rcvClosed = true
   679  		e.rcvMu.Unlock()
   680  
   681  		if !wasClosed {
   682  			e.waiterQueue.Notify(waiter.ReadableEvents)
   683  		}
   684  	}
   685  
   686  	return nil
   687  }
   688  
   689  // Listen is not supported by UDP, it just fails.
   690  func (*endpoint) Listen(int) tcpip.Error {
   691  	return &tcpip.ErrNotSupported{}
   692  }
   693  
   694  // Accept is not supported by UDP, it just fails.
   695  func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
   696  	return nil, nil, &tcpip.ErrNotSupported{}
   697  }
   698  
   699  func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
   700  	bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
   701  	if e.localPort == 0 {
   702  		portRes := ports.Reservation{
   703  			Networks:     netProtos,
   704  			Transport:    ProtocolNumber,
   705  			Addr:         id.LocalAddress,
   706  			Port:         id.LocalPort,
   707  			Flags:        e.portFlags,
   708  			BindToDevice: bindToDevice,
   709  			Dest:         tcpip.FullAddress{},
   710  		}
   711  		port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */)
   712  		if err != nil {
   713  			return id, bindToDevice, err
   714  		}
   715  		id.LocalPort = port
   716  	}
   717  	e.boundPortFlags = e.portFlags
   718  
   719  	err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
   720  	if err != nil {
   721  		portRes := ports.Reservation{
   722  			Networks:     netProtos,
   723  			Transport:    ProtocolNumber,
   724  			Addr:         id.LocalAddress,
   725  			Port:         id.LocalPort,
   726  			Flags:        e.boundPortFlags,
   727  			BindToDevice: bindToDevice,
   728  			Dest:         tcpip.FullAddress{},
   729  		}
   730  		e.stack.ReleasePort(portRes)
   731  		e.boundPortFlags = ports.Flags{}
   732  	}
   733  	return id, bindToDevice, err
   734  }
   735  
   736  func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
   737  	// Don't allow binding once endpoint is not in the initial state
   738  	// anymore.
   739  	if e.net.State() != transport.DatagramEndpointStateInitial {
   740  		return &tcpip.ErrInvalidEndpointState{}
   741  	}
   742  
   743  	err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
   744  		// Expand netProtos to include v4 and v6 if the caller is binding to a
   745  		// wildcard (empty) address, and this is an IPv6 endpoint with v6only
   746  		// set to false.
   747  		netProtos := []tcpip.NetworkProtocolNumber{boundNetProto}
   748  		if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == "" {
   749  			netProtos = []tcpip.NetworkProtocolNumber{
   750  				header.IPv6ProtocolNumber,
   751  				header.IPv4ProtocolNumber,
   752  			}
   753  		}
   754  
   755  		id := stack.TransportEndpointID{
   756  			LocalPort:    addr.Port,
   757  			LocalAddress: boundAddr,
   758  		}
   759  		id, btd, err := e.registerWithStack(netProtos, id)
   760  		if err != nil {
   761  			return err
   762  		}
   763  
   764  		e.localPort = id.LocalPort
   765  		e.boundBindToDevice = btd
   766  		e.effectiveNetProtos = netProtos
   767  		return nil
   768  	})
   769  	if err != nil {
   770  		return err
   771  	}
   772  
   773  	e.rcvMu.Lock()
   774  	e.rcvReady = true
   775  	e.rcvMu.Unlock()
   776  	return nil
   777  }
   778  
   779  // Bind binds the endpoint to a specific local address and port.
   780  // Specifying a NIC is optional.
   781  func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
   782  	e.mu.Lock()
   783  	defer e.mu.Unlock()
   784  
   785  	err := e.bindLocked(addr)
   786  	if err != nil {
   787  		return err
   788  	}
   789  
   790  	return nil
   791  }
   792  
   793  // GetLocalAddress returns the address to which the endpoint is bound.
   794  func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   795  	e.mu.RLock()
   796  	defer e.mu.RUnlock()
   797  
   798  	addr := e.net.GetLocalAddress()
   799  	addr.Port = e.localPort
   800  	return addr, nil
   801  }
   802  
   803  // GetRemoteAddress returns the address to which the endpoint is connected.
   804  func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
   805  	e.mu.RLock()
   806  	defer e.mu.RUnlock()
   807  
   808  	addr, connected := e.net.GetRemoteAddress()
   809  	if !connected || e.remotePort == 0 {
   810  		return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
   811  	}
   812  
   813  	addr.Port = e.remotePort
   814  	return addr, nil
   815  }
   816  
   817  // Readiness returns the current readiness of the endpoint. For example, if
   818  // waiter.EventIn is set, the endpoint is immediately readable.
   819  func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
   820  	// The endpoint is always writable.
   821  	result := waiter.WritableEvents & mask
   822  
   823  	// Determine if the endpoint is readable if requested.
   824  	if mask&waiter.ReadableEvents != 0 {
   825  		e.rcvMu.Lock()
   826  		if !e.rcvList.Empty() || e.rcvClosed {
   827  			result |= waiter.ReadableEvents
   828  		}
   829  		e.rcvMu.Unlock()
   830  	}
   831  
   832  	e.lastErrorMu.Lock()
   833  	hasError := e.lastError != nil
   834  	e.lastErrorMu.Unlock()
   835  	if hasError {
   836  		result |= waiter.EventErr
   837  	}
   838  	return result
   839  }
   840  
   841  // verifyChecksum verifies the checksum unless RX checksum offload is enabled.
   842  func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
   843  	if pkt.RXTransportChecksumValidated {
   844  		return true
   845  	}
   846  
   847  	// On IPv4, UDP checksum is optional, and a zero value means the transmitter
   848  	// omitted the checksum generation, as per RFC 768:
   849  	//
   850  	//   An all zero transmitted checksum value means that the transmitter
   851  	//   generated  no checksum  (for debugging or for higher level protocols that
   852  	//   don't care).
   853  	//
   854  	// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
   855  	//
   856  	//   Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
   857  	//   checksum is not optional.
   858  	if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 {
   859  		return true
   860  	}
   861  
   862  	netHdr := pkt.Network()
   863  	payloadChecksum := pkt.Data().AsRange().Checksum()
   864  	return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum)
   865  }
   866  
   867  // HandlePacket is called by the stack when new packets arrive to this transport
   868  // endpoint.
   869  func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
   870  	// Get the header then trim it from the view.
   871  	hdr := header.UDP(pkt.TransportHeader().View())
   872  	if int(hdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
   873  		// Malformed packet.
   874  		e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
   875  		e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
   876  		return
   877  	}
   878  
   879  	if !verifyChecksum(hdr, pkt) {
   880  		e.stack.Stats().UDP.ChecksumErrors.Increment()
   881  		e.stats.ReceiveErrors.ChecksumErrors.Increment()
   882  		return
   883  	}
   884  
   885  	e.stack.Stats().UDP.PacketsReceived.Increment()
   886  	e.stats.PacketsReceived.Increment()
   887  
   888  	e.rcvMu.Lock()
   889  	// Drop the packet if our buffer is currently full.
   890  	if !e.rcvReady || e.rcvClosed {
   891  		e.rcvMu.Unlock()
   892  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
   893  		e.stats.ReceiveErrors.ClosedReceiver.Increment()
   894  		return
   895  	}
   896  
   897  	rcvBufSize := e.ops.GetReceiveBufferSize()
   898  	if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
   899  		e.rcvMu.Unlock()
   900  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
   901  		e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
   902  		return
   903  	}
   904  
   905  	wasEmpty := e.rcvBufSize == 0
   906  
   907  	// Push new packet into receive list and increment the buffer size.
   908  	packet := &udpPacket{
   909  		netProto: pkt.NetworkProtocolNumber,
   910  		senderAddress: tcpip.FullAddress{
   911  			NIC:  pkt.NICID,
   912  			Addr: id.RemoteAddress,
   913  			Port: hdr.SourcePort(),
   914  		},
   915  		destinationAddress: tcpip.FullAddress{
   916  			NIC:  pkt.NICID,
   917  			Addr: id.LocalAddress,
   918  			Port: hdr.DestinationPort(),
   919  		},
   920  		data: pkt.Data().ExtractVV(),
   921  	}
   922  	e.rcvList.PushBack(packet)
   923  	e.rcvBufSize += packet.data.Size()
   924  
   925  	// Save any useful information from the network header to the packet.
   926  	switch pkt.NetworkProtocolNumber {
   927  	case header.IPv4ProtocolNumber:
   928  		packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS()
   929  	case header.IPv6ProtocolNumber:
   930  		packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS()
   931  	}
   932  
   933  	// TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
   934  	// address. packetInfo.LocalAddr should hold a unicast address that can be
   935  	// used to respond to the incoming packet.
   936  	localAddr := pkt.Network().DestinationAddress()
   937  	packet.packetInfo.LocalAddr = localAddr
   938  	packet.packetInfo.DestinationAddr = localAddr
   939  	packet.packetInfo.NIC = pkt.NICID
   940  	packet.receivedAt = e.stack.Clock().Now()
   941  
   942  	e.rcvMu.Unlock()
   943  
   944  	// Notify any waiters that there's data to be read now.
   945  	if wasEmpty {
   946  		e.waiterQueue.Notify(waiter.ReadableEvents)
   947  	}
   948  }
   949  
   950  func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) {
   951  	// Update last error first.
   952  	e.lastErrorMu.Lock()
   953  	e.lastError = err
   954  	e.lastErrorMu.Unlock()
   955  
   956  	// Update the error queue if IP_RECVERR is enabled.
   957  	if e.SocketOptions().GetRecvError() {
   958  		// Linux passes the payload without the UDP header.
   959  		var payload []byte
   960  		udp := header.UDP(pkt.Data().AsRange().ToOwnedView())
   961  		if len(udp) >= header.UDPMinimumSize {
   962  			payload = udp.Payload()
   963  		}
   964  
   965  		id := e.net.Info().ID
   966  		e.SocketOptions().QueueErr(&tcpip.SockError{
   967  			Err:     err,
   968  			Cause:   transErr,
   969  			Payload: payload,
   970  			Dst: tcpip.FullAddress{
   971  				NIC:  pkt.NICID,
   972  				Addr: id.RemoteAddress,
   973  				Port: e.remotePort,
   974  			},
   975  			Offender: tcpip.FullAddress{
   976  				NIC:  pkt.NICID,
   977  				Addr: id.LocalAddress,
   978  				Port: e.localPort,
   979  			},
   980  			NetProto: pkt.NetworkProtocolNumber,
   981  		})
   982  	}
   983  
   984  	// Notify of the error.
   985  	e.waiterQueue.Notify(waiter.EventErr)
   986  }
   987  
   988  // HandleError implements stack.TransportEndpoint.
   989  func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) {
   990  	// TODO(gvisor.dev/issues/5270): Handle all transport errors.
   991  	switch transErr.Kind() {
   992  	case stack.DestinationPortUnreachableTransportError:
   993  		if e.net.State() == transport.DatagramEndpointStateConnected {
   994  			e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt)
   995  		}
   996  	}
   997  }
   998  
   999  // State implements tcpip.Endpoint.
  1000  func (e *endpoint) State() uint32 {
  1001  	return uint32(e.net.State())
  1002  }
  1003  
  1004  // Info returns a copy of the endpoint info.
  1005  func (e *endpoint) Info() tcpip.EndpointInfo {
  1006  	e.mu.RLock()
  1007  	defer e.mu.RUnlock()
  1008  	info := e.net.Info()
  1009  	info.ID.LocalPort = e.localPort
  1010  	info.ID.RemotePort = e.remotePort
  1011  	return &info
  1012  }
  1013  
  1014  // Stats returns a pointer to the endpoint stats.
  1015  func (e *endpoint) Stats() tcpip.EndpointStats {
  1016  	return &e.stats
  1017  }
  1018  
  1019  // Wait implements tcpip.Endpoint.
  1020  func (*endpoint) Wait() {}
  1021  
  1022  // SetOwner implements tcpip.Endpoint.
  1023  func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
  1024  	e.net.SetOwner(owner)
  1025  }
  1026  
  1027  // SocketOptions implements tcpip.Endpoint.
  1028  func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
  1029  	return &e.ops
  1030  }
  1031  
  1032  // freeze prevents any more packets from being delivered to the endpoint.
  1033  func (e *endpoint) freeze() {
  1034  	e.mu.Lock()
  1035  	e.frozen = true
  1036  	e.mu.Unlock()
  1037  }
  1038  
  1039  // thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
  1040  // new packets to be delivered again.
  1041  func (e *endpoint) thaw() {
  1042  	e.mu.Lock()
  1043  	e.frozen = false
  1044  	e.mu.Unlock()
  1045  }