github.com/slackhq/nebula@v1.9.0/main.go (about)

     1  package nebula
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/sirupsen/logrus"
    11  	"github.com/slackhq/nebula/config"
    12  	"github.com/slackhq/nebula/overlay"
    13  	"github.com/slackhq/nebula/sshd"
    14  	"github.com/slackhq/nebula/udp"
    15  	"github.com/slackhq/nebula/util"
    16  	"gopkg.in/yaml.v2"
    17  )
    18  
    19  type m map[string]interface{}
    20  
    21  func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
    22  	ctx, cancel := context.WithCancel(context.Background())
    23  	// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
    24  	defer func() {
    25  		if reterr != nil {
    26  			cancel()
    27  		}
    28  	}()
    29  
    30  	l := logger
    31  	l.Formatter = &logrus.TextFormatter{
    32  		FullTimestamp: true,
    33  	}
    34  
    35  	// Print the config if in test, the exit comes later
    36  	if configTest {
    37  		b, err := yaml.Marshal(c.Settings)
    38  		if err != nil {
    39  			return nil, err
    40  		}
    41  
    42  		// Print the final config
    43  		l.Println(string(b))
    44  	}
    45  
    46  	err := configLogger(l, c)
    47  	if err != nil {
    48  		return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
    49  	}
    50  
    51  	c.RegisterReloadCallback(func(c *config.C) {
    52  		err := configLogger(l, c)
    53  		if err != nil {
    54  			l.WithError(err).Error("Failed to configure the logger")
    55  		}
    56  	})
    57  
    58  	pki, err := NewPKIFromConfig(l, c)
    59  	if err != nil {
    60  		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
    61  	}
    62  
    63  	certificate := pki.GetCertState().Certificate
    64  	fw, err := NewFirewallFromConfig(l, certificate, c)
    65  	if err != nil {
    66  		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
    67  	}
    68  	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
    69  
    70  	// TODO: make sure mask is 4 bytes
    71  	tunCidr := certificate.Details.Ips[0]
    72  
    73  	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
    74  	if err != nil {
    75  		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
    76  	}
    77  	wireSSHReload(l, ssh, c)
    78  	var sshStart func()
    79  	if c.GetBool("sshd.enabled", false) {
    80  		sshStart, err = configSSH(l, ssh, c)
    81  		if err != nil {
    82  			return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
    83  		}
    84  	}
    85  
    86  	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    87  	// All non system modifying configuration consumption should live above this line
    88  	// tun config, listeners, anything modifying the computer should be below
    89  	////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    90  
    91  	var routines int
    92  
    93  	// If `routines` is set, use that and ignore the specific values
    94  	if routines = c.GetInt("routines", 0); routines != 0 {
    95  		if routines < 1 {
    96  			routines = 1
    97  		}
    98  		if routines > 1 {
    99  			l.WithField("routines", routines).Info("Using multiple routines")
   100  		}
   101  	} else {
   102  		// deprecated and undocumented
   103  		tunQueues := c.GetInt("tun.routines", 1)
   104  		udpQueues := c.GetInt("listen.routines", 1)
   105  		if tunQueues > udpQueues {
   106  			routines = tunQueues
   107  		} else {
   108  			routines = udpQueues
   109  		}
   110  		if routines != 1 {
   111  			l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead")
   112  		}
   113  	}
   114  
   115  	// EXPERIMENTAL
   116  	// Intentionally not documented yet while we do more testing and determine
   117  	// a good default value.
   118  	conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
   119  	if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") {
   120  		// Use a different default if we are running with multiple routines
   121  		conntrackCacheTimeout = 1 * time.Second
   122  	}
   123  	if conntrackCacheTimeout > 0 {
   124  		l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
   125  	}
   126  
   127  	var tun overlay.Device
   128  	if !configTest {
   129  		c.CatchHUP(ctx)
   130  
   131  		if deviceFactory == nil {
   132  			deviceFactory = overlay.NewDeviceFromConfig
   133  		}
   134  
   135  		tun, err = deviceFactory(c, l, tunCidr, routines)
   136  		if err != nil {
   137  			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
   138  		}
   139  
   140  		defer func() {
   141  			if reterr != nil {
   142  				tun.Close()
   143  			}
   144  		}()
   145  	}
   146  
   147  	// set up our UDP listener
   148  	udpConns := make([]udp.Conn, routines)
   149  	port := c.GetInt("listen.port", 0)
   150  
   151  	if !configTest {
   152  		rawListenHost := c.GetString("listen.host", "0.0.0.0")
   153  		var listenHost *net.IPAddr
   154  		if rawListenHost == "[::]" {
   155  			// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
   156  			listenHost = &net.IPAddr{IP: net.IPv6zero}
   157  
   158  		} else {
   159  			listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
   160  			if err != nil {
   161  				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
   162  			}
   163  		}
   164  
   165  		for i := 0; i < routines; i++ {
   166  			l.Infof("listening %q %d", listenHost.IP, port)
   167  			udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
   168  			if err != nil {
   169  				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
   170  			}
   171  			udpServer.ReloadConfig(c)
   172  			udpConns[i] = udpServer
   173  
   174  			// If port is dynamic, discover it before the next pass through the for loop
   175  			// This way all routines will use the same port correctly
   176  			if port == 0 {
   177  				uPort, err := udpServer.LocalAddr()
   178  				if err != nil {
   179  					return nil, util.NewContextualError("Failed to get listening port", nil, err)
   180  				}
   181  				port = int(uPort.Port)
   182  			}
   183  		}
   184  	}
   185  
   186  	hostMap := NewHostMapFromConfig(l, tunCidr, c)
   187  	punchy := NewPunchyFromConfig(l, c)
   188  	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
   189  	if err != nil {
   190  		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
   191  	}
   192  
   193  	var messageMetrics *MessageMetrics
   194  	if c.GetBool("stats.message_metrics", false) {
   195  		messageMetrics = newMessageMetrics()
   196  	} else {
   197  		messageMetrics = newMessageMetricsOnlyRecvError()
   198  	}
   199  
   200  	useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false)
   201  
   202  	handshakeConfig := HandshakeConfig{
   203  		tryInterval:   c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
   204  		retries:       c.GetInt("handshakes.retries", DefaultHandshakeRetries),
   205  		triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
   206  		useRelays:     useRelays,
   207  
   208  		messageMetrics: messageMetrics,
   209  	}
   210  
   211  	handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
   212  	lightHouse.handshakeTrigger = handshakeManager.trigger
   213  
   214  	serveDns := false
   215  	if c.GetBool("lighthouse.serve_dns", false) {
   216  		if c.GetBool("lighthouse.am_lighthouse", false) {
   217  			serveDns = true
   218  		} else {
   219  			l.Warn("DNS server refusing to run because this host is not a lighthouse.")
   220  		}
   221  	}
   222  
   223  	checkInterval := c.GetInt("timers.connection_alive_interval", 5)
   224  	pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
   225  
   226  	ifConfig := &InterfaceConfig{
   227  		HostMap:                 hostMap,
   228  		Inside:                  tun,
   229  		Outside:                 udpConns[0],
   230  		pki:                     pki,
   231  		Cipher:                  c.GetString("cipher", "aes"),
   232  		Firewall:                fw,
   233  		ServeDns:                serveDns,
   234  		HandshakeManager:        handshakeManager,
   235  		lightHouse:              lightHouse,
   236  		checkInterval:           time.Second * time.Duration(checkInterval),
   237  		pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
   238  		tryPromoteEvery:         c.GetUint32("counters.try_promote", defaultPromoteEvery),
   239  		reQueryEvery:            c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
   240  		reQueryWait:             c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
   241  		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
   242  		DropMulticast:           c.GetBool("tun.drop_multicast", false),
   243  		routines:                routines,
   244  		MessageMetrics:          messageMetrics,
   245  		version:                 buildVersion,
   246  		relayManager:            NewRelayManager(ctx, l, hostMap, c),
   247  		punchy:                  punchy,
   248  
   249  		ConntrackCacheTimeout: conntrackCacheTimeout,
   250  		l:                     l,
   251  	}
   252  
   253  	switch ifConfig.Cipher {
   254  	case "aes":
   255  		noiseEndianness = binary.BigEndian
   256  	case "chachapoly":
   257  		noiseEndianness = binary.LittleEndian
   258  	default:
   259  		return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
   260  	}
   261  
   262  	var ifce *Interface
   263  	if !configTest {
   264  		ifce, err = NewInterface(ctx, ifConfig)
   265  		if err != nil {
   266  			return nil, fmt.Errorf("failed to initialize interface: %s", err)
   267  		}
   268  
   269  		// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
   270  		// I don't want to make this initial commit too far-reaching though
   271  		ifce.writers = udpConns
   272  		lightHouse.ifce = ifce
   273  
   274  		ifce.RegisterConfigChangeCallbacks(c)
   275  		ifce.reloadDisconnectInvalid(c)
   276  		ifce.reloadSendRecvError(c)
   277  
   278  		handshakeManager.f = ifce
   279  		go handshakeManager.Run(ctx)
   280  	}
   281  
   282  	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
   283  	// a context so that they can exit when the context is Done.
   284  	statsStart, err := startStats(l, c, buildVersion, configTest)
   285  	if err != nil {
   286  		return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
   287  	}
   288  
   289  	if configTest {
   290  		return nil, nil
   291  	}
   292  
   293  	//TODO: check if we _should_ be emitting stats
   294  	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
   295  
   296  	attachCommands(l, c, ssh, ifce)
   297  
   298  	// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
   299  	var dnsStart func()
   300  	if lightHouse.amLighthouse && serveDns {
   301  		l.Debugln("Starting dns server")
   302  		dnsStart = dnsMain(l, hostMap, c)
   303  	}
   304  
   305  	return &Control{
   306  		ifce,
   307  		l,
   308  		ctx,
   309  		cancel,
   310  		sshStart,
   311  		statsStart,
   312  		dnsStart,
   313  		lightHouse.StartUpdateWorker,
   314  	}, nil
   315  }