github.com/MerlinKodo/sing-tun@v0.1.15/tun_windows.go (about) 1 package tun 2 3 import ( 4 "crypto/md5" 5 "errors" 6 "fmt" 7 "math" 8 "net" 9 "net/netip" 10 "os" 11 "sync" 12 "sync/atomic" 13 "time" 14 "unsafe" 15 16 "github.com/sagernet/sing/common" 17 "github.com/sagernet/sing/common/buf" 18 E "github.com/sagernet/sing/common/exceptions" 19 N "github.com/sagernet/sing/common/network" 20 "github.com/sagernet/sing/common/windnsapi" 21 22 "github.com/MerlinKodo/sing-tun/internal/winipcfg" 23 "github.com/MerlinKodo/sing-tun/internal/winsys" 24 "github.com/MerlinKodo/sing-tun/internal/wintun" 25 26 "golang.org/x/sys/windows" 27 ) 28 29 var TunnelType = "sing-tun" 30 31 type NativeTun struct { 32 adapter *wintun.Adapter 33 options Options 34 session wintun.Session 35 readWait windows.Handle 36 rate rateJuggler 37 running sync.WaitGroup 38 closeOnce sync.Once 39 close int32 40 fwpmSession uintptr 41 } 42 43 func New(options Options) (WinTun, error) { 44 if options.FileDescriptor != 0 { 45 return nil, os.ErrInvalid 46 } 47 adapter, err := wintun.CreateAdapter(options.Name, TunnelType, generateGUIDByDeviceName(options.Name)) 48 if err != nil { 49 return nil, err 50 } 51 nativeTun := &NativeTun{ 52 adapter: adapter, 53 options: options, 54 } 55 session, err := adapter.StartSession(0x800000) 56 if err != nil { 57 return nil, err 58 } 59 nativeTun.session = session 60 nativeTun.readWait = session.ReadWaitEvent() 61 err = nativeTun.configure() 62 if err != nil { 63 session.End() 64 adapter.Close() 65 return nil, err 66 } 67 return nativeTun, nil 68 } 69 70 func (t *NativeTun) configure() error { 71 luid := winipcfg.LUID(t.adapter.LUID()) 72 if len(t.options.Inet4Address) > 0 { 73 err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET), t.options.Inet4Address) 74 if err != nil { 75 return E.Cause(err, "set ipv4 address") 76 } 77 err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil) 78 if err != nil { 79 return E.Cause(err, "set ipv4 dns") 80 } 81 } 82 if len(t.options.Inet6Address) > 0 { 83 err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address) 84 if err != nil { 85 return E.Cause(err, "set ipv6 address") 86 } 87 err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil) 88 if err != nil { 89 return E.Cause(err, "set ipv6 dns") 90 } 91 } 92 if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 { 93 _ = luid.DisableDNSRegistration() 94 } 95 if t.options.AutoRoute { 96 routeRanges, err := t.options.BuildAutoRouteRanges() 97 if err != nil { 98 return err 99 } 100 for _, routeRange := range routeRanges { 101 if routeRange.Addr().Is4() { 102 err = luid.AddRoute(routeRange, netip.IPv4Unspecified(), 0) 103 } else { 104 err = luid.AddRoute(routeRange, netip.IPv6Unspecified(), 0) 105 } 106 } 107 err = windnsapi.FlushResolverCache() 108 if err != nil { 109 return err 110 } 111 } 112 if len(t.options.Inet4Address) > 0 { 113 inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET)) 114 if err != nil { 115 return err 116 } 117 inetIf.ForwardingEnabled = true 118 inetIf.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled 119 inetIf.DadTransmits = 0 120 inetIf.ManagedAddressConfigurationSupported = false 121 inetIf.OtherStatefulConfigurationSupported = false 122 inetIf.NLMTU = t.options.MTU 123 if t.options.AutoRoute { 124 inetIf.UseAutomaticMetric = false 125 inetIf.Metric = 0 126 } 127 err = inetIf.Set() 128 if err != nil { 129 return E.Cause(err, "set ipv4 options") 130 } 131 } 132 if len(t.options.Inet6Address) > 0 { 133 inet6If, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6)) 134 if err != nil { 135 return err 136 } 137 inet6If.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled 138 inet6If.DadTransmits = 0 139 inet6If.ManagedAddressConfigurationSupported = false 140 inet6If.OtherStatefulConfigurationSupported = false 141 inet6If.NLMTU = t.options.MTU 142 if t.options.AutoRoute { 143 inet6If.UseAutomaticMetric = false 144 inet6If.Metric = 0 145 } 146 err = inet6If.Set() 147 if err != nil { 148 return E.Cause(err, "set ipv6 options") 149 } 150 } 151 152 if t.options.AutoRoute && t.options.StrictRoute { 153 var engine uintptr 154 session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC} 155 err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine)) 156 if err != nil { 157 return os.NewSyscallError("FwpmEngineOpen0", err) 158 } 159 t.fwpmSession = engine 160 161 subLayerKey, err := windows.GenerateGUID() 162 if err != nil { 163 return os.NewSyscallError("CoCreateGuid", err) 164 } 165 166 subLayer := winsys.FWPM_SUBLAYER0{} 167 subLayer.SubLayerKey = subLayerKey 168 subLayer.DisplayData = winsys.CreateDisplayData(TunnelType, "auto-route rules") 169 subLayer.Weight = math.MaxUint16 170 err = winsys.FwpmSubLayerAdd0(engine, &subLayer, 0) 171 if err != nil { 172 return os.NewSyscallError("FwpmSubLayerAdd0", err) 173 } 174 175 processAppID, err := winsys.GetCurrentProcessAppID() 176 if err != nil { 177 return err 178 } 179 defer winsys.FwpmFreeMemory0(unsafe.Pointer(&processAppID)) 180 181 var filterId uint64 182 permitCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) 183 permitCondition[0].FieldKey = winsys.FWPM_CONDITION_ALE_APP_ID 184 permitCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 185 permitCondition[0].ConditionValue.Type = winsys.FWP_BYTE_BLOB_TYPE 186 permitCondition[0].ConditionValue.Value = uintptr(unsafe.Pointer(processAppID)) 187 188 permitFilter4 := winsys.FWPM_FILTER0{} 189 permitFilter4.FilterCondition = &permitCondition[0] 190 permitFilter4.NumFilterConditions = 1 191 permitFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv4") 192 permitFilter4.SubLayerKey = subLayerKey 193 permitFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 194 permitFilter4.Action.Type = winsys.FWP_ACTION_PERMIT 195 permitFilter4.Weight.Type = winsys.FWP_UINT8 196 permitFilter4.Weight.Value = uintptr(13) 197 permitFilter4.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT 198 err = winsys.FwpmFilterAdd0(engine, &permitFilter4, 0, &filterId) 199 if err != nil { 200 return os.NewSyscallError("FwpmFilterAdd0", err) 201 } 202 203 permitFilter6 := winsys.FWPM_FILTER0{} 204 permitFilter6.FilterCondition = &permitCondition[0] 205 permitFilter6.NumFilterConditions = 1 206 permitFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv6") 207 permitFilter6.SubLayerKey = subLayerKey 208 permitFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 209 permitFilter6.Action.Type = winsys.FWP_ACTION_PERMIT 210 permitFilter6.Weight.Type = winsys.FWP_UINT8 211 permitFilter6.Weight.Value = uintptr(13) 212 permitFilter6.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT 213 err = winsys.FwpmFilterAdd0(engine, &permitFilter6, 0, &filterId) 214 if err != nil { 215 return os.NewSyscallError("FwpmFilterAdd0", err) 216 } 217 218 /*if len(t.options.Inet4Address) == 0 { 219 blockFilter := winsys.FWPM_FILTER0{} 220 blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4") 221 blockFilter.SubLayerKey = subLayerKey 222 blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 223 blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK 224 blockFilter.Weight.Type = winsys.FWP_UINT8 225 blockFilter.Weight.Value = uintptr(12) 226 err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId) 227 if err != nil { 228 return os.NewSyscallError("FwpmFilterAdd0", err) 229 } 230 }*/ 231 232 if len(t.options.Inet6Address) == 0 { 233 blockFilter := winsys.FWPM_FILTER0{} 234 blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6") 235 blockFilter.SubLayerKey = subLayerKey 236 blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 237 blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK 238 blockFilter.Weight.Type = winsys.FWP_UINT8 239 blockFilter.Weight.Value = uintptr(12) 240 err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId) 241 if err != nil { 242 return os.NewSyscallError("FwpmFilterAdd0", err) 243 } 244 } 245 246 netInterface, err := net.InterfaceByName(t.options.Name) 247 if err != nil { 248 return err 249 } 250 251 tunCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) 252 tunCondition[0].FieldKey = winsys.FWPM_CONDITION_LOCAL_INTERFACE_INDEX 253 tunCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 254 tunCondition[0].ConditionValue.Type = winsys.FWP_UINT32 255 tunCondition[0].ConditionValue.Value = uintptr(uint32(netInterface.Index)) 256 257 if len(t.options.Inet4Address) > 0 { 258 tunFilter4 := winsys.FWPM_FILTER0{} 259 tunFilter4.FilterCondition = &tunCondition[0] 260 tunFilter4.NumFilterConditions = 1 261 tunFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv4") 262 tunFilter4.SubLayerKey = subLayerKey 263 tunFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 264 tunFilter4.Action.Type = winsys.FWP_ACTION_PERMIT 265 tunFilter4.Weight.Type = winsys.FWP_UINT8 266 tunFilter4.Weight.Value = uintptr(11) 267 err = winsys.FwpmFilterAdd0(engine, &tunFilter4, 0, &filterId) 268 if err != nil { 269 return os.NewSyscallError("FwpmFilterAdd0", err) 270 } 271 } 272 273 if len(t.options.Inet6Address) > 0 { 274 tunFilter6 := winsys.FWPM_FILTER0{} 275 tunFilter6.FilterCondition = &tunCondition[0] 276 tunFilter6.NumFilterConditions = 1 277 tunFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv6") 278 tunFilter6.SubLayerKey = subLayerKey 279 tunFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 280 tunFilter6.Action.Type = winsys.FWP_ACTION_PERMIT 281 tunFilter6.Weight.Type = winsys.FWP_UINT8 282 tunFilter6.Weight.Value = uintptr(11) 283 err = winsys.FwpmFilterAdd0(engine, &tunFilter6, 0, &filterId) 284 if err != nil { 285 return os.NewSyscallError("FwpmFilterAdd0", err) 286 } 287 } 288 289 blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 2) 290 blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_PROTOCOL 291 blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL 292 blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT8 293 blockDNSCondition[0].ConditionValue.Value = uintptr(uint8(winsys.IPPROTO_UDP)) 294 blockDNSCondition[1].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT 295 blockDNSCondition[1].MatchType = winsys.FWP_MATCH_EQUAL 296 blockDNSCondition[1].ConditionValue.Type = winsys.FWP_UINT16 297 blockDNSCondition[1].ConditionValue.Value = uintptr(uint16(53)) 298 299 blockDNSFilter4 := winsys.FWPM_FILTER0{} 300 blockDNSFilter4.FilterCondition = &blockDNSCondition[0] 301 blockDNSFilter4.NumFilterConditions = 2 302 blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns") 303 blockDNSFilter4.SubLayerKey = subLayerKey 304 blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 305 blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK 306 blockDNSFilter4.Weight.Type = winsys.FWP_UINT8 307 blockDNSFilter4.Weight.Value = uintptr(10) 308 err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId) 309 if err != nil { 310 return os.NewSyscallError("FwpmFilterAdd0", err) 311 } 312 313 blockDNSFilter6 := winsys.FWPM_FILTER0{} 314 blockDNSFilter6.FilterCondition = &blockDNSCondition[0] 315 blockDNSFilter6.NumFilterConditions = 2 316 blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns") 317 blockDNSFilter6.SubLayerKey = subLayerKey 318 blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 319 blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK 320 blockDNSFilter6.Weight.Type = winsys.FWP_UINT8 321 blockDNSFilter6.Weight.Value = uintptr(10) 322 err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId) 323 if err != nil { 324 return os.NewSyscallError("FwpmFilterAdd0", err) 325 } 326 } 327 328 return nil 329 } 330 331 func (t *NativeTun) Read(p []byte) (n int, err error) { 332 return 0, os.ErrInvalid 333 } 334 335 func (t *NativeTun) ReadPacket() ([]byte, func(), error) { 336 t.running.Add(1) 337 defer t.running.Done() 338 retry: 339 if atomic.LoadInt32(&t.close) == 1 { 340 return nil, nil, os.ErrClosed 341 } 342 start := nanotime() 343 shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2 344 for { 345 if atomic.LoadInt32(&t.close) == 1 { 346 return nil, nil, os.ErrClosed 347 } 348 packet, err := t.session.ReceivePacket() 349 switch err { 350 case nil: 351 packetSize := len(packet) 352 t.rate.update(uint64(packetSize)) 353 return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil 354 case windows.ERROR_NO_MORE_ITEMS: 355 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 356 windows.WaitForSingleObject(t.readWait, windows.INFINITE) 357 goto retry 358 } 359 procyield(1) 360 continue 361 case windows.ERROR_HANDLE_EOF: 362 return nil, nil, os.ErrClosed 363 case windows.ERROR_INVALID_DATA: 364 return nil, nil, errors.New("send ring corrupt") 365 } 366 return nil, nil, fmt.Errorf("read failed: %w", err) 367 } 368 } 369 370 func (t *NativeTun) ReadFunc(block func(b []byte)) error { 371 t.running.Add(1) 372 defer t.running.Done() 373 retry: 374 if atomic.LoadInt32(&t.close) == 1 { 375 return os.ErrClosed 376 } 377 start := nanotime() 378 shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2 379 for { 380 if atomic.LoadInt32(&t.close) == 1 { 381 return os.ErrClosed 382 } 383 packet, err := t.session.ReceivePacket() 384 switch err { 385 case nil: 386 packetSize := len(packet) 387 block(packet) 388 t.session.ReleaseReceivePacket(packet) 389 t.rate.update(uint64(packetSize)) 390 return nil 391 case windows.ERROR_NO_MORE_ITEMS: 392 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 393 windows.WaitForSingleObject(t.readWait, windows.INFINITE) 394 goto retry 395 } 396 procyield(1) 397 continue 398 case windows.ERROR_HANDLE_EOF: 399 return os.ErrClosed 400 case windows.ERROR_INVALID_DATA: 401 return errors.New("send ring corrupt") 402 } 403 return fmt.Errorf("read failed: %w", err) 404 } 405 } 406 407 func (t *NativeTun) Write(p []byte) (n int, err error) { 408 t.running.Add(1) 409 defer t.running.Done() 410 if atomic.LoadInt32(&t.close) == 1 { 411 return 0, os.ErrClosed 412 } 413 t.rate.update(uint64(len(p))) 414 packet, err := t.session.AllocateSendPacket(len(p)) 415 copy(packet, p) 416 if err == nil { 417 t.session.SendPacket(packet) 418 return len(p), nil 419 } 420 switch err { 421 case windows.ERROR_HANDLE_EOF: 422 return 0, os.ErrClosed 423 case windows.ERROR_BUFFER_OVERFLOW: 424 return 0, nil // Dropping when ring is full. 425 } 426 return 0, fmt.Errorf("write failed: %w", err) 427 } 428 429 func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) { 430 t.running.Add(1) 431 defer t.running.Done() 432 if atomic.LoadInt32(&t.close) == 1 { 433 return 0, os.ErrClosed 434 } 435 var packetSize int 436 for _, packetElement := range packetElementList { 437 packetSize += len(packetElement) 438 } 439 t.rate.update(uint64(packetSize)) 440 packet, err := t.session.AllocateSendPacket(packetSize) 441 if err == nil { 442 var index int 443 for _, packetElement := range packetElementList { 444 index += copy(packet[index:], packetElement) 445 } 446 t.session.SendPacket(packet) 447 return 448 } 449 switch err { 450 case windows.ERROR_HANDLE_EOF: 451 return 0, os.ErrClosed 452 case windows.ERROR_BUFFER_OVERFLOW: 453 return 0, nil // Dropping when ring is full. 454 } 455 return 0, fmt.Errorf("write failed: %w", err) 456 } 457 458 func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { 459 return t 460 } 461 462 func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { 463 defer buf.ReleaseMulti(buffers) 464 return common.Error(t.write(buf.ToSliceMulti(buffers))) 465 } 466 467 func (t *NativeTun) Close() error { 468 var err error 469 t.closeOnce.Do(func() { 470 atomic.StoreInt32(&t.close, 1) 471 windows.SetEvent(t.readWait) 472 t.running.Wait() 473 t.session.End() 474 t.adapter.Close() 475 if t.fwpmSession != 0 { 476 winsys.FwpmEngineClose0(t.fwpmSession) 477 } 478 if t.options.AutoRoute { 479 windnsapi.FlushResolverCache() 480 } 481 }) 482 return err 483 } 484 485 func generateGUIDByDeviceName(name string) *windows.GUID { 486 hash := md5.New() 487 hash.Write([]byte("wintun")) 488 hash.Write([]byte(name)) 489 sum := hash.Sum(nil) 490 return (*windows.GUID)(unsafe.Pointer(&sum[0])) 491 } 492 493 //go:linkname procyield runtime.procyield 494 func procyield(cycles uint32) 495 496 //go:linkname nanotime runtime.nanotime 497 func nanotime() int64 498 499 type rateJuggler struct { 500 current uint64 501 nextByteCount uint64 502 nextStartTime int64 503 changing int32 504 } 505 506 func (rate *rateJuggler) update(packetLen uint64) { 507 now := nanotime() 508 total := atomic.AddUint64(&rate.nextByteCount, packetLen) 509 period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) 510 if period >= rateMeasurementGranularity { 511 if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { 512 return 513 } 514 atomic.StoreInt64(&rate.nextStartTime, now) 515 atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) 516 atomic.StoreUint64(&rate.nextByteCount, 0) 517 atomic.StoreInt32(&rate.changing, 0) 518 } 519 } 520 521 const ( 522 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 523 spinloopRateThreshold = 800000000 / 8 // 800mbps 524 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 525 )