github.com/sagernet/sing-tun@v0.3.0-beta.5/tun_windows.go (about) 1 package tun 2 3 import ( 4 "crypto/md5" 5 "errors" 6 "fmt" 7 "math" 8 "net" 9 "net/netip" 10 "os" 11 "sync" 12 "time" 13 "unsafe" 14 15 "github.com/sagernet/sing-tun/internal/winipcfg" 16 "github.com/sagernet/sing-tun/internal/winsys" 17 "github.com/sagernet/sing-tun/internal/wintun" 18 "github.com/sagernet/sing/common" 19 "github.com/sagernet/sing/common/atomic" 20 "github.com/sagernet/sing/common/buf" 21 E "github.com/sagernet/sing/common/exceptions" 22 "github.com/sagernet/sing/common/windnsapi" 23 24 "golang.org/x/sys/windows" 25 ) 26 27 var TunnelType = "sing-tun" 28 29 type NativeTun struct { 30 adapter *wintun.Adapter 31 options Options 32 session wintun.Session 33 readWait windows.Handle 34 rate rateJuggler 35 running sync.WaitGroup 36 closeOnce sync.Once 37 close atomic.Int32 38 fwpmSession uintptr 39 } 40 41 func New(options Options) (WinTun, error) { 42 if options.FileDescriptor != 0 { 43 return nil, os.ErrInvalid 44 } 45 adapter, err := wintun.CreateAdapter(options.Name, TunnelType, generateGUIDByDeviceName(options.Name)) 46 if err != nil { 47 return nil, err 48 } 49 nativeTun := &NativeTun{ 50 adapter: adapter, 51 options: options, 52 } 53 session, err := adapter.StartSession(0x800000) 54 if err != nil { 55 return nil, err 56 } 57 nativeTun.session = session 58 nativeTun.readWait = session.ReadWaitEvent() 59 err = nativeTun.configure() 60 if err != nil { 61 session.End() 62 adapter.Close() 63 return nil, err 64 } 65 return nativeTun, nil 66 } 67 68 func (t *NativeTun) configure() error { 69 luid := winipcfg.LUID(t.adapter.LUID()) 70 if len(t.options.Inet4Address) > 0 { 71 err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET), t.options.Inet4Address) 72 if err != nil { 73 return E.Cause(err, "set ipv4 address") 74 } 75 err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil) 76 if err != nil { 77 return E.Cause(err, "set ipv4 dns") 78 } 79 } 80 if len(t.options.Inet6Address) > 0 { 81 err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address) 82 if err != nil { 83 return E.Cause(err, "set ipv6 address") 84 } 85 err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil) 86 if err != nil { 87 return E.Cause(err, "set ipv6 dns") 88 } 89 } 90 if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 { 91 _ = luid.DisableDNSRegistration() 92 } 93 if t.options.AutoRoute { 94 routeRanges, err := t.options.BuildAutoRouteRanges(false) 95 if err != nil { 96 return err 97 } 98 for _, routeRange := range routeRanges { 99 if routeRange.Addr().Is4() { 100 err = luid.AddRoute(routeRange, netip.IPv4Unspecified(), 0) 101 } else { 102 err = luid.AddRoute(routeRange, netip.IPv6Unspecified(), 0) 103 } 104 } 105 err = windnsapi.FlushResolverCache() 106 if err != nil { 107 return err 108 } 109 } 110 if len(t.options.Inet4Address) > 0 { 111 inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) 112 if err != nil { 113 return err 114 } 115 inetIf.ForwardingEnabled = true 116 inetIf.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled 117 inetIf.DadTransmits = 0 118 inetIf.ManagedAddressConfigurationSupported = false 119 inetIf.OtherStatefulConfigurationSupported = false 120 inetIf.NLMTU = t.options.MTU 121 if t.options.AutoRoute { 122 inetIf.UseAutomaticMetric = false 123 inetIf.Metric = 0 124 } 125 err = inetIf.Set() 126 if err != nil { 127 return E.Cause(err, "set ipv4 options") 128 } 129 } 130 if len(t.options.Inet6Address) > 0 { 131 inet6If, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6)) 132 if err != nil { 133 return err 134 } 135 inet6If.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled 136 inet6If.DadTransmits = 0 137 inet6If.ManagedAddressConfigurationSupported = false 138 inet6If.OtherStatefulConfigurationSupported = false 139 inet6If.NLMTU = t.options.MTU 140 if t.options.AutoRoute { 141 inet6If.UseAutomaticMetric = false 142 inet6If.Metric = 0 143 } 144 err = inet6If.Set() 145 if err != nil { 146 return E.Cause(err, "set ipv6 options") 147 } 148 } 149 150 if t.options.AutoRoute && t.options.StrictRoute { 151 var engine uintptr 152 session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC} 153 err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine)) 154 if err != nil { 155 return os.NewSyscallError("FwpmEngineOpen0", err) 156 } 157 t.fwpmSession = engine 158 159 subLayerKey, err := windows.GenerateGUID() 160 if err != nil { 161 return os.NewSyscallError("CoCreateGuid", err) 162 } 163 164 subLayer := winsys.FWPM_SUBLAYER0{} 165 subLayer.SubLayerKey = subLayerKey 166 subLayer.DisplayData = winsys.CreateDisplayData(TunnelType, "auto-route rules") 167 subLayer.Weight = math.MaxUint16 168 err = winsys.FwpmSubLayerAdd0(engine, &subLayer, 0) 169 if err != nil { 170 return os.NewSyscallError("FwpmSubLayerAdd0", err) 171 } 172 173 processAppID, err := winsys.GetCurrentProcessAppID() 174 if err != nil { 175 return err 176 } 177 defer winsys.FwpmFreeMemory0(unsafe.Pointer(&processAppID)) 178 179 var filterId uint64 180 permitCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) 181 permitCondition[0].FieldKey = winsys.FWPM_CONDITION_ALE_APP_ID 182 permitCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 183 permitCondition[0].ConditionValue.Type = winsys.FWP_BYTE_BLOB_TYPE 184 permitCondition[0].ConditionValue.Value = uintptr(unsafe.Pointer(processAppID)) 185 186 permitFilter4 := winsys.FWPM_FILTER0{} 187 permitFilter4.FilterCondition = &permitCondition[0] 188 permitFilter4.NumFilterConditions = 1 189 permitFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv4") 190 permitFilter4.SubLayerKey = subLayerKey 191 permitFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 192 permitFilter4.Action.Type = winsys.FWP_ACTION_PERMIT 193 permitFilter4.Weight.Type = winsys.FWP_UINT8 194 permitFilter4.Weight.Value = uintptr(13) 195 permitFilter4.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT 196 err = winsys.FwpmFilterAdd0(engine, &permitFilter4, 0, &filterId) 197 if err != nil { 198 return os.NewSyscallError("FwpmFilterAdd0", err) 199 } 200 201 permitFilter6 := winsys.FWPM_FILTER0{} 202 permitFilter6.FilterCondition = &permitCondition[0] 203 permitFilter6.NumFilterConditions = 1 204 permitFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv6") 205 permitFilter6.SubLayerKey = subLayerKey 206 permitFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 207 permitFilter6.Action.Type = winsys.FWP_ACTION_PERMIT 208 permitFilter6.Weight.Type = winsys.FWP_UINT8 209 permitFilter6.Weight.Value = uintptr(13) 210 permitFilter6.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT 211 err = winsys.FwpmFilterAdd0(engine, &permitFilter6, 0, &filterId) 212 if err != nil { 213 return os.NewSyscallError("FwpmFilterAdd0", err) 214 } 215 216 /*if len(t.options.Inet4Address) == 0 { 217 blockFilter := winsys.FWPM_FILTER0{} 218 blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4") 219 blockFilter.SubLayerKey = subLayerKey 220 blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 221 blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK 222 blockFilter.Weight.Type = winsys.FWP_UINT8 223 blockFilter.Weight.Value = uintptr(12) 224 err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId) 225 if err != nil { 226 return os.NewSyscallError("FwpmFilterAdd0", err) 227 } 228 }*/ 229 230 if len(t.options.Inet6Address) == 0 { 231 blockFilter := winsys.FWPM_FILTER0{} 232 blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6") 233 blockFilter.SubLayerKey = subLayerKey 234 blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 235 blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK 236 blockFilter.Weight.Type = winsys.FWP_UINT8 237 blockFilter.Weight.Value = uintptr(12) 238 err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId) 239 if err != nil { 240 return os.NewSyscallError("FwpmFilterAdd0", err) 241 } 242 } 243 244 netInterface, err := net.InterfaceByName(t.options.Name) 245 if err != nil { 246 return err 247 } 248 249 tunCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) 250 tunCondition[0].FieldKey = winsys.FWPM_CONDITION_LOCAL_INTERFACE_INDEX 251 tunCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 252 tunCondition[0].ConditionValue.Type = winsys.FWP_UINT32 253 tunCondition[0].ConditionValue.Value = uintptr(uint32(netInterface.Index)) 254 255 if len(t.options.Inet4Address) > 0 { 256 tunFilter4 := winsys.FWPM_FILTER0{} 257 tunFilter4.FilterCondition = &tunCondition[0] 258 tunFilter4.NumFilterConditions = 1 259 tunFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv4") 260 tunFilter4.SubLayerKey = subLayerKey 261 tunFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 262 tunFilter4.Action.Type = winsys.FWP_ACTION_PERMIT 263 tunFilter4.Weight.Type = winsys.FWP_UINT8 264 tunFilter4.Weight.Value = uintptr(11) 265 err = winsys.FwpmFilterAdd0(engine, &tunFilter4, 0, &filterId) 266 if err != nil { 267 return os.NewSyscallError("FwpmFilterAdd0", err) 268 } 269 } 270 271 if len(t.options.Inet6Address) > 0 { 272 tunFilter6 := winsys.FWPM_FILTER0{} 273 tunFilter6.FilterCondition = &tunCondition[0] 274 tunFilter6.NumFilterConditions = 1 275 tunFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv6") 276 tunFilter6.SubLayerKey = subLayerKey 277 tunFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 278 tunFilter6.Action.Type = winsys.FWP_ACTION_PERMIT 279 tunFilter6.Weight.Type = winsys.FWP_UINT8 280 tunFilter6.Weight.Value = uintptr(11) 281 err = winsys.FwpmFilterAdd0(engine, &tunFilter6, 0, &filterId) 282 if err != nil { 283 return os.NewSyscallError("FwpmFilterAdd0", err) 284 } 285 } 286 287 blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 2) 288 blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_PROTOCOL 289 blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 290 blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT8 291 blockDNSCondition[0].ConditionValue.Value = uintptr(uint8(winsys.IPPROTO_UDP)) 292 blockDNSCondition[1].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT 293 blockDNSCondition[1].MatchType = winsys.FWP_MATCH_EQUAL 294 blockDNSCondition[1].ConditionValue.Type = winsys.FWP_UINT16 295 blockDNSCondition[1].ConditionValue.Value = uintptr(uint16(53)) 296 297 blockDNSFilter4 := winsys.FWPM_FILTER0{} 298 blockDNSFilter4.FilterCondition = &blockDNSCondition[0] 299 blockDNSFilter4.NumFilterConditions = 2 300 blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns") 301 blockDNSFilter4.SubLayerKey = subLayerKey 302 blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 303 blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK 304 blockDNSFilter4.Weight.Type = winsys.FWP_UINT8 305 blockDNSFilter4.Weight.Value = uintptr(10) 306 err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId) 307 if err != nil { 308 return os.NewSyscallError("FwpmFilterAdd0", err) 309 } 310 311 blockDNSFilter6 := winsys.FWPM_FILTER0{} 312 blockDNSFilter6.FilterCondition = &blockDNSCondition[0] 313 blockDNSFilter6.NumFilterConditions = 2 314 blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns") 315 blockDNSFilter6.SubLayerKey = subLayerKey 316 blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 317 blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK 318 blockDNSFilter6.Weight.Type = winsys.FWP_UINT8 319 blockDNSFilter6.Weight.Value = uintptr(10) 320 err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId) 321 if err != nil { 322 return os.NewSyscallError("FwpmFilterAdd0", err) 323 } 324 } 325 326 return nil 327 } 328 329 func (t *NativeTun) Read(p []byte) (n int, err error) { 330 return 0, os.ErrInvalid 331 } 332 333 func (t *NativeTun) ReadPacket() ([]byte, func(), error) { 334 t.running.Add(1) 335 defer t.running.Done() 336 retry: 337 if t.close.Load() == 1 { 338 return nil, nil, os.ErrClosed 339 } 340 start := nanotime() 341 shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 342 for { 343 if t.close.Load() == 1 { 344 return nil, nil, os.ErrClosed 345 } 346 packet, err := t.session.ReceivePacket() 347 switch err { 348 case nil: 349 packetSize := len(packet) 350 t.rate.update(uint64(packetSize)) 351 return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil 352 case windows.ERROR_NO_MORE_ITEMS: 353 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 354 windows.WaitForSingleObject(t.readWait, windows.INFINITE) 355 goto retry 356 } 357 procyield(1) 358 continue 359 case windows.ERROR_HANDLE_EOF: 360 return nil, nil, os.ErrClosed 361 case windows.ERROR_INVALID_DATA: 362 return nil, nil, errors.New("send ring corrupt") 363 } 364 return nil, nil, fmt.Errorf("read failed: %w", err) 365 } 366 } 367 368 func (t *NativeTun) ReadFunc(block func(b []byte)) error { 369 t.running.Add(1) 370 defer t.running.Done() 371 retry: 372 if t.close.Load() == 1 { 373 return os.ErrClosed 374 } 375 start := nanotime() 376 shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 377 for { 378 if t.close.Load() == 1 { 379 return os.ErrClosed 380 } 381 packet, err := t.session.ReceivePacket() 382 switch err { 383 case nil: 384 packetSize := len(packet) 385 block(packet) 386 t.session.ReleaseReceivePacket(packet) 387 t.rate.update(uint64(packetSize)) 388 return nil 389 case windows.ERROR_NO_MORE_ITEMS: 390 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 391 windows.WaitForSingleObject(t.readWait, windows.INFINITE) 392 goto retry 393 } 394 procyield(1) 395 continue 396 case windows.ERROR_HANDLE_EOF: 397 return os.ErrClosed 398 case windows.ERROR_INVALID_DATA: 399 return errors.New("send ring corrupt") 400 } 401 return fmt.Errorf("read failed: %w", err) 402 } 403 } 404 405 func (t *NativeTun) Write(p []byte) (n int, err error) { 406 t.running.Add(1) 407 defer t.running.Done() 408 if t.close.Load() == 1 { 409 return 0, os.ErrClosed 410 } 411 t.rate.update(uint64(len(p))) 412 packet, err := t.session.AllocateSendPacket(len(p)) 413 copy(packet, p) 414 if err == nil { 415 t.session.SendPacket(packet) 416 return len(p), nil 417 } 418 switch err { 419 case windows.ERROR_HANDLE_EOF: 420 return 0, os.ErrClosed 421 case windows.ERROR_BUFFER_OVERFLOW: 422 return 0, nil // Dropping when ring is full. 423 } 424 return 0, fmt.Errorf("write failed: %w", err) 425 } 426 427 func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) { 428 t.running.Add(1) 429 defer t.running.Done() 430 if t.close.Load() == 1 { 431 return 0, os.ErrClosed 432 } 433 var packetSize int 434 for _, packetElement := range packetElementList { 435 packetSize += len(packetElement) 436 } 437 t.rate.update(uint64(packetSize)) 438 packet, err := t.session.AllocateSendPacket(packetSize) 439 if err == nil { 440 var index int 441 for _, packetElement := range packetElementList { 442 index += copy(packet[index:], packetElement) 443 } 444 t.session.SendPacket(packet) 445 return 446 } 447 switch err { 448 case windows.ERROR_HANDLE_EOF: 449 return 0, os.ErrClosed 450 case windows.ERROR_BUFFER_OVERFLOW: 451 return 0, nil // Dropping when ring is full. 452 } 453 return 0, fmt.Errorf("write failed: %w", err) 454 } 455 456 func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { 457 defer buf.ReleaseMulti(buffers) 458 return common.Error(t.write(buf.ToSliceMulti(buffers))) 459 } 460 461 func (t *NativeTun) Close() error { 462 var err error 463 t.closeOnce.Do(func() { 464 t.close.Store(1) 465 windows.SetEvent(t.readWait) 466 t.running.Wait() 467 t.session.End() 468 t.adapter.Close() 469 if t.fwpmSession != 0 { 470 winsys.FwpmEngineClose0(t.fwpmSession) 471 } 472 if t.options.AutoRoute { 473 windnsapi.FlushResolverCache() 474 } 475 }) 476 return err 477 } 478 479 func generateGUIDByDeviceName(name string) *windows.GUID { 480 hash := md5.New() 481 hash.Write([]byte("wintun")) 482 hash.Write([]byte(name)) 483 sum := hash.Sum(nil) 484 return (*windows.GUID)(unsafe.Pointer(&sum[0])) 485 } 486 487 //go:linkname procyield runtime.procyield 488 func procyield(cycles uint32) 489 490 //go:linkname nanotime runtime.nanotime 491 func nanotime() int64 492 493 type rateJuggler struct { 494 current atomic.Uint64 495 nextByteCount atomic.Uint64 496 nextStartTime atomic.Int64 497 changing atomic.Int32 498 } 499 500 func (rate *rateJuggler) update(packetLen uint64) { 501 now := nanotime() 502 total := rate.nextByteCount.Add(packetLen) 503 period := uint64(now - rate.nextStartTime.Load()) 504 if period >= rateMeasurementGranularity { 505 if !rate.changing.CompareAndSwap(0, 1) { 506 return 507 } 508 rate.nextStartTime.Store(now) 509 rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) 510 rate.nextByteCount.Store(0) 511 rate.changing.Store(0) 512 } 513 } 514 515 const ( 516 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 517 spinloopRateThreshold = 800000000 / 8 // 800mbps 518 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 519 )