github.com/slackhq/nebula@v1.9.0/overlay/tun_linux.go (about)

     1  //go:build !android && !e2e_testing
     2  // +build !android,!e2e_testing
     3  
     4  package overlay
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"os"
    12  	"strings"
    13  	"sync/atomic"
    14  	"unsafe"
    15  
    16  	"github.com/sirupsen/logrus"
    17  	"github.com/slackhq/nebula/cidr"
    18  	"github.com/slackhq/nebula/config"
    19  	"github.com/slackhq/nebula/iputil"
    20  	"github.com/slackhq/nebula/util"
    21  	"github.com/vishvananda/netlink"
    22  	"golang.org/x/sys/unix"
    23  )
    24  
    25  type tun struct {
    26  	io.ReadWriteCloser
    27  	fd          int
    28  	Device      string
    29  	cidr        *net.IPNet
    30  	MaxMTU      int
    31  	DefaultMTU  int
    32  	TXQueueLen  int
    33  	deviceIndex int
    34  	ioctlFd     uintptr
    35  
    36  	Routes          atomic.Pointer[[]Route]
    37  	routeTree       atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
    38  	routeChan       chan struct{}
    39  	useSystemRoutes bool
    40  
    41  	l *logrus.Logger
    42  }
    43  
    44  type ifReq struct {
    45  	Name  [16]byte
    46  	Flags uint16
    47  	pad   [8]byte
    48  }
    49  
    50  type ifreqAddr struct {
    51  	Name [16]byte
    52  	Addr unix.RawSockaddrInet4
    53  	pad  [8]byte
    54  }
    55  
    56  type ifreqMTU struct {
    57  	Name [16]byte
    58  	MTU  int32
    59  	pad  [8]byte
    60  }
    61  
    62  type ifreqQLEN struct {
    63  	Name  [16]byte
    64  	Value int32
    65  	pad   [8]byte
    66  }
    67  
    68  func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
    69  	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
    70  
    71  	t, err := newTunGeneric(c, l, file, cidr)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	t.Device = "tun0"
    77  
    78  	return t, nil
    79  }
    80  
    81  func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) {
    82  	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
    83  	if err != nil {
    84  		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
    85  		if os.IsNotExist(err) {
    86  			err = os.MkdirAll("/dev/net", 0755)
    87  			if err != nil {
    88  				return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err)
    89  			}
    90  			err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200)))
    91  			if err != nil {
    92  				return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err)
    93  			}
    94  
    95  			fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0)
    96  			if err != nil {
    97  				return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err)
    98  			}
    99  		} else {
   100  			return nil, err
   101  		}
   102  	}
   103  
   104  	var req ifReq
   105  	req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI)
   106  	if multiqueue {
   107  		req.Flags |= unix.IFF_MULTI_QUEUE
   108  	}
   109  	copy(req.Name[:], c.GetString("tun.dev", ""))
   110  	if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
   111  		return nil, err
   112  	}
   113  	name := strings.Trim(string(req.Name[:]), "\x00")
   114  
   115  	file := os.NewFile(uintptr(fd), "/dev/net/tun")
   116  	t, err := newTunGeneric(c, l, file, cidr)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	t.Device = name
   122  
   123  	return t, nil
   124  }
   125  
   126  func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) {
   127  	t := &tun{
   128  		ReadWriteCloser: file,
   129  		fd:              int(file.Fd()),
   130  		cidr:            cidr,
   131  		TXQueueLen:      c.GetInt("tun.tx_queue", 500),
   132  		useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
   133  		l:               l,
   134  	}
   135  
   136  	err := t.reload(c, true)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  
   141  	c.RegisterReloadCallback(func(c *config.C) {
   142  		err := t.reload(c, false)
   143  		if err != nil {
   144  			util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
   145  		}
   146  	})
   147  
   148  	return t, nil
   149  }
   150  
   151  func (t *tun) reload(c *config.C, initial bool) error {
   152  	routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
   153  	if err != nil {
   154  		return err
   155  	}
   156  
   157  	if !initial && !routeChange && !c.HasChanged("tun.mtu") {
   158  		return nil
   159  	}
   160  
   161  	routeTree, err := makeRouteTree(t.l, routes, true)
   162  	if err != nil {
   163  		return err
   164  	}
   165  
   166  	oldDefaultMTU := t.DefaultMTU
   167  	oldMaxMTU := t.MaxMTU
   168  	newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU)
   169  	newMaxMTU := newDefaultMTU
   170  	for i, r := range routes {
   171  		if r.MTU == 0 {
   172  			routes[i].MTU = newDefaultMTU
   173  		}
   174  
   175  		if r.MTU > t.MaxMTU {
   176  			newMaxMTU = r.MTU
   177  		}
   178  	}
   179  
   180  	t.MaxMTU = newMaxMTU
   181  	t.DefaultMTU = newDefaultMTU
   182  
   183  	// Teach nebula how to handle the routes before establishing them in the system table
   184  	oldRoutes := t.Routes.Swap(&routes)
   185  	t.routeTree.Store(routeTree)
   186  
   187  	if !initial {
   188  		if oldMaxMTU != newMaxMTU {
   189  			t.setMTU()
   190  			t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU)
   191  		}
   192  
   193  		if oldDefaultMTU != newDefaultMTU {
   194  			err := t.setDefaultRoute()
   195  			if err != nil {
   196  				t.l.Warn(err)
   197  			} else {
   198  				t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
   199  			}
   200  		}
   201  
   202  		// Remove first, if the system removes a wanted route hopefully it will be re-added next
   203  		t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
   204  
   205  		// Ensure any routes we actually want are installed
   206  		err = t.addRoutes(true)
   207  		if err != nil {
   208  			// This should never be called since addRoutes should log its own errors in a reload condition
   209  			util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l)
   210  		}
   211  	}
   212  
   213  	return nil
   214  }
   215  
   216  func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
   217  	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  
   222  	var req ifReq
   223  	req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
   224  	copy(req.Name[:], t.Device)
   225  	if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	file := os.NewFile(uintptr(fd), "/dev/net/tun")
   230  
   231  	return file, nil
   232  }
   233  
   234  func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
   235  	_, r := t.routeTree.Load().MostSpecificContains(ip)
   236  	return r
   237  }
   238  
   239  func (t *tun) Write(b []byte) (int, error) {
   240  	var nn int
   241  	max := len(b)
   242  
   243  	for {
   244  		n, err := unix.Write(t.fd, b[nn:max])
   245  		if n > 0 {
   246  			nn += n
   247  		}
   248  		if nn == len(b) {
   249  			return nn, err
   250  		}
   251  
   252  		if err != nil {
   253  			return nn, err
   254  		}
   255  
   256  		if n == 0 {
   257  			return nn, io.ErrUnexpectedEOF
   258  		}
   259  	}
   260  }
   261  
   262  func (t *tun) deviceBytes() (o [16]byte) {
   263  	for i, c := range t.Device {
   264  		o[i] = byte(c)
   265  	}
   266  	return
   267  }
   268  
   269  func (t *tun) Activate() error {
   270  	devName := t.deviceBytes()
   271  
   272  	if t.useSystemRoutes {
   273  		t.watchRoutes()
   274  	}
   275  
   276  	var addr, mask [4]byte
   277  
   278  	copy(addr[:], t.cidr.IP.To4())
   279  	copy(mask[:], t.cidr.Mask)
   280  
   281  	s, err := unix.Socket(
   282  		unix.AF_INET,
   283  		unix.SOCK_DGRAM,
   284  		unix.IPPROTO_IP,
   285  	)
   286  	if err != nil {
   287  		return err
   288  	}
   289  	t.ioctlFd = uintptr(s)
   290  
   291  	ifra := ifreqAddr{
   292  		Name: devName,
   293  		Addr: unix.RawSockaddrInet4{
   294  			Family: unix.AF_INET,
   295  			Addr:   addr,
   296  		},
   297  	}
   298  
   299  	// Set the device ip address
   300  	if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
   301  		return fmt.Errorf("failed to set tun address: %s", err)
   302  	}
   303  
   304  	// Set the device network
   305  	ifra.Addr.Addr = mask
   306  	if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
   307  		return fmt.Errorf("failed to set tun netmask: %s", err)
   308  	}
   309  
   310  	// Set the device name
   311  	ifrf := ifReq{Name: devName}
   312  	if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
   313  		return fmt.Errorf("failed to set tun device name: %s", err)
   314  	}
   315  
   316  	// Setup our default MTU
   317  	t.setMTU()
   318  
   319  	// Set the transmit queue length
   320  	ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
   321  	if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
   322  		// If we can't set the queue length nebula will still work but it may lead to packet loss
   323  		t.l.WithError(err).Error("Failed to set tun tx queue length")
   324  	}
   325  
   326  	// Bring up the interface
   327  	ifrf.Flags = ifrf.Flags | unix.IFF_UP
   328  	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
   329  		return fmt.Errorf("failed to bring the tun device up: %s", err)
   330  	}
   331  
   332  	link, err := netlink.LinkByName(t.Device)
   333  	if err != nil {
   334  		return fmt.Errorf("failed to get tun device link: %s", err)
   335  	}
   336  	t.deviceIndex = link.Attrs().Index
   337  
   338  	if err = t.setDefaultRoute(); err != nil {
   339  		return err
   340  	}
   341  
   342  	// Set the routes
   343  	if err = t.addRoutes(false); err != nil {
   344  		return err
   345  	}
   346  
   347  	// Run the interface
   348  	ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
   349  	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
   350  		return fmt.Errorf("failed to run tun device: %s", err)
   351  	}
   352  
   353  	return nil
   354  }
   355  
   356  func (t *tun) setMTU() {
   357  	// Set the MTU on the device
   358  	ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)}
   359  	if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
   360  		// This is currently a non fatal condition because the route table must have the MTU set appropriately as well
   361  		t.l.WithError(err).Error("Failed to set tun mtu")
   362  	}
   363  }
   364  
   365  func (t *tun) setDefaultRoute() error {
   366  	// Default route
   367  	dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
   368  	nr := netlink.Route{
   369  		LinkIndex: t.deviceIndex,
   370  		Dst:       dr,
   371  		MTU:       t.DefaultMTU,
   372  		AdvMSS:    t.advMSS(Route{}),
   373  		Scope:     unix.RT_SCOPE_LINK,
   374  		Src:       t.cidr.IP,
   375  		Protocol:  unix.RTPROT_KERNEL,
   376  		Table:     unix.RT_TABLE_MAIN,
   377  		Type:      unix.RTN_UNICAST,
   378  	}
   379  	err := netlink.RouteReplace(&nr)
   380  	if err != nil {
   381  		return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
   382  	}
   383  
   384  	return nil
   385  }
   386  
   387  func (t *tun) addRoutes(logErrors bool) error {
   388  	// Path routes
   389  	routes := *t.Routes.Load()
   390  	for _, r := range routes {
   391  		if !r.Install {
   392  			continue
   393  		}
   394  
   395  		nr := netlink.Route{
   396  			LinkIndex: t.deviceIndex,
   397  			Dst:       r.Cidr,
   398  			MTU:       r.MTU,
   399  			AdvMSS:    t.advMSS(r),
   400  			Scope:     unix.RT_SCOPE_LINK,
   401  		}
   402  
   403  		if r.Metric > 0 {
   404  			nr.Priority = r.Metric
   405  		}
   406  
   407  		err := netlink.RouteReplace(&nr)
   408  		if err != nil {
   409  			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
   410  			if logErrors {
   411  				retErr.Log(t.l)
   412  			} else {
   413  				return retErr
   414  			}
   415  		} else {
   416  			t.l.WithField("route", r).Info("Added route")
   417  		}
   418  	}
   419  
   420  	return nil
   421  }
   422  
   423  func (t *tun) removeRoutes(routes []Route) {
   424  	for _, r := range routes {
   425  		if !r.Install {
   426  			continue
   427  		}
   428  
   429  		nr := netlink.Route{
   430  			LinkIndex: t.deviceIndex,
   431  			Dst:       r.Cidr,
   432  			MTU:       r.MTU,
   433  			AdvMSS:    t.advMSS(r),
   434  			Scope:     unix.RT_SCOPE_LINK,
   435  		}
   436  
   437  		if r.Metric > 0 {
   438  			nr.Priority = r.Metric
   439  		}
   440  
   441  		err := netlink.RouteDel(&nr)
   442  		if err != nil {
   443  			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
   444  		} else {
   445  			t.l.WithField("route", r).Info("Removed route")
   446  		}
   447  	}
   448  }
   449  
   450  func (t *tun) Cidr() *net.IPNet {
   451  	return t.cidr
   452  }
   453  
   454  func (t *tun) Name() string {
   455  	return t.Device
   456  }
   457  
   458  func (t *tun) advMSS(r Route) int {
   459  	mtu := r.MTU
   460  	if r.MTU == 0 {
   461  		mtu = t.DefaultMTU
   462  	}
   463  
   464  	// We only need to set advmss if the route MTU does not match the device MTU
   465  	if mtu != t.MaxMTU {
   466  		return mtu - 40
   467  	}
   468  	return 0
   469  }
   470  
   471  func (t *tun) watchRoutes() {
   472  	rch := make(chan netlink.RouteUpdate)
   473  	doneChan := make(chan struct{})
   474  
   475  	if err := netlink.RouteSubscribe(rch, doneChan); err != nil {
   476  		t.l.WithError(err).Errorf("failed to subscribe to system route changes")
   477  		return
   478  	}
   479  
   480  	t.routeChan = doneChan
   481  
   482  	go func() {
   483  		for {
   484  			select {
   485  			case r := <-rch:
   486  				t.updateRoutes(r)
   487  			case <-doneChan:
   488  				// netlink.RouteSubscriber will close the rch for us
   489  				return
   490  			}
   491  		}
   492  	}()
   493  }
   494  
   495  func (t *tun) updateRoutes(r netlink.RouteUpdate) {
   496  	if r.Gw == nil {
   497  		// Not a gateway route, ignore
   498  		t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
   499  		return
   500  	}
   501  
   502  	if !t.cidr.Contains(r.Gw) {
   503  		// Gateway isn't in our overlay network, ignore
   504  		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
   505  		return
   506  	}
   507  
   508  	if x := r.Dst.IP.To4(); x == nil {
   509  		// Nebula only handles ipv4 on the overlay currently
   510  		t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
   511  		return
   512  	}
   513  
   514  	newTree := cidr.NewTree4[iputil.VpnIp]()
   515  	if r.Type == unix.RTM_NEWROUTE {
   516  		for _, oldR := range t.routeTree.Load().List() {
   517  			newTree.AddCIDR(oldR.CIDR, oldR.Value)
   518  		}
   519  
   520  		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
   521  		newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
   522  
   523  	} else {
   524  		gw := iputil.Ip2VpnIp(r.Gw)
   525  		for _, oldR := range t.routeTree.Load().List() {
   526  			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
   527  				// This is the record to delete
   528  				t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
   529  				continue
   530  			}
   531  
   532  			newTree.AddCIDR(oldR.CIDR, oldR.Value)
   533  		}
   534  	}
   535  
   536  	t.routeTree.Store(newTree)
   537  }
   538  
   539  func (t *tun) Close() error {
   540  	if t.routeChan != nil {
   541  		close(t.routeChan)
   542  	}
   543  
   544  	if t.ReadWriteCloser != nil {
   545  		t.ReadWriteCloser.Close()
   546  	}
   547  
   548  	if t.ioctlFd > 0 {
   549  		os.NewFile(t.ioctlFd, "ioctlFd").Close()
   550  	}
   551  
   552  	return nil
   553  }