github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/tun/tun_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 import ( 9 "errors" 10 "fmt" 11 "os" 12 "sync" 13 "sync/atomic" 14 "time" 15 _ "unsafe" 16 17 "golang.org/x/sys/windows" 18 "golang.zx2c4.com/wintun" 19 ) 20 21 const ( 22 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 23 spinloopRateThreshold = 800000000 / 8 // 800mbps 24 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 25 ) 26 27 type rateJuggler struct { 28 current atomic.Uint64 29 nextByteCount atomic.Uint64 30 nextStartTime atomic.Int64 31 changing atomic.Bool 32 } 33 34 type NativeTun struct { 35 wt *wintun.Adapter 36 name string 37 handle windows.Handle 38 rate rateJuggler 39 session wintun.Session 40 readWait windows.Handle 41 events chan Event 42 running sync.WaitGroup 43 closeOnce sync.Once 44 close atomic.Bool 45 forcedMTU int 46 outSizes []int 47 } 48 49 var ( 50 WintunTunnelType = "WireGuard" 51 WintunStaticRequestedGUID *windows.GUID 52 ) 53 54 //go:linkname procyield runtime.procyield 55 func procyield(cycles uint32) 56 57 //go:linkname nanotime runtime.nanotime 58 func nanotime() int64 59 60 // CreateTUN creates a Wintun interface with the given name. Should a Wintun 61 // interface with the same name exist, it is reused. 62 func CreateTUN(ifname string, mtu int) (Device, error) { 63 return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) 64 } 65 66 // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 67 // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 68 func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { 69 wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) 70 if err != nil { 71 return nil, fmt.Errorf("Error creating interface: %w", err) 72 } 73 74 forcedMTU := 1420 75 if mtu > 0 { 76 forcedMTU = mtu 77 } 78 79 tun := &NativeTun{ 80 wt: wt, 81 name: ifname, 82 handle: windows.InvalidHandle, 83 events: make(chan Event, 10), 84 forcedMTU: forcedMTU, 85 } 86 87 tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 88 if err != nil { 89 tun.wt.Close() 90 close(tun.events) 91 return nil, fmt.Errorf("Error starting session: %w", err) 92 } 93 tun.readWait = tun.session.ReadWaitEvent() 94 return tun, nil 95 } 96 97 func (tun *NativeTun) Name() (string, error) { 98 return tun.name, nil 99 } 100 101 func (tun *NativeTun) File() *os.File { 102 return nil 103 } 104 105 func (tun *NativeTun) Events() <-chan Event { 106 return tun.events 107 } 108 109 func (tun *NativeTun) Close() error { 110 var err error 111 tun.closeOnce.Do(func() { 112 tun.close.Store(true) 113 windows.SetEvent(tun.readWait) 114 tun.running.Wait() 115 tun.session.End() 116 if tun.wt != nil { 117 tun.wt.Close() 118 } 119 close(tun.events) 120 }) 121 return err 122 } 123 124 func (tun *NativeTun) MTU() (int, error) { 125 return tun.forcedMTU, nil 126 } 127 128 // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 129 func (tun *NativeTun) ForceMTU(mtu int) { 130 if tun.close.Load() { 131 return 132 } 133 update := tun.forcedMTU != mtu 134 tun.forcedMTU = mtu 135 if update { 136 tun.events <- EventMTUUpdate 137 } 138 } 139 140 func (tun *NativeTun) BatchSize() int { 141 // TODO: implement batching with wintun 142 return 1 143 } 144 145 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 146 147 func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 148 tun.running.Add(1) 149 defer tun.running.Done() 150 retry: 151 if tun.close.Load() { 152 return 0, os.ErrClosed 153 } 154 start := nanotime() 155 shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 156 for { 157 if tun.close.Load() { 158 return 0, os.ErrClosed 159 } 160 packet, err := tun.session.ReceivePacket() 161 switch err { 162 case nil: 163 n := copy(bufs[0][offset:], packet) 164 sizes[0] = n 165 tun.session.ReleaseReceivePacket(packet) 166 tun.rate.update(uint64(n)) 167 return 1, nil 168 case windows.ERROR_NO_MORE_ITEMS: 169 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 170 windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 171 goto retry 172 } 173 procyield(1) 174 continue 175 case windows.ERROR_HANDLE_EOF: 176 return 0, os.ErrClosed 177 case windows.ERROR_INVALID_DATA: 178 return 0, errors.New("Send ring corrupt") 179 } 180 return 0, fmt.Errorf("Read failed: %w", err) 181 } 182 } 183 184 func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { 185 tun.running.Add(1) 186 defer tun.running.Done() 187 if tun.close.Load() { 188 return 0, os.ErrClosed 189 } 190 191 for i, buf := range bufs { 192 packetSize := len(buf) - offset 193 tun.rate.update(uint64(packetSize)) 194 195 packet, err := tun.session.AllocateSendPacket(packetSize) 196 switch err { 197 case nil: 198 // TODO: Explore options to eliminate this copy. 199 copy(packet, buf[offset:]) 200 tun.session.SendPacket(packet) 201 continue 202 case windows.ERROR_HANDLE_EOF: 203 return i, os.ErrClosed 204 case windows.ERROR_BUFFER_OVERFLOW: 205 continue // Dropping when ring is full. 206 default: 207 return i, fmt.Errorf("Write failed: %w", err) 208 } 209 } 210 return len(bufs), nil 211 } 212 213 // LUID returns Windows interface instance ID. 214 func (tun *NativeTun) LUID() uint64 { 215 tun.running.Add(1) 216 defer tun.running.Done() 217 if tun.close.Load() { 218 return 0 219 } 220 return tun.wt.LUID() 221 } 222 223 // RunningVersion returns the running version of the Wintun driver. 224 func (tun *NativeTun) RunningVersion() (version uint32, err error) { 225 return wintun.RunningVersion() 226 } 227 228 func (rate *rateJuggler) update(packetLen uint64) { 229 now := nanotime() 230 total := rate.nextByteCount.Add(packetLen) 231 period := uint64(now - rate.nextStartTime.Load()) 232 if period >= rateMeasurementGranularity { 233 if !rate.changing.CompareAndSwap(false, true) { 234 return 235 } 236 rate.nextStartTime.Store(now) 237 rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) 238 rate.nextByteCount.Store(0) 239 rate.changing.Store(false) 240 } 241 }