github.com/polevpn/netstack@v1.10.9/tcpip/transport/udp/endpoint.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp
    16  
    17  import (
    18  	"sync"
    19  
    20  	"github.com/polevpn/netstack/tcpip"
    21  	"github.com/polevpn/netstack/tcpip/buffer"
    22  	"github.com/polevpn/netstack/tcpip/header"
    23  	"github.com/polevpn/netstack/tcpip/iptables"
    24  	"github.com/polevpn/netstack/tcpip/stack"
    25  	"github.com/polevpn/netstack/waiter"
    26  )
    27  
    28  // +stateify savable
    29  type udpPacket struct {
    30  	udpPacketEntry
    31  	senderAddress tcpip.FullAddress
    32  	data          buffer.VectorisedView
    33  	timestamp     int64
    34  }
    35  
    36  // EndpointState represents the state of a UDP endpoint.
    37  type EndpointState uint32
    38  
    39  // Endpoint states. Note that are represented in a netstack-specific manner and
    40  // may not be meaningful externally. Specifically, they need to be translated to
    41  // Linux's representation for these states if presented to userspace.
    42  const (
    43  	StateInitial EndpointState = iota
    44  	StateBound
    45  	StateConnected
    46  	StateClosed
    47  )
    48  
    49  // String implements fmt.Stringer.String.
    50  func (s EndpointState) String() string {
    51  	switch s {
    52  	case StateInitial:
    53  		return "INITIAL"
    54  	case StateBound:
    55  		return "BOUND"
    56  	case StateConnected:
    57  		return "CONNECTING"
    58  	case StateClosed:
    59  		return "CLOSED"
    60  	default:
    61  		return "UNKNOWN"
    62  	}
    63  }
    64  
    65  // endpoint represents a UDP endpoint. This struct serves as the interface
    66  // between users of the endpoint and the protocol implementation; it is legal to
    67  // have concurrent goroutines make calls into the endpoint, they are properly
    68  // synchronized.
    69  //
    70  // It implements tcpip.Endpoint.
    71  //
    72  // +stateify savable
    73  type endpoint struct {
    74  	stack.TransportEndpointInfo
    75  
    76  	// The following fields are initialized at creation time and do not
    77  	// change throughout the lifetime of the endpoint.
    78  	stack       *stack.Stack
    79  	waiterQueue *waiter.Queue
    80  	uniqueID    uint64
    81  
    82  	// The following fields are used to manage the receive queue, and are
    83  	// protected by rcvMu.
    84  	rcvMu         sync.Mutex
    85  	rcvReady      bool
    86  	rcvList       udpPacketList
    87  	rcvBufSizeMax int
    88  	rcvBufSize    int
    89  	rcvClosed     bool
    90  
    91  	// The following fields are protected by the mu mutex.
    92  	mu             sync.RWMutex
    93  	sndBufSize     int
    94  	state          EndpointState
    95  	route          stack.Route
    96  	dstPort        uint16
    97  	v6only         bool
    98  	ttl            uint8
    99  	multicastTTL   uint8
   100  	multicastAddr  tcpip.Address
   101  	multicastNICID tcpip.NICID
   102  	multicastLoop  bool
   103  	reusePort      bool
   104  	bindToDevice   tcpip.NICID
   105  	broadcast      bool
   106  
   107  	// sendTOS represents IPv4 TOS or IPv6 TrafficClass,
   108  	// applied while sending packets. Defaults to 0 as on Linux.
   109  	sendTOS uint8
   110  
   111  	// shutdownFlags represent the current shutdown state of the endpoint.
   112  	shutdownFlags tcpip.ShutdownFlags
   113  
   114  	// multicastMemberships that need to be remvoed when the endpoint is
   115  	// closed. Protected by the mu mutex.
   116  	multicastMemberships []multicastMembership
   117  
   118  	// effectiveNetProtos contains the network protocols actually in use. In
   119  	// most cases it will only contain "netProto", but in cases like IPv6
   120  	// endpoints with v6only set to false, this could include multiple
   121  	// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
   122  	// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
   123  	// address).
   124  	effectiveNetProtos []tcpip.NetworkProtocolNumber
   125  
   126  	// TODO(b/142022063): Add ability to save and restore per endpoint stats.
   127  	stats tcpip.TransportEndpointStats
   128  }
   129  
   130  // +stateify savable
   131  type multicastMembership struct {
   132  	nicID         tcpip.NICID
   133  	multicastAddr tcpip.Address
   134  }
   135  
   136  func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
   137  	return &endpoint{
   138  		stack: s,
   139  		TransportEndpointInfo: stack.TransportEndpointInfo{
   140  			NetProto:   netProto,
   141  			TransProto: header.UDPProtocolNumber,
   142  		},
   143  		waiterQueue: waiterQueue,
   144  		// RFC 1075 section 5.4 recommends a TTL of 1 for membership
   145  		// requests.
   146  		//
   147  		// RFC 5135 4.2.1 appears to assume that IGMP messages have a
   148  		// TTL of 1.
   149  		//
   150  		// RFC 5135 Appendix A defines TTL=1: A multicast source that
   151  		// wants its traffic to not traverse a router (e.g., leave a
   152  		// home network) may find it useful to send traffic with IP
   153  		// TTL=1.
   154  		//
   155  		// Linux defaults to TTL=1.
   156  		multicastTTL:  1,
   157  		multicastLoop: true,
   158  		rcvBufSizeMax: 32 * 1024,
   159  		sndBufSize:    32 * 1024,
   160  		state:         StateInitial,
   161  		uniqueID:      s.UniqueID(),
   162  	}
   163  }
   164  
   165  // UniqueID implements stack.TransportEndpoint.UniqueID.
   166  func (e *endpoint) UniqueID() uint64 {
   167  	return e.uniqueID
   168  }
   169  
   170  // Close puts the endpoint in a closed state and frees all resources
   171  // associated with it.
   172  func (e *endpoint) Close() {
   173  	e.mu.Lock()
   174  	e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
   175  
   176  	switch e.state {
   177  	case StateBound, StateConnected:
   178  		e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
   179  		e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
   180  	}
   181  
   182  	for _, mem := range e.multicastMemberships {
   183  		e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
   184  	}
   185  	e.multicastMemberships = nil
   186  
   187  	// Close the receive list and drain it.
   188  	e.rcvMu.Lock()
   189  	e.rcvClosed = true
   190  	e.rcvBufSize = 0
   191  	for !e.rcvList.Empty() {
   192  		p := e.rcvList.Front()
   193  		e.rcvList.Remove(p)
   194  	}
   195  	e.rcvMu.Unlock()
   196  
   197  	e.route.Release()
   198  
   199  	// Update the state.
   200  	e.state = StateClosed
   201  
   202  	e.mu.Unlock()
   203  
   204  	e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
   205  }
   206  
   207  // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
   208  func (e *endpoint) ModerateRecvBuf(copied int) {}
   209  
   210  // IPTables implements tcpip.Endpoint.IPTables.
   211  func (e *endpoint) IPTables() (iptables.IPTables, error) {
   212  	return e.stack.IPTables(), nil
   213  }
   214  
   215  // Read reads data from the endpoint. This method does not block if
   216  // there is no data pending.
   217  func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
   218  	e.rcvMu.Lock()
   219  
   220  	if e.rcvList.Empty() {
   221  		err := tcpip.ErrWouldBlock
   222  		if e.rcvClosed {
   223  			e.stats.ReadErrors.ReadClosed.Increment()
   224  			err = tcpip.ErrClosedForReceive
   225  		}
   226  		e.rcvMu.Unlock()
   227  		return buffer.View{}, tcpip.ControlMessages{}, err
   228  	}
   229  
   230  	p := e.rcvList.Front()
   231  	e.rcvList.Remove(p)
   232  	e.rcvBufSize -= p.data.Size()
   233  	e.rcvMu.Unlock()
   234  
   235  	if addr != nil {
   236  		*addr = p.senderAddress
   237  	}
   238  
   239  	return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
   240  }
   241  
   242  // prepareForWrite prepares the endpoint for sending data. In particular, it
   243  // binds it if it's still in the initial state. To do so, it must first
   244  // reacquire the mutex in exclusive mode.
   245  //
   246  // Returns true for retry if preparation should be retried.
   247  func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
   248  	switch e.state {
   249  	case StateInitial:
   250  	case StateConnected:
   251  		return false, nil
   252  
   253  	case StateBound:
   254  		if to == nil {
   255  			return false, tcpip.ErrDestinationRequired
   256  		}
   257  		return false, nil
   258  	default:
   259  		return false, tcpip.ErrInvalidEndpointState
   260  	}
   261  
   262  	e.mu.RUnlock()
   263  	defer e.mu.RLock()
   264  
   265  	e.mu.Lock()
   266  	defer e.mu.Unlock()
   267  
   268  	// The state changed when we released the shared locked and re-acquired
   269  	// it in exclusive mode. Try again.
   270  	if e.state != StateInitial {
   271  		return true, nil
   272  	}
   273  
   274  	// The state is still 'initial', so try to bind the endpoint.
   275  	if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
   276  		return false, err
   277  	}
   278  
   279  	return true, nil
   280  }
   281  
   282  // connectRoute establishes a route to the specified interface or the
   283  // configured multicast interface if no interface is specified and the
   284  // specified address is a multicast address.
   285  func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
   286  	localAddr := e.ID.LocalAddress
   287  	if isBroadcastOrMulticast(localAddr) {
   288  		// A packet can only originate from a unicast address (i.e., an interface).
   289  		localAddr = ""
   290  	}
   291  
   292  	if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
   293  		if nicID == 0 {
   294  			nicID = e.multicastNICID
   295  		}
   296  		if localAddr == "" && nicID == 0 {
   297  			localAddr = e.multicastAddr
   298  		}
   299  	}
   300  
   301  	// Find a route to the desired destination.
   302  	r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
   303  	if err != nil {
   304  		return stack.Route{}, 0, err
   305  	}
   306  	return r, nicID, nil
   307  }
   308  
   309  // Write writes data to the endpoint's peer. This method does not block
   310  // if the data cannot be written.
   311  func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
   312  	n, ch, err := e.write(p, opts)
   313  	switch err {
   314  	case nil:
   315  		e.stats.PacketsSent.Increment()
   316  	case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
   317  		e.stats.WriteErrors.InvalidArgs.Increment()
   318  	case tcpip.ErrClosedForSend:
   319  		e.stats.WriteErrors.WriteClosed.Increment()
   320  	case tcpip.ErrInvalidEndpointState:
   321  		e.stats.WriteErrors.InvalidEndpointState.Increment()
   322  	case tcpip.ErrNoLinkAddress:
   323  		e.stats.SendErrors.NoLinkAddr.Increment()
   324  	case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
   325  		// Errors indicating any problem with IP routing of the packet.
   326  		e.stats.SendErrors.NoRoute.Increment()
   327  	default:
   328  		// For all other errors when writing to the network layer.
   329  		e.stats.SendErrors.SendToNetworkFailed.Increment()
   330  	}
   331  	return n, ch, err
   332  }
   333  
   334  func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
   335  	// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
   336  	if opts.More {
   337  		return 0, nil, tcpip.ErrInvalidOptionValue
   338  	}
   339  
   340  	to := opts.To
   341  
   342  	e.mu.RLock()
   343  	defer e.mu.RUnlock()
   344  
   345  	// If we've shutdown with SHUT_WR we are in an invalid state for sending.
   346  	if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
   347  		return 0, nil, tcpip.ErrClosedForSend
   348  	}
   349  
   350  	// Prepare for write.
   351  	for {
   352  		retry, err := e.prepareForWrite(to)
   353  		if err != nil {
   354  			return 0, nil, err
   355  		}
   356  
   357  		if !retry {
   358  			break
   359  		}
   360  	}
   361  
   362  	var route *stack.Route
   363  	var dstPort uint16
   364  	if to == nil {
   365  		route = &e.route
   366  		dstPort = e.dstPort
   367  
   368  		if route.IsResolutionRequired() {
   369  			// Promote lock to exclusive if using a shared route, given that it may need to
   370  			// change in Route.Resolve() call below.
   371  			e.mu.RUnlock()
   372  			defer e.mu.RLock()
   373  
   374  			e.mu.Lock()
   375  			defer e.mu.Unlock()
   376  
   377  			// Recheck state after lock was re-acquired.
   378  			if e.state != StateConnected {
   379  				return 0, nil, tcpip.ErrInvalidEndpointState
   380  			}
   381  		}
   382  	} else {
   383  		// Reject destination address if it goes through a different
   384  		// NIC than the endpoint was bound to.
   385  		nicID := to.NIC
   386  		if e.BindNICID != 0 {
   387  			if nicID != 0 && nicID != e.BindNICID {
   388  				return 0, nil, tcpip.ErrNoRoute
   389  			}
   390  
   391  			nicID = e.BindNICID
   392  		}
   393  
   394  		if to.Addr == header.IPv4Broadcast && !e.broadcast {
   395  			return 0, nil, tcpip.ErrBroadcastDisabled
   396  		}
   397  
   398  		netProto, err := e.checkV4Mapped(to, false)
   399  		if err != nil {
   400  			return 0, nil, err
   401  		}
   402  
   403  		r, _, err := e.connectRoute(nicID, *to, netProto)
   404  		if err != nil {
   405  			return 0, nil, err
   406  		}
   407  		defer r.Release()
   408  
   409  		route = &r
   410  		dstPort = to.Port
   411  	}
   412  
   413  	if route.IsResolutionRequired() {
   414  		if ch, err := route.Resolve(nil); err != nil {
   415  			if err == tcpip.ErrWouldBlock {
   416  				return 0, ch, tcpip.ErrNoLinkAddress
   417  			}
   418  			return 0, nil, err
   419  		}
   420  	}
   421  
   422  	v, err := p.FullPayload()
   423  	if err != nil {
   424  		return 0, nil, err
   425  	}
   426  	if len(v) > header.UDPMaximumPacketSize {
   427  		// Payload can't possibly fit in a packet.
   428  		return 0, nil, tcpip.ErrMessageTooLong
   429  	}
   430  
   431  	ttl := e.ttl
   432  	useDefaultTTL := ttl == 0
   433  
   434  	if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
   435  		ttl = e.multicastTTL
   436  		// Multicast allows a 0 TTL.
   437  		useDefaultTTL = false
   438  	}
   439  
   440  	if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil {
   441  		return 0, nil, err
   442  	}
   443  	return int64(len(v)), nil, nil
   444  }
   445  
   446  // Peek only returns data from a single datagram, so do nothing here.
   447  func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
   448  	return 0, tcpip.ControlMessages{}, nil
   449  }
   450  
   451  // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
   452  func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
   453  	return nil
   454  }
   455  
   456  // SetSockOpt implements tcpip.Endpoint.SetSockOpt.
   457  func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
   458  	switch v := opt.(type) {
   459  	case tcpip.V6OnlyOption:
   460  		// We only recognize this option on v6 endpoints.
   461  		if e.NetProto != header.IPv6ProtocolNumber {
   462  			return tcpip.ErrInvalidEndpointState
   463  		}
   464  
   465  		e.mu.Lock()
   466  		defer e.mu.Unlock()
   467  
   468  		// We only allow this to be set when we're in the initial state.
   469  		if e.state != StateInitial {
   470  			return tcpip.ErrInvalidEndpointState
   471  		}
   472  
   473  		e.v6only = v != 0
   474  
   475  	case tcpip.TTLOption:
   476  		e.mu.Lock()
   477  		e.ttl = uint8(v)
   478  		e.mu.Unlock()
   479  
   480  	case tcpip.MulticastTTLOption:
   481  		e.mu.Lock()
   482  		e.multicastTTL = uint8(v)
   483  		e.mu.Unlock()
   484  
   485  	case tcpip.MulticastInterfaceOption:
   486  		e.mu.Lock()
   487  		defer e.mu.Unlock()
   488  
   489  		fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
   490  		netProto, err := e.checkV4Mapped(&fa, false)
   491  		if err != nil {
   492  			return err
   493  		}
   494  		nic := v.NIC
   495  		addr := fa.Addr
   496  
   497  		if nic == 0 && addr == "" {
   498  			e.multicastAddr = ""
   499  			e.multicastNICID = 0
   500  			break
   501  		}
   502  
   503  		if nic != 0 {
   504  			if !e.stack.CheckNIC(nic) {
   505  				return tcpip.ErrBadLocalAddress
   506  			}
   507  		} else {
   508  			nic = e.stack.CheckLocalAddress(0, netProto, addr)
   509  			if nic == 0 {
   510  				return tcpip.ErrBadLocalAddress
   511  			}
   512  		}
   513  
   514  		if e.BindNICID != 0 && e.BindNICID != nic {
   515  			return tcpip.ErrInvalidEndpointState
   516  		}
   517  
   518  		e.multicastNICID = nic
   519  		e.multicastAddr = addr
   520  
   521  	case tcpip.AddMembershipOption:
   522  		if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
   523  			return tcpip.ErrInvalidOptionValue
   524  		}
   525  
   526  		nicID := v.NIC
   527  
   528  		// The interface address is considered not-set if it is empty or contains
   529  		// all-zeros. The former represent the zero-value in golang, the latter the
   530  		// same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
   531  		allZeros := header.IPv4Any
   532  		if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
   533  			if nicID == 0 {
   534  				r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
   535  				if err == nil {
   536  					nicID = r.NICID()
   537  					r.Release()
   538  				}
   539  			}
   540  		} else {
   541  			nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
   542  		}
   543  		if nicID == 0 {
   544  			return tcpip.ErrUnknownDevice
   545  		}
   546  
   547  		memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
   548  
   549  		e.mu.Lock()
   550  		defer e.mu.Unlock()
   551  
   552  		for _, mem := range e.multicastMemberships {
   553  			if mem == memToInsert {
   554  				return tcpip.ErrPortInUse
   555  			}
   556  		}
   557  
   558  		if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
   559  			return err
   560  		}
   561  
   562  		e.multicastMemberships = append(e.multicastMemberships, memToInsert)
   563  
   564  	case tcpip.RemoveMembershipOption:
   565  		if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
   566  			return tcpip.ErrInvalidOptionValue
   567  		}
   568  
   569  		nicID := v.NIC
   570  		if v.InterfaceAddr == header.IPv4Any {
   571  			if nicID == 0 {
   572  				r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
   573  				if err == nil {
   574  					nicID = r.NICID()
   575  					r.Release()
   576  				}
   577  			}
   578  		} else {
   579  			nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
   580  		}
   581  		if nicID == 0 {
   582  			return tcpip.ErrUnknownDevice
   583  		}
   584  
   585  		memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
   586  		memToRemoveIndex := -1
   587  
   588  		e.mu.Lock()
   589  		defer e.mu.Unlock()
   590  
   591  		for i, mem := range e.multicastMemberships {
   592  			if mem == memToRemove {
   593  				memToRemoveIndex = i
   594  				break
   595  			}
   596  		}
   597  		if memToRemoveIndex == -1 {
   598  			return tcpip.ErrBadLocalAddress
   599  		}
   600  
   601  		if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
   602  			return err
   603  		}
   604  
   605  		e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
   606  		e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
   607  
   608  	case tcpip.MulticastLoopOption:
   609  		e.mu.Lock()
   610  		e.multicastLoop = bool(v)
   611  		e.mu.Unlock()
   612  
   613  	case tcpip.ReusePortOption:
   614  		e.mu.Lock()
   615  		e.reusePort = v != 0
   616  		e.mu.Unlock()
   617  
   618  	case tcpip.BindToDeviceOption:
   619  		e.mu.Lock()
   620  		defer e.mu.Unlock()
   621  		if v == "" {
   622  			e.bindToDevice = 0
   623  			return nil
   624  		}
   625  		for nicID, nic := range e.stack.NICInfo() {
   626  			if nic.Name == string(v) {
   627  				e.bindToDevice = nicID
   628  				return nil
   629  			}
   630  		}
   631  		return tcpip.ErrUnknownDevice
   632  
   633  	case tcpip.BroadcastOption:
   634  		e.mu.Lock()
   635  		e.broadcast = v != 0
   636  		e.mu.Unlock()
   637  
   638  		return nil
   639  
   640  	case tcpip.IPv4TOSOption:
   641  		e.mu.Lock()
   642  		e.sendTOS = uint8(v)
   643  		e.mu.Unlock()
   644  		return nil
   645  
   646  	case tcpip.IPv6TrafficClassOption:
   647  		e.mu.Lock()
   648  		e.sendTOS = uint8(v)
   649  		e.mu.Unlock()
   650  		return nil
   651  	}
   652  	return nil
   653  }
   654  
   655  // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
   656  func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
   657  	switch opt {
   658  	case tcpip.ReceiveQueueSizeOption:
   659  		v := 0
   660  		e.rcvMu.Lock()
   661  		if !e.rcvList.Empty() {
   662  			p := e.rcvList.Front()
   663  			v = p.data.Size()
   664  		}
   665  		e.rcvMu.Unlock()
   666  		return v, nil
   667  
   668  	case tcpip.SendBufferSizeOption:
   669  		e.mu.Lock()
   670  		v := e.sndBufSize
   671  		e.mu.Unlock()
   672  		return v, nil
   673  
   674  	case tcpip.ReceiveBufferSizeOption:
   675  		e.rcvMu.Lock()
   676  		v := e.rcvBufSizeMax
   677  		e.rcvMu.Unlock()
   678  		return v, nil
   679  	}
   680  
   681  	return -1, tcpip.ErrUnknownProtocolOption
   682  }
   683  
   684  // GetSockOpt implements tcpip.Endpoint.GetSockOpt.
   685  func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
   686  	switch o := opt.(type) {
   687  	case tcpip.ErrorOption:
   688  		return nil
   689  
   690  	case *tcpip.V6OnlyOption:
   691  		// We only recognize this option on v6 endpoints.
   692  		if e.NetProto != header.IPv6ProtocolNumber {
   693  			return tcpip.ErrUnknownProtocolOption
   694  		}
   695  
   696  		e.mu.Lock()
   697  		v := e.v6only
   698  		e.mu.Unlock()
   699  
   700  		*o = 0
   701  		if v {
   702  			*o = 1
   703  		}
   704  		return nil
   705  
   706  	case *tcpip.TTLOption:
   707  		e.mu.Lock()
   708  		*o = tcpip.TTLOption(e.ttl)
   709  		e.mu.Unlock()
   710  		return nil
   711  
   712  	case *tcpip.MulticastTTLOption:
   713  		e.mu.Lock()
   714  		*o = tcpip.MulticastTTLOption(e.multicastTTL)
   715  		e.mu.Unlock()
   716  		return nil
   717  
   718  	case *tcpip.MulticastInterfaceOption:
   719  		e.mu.Lock()
   720  		*o = tcpip.MulticastInterfaceOption{
   721  			e.multicastNICID,
   722  			e.multicastAddr,
   723  		}
   724  		e.mu.Unlock()
   725  		return nil
   726  
   727  	case *tcpip.MulticastLoopOption:
   728  		e.mu.RLock()
   729  		v := e.multicastLoop
   730  		e.mu.RUnlock()
   731  
   732  		*o = tcpip.MulticastLoopOption(v)
   733  		return nil
   734  
   735  	case *tcpip.ReuseAddressOption:
   736  		*o = 0
   737  		return nil
   738  
   739  	case *tcpip.ReusePortOption:
   740  		e.mu.RLock()
   741  		v := e.reusePort
   742  		e.mu.RUnlock()
   743  
   744  		*o = 0
   745  		if v {
   746  			*o = 1
   747  		}
   748  		return nil
   749  
   750  	case *tcpip.BindToDeviceOption:
   751  		e.mu.RLock()
   752  		defer e.mu.RUnlock()
   753  		if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
   754  			*o = tcpip.BindToDeviceOption(nic.Name)
   755  			return nil
   756  		}
   757  		*o = tcpip.BindToDeviceOption("")
   758  		return nil
   759  
   760  	case *tcpip.KeepaliveEnabledOption:
   761  		*o = 0
   762  		return nil
   763  
   764  	case *tcpip.BroadcastOption:
   765  		e.mu.RLock()
   766  		v := e.broadcast
   767  		e.mu.RUnlock()
   768  
   769  		*o = 0
   770  		if v {
   771  			*o = 1
   772  		}
   773  		return nil
   774  
   775  	case *tcpip.IPv4TOSOption:
   776  		e.mu.RLock()
   777  		*o = tcpip.IPv4TOSOption(e.sendTOS)
   778  		e.mu.RUnlock()
   779  		return nil
   780  
   781  	case *tcpip.IPv6TrafficClassOption:
   782  		e.mu.RLock()
   783  		*o = tcpip.IPv6TrafficClassOption(e.sendTOS)
   784  		e.mu.RUnlock()
   785  		return nil
   786  
   787  	default:
   788  		return tcpip.ErrUnknownProtocolOption
   789  	}
   790  }
   791  
   792  // sendUDP sends a UDP segment via the provided network endpoint and under the
   793  // provided identity.
   794  func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error {
   795  	// Allocate a buffer for the UDP header.
   796  	hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
   797  
   798  	// Initialize the header.
   799  	udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
   800  
   801  	length := uint16(hdr.UsedLength() + data.Size())
   802  	udp.Encode(&header.UDPFields{
   803  		SrcPort: localPort,
   804  		DstPort: remotePort,
   805  		Length:  length,
   806  	})
   807  
   808  	// Only calculate the checksum if offloading isn't supported.
   809  	if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
   810  		xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
   811  		for _, v := range data.Views() {
   812  			xsum = header.Checksum(v, xsum)
   813  		}
   814  		udp.SetChecksum(^udp.CalculateChecksum(xsum))
   815  	}
   816  
   817  	if useDefaultTTL {
   818  		ttl = r.DefaultTTL()
   819  	}
   820  	if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{
   821  		Header: hdr,
   822  		Data:   data,
   823  	}); err != nil {
   824  		r.Stats().UDP.PacketSendErrors.Increment()
   825  		return err
   826  	}
   827  
   828  	// Track count of packets sent.
   829  	r.Stats().UDP.PacketsSent.Increment()
   830  	return nil
   831  }
   832  
   833  func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
   834  	netProto := e.NetProto
   835  	if len(addr.Addr) == 0 {
   836  		return netProto, nil
   837  	}
   838  	if header.IsV4MappedAddress(addr.Addr) {
   839  		// Fail if using a v4 mapped address on a v6only endpoint.
   840  		if e.v6only {
   841  			return 0, tcpip.ErrNoRoute
   842  		}
   843  
   844  		netProto = header.IPv4ProtocolNumber
   845  		addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
   846  		if addr.Addr == header.IPv4Any {
   847  			addr.Addr = ""
   848  		}
   849  
   850  		// Fail if we are bound to an IPv6 address.
   851  		if !allowMismatch && len(e.ID.LocalAddress) == 16 {
   852  			return 0, tcpip.ErrNetworkUnreachable
   853  		}
   854  	}
   855  
   856  	// Fail if we're bound to an address length different from the one we're
   857  	// checking.
   858  	if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) {
   859  		return 0, tcpip.ErrInvalidEndpointState
   860  	}
   861  
   862  	return netProto, nil
   863  }
   864  
   865  // Disconnect implements tcpip.Endpoint.Disconnect.
   866  func (e *endpoint) Disconnect() *tcpip.Error {
   867  	e.mu.Lock()
   868  	defer e.mu.Unlock()
   869  
   870  	if e.state != StateConnected {
   871  		return nil
   872  	}
   873  	id := stack.TransportEndpointID{}
   874  	// Exclude ephemerally bound endpoints.
   875  	if e.BindNICID != 0 || e.ID.LocalAddress == "" {
   876  		var err *tcpip.Error
   877  		id = stack.TransportEndpointID{
   878  			LocalPort:    e.ID.LocalPort,
   879  			LocalAddress: e.ID.LocalAddress,
   880  		}
   881  		id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
   882  		if err != nil {
   883  			return err
   884  		}
   885  		e.state = StateBound
   886  	} else {
   887  		if e.ID.LocalPort != 0 {
   888  			// Release the ephemeral port.
   889  			e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
   890  		}
   891  		e.state = StateInitial
   892  	}
   893  
   894  	e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
   895  	e.ID = id
   896  	e.route.Release()
   897  	e.route = stack.Route{}
   898  	e.dstPort = 0
   899  
   900  	return nil
   901  }
   902  
   903  // Connect connects the endpoint to its peer. Specifying a NIC is optional.
   904  func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
   905  	netProto, err := e.checkV4Mapped(&addr, false)
   906  	if err != nil {
   907  		return err
   908  	}
   909  	if addr.Port == 0 {
   910  		// We don't support connecting to port zero.
   911  		return tcpip.ErrInvalidEndpointState
   912  	}
   913  
   914  	e.mu.Lock()
   915  	defer e.mu.Unlock()
   916  
   917  	nicID := addr.NIC
   918  	var localPort uint16
   919  	switch e.state {
   920  	case StateInitial:
   921  	case StateBound, StateConnected:
   922  		localPort = e.ID.LocalPort
   923  		if e.BindNICID == 0 {
   924  			break
   925  		}
   926  
   927  		if nicID != 0 && nicID != e.BindNICID {
   928  			return tcpip.ErrInvalidEndpointState
   929  		}
   930  
   931  		nicID = e.BindNICID
   932  	default:
   933  		return tcpip.ErrInvalidEndpointState
   934  	}
   935  
   936  	r, nicID, err := e.connectRoute(nicID, addr, netProto)
   937  	if err != nil {
   938  		return err
   939  	}
   940  	defer r.Release()
   941  
   942  	id := stack.TransportEndpointID{
   943  		LocalAddress:  e.ID.LocalAddress,
   944  		LocalPort:     localPort,
   945  		RemotePort:    addr.Port,
   946  		RemoteAddress: r.RemoteAddress,
   947  	}
   948  
   949  	if e.state == StateInitial {
   950  		id.LocalAddress = r.LocalAddress
   951  	}
   952  
   953  	// Even if we're connected, this endpoint can still be used to send
   954  	// packets on a different network protocol, so we register both even if
   955  	// v6only is set to false and this is an ipv6 endpoint.
   956  	netProtos := []tcpip.NetworkProtocolNumber{netProto}
   957  	if netProto == header.IPv6ProtocolNumber && !e.v6only {
   958  		netProtos = []tcpip.NetworkProtocolNumber{
   959  			header.IPv4ProtocolNumber,
   960  			header.IPv6ProtocolNumber,
   961  		}
   962  	}
   963  
   964  	id, err = e.registerWithStack(nicID, netProtos, id)
   965  	if err != nil {
   966  		return err
   967  	}
   968  
   969  	// Remove the old registration.
   970  	if e.ID.LocalPort != 0 {
   971  		e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
   972  	}
   973  
   974  	e.ID = id
   975  	e.route = r.Clone()
   976  	e.dstPort = addr.Port
   977  	e.RegisterNICID = nicID
   978  	e.effectiveNetProtos = netProtos
   979  
   980  	e.state = StateConnected
   981  
   982  	e.rcvMu.Lock()
   983  	e.rcvReady = true
   984  	e.rcvMu.Unlock()
   985  
   986  	return nil
   987  }
   988  
   989  // ConnectEndpoint is not supported.
   990  func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
   991  	return tcpip.ErrInvalidEndpointState
   992  }
   993  
   994  // Shutdown closes the read and/or write end of the endpoint connection
   995  // to its peer.
   996  func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
   997  	e.mu.Lock()
   998  	defer e.mu.Unlock()
   999  
  1000  	// A socket in the bound state can still receive multicast messages,
  1001  	// so we need to notify waiters on shutdown.
  1002  	if e.state != StateBound && e.state != StateConnected {
  1003  		return tcpip.ErrNotConnected
  1004  	}
  1005  
  1006  	e.shutdownFlags |= flags
  1007  
  1008  	if flags&tcpip.ShutdownRead != 0 {
  1009  		e.rcvMu.Lock()
  1010  		wasClosed := e.rcvClosed
  1011  		e.rcvClosed = true
  1012  		e.rcvMu.Unlock()
  1013  
  1014  		if !wasClosed {
  1015  			e.waiterQueue.Notify(waiter.EventIn)
  1016  		}
  1017  	}
  1018  
  1019  	return nil
  1020  }
  1021  
  1022  // Listen is not supported by UDP, it just fails.
  1023  func (*endpoint) Listen(int) *tcpip.Error {
  1024  	return tcpip.ErrNotSupported
  1025  }
  1026  
  1027  // Accept is not supported by UDP, it just fails.
  1028  func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
  1029  	return nil, nil, tcpip.ErrNotSupported
  1030  }
  1031  
  1032  func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
  1033  	if e.ID.LocalPort == 0 {
  1034  		port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
  1035  		if err != nil {
  1036  			return id, err
  1037  		}
  1038  		id.LocalPort = port
  1039  	}
  1040  
  1041  	err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
  1042  	if err != nil {
  1043  		e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
  1044  	}
  1045  	return id, err
  1046  }
  1047  
  1048  func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
  1049  	// Don't allow binding once endpoint is not in the initial state
  1050  	// anymore.
  1051  	if e.state != StateInitial {
  1052  		return tcpip.ErrInvalidEndpointState
  1053  	}
  1054  
  1055  	netProto, err := e.checkV4Mapped(&addr, true)
  1056  	if err != nil {
  1057  		return err
  1058  	}
  1059  
  1060  	// Expand netProtos to include v4 and v6 if the caller is binding to a
  1061  	// wildcard (empty) address, and this is an IPv6 endpoint with v6only
  1062  	// set to false.
  1063  	netProtos := []tcpip.NetworkProtocolNumber{netProto}
  1064  	if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
  1065  		netProtos = []tcpip.NetworkProtocolNumber{
  1066  			header.IPv6ProtocolNumber,
  1067  			header.IPv4ProtocolNumber,
  1068  		}
  1069  	}
  1070  
  1071  	nicID := addr.NIC
  1072  	if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
  1073  		// A local unicast address was specified, verify that it's valid.
  1074  		nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
  1075  		if nicID == 0 {
  1076  			return tcpip.ErrBadLocalAddress
  1077  		}
  1078  	}
  1079  
  1080  	id := stack.TransportEndpointID{
  1081  		LocalPort:    addr.Port,
  1082  		LocalAddress: addr.Addr,
  1083  	}
  1084  	id, err = e.registerWithStack(nicID, netProtos, id)
  1085  	if err != nil {
  1086  		return err
  1087  	}
  1088  
  1089  	e.ID = id
  1090  	e.RegisterNICID = nicID
  1091  	e.effectiveNetProtos = netProtos
  1092  
  1093  	// Mark endpoint as bound.
  1094  	e.state = StateBound
  1095  
  1096  	e.rcvMu.Lock()
  1097  	e.rcvReady = true
  1098  	e.rcvMu.Unlock()
  1099  
  1100  	return nil
  1101  }
  1102  
  1103  // Bind binds the endpoint to a specific local address and port.
  1104  // Specifying a NIC is optional.
  1105  func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
  1106  	e.mu.Lock()
  1107  	defer e.mu.Unlock()
  1108  
  1109  	err := e.bindLocked(addr)
  1110  	if err != nil {
  1111  		return err
  1112  	}
  1113  
  1114  	// Save the effective NICID generated by bindLocked.
  1115  	e.BindNICID = e.RegisterNICID
  1116  
  1117  	return nil
  1118  }
  1119  
  1120  // GetLocalAddress returns the address to which the endpoint is bound.
  1121  func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
  1122  	e.mu.RLock()
  1123  	defer e.mu.RUnlock()
  1124  
  1125  	return tcpip.FullAddress{
  1126  		NIC:  e.RegisterNICID,
  1127  		Addr: e.ID.LocalAddress,
  1128  		Port: e.ID.LocalPort,
  1129  	}, nil
  1130  }
  1131  
  1132  // GetRemoteAddress returns the address to which the endpoint is connected.
  1133  func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
  1134  	e.mu.RLock()
  1135  	defer e.mu.RUnlock()
  1136  
  1137  	if e.state != StateConnected {
  1138  		return tcpip.FullAddress{}, tcpip.ErrNotConnected
  1139  	}
  1140  
  1141  	return tcpip.FullAddress{
  1142  		NIC:  e.RegisterNICID,
  1143  		Addr: e.ID.RemoteAddress,
  1144  		Port: e.ID.RemotePort,
  1145  	}, nil
  1146  }
  1147  
  1148  // Readiness returns the current readiness of the endpoint. For example, if
  1149  // waiter.EventIn is set, the endpoint is immediately readable.
  1150  func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
  1151  	// The endpoint is always writable.
  1152  	result := waiter.EventOut & mask
  1153  
  1154  	// Determine if the endpoint is readable if requested.
  1155  	if (mask & waiter.EventIn) != 0 {
  1156  		e.rcvMu.Lock()
  1157  		if !e.rcvList.Empty() || e.rcvClosed {
  1158  			result |= waiter.EventIn
  1159  		}
  1160  		e.rcvMu.Unlock()
  1161  	}
  1162  
  1163  	return result
  1164  }
  1165  
  1166  // HandlePacket is called by the stack when new packets arrive to this transport
  1167  // endpoint.
  1168  func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
  1169  	// Get the header then trim it from the view.
  1170  	hdr := header.UDP(pkt.Data.First())
  1171  	if int(hdr.Length()) > pkt.Data.Size() {
  1172  		// Malformed packet.
  1173  		e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
  1174  		e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
  1175  		return
  1176  	}
  1177  
  1178  	pkt.Data.TrimFront(header.UDPMinimumSize)
  1179  
  1180  	e.rcvMu.Lock()
  1181  	e.stack.Stats().UDP.PacketsReceived.Increment()
  1182  	e.stats.PacketsReceived.Increment()
  1183  
  1184  	// Drop the packet if our buffer is currently full.
  1185  	if !e.rcvReady || e.rcvClosed {
  1186  		e.rcvMu.Unlock()
  1187  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
  1188  		e.stats.ReceiveErrors.ClosedReceiver.Increment()
  1189  		return
  1190  	}
  1191  
  1192  	if e.rcvBufSize >= e.rcvBufSizeMax {
  1193  		e.rcvMu.Unlock()
  1194  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
  1195  		e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
  1196  		return
  1197  	}
  1198  
  1199  	wasEmpty := e.rcvBufSize == 0
  1200  
  1201  	// Push new packet into receive list and increment the buffer size.
  1202  	packet := &udpPacket{
  1203  		senderAddress: tcpip.FullAddress{
  1204  			NIC:  r.NICID(),
  1205  			Addr: id.RemoteAddress,
  1206  			Port: hdr.SourcePort(),
  1207  		},
  1208  	}
  1209  	packet.data = pkt.Data
  1210  	e.rcvList.PushBack(packet)
  1211  	e.rcvBufSize += pkt.Data.Size()
  1212  
  1213  	packet.timestamp = e.stack.NowNanoseconds()
  1214  
  1215  	e.rcvMu.Unlock()
  1216  
  1217  	// Notify any waiters that there's data to be read now.
  1218  	if wasEmpty {
  1219  		e.waiterQueue.Notify(waiter.EventIn)
  1220  	}
  1221  }
  1222  
  1223  // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
  1224  func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) {
  1225  }
  1226  
  1227  // State implements tcpip.Endpoint.State.
  1228  func (e *endpoint) State() uint32 {
  1229  	e.mu.Lock()
  1230  	defer e.mu.Unlock()
  1231  	return uint32(e.state)
  1232  }
  1233  
  1234  // Info returns a copy of the endpoint info.
  1235  func (e *endpoint) Info() tcpip.EndpointInfo {
  1236  	e.mu.RLock()
  1237  	// Make a copy of the endpoint info.
  1238  	ret := e.TransportEndpointInfo
  1239  	e.mu.RUnlock()
  1240  	return &ret
  1241  }
  1242  
  1243  // Stats returns a pointer to the endpoint stats.
  1244  func (e *endpoint) Stats() tcpip.EndpointStats {
  1245  	return &e.stats
  1246  }
  1247  
  1248  // Wait implements tcpip.Endpoint.Wait.
  1249  func (*endpoint) Wait() {}
  1250  
  1251  func isBroadcastOrMulticast(a tcpip.Address) bool {
  1252  	return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
  1253  }