github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/stack/transport_demuxer.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"fmt"
    19  	"math/rand"
    20  	"sync"
    21  
    22  	"github.com/FlowerWrong/netstack/tcpip"
    23  	"github.com/FlowerWrong/netstack/tcpip/buffer"
    24  	"github.com/FlowerWrong/netstack/tcpip/hash/jenkins"
    25  	"github.com/FlowerWrong/netstack/tcpip/header"
    26  )
    27  
    28  type protocolIDs struct {
    29  	network   tcpip.NetworkProtocolNumber
    30  	transport tcpip.TransportProtocolNumber
    31  }
    32  
    33  // transportEndpoints manages all endpoints of a given protocol. It has its own
    34  // mutex so as to reduce interference between protocols.
    35  type transportEndpoints struct {
    36  	// mu protects all fields of the transportEndpoints.
    37  	mu        sync.RWMutex
    38  	endpoints map[TransportEndpointID]*endpointsByNic
    39  	// rawEndpoints contains endpoints for raw sockets, which receive all
    40  	// traffic of a given protocol regardless of port.
    41  	rawEndpoints []RawTransportEndpoint
    42  }
    43  
    44  type endpointsByNic struct {
    45  	mu        sync.RWMutex
    46  	endpoints map[tcpip.NICID]*multiPortEndpoint
    47  	// seed is a random secret for a jenkins hash.
    48  	seed uint32
    49  }
    50  
    51  // HandlePacket is called by the stack when new packets arrive to this transport
    52  // endpoint.
    53  func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
    54  	epsByNic.mu.RLock()
    55  
    56  	mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
    57  	if !ok {
    58  		if mpep, ok = epsByNic.endpoints[0]; !ok {
    59  			epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
    60  			return
    61  		}
    62  	}
    63  
    64  	// If this is a broadcast or multicast datagram, deliver the datagram to all
    65  	// endpoints bound to the right device.
    66  	if isMulticastOrBroadcast(id.LocalAddress) {
    67  		mpep.handlePacketAll(r, id, vv)
    68  		epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
    69  		return
    70  	}
    71  
    72  	// multiPortEndpoints are guaranteed to have at least one element.
    73  	selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
    74  	epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
    75  }
    76  
    77  // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
    78  func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
    79  	epsByNic.mu.RLock()
    80  	defer epsByNic.mu.RUnlock()
    81  
    82  	mpep, ok := epsByNic.endpoints[n.ID()]
    83  	if !ok {
    84  		mpep, ok = epsByNic.endpoints[0]
    85  	}
    86  	if !ok {
    87  		return
    88  	}
    89  
    90  	// TODO(eyalsoha): Why don't we look at id to see if this packet needs to
    91  	// broadcast like we are doing with handlePacket above?
    92  
    93  	// multiPortEndpoints are guaranteed to have at least one element.
    94  	selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
    95  }
    96  
    97  // registerEndpoint returns true if it succeeds. It fails and returns
    98  // false if ep already has an element with the same key.
    99  func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
   100  	epsByNic.mu.Lock()
   101  	defer epsByNic.mu.Unlock()
   102  
   103  	if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
   104  		// There was already a bind.
   105  		return multiPortEp.singleRegisterEndpoint(t, reusePort)
   106  	}
   107  
   108  	// This is a new binding.
   109  	multiPortEp := &multiPortEndpoint{}
   110  	multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
   111  	multiPortEp.reuse = reusePort
   112  	epsByNic.endpoints[bindToDevice] = multiPortEp
   113  	return multiPortEp.singleRegisterEndpoint(t, reusePort)
   114  }
   115  
   116  // unregisterEndpoint returns true if endpointsByNic has to be unregistered.
   117  func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
   118  	epsByNic.mu.Lock()
   119  	defer epsByNic.mu.Unlock()
   120  	multiPortEp, ok := epsByNic.endpoints[bindToDevice]
   121  	if !ok {
   122  		return false
   123  	}
   124  	if multiPortEp.unregisterEndpoint(t) {
   125  		delete(epsByNic.endpoints, bindToDevice)
   126  	}
   127  	return len(epsByNic.endpoints) == 0
   128  }
   129  
   130  // unregisterEndpoint unregisters the endpoint with the given id such that it
   131  // won't receive any more packets.
   132  func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
   133  	eps.mu.Lock()
   134  	defer eps.mu.Unlock()
   135  	epsByNic, ok := eps.endpoints[id]
   136  	if !ok {
   137  		return
   138  	}
   139  	if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
   140  		return
   141  	}
   142  	delete(eps.endpoints, id)
   143  }
   144  
   145  // transportDemuxer demultiplexes packets targeted at a transport endpoint
   146  // (i.e., after they've been parsed by the network layer). It does two levels
   147  // of demultiplexing: first based on the network and transport protocols, then
   148  // based on endpoints IDs. It should only be instantiated via
   149  // newTransportDemuxer.
   150  type transportDemuxer struct {
   151  	// protocol is immutable.
   152  	protocol map[protocolIDs]*transportEndpoints
   153  }
   154  
   155  func newTransportDemuxer(stack *Stack) *transportDemuxer {
   156  	d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
   157  
   158  	// Add each network and transport pair to the demuxer.
   159  	for netProto := range stack.networkProtocols {
   160  		for proto := range stack.transportProtocols {
   161  			d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
   162  				endpoints: make(map[TransportEndpointID]*endpointsByNic),
   163  			}
   164  		}
   165  	}
   166  
   167  	return d
   168  }
   169  
   170  // registerEndpoint registers the given endpoint with the dispatcher such that
   171  // packets that match the endpoint ID are delivered to it.
   172  func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
   173  	for i, n := range netProtos {
   174  		if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
   175  			d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
   176  			return err
   177  		}
   178  	}
   179  
   180  	return nil
   181  }
   182  
   183  // multiPortEndpoint is a container for TransportEndpoints which are bound to
   184  // the same pair of address and port. endpointsArr always has at least one
   185  // element.
   186  type multiPortEndpoint struct {
   187  	mu           sync.RWMutex
   188  	endpointsArr []TransportEndpoint
   189  	endpointsMap map[TransportEndpoint]int
   190  	// reuse indicates if more than one endpoint is allowed.
   191  	reuse bool
   192  }
   193  
   194  // reciprocalScale scales a value into range [0, n).
   195  //
   196  // This is similar to val % n, but faster.
   197  // See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
   198  func reciprocalScale(val, n uint32) uint32 {
   199  	return uint32((uint64(val) * uint64(n)) >> 32)
   200  }
   201  
   202  // selectEndpoint calculates a hash of destination and source addresses and
   203  // ports then uses it to select a socket. In this case, all packets from one
   204  // address will be sent to same endpoint.
   205  func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
   206  	if len(mpep.endpointsArr) == 1 {
   207  		return mpep.endpointsArr[0]
   208  	}
   209  
   210  	payload := []byte{
   211  		byte(id.LocalPort),
   212  		byte(id.LocalPort >> 8),
   213  		byte(id.RemotePort),
   214  		byte(id.RemotePort >> 8),
   215  	}
   216  
   217  	h := jenkins.Sum32(seed)
   218  	h.Write(payload)
   219  	h.Write([]byte(id.LocalAddress))
   220  	h.Write([]byte(id.RemoteAddress))
   221  	hash := h.Sum32()
   222  
   223  	idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
   224  	return mpep.endpointsArr[idx]
   225  }
   226  
   227  func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
   228  	ep.mu.RLock()
   229  	for i, endpoint := range ep.endpointsArr {
   230  		// HandlePacket modifies vv, so each endpoint needs its own copy except for
   231  		// the final one.
   232  		if i == len(ep.endpointsArr)-1 {
   233  			endpoint.HandlePacket(r, id, vv)
   234  			break
   235  		}
   236  		vvCopy := buffer.NewView(vv.Size())
   237  		copy(vvCopy, vv.ToView())
   238  		endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
   239  	}
   240  	ep.mu.RUnlock() // Don't use defer for performance reasons.
   241  }
   242  
   243  // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
   244  // list. The list might be empty already.
   245  func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
   246  	ep.mu.Lock()
   247  	defer ep.mu.Unlock()
   248  
   249  	if len(ep.endpointsArr) > 0 {
   250  		// If it was previously bound, we need to check if we can bind again.
   251  		if !ep.reuse || !reusePort {
   252  			return tcpip.ErrPortInUse
   253  		}
   254  	}
   255  
   256  	// A new endpoint is added into endpointsArr and its index there is saved in
   257  	// endpointsMap. This will allow us to remove endpoint from the array fast.
   258  	ep.endpointsMap[t] = len(ep.endpointsArr)
   259  	ep.endpointsArr = append(ep.endpointsArr, t)
   260  	return nil
   261  }
   262  
   263  // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
   264  func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
   265  	ep.mu.Lock()
   266  	defer ep.mu.Unlock()
   267  
   268  	idx, ok := ep.endpointsMap[t]
   269  	if !ok {
   270  		return false
   271  	}
   272  	delete(ep.endpointsMap, t)
   273  	l := len(ep.endpointsArr)
   274  	if l > 1 {
   275  		// The last endpoint in endpointsArr is moved instead of the deleted one.
   276  		lastEp := ep.endpointsArr[l-1]
   277  		ep.endpointsArr[idx] = lastEp
   278  		ep.endpointsMap[lastEp] = idx
   279  		ep.endpointsArr = ep.endpointsArr[0 : l-1]
   280  		return false
   281  	}
   282  	return true
   283  }
   284  
   285  func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
   286  	if id.RemotePort != 0 {
   287  		// TODO(eyalsoha): Why?
   288  		reusePort = false
   289  	}
   290  
   291  	eps, ok := d.protocol[protocolIDs{netProto, protocol}]
   292  	if !ok {
   293  		return tcpip.ErrUnknownProtocol
   294  	}
   295  
   296  	eps.mu.Lock()
   297  	defer eps.mu.Unlock()
   298  
   299  	if epsByNic, ok := eps.endpoints[id]; ok {
   300  		// There was already a binding.
   301  		return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
   302  	}
   303  
   304  	// This is a new binding.
   305  	epsByNic := &endpointsByNic{
   306  		endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
   307  		seed:      rand.Uint32(),
   308  	}
   309  	eps.endpoints[id] = epsByNic
   310  
   311  	return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
   312  }
   313  
   314  // unregisterEndpoint unregisters the endpoint with the given id such that it
   315  // won't receive any more packets.
   316  func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
   317  	for _, n := range netProtos {
   318  		if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
   319  			eps.unregisterEndpoint(id, ep, bindToDevice)
   320  		}
   321  	}
   322  }
   323  
   324  var loopbackSubnet = func() tcpip.Subnet {
   325  	sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
   326  	if err != nil {
   327  		panic(err)
   328  	}
   329  	return sn
   330  }()
   331  
   332  // deliverPacket attempts to find one or more matching transport endpoints, and
   333  // then, if matches are found, delivers the packet to them. Returns true if it
   334  // found one or more endpoints, false otherwise.
   335  func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
   336  	eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
   337  	if !ok {
   338  		return false
   339  	}
   340  
   341  	eps.mu.RLock()
   342  
   343  	// Determine which transport endpoint or endpoints to deliver this packet to.
   344  	// If the packet is a broadcast or multicast, then find all matching
   345  	// transport endpoints.
   346  	var destEps []*endpointsByNic
   347  	if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
   348  		destEps = d.findAllEndpointsLocked(eps, vv, id)
   349  	} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
   350  		destEps = append(destEps, ep)
   351  	}
   352  
   353  	eps.mu.RUnlock()
   354  
   355  	// Fail if we didn't find at least one matching transport endpoint.
   356  	if len(destEps) == 0 {
   357  		// UDP packet could not be delivered to an unknown destination port.
   358  		if protocol == header.UDPProtocolNumber {
   359  			r.Stats().UDP.UnknownPortErrors.Increment()
   360  		}
   361  		return false
   362  	}
   363  
   364  	// Deliver the packet.
   365  	for _, ep := range destEps {
   366  		ep.handlePacket(r, id, vv)
   367  	}
   368  
   369  	return true
   370  }
   371  
   372  // deliverRawPacket attempts to deliver the given packet and returns whether it
   373  // was delivered successfully.
   374  func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool {
   375  	eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
   376  	if !ok {
   377  		return false
   378  	}
   379  
   380  	// As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
   381  	// raw endpoint first. If there are multiple raw endpoints, they all
   382  	// receive the packet.
   383  	foundRaw := false
   384  	eps.mu.RLock()
   385  	for _, rawEP := range eps.rawEndpoints {
   386  		// Each endpoint gets its own copy of the packet for the sake
   387  		// of save/restore.
   388  		rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
   389  		foundRaw = true
   390  	}
   391  	eps.mu.RUnlock()
   392  
   393  	return foundRaw
   394  }
   395  
   396  // deliverControlPacket attempts to deliver the given control packet. Returns
   397  // true if it found an endpoint, false otherwise.
   398  func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
   399  	eps, ok := d.protocol[protocolIDs{net, trans}]
   400  	if !ok {
   401  		return false
   402  	}
   403  
   404  	// Try to find the endpoint.
   405  	eps.mu.RLock()
   406  	ep := d.findEndpointLocked(eps, vv, id)
   407  	eps.mu.RUnlock()
   408  
   409  	// Fail if we didn't find one.
   410  	if ep == nil {
   411  		return false
   412  	}
   413  
   414  	// Deliver the packet.
   415  	ep.handleControlPacket(n, id, typ, extra, vv)
   416  
   417  	return true
   418  }
   419  
   420  func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic {
   421  	var matchedEPs []*endpointsByNic
   422  	// Try to find a match with the id as provided.
   423  	if ep, ok := eps.endpoints[id]; ok {
   424  		matchedEPs = append(matchedEPs, ep)
   425  	}
   426  
   427  	// Try to find a match with the id minus the local address.
   428  	nid := id
   429  
   430  	nid.LocalAddress = ""
   431  	if ep, ok := eps.endpoints[nid]; ok {
   432  		matchedEPs = append(matchedEPs, ep)
   433  	}
   434  
   435  	// Try to find a match with the id minus the remote part.
   436  	nid.LocalAddress = id.LocalAddress
   437  	nid.RemoteAddress = ""
   438  	nid.RemotePort = 0
   439  	if ep, ok := eps.endpoints[nid]; ok {
   440  		matchedEPs = append(matchedEPs, ep)
   441  	}
   442  
   443  	// Try to find a match with only the local port.
   444  	nid.LocalAddress = ""
   445  	if ep, ok := eps.endpoints[nid]; ok {
   446  		matchedEPs = append(matchedEPs, ep)
   447  	}
   448  
   449  	return matchedEPs
   450  }
   451  
   452  // findEndpointLocked returns the endpoint that most closely matches the given
   453  // id.
   454  func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
   455  	if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 {
   456  		return matchedEPs[0]
   457  	}
   458  	return nil
   459  }
   460  
   461  // registerRawEndpoint registers the given endpoint with the dispatcher such
   462  // that packets of the appropriate protocol are delivered to it. A single
   463  // packet can be sent to one or more raw endpoints along with a non-raw
   464  // endpoint.
   465  func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
   466  	eps, ok := d.protocol[protocolIDs{netProto, transProto}]
   467  	if !ok {
   468  		return nil
   469  	}
   470  
   471  	eps.mu.Lock()
   472  	defer eps.mu.Unlock()
   473  	eps.rawEndpoints = append(eps.rawEndpoints, ep)
   474  
   475  	return nil
   476  }
   477  
   478  // unregisterRawEndpoint unregisters the raw endpoint for the given transport
   479  // protocol such that it won't receive any more packets.
   480  func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
   481  	eps, ok := d.protocol[protocolIDs{netProto, transProto}]
   482  	if !ok {
   483  		panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
   484  	}
   485  
   486  	eps.mu.Lock()
   487  	defer eps.mu.Unlock()
   488  	for i, rawEP := range eps.rawEndpoints {
   489  		if rawEP == ep {
   490  			eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
   491  			return
   492  		}
   493  	}
   494  }
   495  
   496  func isMulticastOrBroadcast(addr tcpip.Address) bool {
   497  	return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
   498  }