github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/tun/tun_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  	_ "unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  	wintun "github.com/koomox/wintun-go"
    19  )
    20  
    21  const (
    22  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
    23  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
    24  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
    25  )
    26  
    27  type rateJuggler struct {
    28  	current       atomic.Uint64
    29  	nextByteCount atomic.Uint64
    30  	nextStartTime atomic.Int64
    31  	changing      atomic.Bool
    32  }
    33  
    34  type NativeTun struct {
    35  	wt        *wintun.Adapter
    36  	name      string
    37  	handle    windows.Handle
    38  	rate      rateJuggler
    39  	session   wintun.Session
    40  	readWait  windows.Handle
    41  	events    chan Event
    42  	running   sync.WaitGroup
    43  	closeOnce sync.Once
    44  	close     atomic.Bool
    45  	forcedMTU int
    46  	outSizes  []int
    47  }
    48  
    49  var (
    50  	WintunTunnelType          = "WireGuard"
    51  	WintunStaticRequestedGUID *windows.GUID
    52  )
    53  
    54  //go:linkname procyield runtime.procyield
    55  func procyield(cycles uint32)
    56  
    57  //go:linkname nanotime runtime.nanotime
    58  func nanotime() int64
    59  
    60  // CreateTUN creates a Wintun interface with the given name. Should a Wintun
    61  // interface with the same name exist, it is reused.
    62  func CreateTUN(ifname string, mtu int) (Device, error) {
    63  	return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
    64  }
    65  
    66  // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
    67  // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
    68  func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
    69  	wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
    70  	if err != nil {
    71  		return nil, fmt.Errorf("Error creating interface: %w", err)
    72  	}
    73  
    74  	forcedMTU := 1420
    75  	if mtu > 0 {
    76  		forcedMTU = mtu
    77  	}
    78  
    79  	tun := &NativeTun{
    80  		wt:        wt,
    81  		name:      ifname,
    82  		handle:    windows.InvalidHandle,
    83  		events:    make(chan Event, 10),
    84  		forcedMTU: forcedMTU,
    85  	}
    86  
    87  	tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
    88  	if err != nil {
    89  		tun.wt.Close()
    90  		close(tun.events)
    91  		return nil, fmt.Errorf("Error starting session: %w", err)
    92  	}
    93  	tun.readWait = tun.session.ReadWaitEvent()
    94  	return tun, nil
    95  }
    96  
    97  func (tun *NativeTun) Name() (string, error) {
    98  	return tun.name, nil
    99  }
   100  
   101  func (tun *NativeTun) File() *os.File {
   102  	return nil
   103  }
   104  
   105  func (tun *NativeTun) Events() <-chan Event {
   106  	return tun.events
   107  }
   108  
   109  func (tun *NativeTun) Close() error {
   110  	var err error
   111  	tun.closeOnce.Do(func() {
   112  		tun.close.Store(true)
   113  		windows.SetEvent(tun.readWait)
   114  		tun.running.Wait()
   115  		tun.session.End()
   116  		if tun.wt != nil {
   117  			tun.wt.Close()
   118  		}
   119  		close(tun.events)
   120  	})
   121  	return err
   122  }
   123  
   124  func (tun *NativeTun) MTU() (int, error) {
   125  	return tun.forcedMTU, nil
   126  }
   127  
   128  // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
   129  func (tun *NativeTun) ForceMTU(mtu int) {
   130  	update := tun.forcedMTU != mtu
   131  	tun.forcedMTU = mtu
   132  	if update {
   133  		tun.events <- EventMTUUpdate
   134  	}
   135  }
   136  
   137  func (tun *NativeTun) BatchSize() int {
   138  	// TODO: implement batching with wintun
   139  	return 1
   140  }
   141  
   142  // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
   143  
   144  func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
   145  	tun.running.Add(1)
   146  	defer tun.running.Done()
   147  retry:
   148  	if tun.close.Load() {
   149  		return 0, os.ErrClosed
   150  	}
   151  	start := nanotime()
   152  	shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
   153  	for {
   154  		if tun.close.Load() {
   155  			return 0, os.ErrClosed
   156  		}
   157  		packet, err := tun.session.ReceivePacket()
   158  		switch err {
   159  		case nil:
   160  			packetSize := len(packet)
   161  			copy(bufs[0][offset:], packet)
   162  			sizes[0] = packetSize
   163  			tun.session.ReleaseReceivePacket(packet)
   164  			tun.rate.update(uint64(packetSize))
   165  			return 1, nil
   166  		case windows.ERROR_NO_MORE_ITEMS:
   167  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   168  				windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
   169  				goto retry
   170  			}
   171  			procyield(1)
   172  			continue
   173  		case windows.ERROR_HANDLE_EOF:
   174  			return 0, os.ErrClosed
   175  		case windows.ERROR_INVALID_DATA:
   176  			return 0, errors.New("Send ring corrupt")
   177  		}
   178  		return 0, fmt.Errorf("Read failed: %w", err)
   179  	}
   180  }
   181  
   182  func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
   183  	tun.running.Add(1)
   184  	defer tun.running.Done()
   185  	if tun.close.Load() {
   186  		return 0, os.ErrClosed
   187  	}
   188  
   189  	for i, buf := range bufs {
   190  		packetSize := len(buf) - offset
   191  		tun.rate.update(uint64(packetSize))
   192  
   193  		packet, err := tun.session.AllocateSendPacket(packetSize)
   194  		switch err {
   195  		case nil:
   196  			// TODO: Explore options to eliminate this copy.
   197  			copy(packet, buf[offset:])
   198  			tun.session.SendPacket(packet)
   199  			continue
   200  		case windows.ERROR_HANDLE_EOF:
   201  			return i, os.ErrClosed
   202  		case windows.ERROR_BUFFER_OVERFLOW:
   203  			continue // Dropping when ring is full.
   204  		default:
   205  			return i, fmt.Errorf("Write failed: %w", err)
   206  		}
   207  	}
   208  	return len(bufs), nil
   209  }
   210  
   211  // LUID returns Windows interface instance ID.
   212  func (tun *NativeTun) LUID() uint64 {
   213  	tun.running.Add(1)
   214  	defer tun.running.Done()
   215  	if tun.close.Load() {
   216  		return 0
   217  	}
   218  	return tun.wt.LUID()
   219  }
   220  
   221  // RunningVersion returns the running version of the Wintun driver.
   222  func (tun *NativeTun) RunningVersion() (version uint32, err error) {
   223  	return wintun.RunningVersion()
   224  }
   225  
   226  func (rate *rateJuggler) update(packetLen uint64) {
   227  	now := nanotime()
   228  	total := rate.nextByteCount.Add(packetLen)
   229  	period := uint64(now - rate.nextStartTime.Load())
   230  	if period >= rateMeasurementGranularity {
   231  		if !rate.changing.CompareAndSwap(false, true) {
   232  			return
   233  		}
   234  		rate.nextStartTime.Store(now)
   235  		rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
   236  		rate.nextByteCount.Store(0)
   237  		rate.changing.Store(false)
   238  	}
   239  }