github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/tun_windows.go (about) 1 package tun 2 3 import ( 4 "crypto/md5" 5 "errors" 6 "fmt" 7 "math" 8 "net" 9 "net/netip" 10 "os" 11 "sync" 12 "time" 13 "unsafe" 14 15 "github.com/sagernet/sing/common" 16 "github.com/sagernet/sing/common/atomic" 17 "github.com/sagernet/sing/common/buf" 18 E "github.com/sagernet/sing/common/exceptions" 19 "github.com/sagernet/sing/common/windnsapi" 20 21 "github.com/metacubex/sing-tun/internal/winipcfg" 22 "github.com/metacubex/sing-tun/internal/winsys" 23 "github.com/metacubex/sing-tun/internal/wintun" 24 25 "golang.org/x/sys/windows" 26 ) 27 28 var TunnelType = "sing-tun" 29 30 type NativeTun struct { 31 adapter *wintun.Adapter 32 options Options 33 session wintun.Session 34 readWait windows.Handle 35 rate rateJuggler 36 running sync.WaitGroup 37 closeOnce sync.Once 38 close atomic.Int32 39 fwpmSession uintptr 40 } 41 42 func New(options Options) (WinTun, error) { 43 if options.FileDescriptor != 0 { 44 return nil, os.ErrInvalid 45 } 46 adapter, err := wintun.CreateAdapter(options.Name, TunnelType, generateGUIDByDeviceName(options.Name)) 47 if err != nil { 48 return nil, err 49 } 50 nativeTun := &NativeTun{ 51 adapter: adapter, 52 options: options, 53 } 54 session, err := adapter.StartSession(0x800000) 55 if err != nil { 56 return nil, err 57 } 58 nativeTun.session = session 59 nativeTun.readWait = session.ReadWaitEvent() 60 err = nativeTun.configure() 61 if err != nil { 62 session.End() 63 adapter.Close() 64 return nil, err 65 } 66 return nativeTun, nil 67 } 68 69 func (t *NativeTun) configure() error { 70 luid := winipcfg.LUID(t.adapter.LUID()) 71 if len(t.options.Inet4Address) > 0 { 72 err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET), t.options.Inet4Address) 73 if err != nil { 74 return E.Cause(err, "set ipv4 address") 75 } 76 err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil) 77 if err != nil { 78 return E.Cause(err, "set ipv4 dns") 79 } 80 } 81 if len(t.options.Inet6Address) > 0 { 82 err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address) 83 if err != nil { 84 return E.Cause(err, "set ipv6 address") 85 } 86 err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil) 87 if err != nil { 88 return E.Cause(err, "set ipv6 dns") 89 } 90 } 91 if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 { 92 _ = luid.DisableDNSRegistration() 93 } 94 if t.options.AutoRoute { 95 routeRanges, err := t.options.BuildAutoRouteRanges(false) 96 if err != nil { 97 return err 98 } 99 for _, routeRange := range routeRanges { 100 if routeRange.Addr().Is4() { 101 err = luid.AddRoute(routeRange, netip.IPv4Unspecified(), 0) 102 } else { 103 err = luid.AddRoute(routeRange, netip.IPv6Unspecified(), 0) 104 } 105 } 106 err = windnsapi.FlushResolverCache() 107 if err != nil { 108 return err 109 } 110 } 111 if len(t.options.Inet4Address) > 0 { 112 inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) 113 if err != nil { 114 return err 115 } 116 inetIf.ForwardingEnabled = true 117 inetIf.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled 118 inetIf.DadTransmits = 0 119 inetIf.ManagedAddressConfigurationSupported = false 120 inetIf.OtherStatefulConfigurationSupported = false 121 inetIf.NLMTU = t.options.MTU 122 if t.options.AutoRoute { 123 inetIf.UseAutomaticMetric = false 124 inetIf.Metric = 0 125 } 126 err = inetIf.Set() 127 if err != nil { 128 return E.Cause(err, "set ipv4 options") 129 } 130 } 131 if len(t.options.Inet6Address) > 0 { 132 inet6If, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6)) 133 if err != nil { 134 return err 135 } 136 inet6If.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled 137 inet6If.DadTransmits = 0 138 inet6If.ManagedAddressConfigurationSupported = false 139 inet6If.OtherStatefulConfigurationSupported = false 140 inet6If.NLMTU = t.options.MTU 141 if t.options.AutoRoute { 142 inet6If.UseAutomaticMetric = false 143 inet6If.Metric = 0 144 } 145 err = inet6If.Set() 146 if err != nil { 147 return E.Cause(err, "set ipv6 options") 148 } 149 } 150 151 if t.options.AutoRoute && t.options.StrictRoute { 152 var engine uintptr 153 session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC} 154 err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine)) 155 if err != nil { 156 return os.NewSyscallError("FwpmEngineOpen0", err) 157 } 158 t.fwpmSession = engine 159 160 subLayerKey, err := windows.GenerateGUID() 161 if err != nil { 162 return os.NewSyscallError("CoCreateGuid", err) 163 } 164 165 subLayer := winsys.FWPM_SUBLAYER0{} 166 subLayer.SubLayerKey = subLayerKey 167 subLayer.DisplayData = winsys.CreateDisplayData(TunnelType, "auto-route rules") 168 subLayer.Weight = math.MaxUint16 169 err = winsys.FwpmSubLayerAdd0(engine, &subLayer, 0) 170 if err != nil { 171 return os.NewSyscallError("FwpmSubLayerAdd0", err) 172 } 173 174 processAppID, err := winsys.GetCurrentProcessAppID() 175 if err != nil { 176 return err 177 } 178 defer winsys.FwpmFreeMemory0(unsafe.Pointer(&processAppID)) 179 180 var filterId uint64 181 permitCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) 182 permitCondition[0].FieldKey = winsys.FWPM_CONDITION_ALE_APP_ID 183 permitCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 184 permitCondition[0].ConditionValue.Type = winsys.FWP_BYTE_BLOB_TYPE 185 permitCondition[0].ConditionValue.Value = uintptr(unsafe.Pointer(processAppID)) 186 187 permitFilter4 := winsys.FWPM_FILTER0{} 188 permitFilter4.FilterCondition = &permitCondition[0] 189 permitFilter4.NumFilterConditions = 1 190 permitFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv4") 191 permitFilter4.SubLayerKey = subLayerKey 192 permitFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 193 permitFilter4.Action.Type = winsys.FWP_ACTION_PERMIT 194 permitFilter4.Weight.Type = winsys.FWP_UINT8 195 permitFilter4.Weight.Value = uintptr(13) 196 permitFilter4.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT 197 err = winsys.FwpmFilterAdd0(engine, &permitFilter4, 0, &filterId) 198 if err != nil { 199 return os.NewSyscallError("FwpmFilterAdd0", err) 200 } 201 202 permitFilter6 := winsys.FWPM_FILTER0{} 203 permitFilter6.FilterCondition = &permitCondition[0] 204 permitFilter6.NumFilterConditions = 1 205 permitFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv6") 206 permitFilter6.SubLayerKey = subLayerKey 207 permitFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 208 permitFilter6.Action.Type = winsys.FWP_ACTION_PERMIT 209 permitFilter6.Weight.Type = winsys.FWP_UINT8 210 permitFilter6.Weight.Value = uintptr(13) 211 permitFilter6.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT 212 err = winsys.FwpmFilterAdd0(engine, &permitFilter6, 0, &filterId) 213 if err != nil { 214 return os.NewSyscallError("FwpmFilterAdd0", err) 215 } 216 217 /*if len(t.options.Inet4Address) == 0 { 218 blockFilter := winsys.FWPM_FILTER0{} 219 blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4") 220 blockFilter.SubLayerKey = subLayerKey 221 blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 222 blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK 223 blockFilter.Weight.Type = winsys.FWP_UINT8 224 blockFilter.Weight.Value = uintptr(12) 225 err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId) 226 if err != nil { 227 return os.NewSyscallError("FwpmFilterAdd0", err) 228 } 229 }*/ 230 231 if len(t.options.Inet6Address) == 0 { 232 blockFilter := winsys.FWPM_FILTER0{} 233 blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6") 234 blockFilter.SubLayerKey = subLayerKey 235 blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 236 blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK 237 blockFilter.Weight.Type = winsys.FWP_UINT8 238 blockFilter.Weight.Value = uintptr(12) 239 err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId) 240 if err != nil { 241 return os.NewSyscallError("FwpmFilterAdd0", err) 242 } 243 } 244 245 netInterface, err := net.InterfaceByName(t.options.Name) 246 if err != nil { 247 return err 248 } 249 250 tunCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) 251 tunCondition[0].FieldKey = winsys.FWPM_CONDITION_LOCAL_INTERFACE_INDEX 252 tunCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 253 tunCondition[0].ConditionValue.Type = winsys.FWP_UINT32 254 tunCondition[0].ConditionValue.Value = uintptr(uint32(netInterface.Index)) 255 256 if len(t.options.Inet4Address) > 0 { 257 tunFilter4 := winsys.FWPM_FILTER0{} 258 tunFilter4.FilterCondition = &tunCondition[0] 259 tunFilter4.NumFilterConditions = 1 260 tunFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv4") 261 tunFilter4.SubLayerKey = subLayerKey 262 tunFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 263 tunFilter4.Action.Type = winsys.FWP_ACTION_PERMIT 264 tunFilter4.Weight.Type = winsys.FWP_UINT8 265 tunFilter4.Weight.Value = uintptr(11) 266 err = winsys.FwpmFilterAdd0(engine, &tunFilter4, 0, &filterId) 267 if err != nil { 268 return os.NewSyscallError("FwpmFilterAdd0", err) 269 } 270 } 271 272 if len(t.options.Inet6Address) > 0 { 273 tunFilter6 := winsys.FWPM_FILTER0{} 274 tunFilter6.FilterCondition = &tunCondition[0] 275 tunFilter6.NumFilterConditions = 1 276 tunFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv6") 277 tunFilter6.SubLayerKey = subLayerKey 278 tunFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 279 tunFilter6.Action.Type = winsys.FWP_ACTION_PERMIT 280 tunFilter6.Weight.Type = winsys.FWP_UINT8 281 tunFilter6.Weight.Value = uintptr(11) 282 err = winsys.FwpmFilterAdd0(engine, &tunFilter6, 0, &filterId) 283 if err != nil { 284 return os.NewSyscallError("FwpmFilterAdd0", err) 285 } 286 } 287 288 blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 2) 289 blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_PROTOCOL 290 blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 291 blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT8 292 blockDNSCondition[0].ConditionValue.Value = uintptr(uint8(winsys.IPPROTO_UDP)) 293 blockDNSCondition[1].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT 294 blockDNSCondition[1].MatchType = winsys.FWP_MATCH_EQUAL 295 blockDNSCondition[1].ConditionValue.Type = winsys.FWP_UINT16 296 blockDNSCondition[1].ConditionValue.Value = uintptr(uint16(53)) 297 298 blockDNSFilter4 := winsys.FWPM_FILTER0{} 299 blockDNSFilter4.FilterCondition = &blockDNSCondition[0] 300 blockDNSFilter4.NumFilterConditions = 2 301 blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns") 302 blockDNSFilter4.SubLayerKey = subLayerKey 303 blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 304 blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK 305 blockDNSFilter4.Weight.Type = winsys.FWP_UINT8 306 blockDNSFilter4.Weight.Value = uintptr(10) 307 err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId) 308 if err != nil { 309 return os.NewSyscallError("FwpmFilterAdd0", err) 310 } 311 312 blockDNSFilter6 := winsys.FWPM_FILTER0{} 313 blockDNSFilter6.FilterCondition = &blockDNSCondition[0] 314 blockDNSFilter6.NumFilterConditions = 2 315 blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns") 316 blockDNSFilter6.SubLayerKey = subLayerKey 317 blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 318 blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK 319 blockDNSFilter6.Weight.Type = winsys.FWP_UINT8 320 blockDNSFilter6.Weight.Value = uintptr(10) 321 err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId) 322 if err != nil { 323 return os.NewSyscallError("FwpmFilterAdd0", err) 324 } 325 } 326 327 return nil 328 } 329 330 func (t *NativeTun) Read(p []byte) (n int, err error) { 331 return 0, os.ErrInvalid 332 } 333 334 func (t *NativeTun) ReadPacket() ([]byte, func(), error) { 335 t.running.Add(1) 336 defer t.running.Done() 337 retry: 338 if t.close.Load() == 1 { 339 return nil, nil, os.ErrClosed 340 } 341 start := nanotime() 342 shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 343 for { 344 if t.close.Load() == 1 { 345 return nil, nil, os.ErrClosed 346 } 347 packet, err := t.session.ReceivePacket() 348 switch err { 349 case nil: 350 packetSize := len(packet) 351 t.rate.update(uint64(packetSize)) 352 return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil 353 case windows.ERROR_NO_MORE_ITEMS: 354 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 355 windows.WaitForSingleObject(t.readWait, windows.INFINITE) 356 goto retry 357 } 358 procyield(1) 359 continue 360 case windows.ERROR_HANDLE_EOF: 361 return nil, nil, os.ErrClosed 362 case windows.ERROR_INVALID_DATA: 363 return nil, nil, errors.New("send ring corrupt") 364 } 365 return nil, nil, fmt.Errorf("read failed: %w", err) 366 } 367 } 368 369 func (t *NativeTun) ReadFunc(block func(b []byte)) error { 370 t.running.Add(1) 371 defer t.running.Done() 372 retry: 373 if t.close.Load() == 1 { 374 return os.ErrClosed 375 } 376 start := nanotime() 377 shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 378 for { 379 if t.close.Load() == 1 { 380 return os.ErrClosed 381 } 382 packet, err := t.session.ReceivePacket() 383 switch err { 384 case nil: 385 packetSize := len(packet) 386 block(packet) 387 t.session.ReleaseReceivePacket(packet) 388 t.rate.update(uint64(packetSize)) 389 return nil 390 case windows.ERROR_NO_MORE_ITEMS: 391 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 392 windows.WaitForSingleObject(t.readWait, windows.INFINITE) 393 goto retry 394 } 395 procyield(1) 396 continue 397 case windows.ERROR_HANDLE_EOF: 398 return os.ErrClosed 399 case windows.ERROR_INVALID_DATA: 400 return errors.New("send ring corrupt") 401 } 402 return fmt.Errorf("read failed: %w", err) 403 } 404 } 405 406 func (t *NativeTun) Write(p []byte) (n int, err error) { 407 t.running.Add(1) 408 defer t.running.Done() 409 if t.close.Load() == 1 { 410 return 0, os.ErrClosed 411 } 412 t.rate.update(uint64(len(p))) 413 packet, err := t.session.AllocateSendPacket(len(p)) 414 copy(packet, p) 415 if err == nil { 416 t.session.SendPacket(packet) 417 return len(p), nil 418 } 419 switch err { 420 case windows.ERROR_HANDLE_EOF: 421 return 0, os.ErrClosed 422 case windows.ERROR_BUFFER_OVERFLOW: 423 return 0, nil // Dropping when ring is full. 424 } 425 return 0, fmt.Errorf("write failed: %w", err) 426 } 427 428 func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) { 429 t.running.Add(1) 430 defer t.running.Done() 431 if t.close.Load() == 1 { 432 return 0, os.ErrClosed 433 } 434 var packetSize int 435 for _, packetElement := range packetElementList { 436 packetSize += len(packetElement) 437 } 438 t.rate.update(uint64(packetSize)) 439 packet, err := t.session.AllocateSendPacket(packetSize) 440 if err == nil { 441 var index int 442 for _, packetElement := range packetElementList { 443 index += copy(packet[index:], packetElement) 444 } 445 t.session.SendPacket(packet) 446 return 447 } 448 switch err { 449 case windows.ERROR_HANDLE_EOF: 450 return 0, os.ErrClosed 451 case windows.ERROR_BUFFER_OVERFLOW: 452 return 0, nil // Dropping when ring is full. 453 } 454 return 0, fmt.Errorf("write failed: %w", err) 455 } 456 457 func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { 458 defer buf.ReleaseMulti(buffers) 459 return common.Error(t.write(buf.ToSliceMulti(buffers))) 460 } 461 462 func (t *NativeTun) Close() error { 463 var err error 464 t.closeOnce.Do(func() { 465 t.close.Store(1) 466 windows.SetEvent(t.readWait) 467 t.running.Wait() 468 t.session.End() 469 t.adapter.Close() 470 if t.fwpmSession != 0 { 471 winsys.FwpmEngineClose0(t.fwpmSession) 472 } 473 if t.options.AutoRoute { 474 windnsapi.FlushResolverCache() 475 } 476 }) 477 return err 478 } 479 480 func generateGUIDByDeviceName(name string) *windows.GUID { 481 hash := md5.New() 482 hash.Write([]byte("wintun")) 483 hash.Write([]byte(name)) 484 sum := hash.Sum(nil) 485 return (*windows.GUID)(unsafe.Pointer(&sum[0])) 486 } 487 488 //go:linkname procyield runtime.procyield 489 func procyield(cycles uint32) 490 491 //go:linkname nanotime runtime.nanotime 492 func nanotime() int64 493 494 type rateJuggler struct { 495 current atomic.Uint64 496 nextByteCount atomic.Uint64 497 nextStartTime atomic.Int64 498 changing atomic.Int32 499 } 500 501 func (rate *rateJuggler) update(packetLen uint64) { 502 now := nanotime() 503 total := rate.nextByteCount.Add(packetLen) 504 period := uint64(now - rate.nextStartTime.Load()) 505 if period >= rateMeasurementGranularity { 506 if !rate.changing.CompareAndSwap(0, 1) { 507 return 508 } 509 rate.nextStartTime.Store(now) 510 rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) 511 rate.nextByteCount.Store(0) 512 rate.changing.Store(0) 513 } 514 } 515 516 const ( 517 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 518 spinloopRateThreshold = 800000000 / 8 // 800mbps 519 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 520 )