github.com/slackhq/nebula@v1.9.0/overlay/tun_wintun_windows.go (about)

     1  package overlay
     2  
     3  import (
     4  	"crypto"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  	"sync/atomic"
    10  	"unsafe"
    11  
    12  	"github.com/sirupsen/logrus"
    13  	"github.com/slackhq/nebula/cidr"
    14  	"github.com/slackhq/nebula/config"
    15  	"github.com/slackhq/nebula/iputil"
    16  	"github.com/slackhq/nebula/util"
    17  	"github.com/slackhq/nebula/wintun"
    18  	"golang.org/x/sys/windows"
    19  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    20  )
    21  
    22  const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
    23  
    24  type winTun struct {
    25  	Device    string
    26  	cidr      *net.IPNet
    27  	prefix    netip.Prefix
    28  	MTU       int
    29  	Routes    atomic.Pointer[[]Route]
    30  	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
    31  	l         *logrus.Logger
    32  
    33  	tun *wintun.NativeTun
    34  }
    35  
    36  func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
    37  	// GUID is 128 bit
    38  	hash := crypto.MD5.New()
    39  
    40  	_, err := hash.Write([]byte(tunGUIDLabel))
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	_, err = hash.Write([]byte(name))
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  
    50  	sum := hash.Sum(nil)
    51  
    52  	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
    53  }
    54  
    55  func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
    56  	deviceName := c.GetString("tun.dev", "")
    57  	guid, err := generateGUIDByDeviceName(deviceName)
    58  	if err != nil {
    59  		return nil, fmt.Errorf("generate GUID failed: %w", err)
    60  	}
    61  
    62  	prefix, err := iputil.ToNetIpPrefix(*cidr)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	t := &winTun{
    68  		Device: deviceName,
    69  		cidr:   cidr,
    70  		prefix: prefix,
    71  		MTU:    c.GetInt("tun.mtu", DefaultMTU),
    72  		l:      l,
    73  	}
    74  
    75  	err = t.reload(c, true)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	var tunDevice wintun.Device
    81  	tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
    82  	if err != nil {
    83  		// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
    84  		// Trying a second time resolves the issue.
    85  		l.WithError(err).Debug("Failed to create wintun device, retrying")
    86  		tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
    87  		if err != nil {
    88  			return nil, fmt.Errorf("create TUN device failed: %w", err)
    89  		}
    90  	}
    91  	t.tun = tunDevice.(*wintun.NativeTun)
    92  
    93  	c.RegisterReloadCallback(func(c *config.C) {
    94  		err := t.reload(c, false)
    95  		if err != nil {
    96  			util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
    97  		}
    98  	})
    99  
   100  	return t, nil
   101  }
   102  
   103  func (t *winTun) reload(c *config.C, initial bool) error {
   104  	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
   105  	if err != nil {
   106  		return err
   107  	}
   108  
   109  	if !initial && !change {
   110  		return nil
   111  	}
   112  
   113  	routeTree, err := makeRouteTree(t.l, routes, false)
   114  	if err != nil {
   115  		return err
   116  	}
   117  
   118  	// Teach nebula how to handle the routes before establishing them in the system table
   119  	oldRoutes := t.Routes.Swap(&routes)
   120  	t.routeTree.Store(routeTree)
   121  
   122  	if !initial {
   123  		// Remove first, if the system removes a wanted route hopefully it will be re-added next
   124  		err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
   125  		if err != nil {
   126  			util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
   127  		}
   128  
   129  		// Ensure any routes we actually want are installed
   130  		err = t.addRoutes(true)
   131  		if err != nil {
   132  			// Catch any stray logs
   133  			util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
   134  		}
   135  	}
   136  
   137  	return nil
   138  }
   139  
   140  func (t *winTun) Activate() error {
   141  	luid := winipcfg.LUID(t.tun.LUID())
   142  
   143  	err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
   144  	if err != nil {
   145  		return fmt.Errorf("failed to set address: %w", err)
   146  	}
   147  
   148  	err = t.addRoutes(false)
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	return nil
   154  }
   155  
   156  func (t *winTun) addRoutes(logErrors bool) error {
   157  	luid := winipcfg.LUID(t.tun.LUID())
   158  	routes := *t.Routes.Load()
   159  	foundDefault4 := false
   160  
   161  	for _, r := range routes {
   162  		if r.Via == nil || !r.Install {
   163  			// We don't allow route MTUs so only install routes with a via
   164  			continue
   165  		}
   166  
   167  		prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
   168  		if err != nil {
   169  			retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
   170  			if logErrors {
   171  				retErr.Log(t.l)
   172  				continue
   173  			} else {
   174  				return retErr
   175  			}
   176  		}
   177  
   178  		// Add our unsafe route
   179  		err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
   180  		if err != nil {
   181  			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
   182  			if logErrors {
   183  				retErr.Log(t.l)
   184  				continue
   185  			} else {
   186  				return retErr
   187  			}
   188  		} else {
   189  			t.l.WithField("route", r).Info("Added route")
   190  		}
   191  
   192  		if !foundDefault4 {
   193  			if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
   194  				foundDefault4 = true
   195  			}
   196  		}
   197  	}
   198  
   199  	ipif, err := luid.IPInterface(windows.AF_INET)
   200  	if err != nil {
   201  		return fmt.Errorf("failed to get ip interface: %w", err)
   202  	}
   203  
   204  	ipif.NLMTU = uint32(t.MTU)
   205  	if foundDefault4 {
   206  		ipif.UseAutomaticMetric = false
   207  		ipif.Metric = 0
   208  	}
   209  
   210  	if err := ipif.Set(); err != nil {
   211  		return fmt.Errorf("failed to set ip interface: %w", err)
   212  	}
   213  	return nil
   214  }
   215  
   216  func (t *winTun) removeRoutes(routes []Route) error {
   217  	luid := winipcfg.LUID(t.tun.LUID())
   218  
   219  	for _, r := range routes {
   220  		if !r.Install {
   221  			continue
   222  		}
   223  
   224  		prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
   225  		if err != nil {
   226  			t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
   227  			continue
   228  		}
   229  
   230  		err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
   231  		if err != nil {
   232  			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
   233  		} else {
   234  			t.l.WithField("route", r).Info("Removed route")
   235  		}
   236  	}
   237  	return nil
   238  }
   239  
   240  func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
   241  	_, r := t.routeTree.Load().MostSpecificContains(ip)
   242  	return r
   243  }
   244  
   245  func (t *winTun) Cidr() *net.IPNet {
   246  	return t.cidr
   247  }
   248  
   249  func (t *winTun) Name() string {
   250  	return t.Device
   251  }
   252  
   253  func (t *winTun) Read(b []byte) (int, error) {
   254  	return t.tun.Read(b, 0)
   255  }
   256  
   257  func (t *winTun) Write(b []byte) (int, error) {
   258  	return t.tun.Write(b, 0)
   259  }
   260  
   261  func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
   262  	return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
   263  }
   264  
   265  func (t *winTun) Close() error {
   266  	// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
   267  	// so to be certain, just remove everything before destroying.
   268  	luid := winipcfg.LUID(t.tun.LUID())
   269  	_ = luid.FlushRoutes(windows.AF_INET)
   270  	_ = luid.FlushIPAddresses(windows.AF_INET)
   271  	/* We don't support IPV6 yet
   272  	_ = luid.FlushRoutes(windows.AF_INET6)
   273  	_ = luid.FlushIPAddresses(windows.AF_INET6)
   274  	*/
   275  	_ = luid.FlushDNS(windows.AF_INET)
   276  
   277  	return t.tun.Close()
   278  }