github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/stack/transport_demuxer.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip"
    21  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/hash/jenkins"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/ports"
    24  )
    25  
    26  type protocolIDs struct {
    27  	network   tcpip.NetworkProtocolNumber
    28  	transport tcpip.TransportProtocolNumber
    29  }
    30  
    31  // transportEndpoints manages all endpoints of a given protocol. It has its own
    32  // mutex so as to reduce interference between protocols.
    33  type transportEndpoints struct {
    34  	mu transportEndpointsRWMutex
    35  	// +checklocks:mu
    36  	endpoints map[TransportEndpointID]*endpointsByNIC
    37  	// rawEndpoints contains endpoints for raw sockets, which receive all
    38  	// traffic of a given protocol regardless of port.
    39  	//
    40  	// +checklocks:mu
    41  	rawEndpoints []RawTransportEndpoint
    42  }
    43  
    44  // unregisterEndpoint unregisters the endpoint with the given id such that it
    45  // won't receive any more packets.
    46  func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
    47  	eps.mu.Lock()
    48  	defer eps.mu.Unlock()
    49  	epsByNIC, ok := eps.endpoints[id]
    50  	if !ok {
    51  		return
    52  	}
    53  	if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) {
    54  		return
    55  	}
    56  	delete(eps.endpoints, id)
    57  }
    58  
    59  func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
    60  	eps.mu.RLock()
    61  	defer eps.mu.RUnlock()
    62  	es := make([]TransportEndpoint, 0, len(eps.endpoints))
    63  	for _, e := range eps.endpoints {
    64  		es = append(es, e.transportEndpoints()...)
    65  	}
    66  	return es
    67  }
    68  
    69  // iterEndpointsLocked yields all endpointsByNIC in eps that match id, in
    70  // descending order of match quality. If a call to yield returns false,
    71  // iterEndpointsLocked stops iteration and returns immediately.
    72  //
    73  // +checklocksread:eps.mu
    74  func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
    75  	// Try to find a match with the id as provided.
    76  	if ep, ok := eps.endpoints[id]; ok {
    77  		if !yield(ep) {
    78  			return
    79  		}
    80  	}
    81  
    82  	// Try to find a match with the id minus the local address.
    83  	nid := id
    84  
    85  	nid.LocalAddress = tcpip.Address{}
    86  	if ep, ok := eps.endpoints[nid]; ok {
    87  		if !yield(ep) {
    88  			return
    89  		}
    90  	}
    91  
    92  	// Try to find a match with the id minus the remote part.
    93  	nid.LocalAddress = id.LocalAddress
    94  	nid.RemoteAddress = tcpip.Address{}
    95  	nid.RemotePort = 0
    96  	if ep, ok := eps.endpoints[nid]; ok {
    97  		if !yield(ep) {
    98  			return
    99  		}
   100  	}
   101  
   102  	// Try to find a match with only the local port.
   103  	nid.LocalAddress = tcpip.Address{}
   104  	if ep, ok := eps.endpoints[nid]; ok {
   105  		if !yield(ep) {
   106  			return
   107  		}
   108  	}
   109  }
   110  
   111  // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
   112  // descending order of match quality.
   113  //
   114  // +checklocksread:eps.mu
   115  func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
   116  	var matchedEPs []*endpointsByNIC
   117  	eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
   118  		matchedEPs = append(matchedEPs, ep)
   119  		return true
   120  	})
   121  	return matchedEPs
   122  }
   123  
   124  // findEndpointLocked returns the endpoint that most closely matches the given id.
   125  //
   126  // +checklocksread:eps.mu
   127  func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
   128  	var matchedEP *endpointsByNIC
   129  	eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
   130  		matchedEP = ep
   131  		return false
   132  	})
   133  	return matchedEP
   134  }
   135  
   136  type endpointsByNIC struct {
   137  	// seed is a random secret for a jenkins hash.
   138  	seed uint32
   139  
   140  	mu endpointsByNICRWMutex
   141  	// +checklocks:mu
   142  	endpoints map[tcpip.NICID]*multiPortEndpoint
   143  }
   144  
   145  func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
   146  	epsByNIC.mu.RLock()
   147  	defer epsByNIC.mu.RUnlock()
   148  	var eps []TransportEndpoint
   149  	for _, ep := range epsByNIC.endpoints {
   150  		eps = append(eps, ep.transportEndpoints()...)
   151  	}
   152  	return eps
   153  }
   154  
   155  // handlePacket is called by the stack when new packets arrive to this transport
   156  // endpoint. It returns false if the packet could not be matched to any
   157  // transport endpoint, true otherwise.
   158  func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt PacketBufferPtr) bool {
   159  	epsByNIC.mu.RLock()
   160  
   161  	mpep, ok := epsByNIC.endpoints[pkt.NICID]
   162  	if !ok {
   163  		if mpep, ok = epsByNIC.endpoints[0]; !ok {
   164  			epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
   165  			return false
   166  		}
   167  	}
   168  
   169  	// If this is a broadcast or multicast datagram, deliver the datagram to all
   170  	// endpoints bound to the right device.
   171  	if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
   172  		mpep.handlePacketAll(id, pkt)
   173  		epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
   174  		return true
   175  	}
   176  	// multiPortEndpoints are guaranteed to have at least one element.
   177  	transEP := mpep.selectEndpoint(id, epsByNIC.seed)
   178  	if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
   179  		queuedProtocol.QueuePacket(transEP, id, pkt)
   180  		epsByNIC.mu.RUnlock()
   181  		return true
   182  	}
   183  	epsByNIC.mu.RUnlock()
   184  
   185  	transEP.HandlePacket(id, pkt)
   186  	return true
   187  }
   188  
   189  // handleError delivers an error to the transport endpoint identified by id.
   190  func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, transErr TransportError, pkt PacketBufferPtr) {
   191  	epsByNIC.mu.RLock()
   192  
   193  	mpep, ok := epsByNIC.endpoints[n.ID()]
   194  	if !ok {
   195  		mpep, ok = epsByNIC.endpoints[0]
   196  	}
   197  	if !ok {
   198  		epsByNIC.mu.RUnlock()
   199  		return
   200  	}
   201  
   202  	// TODO(eyalsoha): Why don't we look at id to see if this packet needs to
   203  	// broadcast like we are doing with handlePacket above?
   204  
   205  	// multiPortEndpoints are guaranteed to have at least one element.
   206  	transEP := mpep.selectEndpoint(id, epsByNIC.seed)
   207  	epsByNIC.mu.RUnlock()
   208  
   209  	transEP.HandleError(transErr, pkt)
   210  }
   211  
   212  // registerEndpoint returns true if it succeeds. It fails and returns
   213  // false if ep already has an element with the same key.
   214  func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   215  	epsByNIC.mu.Lock()
   216  	defer epsByNIC.mu.Unlock()
   217  
   218  	multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
   219  	if !ok {
   220  		multiPortEp = &multiPortEndpoint{
   221  			demux:      d,
   222  			netProto:   netProto,
   223  			transProto: transProto,
   224  		}
   225  	}
   226  
   227  	if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil {
   228  		return err
   229  	}
   230  	// Only add this newly created multiportEndpoint if the singleRegisterEndpoint
   231  	// succeeded.
   232  	if !ok {
   233  		epsByNIC.endpoints[bindToDevice] = multiPortEp
   234  	}
   235  	return nil
   236  }
   237  
   238  func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   239  	epsByNIC.mu.RLock()
   240  	defer epsByNIC.mu.RUnlock()
   241  
   242  	multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
   243  	if !ok {
   244  		return nil
   245  	}
   246  
   247  	return multiPortEp.singleCheckEndpoint(flags)
   248  }
   249  
   250  // unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
   251  func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
   252  	epsByNIC.mu.Lock()
   253  	defer epsByNIC.mu.Unlock()
   254  	multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
   255  	if !ok {
   256  		return false
   257  	}
   258  	if multiPortEp.unregisterEndpoint(t, flags) {
   259  		delete(epsByNIC.endpoints, bindToDevice)
   260  	}
   261  	return len(epsByNIC.endpoints) == 0
   262  }
   263  
   264  // transportDemuxer demultiplexes packets targeted at a transport endpoint
   265  // (i.e., after they've been parsed by the network layer). It does two levels
   266  // of demultiplexing: first based on the network and transport protocols, then
   267  // based on endpoints IDs. It should only be instantiated via
   268  // newTransportDemuxer.
   269  type transportDemuxer struct {
   270  	stack *Stack
   271  
   272  	// protocol is immutable.
   273  	protocol        map[protocolIDs]*transportEndpoints
   274  	queuedProtocols map[protocolIDs]queuedTransportProtocol
   275  }
   276  
   277  // queuedTransportProtocol if supported by a protocol implementation will cause
   278  // the dispatcher to delivery packets to the QueuePacket method instead of
   279  // calling HandlePacket directly on the endpoint.
   280  type queuedTransportProtocol interface {
   281  	QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt PacketBufferPtr)
   282  }
   283  
   284  func newTransportDemuxer(stack *Stack) *transportDemuxer {
   285  	d := &transportDemuxer{
   286  		stack:           stack,
   287  		protocol:        make(map[protocolIDs]*transportEndpoints),
   288  		queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
   289  	}
   290  
   291  	// Add each network and transport pair to the demuxer.
   292  	for netProto := range stack.networkProtocols {
   293  		for proto := range stack.transportProtocols {
   294  			protoIDs := protocolIDs{netProto, proto}
   295  			d.protocol[protoIDs] = &transportEndpoints{
   296  				endpoints: make(map[TransportEndpointID]*endpointsByNIC),
   297  			}
   298  			qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
   299  			if isQueued {
   300  				d.queuedProtocols[protoIDs] = qTransProto
   301  			}
   302  		}
   303  	}
   304  
   305  	return d
   306  }
   307  
   308  // registerEndpoint registers the given endpoint with the dispatcher such that
   309  // packets that match the endpoint ID are delivered to it.
   310  func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   311  	for i, n := range netProtos {
   312  		if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
   313  			d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
   314  			return err
   315  		}
   316  	}
   317  
   318  	return nil
   319  }
   320  
   321  // checkEndpoint checks if an endpoint can be registered with the dispatcher.
   322  func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   323  	for _, n := range netProtos {
   324  		if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
   325  			return err
   326  		}
   327  	}
   328  
   329  	return nil
   330  }
   331  
   332  // multiPortEndpoint is a container for TransportEndpoints which are bound to
   333  // the same pair of address and port. endpointsArr always has at least one
   334  // element.
   335  //
   336  // FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
   337  // this to ensure that the underlying endpoints get saved/restored, but not not
   338  // use the restored copy.
   339  //
   340  // +stateify savable
   341  type multiPortEndpoint struct {
   342  	demux      *transportDemuxer
   343  	netProto   tcpip.NetworkProtocolNumber
   344  	transProto tcpip.TransportProtocolNumber
   345  
   346  	flags ports.FlagCounter
   347  
   348  	mu multiPortEndpointRWMutex `state:"nosave"`
   349  	// endpoints stores the transport endpoints in the order in which they
   350  	// were bound. This is required for UDP SO_REUSEADDR.
   351  	//
   352  	// +checklocks:mu
   353  	endpoints []TransportEndpoint
   354  }
   355  
   356  func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
   357  	ep.mu.RLock()
   358  	eps := append([]TransportEndpoint(nil), ep.endpoints...)
   359  	ep.mu.RUnlock()
   360  	return eps
   361  }
   362  
   363  // reciprocalScale scales a value into range [0, n).
   364  //
   365  // This is similar to val % n, but faster.
   366  // See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
   367  func reciprocalScale(val, n uint32) uint32 {
   368  	return uint32((uint64(val) * uint64(n)) >> 32)
   369  }
   370  
   371  // selectEndpoint calculates a hash of destination and source addresses and
   372  // ports then uses it to select a socket. In this case, all packets from one
   373  // address will be sent to same endpoint.
   374  func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint {
   375  	ep.mu.RLock()
   376  	defer ep.mu.RUnlock()
   377  
   378  	if len(ep.endpoints) == 1 {
   379  		return ep.endpoints[0]
   380  	}
   381  
   382  	if ep.flags.SharedFlags().ToFlags().Effective().MostRecent {
   383  		return ep.endpoints[len(ep.endpoints)-1]
   384  	}
   385  
   386  	payload := []byte{
   387  		byte(id.LocalPort),
   388  		byte(id.LocalPort >> 8),
   389  		byte(id.RemotePort),
   390  		byte(id.RemotePort >> 8),
   391  	}
   392  
   393  	h := jenkins.Sum32(seed)
   394  	h.Write(payload)
   395  	h.Write(id.LocalAddress.AsSlice())
   396  	h.Write(id.RemoteAddress.AsSlice())
   397  	hash := h.Sum32()
   398  
   399  	idx := reciprocalScale(hash, uint32(len(ep.endpoints)))
   400  	return ep.endpoints[idx]
   401  }
   402  
   403  func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt PacketBufferPtr) {
   404  	ep.mu.RLock()
   405  	queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
   406  	// HandlePacket may modify pkt, so each endpoint needs
   407  	// its own copy except for the final one.
   408  	for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
   409  		clone := pkt.Clone()
   410  		if mustQueue {
   411  			queuedProtocol.QueuePacket(endpoint, id, clone)
   412  		} else {
   413  			endpoint.HandlePacket(id, clone)
   414  		}
   415  		clone.DecRef()
   416  	}
   417  	if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
   418  		queuedProtocol.QueuePacket(endpoint, id, pkt)
   419  	} else {
   420  		endpoint.HandlePacket(id, pkt)
   421  	}
   422  	ep.mu.RUnlock() // Don't use defer for performance reasons.
   423  }
   424  
   425  // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
   426  // list. The list might be empty already.
   427  func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error {
   428  	ep.mu.Lock()
   429  	defer ep.mu.Unlock()
   430  	bits := flags.Bits() & ports.MultiBindFlagMask
   431  
   432  	if len(ep.endpoints) != 0 {
   433  		// If it was previously bound, we need to check if we can bind again.
   434  		if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
   435  			return &tcpip.ErrPortInUse{}
   436  		}
   437  	}
   438  
   439  	ep.endpoints = append(ep.endpoints, t)
   440  	ep.flags.AddRef(bits)
   441  
   442  	return nil
   443  }
   444  
   445  func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error {
   446  	ep.mu.RLock()
   447  	defer ep.mu.RUnlock()
   448  
   449  	bits := flags.Bits() & ports.MultiBindFlagMask
   450  
   451  	if len(ep.endpoints) != 0 {
   452  		// If it was previously bound, we need to check if we can bind again.
   453  		if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
   454  			return &tcpip.ErrPortInUse{}
   455  		}
   456  	}
   457  
   458  	return nil
   459  }
   460  
   461  // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
   462  func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
   463  	ep.mu.Lock()
   464  	defer ep.mu.Unlock()
   465  
   466  	for i, endpoint := range ep.endpoints {
   467  		if endpoint == t {
   468  			copy(ep.endpoints[i:], ep.endpoints[i+1:])
   469  			ep.endpoints[len(ep.endpoints)-1] = nil
   470  			ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
   471  
   472  			ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask)
   473  			break
   474  		}
   475  	}
   476  	return len(ep.endpoints) == 0
   477  }
   478  
   479  func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   480  	if id.RemotePort != 0 {
   481  		// SO_REUSEPORT only applies to bound/listening endpoints.
   482  		flags.LoadBalanced = false
   483  	}
   484  
   485  	eps, ok := d.protocol[protocolIDs{netProto, protocol}]
   486  	if !ok {
   487  		return &tcpip.ErrUnknownProtocol{}
   488  	}
   489  
   490  	eps.mu.Lock()
   491  	defer eps.mu.Unlock()
   492  	epsByNIC, ok := eps.endpoints[id]
   493  	if !ok {
   494  		epsByNIC = &endpointsByNIC{
   495  			endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
   496  			seed:      d.stack.seed,
   497  		}
   498  	}
   499  	if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
   500  		return err
   501  	}
   502  	// Only add this newly created epsByNIC if registerEndpoint succeeded.
   503  	if !ok {
   504  		eps.endpoints[id] = epsByNIC
   505  	}
   506  	return nil
   507  }
   508  
   509  func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
   510  	if id.RemotePort != 0 {
   511  		// SO_REUSEPORT only applies to bound/listening endpoints.
   512  		flags.LoadBalanced = false
   513  	}
   514  
   515  	eps, ok := d.protocol[protocolIDs{netProto, protocol}]
   516  	if !ok {
   517  		return &tcpip.ErrUnknownProtocol{}
   518  	}
   519  
   520  	eps.mu.RLock()
   521  	defer eps.mu.RUnlock()
   522  
   523  	epsByNIC, ok := eps.endpoints[id]
   524  	if !ok {
   525  		return nil
   526  	}
   527  
   528  	return epsByNIC.checkEndpoint(flags, bindToDevice)
   529  }
   530  
   531  // unregisterEndpoint unregisters the endpoint with the given id such that it
   532  // won't receive any more packets.
   533  func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
   534  	if id.RemotePort != 0 {
   535  		// SO_REUSEPORT only applies to bound/listening endpoints.
   536  		flags.LoadBalanced = false
   537  	}
   538  
   539  	for _, n := range netProtos {
   540  		if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
   541  			eps.unregisterEndpoint(id, ep, flags, bindToDevice)
   542  		}
   543  	}
   544  }
   545  
   546  // deliverPacket attempts to find one or more matching transport endpoints, and
   547  // then, if matches are found, delivers the packet to them. Returns true if
   548  // the packet no longer needs to be handled.
   549  func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt PacketBufferPtr, id TransportEndpointID) bool {
   550  	eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
   551  	if !ok {
   552  		return false
   553  	}
   554  
   555  	// If the packet is a UDP broadcast or multicast, then find all matching
   556  	// transport endpoints.
   557  	if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
   558  		eps.mu.RLock()
   559  		destEPs := eps.findAllEndpointsLocked(id)
   560  		eps.mu.RUnlock()
   561  		// Fail if we didn't find at least one matching transport endpoint.
   562  		if len(destEPs) == 0 {
   563  			d.stack.stats.UDP.UnknownPortErrors.Increment()
   564  			return false
   565  		}
   566  		// handlePacket takes may modify pkt, so each endpoint needs its own
   567  		// copy except for the final one.
   568  		for _, ep := range destEPs[:len(destEPs)-1] {
   569  			clone := pkt.Clone()
   570  			ep.handlePacket(id, clone)
   571  			clone.DecRef()
   572  		}
   573  		destEPs[len(destEPs)-1].handlePacket(id, pkt)
   574  		return true
   575  	}
   576  
   577  	// If the packet is a TCP packet with a unspecified source or non-unicast
   578  	// destination address, then do nothing further and instruct the caller to do
   579  	// the same. The network layer handles address validation for specified source
   580  	// addresses.
   581  	if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
   582  		// TCP can only be used to communicate between a single source and a
   583  		// single destination; the addresses must be unicast.e
   584  		d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
   585  		return true
   586  	}
   587  
   588  	eps.mu.RLock()
   589  	ep := eps.findEndpointLocked(id)
   590  	eps.mu.RUnlock()
   591  	if ep == nil {
   592  		if protocol == header.UDPProtocolNumber {
   593  			d.stack.stats.UDP.UnknownPortErrors.Increment()
   594  		}
   595  		return false
   596  	}
   597  	return ep.handlePacket(id, pkt)
   598  }
   599  
   600  // deliverRawPacket attempts to deliver the given packet and returns whether it
   601  // was delivered successfully.
   602  func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt PacketBufferPtr) bool {
   603  	eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
   604  	if !ok {
   605  		return false
   606  	}
   607  
   608  	// As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
   609  	// raw endpoint first. If there are multiple raw endpoints, they all
   610  	// receive the packet.
   611  	eps.mu.RLock()
   612  	// Copy the list of raw endpoints to avoid packet handling under lock.
   613  	var rawEPs []RawTransportEndpoint
   614  	if n := len(eps.rawEndpoints); n != 0 {
   615  		rawEPs = make([]RawTransportEndpoint, n)
   616  		if m := copy(rawEPs, eps.rawEndpoints); m != n {
   617  			panic(fmt.Sprintf("unexpected copy = %d, want %d", m, n))
   618  		}
   619  	}
   620  	eps.mu.RUnlock()
   621  	for _, rawEP := range rawEPs {
   622  		// Each endpoint gets its own copy of the packet for the sake
   623  		// of save/restore.
   624  		clone := pkt.Clone()
   625  		rawEP.HandlePacket(clone)
   626  		clone.DecRef()
   627  	}
   628  
   629  	return len(rawEPs) != 0
   630  }
   631  
   632  // deliverError attempts to deliver the given error to the appropriate transport
   633  // endpoint.
   634  //
   635  // Returns true if the error was delivered.
   636  func (d *transportDemuxer) deliverError(n *nic, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt PacketBufferPtr, id TransportEndpointID) bool {
   637  	eps, ok := d.protocol[protocolIDs{net, trans}]
   638  	if !ok {
   639  		return false
   640  	}
   641  
   642  	eps.mu.RLock()
   643  	ep := eps.findEndpointLocked(id)
   644  	eps.mu.RUnlock()
   645  	if ep == nil {
   646  		return false
   647  	}
   648  
   649  	ep.handleError(n, id, transErr, pkt)
   650  	return true
   651  }
   652  
   653  // findTransportEndpoint find a single endpoint that most closely matches the provided id.
   654  func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
   655  	eps, ok := d.protocol[protocolIDs{netProto, transProto}]
   656  	if !ok {
   657  		return nil
   658  	}
   659  
   660  	eps.mu.RLock()
   661  	epsByNIC := eps.findEndpointLocked(id)
   662  	if epsByNIC == nil {
   663  		eps.mu.RUnlock()
   664  		return nil
   665  	}
   666  
   667  	epsByNIC.mu.RLock()
   668  	eps.mu.RUnlock()
   669  
   670  	mpep, ok := epsByNIC.endpoints[nicID]
   671  	if !ok {
   672  		if mpep, ok = epsByNIC.endpoints[0]; !ok {
   673  			epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
   674  			return nil
   675  		}
   676  	}
   677  
   678  	ep := mpep.selectEndpoint(id, epsByNIC.seed)
   679  	epsByNIC.mu.RUnlock()
   680  	return ep
   681  }
   682  
   683  // registerRawEndpoint registers the given endpoint with the dispatcher such
   684  // that packets of the appropriate protocol are delivered to it. A single
   685  // packet can be sent to one or more raw endpoints along with a non-raw
   686  // endpoint.
   687  func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error {
   688  	eps, ok := d.protocol[protocolIDs{netProto, transProto}]
   689  	if !ok {
   690  		return &tcpip.ErrNotSupported{}
   691  	}
   692  
   693  	eps.mu.Lock()
   694  	eps.rawEndpoints = append(eps.rawEndpoints, ep)
   695  	eps.mu.Unlock()
   696  
   697  	return nil
   698  }
   699  
   700  // unregisterRawEndpoint unregisters the raw endpoint for the given transport
   701  // protocol such that it won't receive any more packets.
   702  func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
   703  	eps, ok := d.protocol[protocolIDs{netProto, transProto}]
   704  	if !ok {
   705  		panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
   706  	}
   707  
   708  	eps.mu.Lock()
   709  	for i, rawEP := range eps.rawEndpoints {
   710  		if rawEP == ep {
   711  			lastIdx := len(eps.rawEndpoints) - 1
   712  			eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx]
   713  			eps.rawEndpoints[lastIdx] = nil
   714  			eps.rawEndpoints = eps.rawEndpoints[:lastIdx]
   715  			break
   716  		}
   717  	}
   718  	eps.mu.Unlock()
   719  }
   720  
   721  func isInboundMulticastOrBroadcast(pkt PacketBufferPtr, localAddr tcpip.Address) bool {
   722  	return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
   723  }
   724  
   725  func isSpecified(addr tcpip.Address) bool {
   726  	return addr != header.IPv4Any && addr != header.IPv6Any
   727  }