github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/tun/tun_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  	_ "unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  
    19  	"github.com/bugfan/wireguard-go/tun/wintun"
    20  )
    21  
    22  const (
    23  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
    24  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
    25  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
    26  )
    27  
    28  type rateJuggler struct {
    29  	current       uint64
    30  	nextByteCount uint64
    31  	nextStartTime int64
    32  	changing      int32
    33  }
    34  
    35  type NativeTun struct {
    36  	wt        *wintun.Adapter
    37  	name      string
    38  	handle    windows.Handle
    39  	rate      rateJuggler
    40  	session   wintun.Session
    41  	readWait  windows.Handle
    42  	events    chan Event
    43  	running   sync.WaitGroup
    44  	closeOnce sync.Once
    45  	close     int32
    46  	forcedMTU int
    47  }
    48  
    49  var WintunTunnelType = "WireGuard"
    50  var WintunStaticRequestedGUID *windows.GUID
    51  
    52  //go:linkname procyield runtime.procyield
    53  func procyield(cycles uint32)
    54  
    55  //go:linkname nanotime runtime.nanotime
    56  func nanotime() int64
    57  
    58  //
    59  // CreateTUN creates a Wintun interface with the given name. Should a Wintun
    60  // interface with the same name exist, it is reused.
    61  //
    62  func CreateTUN(ifname string, mtu int) (Device, error) {
    63  	return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
    64  }
    65  
    66  //
    67  // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
    68  // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
    69  //
    70  func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
    71  	wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
    72  	if err != nil {
    73  		return nil, fmt.Errorf("Error creating interface: %w", err)
    74  	}
    75  
    76  	forcedMTU := 1420
    77  	if mtu > 0 {
    78  		forcedMTU = mtu
    79  	}
    80  
    81  	tun := &NativeTun{
    82  		wt:        wt,
    83  		name:      ifname,
    84  		handle:    windows.InvalidHandle,
    85  		events:    make(chan Event, 10),
    86  		forcedMTU: forcedMTU,
    87  	}
    88  
    89  	tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
    90  	if err != nil {
    91  		tun.wt.Close()
    92  		close(tun.events)
    93  		return nil, fmt.Errorf("Error starting session: %w", err)
    94  	}
    95  	tun.readWait = tun.session.ReadWaitEvent()
    96  	return tun, nil
    97  }
    98  
    99  func (tun *NativeTun) Name() (string, error) {
   100  	return tun.name, nil
   101  }
   102  
   103  func (tun *NativeTun) File() *os.File {
   104  	return nil
   105  }
   106  
   107  func (tun *NativeTun) Events() chan Event {
   108  	return tun.events
   109  }
   110  
   111  func (tun *NativeTun) Close() error {
   112  	var err error
   113  	tun.closeOnce.Do(func() {
   114  		atomic.StoreInt32(&tun.close, 1)
   115  		windows.SetEvent(tun.readWait)
   116  		tun.running.Wait()
   117  		tun.session.End()
   118  		if tun.wt != nil {
   119  			tun.wt.Close()
   120  		}
   121  		close(tun.events)
   122  	})
   123  	return err
   124  }
   125  
   126  func (tun *NativeTun) MTU() (int, error) {
   127  	return tun.forcedMTU, nil
   128  }
   129  
   130  // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
   131  func (tun *NativeTun) ForceMTU(mtu int) {
   132  	update := tun.forcedMTU != mtu
   133  	tun.forcedMTU = mtu
   134  	if update {
   135  		tun.events <- EventMTUUpdate
   136  	}
   137  }
   138  
   139  // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
   140  
   141  func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
   142  	tun.running.Add(1)
   143  	defer tun.running.Done()
   144  retry:
   145  	if atomic.LoadInt32(&tun.close) == 1 {
   146  		return 0, os.ErrClosed
   147  	}
   148  	start := nanotime()
   149  	shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
   150  	for {
   151  		if atomic.LoadInt32(&tun.close) == 1 {
   152  			return 0, os.ErrClosed
   153  		}
   154  		packet, err := tun.session.ReceivePacket()
   155  		switch err {
   156  		case nil:
   157  			packetSize := len(packet)
   158  			copy(buff[offset:], packet)
   159  			tun.session.ReleaseReceivePacket(packet)
   160  			tun.rate.update(uint64(packetSize))
   161  			return packetSize, nil
   162  		case windows.ERROR_NO_MORE_ITEMS:
   163  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   164  				windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
   165  				goto retry
   166  			}
   167  			procyield(1)
   168  			continue
   169  		case windows.ERROR_HANDLE_EOF:
   170  			return 0, os.ErrClosed
   171  		case windows.ERROR_INVALID_DATA:
   172  			return 0, errors.New("Send ring corrupt")
   173  		}
   174  		return 0, fmt.Errorf("Read failed: %w", err)
   175  	}
   176  }
   177  
   178  func (tun *NativeTun) Flush() error {
   179  	return nil
   180  }
   181  
   182  func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
   183  	tun.running.Add(1)
   184  	defer tun.running.Done()
   185  	if atomic.LoadInt32(&tun.close) == 1 {
   186  		return 0, os.ErrClosed
   187  	}
   188  
   189  	packetSize := len(buff) - offset
   190  	tun.rate.update(uint64(packetSize))
   191  
   192  	packet, err := tun.session.AllocateSendPacket(packetSize)
   193  	if err == nil {
   194  		copy(packet, buff[offset:])
   195  		tun.session.SendPacket(packet)
   196  		return packetSize, nil
   197  	}
   198  	switch err {
   199  	case windows.ERROR_HANDLE_EOF:
   200  		return 0, os.ErrClosed
   201  	case windows.ERROR_BUFFER_OVERFLOW:
   202  		return 0, nil // Dropping when ring is full.
   203  	}
   204  	return 0, fmt.Errorf("Write failed: %w", err)
   205  }
   206  
   207  // LUID returns Windows interface instance ID.
   208  func (tun *NativeTun) LUID() uint64 {
   209  	tun.running.Add(1)
   210  	defer tun.running.Done()
   211  	if atomic.LoadInt32(&tun.close) == 1 {
   212  		return 0
   213  	}
   214  	return tun.wt.LUID()
   215  }
   216  
   217  // RunningVersion returns the running version of the Wintun driver.
   218  func (tun *NativeTun) RunningVersion() (version uint32, err error) {
   219  	return wintun.RunningVersion()
   220  }
   221  
   222  func (rate *rateJuggler) update(packetLen uint64) {
   223  	now := nanotime()
   224  	total := atomic.AddUint64(&rate.nextByteCount, packetLen)
   225  	period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
   226  	if period >= rateMeasurementGranularity {
   227  		if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
   228  			return
   229  		}
   230  		atomic.StoreInt64(&rate.nextStartTime, now)
   231  		atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
   232  		atomic.StoreUint64(&rate.nextByteCount, 0)
   233  		atomic.StoreInt32(&rate.changing, 0)
   234  	}
   235  }