github.com/slackhq/nebula@v1.9.0/connection_manager.go (about)

     1  package nebula
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/rcrowley/go-metrics"
    10  	"github.com/sirupsen/logrus"
    11  	"github.com/slackhq/nebula/cert"
    12  	"github.com/slackhq/nebula/header"
    13  	"github.com/slackhq/nebula/iputil"
    14  	"github.com/slackhq/nebula/udp"
    15  )
    16  
    17  type trafficDecision int
    18  
    19  const (
    20  	doNothing      trafficDecision = 0
    21  	deleteTunnel   trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote
    22  	closeTunnel    trafficDecision = 2 // delete the hostinfo and notify the remote
    23  	swapPrimary    trafficDecision = 3
    24  	migrateRelays  trafficDecision = 4
    25  	tryRehandshake trafficDecision = 5
    26  	sendTestPacket trafficDecision = 6
    27  )
    28  
    29  type connectionManager struct {
    30  	in     map[uint32]struct{}
    31  	inLock *sync.RWMutex
    32  
    33  	out     map[uint32]struct{}
    34  	outLock *sync.RWMutex
    35  
    36  	// relayUsed holds which relay localIndexs are in use
    37  	relayUsed     map[uint32]struct{}
    38  	relayUsedLock *sync.RWMutex
    39  
    40  	hostMap                 *HostMap
    41  	trafficTimer            *LockingTimerWheel[uint32]
    42  	intf                    *Interface
    43  	pendingDeletion         map[uint32]struct{}
    44  	punchy                  *Punchy
    45  	checkInterval           time.Duration
    46  	pendingDeletionInterval time.Duration
    47  	metricsTxPunchy         metrics.Counter
    48  
    49  	l *logrus.Logger
    50  }
    51  
    52  func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
    53  	var max time.Duration
    54  	if checkInterval < pendingDeletionInterval {
    55  		max = pendingDeletionInterval
    56  	} else {
    57  		max = checkInterval
    58  	}
    59  
    60  	nc := &connectionManager{
    61  		hostMap:                 intf.hostMap,
    62  		in:                      make(map[uint32]struct{}),
    63  		inLock:                  &sync.RWMutex{},
    64  		out:                     make(map[uint32]struct{}),
    65  		outLock:                 &sync.RWMutex{},
    66  		relayUsed:               make(map[uint32]struct{}),
    67  		relayUsedLock:           &sync.RWMutex{},
    68  		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
    69  		intf:                    intf,
    70  		pendingDeletion:         make(map[uint32]struct{}),
    71  		checkInterval:           checkInterval,
    72  		pendingDeletionInterval: pendingDeletionInterval,
    73  		punchy:                  punchy,
    74  		metricsTxPunchy:         metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
    75  		l:                       l,
    76  	}
    77  
    78  	nc.Start(ctx)
    79  	return nc
    80  }
    81  
    82  func (n *connectionManager) In(localIndex uint32) {
    83  	n.inLock.RLock()
    84  	// If this already exists, return
    85  	if _, ok := n.in[localIndex]; ok {
    86  		n.inLock.RUnlock()
    87  		return
    88  	}
    89  	n.inLock.RUnlock()
    90  	n.inLock.Lock()
    91  	n.in[localIndex] = struct{}{}
    92  	n.inLock.Unlock()
    93  }
    94  
    95  func (n *connectionManager) Out(localIndex uint32) {
    96  	n.outLock.RLock()
    97  	// If this already exists, return
    98  	if _, ok := n.out[localIndex]; ok {
    99  		n.outLock.RUnlock()
   100  		return
   101  	}
   102  	n.outLock.RUnlock()
   103  	n.outLock.Lock()
   104  	n.out[localIndex] = struct{}{}
   105  	n.outLock.Unlock()
   106  }
   107  
   108  func (n *connectionManager) RelayUsed(localIndex uint32) {
   109  	n.relayUsedLock.RLock()
   110  	// If this already exists, return
   111  	if _, ok := n.relayUsed[localIndex]; ok {
   112  		n.relayUsedLock.RUnlock()
   113  		return
   114  	}
   115  	n.relayUsedLock.RUnlock()
   116  	n.relayUsedLock.Lock()
   117  	n.relayUsed[localIndex] = struct{}{}
   118  	n.relayUsedLock.Unlock()
   119  }
   120  
   121  // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
   122  // resets the state for this local index
   123  func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
   124  	n.inLock.Lock()
   125  	n.outLock.Lock()
   126  	_, in := n.in[localIndex]
   127  	_, out := n.out[localIndex]
   128  	delete(n.in, localIndex)
   129  	delete(n.out, localIndex)
   130  	n.inLock.Unlock()
   131  	n.outLock.Unlock()
   132  	return in, out
   133  }
   134  
   135  func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
   136  	// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
   137  	n.outLock.Lock()
   138  	if _, ok := n.out[localIndex]; ok {
   139  		n.outLock.Unlock()
   140  		return
   141  	}
   142  	n.out[localIndex] = struct{}{}
   143  	n.trafficTimer.Add(localIndex, n.checkInterval)
   144  	n.outLock.Unlock()
   145  }
   146  
   147  func (n *connectionManager) Start(ctx context.Context) {
   148  	go n.Run(ctx)
   149  }
   150  
   151  func (n *connectionManager) Run(ctx context.Context) {
   152  	//TODO: this tick should be based on the min wheel tick? Check firewall
   153  	clockSource := time.NewTicker(500 * time.Millisecond)
   154  	defer clockSource.Stop()
   155  
   156  	p := []byte("")
   157  	nb := make([]byte, 12, 12)
   158  	out := make([]byte, mtu)
   159  
   160  	for {
   161  		select {
   162  		case <-ctx.Done():
   163  			return
   164  
   165  		case now := <-clockSource.C:
   166  			n.trafficTimer.Advance(now)
   167  			for {
   168  				localIndex, has := n.trafficTimer.Purge()
   169  				if !has {
   170  					break
   171  				}
   172  
   173  				n.doTrafficCheck(localIndex, p, nb, out, now)
   174  			}
   175  		}
   176  	}
   177  }
   178  
   179  func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
   180  	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
   181  
   182  	switch decision {
   183  	case deleteTunnel:
   184  		if n.hostMap.DeleteHostInfo(hostinfo) {
   185  			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
   186  			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
   187  		}
   188  
   189  	case closeTunnel:
   190  		n.intf.sendCloseTunnel(hostinfo)
   191  		n.intf.closeTunnel(hostinfo)
   192  
   193  	case swapPrimary:
   194  		n.swapPrimary(hostinfo, primary)
   195  
   196  	case migrateRelays:
   197  		n.migrateRelayUsed(hostinfo, primary)
   198  
   199  	case tryRehandshake:
   200  		n.tryRehandshake(hostinfo)
   201  
   202  	case sendTestPacket:
   203  		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
   204  	}
   205  
   206  	n.resetRelayTrafficCheck(hostinfo)
   207  }
   208  
   209  func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
   210  	if hostinfo != nil {
   211  		n.relayUsedLock.Lock()
   212  		defer n.relayUsedLock.Unlock()
   213  		// No need to migrate any relays, delete usage info now.
   214  		for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
   215  			delete(n.relayUsed, idx)
   216  		}
   217  	}
   218  }
   219  
   220  func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
   221  	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
   222  
   223  	for _, r := range relayFor {
   224  		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
   225  
   226  		var index uint32
   227  		var relayFrom iputil.VpnIp
   228  		var relayTo iputil.VpnIp
   229  		switch {
   230  		case ok && existing.State == Established:
   231  			// This relay already exists in newhostinfo, then do nothing.
   232  			continue
   233  		case ok && existing.State == Requested:
   234  			// The relay exists in a Requested state; re-send the request
   235  			index = existing.LocalIndex
   236  			switch r.Type {
   237  			case TerminalType:
   238  				relayFrom = n.intf.myVpnIp
   239  				relayTo = existing.PeerIp
   240  			case ForwardingType:
   241  				relayFrom = existing.PeerIp
   242  				relayTo = newhostinfo.vpnIp
   243  			default:
   244  				// should never happen
   245  			}
   246  		case !ok:
   247  			n.relayUsedLock.RLock()
   248  			if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
   249  				// The relay hasn't been used; don't migrate it.
   250  				n.relayUsedLock.RUnlock()
   251  				continue
   252  			}
   253  			n.relayUsedLock.RUnlock()
   254  			// The relay doesn't exist at all; create some relay state and send the request.
   255  			var err error
   256  			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
   257  			if err != nil {
   258  				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
   259  				continue
   260  			}
   261  			switch r.Type {
   262  			case TerminalType:
   263  				relayFrom = n.intf.myVpnIp
   264  				relayTo = r.PeerIp
   265  			case ForwardingType:
   266  				relayFrom = r.PeerIp
   267  				relayTo = newhostinfo.vpnIp
   268  			default:
   269  				// should never happen
   270  			}
   271  		}
   272  
   273  		// Send a CreateRelayRequest to the peer.
   274  		req := NebulaControl{
   275  			Type:                NebulaControl_CreateRelayRequest,
   276  			InitiatorRelayIndex: index,
   277  			RelayFromIp:         uint32(relayFrom),
   278  			RelayToIp:           uint32(relayTo),
   279  		}
   280  		msg, err := req.Marshal()
   281  		if err != nil {
   282  			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
   283  		} else {
   284  			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
   285  			n.l.WithFields(logrus.Fields{
   286  				"relayFrom":           iputil.VpnIp(req.RelayFromIp),
   287  				"relayTo":             iputil.VpnIp(req.RelayToIp),
   288  				"initiatorRelayIndex": req.InitiatorRelayIndex,
   289  				"responderRelayIndex": req.ResponderRelayIndex,
   290  				"vpnIp":               newhostinfo.vpnIp}).
   291  				Info("send CreateRelayRequest")
   292  		}
   293  	}
   294  }
   295  
   296  func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
   297  	n.hostMap.RLock()
   298  	defer n.hostMap.RUnlock()
   299  
   300  	hostinfo := n.hostMap.Indexes[localIndex]
   301  	if hostinfo == nil {
   302  		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
   303  		delete(n.pendingDeletion, localIndex)
   304  		return doNothing, nil, nil
   305  	}
   306  
   307  	if n.isInvalidCertificate(now, hostinfo) {
   308  		delete(n.pendingDeletion, hostinfo.localIndexId)
   309  		return closeTunnel, hostinfo, nil
   310  	}
   311  
   312  	primary := n.hostMap.Hosts[hostinfo.vpnIp]
   313  	mainHostInfo := true
   314  	if primary != nil && primary != hostinfo {
   315  		mainHostInfo = false
   316  	}
   317  
   318  	// Check for traffic on this hostinfo
   319  	inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
   320  
   321  	// A hostinfo is determined alive if there is incoming traffic
   322  	if inTraffic {
   323  		decision := doNothing
   324  		if n.l.Level >= logrus.DebugLevel {
   325  			hostinfo.logger(n.l).
   326  				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
   327  				Debug("Tunnel status")
   328  		}
   329  		delete(n.pendingDeletion, hostinfo.localIndexId)
   330  
   331  		if mainHostInfo {
   332  			decision = tryRehandshake
   333  
   334  		} else {
   335  			if n.shouldSwapPrimary(hostinfo, primary) {
   336  				decision = swapPrimary
   337  			} else {
   338  				// migrate the relays to the primary, if in use.
   339  				decision = migrateRelays
   340  			}
   341  		}
   342  
   343  		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
   344  
   345  		if !outTraffic {
   346  			// Send a punch packet to keep the NAT state alive
   347  			n.sendPunch(hostinfo)
   348  		}
   349  
   350  		return decision, hostinfo, primary
   351  	}
   352  
   353  	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
   354  		// We have already sent a test packet and nothing was returned, this hostinfo is dead
   355  		hostinfo.logger(n.l).
   356  			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
   357  			Info("Tunnel status")
   358  
   359  		delete(n.pendingDeletion, hostinfo.localIndexId)
   360  		return deleteTunnel, hostinfo, nil
   361  	}
   362  
   363  	decision := doNothing
   364  	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
   365  		if !outTraffic {
   366  			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
   367  			// Just maintain NAT state if configured to do so.
   368  			n.sendPunch(hostinfo)
   369  			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
   370  			return doNothing, nil, nil
   371  
   372  		}
   373  
   374  		if n.punchy.GetTargetEverything() {
   375  			// This is similar to the old punchy behavior with a slight optimization.
   376  			// We aren't receiving traffic but we are sending it, punch on all known
   377  			// ips in case we need to re-prime NAT state
   378  			n.sendPunch(hostinfo)
   379  		}
   380  
   381  		if n.l.Level >= logrus.DebugLevel {
   382  			hostinfo.logger(n.l).
   383  				WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
   384  				Debug("Tunnel status")
   385  		}
   386  
   387  		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
   388  		decision = sendTestPacket
   389  
   390  	} else {
   391  		if n.l.Level >= logrus.DebugLevel {
   392  			hostinfo.logger(n.l).Debugf("Hostinfo sadness")
   393  		}
   394  	}
   395  
   396  	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
   397  	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
   398  	return decision, hostinfo, nil
   399  }
   400  
   401  func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
   402  	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
   403  	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
   404  	// Let's sort this out.
   405  
   406  	if current.vpnIp < n.intf.myVpnIp {
   407  		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
   408  		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
   409  		// The remotes vpn ip is lower than mine. I will not flip.
   410  		return false
   411  	}
   412  
   413  	certState := n.intf.pki.GetCertState()
   414  	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
   415  }
   416  
   417  func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
   418  	n.hostMap.Lock()
   419  	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
   420  	if n.hostMap.Hosts[current.vpnIp] == primary {
   421  		n.hostMap.unlockedMakePrimary(current)
   422  	}
   423  	n.hostMap.Unlock()
   424  }
   425  
   426  // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
   427  // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
   428  // check and return true.
   429  func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
   430  	remoteCert := hostinfo.GetCert()
   431  	if remoteCert == nil {
   432  		return false
   433  	}
   434  
   435  	valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
   436  	if valid {
   437  		return false
   438  	}
   439  
   440  	if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
   441  		// Block listed certificates should always be disconnected
   442  		return false
   443  	}
   444  
   445  	fingerprint, _ := remoteCert.Sha256Sum()
   446  	hostinfo.logger(n.l).WithError(err).
   447  		WithField("fingerprint", fingerprint).
   448  		Info("Remote certificate is no longer valid, tearing down the tunnel")
   449  
   450  	return true
   451  }
   452  
   453  func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
   454  	if !n.punchy.GetPunch() {
   455  		// Punching is disabled
   456  		return
   457  	}
   458  
   459  	if n.punchy.GetTargetEverything() {
   460  		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
   461  			n.metricsTxPunchy.Inc(1)
   462  			n.intf.outside.WriteTo([]byte{1}, addr)
   463  		})
   464  
   465  	} else if hostinfo.remote != nil {
   466  		n.metricsTxPunchy.Inc(1)
   467  		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
   468  	}
   469  }
   470  
   471  func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
   472  	certState := n.intf.pki.GetCertState()
   473  	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
   474  		return
   475  	}
   476  
   477  	n.l.WithField("vpnIp", hostinfo.vpnIp).
   478  		WithField("reason", "local certificate is not current").
   479  		Info("Re-handshaking with remote")
   480  
   481  	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
   482  }