github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/tun/tun_windows.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package tun 7 8 import ( 9 "errors" 10 "fmt" 11 "os" 12 "sync" 13 "sync/atomic" 14 "time" 15 _ "unsafe" 16 17 "golang.org/x/sys/windows" 18 19 "golang.zx2c4.com/wintun" 20 ) 21 22 const ( 23 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 24 spinloopRateThreshold = 800000000 / 8 // 800mbps 25 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 26 ) 27 28 type rateJuggler struct { 29 current uint64 30 nextByteCount uint64 31 nextStartTime int64 32 changing int32 33 } 34 35 type NativeTun struct { 36 wt *wintun.Adapter 37 name string 38 handle windows.Handle 39 rate rateJuggler 40 session wintun.Session 41 readWait windows.Handle 42 events chan Event 43 running sync.WaitGroup 44 closeOnce sync.Once 45 close int32 46 forcedMTU int 47 } 48 49 var ( 50 WintunTunnelType = "WireGuard" 51 WintunStaticRequestedGUID *windows.GUID 52 ) 53 54 //go:linkname procyield runtime.procyield 55 func procyield(cycles uint32) 56 57 //go:linkname nanotime runtime.nanotime 58 func nanotime() int64 59 60 // 61 // CreateTUN creates a Wintun interface with the given name. Should a Wintun 62 // interface with the same name exist, it is reused. 63 // 64 func CreateTUN(ifname string, mtu int, nopi bool) (Device, error) { 65 return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu, nopi) 66 } 67 68 // 69 // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 70 // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 71 // 72 func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int, nopi bool) (Device, error) { 73 wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) 74 if err != nil { 75 return nil, fmt.Errorf("Error creating interface: %w", err) 76 } 77 78 forcedMTU := 1420 79 if mtu > 0 { 80 forcedMTU = mtu 81 } 82 83 tun := &NativeTun{ 84 wt: wt, 85 name: ifname, 86 handle: windows.InvalidHandle, 87 events: make(chan Event, 10), 88 forcedMTU: forcedMTU, 89 } 90 91 tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 92 if err != nil { 93 tun.wt.Close() 94 close(tun.events) 95 return nil, fmt.Errorf("Error starting session: %w", err) 96 } 97 tun.readWait = tun.session.ReadWaitEvent() 98 return tun, nil 99 } 100 101 func (tun *NativeTun) Name() (string, error) { 102 return tun.name, nil 103 } 104 105 func (tun *NativeTun) File() *os.File { 106 return nil 107 } 108 109 func (tun *NativeTun) Events() chan Event { 110 return tun.events 111 } 112 113 func (tun *NativeTun) Close() error { 114 var err error 115 tun.closeOnce.Do(func() { 116 atomic.StoreInt32(&tun.close, 1) 117 windows.SetEvent(tun.readWait) 118 tun.running.Wait() 119 tun.session.End() 120 if tun.wt != nil { 121 tun.wt.Close() 122 } 123 close(tun.events) 124 }) 125 return err 126 } 127 128 func (tun *NativeTun) MTU() (int, error) { 129 return tun.forcedMTU, nil 130 } 131 132 // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 133 func (tun *NativeTun) ForceMTU(mtu int) { 134 update := tun.forcedMTU != mtu 135 tun.forcedMTU = mtu 136 if update { 137 tun.events <- EventMTUUpdate 138 } 139 } 140 141 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 142 143 func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 144 tun.running.Add(1) 145 defer tun.running.Done() 146 retry: 147 if atomic.LoadInt32(&tun.close) == 1 { 148 return 0, os.ErrClosed 149 } 150 start := nanotime() 151 shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 152 for { 153 if atomic.LoadInt32(&tun.close) == 1 { 154 return 0, os.ErrClosed 155 } 156 packet, err := tun.session.ReceivePacket() 157 switch err { 158 case nil: 159 packetSize := len(packet) 160 copy(buff[offset:], packet) 161 tun.session.ReleaseReceivePacket(packet) 162 tun.rate.update(uint64(packetSize)) 163 return packetSize, nil 164 case windows.ERROR_NO_MORE_ITEMS: 165 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 166 windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 167 goto retry 168 } 169 procyield(1) 170 continue 171 case windows.ERROR_HANDLE_EOF: 172 return 0, os.ErrClosed 173 case windows.ERROR_INVALID_DATA: 174 return 0, errors.New("Send ring corrupt") 175 } 176 return 0, fmt.Errorf("Read failed: %w", err) 177 } 178 } 179 180 func (tun *NativeTun) Flush() error { 181 return nil 182 } 183 184 func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 185 tun.running.Add(1) 186 defer tun.running.Done() 187 if atomic.LoadInt32(&tun.close) == 1 { 188 return 0, os.ErrClosed 189 } 190 191 packetSize := len(buff) - offset 192 tun.rate.update(uint64(packetSize)) 193 194 packet, err := tun.session.AllocateSendPacket(packetSize) 195 if err == nil { 196 copy(packet, buff[offset:]) 197 tun.session.SendPacket(packet) 198 return packetSize, nil 199 } 200 switch err { 201 case windows.ERROR_HANDLE_EOF: 202 return 0, os.ErrClosed 203 case windows.ERROR_BUFFER_OVERFLOW: 204 return 0, nil // Dropping when ring is full. 205 } 206 return 0, fmt.Errorf("Write failed: %w", err) 207 } 208 209 // LUID returns Windows interface instance ID. 210 func (tun *NativeTun) LUID() uint64 { 211 tun.running.Add(1) 212 defer tun.running.Done() 213 if atomic.LoadInt32(&tun.close) == 1 { 214 return 0 215 } 216 return tun.wt.LUID() 217 } 218 219 // RunningVersion returns the running version of the Wintun driver. 220 func (tun *NativeTun) RunningVersion() (version uint32, err error) { 221 return wintun.RunningVersion() 222 } 223 224 func (rate *rateJuggler) update(packetLen uint64) { 225 now := nanotime() 226 total := atomic.AddUint64(&rate.nextByteCount, packetLen) 227 period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) 228 if period >= rateMeasurementGranularity { 229 if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { 230 return 231 } 232 atomic.StoreInt64(&rate.nextStartTime, now) 233 atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) 234 atomic.StoreUint64(&rate.nextByteCount, 0) 235 atomic.StoreInt32(&rate.changing, 0) 236 } 237 }