github.com/forest33/wtun@v0.3.1/tun/tun_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 import ( 9 "errors" 10 "fmt" 11 "os" 12 "sync" 13 "sync/atomic" 14 "time" 15 _ "unsafe" 16 17 "golang.org/x/sys/windows" 18 19 "github.com/forest33/wtun/wintun" 20 ) 21 22 const ( 23 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 24 spinloopRateThreshold = 800000000 / 8 // 800mbps 25 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 26 ) 27 28 type rateJuggler struct { 29 current atomic.Uint64 30 nextByteCount atomic.Uint64 31 nextStartTime atomic.Int64 32 changing atomic.Bool 33 } 34 35 type NativeTun struct { 36 wt *wintun.Adapter 37 name string 38 handle windows.Handle 39 rate rateJuggler 40 session wintun.Session 41 readWait windows.Handle 42 events chan Event 43 running sync.WaitGroup 44 closeOnce sync.Once 45 close atomic.Bool 46 forcedMTU int 47 outSizes []int 48 } 49 50 var ( 51 WintunTunnelType = "WireGuard" 52 WintunStaticRequestedGUID *windows.GUID 53 ) 54 55 //go:linkname procyield runtime.procyield 56 func procyield(cycles uint32) 57 58 //go:linkname nanotime runtime.nanotime 59 func nanotime() int64 60 61 // CreateTUN creates a Wintun interface with the given name. Should a Wintun 62 // interface with the same name exist, it is reused. 63 func CreateTUN(ifname, tunnelName string, mtu int) (Device, error) { 64 return CreateTUNWithRequestedGUID(ifname, tunnelName, WintunStaticRequestedGUID, mtu) 65 } 66 67 // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 68 // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 69 func CreateTUNWithRequestedGUID(ifname, tunnelName string, requestedGUID *windows.GUID, mtu int) (Device, error) { 70 if tunnelName == "" { 71 tunnelName = WintunTunnelType 72 } 73 wt, err := wintun.CreateAdapter(ifname, tunnelName, requestedGUID) 74 if err != nil { 75 return nil, fmt.Errorf("Error creating interface: %w", err) 76 } 77 78 forcedMTU := 1420 79 if mtu > 0 { 80 forcedMTU = mtu 81 } 82 83 tun := &NativeTun{ 84 wt: wt, 85 name: ifname, 86 handle: windows.InvalidHandle, 87 events: make(chan Event, 10), 88 forcedMTU: forcedMTU, 89 } 90 91 tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 92 if err != nil { 93 tun.wt.Close() 94 close(tun.events) 95 return nil, fmt.Errorf("Error starting session: %w", err) 96 } 97 tun.readWait = tun.session.ReadWaitEvent() 98 return tun, nil 99 } 100 101 func (tun *NativeTun) Name() (string, error) { 102 return tun.name, nil 103 } 104 105 func (tun *NativeTun) File() *os.File { 106 return nil 107 } 108 109 func (tun *NativeTun) Events() <-chan Event { 110 return tun.events 111 } 112 113 func (tun *NativeTun) Close() error { 114 var err error 115 tun.closeOnce.Do(func() { 116 tun.close.Store(true) 117 windows.SetEvent(tun.readWait) 118 tun.running.Wait() 119 tun.session.End() 120 if tun.wt != nil { 121 tun.wt.Close() 122 } 123 close(tun.events) 124 }) 125 return err 126 } 127 128 func (tun *NativeTun) MTU() (int, error) { 129 return tun.forcedMTU, nil 130 } 131 132 // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 133 func (tun *NativeTun) ForceMTU(mtu int) { 134 update := tun.forcedMTU != mtu 135 tun.forcedMTU = mtu 136 if update { 137 tun.events <- EventMTUUpdate 138 } 139 } 140 141 func (tun *NativeTun) BatchSize() int { 142 // TODO: implement batching with wintun 143 return 1 144 } 145 146 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 147 148 func (tun *NativeTun) ReadPackets(bufs [][]byte, sizes []int, offset int) (int, error) { 149 tun.running.Add(1) 150 defer tun.running.Done() 151 retry: 152 if tun.close.Load() { 153 return 0, os.ErrClosed 154 } 155 start := nanotime() 156 shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 157 for { 158 if tun.close.Load() { 159 return 0, os.ErrClosed 160 } 161 packet, err := tun.session.ReceivePacket() 162 switch err { 163 case nil: 164 packetSize := len(packet) 165 copy(bufs[0][offset:], packet) 166 sizes[0] = packetSize 167 tun.session.ReleaseReceivePacket(packet) 168 tun.rate.update(uint64(packetSize)) 169 return 1, nil 170 case windows.ERROR_NO_MORE_ITEMS: 171 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 172 windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 173 goto retry 174 } 175 procyield(1) 176 continue 177 case windows.ERROR_HANDLE_EOF: 178 return 0, os.ErrClosed 179 case windows.ERROR_INVALID_DATA: 180 return 0, errors.New("Send ring corrupt") 181 } 182 return 0, fmt.Errorf("Read failed: %w", err) 183 } 184 } 185 186 func (tun *NativeTun) WritePackets(bufs [][]byte, offset int) (int, error) { 187 tun.running.Add(1) 188 defer tun.running.Done() 189 if tun.close.Load() { 190 return 0, os.ErrClosed 191 } 192 193 var total int 194 195 for i, buf := range bufs { 196 packetSize := len(buf) - offset 197 tun.rate.update(uint64(packetSize)) 198 199 packet, err := tun.session.AllocateSendPacket(packetSize) 200 switch err { 201 case nil: 202 // TODO: Explore options to eliminate this copy. 203 copy(packet, buf[offset:]) 204 tun.session.SendPacket(packet) 205 total += len(packet) 206 continue 207 case windows.ERROR_HANDLE_EOF: 208 return i, os.ErrClosed 209 case windows.ERROR_BUFFER_OVERFLOW: 210 continue // Dropping when ring is full. 211 default: 212 return i, fmt.Errorf("Write failed: %w", err) 213 } 214 } 215 return total, nil 216 } 217 218 // LUID returns Windows interface instance ID. 219 func (tun *NativeTun) LUID() uint64 { 220 tun.running.Add(1) 221 defer tun.running.Done() 222 if tun.close.Load() { 223 return 0 224 } 225 return tun.wt.LUID() 226 } 227 228 // RunningVersion returns the running version of the Wintun driver. 229 func (tun *NativeTun) RunningVersion() (version uint32, err error) { 230 return wintun.RunningVersion() 231 } 232 233 func (rate *rateJuggler) update(packetLen uint64) { 234 now := nanotime() 235 total := rate.nextByteCount.Add(packetLen) 236 period := uint64(now - rate.nextStartTime.Load()) 237 if period >= rateMeasurementGranularity { 238 if !rate.changing.CompareAndSwap(false, true) { 239 return 240 } 241 rate.nextStartTime.Store(now) 242 rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) 243 rate.nextByteCount.Store(0) 244 rate.changing.Store(false) 245 } 246 } 247 248 func (tun *NativeTun) Read(p []byte) (n int, err error) { 249 var ( 250 bufs = make([][]byte, 1) 251 sizes = make([]int, 1) 252 ) 253 254 bufs[0] = make([]byte, len(p)) 255 n, err = tun.ReadPackets(bufs, sizes, 0) 256 if err != nil { 257 return 0, err 258 } 259 if sizes[0] < 1 { 260 return 0, nil 261 } 262 263 copy(p, bufs[0][:sizes[0]]) 264 265 return sizes[0], nil 266 } 267 268 func (tun *NativeTun) Write(p []byte) (n int, err error) { 269 return tun.WritePackets([][]byte{p}, 0) 270 }