github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/ports/ports.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package ports provides PortManager that manages allocating, reserving and
    16  // releasing ports.
    17  package ports
    18  
    19  import (
    20  	"math"
    21  	"math/rand"
    22  	"sync/atomic"
    23  
    24  	"github.com/SagerNet/gvisor/pkg/sync"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    27  )
    28  
    29  const (
    30  	firstEphemeral               = 16000
    31  	anyIPAddress   tcpip.Address = ""
    32  )
    33  
    34  // Reservation describes a port reservation.
    35  type Reservation struct {
    36  	// Networks is a list of network protocols to which the reservation
    37  	// applies. Can be IPv4, IPv6, or both.
    38  	Networks []tcpip.NetworkProtocolNumber
    39  
    40  	// Transport is the transport protocol to which the reservation applies.
    41  	Transport tcpip.TransportProtocolNumber
    42  
    43  	// Addr is the address of the local endpoint.
    44  	Addr tcpip.Address
    45  
    46  	// Port is the local port number.
    47  	Port uint16
    48  
    49  	// Flags describe features of the reservation.
    50  	Flags Flags
    51  
    52  	// BindToDevice is the NIC to which the reservation applies.
    53  	BindToDevice tcpip.NICID
    54  
    55  	// Dest is the destination address.
    56  	Dest tcpip.FullAddress
    57  }
    58  
    59  func (rs Reservation) dst() destination {
    60  	return destination{
    61  		rs.Dest.Addr,
    62  		rs.Dest.Port,
    63  	}
    64  }
    65  
    66  type portDescriptor struct {
    67  	network   tcpip.NetworkProtocolNumber
    68  	transport tcpip.TransportProtocolNumber
    69  	port      uint16
    70  }
    71  
    72  type destination struct {
    73  	addr tcpip.Address
    74  	port uint16
    75  }
    76  
    77  // destToCounter maps each destination to the FlagCounter that represents
    78  // endpoints to that destination.
    79  //
    80  // destToCounter is never empty. When it has no elements, it is removed from
    81  // the map that references it.
    82  type destToCounter map[destination]FlagCounter
    83  
    84  // intersectionFlags calculates the intersection of flag bit values which affect
    85  // the specified destination.
    86  //
    87  // If no destinations are present, all flag values are returned as there are no
    88  // entries to limit possible flag values of a new entry.
    89  //
    90  // In addition to the intersection, the number of intersecting refs is
    91  // returned.
    92  func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
    93  	intersection := FlagMask
    94  	var count int
    95  
    96  	for dest, counter := range dc {
    97  		if dest == res.dst() {
    98  			intersection &= counter.SharedFlags()
    99  			count++
   100  			continue
   101  		}
   102  		// Wildcard destinations affect all destinations for TupleOnly.
   103  		if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
   104  			// Only bitwise and the TupleOnlyFlag.
   105  			intersection &= (^TupleOnlyFlag) | counter.SharedFlags()
   106  			count++
   107  		}
   108  	}
   109  
   110  	return intersection, count
   111  }
   112  
   113  // deviceToDest maps NICs to destinations for which there are port reservations.
   114  //
   115  // deviceToDest is never empty. When it has no elements, it is removed from the
   116  // map that references it.
   117  type deviceToDest map[tcpip.NICID]destToCounter
   118  
   119  // isAvailable checks whether binding is possible by device. If not binding to
   120  // a device, check against all FlagCounters. If binding to a specific device,
   121  // check against the unspecified device and the provided device.
   122  //
   123  // If either of the port reuse flags is enabled on any of the nodes, all nodes
   124  // sharing a port must share at least one reuse flag. This matches Linux's
   125  // behavior.
   126  func (dd deviceToDest) isAvailable(res Reservation, portSpecified bool) bool {
   127  	flagBits := res.Flags.Bits()
   128  	if res.BindToDevice == 0 {
   129  		intersection := FlagMask
   130  		for _, dest := range dd {
   131  			flags, count := dest.intersectionFlags(res)
   132  			if count == 0 {
   133  				continue
   134  			}
   135  			intersection &= flags
   136  			if intersection&flagBits == 0 {
   137  				// Can't bind because the (addr,port) was
   138  				// previously bound without reuse.
   139  				return false
   140  			}
   141  		}
   142  		if !portSpecified && res.Transport == header.TCPProtocolNumber {
   143  			return false
   144  		}
   145  		return true
   146  	}
   147  
   148  	intersection := FlagMask
   149  
   150  	if dests, ok := dd[0]; ok {
   151  		var count int
   152  		intersection, count = dests.intersectionFlags(res)
   153  		if count > 0 {
   154  			if intersection&flagBits == 0 {
   155  				return false
   156  			}
   157  			if !portSpecified && res.Transport == header.TCPProtocolNumber {
   158  				return false
   159  			}
   160  		}
   161  	}
   162  
   163  	if dests, ok := dd[res.BindToDevice]; ok {
   164  		flags, count := dests.intersectionFlags(res)
   165  		intersection &= flags
   166  		if count > 0 {
   167  			if intersection&flagBits == 0 {
   168  				return false
   169  			}
   170  			if !portSpecified && res.Transport == header.TCPProtocolNumber {
   171  				return false
   172  			}
   173  		}
   174  	}
   175  
   176  	return true
   177  }
   178  
   179  // addrToDevice maps IP addresses to NICs that have port reservations.
   180  type addrToDevice map[tcpip.Address]deviceToDest
   181  
   182  // isAvailable checks whether an IP address is available to bind to. If the
   183  // address is the "any" address, check all other addresses. Otherwise, just
   184  // check against the "any" address and the provided address.
   185  func (ad addrToDevice) isAvailable(res Reservation, portSpecified bool) bool {
   186  	if res.Addr == anyIPAddress {
   187  		// If binding to the "any" address then check that there are no
   188  		// conflicts with all addresses.
   189  		for _, devices := range ad {
   190  			if !devices.isAvailable(res, portSpecified) {
   191  				return false
   192  			}
   193  		}
   194  		return true
   195  	}
   196  
   197  	// Check that there is no conflict with the "any" address.
   198  	if devices, ok := ad[anyIPAddress]; ok {
   199  		if !devices.isAvailable(res, portSpecified) {
   200  			return false
   201  		}
   202  	}
   203  
   204  	// Check that this is no conflict with the provided address.
   205  	if devices, ok := ad[res.Addr]; ok {
   206  		if !devices.isAvailable(res, portSpecified) {
   207  			return false
   208  		}
   209  	}
   210  
   211  	return true
   212  }
   213  
   214  // PortManager manages allocating, reserving and releasing ports.
   215  type PortManager struct {
   216  	// mu protects allocatedPorts.
   217  	// LOCK ORDERING: mu > ephemeralMu.
   218  	mu sync.RWMutex
   219  	// allocatedPorts is a nesting of maps that ultimately map Reservations
   220  	// to FlagCounters describing whether the Reservation is valid and can
   221  	// be reused.
   222  	allocatedPorts map[portDescriptor]addrToDevice
   223  
   224  	// ephemeralMu protects firstEphemeral and numEphemeral.
   225  	ephemeralMu    sync.RWMutex
   226  	firstEphemeral uint16
   227  	numEphemeral   uint16
   228  
   229  	// hint is used to pick ports ephemeral ports in a stable order for
   230  	// a given port offset.
   231  	//
   232  	// hint must be accessed using the portHint/incPortHint helpers.
   233  	// TODO(github.com/SagerNet/issue/940): S/R this field.
   234  	hint uint32
   235  }
   236  
   237  // NewPortManager creates new PortManager.
   238  func NewPortManager() *PortManager {
   239  	return &PortManager{
   240  		allocatedPorts: make(map[portDescriptor]addrToDevice),
   241  		firstEphemeral: firstEphemeral,
   242  		numEphemeral:   math.MaxUint16 - firstEphemeral + 1,
   243  	}
   244  }
   245  
   246  // PortTester indicates whether the passed in port is suitable. Returning an
   247  // error causes the function to which the PortTester is passed to return that
   248  // error.
   249  type PortTester func(port uint16) (good bool, err tcpip.Error)
   250  
   251  // PickEphemeralPort randomly chooses a starting point and iterates over all
   252  // possible ephemeral ports, allowing the caller to decide whether a given port
   253  // is suitable for its needs, and stopping when a port is found or an error
   254  // occurs.
   255  func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) {
   256  	pm.ephemeralMu.RLock()
   257  	firstEphemeral := pm.firstEphemeral
   258  	numEphemeral := pm.numEphemeral
   259  	pm.ephemeralMu.RUnlock()
   260  
   261  	offset := uint32(rng.Int31n(int32(numEphemeral)))
   262  	return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
   263  }
   264  
   265  // portHint atomically reads and returns the pm.hint value.
   266  func (pm *PortManager) portHint() uint32 {
   267  	return atomic.LoadUint32(&pm.hint)
   268  }
   269  
   270  // incPortHint atomically increments pm.hint by 1.
   271  func (pm *PortManager) incPortHint() {
   272  	atomic.AddUint32(&pm.hint, 1)
   273  }
   274  
   275  // PickEphemeralPortStable starts at the specified offset + pm.portHint and
   276  // iterates over all ephemeral ports, allowing the caller to decide whether a
   277  // given port is suitable for its needs and stopping when a port is found or an
   278  // error occurs.
   279  func (pm *PortManager) PickEphemeralPortStable(offset uint32, testPort PortTester) (port uint16, err tcpip.Error) {
   280  	pm.ephemeralMu.RLock()
   281  	firstEphemeral := pm.firstEphemeral
   282  	numEphemeral := pm.numEphemeral
   283  	pm.ephemeralMu.RUnlock()
   284  
   285  	p, err := pickEphemeralPort(pm.portHint()+offset, firstEphemeral, numEphemeral, testPort)
   286  	if err == nil {
   287  		pm.incPortHint()
   288  	}
   289  	return p, err
   290  }
   291  
   292  // pickEphemeralPort starts at the offset specified from the FirstEphemeral port
   293  // and iterates over the number of ports specified by count and allows the
   294  // caller to decide whether a given port is suitable for its needs, and stopping
   295  // when a port is found or an error occurs.
   296  func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
   297  	for i := uint32(0); i < uint32(count); i++ {
   298  		port := uint16(uint32(first) + (offset+i)%uint32(count))
   299  		ok, err := testPort(port)
   300  		if err != nil {
   301  			return 0, err
   302  		}
   303  
   304  		if ok {
   305  			return port, nil
   306  		}
   307  	}
   308  
   309  	return 0, &tcpip.ErrNoPortAvailable{}
   310  }
   311  
   312  // ReservePort marks a port/IP combination as reserved so that it cannot be
   313  // reserved by another endpoint. If port is zero, ReservePort will search for
   314  // an unreserved ephemeral port and reserve it, returning its value in the
   315  // "port" return value.
   316  //
   317  // An optional PortTester can be passed in which if provided will be used to
   318  // test if the picked port can be used. The function should return true if the
   319  // port is safe to use, false otherwise.
   320  func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
   321  	pm.mu.Lock()
   322  	defer pm.mu.Unlock()
   323  
   324  	// If a port is specified, just try to reserve it for all network
   325  	// protocols.
   326  	if res.Port != 0 {
   327  		if !pm.reserveSpecificPortLocked(res, true /* portSpecified */) {
   328  			return 0, &tcpip.ErrPortInUse{}
   329  		}
   330  		if testPort != nil {
   331  			ok, err := testPort(res.Port)
   332  			if err != nil {
   333  				pm.releasePortLocked(res)
   334  				return 0, err
   335  			}
   336  			if !ok {
   337  				pm.releasePortLocked(res)
   338  				return 0, &tcpip.ErrPortInUse{}
   339  			}
   340  		}
   341  		return res.Port, nil
   342  	}
   343  
   344  	// A port wasn't specified, so try to find one.
   345  	return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
   346  		res.Port = p
   347  		if !pm.reserveSpecificPortLocked(res, false /* portSpecified */) {
   348  			return false, nil
   349  		}
   350  		if testPort != nil {
   351  			ok, err := testPort(p)
   352  			if err != nil {
   353  				pm.releasePortLocked(res)
   354  				return false, err
   355  			}
   356  			if !ok {
   357  				pm.releasePortLocked(res)
   358  				return false, nil
   359  			}
   360  		}
   361  		return true, nil
   362  	})
   363  }
   364  
   365  // reserveSpecificPortLocked tries to reserve the given port on all given
   366  // protocols.
   367  func (pm *PortManager) reserveSpecificPortLocked(res Reservation, portSpecified bool) bool {
   368  	// Make sure the port is available.
   369  	for _, network := range res.Networks {
   370  		desc := portDescriptor{network, res.Transport, res.Port}
   371  		if addrs, ok := pm.allocatedPorts[desc]; ok {
   372  			if !addrs.isAvailable(res, portSpecified) {
   373  				return false
   374  			}
   375  		}
   376  	}
   377  
   378  	// Reserve port on all network protocols.
   379  	flagBits := res.Flags.Bits()
   380  	dst := res.dst()
   381  	for _, network := range res.Networks {
   382  		desc := portDescriptor{network, res.Transport, res.Port}
   383  		addrToDev, ok := pm.allocatedPorts[desc]
   384  		if !ok {
   385  			addrToDev = make(addrToDevice)
   386  			pm.allocatedPorts[desc] = addrToDev
   387  		}
   388  		devToDest, ok := addrToDev[res.Addr]
   389  		if !ok {
   390  			devToDest = make(deviceToDest)
   391  			addrToDev[res.Addr] = devToDest
   392  		}
   393  		destToCntr := devToDest[res.BindToDevice]
   394  		if destToCntr == nil {
   395  			destToCntr = make(destToCounter)
   396  		}
   397  		counter := destToCntr[dst]
   398  		counter.AddRef(flagBits)
   399  		destToCntr[dst] = counter
   400  		devToDest[res.BindToDevice] = destToCntr
   401  	}
   402  
   403  	return true
   404  }
   405  
   406  // ReserveTuple adds a port reservation for the tuple on all given protocol.
   407  func (pm *PortManager) ReserveTuple(res Reservation) bool {
   408  	flagBits := res.Flags.Bits()
   409  	dst := res.dst()
   410  
   411  	pm.mu.Lock()
   412  	defer pm.mu.Unlock()
   413  
   414  	// It is easier to undo the entire reservation, so if we find that the
   415  	// tuple can't be fully added, finish and undo the whole thing.
   416  	undo := false
   417  
   418  	// Reserve port on all network protocols.
   419  	for _, network := range res.Networks {
   420  		desc := portDescriptor{network, res.Transport, res.Port}
   421  		addrToDev, ok := pm.allocatedPorts[desc]
   422  		if !ok {
   423  			addrToDev = make(addrToDevice)
   424  			pm.allocatedPorts[desc] = addrToDev
   425  		}
   426  		devToDest, ok := addrToDev[res.Addr]
   427  		if !ok {
   428  			devToDest = make(deviceToDest)
   429  			addrToDev[res.Addr] = devToDest
   430  		}
   431  		destToCntr := devToDest[res.BindToDevice]
   432  		if destToCntr == nil {
   433  			destToCntr = make(destToCounter)
   434  		}
   435  
   436  		counter := destToCntr[dst]
   437  		if counter.TotalRefs() != 0 && counter.SharedFlags()&flagBits == 0 {
   438  			// Tuple already exists.
   439  			undo = true
   440  		}
   441  		counter.AddRef(flagBits)
   442  		destToCntr[dst] = counter
   443  		devToDest[res.BindToDevice] = destToCntr
   444  	}
   445  
   446  	if undo {
   447  		// releasePortLocked decrements the counts (rather than setting
   448  		// them to zero), so it will undo the incorrect incrementing
   449  		// above.
   450  		pm.releasePortLocked(res)
   451  		return false
   452  	}
   453  
   454  	return true
   455  }
   456  
   457  // ReleasePort releases the reservation on a port/IP combination so that it can
   458  // be reserved by other endpoints.
   459  func (pm *PortManager) ReleasePort(res Reservation) {
   460  	pm.mu.Lock()
   461  	defer pm.mu.Unlock()
   462  
   463  	pm.releasePortLocked(res)
   464  }
   465  
   466  func (pm *PortManager) releasePortLocked(res Reservation) {
   467  	dst := res.dst()
   468  	for _, network := range res.Networks {
   469  		desc := portDescriptor{network, res.Transport, res.Port}
   470  		addrToDev, ok := pm.allocatedPorts[desc]
   471  		if !ok {
   472  			continue
   473  		}
   474  		devToDest, ok := addrToDev[res.Addr]
   475  		if !ok {
   476  			continue
   477  		}
   478  		destToCounter, ok := devToDest[res.BindToDevice]
   479  		if !ok {
   480  			continue
   481  		}
   482  		counter, ok := destToCounter[dst]
   483  		if !ok {
   484  			continue
   485  		}
   486  		counter.DropRef(res.Flags.Bits())
   487  		if counter.TotalRefs() > 0 {
   488  			destToCounter[dst] = counter
   489  			continue
   490  		}
   491  		delete(destToCounter, dst)
   492  		if len(destToCounter) > 0 {
   493  			continue
   494  		}
   495  		delete(devToDest, res.BindToDevice)
   496  		if len(devToDest) > 0 {
   497  			continue
   498  		}
   499  		delete(addrToDev, res.Addr)
   500  		if len(addrToDev) > 0 {
   501  			continue
   502  		}
   503  		delete(pm.allocatedPorts, desc)
   504  	}
   505  }
   506  
   507  // PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
   508  // both IPv4 and IPv6.
   509  func (pm *PortManager) PortRange() (uint16, uint16) {
   510  	pm.ephemeralMu.RLock()
   511  	defer pm.ephemeralMu.RUnlock()
   512  	return pm.firstEphemeral, pm.firstEphemeral + pm.numEphemeral - 1
   513  }
   514  
   515  // SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
   516  // (inclusive).
   517  func (pm *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error {
   518  	if start > end {
   519  		return &tcpip.ErrInvalidPortRange{}
   520  	}
   521  	pm.ephemeralMu.Lock()
   522  	defer pm.ephemeralMu.Unlock()
   523  	pm.firstEphemeral = start
   524  	pm.numEphemeral = end - start + 1
   525  	return nil
   526  }