github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/tun/tun_windows.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tun
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  	_ "unsafe"
    16  
    17  	"golang.org/x/sys/windows"
    18  
    19  	"golang.zx2c4.com/wintun"
    20  )
    21  
    22  const (
    23  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
    24  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
    25  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
    26  )
    27  
    28  type rateJuggler struct {
    29  	current       uint64
    30  	nextByteCount uint64
    31  	nextStartTime int64
    32  	changing      int32
    33  }
    34  
    35  type NativeTun struct {
    36  	wt        *wintun.Adapter
    37  	name      string
    38  	handle    windows.Handle
    39  	rate      rateJuggler
    40  	session   wintun.Session
    41  	readWait  windows.Handle
    42  	events    chan Event
    43  	running   sync.WaitGroup
    44  	closeOnce sync.Once
    45  	close     int32
    46  	forcedMTU int
    47  }
    48  
    49  var (
    50  	WintunTunnelType          = "WireGuard"
    51  	WintunStaticRequestedGUID *windows.GUID
    52  )
    53  
    54  //go:linkname procyield runtime.procyield
    55  func procyield(cycles uint32)
    56  
    57  //go:linkname nanotime runtime.nanotime
    58  func nanotime() int64
    59  
    60  //
    61  // CreateTUN creates a Wintun interface with the given name. Should a Wintun
    62  // interface with the same name exist, it is reused.
    63  //
    64  func CreateTUN(ifname string, mtu int, nopi bool) (Device, error) {
    65  	return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu, nopi)
    66  }
    67  
    68  //
    69  // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
    70  // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
    71  //
    72  func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int, nopi bool) (Device, error) {
    73  	wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
    74  	if err != nil {
    75  		return nil, fmt.Errorf("Error creating interface: %w", err)
    76  	}
    77  
    78  	forcedMTU := 1420
    79  	if mtu > 0 {
    80  		forcedMTU = mtu
    81  	}
    82  
    83  	tun := &NativeTun{
    84  		wt:        wt,
    85  		name:      ifname,
    86  		handle:    windows.InvalidHandle,
    87  		events:    make(chan Event, 10),
    88  		forcedMTU: forcedMTU,
    89  	}
    90  
    91  	tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
    92  	if err != nil {
    93  		tun.wt.Close()
    94  		close(tun.events)
    95  		return nil, fmt.Errorf("Error starting session: %w", err)
    96  	}
    97  	tun.readWait = tun.session.ReadWaitEvent()
    98  	return tun, nil
    99  }
   100  
   101  func (tun *NativeTun) Name() (string, error) {
   102  	return tun.name, nil
   103  }
   104  
   105  func (tun *NativeTun) File() *os.File {
   106  	return nil
   107  }
   108  
   109  func (tun *NativeTun) Events() chan Event {
   110  	return tun.events
   111  }
   112  
   113  func (tun *NativeTun) Close() error {
   114  	var err error
   115  	tun.closeOnce.Do(func() {
   116  		atomic.StoreInt32(&tun.close, 1)
   117  		windows.SetEvent(tun.readWait)
   118  		tun.running.Wait()
   119  		tun.session.End()
   120  		if tun.wt != nil {
   121  			tun.wt.Close()
   122  		}
   123  		close(tun.events)
   124  	})
   125  	return err
   126  }
   127  
   128  func (tun *NativeTun) MTU() (int, error) {
   129  	return tun.forcedMTU, nil
   130  }
   131  
   132  // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
   133  func (tun *NativeTun) ForceMTU(mtu int) {
   134  	update := tun.forcedMTU != mtu
   135  	tun.forcedMTU = mtu
   136  	if update {
   137  		tun.events <- EventMTUUpdate
   138  	}
   139  }
   140  
   141  // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
   142  
   143  func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
   144  	tun.running.Add(1)
   145  	defer tun.running.Done()
   146  retry:
   147  	if atomic.LoadInt32(&tun.close) == 1 {
   148  		return 0, os.ErrClosed
   149  	}
   150  	start := nanotime()
   151  	shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
   152  	for {
   153  		if atomic.LoadInt32(&tun.close) == 1 {
   154  			return 0, os.ErrClosed
   155  		}
   156  		packet, err := tun.session.ReceivePacket()
   157  		switch err {
   158  		case nil:
   159  			packetSize := len(packet)
   160  			copy(buff[offset:], packet)
   161  			tun.session.ReleaseReceivePacket(packet)
   162  			tun.rate.update(uint64(packetSize))
   163  			return packetSize, nil
   164  		case windows.ERROR_NO_MORE_ITEMS:
   165  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   166  				windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
   167  				goto retry
   168  			}
   169  			procyield(1)
   170  			continue
   171  		case windows.ERROR_HANDLE_EOF:
   172  			return 0, os.ErrClosed
   173  		case windows.ERROR_INVALID_DATA:
   174  			return 0, errors.New("Send ring corrupt")
   175  		}
   176  		return 0, fmt.Errorf("Read failed: %w", err)
   177  	}
   178  }
   179  
   180  func (tun *NativeTun) Flush() error {
   181  	return nil
   182  }
   183  
   184  func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
   185  	tun.running.Add(1)
   186  	defer tun.running.Done()
   187  	if atomic.LoadInt32(&tun.close) == 1 {
   188  		return 0, os.ErrClosed
   189  	}
   190  
   191  	packetSize := len(buff) - offset
   192  	tun.rate.update(uint64(packetSize))
   193  
   194  	packet, err := tun.session.AllocateSendPacket(packetSize)
   195  	if err == nil {
   196  		copy(packet, buff[offset:])
   197  		tun.session.SendPacket(packet)
   198  		return packetSize, nil
   199  	}
   200  	switch err {
   201  	case windows.ERROR_HANDLE_EOF:
   202  		return 0, os.ErrClosed
   203  	case windows.ERROR_BUFFER_OVERFLOW:
   204  		return 0, nil // Dropping when ring is full.
   205  	}
   206  	return 0, fmt.Errorf("Write failed: %w", err)
   207  }
   208  
   209  // LUID returns Windows interface instance ID.
   210  func (tun *NativeTun) LUID() uint64 {
   211  	tun.running.Add(1)
   212  	defer tun.running.Done()
   213  	if atomic.LoadInt32(&tun.close) == 1 {
   214  		return 0
   215  	}
   216  	return tun.wt.LUID()
   217  }
   218  
   219  // RunningVersion returns the running version of the Wintun driver.
   220  func (tun *NativeTun) RunningVersion() (version uint32, err error) {
   221  	return wintun.RunningVersion()
   222  }
   223  
   224  func (rate *rateJuggler) update(packetLen uint64) {
   225  	now := nanotime()
   226  	total := atomic.AddUint64(&rate.nextByteCount, packetLen)
   227  	period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
   228  	if period >= rateMeasurementGranularity {
   229  		if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
   230  			return
   231  		}
   232  		atomic.StoreInt64(&rate.nextStartTime, now)
   233  		atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
   234  		atomic.StoreUint64(&rate.nextByteCount, 0)
   235  		atomic.StoreInt32(&rate.changing, 0)
   236  	}
   237  }