github.com/amnezia-vpn/amneziawg-go@v0.2.8/tun/tun_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  	_ "unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  	"golang.zx2c4.com/wintun"
    19  )
    20  
    21  const (
    22  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
    23  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
    24  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
    25  )
    26  
    27  type rateJuggler struct {
    28  	current       atomic.Uint64
    29  	nextByteCount atomic.Uint64
    30  	nextStartTime atomic.Int64
    31  	changing      atomic.Bool
    32  }
    33  
    34  type NativeTun struct {
    35  	wt        *wintun.Adapter
    36  	name      string
    37  	handle    windows.Handle
    38  	rate      rateJuggler
    39  	session   wintun.Session
    40  	readWait  windows.Handle
    41  	events    chan Event
    42  	running   sync.WaitGroup
    43  	closeOnce sync.Once
    44  	close     atomic.Bool
    45  	forcedMTU int
    46  	outSizes  []int
    47  }
    48  
    49  var (
    50  	WintunTunnelType          = "WireGuard"
    51  	WintunStaticRequestedGUID *windows.GUID
    52  )
    53  
    54  //go:linkname procyield runtime.procyield
    55  func procyield(cycles uint32)
    56  
    57  //go:linkname nanotime runtime.nanotime
    58  func nanotime() int64
    59  
    60  // CreateTUN creates a Wintun interface with the given name. Should a Wintun
    61  // interface with the same name exist, it is reused.
    62  func CreateTUN(ifname string, mtu int) (Device, error) {
    63  	return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
    64  }
    65  
    66  // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
    67  // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
    68  func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
    69  	wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
    70  	if err != nil {
    71  		return nil, fmt.Errorf("Error creating interface: %w", err)
    72  	}
    73  
    74  	forcedMTU := 1420
    75  	if mtu > 0 {
    76  		forcedMTU = mtu
    77  	}
    78  
    79  	tun := &NativeTun{
    80  		wt:        wt,
    81  		name:      ifname,
    82  		handle:    windows.InvalidHandle,
    83  		events:    make(chan Event, 10),
    84  		forcedMTU: forcedMTU,
    85  	}
    86  
    87  	tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
    88  	if err != nil {
    89  		tun.wt.Close()
    90  		close(tun.events)
    91  		return nil, fmt.Errorf("Error starting session: %w", err)
    92  	}
    93  	tun.readWait = tun.session.ReadWaitEvent()
    94  	return tun, nil
    95  }
    96  
    97  func (tun *NativeTun) Name() (string, error) {
    98  	return tun.name, nil
    99  }
   100  
   101  func (tun *NativeTun) File() *os.File {
   102  	return nil
   103  }
   104  
   105  func (tun *NativeTun) Events() <-chan Event {
   106  	return tun.events
   107  }
   108  
   109  func (tun *NativeTun) Close() error {
   110  	var err error
   111  	tun.closeOnce.Do(func() {
   112  		tun.close.Store(true)
   113  		windows.SetEvent(tun.readWait)
   114  		tun.running.Wait()
   115  		tun.session.End()
   116  		if tun.wt != nil {
   117  			tun.wt.Close()
   118  		}
   119  		close(tun.events)
   120  	})
   121  	return err
   122  }
   123  
   124  func (tun *NativeTun) MTU() (int, error) {
   125  	return tun.forcedMTU, nil
   126  }
   127  
   128  // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
   129  func (tun *NativeTun) ForceMTU(mtu int) {
   130  	if tun.close.Load() {
   131  		return
   132  	}
   133  	update := tun.forcedMTU != mtu
   134  	tun.forcedMTU = mtu
   135  	if update {
   136  		tun.events <- EventMTUUpdate
   137  	}
   138  }
   139  
   140  func (tun *NativeTun) BatchSize() int {
   141  	// TODO: implement batching with wintun
   142  	return 1
   143  }
   144  
   145  // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
   146  
   147  func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
   148  	tun.running.Add(1)
   149  	defer tun.running.Done()
   150  retry:
   151  	if tun.close.Load() {
   152  		return 0, os.ErrClosed
   153  	}
   154  	start := nanotime()
   155  	shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
   156  	for {
   157  		if tun.close.Load() {
   158  			return 0, os.ErrClosed
   159  		}
   160  		packet, err := tun.session.ReceivePacket()
   161  		switch err {
   162  		case nil:
   163  			n := copy(bufs[0][offset:], packet)
   164  			sizes[0] = n
   165  			tun.session.ReleaseReceivePacket(packet)
   166  			tun.rate.update(uint64(n))
   167  			return 1, nil
   168  		case windows.ERROR_NO_MORE_ITEMS:
   169  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   170  				windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
   171  				goto retry
   172  			}
   173  			procyield(1)
   174  			continue
   175  		case windows.ERROR_HANDLE_EOF:
   176  			return 0, os.ErrClosed
   177  		case windows.ERROR_INVALID_DATA:
   178  			return 0, errors.New("Send ring corrupt")
   179  		}
   180  		return 0, fmt.Errorf("Read failed: %w", err)
   181  	}
   182  }
   183  
   184  func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
   185  	tun.running.Add(1)
   186  	defer tun.running.Done()
   187  	if tun.close.Load() {
   188  		return 0, os.ErrClosed
   189  	}
   190  
   191  	for i, buf := range bufs {
   192  		packetSize := len(buf) - offset
   193  		tun.rate.update(uint64(packetSize))
   194  
   195  		packet, err := tun.session.AllocateSendPacket(packetSize)
   196  		switch err {
   197  		case nil:
   198  			// TODO: Explore options to eliminate this copy.
   199  			copy(packet, buf[offset:])
   200  			tun.session.SendPacket(packet)
   201  			continue
   202  		case windows.ERROR_HANDLE_EOF:
   203  			return i, os.ErrClosed
   204  		case windows.ERROR_BUFFER_OVERFLOW:
   205  			continue // Dropping when ring is full.
   206  		default:
   207  			return i, fmt.Errorf("Write failed: %w", err)
   208  		}
   209  	}
   210  	return len(bufs), nil
   211  }
   212  
   213  // LUID returns Windows interface instance ID.
   214  func (tun *NativeTun) LUID() uint64 {
   215  	tun.running.Add(1)
   216  	defer tun.running.Done()
   217  	if tun.close.Load() {
   218  		return 0
   219  	}
   220  	return tun.wt.LUID()
   221  }
   222  
   223  // RunningVersion returns the running version of the Wintun driver.
   224  func (tun *NativeTun) RunningVersion() (version uint32, err error) {
   225  	return wintun.RunningVersion()
   226  }
   227  
   228  func (rate *rateJuggler) update(packetLen uint64) {
   229  	now := nanotime()
   230  	total := rate.nextByteCount.Add(packetLen)
   231  	period := uint64(now - rate.nextStartTime.Load())
   232  	if period >= rateMeasurementGranularity {
   233  		if !rate.changing.CompareAndSwap(false, true) {
   234  			return
   235  		}
   236  		rate.nextStartTime.Store(now)
   237  		rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
   238  		rate.nextByteCount.Store(0)
   239  		rate.changing.Store(false)
   240  	}
   241  }