github.com/slackhq/nebula@v1.9.0/interface.go (about)

     1  package nebula
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"runtime"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/rcrowley/go-metrics"
    15  	"github.com/sirupsen/logrus"
    16  	"github.com/slackhq/nebula/config"
    17  	"github.com/slackhq/nebula/firewall"
    18  	"github.com/slackhq/nebula/header"
    19  	"github.com/slackhq/nebula/iputil"
    20  	"github.com/slackhq/nebula/overlay"
    21  	"github.com/slackhq/nebula/udp"
    22  )
    23  
    24  const mtu = 9001
    25  
    26  type InterfaceConfig struct {
    27  	HostMap                 *HostMap
    28  	Outside                 udp.Conn
    29  	Inside                  overlay.Device
    30  	pki                     *PKI
    31  	Cipher                  string
    32  	Firewall                *Firewall
    33  	ServeDns                bool
    34  	HandshakeManager        *HandshakeManager
    35  	lightHouse              *LightHouse
    36  	checkInterval           time.Duration
    37  	pendingDeletionInterval time.Duration
    38  	DropLocalBroadcast      bool
    39  	DropMulticast           bool
    40  	routines                int
    41  	MessageMetrics          *MessageMetrics
    42  	version                 string
    43  	relayManager            *relayManager
    44  	punchy                  *Punchy
    45  
    46  	tryPromoteEvery uint32
    47  	reQueryEvery    uint32
    48  	reQueryWait     time.Duration
    49  
    50  	ConntrackCacheTimeout time.Duration
    51  	l                     *logrus.Logger
    52  }
    53  
    54  type Interface struct {
    55  	hostMap            *HostMap
    56  	outside            udp.Conn
    57  	inside             overlay.Device
    58  	pki                *PKI
    59  	cipher             string
    60  	firewall           *Firewall
    61  	connectionManager  *connectionManager
    62  	handshakeManager   *HandshakeManager
    63  	serveDns           bool
    64  	createTime         time.Time
    65  	lightHouse         *LightHouse
    66  	localBroadcast     iputil.VpnIp
    67  	myVpnIp            iputil.VpnIp
    68  	dropLocalBroadcast bool
    69  	dropMulticast      bool
    70  	routines           int
    71  	disconnectInvalid  atomic.Bool
    72  	closed             atomic.Bool
    73  	relayManager       *relayManager
    74  
    75  	tryPromoteEvery atomic.Uint32
    76  	reQueryEvery    atomic.Uint32
    77  	reQueryWait     atomic.Int64
    78  
    79  	sendRecvErrorConfig sendRecvErrorConfig
    80  
    81  	// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
    82  	rebindCount int8
    83  	version     string
    84  
    85  	conntrackCacheTimeout time.Duration
    86  
    87  	writers []udp.Conn
    88  	readers []io.ReadWriteCloser
    89  
    90  	metricHandshakes    metrics.Histogram
    91  	messageMetrics      *MessageMetrics
    92  	cachedPacketMetrics *cachedPacketMetrics
    93  
    94  	l *logrus.Logger
    95  }
    96  
    97  type EncWriter interface {
    98  	SendVia(via *HostInfo,
    99  		relay *Relay,
   100  		ad,
   101  		nb,
   102  		out []byte,
   103  		nocopy bool,
   104  	)
   105  	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
   106  	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
   107  	Handshake(vpnIp iputil.VpnIp)
   108  }
   109  
   110  type sendRecvErrorConfig uint8
   111  
   112  const (
   113  	sendRecvErrorAlways sendRecvErrorConfig = iota
   114  	sendRecvErrorNever
   115  	sendRecvErrorPrivate
   116  )
   117  
   118  func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool {
   119  	switch s {
   120  	case sendRecvErrorPrivate:
   121  		return ip.IsPrivate()
   122  	case sendRecvErrorAlways:
   123  		return true
   124  	case sendRecvErrorNever:
   125  		return false
   126  	default:
   127  		panic(fmt.Errorf("invalid sendRecvErrorConfig value: %d", s))
   128  	}
   129  }
   130  
   131  func (s sendRecvErrorConfig) String() string {
   132  	switch s {
   133  	case sendRecvErrorAlways:
   134  		return "always"
   135  	case sendRecvErrorNever:
   136  		return "never"
   137  	case sendRecvErrorPrivate:
   138  		return "private"
   139  	default:
   140  		return fmt.Sprintf("invalid(%d)", s)
   141  	}
   142  }
   143  
   144  func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
   145  	if c.Outside == nil {
   146  		return nil, errors.New("no outside connection")
   147  	}
   148  	if c.Inside == nil {
   149  		return nil, errors.New("no inside interface (tun)")
   150  	}
   151  	if c.pki == nil {
   152  		return nil, errors.New("no certificate state")
   153  	}
   154  	if c.Firewall == nil {
   155  		return nil, errors.New("no firewall rules")
   156  	}
   157  
   158  	certificate := c.pki.GetCertState().Certificate
   159  	myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
   160  	ifce := &Interface{
   161  		pki:                c.pki,
   162  		hostMap:            c.HostMap,
   163  		outside:            c.Outside,
   164  		inside:             c.Inside,
   165  		cipher:             c.Cipher,
   166  		firewall:           c.Firewall,
   167  		serveDns:           c.ServeDns,
   168  		handshakeManager:   c.HandshakeManager,
   169  		createTime:         time.Now(),
   170  		lightHouse:         c.lightHouse,
   171  		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
   172  		dropLocalBroadcast: c.DropLocalBroadcast,
   173  		dropMulticast:      c.DropMulticast,
   174  		routines:           c.routines,
   175  		version:            c.version,
   176  		writers:            make([]udp.Conn, c.routines),
   177  		readers:            make([]io.ReadWriteCloser, c.routines),
   178  		myVpnIp:            myVpnIp,
   179  		relayManager:       c.relayManager,
   180  
   181  		conntrackCacheTimeout: c.ConntrackCacheTimeout,
   182  
   183  		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
   184  		messageMetrics:   c.MessageMetrics,
   185  		cachedPacketMetrics: &cachedPacketMetrics{
   186  			sent:    metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil),
   187  			dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil),
   188  		},
   189  
   190  		l: c.l,
   191  	}
   192  
   193  	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
   194  	ifce.reQueryEvery.Store(c.reQueryEvery)
   195  	ifce.reQueryWait.Store(int64(c.reQueryWait))
   196  
   197  	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
   198  
   199  	return ifce, nil
   200  }
   201  
   202  // activate creates the interface on the host. After the interface is created, any
   203  // other services that want to bind listeners to its IP may do so successfully. However,
   204  // the interface isn't going to process anything until run() is called.
   205  func (f *Interface) activate() {
   206  	// actually turn on tun dev
   207  
   208  	addr, err := f.outside.LocalAddr()
   209  	if err != nil {
   210  		f.l.WithError(err).Error("Failed to get udp listen address")
   211  	}
   212  
   213  	f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
   214  		WithField("build", f.version).WithField("udpAddr", addr).
   215  		WithField("boringcrypto", boringEnabled()).
   216  		Info("Nebula interface is active")
   217  
   218  	metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
   219  
   220  	// Prepare n tun queues
   221  	var reader io.ReadWriteCloser = f.inside
   222  	for i := 0; i < f.routines; i++ {
   223  		if i > 0 {
   224  			reader, err = f.inside.NewMultiQueueReader()
   225  			if err != nil {
   226  				f.l.Fatal(err)
   227  			}
   228  		}
   229  		f.readers[i] = reader
   230  	}
   231  
   232  	if err := f.inside.Activate(); err != nil {
   233  		f.inside.Close()
   234  		f.l.Fatal(err)
   235  	}
   236  }
   237  
   238  func (f *Interface) run() {
   239  	// Launch n queues to read packets from udp
   240  	for i := 0; i < f.routines; i++ {
   241  		go f.listenOut(i)
   242  	}
   243  
   244  	// Launch n queues to read packets from tun dev
   245  	for i := 0; i < f.routines; i++ {
   246  		go f.listenIn(f.readers[i], i)
   247  	}
   248  }
   249  
   250  func (f *Interface) listenOut(i int) {
   251  	runtime.LockOSThread()
   252  
   253  	var li udp.Conn
   254  	// TODO clean this up with a coherent interface for each outside connection
   255  	if i > 0 {
   256  		li = f.writers[i]
   257  	} else {
   258  		li = f.outside
   259  	}
   260  
   261  	lhh := f.lightHouse.NewRequestHandler()
   262  	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
   263  	li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
   264  }
   265  
   266  func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
   267  	runtime.LockOSThread()
   268  
   269  	packet := make([]byte, mtu)
   270  	out := make([]byte, mtu)
   271  	fwPacket := &firewall.Packet{}
   272  	nb := make([]byte, 12, 12)
   273  
   274  	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
   275  
   276  	for {
   277  		n, err := reader.Read(packet)
   278  		if err != nil {
   279  			if errors.Is(err, os.ErrClosed) && f.closed.Load() {
   280  				return
   281  			}
   282  
   283  			f.l.WithError(err).Error("Error while reading outbound packet")
   284  			// This only seems to happen when something fatal happens to the fd, so exit.
   285  			os.Exit(2)
   286  		}
   287  
   288  		f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l))
   289  	}
   290  }
   291  
   292  func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
   293  	c.RegisterReloadCallback(f.reloadFirewall)
   294  	c.RegisterReloadCallback(f.reloadSendRecvError)
   295  	c.RegisterReloadCallback(f.reloadDisconnectInvalid)
   296  	c.RegisterReloadCallback(f.reloadMisc)
   297  
   298  	for _, udpConn := range f.writers {
   299  		c.RegisterReloadCallback(udpConn.ReloadConfig)
   300  	}
   301  }
   302  
   303  func (f *Interface) reloadDisconnectInvalid(c *config.C) {
   304  	initial := c.InitialLoad()
   305  	if initial || c.HasChanged("pki.disconnect_invalid") {
   306  		f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
   307  		if !initial {
   308  			f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
   309  		}
   310  	}
   311  }
   312  
   313  func (f *Interface) reloadFirewall(c *config.C) {
   314  	//TODO: need to trigger/detect if the certificate changed too
   315  	if c.HasChanged("firewall") == false {
   316  		f.l.Debug("No firewall config change detected")
   317  		return
   318  	}
   319  
   320  	fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
   321  	if err != nil {
   322  		f.l.WithError(err).Error("Error while creating firewall during reload")
   323  		return
   324  	}
   325  
   326  	oldFw := f.firewall
   327  	conntrack := oldFw.Conntrack
   328  	conntrack.Lock()
   329  	defer conntrack.Unlock()
   330  
   331  	fw.rulesVersion = oldFw.rulesVersion + 1
   332  	// If rulesVersion is back to zero, we have wrapped all the way around. Be
   333  	// safe and just reset conntrack in this case.
   334  	if fw.rulesVersion == 0 {
   335  		f.l.WithField("firewallHashes", fw.GetRuleHashes()).
   336  			WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
   337  			WithField("rulesVersion", fw.rulesVersion).
   338  			Warn("firewall rulesVersion has overflowed, resetting conntrack")
   339  	} else {
   340  		fw.Conntrack = conntrack
   341  	}
   342  
   343  	f.firewall = fw
   344  
   345  	oldFw.Destroy()
   346  	f.l.WithField("firewallHashes", fw.GetRuleHashes()).
   347  		WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
   348  		WithField("rulesVersion", fw.rulesVersion).
   349  		Info("New firewall has been installed")
   350  }
   351  
   352  func (f *Interface) reloadSendRecvError(c *config.C) {
   353  	if c.InitialLoad() || c.HasChanged("listen.send_recv_error") {
   354  		stringValue := c.GetString("listen.send_recv_error", "always")
   355  
   356  		switch stringValue {
   357  		case "always":
   358  			f.sendRecvErrorConfig = sendRecvErrorAlways
   359  		case "never":
   360  			f.sendRecvErrorConfig = sendRecvErrorNever
   361  		case "private":
   362  			f.sendRecvErrorConfig = sendRecvErrorPrivate
   363  		default:
   364  			if c.GetBool("listen.send_recv_error", true) {
   365  				f.sendRecvErrorConfig = sendRecvErrorAlways
   366  			} else {
   367  				f.sendRecvErrorConfig = sendRecvErrorNever
   368  			}
   369  		}
   370  
   371  		f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()).
   372  			Info("Loaded send_recv_error config")
   373  	}
   374  }
   375  
   376  func (f *Interface) reloadMisc(c *config.C) {
   377  	if c.HasChanged("counters.try_promote") {
   378  		n := c.GetUint32("counters.try_promote", defaultPromoteEvery)
   379  		f.tryPromoteEvery.Store(n)
   380  		f.l.Info("counters.try_promote has changed")
   381  	}
   382  
   383  	if c.HasChanged("counters.requery_every_packets") {
   384  		n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery)
   385  		f.reQueryEvery.Store(n)
   386  		f.l.Info("counters.requery_every_packets has changed")
   387  	}
   388  
   389  	if c.HasChanged("timers.requery_wait_duration") {
   390  		n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait)
   391  		f.reQueryWait.Store(int64(n))
   392  		f.l.Info("timers.requery_wait_duration has changed")
   393  	}
   394  }
   395  
   396  func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
   397  	ticker := time.NewTicker(i)
   398  	defer ticker.Stop()
   399  
   400  	udpStats := udp.NewUDPStatsEmitter(f.writers)
   401  
   402  	certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
   403  
   404  	for {
   405  		select {
   406  		case <-ctx.Done():
   407  			return
   408  		case <-ticker.C:
   409  			f.firewall.EmitStats()
   410  			f.handshakeManager.EmitStats()
   411  			udpStats()
   412  			certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
   413  		}
   414  	}
   415  }
   416  
   417  func (f *Interface) Close() error {
   418  	f.closed.Store(true)
   419  
   420  	for _, u := range f.writers {
   421  		err := u.Close()
   422  		if err != nil {
   423  			f.l.WithError(err).Error("Error while closing udp socket")
   424  		}
   425  	}
   426  
   427  	// Release the tun device
   428  	return f.inside.Close()
   429  }