github.com/apernet/sing-tun@v0.2.6-0.20240323130332-b9f6511036ad/tun_windows.go (about) 1 package tun 2 3 import ( 4 "crypto/md5" 5 "errors" 6 "fmt" 7 "math" 8 "net" 9 "net/netip" 10 "os" 11 "sync" 12 "time" 13 "unsafe" 14 15 "github.com/apernet/sing-tun/internal/winipcfg" 16 "github.com/apernet/sing-tun/internal/winsys" 17 "github.com/apernet/sing-tun/internal/wintun" 18 "github.com/sagernet/sing/common" 19 "github.com/sagernet/sing/common/atomic" 20 "github.com/sagernet/sing/common/buf" 21 E "github.com/sagernet/sing/common/exceptions" 22 "github.com/sagernet/sing/common/windnsapi" 23 24 "golang.org/x/sys/windows" 25 ) 26 27 var TunnelType = "sing-tun" 28 29 type NativeTun struct { 30 adapter *wintun.Adapter 31 options Options 32 session wintun.Session 33 readWait windows.Handle 34 rate rateJuggler 35 running sync.WaitGroup 36 closeOnce sync.Once 37 close atomic.Int32 38 fwpmSession uintptr 39 } 40 41 func New(options Options) (WinTun, error) { 42 if options.FileDescriptor != 0 { 43 return nil, os.ErrInvalid 44 } 45 adapter, err := wintun.CreateAdapter(options.Name, TunnelType, generateGUIDByDeviceName(options.Name)) 46 if err != nil { 47 return nil, err 48 } 49 nativeTun := &NativeTun{ 50 adapter: adapter, 51 options: options, 52 } 53 session, err := adapter.StartSession(0x800000) 54 if err != nil { 55 return nil, err 56 } 57 nativeTun.session = session 58 nativeTun.readWait = session.ReadWaitEvent() 59 err = nativeTun.configure() 60 if err != nil { 61 session.End() 62 adapter.Close() 63 return nil, err 64 } 65 return nativeTun, nil 66 } 67 68 func (t *NativeTun) configure() error { 69 luid := winipcfg.LUID(t.adapter.LUID()) 70 if len(t.options.Inet4Address) > 0 { 71 err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET), t.options.Inet4Address) 72 if err != nil { 73 return E.Cause(err, "set ipv4 address") 74 } 75 } 76 if len(t.options.Inet6Address) > 0 { 77 err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address) 78 if err != nil { 79 return E.Cause(err, "set ipv6 address") 80 } 81 } 82 if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 { 83 _ = luid.DisableDNSRegistration() 84 } 85 if t.options.AutoRoute { 86 routeRanges, err := t.options.BuildAutoRouteRanges(false) 87 if err != nil { 88 return err 89 } 90 for _, routeRange := range routeRanges { 91 if routeRange.Addr().Is4() { 92 err = luid.AddRoute(routeRange, netip.IPv4Unspecified(), 0) 93 } else { 94 err = luid.AddRoute(routeRange, netip.IPv6Unspecified(), 0) 95 } 96 } 97 err = windnsapi.FlushResolverCache() 98 if err != nil { 99 return err 100 } 101 } 102 if len(t.options.Inet4Address) > 0 { 103 inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) 104 if err != nil { 105 return err 106 } 107 inetIf.ForwardingEnabled = true 108 inetIf.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled 109 inetIf.DadTransmits = 0 110 inetIf.ManagedAddressConfigurationSupported = false 111 inetIf.OtherStatefulConfigurationSupported = false 112 inetIf.NLMTU = t.options.MTU 113 if t.options.AutoRoute { 114 inetIf.UseAutomaticMetric = false 115 inetIf.Metric = 0 116 } 117 err = inetIf.Set() 118 if err != nil { 119 return E.Cause(err, "set ipv4 options") 120 } 121 } 122 if len(t.options.Inet6Address) > 0 { 123 inet6If, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6)) 124 if err != nil { 125 return err 126 } 127 inet6If.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled 128 inet6If.DadTransmits = 0 129 inet6If.ManagedAddressConfigurationSupported = false 130 inet6If.OtherStatefulConfigurationSupported = false 131 inet6If.NLMTU = t.options.MTU 132 if t.options.AutoRoute { 133 inet6If.UseAutomaticMetric = false 134 inet6If.Metric = 0 135 } 136 err = inet6If.Set() 137 if err != nil { 138 return E.Cause(err, "set ipv6 options") 139 } 140 } 141 142 if t.options.AutoRoute && t.options.StrictRoute { 143 var engine uintptr 144 session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC} 145 err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine)) 146 if err != nil { 147 return os.NewSyscallError("FwpmEngineOpen0", err) 148 } 149 t.fwpmSession = engine 150 151 subLayerKey, err := windows.GenerateGUID() 152 if err != nil { 153 return os.NewSyscallError("CoCreateGuid", err) 154 } 155 156 subLayer := winsys.FWPM_SUBLAYER0{} 157 subLayer.SubLayerKey = subLayerKey 158 subLayer.DisplayData = winsys.CreateDisplayData(TunnelType, "auto-route rules") 159 subLayer.Weight = math.MaxUint16 160 err = winsys.FwpmSubLayerAdd0(engine, &subLayer, 0) 161 if err != nil { 162 return os.NewSyscallError("FwpmSubLayerAdd0", err) 163 } 164 165 processAppID, err := winsys.GetCurrentProcessAppID() 166 if err != nil { 167 return err 168 } 169 defer winsys.FwpmFreeMemory0(unsafe.Pointer(&processAppID)) 170 171 var filterId uint64 172 permitCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) 173 permitCondition[0].FieldKey = winsys.FWPM_CONDITION_ALE_APP_ID 174 permitCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 175 permitCondition[0].ConditionValue.Type = winsys.FWP_BYTE_BLOB_TYPE 176 permitCondition[0].ConditionValue.Value = uintptr(unsafe.Pointer(processAppID)) 177 178 permitFilter4 := winsys.FWPM_FILTER0{} 179 permitFilter4.FilterCondition = &permitCondition[0] 180 permitFilter4.NumFilterConditions = 1 181 permitFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv4") 182 permitFilter4.SubLayerKey = subLayerKey 183 permitFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 184 permitFilter4.Action.Type = winsys.FWP_ACTION_PERMIT 185 permitFilter4.Weight.Type = winsys.FWP_UINT8 186 permitFilter4.Weight.Value = uintptr(13) 187 permitFilter4.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT 188 err = winsys.FwpmFilterAdd0(engine, &permitFilter4, 0, &filterId) 189 if err != nil { 190 return os.NewSyscallError("FwpmFilterAdd0", err) 191 } 192 193 permitFilter6 := winsys.FWPM_FILTER0{} 194 permitFilter6.FilterCondition = &permitCondition[0] 195 permitFilter6.NumFilterConditions = 1 196 permitFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv6") 197 permitFilter6.SubLayerKey = subLayerKey 198 permitFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 199 permitFilter6.Action.Type = winsys.FWP_ACTION_PERMIT 200 permitFilter6.Weight.Type = winsys.FWP_UINT8 201 permitFilter6.Weight.Value = uintptr(13) 202 permitFilter6.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT 203 err = winsys.FwpmFilterAdd0(engine, &permitFilter6, 0, &filterId) 204 if err != nil { 205 return os.NewSyscallError("FwpmFilterAdd0", err) 206 } 207 208 /*if len(t.options.Inet4Address) == 0 { 209 blockFilter := winsys.FWPM_FILTER0{} 210 blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4") 211 blockFilter.SubLayerKey = subLayerKey 212 blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 213 blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK 214 blockFilter.Weight.Type = winsys.FWP_UINT8 215 blockFilter.Weight.Value = uintptr(12) 216 err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId) 217 if err != nil { 218 return os.NewSyscallError("FwpmFilterAdd0", err) 219 } 220 }*/ 221 222 if len(t.options.Inet6Address) == 0 { 223 blockFilter := winsys.FWPM_FILTER0{} 224 blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6") 225 blockFilter.SubLayerKey = subLayerKey 226 blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 227 blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK 228 blockFilter.Weight.Type = winsys.FWP_UINT8 229 blockFilter.Weight.Value = uintptr(12) 230 err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId) 231 if err != nil { 232 return os.NewSyscallError("FwpmFilterAdd0", err) 233 } 234 } 235 236 netInterface, err := net.InterfaceByName(t.options.Name) 237 if err != nil { 238 return err 239 } 240 241 tunCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) 242 tunCondition[0].FieldKey = winsys.FWPM_CONDITION_LOCAL_INTERFACE_INDEX 243 tunCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 244 tunCondition[0].ConditionValue.Type = winsys.FWP_UINT32 245 tunCondition[0].ConditionValue.Value = uintptr(uint32(netInterface.Index)) 246 247 if len(t.options.Inet4Address) > 0 { 248 tunFilter4 := winsys.FWPM_FILTER0{} 249 tunFilter4.FilterCondition = &tunCondition[0] 250 tunFilter4.NumFilterConditions = 1 251 tunFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv4") 252 tunFilter4.SubLayerKey = subLayerKey 253 tunFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 254 tunFilter4.Action.Type = winsys.FWP_ACTION_PERMIT 255 tunFilter4.Weight.Type = winsys.FWP_UINT8 256 tunFilter4.Weight.Value = uintptr(11) 257 err = winsys.FwpmFilterAdd0(engine, &tunFilter4, 0, &filterId) 258 if err != nil { 259 return os.NewSyscallError("FwpmFilterAdd0", err) 260 } 261 } 262 263 if len(t.options.Inet6Address) > 0 { 264 tunFilter6 := winsys.FWPM_FILTER0{} 265 tunFilter6.FilterCondition = &tunCondition[0] 266 tunFilter6.NumFilterConditions = 1 267 tunFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv6") 268 tunFilter6.SubLayerKey = subLayerKey 269 tunFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 270 tunFilter6.Action.Type = winsys.FWP_ACTION_PERMIT 271 tunFilter6.Weight.Type = winsys.FWP_UINT8 272 tunFilter6.Weight.Value = uintptr(11) 273 err = winsys.FwpmFilterAdd0(engine, &tunFilter6, 0, &filterId) 274 if err != nil { 275 return os.NewSyscallError("FwpmFilterAdd0", err) 276 } 277 } 278 279 blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 2) 280 blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_PROTOCOL 281 blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 282 blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT8 283 blockDNSCondition[0].ConditionValue.Value = uintptr(uint8(winsys.IPPROTO_UDP)) 284 blockDNSCondition[1].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT 285 blockDNSCondition[1].MatchType = winsys.FWP_MATCH_EQUAL 286 blockDNSCondition[1].ConditionValue.Type = winsys.FWP_UINT16 287 blockDNSCondition[1].ConditionValue.Value = uintptr(uint16(53)) 288 289 blockDNSFilter4 := winsys.FWPM_FILTER0{} 290 blockDNSFilter4.FilterCondition = &blockDNSCondition[0] 291 blockDNSFilter4.NumFilterConditions = 2 292 blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns") 293 blockDNSFilter4.SubLayerKey = subLayerKey 294 blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 295 blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK 296 blockDNSFilter4.Weight.Type = winsys.FWP_UINT8 297 blockDNSFilter4.Weight.Value = uintptr(10) 298 err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId) 299 if err != nil { 300 return os.NewSyscallError("FwpmFilterAdd0", err) 301 } 302 303 blockDNSFilter6 := winsys.FWPM_FILTER0{} 304 blockDNSFilter6.FilterCondition = &blockDNSCondition[0] 305 blockDNSFilter6.NumFilterConditions = 2 306 blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns") 307 blockDNSFilter6.SubLayerKey = subLayerKey 308 blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 309 blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK 310 blockDNSFilter6.Weight.Type = winsys.FWP_UINT8 311 blockDNSFilter6.Weight.Value = uintptr(10) 312 err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId) 313 if err != nil { 314 return os.NewSyscallError("FwpmFilterAdd0", err) 315 } 316 } 317 318 return nil 319 } 320 321 func (t *NativeTun) Read(p []byte) (n int, err error) { 322 return 0, os.ErrInvalid 323 } 324 325 func (t *NativeTun) ReadPacket() ([]byte, func(), error) { 326 t.running.Add(1) 327 defer t.running.Done() 328 retry: 329 if t.close.Load() == 1 { 330 return nil, nil, os.ErrClosed 331 } 332 start := nanotime() 333 shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 334 for { 335 if t.close.Load() == 1 { 336 return nil, nil, os.ErrClosed 337 } 338 packet, err := t.session.ReceivePacket() 339 switch err { 340 case nil: 341 packetSize := len(packet) 342 t.rate.update(uint64(packetSize)) 343 return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil 344 case windows.ERROR_NO_MORE_ITEMS: 345 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 346 windows.WaitForSingleObject(t.readWait, windows.INFINITE) 347 goto retry 348 } 349 procyield(1) 350 continue 351 case windows.ERROR_HANDLE_EOF: 352 return nil, nil, os.ErrClosed 353 case windows.ERROR_INVALID_DATA: 354 return nil, nil, errors.New("send ring corrupt") 355 } 356 return nil, nil, fmt.Errorf("read failed: %w", err) 357 } 358 } 359 360 func (t *NativeTun) ReadFunc(block func(b []byte)) error { 361 t.running.Add(1) 362 defer t.running.Done() 363 retry: 364 if t.close.Load() == 1 { 365 return os.ErrClosed 366 } 367 start := nanotime() 368 shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 369 for { 370 if t.close.Load() == 1 { 371 return os.ErrClosed 372 } 373 packet, err := t.session.ReceivePacket() 374 switch err { 375 case nil: 376 packetSize := len(packet) 377 block(packet) 378 t.session.ReleaseReceivePacket(packet) 379 t.rate.update(uint64(packetSize)) 380 return nil 381 case windows.ERROR_NO_MORE_ITEMS: 382 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 383 windows.WaitForSingleObject(t.readWait, windows.INFINITE) 384 goto retry 385 } 386 procyield(1) 387 continue 388 case windows.ERROR_HANDLE_EOF: 389 return os.ErrClosed 390 case windows.ERROR_INVALID_DATA: 391 return errors.New("send ring corrupt") 392 } 393 return fmt.Errorf("read failed: %w", err) 394 } 395 } 396 397 func (t *NativeTun) Write(p []byte) (n int, err error) { 398 t.running.Add(1) 399 defer t.running.Done() 400 if t.close.Load() == 1 { 401 return 0, os.ErrClosed 402 } 403 t.rate.update(uint64(len(p))) 404 packet, err := t.session.AllocateSendPacket(len(p)) 405 copy(packet, p) 406 if err == nil { 407 t.session.SendPacket(packet) 408 return len(p), nil 409 } 410 switch err { 411 case windows.ERROR_HANDLE_EOF: 412 return 0, os.ErrClosed 413 case windows.ERROR_BUFFER_OVERFLOW: 414 return 0, nil // Dropping when ring is full. 415 } 416 return 0, fmt.Errorf("write failed: %w", err) 417 } 418 419 func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) { 420 t.running.Add(1) 421 defer t.running.Done() 422 if t.close.Load() == 1 { 423 return 0, os.ErrClosed 424 } 425 var packetSize int 426 for _, packetElement := range packetElementList { 427 packetSize += len(packetElement) 428 } 429 t.rate.update(uint64(packetSize)) 430 packet, err := t.session.AllocateSendPacket(packetSize) 431 if err == nil { 432 var index int 433 for _, packetElement := range packetElementList { 434 index += copy(packet[index:], packetElement) 435 } 436 t.session.SendPacket(packet) 437 return 438 } 439 switch err { 440 case windows.ERROR_HANDLE_EOF: 441 return 0, os.ErrClosed 442 case windows.ERROR_BUFFER_OVERFLOW: 443 return 0, nil // Dropping when ring is full. 444 } 445 return 0, fmt.Errorf("write failed: %w", err) 446 } 447 448 func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { 449 defer buf.ReleaseMulti(buffers) 450 return common.Error(t.write(buf.ToSliceMulti(buffers))) 451 } 452 453 func (t *NativeTun) Close() error { 454 var err error 455 t.closeOnce.Do(func() { 456 t.close.Store(1) 457 windows.SetEvent(t.readWait) 458 t.running.Wait() 459 t.session.End() 460 t.adapter.Close() 461 if t.fwpmSession != 0 { 462 winsys.FwpmEngineClose0(t.fwpmSession) 463 } 464 if t.options.AutoRoute { 465 windnsapi.FlushResolverCache() 466 } 467 }) 468 return err 469 } 470 471 func generateGUIDByDeviceName(name string) *windows.GUID { 472 hash := md5.New() 473 hash.Write([]byte("wintun")) 474 hash.Write([]byte(name)) 475 sum := hash.Sum(nil) 476 return (*windows.GUID)(unsafe.Pointer(&sum[0])) 477 } 478 479 //go:linkname procyield runtime.procyield 480 func procyield(cycles uint32) 481 482 //go:linkname nanotime runtime.nanotime 483 func nanotime() int64 484 485 type rateJuggler struct { 486 current atomic.Uint64 487 nextByteCount atomic.Uint64 488 nextStartTime atomic.Int64 489 changing atomic.Int32 490 } 491 492 func (rate *rateJuggler) update(packetLen uint64) { 493 now := nanotime() 494 total := rate.nextByteCount.Add(packetLen) 495 period := uint64(now - rate.nextStartTime.Load()) 496 if period >= rateMeasurementGranularity { 497 if !rate.changing.CompareAndSwap(0, 1) { 498 return 499 } 500 rate.nextStartTime.Store(now) 501 rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) 502 rate.nextByteCount.Store(0) 503 rate.changing.Store(0) 504 } 505 } 506 507 const ( 508 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 509 spinloopRateThreshold = 800000000 / 8 // 800mbps 510 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 511 )