github.com/borderzero/water@v0.0.1/syscalls_tun_windows.go (about)

     1  package water
     2  
     3  import (
     4  	"crypto/rand"
     5  	_ "embed"
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  	"os"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  	_ "unsafe"
    14  
    15  	"golang.org/x/sys/windows"
    16  	"golang.zx2c4.com/wintun"
    17  )
    18  
    19  const (
    20  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
    21  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
    22  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
    23  )
    24  
    25  type Event int
    26  
    27  const (
    28  	EventUp = 1 << iota
    29  	EventDown
    30  	EventMTUUpdate
    31  )
    32  
    33  type rateJuggler struct {
    34  	current       uint64
    35  	nextByteCount uint64
    36  	nextStartTime int64
    37  	changing      int32
    38  }
    39  
    40  type NativeTun struct {
    41  	wt        *wintun.Adapter
    42  	name      string
    43  	handle    windows.Handle
    44  	rate      rateJuggler
    45  	session   wintun.Session
    46  	readWait  windows.Handle
    47  	events    chan Event
    48  	running   sync.WaitGroup
    49  	closeOnce sync.Once
    50  	close     int32
    51  	forcedMTU int
    52  }
    53  
    54  type WTun struct {
    55  	dev *NativeTun
    56  }
    57  
    58  func (w *WTun) Close() error {
    59  	return w.dev.Close()
    60  }
    61  
    62  func (w *WTun) Write(b []byte) (int, error) {
    63  	return w.dev.Write(b, 0)
    64  }
    65  
    66  func (w *WTun) Read(b []byte) (int, error) {
    67  	return w.dev.Read(b, 0)
    68  }
    69  
    70  var (
    71  	WintunTunnelType          = "Wintun"
    72  	WintunStaticRequestedGUID *windows.GUID
    73  )
    74  
    75  //go:linkname procyield runtime.procyield
    76  func procyield(cycles uint32)
    77  
    78  //go:linkname nanotime runtime.nanotime
    79  func nanotime() int64
    80  
    81  func generateRandomGUID() (*windows.GUID, error) {
    82  	var guid windows.GUID
    83  
    84  	// Generate random values for the Data1, Data2, and Data3 fields
    85  	if err := binary.Read(rand.Reader, binary.LittleEndian, &guid.Data1); err != nil {
    86  		return nil, err
    87  	}
    88  	if err := binary.Read(rand.Reader, binary.LittleEndian, &guid.Data2); err != nil {
    89  		return nil, err
    90  	}
    91  	if err := binary.Read(rand.Reader, binary.LittleEndian, &guid.Data3); err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	// Generate random bytes for the Data4 field
    96  	if _, err := rand.Read(guid.Data4[:]); err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	return &guid, nil
   101  }
   102  
   103  func openTunDev(config Config) (ifce *Interface, err error) {
   104  	/*
   105  		gUID := &windows.GUID{
   106  			0x0000000,
   107  			0xFFFF,
   108  			0xFFFF,
   109  			[8]byte{0xFF, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
   110  		}
   111  	*/
   112  
   113  	// We'll geneerate a random GUID for the Wintun interface
   114  	// This is to work around some issue in the Wintun driver that causes
   115  	// it to fail to create an interface as claims it already exists
   116  	gUID, _ := generateRandomGUID()
   117  
   118  	if config.PlatformSpecificParams.Name == "" {
   119  		config.PlatformSpecificParams.Name = "WaterIface"
   120  	}
   121  	nativeTunDevice, err := CreateTUNWithRequestedGUID(config.PlatformSpecificParams.Name, gUID, 0)
   122  	if err != nil {
   123  		// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
   124  		// Trying a second time resolves the issue.
   125  		time.Sleep(1 * time.Second)
   126  		nativeTunDevice, err = CreateTUNWithRequestedGUID(config.PlatformSpecificParams.Name, gUID, 0)
   127  		if err != nil {
   128  			return nil, err
   129  		}
   130  	}
   131  	ifce = &Interface{
   132  		isTAP:           config.DeviceType == TAP,
   133  		ReadWriteCloser: &WTun{dev: nativeTunDevice},
   134  		name:            config.PlatformSpecificParams.Name,
   135  	}
   136  	return ifce, nil
   137  }
   138  
   139  // CreateTUN creates a Wintun interface with the given name. Should a Wintun
   140  // interface with the same name exist, it is reused.
   141  func CreateTUN(ifname string, mtu int) (*NativeTun, error) {
   142  	return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
   143  }
   144  
   145  // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
   146  // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
   147  func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (*NativeTun, error) {
   148  	wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
   149  	if err != nil {
   150  		return nil, fmt.Errorf("Error creating interface: %w", err)
   151  	}
   152  
   153  	forcedMTU := 1420
   154  	if mtu > 0 {
   155  		forcedMTU = mtu
   156  	}
   157  
   158  	tun := &NativeTun{
   159  		wt:        wt,
   160  		name:      ifname,
   161  		handle:    windows.InvalidHandle,
   162  		events:    make(chan Event, 10),
   163  		forcedMTU: forcedMTU,
   164  	}
   165  
   166  	tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
   167  	if err != nil {
   168  		tun.wt.Close()
   169  		close(tun.events)
   170  		return nil, fmt.Errorf("Error starting session: %w", err)
   171  	}
   172  	tun.readWait = tun.session.ReadWaitEvent()
   173  	return tun, nil
   174  }
   175  
   176  func (tun *NativeTun) Name() (string, error) {
   177  	return tun.name, nil
   178  }
   179  
   180  func (tun *NativeTun) File() *os.File {
   181  	return nil
   182  }
   183  
   184  func (tun *NativeTun) Events() chan Event {
   185  	return tun.events
   186  }
   187  
   188  func (tun *NativeTun) Close() error {
   189  	var err error
   190  	tun.closeOnce.Do(func() {
   191  		atomic.StoreInt32(&tun.close, 1)
   192  		windows.SetEvent(tun.readWait)
   193  		tun.running.Wait()
   194  		tun.session.End()
   195  		if tun.wt != nil {
   196  			tun.wt.Close()
   197  		}
   198  		close(tun.events)
   199  	})
   200  	return err
   201  }
   202  
   203  func (tun *NativeTun) MTU() (int, error) {
   204  	return tun.forcedMTU, nil
   205  }
   206  
   207  // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
   208  func (tun *NativeTun) ForceMTU(mtu int) {
   209  	update := tun.forcedMTU != mtu
   210  	tun.forcedMTU = mtu
   211  	if update {
   212  		tun.events <- EventMTUUpdate
   213  	}
   214  }
   215  
   216  // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
   217  
   218  func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
   219  	tun.running.Add(1)
   220  	defer tun.running.Done()
   221  retry:
   222  	if atomic.LoadInt32(&tun.close) == 1 {
   223  		return 0, os.ErrClosed
   224  	}
   225  	start := nanotime()
   226  	shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
   227  	for {
   228  		if atomic.LoadInt32(&tun.close) == 1 {
   229  			return 0, os.ErrClosed
   230  		}
   231  		packet, err := tun.session.ReceivePacket()
   232  		switch err {
   233  		case nil:
   234  			packetSize := len(packet)
   235  			copy(buff[offset:], packet)
   236  			tun.session.ReleaseReceivePacket(packet)
   237  			tun.rate.update(uint64(packetSize))
   238  			return packetSize, nil
   239  		case windows.ERROR_NO_MORE_ITEMS:
   240  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   241  				windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
   242  				goto retry
   243  			}
   244  			procyield(1)
   245  			continue
   246  		case windows.ERROR_HANDLE_EOF:
   247  			return 0, os.ErrClosed
   248  		case windows.ERROR_INVALID_DATA:
   249  			return 0, errors.New("Send ring corrupt")
   250  		}
   251  		return 0, fmt.Errorf("Read failed: %w", err)
   252  	}
   253  }
   254  
   255  func (tun *NativeTun) Flush() error {
   256  	return nil
   257  }
   258  
   259  func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
   260  	tun.running.Add(1)
   261  	defer tun.running.Done()
   262  	if atomic.LoadInt32(&tun.close) == 1 {
   263  		return 0, os.ErrClosed
   264  	}
   265  
   266  	packetSize := len(buff) - offset
   267  	tun.rate.update(uint64(packetSize))
   268  
   269  	packet, err := tun.session.AllocateSendPacket(packetSize)
   270  	if err == nil {
   271  		copy(packet, buff[offset:])
   272  		tun.session.SendPacket(packet)
   273  		return packetSize, nil
   274  	}
   275  	switch err {
   276  	case windows.ERROR_HANDLE_EOF:
   277  		return 0, os.ErrClosed
   278  	case windows.ERROR_BUFFER_OVERFLOW:
   279  		return 0, nil // Dropping when ring is full.
   280  	}
   281  	return 0, fmt.Errorf("Write failed: %w", err)
   282  }
   283  
   284  // LUID returns Windows interface instance ID.
   285  func (tun *NativeTun) LUID() uint64 {
   286  	tun.running.Add(1)
   287  	defer tun.running.Done()
   288  	if atomic.LoadInt32(&tun.close) == 1 {
   289  		return 0
   290  	}
   291  	return tun.wt.LUID()
   292  }
   293  
   294  // RunningVersion returns the running version of the Wintun driver.
   295  func (tun *NativeTun) RunningVersion() (version uint32, err error) {
   296  	return wintun.RunningVersion()
   297  }
   298  
   299  func (rate *rateJuggler) update(packetLen uint64) {
   300  	now := nanotime()
   301  	total := atomic.AddUint64(&rate.nextByteCount, packetLen)
   302  	period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
   303  	if period >= rateMeasurementGranularity {
   304  		if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
   305  			return
   306  		}
   307  		atomic.StoreInt64(&rate.nextStartTime, now)
   308  		atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
   309  		atomic.StoreUint64(&rate.nextByteCount, 0)
   310  		atomic.StoreInt32(&rate.changing, 0)
   311  	}
   312  }