github.com/forest33/wtun@v0.3.1/tun/tun_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  	_ "unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  
    19  	"github.com/forest33/wtun/wintun"
    20  )
    21  
    22  const (
    23  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
    24  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
    25  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
    26  )
    27  
    28  type rateJuggler struct {
    29  	current       atomic.Uint64
    30  	nextByteCount atomic.Uint64
    31  	nextStartTime atomic.Int64
    32  	changing      atomic.Bool
    33  }
    34  
    35  type NativeTun struct {
    36  	wt        *wintun.Adapter
    37  	name      string
    38  	handle    windows.Handle
    39  	rate      rateJuggler
    40  	session   wintun.Session
    41  	readWait  windows.Handle
    42  	events    chan Event
    43  	running   sync.WaitGroup
    44  	closeOnce sync.Once
    45  	close     atomic.Bool
    46  	forcedMTU int
    47  	outSizes  []int
    48  }
    49  
    50  var (
    51  	WintunTunnelType          = "WireGuard"
    52  	WintunStaticRequestedGUID *windows.GUID
    53  )
    54  
    55  //go:linkname procyield runtime.procyield
    56  func procyield(cycles uint32)
    57  
    58  //go:linkname nanotime runtime.nanotime
    59  func nanotime() int64
    60  
    61  // CreateTUN creates a Wintun interface with the given name. Should a Wintun
    62  // interface with the same name exist, it is reused.
    63  func CreateTUN(ifname, tunnelName string, mtu int) (Device, error) {
    64  	return CreateTUNWithRequestedGUID(ifname, tunnelName, WintunStaticRequestedGUID, mtu)
    65  }
    66  
    67  // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
    68  // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
    69  func CreateTUNWithRequestedGUID(ifname, tunnelName string, requestedGUID *windows.GUID, mtu int) (Device, error) {
    70  	if tunnelName == "" {
    71  		tunnelName = WintunTunnelType
    72  	}
    73  	wt, err := wintun.CreateAdapter(ifname, tunnelName, requestedGUID)
    74  	if err != nil {
    75  		return nil, fmt.Errorf("Error creating interface: %w", err)
    76  	}
    77  
    78  	forcedMTU := 1420
    79  	if mtu > 0 {
    80  		forcedMTU = mtu
    81  	}
    82  
    83  	tun := &NativeTun{
    84  		wt:        wt,
    85  		name:      ifname,
    86  		handle:    windows.InvalidHandle,
    87  		events:    make(chan Event, 10),
    88  		forcedMTU: forcedMTU,
    89  	}
    90  
    91  	tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
    92  	if err != nil {
    93  		tun.wt.Close()
    94  		close(tun.events)
    95  		return nil, fmt.Errorf("Error starting session: %w", err)
    96  	}
    97  	tun.readWait = tun.session.ReadWaitEvent()
    98  	return tun, nil
    99  }
   100  
   101  func (tun *NativeTun) Name() (string, error) {
   102  	return tun.name, nil
   103  }
   104  
   105  func (tun *NativeTun) File() *os.File {
   106  	return nil
   107  }
   108  
   109  func (tun *NativeTun) Events() <-chan Event {
   110  	return tun.events
   111  }
   112  
   113  func (tun *NativeTun) Close() error {
   114  	var err error
   115  	tun.closeOnce.Do(func() {
   116  		tun.close.Store(true)
   117  		windows.SetEvent(tun.readWait)
   118  		tun.running.Wait()
   119  		tun.session.End()
   120  		if tun.wt != nil {
   121  			tun.wt.Close()
   122  		}
   123  		close(tun.events)
   124  	})
   125  	return err
   126  }
   127  
   128  func (tun *NativeTun) MTU() (int, error) {
   129  	return tun.forcedMTU, nil
   130  }
   131  
   132  // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
   133  func (tun *NativeTun) ForceMTU(mtu int) {
   134  	update := tun.forcedMTU != mtu
   135  	tun.forcedMTU = mtu
   136  	if update {
   137  		tun.events <- EventMTUUpdate
   138  	}
   139  }
   140  
   141  func (tun *NativeTun) BatchSize() int {
   142  	// TODO: implement batching with wintun
   143  	return 1
   144  }
   145  
   146  // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
   147  
   148  func (tun *NativeTun) ReadPackets(bufs [][]byte, sizes []int, offset int) (int, error) {
   149  	tun.running.Add(1)
   150  	defer tun.running.Done()
   151  retry:
   152  	if tun.close.Load() {
   153  		return 0, os.ErrClosed
   154  	}
   155  	start := nanotime()
   156  	shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
   157  	for {
   158  		if tun.close.Load() {
   159  			return 0, os.ErrClosed
   160  		}
   161  		packet, err := tun.session.ReceivePacket()
   162  		switch err {
   163  		case nil:
   164  			packetSize := len(packet)
   165  			copy(bufs[0][offset:], packet)
   166  			sizes[0] = packetSize
   167  			tun.session.ReleaseReceivePacket(packet)
   168  			tun.rate.update(uint64(packetSize))
   169  			return 1, nil
   170  		case windows.ERROR_NO_MORE_ITEMS:
   171  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   172  				windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
   173  				goto retry
   174  			}
   175  			procyield(1)
   176  			continue
   177  		case windows.ERROR_HANDLE_EOF:
   178  			return 0, os.ErrClosed
   179  		case windows.ERROR_INVALID_DATA:
   180  			return 0, errors.New("Send ring corrupt")
   181  		}
   182  		return 0, fmt.Errorf("Read failed: %w", err)
   183  	}
   184  }
   185  
   186  func (tun *NativeTun) WritePackets(bufs [][]byte, offset int) (int, error) {
   187  	tun.running.Add(1)
   188  	defer tun.running.Done()
   189  	if tun.close.Load() {
   190  		return 0, os.ErrClosed
   191  	}
   192  
   193  	var total int
   194  
   195  	for i, buf := range bufs {
   196  		packetSize := len(buf) - offset
   197  		tun.rate.update(uint64(packetSize))
   198  
   199  		packet, err := tun.session.AllocateSendPacket(packetSize)
   200  		switch err {
   201  		case nil:
   202  			// TODO: Explore options to eliminate this copy.
   203  			copy(packet, buf[offset:])
   204  			tun.session.SendPacket(packet)
   205  			total += len(packet)
   206  			continue
   207  		case windows.ERROR_HANDLE_EOF:
   208  			return i, os.ErrClosed
   209  		case windows.ERROR_BUFFER_OVERFLOW:
   210  			continue // Dropping when ring is full.
   211  		default:
   212  			return i, fmt.Errorf("Write failed: %w", err)
   213  		}
   214  	}
   215  	return total, nil
   216  }
   217  
   218  // LUID returns Windows interface instance ID.
   219  func (tun *NativeTun) LUID() uint64 {
   220  	tun.running.Add(1)
   221  	defer tun.running.Done()
   222  	if tun.close.Load() {
   223  		return 0
   224  	}
   225  	return tun.wt.LUID()
   226  }
   227  
   228  // RunningVersion returns the running version of the Wintun driver.
   229  func (tun *NativeTun) RunningVersion() (version uint32, err error) {
   230  	return wintun.RunningVersion()
   231  }
   232  
   233  func (rate *rateJuggler) update(packetLen uint64) {
   234  	now := nanotime()
   235  	total := rate.nextByteCount.Add(packetLen)
   236  	period := uint64(now - rate.nextStartTime.Load())
   237  	if period >= rateMeasurementGranularity {
   238  		if !rate.changing.CompareAndSwap(false, true) {
   239  			return
   240  		}
   241  		rate.nextStartTime.Store(now)
   242  		rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
   243  		rate.nextByteCount.Store(0)
   244  		rate.changing.Store(false)
   245  	}
   246  }
   247  
   248  func (tun *NativeTun) Read(p []byte) (n int, err error) {
   249  	var (
   250  		bufs  = make([][]byte, 1)
   251  		sizes = make([]int, 1)
   252  	)
   253  
   254  	bufs[0] = make([]byte, len(p))
   255  	n, err = tun.ReadPackets(bufs, sizes, 0)
   256  	if err != nil {
   257  		return 0, err
   258  	}
   259  	if sizes[0] < 1 {
   260  		return 0, nil
   261  	}
   262  
   263  	copy(p, bufs[0][:sizes[0]])
   264  
   265  	return sizes[0], nil
   266  }
   267  
   268  func (tun *NativeTun) Write(p []byte) (n int, err error) {
   269  	return tun.WritePackets([][]byte{p}, 0)
   270  }