github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/transport/udp/endpoint.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package udp
    16  
    17  import (
    18  	"sync"
    19  
    20  	"github.com/FlowerWrong/netstack/tcpip"
    21  	"github.com/FlowerWrong/netstack/tcpip/buffer"
    22  	"github.com/FlowerWrong/netstack/tcpip/header"
    23  	"github.com/FlowerWrong/netstack/tcpip/iptables"
    24  	"github.com/FlowerWrong/netstack/tcpip/stack"
    25  	"github.com/FlowerWrong/netstack/waiter"
    26  )
    27  
    28  // +stateify savable
    29  type udpPacket struct {
    30  	udpPacketEntry
    31  	senderAddress tcpip.FullAddress
    32  	data          buffer.VectorisedView
    33  	timestamp     int64
    34  	// views is used as buffer for data when its length is large
    35  	// enough to store a VectorisedView.
    36  	views [8]buffer.View
    37  }
    38  
    39  // EndpointState represents the state of a UDP endpoint.
    40  type EndpointState uint32
    41  
    42  // Endpoint states. Note that are represented in a netstack-specific manner and
    43  // may not be meaningful externally. Specifically, they need to be translated to
    44  // Linux's representation for these states if presented to userspace.
    45  const (
    46  	StateInitial EndpointState = iota
    47  	StateBound
    48  	StateConnected
    49  	StateClosed
    50  )
    51  
    52  // endpoint represents a UDP endpoint. This struct serves as the interface
    53  // between users of the endpoint and the protocol implementation; it is legal to
    54  // have concurrent goroutines make calls into the endpoint, they are properly
    55  // synchronized.
    56  //
    57  // It implements tcpip.Endpoint.
    58  //
    59  // +stateify savable
    60  type endpoint struct {
    61  	// The following fields are initialized at creation time and do not
    62  	// change throughout the lifetime of the endpoint.
    63  	stack       *stack.Stack
    64  	netProto    tcpip.NetworkProtocolNumber
    65  	waiterQueue *waiter.Queue
    66  
    67  	// The following fields are used to manage the receive queue, and are
    68  	// protected by rcvMu.
    69  	rcvMu         sync.Mutex
    70  	rcvReady      bool
    71  	rcvList       udpPacketList
    72  	rcvBufSizeMax int
    73  	rcvBufSize    int
    74  	rcvClosed     bool
    75  
    76  	// The following fields are protected by the mu mutex.
    77  	mu             sync.RWMutex
    78  	sndBufSize     int
    79  	id             stack.TransportEndpointID
    80  	state          EndpointState
    81  	bindNICID      tcpip.NICID
    82  	regNICID       tcpip.NICID
    83  	route          stack.Route
    84  	dstPort        uint16
    85  	v6only         bool
    86  	ttl            uint8
    87  	multicastTTL   uint8
    88  	multicastAddr  tcpip.Address
    89  	multicastNICID tcpip.NICID
    90  	multicastLoop  bool
    91  	reusePort      bool
    92  	bindToDevice   tcpip.NICID
    93  	broadcast      bool
    94  
    95  	// shutdownFlags represent the current shutdown state of the endpoint.
    96  	shutdownFlags tcpip.ShutdownFlags
    97  
    98  	// multicastMemberships that need to be remvoed when the endpoint is
    99  	// closed. Protected by the mu mutex.
   100  	multicastMemberships []multicastMembership
   101  
   102  	// effectiveNetProtos contains the network protocols actually in use. In
   103  	// most cases it will only contain "netProto", but in cases like IPv6
   104  	// endpoints with v6only set to false, this could include multiple
   105  	// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
   106  	// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
   107  	// address).
   108  	effectiveNetProtos []tcpip.NetworkProtocolNumber
   109  }
   110  
   111  // +stateify savable
   112  type multicastMembership struct {
   113  	nicID         tcpip.NICID
   114  	multicastAddr tcpip.Address
   115  }
   116  
   117  func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
   118  	return &endpoint{
   119  		stack:       stack,
   120  		netProto:    netProto,
   121  		waiterQueue: waiterQueue,
   122  		// RFC 1075 section 5.4 recommends a TTL of 1 for membership
   123  		// requests.
   124  		//
   125  		// RFC 5135 4.2.1 appears to assume that IGMP messages have a
   126  		// TTL of 1.
   127  		//
   128  		// RFC 5135 Appendix A defines TTL=1: A multicast source that
   129  		// wants its traffic to not traverse a router (e.g., leave a
   130  		// home network) may find it useful to send traffic with IP
   131  		// TTL=1.
   132  		//
   133  		// Linux defaults to TTL=1.
   134  		multicastTTL:  1,
   135  		multicastLoop: true,
   136  		rcvBufSizeMax: 32 * 1024,
   137  		sndBufSize:    32 * 1024,
   138  	}
   139  }
   140  
   141  // Close puts the endpoint in a closed state and frees all resources
   142  // associated with it.
   143  func (e *endpoint) Close() {
   144  	e.mu.Lock()
   145  	e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
   146  
   147  	switch e.state {
   148  	case StateBound, StateConnected:
   149  		e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
   150  		e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
   151  	}
   152  
   153  	for _, mem := range e.multicastMemberships {
   154  		e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
   155  	}
   156  	e.multicastMemberships = nil
   157  
   158  	// Close the receive list and drain it.
   159  	e.rcvMu.Lock()
   160  	e.rcvClosed = true
   161  	e.rcvBufSize = 0
   162  	for !e.rcvList.Empty() {
   163  		p := e.rcvList.Front()
   164  		e.rcvList.Remove(p)
   165  	}
   166  	e.rcvMu.Unlock()
   167  
   168  	e.route.Release()
   169  
   170  	// Update the state.
   171  	e.state = StateClosed
   172  
   173  	e.mu.Unlock()
   174  
   175  	e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
   176  }
   177  
   178  // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
   179  func (e *endpoint) ModerateRecvBuf(copied int) {}
   180  
   181  // IPTables implements tcpip.Endpoint.IPTables.
   182  func (e *endpoint) IPTables() (iptables.IPTables, error) {
   183  	return e.stack.IPTables(), nil
   184  }
   185  
   186  // Read reads data from the endpoint. This method does not block if
   187  // there is no data pending.
   188  func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
   189  	e.rcvMu.Lock()
   190  
   191  	if e.rcvList.Empty() {
   192  		err := tcpip.ErrWouldBlock
   193  		if e.rcvClosed {
   194  			err = tcpip.ErrClosedForReceive
   195  		}
   196  		e.rcvMu.Unlock()
   197  		return buffer.View{}, tcpip.ControlMessages{}, err
   198  	}
   199  
   200  	p := e.rcvList.Front()
   201  	e.rcvList.Remove(p)
   202  	e.rcvBufSize -= p.data.Size()
   203  	e.rcvMu.Unlock()
   204  
   205  	if addr != nil {
   206  		*addr = p.senderAddress
   207  	}
   208  
   209  	return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
   210  }
   211  
   212  // prepareForWrite prepares the endpoint for sending data. In particular, it
   213  // binds it if it's still in the initial state. To do so, it must first
   214  // reacquire the mutex in exclusive mode.
   215  //
   216  // Returns true for retry if preparation should be retried.
   217  func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
   218  	switch e.state {
   219  	case StateInitial:
   220  	case StateConnected:
   221  		return false, nil
   222  
   223  	case StateBound:
   224  		if to == nil {
   225  			return false, tcpip.ErrDestinationRequired
   226  		}
   227  		return false, nil
   228  	default:
   229  		return false, tcpip.ErrInvalidEndpointState
   230  	}
   231  
   232  	e.mu.RUnlock()
   233  	defer e.mu.RLock()
   234  
   235  	e.mu.Lock()
   236  	defer e.mu.Unlock()
   237  
   238  	// The state changed when we released the shared locked and re-acquired
   239  	// it in exclusive mode. Try again.
   240  	if e.state != StateInitial {
   241  		return true, nil
   242  	}
   243  
   244  	// The state is still 'initial', so try to bind the endpoint.
   245  	if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
   246  		return false, err
   247  	}
   248  
   249  	return true, nil
   250  }
   251  
   252  // connectRoute establishes a route to the specified interface or the
   253  // configured multicast interface if no interface is specified and the
   254  // specified address is a multicast address.
   255  func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
   256  	localAddr := e.id.LocalAddress
   257  	if isBroadcastOrMulticast(localAddr) {
   258  		// A packet can only originate from a unicast address (i.e., an interface).
   259  		localAddr = ""
   260  	}
   261  
   262  	if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
   263  		if nicid == 0 {
   264  			nicid = e.multicastNICID
   265  		}
   266  		if localAddr == "" && nicid == 0 {
   267  			localAddr = e.multicastAddr
   268  		}
   269  	}
   270  
   271  	// Find a route to the desired destination.
   272  	r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
   273  	if err != nil {
   274  		return stack.Route{}, 0, err
   275  	}
   276  	return r, nicid, nil
   277  }
   278  
   279  // Write writes data to the endpoint's peer. This method does not block
   280  // if the data cannot be written.
   281  func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
   282  	// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
   283  	if opts.More {
   284  		return 0, nil, tcpip.ErrInvalidOptionValue
   285  	}
   286  
   287  	to := opts.To
   288  
   289  	e.mu.RLock()
   290  	defer e.mu.RUnlock()
   291  
   292  	// If we've shutdown with SHUT_WR we are in an invalid state for sending.
   293  	if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
   294  		return 0, nil, tcpip.ErrClosedForSend
   295  	}
   296  
   297  	// Prepare for write.
   298  	for {
   299  		retry, err := e.prepareForWrite(to)
   300  		if err != nil {
   301  			return 0, nil, err
   302  		}
   303  
   304  		if !retry {
   305  			break
   306  		}
   307  	}
   308  
   309  	var route *stack.Route
   310  	var dstPort uint16
   311  	if to == nil {
   312  		route = &e.route
   313  		dstPort = e.dstPort
   314  
   315  		if route.IsResolutionRequired() {
   316  			// Promote lock to exclusive if using a shared route, given that it may need to
   317  			// change in Route.Resolve() call below.
   318  			e.mu.RUnlock()
   319  			defer e.mu.RLock()
   320  
   321  			e.mu.Lock()
   322  			defer e.mu.Unlock()
   323  
   324  			// Recheck state after lock was re-acquired.
   325  			if e.state != StateConnected {
   326  				return 0, nil, tcpip.ErrInvalidEndpointState
   327  			}
   328  		}
   329  	} else {
   330  		// Reject destination address if it goes through a different
   331  		// NIC than the endpoint was bound to.
   332  		nicid := to.NIC
   333  		if e.bindNICID != 0 {
   334  			if nicid != 0 && nicid != e.bindNICID {
   335  				return 0, nil, tcpip.ErrNoRoute
   336  			}
   337  
   338  			nicid = e.bindNICID
   339  		}
   340  
   341  		if to.Addr == header.IPv4Broadcast && !e.broadcast {
   342  			return 0, nil, tcpip.ErrBroadcastDisabled
   343  		}
   344  
   345  		netProto, err := e.checkV4Mapped(to, false)
   346  		if err != nil {
   347  			return 0, nil, err
   348  		}
   349  
   350  		r, _, err := e.connectRoute(nicid, *to, netProto)
   351  		if err != nil {
   352  			return 0, nil, err
   353  		}
   354  		defer r.Release()
   355  
   356  		route = &r
   357  		dstPort = to.Port
   358  	}
   359  
   360  	if route.IsResolutionRequired() {
   361  		if ch, err := route.Resolve(nil); err != nil {
   362  			if err == tcpip.ErrWouldBlock {
   363  				return 0, ch, tcpip.ErrNoLinkAddress
   364  			}
   365  			return 0, nil, err
   366  		}
   367  	}
   368  
   369  	v, err := p.FullPayload()
   370  	if err != nil {
   371  		return 0, nil, err
   372  	}
   373  	if len(v) > header.UDPMaximumPacketSize {
   374  		// Payload can't possibly fit in a packet.
   375  		return 0, nil, tcpip.ErrMessageTooLong
   376  	}
   377  
   378  	ttl := e.ttl
   379  	useDefaultTTL := ttl == 0
   380  
   381  	if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
   382  		ttl = e.multicastTTL
   383  		// Multicast allows a 0 TTL.
   384  		useDefaultTTL = false
   385  	}
   386  
   387  	if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl, useDefaultTTL); err != nil {
   388  		return 0, nil, err
   389  	}
   390  	return int64(len(v)), nil, nil
   391  }
   392  
   393  // Peek only returns data from a single datagram, so do nothing here.
   394  func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
   395  	return 0, tcpip.ControlMessages{}, nil
   396  }
   397  
   398  // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
   399  func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
   400  	return nil
   401  }
   402  
   403  // SetSockOpt implements tcpip.Endpoint.SetSockOpt.
   404  func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
   405  	switch v := opt.(type) {
   406  	case tcpip.V6OnlyOption:
   407  		// We only recognize this option on v6 endpoints.
   408  		if e.netProto != header.IPv6ProtocolNumber {
   409  			return tcpip.ErrInvalidEndpointState
   410  		}
   411  
   412  		e.mu.Lock()
   413  		defer e.mu.Unlock()
   414  
   415  		// We only allow this to be set when we're in the initial state.
   416  		if e.state != StateInitial {
   417  			return tcpip.ErrInvalidEndpointState
   418  		}
   419  
   420  		e.v6only = v != 0
   421  
   422  	case tcpip.TTLOption:
   423  		e.mu.Lock()
   424  		e.ttl = uint8(v)
   425  		e.mu.Unlock()
   426  
   427  	case tcpip.MulticastTTLOption:
   428  		e.mu.Lock()
   429  		e.multicastTTL = uint8(v)
   430  		e.mu.Unlock()
   431  
   432  	case tcpip.MulticastInterfaceOption:
   433  		e.mu.Lock()
   434  		defer e.mu.Unlock()
   435  
   436  		fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
   437  		netProto, err := e.checkV4Mapped(&fa, false)
   438  		if err != nil {
   439  			return err
   440  		}
   441  		nic := v.NIC
   442  		addr := fa.Addr
   443  
   444  		if nic == 0 && addr == "" {
   445  			e.multicastAddr = ""
   446  			e.multicastNICID = 0
   447  			break
   448  		}
   449  
   450  		if nic != 0 {
   451  			if !e.stack.CheckNIC(nic) {
   452  				return tcpip.ErrBadLocalAddress
   453  			}
   454  		} else {
   455  			nic = e.stack.CheckLocalAddress(0, netProto, addr)
   456  			if nic == 0 {
   457  				return tcpip.ErrBadLocalAddress
   458  			}
   459  		}
   460  
   461  		if e.bindNICID != 0 && e.bindNICID != nic {
   462  			return tcpip.ErrInvalidEndpointState
   463  		}
   464  
   465  		e.multicastNICID = nic
   466  		e.multicastAddr = addr
   467  
   468  	case tcpip.AddMembershipOption:
   469  		if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
   470  			return tcpip.ErrInvalidOptionValue
   471  		}
   472  
   473  		nicID := v.NIC
   474  
   475  		// The interface address is considered not-set if it is empty or contains
   476  		// all-zeros. The former represent the zero-value in golang, the latter the
   477  		// same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
   478  		allZeros := header.IPv4Any
   479  		if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
   480  			if nicID == 0 {
   481  				r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
   482  				if err == nil {
   483  					nicID = r.NICID()
   484  					r.Release()
   485  				}
   486  			}
   487  		} else {
   488  			nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
   489  		}
   490  		if nicID == 0 {
   491  			return tcpip.ErrUnknownDevice
   492  		}
   493  
   494  		memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
   495  
   496  		e.mu.Lock()
   497  		defer e.mu.Unlock()
   498  
   499  		for _, mem := range e.multicastMemberships {
   500  			if mem == memToInsert {
   501  				return tcpip.ErrPortInUse
   502  			}
   503  		}
   504  
   505  		if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
   506  			return err
   507  		}
   508  
   509  		e.multicastMemberships = append(e.multicastMemberships, memToInsert)
   510  
   511  	case tcpip.RemoveMembershipOption:
   512  		if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
   513  			return tcpip.ErrInvalidOptionValue
   514  		}
   515  
   516  		nicID := v.NIC
   517  		if v.InterfaceAddr == header.IPv4Any {
   518  			if nicID == 0 {
   519  				r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
   520  				if err == nil {
   521  					nicID = r.NICID()
   522  					r.Release()
   523  				}
   524  			}
   525  		} else {
   526  			nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
   527  		}
   528  		if nicID == 0 {
   529  			return tcpip.ErrUnknownDevice
   530  		}
   531  
   532  		memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
   533  		memToRemoveIndex := -1
   534  
   535  		e.mu.Lock()
   536  		defer e.mu.Unlock()
   537  
   538  		for i, mem := range e.multicastMemberships {
   539  			if mem == memToRemove {
   540  				memToRemoveIndex = i
   541  				break
   542  			}
   543  		}
   544  		if memToRemoveIndex == -1 {
   545  			return tcpip.ErrBadLocalAddress
   546  		}
   547  
   548  		if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
   549  			return err
   550  		}
   551  
   552  		e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
   553  		e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
   554  
   555  	case tcpip.MulticastLoopOption:
   556  		e.mu.Lock()
   557  		e.multicastLoop = bool(v)
   558  		e.mu.Unlock()
   559  
   560  	case tcpip.ReusePortOption:
   561  		e.mu.Lock()
   562  		e.reusePort = v != 0
   563  		e.mu.Unlock()
   564  
   565  	case tcpip.BindToDeviceOption:
   566  		e.mu.Lock()
   567  		defer e.mu.Unlock()
   568  		if v == "" {
   569  			e.bindToDevice = 0
   570  			return nil
   571  		}
   572  		for nicid, nic := range e.stack.NICInfo() {
   573  			if nic.Name == string(v) {
   574  				e.bindToDevice = nicid
   575  				return nil
   576  			}
   577  		}
   578  		return tcpip.ErrUnknownDevice
   579  
   580  	case tcpip.BroadcastOption:
   581  		e.mu.Lock()
   582  		e.broadcast = v != 0
   583  		e.mu.Unlock()
   584  
   585  		return nil
   586  	}
   587  	return nil
   588  }
   589  
   590  // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
   591  func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
   592  	switch opt {
   593  	case tcpip.ReceiveQueueSizeOption:
   594  		v := 0
   595  		e.rcvMu.Lock()
   596  		if !e.rcvList.Empty() {
   597  			p := e.rcvList.Front()
   598  			v = p.data.Size()
   599  		}
   600  		e.rcvMu.Unlock()
   601  		return v, nil
   602  
   603  	case tcpip.SendBufferSizeOption:
   604  		e.mu.Lock()
   605  		v := e.sndBufSize
   606  		e.mu.Unlock()
   607  		return v, nil
   608  
   609  	case tcpip.ReceiveBufferSizeOption:
   610  		e.rcvMu.Lock()
   611  		v := e.rcvBufSizeMax
   612  		e.rcvMu.Unlock()
   613  		return v, nil
   614  	}
   615  
   616  	return -1, tcpip.ErrUnknownProtocolOption
   617  }
   618  
   619  // GetSockOpt implements tcpip.Endpoint.GetSockOpt.
   620  func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
   621  	switch o := opt.(type) {
   622  	case tcpip.ErrorOption:
   623  		return nil
   624  
   625  	case *tcpip.V6OnlyOption:
   626  		// We only recognize this option on v6 endpoints.
   627  		if e.netProto != header.IPv6ProtocolNumber {
   628  			return tcpip.ErrUnknownProtocolOption
   629  		}
   630  
   631  		e.mu.Lock()
   632  		v := e.v6only
   633  		e.mu.Unlock()
   634  
   635  		*o = 0
   636  		if v {
   637  			*o = 1
   638  		}
   639  		return nil
   640  
   641  	case *tcpip.TTLOption:
   642  		e.mu.Lock()
   643  		*o = tcpip.TTLOption(e.ttl)
   644  		e.mu.Unlock()
   645  		return nil
   646  
   647  	case *tcpip.MulticastTTLOption:
   648  		e.mu.Lock()
   649  		*o = tcpip.MulticastTTLOption(e.multicastTTL)
   650  		e.mu.Unlock()
   651  		return nil
   652  
   653  	case *tcpip.MulticastInterfaceOption:
   654  		e.mu.Lock()
   655  		*o = tcpip.MulticastInterfaceOption{
   656  			e.multicastNICID,
   657  			e.multicastAddr,
   658  		}
   659  		e.mu.Unlock()
   660  		return nil
   661  
   662  	case *tcpip.MulticastLoopOption:
   663  		e.mu.RLock()
   664  		v := e.multicastLoop
   665  		e.mu.RUnlock()
   666  
   667  		*o = tcpip.MulticastLoopOption(v)
   668  		return nil
   669  
   670  	case *tcpip.ReusePortOption:
   671  		e.mu.RLock()
   672  		v := e.reusePort
   673  		e.mu.RUnlock()
   674  
   675  		*o = 0
   676  		if v {
   677  			*o = 1
   678  		}
   679  		return nil
   680  
   681  	case *tcpip.BindToDeviceOption:
   682  		e.mu.RLock()
   683  		defer e.mu.RUnlock()
   684  		if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
   685  			*o = tcpip.BindToDeviceOption(nic.Name)
   686  			return nil
   687  		}
   688  		*o = tcpip.BindToDeviceOption("")
   689  		return nil
   690  
   691  	case *tcpip.KeepaliveEnabledOption:
   692  		*o = 0
   693  		return nil
   694  
   695  	case *tcpip.BroadcastOption:
   696  		e.mu.RLock()
   697  		v := e.broadcast
   698  		e.mu.RUnlock()
   699  
   700  		*o = 0
   701  		if v {
   702  			*o = 1
   703  		}
   704  		return nil
   705  
   706  	default:
   707  		return tcpip.ErrUnknownProtocolOption
   708  	}
   709  }
   710  
   711  // sendUDP sends a UDP segment via the provided network endpoint and under the
   712  // provided identity.
   713  func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool) *tcpip.Error {
   714  	// Allocate a buffer for the UDP header.
   715  	hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
   716  
   717  	// Initialize the header.
   718  	udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
   719  
   720  	length := uint16(hdr.UsedLength() + data.Size())
   721  	udp.Encode(&header.UDPFields{
   722  		SrcPort: localPort,
   723  		DstPort: remotePort,
   724  		Length:  length,
   725  	})
   726  
   727  	// Only calculate the checksum if offloading isn't supported.
   728  	if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
   729  		xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
   730  		for _, v := range data.Views() {
   731  			xsum = header.Checksum(v, xsum)
   732  		}
   733  		udp.SetChecksum(^udp.CalculateChecksum(xsum))
   734  	}
   735  
   736  	// Track count of packets sent.
   737  	r.Stats().UDP.PacketsSent.Increment()
   738  
   739  	return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL)
   740  }
   741  
   742  func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
   743  	netProto := e.netProto
   744  	if len(addr.Addr) == 0 {
   745  		return netProto, nil
   746  	}
   747  	if header.IsV4MappedAddress(addr.Addr) {
   748  		// Fail if using a v4 mapped address on a v6only endpoint.
   749  		if e.v6only {
   750  			return 0, tcpip.ErrNoRoute
   751  		}
   752  
   753  		netProto = header.IPv4ProtocolNumber
   754  		addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
   755  		if addr.Addr == header.IPv4Any {
   756  			addr.Addr = ""
   757  		}
   758  
   759  		// Fail if we are bound to an IPv6 address.
   760  		if !allowMismatch && len(e.id.LocalAddress) == 16 {
   761  			return 0, tcpip.ErrNetworkUnreachable
   762  		}
   763  	}
   764  
   765  	// Fail if we're bound to an address length different from the one we're
   766  	// checking.
   767  	if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
   768  		return 0, tcpip.ErrInvalidEndpointState
   769  	}
   770  
   771  	return netProto, nil
   772  }
   773  
   774  // Disconnect implements tcpip.Endpoint.Disconnect.
   775  func (e *endpoint) Disconnect() *tcpip.Error {
   776  	e.mu.Lock()
   777  	defer e.mu.Unlock()
   778  
   779  	if e.state != StateConnected {
   780  		return nil
   781  	}
   782  	id := stack.TransportEndpointID{}
   783  	// Exclude ephemerally bound endpoints.
   784  	if e.bindNICID != 0 || e.id.LocalAddress == "" {
   785  		var err *tcpip.Error
   786  		id = stack.TransportEndpointID{
   787  			LocalPort:    e.id.LocalPort,
   788  			LocalAddress: e.id.LocalAddress,
   789  		}
   790  		id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
   791  		if err != nil {
   792  			return err
   793  		}
   794  		e.state = StateBound
   795  	} else {
   796  		if e.id.LocalPort != 0 {
   797  			// Release the ephemeral port.
   798  			e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
   799  		}
   800  		e.state = StateInitial
   801  	}
   802  
   803  	e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
   804  	e.id = id
   805  	e.route.Release()
   806  	e.route = stack.Route{}
   807  	e.dstPort = 0
   808  
   809  	return nil
   810  }
   811  
   812  // Connect connects the endpoint to its peer. Specifying a NIC is optional.
   813  func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
   814  	netProto, err := e.checkV4Mapped(&addr, false)
   815  	if err != nil {
   816  		return err
   817  	}
   818  	if addr.Port == 0 {
   819  		// We don't support connecting to port zero.
   820  		return tcpip.ErrInvalidEndpointState
   821  	}
   822  
   823  	e.mu.Lock()
   824  	defer e.mu.Unlock()
   825  
   826  	nicid := addr.NIC
   827  	var localPort uint16
   828  	switch e.state {
   829  	case StateInitial:
   830  	case StateBound, StateConnected:
   831  		localPort = e.id.LocalPort
   832  		if e.bindNICID == 0 {
   833  			break
   834  		}
   835  
   836  		if nicid != 0 && nicid != e.bindNICID {
   837  			return tcpip.ErrInvalidEndpointState
   838  		}
   839  
   840  		nicid = e.bindNICID
   841  	default:
   842  		return tcpip.ErrInvalidEndpointState
   843  	}
   844  
   845  	r, nicid, err := e.connectRoute(nicid, addr, netProto)
   846  	if err != nil {
   847  		return err
   848  	}
   849  	defer r.Release()
   850  
   851  	id := stack.TransportEndpointID{
   852  		LocalAddress:  e.id.LocalAddress,
   853  		LocalPort:     localPort,
   854  		RemotePort:    addr.Port,
   855  		RemoteAddress: r.RemoteAddress,
   856  	}
   857  
   858  	if e.state == StateInitial {
   859  		id.LocalAddress = r.LocalAddress
   860  	}
   861  
   862  	// Even if we're connected, this endpoint can still be used to send
   863  	// packets on a different network protocol, so we register both even if
   864  	// v6only is set to false and this is an ipv6 endpoint.
   865  	netProtos := []tcpip.NetworkProtocolNumber{netProto}
   866  	if netProto == header.IPv6ProtocolNumber && !e.v6only {
   867  		netProtos = []tcpip.NetworkProtocolNumber{
   868  			header.IPv4ProtocolNumber,
   869  			header.IPv6ProtocolNumber,
   870  		}
   871  	}
   872  
   873  	id, err = e.registerWithStack(nicid, netProtos, id)
   874  	if err != nil {
   875  		return err
   876  	}
   877  
   878  	// Remove the old registration.
   879  	if e.id.LocalPort != 0 {
   880  		e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
   881  	}
   882  
   883  	e.id = id
   884  	e.route = r.Clone()
   885  	e.dstPort = addr.Port
   886  	e.regNICID = nicid
   887  	e.effectiveNetProtos = netProtos
   888  
   889  	e.state = StateConnected
   890  
   891  	e.rcvMu.Lock()
   892  	e.rcvReady = true
   893  	e.rcvMu.Unlock()
   894  
   895  	return nil
   896  }
   897  
   898  // ConnectEndpoint is not supported.
   899  func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
   900  	return tcpip.ErrInvalidEndpointState
   901  }
   902  
   903  // Shutdown closes the read and/or write end of the endpoint connection
   904  // to its peer.
   905  func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
   906  	e.mu.Lock()
   907  	defer e.mu.Unlock()
   908  
   909  	// A socket in the bound state can still receive multicast messages,
   910  	// so we need to notify waiters on shutdown.
   911  	if e.state != StateBound && e.state != StateConnected {
   912  		return tcpip.ErrNotConnected
   913  	}
   914  
   915  	e.shutdownFlags |= flags
   916  
   917  	if flags&tcpip.ShutdownRead != 0 {
   918  		e.rcvMu.Lock()
   919  		wasClosed := e.rcvClosed
   920  		e.rcvClosed = true
   921  		e.rcvMu.Unlock()
   922  
   923  		if !wasClosed {
   924  			e.waiterQueue.Notify(waiter.EventIn)
   925  		}
   926  	}
   927  
   928  	return nil
   929  }
   930  
   931  // Listen is not supported by UDP, it just fails.
   932  func (*endpoint) Listen(int) *tcpip.Error {
   933  	return tcpip.ErrNotSupported
   934  }
   935  
   936  // Accept is not supported by UDP, it just fails.
   937  func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
   938  	return nil, nil, tcpip.ErrNotSupported
   939  }
   940  
   941  func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
   942  	if e.id.LocalPort == 0 {
   943  		port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
   944  		if err != nil {
   945  			return id, err
   946  		}
   947  		id.LocalPort = port
   948  	}
   949  
   950  	err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
   951  	if err != nil {
   952  		e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
   953  	}
   954  	return id, err
   955  }
   956  
   957  func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
   958  	// Don't allow binding once endpoint is not in the initial state
   959  	// anymore.
   960  	if e.state != StateInitial {
   961  		return tcpip.ErrInvalidEndpointState
   962  	}
   963  
   964  	netProto, err := e.checkV4Mapped(&addr, true)
   965  	if err != nil {
   966  		return err
   967  	}
   968  
   969  	// Expand netProtos to include v4 and v6 if the caller is binding to a
   970  	// wildcard (empty) address, and this is an IPv6 endpoint with v6only
   971  	// set to false.
   972  	netProtos := []tcpip.NetworkProtocolNumber{netProto}
   973  	if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
   974  		netProtos = []tcpip.NetworkProtocolNumber{
   975  			header.IPv6ProtocolNumber,
   976  			header.IPv4ProtocolNumber,
   977  		}
   978  	}
   979  
   980  	nicid := addr.NIC
   981  	if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
   982  		// A local unicast address was specified, verify that it's valid.
   983  		nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
   984  		if nicid == 0 {
   985  			return tcpip.ErrBadLocalAddress
   986  		}
   987  	}
   988  
   989  	id := stack.TransportEndpointID{
   990  		LocalPort:    addr.Port,
   991  		LocalAddress: addr.Addr,
   992  	}
   993  	id, err = e.registerWithStack(nicid, netProtos, id)
   994  	if err != nil {
   995  		return err
   996  	}
   997  
   998  	e.id = id
   999  	e.regNICID = nicid
  1000  	e.effectiveNetProtos = netProtos
  1001  
  1002  	// Mark endpoint as bound.
  1003  	e.state = StateBound
  1004  
  1005  	e.rcvMu.Lock()
  1006  	e.rcvReady = true
  1007  	e.rcvMu.Unlock()
  1008  
  1009  	return nil
  1010  }
  1011  
  1012  // Bind binds the endpoint to a specific local address and port.
  1013  // Specifying a NIC is optional.
  1014  func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
  1015  	e.mu.Lock()
  1016  	defer e.mu.Unlock()
  1017  
  1018  	err := e.bindLocked(addr)
  1019  	if err != nil {
  1020  		return err
  1021  	}
  1022  
  1023  	// Save the effective NICID generated by bindLocked.
  1024  	e.bindNICID = e.regNICID
  1025  
  1026  	return nil
  1027  }
  1028  
  1029  // GetLocalAddress returns the address to which the endpoint is bound.
  1030  func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
  1031  	e.mu.RLock()
  1032  	defer e.mu.RUnlock()
  1033  
  1034  	return tcpip.FullAddress{
  1035  		NIC:  e.regNICID,
  1036  		Addr: e.id.LocalAddress,
  1037  		Port: e.id.LocalPort,
  1038  	}, nil
  1039  }
  1040  
  1041  // GetRemoteAddress returns the address to which the endpoint is connected.
  1042  func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
  1043  	e.mu.RLock()
  1044  	defer e.mu.RUnlock()
  1045  
  1046  	if e.state != StateConnected {
  1047  		return tcpip.FullAddress{}, tcpip.ErrNotConnected
  1048  	}
  1049  
  1050  	return tcpip.FullAddress{
  1051  		NIC:  e.regNICID,
  1052  		Addr: e.id.RemoteAddress,
  1053  		Port: e.id.RemotePort,
  1054  	}, nil
  1055  }
  1056  
  1057  // Readiness returns the current readiness of the endpoint. For example, if
  1058  // waiter.EventIn is set, the endpoint is immediately readable.
  1059  func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
  1060  	// The endpoint is always writable.
  1061  	result := waiter.EventOut & mask
  1062  
  1063  	// Determine if the endpoint is readable if requested.
  1064  	if (mask & waiter.EventIn) != 0 {
  1065  		e.rcvMu.Lock()
  1066  		if !e.rcvList.Empty() || e.rcvClosed {
  1067  			result |= waiter.EventIn
  1068  		}
  1069  		e.rcvMu.Unlock()
  1070  	}
  1071  
  1072  	return result
  1073  }
  1074  
  1075  // HandlePacket is called by the stack when new packets arrive to this transport
  1076  // endpoint.
  1077  func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
  1078  	// Get the header then trim it from the view.
  1079  	hdr := header.UDP(vv.First())
  1080  	if int(hdr.Length()) > vv.Size() {
  1081  		// Malformed packet.
  1082  		e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
  1083  		return
  1084  	}
  1085  
  1086  	vv.TrimFront(header.UDPMinimumSize)
  1087  
  1088  	e.rcvMu.Lock()
  1089  	e.stack.Stats().UDP.PacketsReceived.Increment()
  1090  
  1091  	// Drop the packet if our buffer is currently full.
  1092  	if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
  1093  		e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
  1094  		e.rcvMu.Unlock()
  1095  		return
  1096  	}
  1097  
  1098  	wasEmpty := e.rcvBufSize == 0
  1099  
  1100  	// Push new packet into receive list and increment the buffer size.
  1101  	pkt := &udpPacket{
  1102  		senderAddress: tcpip.FullAddress{
  1103  			NIC:  r.NICID(),
  1104  			Addr: id.RemoteAddress,
  1105  			Port: hdr.SourcePort(),
  1106  		},
  1107  	}
  1108  	pkt.data = vv.Clone(pkt.views[:])
  1109  	e.rcvList.PushBack(pkt)
  1110  	e.rcvBufSize += vv.Size()
  1111  
  1112  	pkt.timestamp = e.stack.NowNanoseconds()
  1113  
  1114  	e.rcvMu.Unlock()
  1115  
  1116  	// Notify any waiters that there's data to be read now.
  1117  	if wasEmpty {
  1118  		e.waiterQueue.Notify(waiter.EventIn)
  1119  	}
  1120  }
  1121  
  1122  // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
  1123  func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
  1124  }
  1125  
  1126  // State implements tcpip.Endpoint.State.
  1127  func (e *endpoint) State() uint32 {
  1128  	e.mu.Lock()
  1129  	defer e.mu.Unlock()
  1130  	return uint32(e.state)
  1131  }
  1132  
  1133  func isBroadcastOrMulticast(a tcpip.Address) bool {
  1134  	return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
  1135  }