github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/tun/tun_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 import ( 9 "errors" 10 "fmt" 11 "os" 12 "sync" 13 "sync/atomic" 14 "time" 15 _ "unsafe" 16 17 "golang.org/x/sys/windows" 18 wintun "github.com/koomox/wintun-go" 19 ) 20 21 const ( 22 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 23 spinloopRateThreshold = 800000000 / 8 // 800mbps 24 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 25 ) 26 27 type rateJuggler struct { 28 current atomic.Uint64 29 nextByteCount atomic.Uint64 30 nextStartTime atomic.Int64 31 changing atomic.Bool 32 } 33 34 type NativeTun struct { 35 wt *wintun.Adapter 36 name string 37 handle windows.Handle 38 rate rateJuggler 39 session wintun.Session 40 readWait windows.Handle 41 events chan Event 42 running sync.WaitGroup 43 closeOnce sync.Once 44 close atomic.Bool 45 forcedMTU int 46 outSizes []int 47 } 48 49 var ( 50 WintunTunnelType = "WireGuard" 51 WintunStaticRequestedGUID *windows.GUID 52 ) 53 54 //go:linkname procyield runtime.procyield 55 func procyield(cycles uint32) 56 57 //go:linkname nanotime runtime.nanotime 58 func nanotime() int64 59 60 // CreateTUN creates a Wintun interface with the given name. Should a Wintun 61 // interface with the same name exist, it is reused. 62 func CreateTUN(ifname string, mtu int) (Device, error) { 63 return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) 64 } 65 66 // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 67 // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 68 func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { 69 wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) 70 if err != nil { 71 return nil, fmt.Errorf("Error creating interface: %w", err) 72 } 73 74 forcedMTU := 1420 75 if mtu > 0 { 76 forcedMTU = mtu 77 } 78 79 tun := &NativeTun{ 80 wt: wt, 81 name: ifname, 82 handle: windows.InvalidHandle, 83 events: make(chan Event, 10), 84 forcedMTU: forcedMTU, 85 } 86 87 tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 88 if err != nil { 89 tun.wt.Close() 90 close(tun.events) 91 return nil, fmt.Errorf("Error starting session: %w", err) 92 } 93 tun.readWait = tun.session.ReadWaitEvent() 94 return tun, nil 95 } 96 97 func (tun *NativeTun) Name() (string, error) { 98 return tun.name, nil 99 } 100 101 func (tun *NativeTun) File() *os.File { 102 return nil 103 } 104 105 func (tun *NativeTun) Events() <-chan Event { 106 return tun.events 107 } 108 109 func (tun *NativeTun) Close() error { 110 var err error 111 tun.closeOnce.Do(func() { 112 tun.close.Store(true) 113 windows.SetEvent(tun.readWait) 114 tun.running.Wait() 115 tun.session.End() 116 if tun.wt != nil { 117 tun.wt.Close() 118 } 119 close(tun.events) 120 }) 121 return err 122 } 123 124 func (tun *NativeTun) MTU() (int, error) { 125 return tun.forcedMTU, nil 126 } 127 128 // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 129 func (tun *NativeTun) ForceMTU(mtu int) { 130 update := tun.forcedMTU != mtu 131 tun.forcedMTU = mtu 132 if update { 133 tun.events <- EventMTUUpdate 134 } 135 } 136 137 func (tun *NativeTun) BatchSize() int { 138 // TODO: implement batching with wintun 139 return 1 140 } 141 142 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 143 144 func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 145 tun.running.Add(1) 146 defer tun.running.Done() 147 retry: 148 if tun.close.Load() { 149 return 0, os.ErrClosed 150 } 151 start := nanotime() 152 shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 153 for { 154 if tun.close.Load() { 155 return 0, os.ErrClosed 156 } 157 packet, err := tun.session.ReceivePacket() 158 switch err { 159 case nil: 160 packetSize := len(packet) 161 copy(bufs[0][offset:], packet) 162 sizes[0] = packetSize 163 tun.session.ReleaseReceivePacket(packet) 164 tun.rate.update(uint64(packetSize)) 165 return 1, nil 166 case windows.ERROR_NO_MORE_ITEMS: 167 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 168 windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 169 goto retry 170 } 171 procyield(1) 172 continue 173 case windows.ERROR_HANDLE_EOF: 174 return 0, os.ErrClosed 175 case windows.ERROR_INVALID_DATA: 176 return 0, errors.New("Send ring corrupt") 177 } 178 return 0, fmt.Errorf("Read failed: %w", err) 179 } 180 } 181 182 func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { 183 tun.running.Add(1) 184 defer tun.running.Done() 185 if tun.close.Load() { 186 return 0, os.ErrClosed 187 } 188 189 for i, buf := range bufs { 190 packetSize := len(buf) - offset 191 tun.rate.update(uint64(packetSize)) 192 193 packet, err := tun.session.AllocateSendPacket(packetSize) 194 switch err { 195 case nil: 196 // TODO: Explore options to eliminate this copy. 197 copy(packet, buf[offset:]) 198 tun.session.SendPacket(packet) 199 continue 200 case windows.ERROR_HANDLE_EOF: 201 return i, os.ErrClosed 202 case windows.ERROR_BUFFER_OVERFLOW: 203 continue // Dropping when ring is full. 204 default: 205 return i, fmt.Errorf("Write failed: %w", err) 206 } 207 } 208 return len(bufs), nil 209 } 210 211 // LUID returns Windows interface instance ID. 212 func (tun *NativeTun) LUID() uint64 { 213 tun.running.Add(1) 214 defer tun.running.Done() 215 if tun.close.Load() { 216 return 0 217 } 218 return tun.wt.LUID() 219 } 220 221 // RunningVersion returns the running version of the Wintun driver. 222 func (tun *NativeTun) RunningVersion() (version uint32, err error) { 223 return wintun.RunningVersion() 224 } 225 226 func (rate *rateJuggler) update(packetLen uint64) { 227 now := nanotime() 228 total := rate.nextByteCount.Add(packetLen) 229 period := uint64(now - rate.nextStartTime.Load()) 230 if period >= rateMeasurementGranularity { 231 if !rate.changing.CompareAndSwap(false, true) { 232 return 233 } 234 rate.nextStartTime.Store(now) 235 rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) 236 rate.nextByteCount.Store(0) 237 rate.changing.Store(false) 238 } 239 }