github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/tun/tun_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 import ( 9 "errors" 10 "fmt" 11 "os" 12 "sync" 13 "sync/atomic" 14 "time" 15 _ "unsafe" 16 17 "golang.org/x/sys/windows" 18 19 "github.com/bugfan/wireguard-go/tun/wintun" 20 ) 21 22 const ( 23 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 24 spinloopRateThreshold = 800000000 / 8 // 800mbps 25 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 26 ) 27 28 type rateJuggler struct { 29 current uint64 30 nextByteCount uint64 31 nextStartTime int64 32 changing int32 33 } 34 35 type NativeTun struct { 36 wt *wintun.Adapter 37 name string 38 handle windows.Handle 39 rate rateJuggler 40 session wintun.Session 41 readWait windows.Handle 42 events chan Event 43 running sync.WaitGroup 44 closeOnce sync.Once 45 close int32 46 forcedMTU int 47 } 48 49 var WintunTunnelType = "WireGuard" 50 var WintunStaticRequestedGUID *windows.GUID 51 52 //go:linkname procyield runtime.procyield 53 func procyield(cycles uint32) 54 55 //go:linkname nanotime runtime.nanotime 56 func nanotime() int64 57 58 // 59 // CreateTUN creates a Wintun interface with the given name. Should a Wintun 60 // interface with the same name exist, it is reused. 61 // 62 func CreateTUN(ifname string, mtu int) (Device, error) { 63 return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) 64 } 65 66 // 67 // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 68 // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 69 // 70 func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { 71 wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) 72 if err != nil { 73 return nil, fmt.Errorf("Error creating interface: %w", err) 74 } 75 76 forcedMTU := 1420 77 if mtu > 0 { 78 forcedMTU = mtu 79 } 80 81 tun := &NativeTun{ 82 wt: wt, 83 name: ifname, 84 handle: windows.InvalidHandle, 85 events: make(chan Event, 10), 86 forcedMTU: forcedMTU, 87 } 88 89 tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 90 if err != nil { 91 tun.wt.Close() 92 close(tun.events) 93 return nil, fmt.Errorf("Error starting session: %w", err) 94 } 95 tun.readWait = tun.session.ReadWaitEvent() 96 return tun, nil 97 } 98 99 func (tun *NativeTun) Name() (string, error) { 100 return tun.name, nil 101 } 102 103 func (tun *NativeTun) File() *os.File { 104 return nil 105 } 106 107 func (tun *NativeTun) Events() chan Event { 108 return tun.events 109 } 110 111 func (tun *NativeTun) Close() error { 112 var err error 113 tun.closeOnce.Do(func() { 114 atomic.StoreInt32(&tun.close, 1) 115 windows.SetEvent(tun.readWait) 116 tun.running.Wait() 117 tun.session.End() 118 if tun.wt != nil { 119 tun.wt.Close() 120 } 121 close(tun.events) 122 }) 123 return err 124 } 125 126 func (tun *NativeTun) MTU() (int, error) { 127 return tun.forcedMTU, nil 128 } 129 130 // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 131 func (tun *NativeTun) ForceMTU(mtu int) { 132 update := tun.forcedMTU != mtu 133 tun.forcedMTU = mtu 134 if update { 135 tun.events <- EventMTUUpdate 136 } 137 } 138 139 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 140 141 func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 142 tun.running.Add(1) 143 defer tun.running.Done() 144 retry: 145 if atomic.LoadInt32(&tun.close) == 1 { 146 return 0, os.ErrClosed 147 } 148 start := nanotime() 149 shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 150 for { 151 if atomic.LoadInt32(&tun.close) == 1 { 152 return 0, os.ErrClosed 153 } 154 packet, err := tun.session.ReceivePacket() 155 switch err { 156 case nil: 157 packetSize := len(packet) 158 copy(buff[offset:], packet) 159 tun.session.ReleaseReceivePacket(packet) 160 tun.rate.update(uint64(packetSize)) 161 return packetSize, nil 162 case windows.ERROR_NO_MORE_ITEMS: 163 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 164 windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 165 goto retry 166 } 167 procyield(1) 168 continue 169 case windows.ERROR_HANDLE_EOF: 170 return 0, os.ErrClosed 171 case windows.ERROR_INVALID_DATA: 172 return 0, errors.New("Send ring corrupt") 173 } 174 return 0, fmt.Errorf("Read failed: %w", err) 175 } 176 } 177 178 func (tun *NativeTun) Flush() error { 179 return nil 180 } 181 182 func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 183 tun.running.Add(1) 184 defer tun.running.Done() 185 if atomic.LoadInt32(&tun.close) == 1 { 186 return 0, os.ErrClosed 187 } 188 189 packetSize := len(buff) - offset 190 tun.rate.update(uint64(packetSize)) 191 192 packet, err := tun.session.AllocateSendPacket(packetSize) 193 if err == nil { 194 copy(packet, buff[offset:]) 195 tun.session.SendPacket(packet) 196 return packetSize, nil 197 } 198 switch err { 199 case windows.ERROR_HANDLE_EOF: 200 return 0, os.ErrClosed 201 case windows.ERROR_BUFFER_OVERFLOW: 202 return 0, nil // Dropping when ring is full. 203 } 204 return 0, fmt.Errorf("Write failed: %w", err) 205 } 206 207 // LUID returns Windows interface instance ID. 208 func (tun *NativeTun) LUID() uint64 { 209 tun.running.Add(1) 210 defer tun.running.Done() 211 if atomic.LoadInt32(&tun.close) == 1 { 212 return 0 213 } 214 return tun.wt.LUID() 215 } 216 217 // RunningVersion returns the running version of the Wintun driver. 218 func (tun *NativeTun) RunningVersion() (version uint32, err error) { 219 return wintun.RunningVersion() 220 } 221 222 func (rate *rateJuggler) update(packetLen uint64) { 223 now := nanotime() 224 total := atomic.AddUint64(&rate.nextByteCount, packetLen) 225 period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) 226 if period >= rateMeasurementGranularity { 227 if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { 228 return 229 } 230 atomic.StoreInt64(&rate.nextStartTime, now) 231 atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) 232 atomic.StoreUint64(&rate.nextByteCount, 0) 233 atomic.StoreInt32(&rate.changing, 0) 234 } 235 }