github.com/borderzero/water@v0.0.1/syscalls_tun_windows.go (about) 1 package water 2 3 import ( 4 "crypto/rand" 5 _ "embed" 6 "encoding/binary" 7 "errors" 8 "fmt" 9 "os" 10 "sync" 11 "sync/atomic" 12 "time" 13 _ "unsafe" 14 15 "golang.org/x/sys/windows" 16 "golang.zx2c4.com/wintun" 17 ) 18 19 const ( 20 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 21 spinloopRateThreshold = 800000000 / 8 // 800mbps 22 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 23 ) 24 25 type Event int 26 27 const ( 28 EventUp = 1 << iota 29 EventDown 30 EventMTUUpdate 31 ) 32 33 type rateJuggler struct { 34 current uint64 35 nextByteCount uint64 36 nextStartTime int64 37 changing int32 38 } 39 40 type NativeTun struct { 41 wt *wintun.Adapter 42 name string 43 handle windows.Handle 44 rate rateJuggler 45 session wintun.Session 46 readWait windows.Handle 47 events chan Event 48 running sync.WaitGroup 49 closeOnce sync.Once 50 close int32 51 forcedMTU int 52 } 53 54 type WTun struct { 55 dev *NativeTun 56 } 57 58 func (w *WTun) Close() error { 59 return w.dev.Close() 60 } 61 62 func (w *WTun) Write(b []byte) (int, error) { 63 return w.dev.Write(b, 0) 64 } 65 66 func (w *WTun) Read(b []byte) (int, error) { 67 return w.dev.Read(b, 0) 68 } 69 70 var ( 71 WintunTunnelType = "Wintun" 72 WintunStaticRequestedGUID *windows.GUID 73 ) 74 75 //go:linkname procyield runtime.procyield 76 func procyield(cycles uint32) 77 78 //go:linkname nanotime runtime.nanotime 79 func nanotime() int64 80 81 func generateRandomGUID() (*windows.GUID, error) { 82 var guid windows.GUID 83 84 // Generate random values for the Data1, Data2, and Data3 fields 85 if err := binary.Read(rand.Reader, binary.LittleEndian, &guid.Data1); err != nil { 86 return nil, err 87 } 88 if err := binary.Read(rand.Reader, binary.LittleEndian, &guid.Data2); err != nil { 89 return nil, err 90 } 91 if err := binary.Read(rand.Reader, binary.LittleEndian, &guid.Data3); err != nil { 92 return nil, err 93 } 94 95 // Generate random bytes for the Data4 field 96 if _, err := rand.Read(guid.Data4[:]); err != nil { 97 return nil, err 98 } 99 100 return &guid, nil 101 } 102 103 func openTunDev(config Config) (ifce *Interface, err error) { 104 /* 105 gUID := &windows.GUID{ 106 0x0000000, 107 0xFFFF, 108 0xFFFF, 109 [8]byte{0xFF, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e}, 110 } 111 */ 112 113 // We'll geneerate a random GUID for the Wintun interface 114 // This is to work around some issue in the Wintun driver that causes 115 // it to fail to create an interface as claims it already exists 116 gUID, _ := generateRandomGUID() 117 118 if config.PlatformSpecificParams.Name == "" { 119 config.PlatformSpecificParams.Name = "WaterIface" 120 } 121 nativeTunDevice, err := CreateTUNWithRequestedGUID(config.PlatformSpecificParams.Name, gUID, 0) 122 if err != nil { 123 // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. 124 // Trying a second time resolves the issue. 125 time.Sleep(1 * time.Second) 126 nativeTunDevice, err = CreateTUNWithRequestedGUID(config.PlatformSpecificParams.Name, gUID, 0) 127 if err != nil { 128 return nil, err 129 } 130 } 131 ifce = &Interface{ 132 isTAP: config.DeviceType == TAP, 133 ReadWriteCloser: &WTun{dev: nativeTunDevice}, 134 name: config.PlatformSpecificParams.Name, 135 } 136 return ifce, nil 137 } 138 139 // CreateTUN creates a Wintun interface with the given name. Should a Wintun 140 // interface with the same name exist, it is reused. 141 func CreateTUN(ifname string, mtu int) (*NativeTun, error) { 142 return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) 143 } 144 145 // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 146 // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 147 func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (*NativeTun, error) { 148 wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) 149 if err != nil { 150 return nil, fmt.Errorf("Error creating interface: %w", err) 151 } 152 153 forcedMTU := 1420 154 if mtu > 0 { 155 forcedMTU = mtu 156 } 157 158 tun := &NativeTun{ 159 wt: wt, 160 name: ifname, 161 handle: windows.InvalidHandle, 162 events: make(chan Event, 10), 163 forcedMTU: forcedMTU, 164 } 165 166 tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 167 if err != nil { 168 tun.wt.Close() 169 close(tun.events) 170 return nil, fmt.Errorf("Error starting session: %w", err) 171 } 172 tun.readWait = tun.session.ReadWaitEvent() 173 return tun, nil 174 } 175 176 func (tun *NativeTun) Name() (string, error) { 177 return tun.name, nil 178 } 179 180 func (tun *NativeTun) File() *os.File { 181 return nil 182 } 183 184 func (tun *NativeTun) Events() chan Event { 185 return tun.events 186 } 187 188 func (tun *NativeTun) Close() error { 189 var err error 190 tun.closeOnce.Do(func() { 191 atomic.StoreInt32(&tun.close, 1) 192 windows.SetEvent(tun.readWait) 193 tun.running.Wait() 194 tun.session.End() 195 if tun.wt != nil { 196 tun.wt.Close() 197 } 198 close(tun.events) 199 }) 200 return err 201 } 202 203 func (tun *NativeTun) MTU() (int, error) { 204 return tun.forcedMTU, nil 205 } 206 207 // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 208 func (tun *NativeTun) ForceMTU(mtu int) { 209 update := tun.forcedMTU != mtu 210 tun.forcedMTU = mtu 211 if update { 212 tun.events <- EventMTUUpdate 213 } 214 } 215 216 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 217 218 func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 219 tun.running.Add(1) 220 defer tun.running.Done() 221 retry: 222 if atomic.LoadInt32(&tun.close) == 1 { 223 return 0, os.ErrClosed 224 } 225 start := nanotime() 226 shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 227 for { 228 if atomic.LoadInt32(&tun.close) == 1 { 229 return 0, os.ErrClosed 230 } 231 packet, err := tun.session.ReceivePacket() 232 switch err { 233 case nil: 234 packetSize := len(packet) 235 copy(buff[offset:], packet) 236 tun.session.ReleaseReceivePacket(packet) 237 tun.rate.update(uint64(packetSize)) 238 return packetSize, nil 239 case windows.ERROR_NO_MORE_ITEMS: 240 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 241 windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 242 goto retry 243 } 244 procyield(1) 245 continue 246 case windows.ERROR_HANDLE_EOF: 247 return 0, os.ErrClosed 248 case windows.ERROR_INVALID_DATA: 249 return 0, errors.New("Send ring corrupt") 250 } 251 return 0, fmt.Errorf("Read failed: %w", err) 252 } 253 } 254 255 func (tun *NativeTun) Flush() error { 256 return nil 257 } 258 259 func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 260 tun.running.Add(1) 261 defer tun.running.Done() 262 if atomic.LoadInt32(&tun.close) == 1 { 263 return 0, os.ErrClosed 264 } 265 266 packetSize := len(buff) - offset 267 tun.rate.update(uint64(packetSize)) 268 269 packet, err := tun.session.AllocateSendPacket(packetSize) 270 if err == nil { 271 copy(packet, buff[offset:]) 272 tun.session.SendPacket(packet) 273 return packetSize, nil 274 } 275 switch err { 276 case windows.ERROR_HANDLE_EOF: 277 return 0, os.ErrClosed 278 case windows.ERROR_BUFFER_OVERFLOW: 279 return 0, nil // Dropping when ring is full. 280 } 281 return 0, fmt.Errorf("Write failed: %w", err) 282 } 283 284 // LUID returns Windows interface instance ID. 285 func (tun *NativeTun) LUID() uint64 { 286 tun.running.Add(1) 287 defer tun.running.Done() 288 if atomic.LoadInt32(&tun.close) == 1 { 289 return 0 290 } 291 return tun.wt.LUID() 292 } 293 294 // RunningVersion returns the running version of the Wintun driver. 295 func (tun *NativeTun) RunningVersion() (version uint32, err error) { 296 return wintun.RunningVersion() 297 } 298 299 func (rate *rateJuggler) update(packetLen uint64) { 300 now := nanotime() 301 total := atomic.AddUint64(&rate.nextByteCount, packetLen) 302 period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) 303 if period >= rateMeasurementGranularity { 304 if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { 305 return 306 } 307 atomic.StoreInt64(&rate.nextStartTime, now) 308 atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) 309 atomic.StoreUint64(&rate.nextByteCount, 0) 310 atomic.StoreInt32(&rate.changing, 0) 311 } 312 }