inet.af/netstack@v0.0.0-20220214151720-7585b01ddccf/tcpip/stack/transport_demuxer.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"inet.af/netstack/sync"
    21  	"inet.af/netstack/tcpip"
    22  	"inet.af/netstack/tcpip/hash/jenkins"
    23  	"inet.af/netstack/tcpip/header"
    24  	"inet.af/netstack/tcpip/ports"
    25  )
    26  
    27  type protocolIDs struct {
    28  	network   tcpip.NetworkProtocolNumber
    29  	transport tcpip.TransportProtocolNumber
    30  }
    31  
    32  // transportEndpoints manages all endpoints of a given protocol. It has its own
    33  // mutex so as to reduce interference between protocols.
    34  type transportEndpoints struct {
    35  	mu sync.RWMutex
    36  	// +checklocks:mu
    37  	endpoints map[TransportEndpointID]*endpointsByNIC
    38  	// rawEndpoints contains endpoints for raw sockets, which receive all
    39  	// traffic of a given protocol regardless of port.
    40  	//
    41  	// +checklocks:mu
    42  	rawEndpoints []RawTransportEndpoint
    43  }
    44  
    45  // unregisterEndpoint unregisters the endpoint with the given id such that it
    46  // won't receive any more packets.
    47  func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
    48  	eps.mu.Lock()
    49  	defer eps.mu.Unlock()
    50  	epsByNIC, ok := eps.endpoints[id]
    51  	if !ok {
    52  		return
    53  	}
    54  	if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) {
    55  		return
    56  	}
    57  	delete(eps.endpoints, id)
    58  }
    59  
    60  func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
    61  	eps.mu.RLock()
    62  	defer eps.mu.RUnlock()
    63  	es := make([]TransportEndpoint, 0, len(eps.endpoints))
    64  	for _, e := range eps.endpoints {
    65  		es = append(es, e.transportEndpoints()...)
    66  	}
    67  	return es
    68  }
    69  
    70  // iterEndpointsLocked yields all endpointsByNIC in eps that match id, in
    71  // descending order of match quality. If a call to yield returns false,
    72  // iterEndpointsLocked stops iteration and returns immediately.
    73  //
    74  // +checklocksread:eps.mu
    75  func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
    76  	// Try to find a match with the id as provided.
    77  	if ep, ok := eps.endpoints[id]; ok {
    78  		if !yield(ep) {
    79  			return
    80  		}
    81  	}
    82  
    83  	// Try to find a match with the id minus the local address.
    84  	nid := id
    85  
    86  	nid.LocalAddress = ""
    87  	if ep, ok := eps.endpoints[nid]; ok {
    88  		if !yield(ep) {
    89  			return
    90  		}
    91  	}
    92  
    93  	// Try to find a match with the id minus the remote part.
    94  	nid.LocalAddress = id.LocalAddress
    95  	nid.RemoteAddress = ""
    96  	nid.RemotePort = 0
    97  	if ep, ok := eps.endpoints[nid]; ok {
    98  		if !yield(ep) {
    99  			return
   100  		}
   101  	}
   102  
   103  	// Try to find a match with only the local port.
   104  	nid.LocalAddress = ""
   105  	if ep, ok := eps.endpoints[nid]; ok {
   106  		if !yield(ep) {
   107  			return
   108  		}
   109  	}
   110  }
   111  
   112  // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
   113  // descending order of match quality.
   114  //
   115  // +checklocksread:eps.mu
   116  func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
   117  	var matchedEPs []*endpointsByNIC
   118  	eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
   119  		matchedEPs = append(matchedEPs, ep)
   120  		return true
   121  	})
   122  	return matchedEPs
   123  }
   124  
   125  // findEndpointLocked returns the endpoint that most closely matches the given id.
   126  //
   127  // +checklocksread:eps.mu
   128  func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
   129  	var matchedEP *endpointsByNIC
   130  	eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
   131  		matchedEP = ep
   132  		return false
   133  	})
   134  	return matchedEP
   135  }
   136  
   137  type endpointsByNIC struct {
   138  	// seed is a random secret for a jenkins hash.
   139  	seed uint32
   140  
   141  	mu sync.RWMutex
   142  	// +checklocks:mu
   143  	endpoints map[tcpip.NICID]*multiPortEndpoint
   144  }
   145  
   146  func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
   147  	epsByNIC.mu.RLock()
   148  	defer epsByNIC.mu.RUnlock()
   149  	var eps []TransportEndpoint
   150  	for _, ep := range epsByNIC.endpoints {
   151  		eps = append(eps, ep.transportEndpoints()...)
   152  	}
   153  	return eps
   154  }
   155  
   156  // handlePacket is called by the stack when new packets arrive to this transport
   157  // endpoint. It returns false if the packet could not be matched to any
   158  // transport endpoint, true otherwise.
   159  func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool {
   160  	epsByNIC.mu.RLock()
   161  
   162  	mpep, ok := epsByNIC.endpoints[pkt.NICID]
   163  	if !ok {
   164  		if mpep, ok = epsByNIC.endpoints[0]; !ok {
   165  			epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
   166  			return false
   167  		}
   168  	}
   169  
   170  	// If this is a broadcast or multicast datagram, deliver the datagram to all
   171  	// endpoints bound to the right device.
   172  	if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
   173  		mpep.handlePacketAll(id, pkt)
   174  		epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
   175  		return true
   176  	}
   177  	// multiPortEndpoints are guaranteed to have at least one element.
   178  	transEP := mpep.selectEndpoint(id, epsByNIC.seed)
   179  	if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
   180  		queuedProtocol.QueuePacket(transEP, id, pkt)
   181  		epsByNIC.mu.RUnlock()
   182  		return true
   183  	}
   184  
   185  	transEP.HandlePacket(id, pkt)
   186  	epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
   187  	return true
   188  }
   189  
   190  // handleError delivers an error to the transport endpoint identified by id.
   191  func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) {
   192  	epsByNIC.mu.RLock()
   193  	defer epsByNIC.mu.RUnlock()
   194  
   195  	mpep, ok := epsByNIC.endpoints[n.ID()]
   196  	if !ok {
   197  		mpep, ok = epsByNIC.endpoints[0]
   198  	}
   199  	if !ok {
   200  		return
   201  	}
   202  
   203  	// TODO(eyalsoha): Why don't we look at id to see if this packet needs to
   204  	// broadcast like we are doing with handlePacket above?
   205  
   206  	// multiPortEndpoints are guaranteed to have at least one element.
   207  	mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt)
   208  }
   209  
   210  // registerEndpoint returns true if it succeeds. It fails and returns
   211  // false if ep already has an element with the same key.
   212  func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   213  	epsByNIC.mu.Lock()
   214  	defer epsByNIC.mu.Unlock()
   215  
   216  	multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
   217  	if !ok {
   218  		multiPortEp = &multiPortEndpoint{
   219  			demux:      d,
   220  			netProto:   netProto,
   221  			transProto: transProto,
   222  		}
   223  	}
   224  
   225  	if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil {
   226  		return err
   227  	}
   228  	// Only add this newly created multiportEndpoint if the singleRegisterEndpoint
   229  	// succeeded.
   230  	if !ok {
   231  		epsByNIC.endpoints[bindToDevice] = multiPortEp
   232  	}
   233  	return nil
   234  }
   235  
   236  func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   237  	epsByNIC.mu.RLock()
   238  	defer epsByNIC.mu.RUnlock()
   239  
   240  	multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
   241  	if !ok {
   242  		return nil
   243  	}
   244  
   245  	return multiPortEp.singleCheckEndpoint(flags)
   246  }
   247  
   248  // unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
   249  func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
   250  	epsByNIC.mu.Lock()
   251  	defer epsByNIC.mu.Unlock()
   252  	multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
   253  	if !ok {
   254  		return false
   255  	}
   256  	if multiPortEp.unregisterEndpoint(t, flags) {
   257  		delete(epsByNIC.endpoints, bindToDevice)
   258  	}
   259  	return len(epsByNIC.endpoints) == 0
   260  }
   261  
   262  // transportDemuxer demultiplexes packets targeted at a transport endpoint
   263  // (i.e., after they've been parsed by the network layer). It does two levels
   264  // of demultiplexing: first based on the network and transport protocols, then
   265  // based on endpoints IDs. It should only be instantiated via
   266  // newTransportDemuxer.
   267  type transportDemuxer struct {
   268  	stack *Stack
   269  
   270  	// protocol is immutable.
   271  	protocol        map[protocolIDs]*transportEndpoints
   272  	queuedProtocols map[protocolIDs]queuedTransportProtocol
   273  }
   274  
   275  // queuedTransportProtocol if supported by a protocol implementation will cause
   276  // the dispatcher to delivery packets to the QueuePacket method instead of
   277  // calling HandlePacket directly on the endpoint.
   278  type queuedTransportProtocol interface {
   279  	QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
   280  }
   281  
   282  func newTransportDemuxer(stack *Stack) *transportDemuxer {
   283  	d := &transportDemuxer{
   284  		stack:           stack,
   285  		protocol:        make(map[protocolIDs]*transportEndpoints),
   286  		queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
   287  	}
   288  
   289  	// Add each network and transport pair to the demuxer.
   290  	for netProto := range stack.networkProtocols {
   291  		for proto := range stack.transportProtocols {
   292  			protoIDs := protocolIDs{netProto, proto}
   293  			d.protocol[protoIDs] = &transportEndpoints{
   294  				endpoints: make(map[TransportEndpointID]*endpointsByNIC),
   295  			}
   296  			qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
   297  			if isQueued {
   298  				d.queuedProtocols[protoIDs] = qTransProto
   299  			}
   300  		}
   301  	}
   302  
   303  	return d
   304  }
   305  
   306  // registerEndpoint registers the given endpoint with the dispatcher such that
   307  // packets that match the endpoint ID are delivered to it.
   308  func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   309  	for i, n := range netProtos {
   310  		if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
   311  			d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
   312  			return err
   313  		}
   314  	}
   315  
   316  	return nil
   317  }
   318  
   319  // checkEndpoint checks if an endpoint can be registered with the dispatcher.
   320  func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   321  	for _, n := range netProtos {
   322  		if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
   323  			return err
   324  		}
   325  	}
   326  
   327  	return nil
   328  }
   329  
   330  // multiPortEndpoint is a container for TransportEndpoints which are bound to
   331  // the same pair of address and port. endpointsArr always has at least one
   332  // element.
   333  //
   334  // FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
   335  // this to ensure that the underlying endpoints get saved/restored, but not not
   336  // use the restored copy.
   337  //
   338  // +stateify savable
   339  type multiPortEndpoint struct {
   340  	demux      *transportDemuxer
   341  	netProto   tcpip.NetworkProtocolNumber
   342  	transProto tcpip.TransportProtocolNumber
   343  
   344  	flags ports.FlagCounter
   345  
   346  	mu sync.RWMutex `state:"nosave"`
   347  	// endpoints stores the transport endpoints in the order in which they
   348  	// were bound. This is required for UDP SO_REUSEADDR.
   349  	//
   350  	// +checklocks:mu
   351  	endpoints []TransportEndpoint
   352  }
   353  
   354  func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
   355  	ep.mu.RLock()
   356  	eps := append([]TransportEndpoint(nil), ep.endpoints...)
   357  	ep.mu.RUnlock()
   358  	return eps
   359  }
   360  
   361  // reciprocalScale scales a value into range [0, n).
   362  //
   363  // This is similar to val % n, but faster.
   364  // See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
   365  func reciprocalScale(val, n uint32) uint32 {
   366  	return uint32((uint64(val) * uint64(n)) >> 32)
   367  }
   368  
   369  // selectEndpoint calculates a hash of destination and source addresses and
   370  // ports then uses it to select a socket. In this case, all packets from one
   371  // address will be sent to same endpoint.
   372  func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
   373  	ep.mu.RLock()
   374  	defer ep.mu.RUnlock()
   375  
   376  	if len(ep.endpoints) == 1 {
   377  		return ep.endpoints[0]
   378  	}
   379  
   380  	if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
   381  		return ep.endpoints[len(ep.endpoints)-1]
   382  	}
   383  
   384  	payload := []byte{
   385  		byte(id.LocalPort),
   386  		byte(id.LocalPort >> 8),
   387  		byte(id.RemotePort),
   388  		byte(id.RemotePort >> 8),
   389  	}
   390  
   391  	h := jenkins.Sum32(seed)
   392  	h.Write(payload)
   393  	h.Write([]byte(id.LocalAddress))
   394  	h.Write([]byte(id.RemoteAddress))
   395  	hash := h.Sum32()
   396  
   397  	idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
   398  	return ep.endpoints[idx]
   399  }
   400  
   401  func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
   402  	ep.mu.RLock()
   403  	queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
   404  	// HandlePacket may modify pkt, so each endpoint needs
   405  	// its own copy except for the final one.
   406  	for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
   407  		clone := pkt.Clone()
   408  		if mustQueue {
   409  			queuedProtocol.QueuePacket(endpoint, id, clone)
   410  		} else {
   411  			endpoint.HandlePacket(id, clone)
   412  		}
   413  		clone.DecRef()
   414  	}
   415  	if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
   416  		queuedProtocol.QueuePacket(endpoint, id, pkt)
   417  	} else {
   418  		endpoint.HandlePacket(id, pkt)
   419  	}
   420  	ep.mu.RUnlock() // Don't use defer for performance reasons.
   421  }
   422  
   423  // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
   424  // list. The list might be empty already.
   425  func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error {
   426  	ep.mu.Lock()
   427  	defer ep.mu.Unlock()
   428  	bits := flags.Bits() & ports.MultiBindFlagMask
   429  
   430  	if len(ep.endpoints) != 0 {
   431  		// If it was previously bound, we need to check if we can bind again.
   432  		if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
   433  			return &tcpip.ErrPortInUse{}
   434  		}
   435  	}
   436  
   437  	ep.endpoints = append(ep.endpoints, t)
   438  	ep.flags.AddRef(bits)
   439  
   440  	return nil
   441  }
   442  
   443  func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error {
   444  	ep.mu.RLock()
   445  	defer ep.mu.RUnlock()
   446  
   447  	bits := flags.Bits() & ports.MultiBindFlagMask
   448  
   449  	if len(ep.endpoints) != 0 {
   450  		// If it was previously bound, we need to check if we can bind again.
   451  		if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
   452  			return &tcpip.ErrPortInUse{}
   453  		}
   454  	}
   455  
   456  	return nil
   457  }
   458  
   459  // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
   460  func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
   461  	ep.mu.Lock()
   462  	defer ep.mu.Unlock()
   463  
   464  	for i, endpoint := range ep.endpoints {
   465  		if endpoint == t {
   466  			copy(ep.endpoints[i:], ep.endpoints[i+1:])
   467  			ep.endpoints[len(ep.endpoints)-1] = nil
   468  			ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
   469  
   470  			ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask)
   471  			break
   472  		}
   473  	}
   474  	return len(ep.endpoints) == 0
   475  }
   476  
   477  func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   478  	if id.RemotePort != 0 {
   479  		// SO_REUSEPORT only applies to bound/listening endpoints.
   480  		flags.LoadBalanced = false
   481  	}
   482  
   483  	eps, ok := d.protocol[protocolIDs{netProto, protocol}]
   484  	if !ok {
   485  		return &tcpip.ErrUnknownProtocol{}
   486  	}
   487  
   488  	eps.mu.Lock()
   489  	defer eps.mu.Unlock()
   490  	epsByNIC, ok := eps.endpoints[id]
   491  	if !ok {
   492  		epsByNIC = &endpointsByNIC{
   493  			endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
   494  			seed:      d.stack.seed,
   495  		}
   496  	}
   497  	if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
   498  		return err
   499  	}
   500  	// Only add this newly created epsByNIC if registerEndpoint succeeded.
   501  	if !ok {
   502  		eps.endpoints[id] = epsByNIC
   503  	}
   504  	return nil
   505  }
   506  
   507  func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   508  	if id.RemotePort != 0 {
   509  		// SO_REUSEPORT only applies to bound/listening endpoints.
   510  		flags.LoadBalanced = false
   511  	}
   512  
   513  	eps, ok := d.protocol[protocolIDs{netProto, protocol}]
   514  	if !ok {
   515  		return &tcpip.ErrUnknownProtocol{}
   516  	}
   517  
   518  	eps.mu.RLock()
   519  	defer eps.mu.RUnlock()
   520  
   521  	epsByNIC, ok := eps.endpoints[id]
   522  	if !ok {
   523  		return nil
   524  	}
   525  
   526  	return epsByNIC.checkEndpoint(flags, bindToDevice)
   527  }
   528  
   529  // unregisterEndpoint unregisters the endpoint with the given id such that it
   530  // won't receive any more packets.
   531  func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
   532  	if id.RemotePort != 0 {
   533  		// SO_REUSEPORT only applies to bound/listening endpoints.
   534  		flags.LoadBalanced = false
   535  	}
   536  
   537  	for _, n := range netProtos {
   538  		if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
   539  			eps.unregisterEndpoint(id, ep, flags, bindToDevice)
   540  		}
   541  	}
   542  }
   543  
   544  // deliverPacket attempts to find one or more matching transport endpoints, and
   545  // then, if matches are found, delivers the packet to them. Returns true if
   546  // the packet no longer needs to be handled.
   547  func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
   548  	eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
   549  	if !ok {
   550  		return false
   551  	}
   552  
   553  	// If the packet is a UDP broadcast or multicast, then find all matching
   554  	// transport endpoints.
   555  	if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
   556  		eps.mu.RLock()
   557  		destEPs := eps.findAllEndpointsLocked(id)
   558  		eps.mu.RUnlock()
   559  		// Fail if we didn't find at least one matching transport endpoint.
   560  		if len(destEPs) == 0 {
   561  			d.stack.stats.UDP.UnknownPortErrors.Increment()
   562  			return false
   563  		}
   564  		// handlePacket takes may modify pkt, so each endpoint needs its own
   565  		// copy except for the final one.
   566  		for _, ep := range destEPs[:len(destEPs)-1] {
   567  			clone := pkt.Clone()
   568  			ep.handlePacket(id, clone)
   569  			clone.DecRef()
   570  		}
   571  		destEPs[len(destEPs)-1].handlePacket(id, pkt)
   572  		return true
   573  	}
   574  
   575  	// If the packet is a TCP packet with a unspecified source or non-unicast
   576  	// destination address, then do nothing further and instruct the caller to do
   577  	// the same. The network layer handles address validation for specified source
   578  	// addresses.
   579  	if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
   580  		// TCP can only be used to communicate between a single source and a
   581  		// single destination; the addresses must be unicast.e
   582  		d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
   583  		return true
   584  	}
   585  
   586  	eps.mu.RLock()
   587  	ep := eps.findEndpointLocked(id)
   588  	eps.mu.RUnlock()
   589  	if ep == nil {
   590  		if protocol == header.UDPProtocolNumber {
   591  			d.stack.stats.UDP.UnknownPortErrors.Increment()
   592  		}
   593  		return false
   594  	}
   595  	return ep.handlePacket(id, pkt)
   596  }
   597  
   598  // deliverRawPacket attempts to deliver the given packet and returns whether it
   599  // was delivered successfully.
   600  func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
   601  	eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
   602  	if !ok {
   603  		return false
   604  	}
   605  
   606  	// As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
   607  	// raw endpoint first. If there are multiple raw endpoints, they all
   608  	// receive the packet.
   609  	eps.mu.RLock()
   610  	// Copy the list of raw endpoints to avoid packet handling under lock.
   611  	var rawEPs []RawTransportEndpoint
   612  	if n := len(eps.rawEndpoints); n != 0 {
   613  		rawEPs = make([]RawTransportEndpoint, n)
   614  		if m := copy(rawEPs, eps.rawEndpoints); m != n {
   615  			panic(fmt.Sprintf("unexpected copy = %d, want %d", m, n))
   616  		}
   617  	}
   618  	eps.mu.RUnlock()
   619  	for _, rawEP := range rawEPs {
   620  		// Each endpoint gets its own copy of the packet for the sake
   621  		// of save/restore.
   622  		clone := pkt.Clone()
   623  		rawEP.HandlePacket(clone)
   624  		clone.DecRef()
   625  	}
   626  
   627  	return len(rawEPs) != 0
   628  }
   629  
   630  // deliverError attempts to deliver the given error to the appropriate transport
   631  // endpoint.
   632  //
   633  // Returns true if the error was delivered.
   634  func (d *transportDemuxer) deliverError(n *nic, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool {
   635  	eps, ok := d.protocol[protocolIDs{net, trans}]
   636  	if !ok {
   637  		return false
   638  	}
   639  
   640  	eps.mu.RLock()
   641  	ep := eps.findEndpointLocked(id)
   642  	eps.mu.RUnlock()
   643  	if ep == nil {
   644  		return false
   645  	}
   646  
   647  	ep.handleError(n, id, transErr, pkt)
   648  	return true
   649  }
   650  
   651  // findTransportEndpoint find a single endpoint that most closely matches the provided id.
   652  func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
   653  	eps, ok := d.protocol[protocolIDs{netProto, transProto}]
   654  	if !ok {
   655  		return nil
   656  	}
   657  
   658  	eps.mu.RLock()
   659  	epsByNIC := eps.findEndpointLocked(id)
   660  	if epsByNIC == nil {
   661  		eps.mu.RUnlock()
   662  		return nil
   663  	}
   664  
   665  	epsByNIC.mu.RLock()
   666  	eps.mu.RUnlock()
   667  
   668  	mpep, ok := epsByNIC.endpoints[nicID]
   669  	if !ok {
   670  		if mpep, ok = epsByNIC.endpoints[0]; !ok {
   671  			epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
   672  			return nil
   673  		}
   674  	}
   675  
   676  	ep := mpep.selectEndpoint(id, epsByNIC.seed)
   677  	epsByNIC.mu.RUnlock()
   678  	return ep
   679  }
   680  
   681  // registerRawEndpoint registers the given endpoint with the dispatcher such
   682  // that packets of the appropriate protocol are delivered to it. A single
   683  // packet can be sent to one or more raw endpoints along with a non-raw
   684  // endpoint.
   685  func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error {
   686  	eps, ok := d.protocol[protocolIDs{netProto, transProto}]
   687  	if !ok {
   688  		return &tcpip.ErrNotSupported{}
   689  	}
   690  
   691  	eps.mu.Lock()
   692  	eps.rawEndpoints = append(eps.rawEndpoints, ep)
   693  	eps.mu.Unlock()
   694  
   695  	return nil
   696  }
   697  
   698  // unregisterRawEndpoint unregisters the raw endpoint for the given transport
   699  // protocol such that it won't receive any more packets.
   700  func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
   701  	eps, ok := d.protocol[protocolIDs{netProto, transProto}]
   702  	if !ok {
   703  		panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
   704  	}
   705  
   706  	eps.mu.Lock()
   707  	for i, rawEP := range eps.rawEndpoints {
   708  		if rawEP == ep {
   709  			lastIdx := len(eps.rawEndpoints) - 1
   710  			eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx]
   711  			eps.rawEndpoints[lastIdx] = nil
   712  			eps.rawEndpoints = eps.rawEndpoints[:lastIdx]
   713  			break
   714  		}
   715  	}
   716  	eps.mu.Unlock()
   717  }
   718  
   719  func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool {
   720  	return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
   721  }
   722  
   723  func isSpecified(addr tcpip.Address) bool {
   724  	return addr != header.IPv4Any && addr != header.IPv6Any
   725  }