github.com/apernet/sing-tun@v0.2.6-0.20240323130332-b9f6511036ad/tun_windows.go (about)

     1  package tun
     2  
     3  import (
     4  	"crypto/md5"
     5  	"errors"
     6  	"fmt"
     7  	"math"
     8  	"net"
     9  	"net/netip"
    10  	"os"
    11  	"sync"
    12  	"time"
    13  	"unsafe"
    14  
    15  	"github.com/apernet/sing-tun/internal/winipcfg"
    16  	"github.com/apernet/sing-tun/internal/winsys"
    17  	"github.com/apernet/sing-tun/internal/wintun"
    18  	"github.com/sagernet/sing/common"
    19  	"github.com/sagernet/sing/common/atomic"
    20  	"github.com/sagernet/sing/common/buf"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	"github.com/sagernet/sing/common/windnsapi"
    23  
    24  	"golang.org/x/sys/windows"
    25  )
    26  
    27  var TunnelType = "sing-tun"
    28  
    29  type NativeTun struct {
    30  	adapter     *wintun.Adapter
    31  	options     Options
    32  	session     wintun.Session
    33  	readWait    windows.Handle
    34  	rate        rateJuggler
    35  	running     sync.WaitGroup
    36  	closeOnce   sync.Once
    37  	close       atomic.Int32
    38  	fwpmSession uintptr
    39  }
    40  
    41  func New(options Options) (WinTun, error) {
    42  	if options.FileDescriptor != 0 {
    43  		return nil, os.ErrInvalid
    44  	}
    45  	adapter, err := wintun.CreateAdapter(options.Name, TunnelType, generateGUIDByDeviceName(options.Name))
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	nativeTun := &NativeTun{
    50  		adapter: adapter,
    51  		options: options,
    52  	}
    53  	session, err := adapter.StartSession(0x800000)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	nativeTun.session = session
    58  	nativeTun.readWait = session.ReadWaitEvent()
    59  	err = nativeTun.configure()
    60  	if err != nil {
    61  		session.End()
    62  		adapter.Close()
    63  		return nil, err
    64  	}
    65  	return nativeTun, nil
    66  }
    67  
    68  func (t *NativeTun) configure() error {
    69  	luid := winipcfg.LUID(t.adapter.LUID())
    70  	if len(t.options.Inet4Address) > 0 {
    71  		err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET), t.options.Inet4Address)
    72  		if err != nil {
    73  			return E.Cause(err, "set ipv4 address")
    74  		}
    75  	}
    76  	if len(t.options.Inet6Address) > 0 {
    77  		err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address)
    78  		if err != nil {
    79  			return E.Cause(err, "set ipv6 address")
    80  		}
    81  	}
    82  	if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
    83  		_ = luid.DisableDNSRegistration()
    84  	}
    85  	if t.options.AutoRoute {
    86  		routeRanges, err := t.options.BuildAutoRouteRanges(false)
    87  		if err != nil {
    88  			return err
    89  		}
    90  		for _, routeRange := range routeRanges {
    91  			if routeRange.Addr().Is4() {
    92  				err = luid.AddRoute(routeRange, netip.IPv4Unspecified(), 0)
    93  			} else {
    94  				err = luid.AddRoute(routeRange, netip.IPv6Unspecified(), 0)
    95  			}
    96  		}
    97  		err = windnsapi.FlushResolverCache()
    98  		if err != nil {
    99  			return err
   100  		}
   101  	}
   102  	if len(t.options.Inet4Address) > 0 {
   103  		inetIf, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET))
   104  		if err != nil {
   105  			return err
   106  		}
   107  		inetIf.ForwardingEnabled = true
   108  		inetIf.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
   109  		inetIf.DadTransmits = 0
   110  		inetIf.ManagedAddressConfigurationSupported = false
   111  		inetIf.OtherStatefulConfigurationSupported = false
   112  		inetIf.NLMTU = t.options.MTU
   113  		if t.options.AutoRoute {
   114  			inetIf.UseAutomaticMetric = false
   115  			inetIf.Metric = 0
   116  		}
   117  		err = inetIf.Set()
   118  		if err != nil {
   119  			return E.Cause(err, "set ipv4 options")
   120  		}
   121  	}
   122  	if len(t.options.Inet6Address) > 0 {
   123  		inet6If, err := luid.IPInterface(winipcfg.AddressFamily(windows.AF_INET6))
   124  		if err != nil {
   125  			return err
   126  		}
   127  		inet6If.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
   128  		inet6If.DadTransmits = 0
   129  		inet6If.ManagedAddressConfigurationSupported = false
   130  		inet6If.OtherStatefulConfigurationSupported = false
   131  		inet6If.NLMTU = t.options.MTU
   132  		if t.options.AutoRoute {
   133  			inet6If.UseAutomaticMetric = false
   134  			inet6If.Metric = 0
   135  		}
   136  		err = inet6If.Set()
   137  		if err != nil {
   138  			return E.Cause(err, "set ipv6 options")
   139  		}
   140  	}
   141  
   142  	if t.options.AutoRoute && t.options.StrictRoute {
   143  		var engine uintptr
   144  		session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC}
   145  		err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine))
   146  		if err != nil {
   147  			return os.NewSyscallError("FwpmEngineOpen0", err)
   148  		}
   149  		t.fwpmSession = engine
   150  
   151  		subLayerKey, err := windows.GenerateGUID()
   152  		if err != nil {
   153  			return os.NewSyscallError("CoCreateGuid", err)
   154  		}
   155  
   156  		subLayer := winsys.FWPM_SUBLAYER0{}
   157  		subLayer.SubLayerKey = subLayerKey
   158  		subLayer.DisplayData = winsys.CreateDisplayData(TunnelType, "auto-route rules")
   159  		subLayer.Weight = math.MaxUint16
   160  		err = winsys.FwpmSubLayerAdd0(engine, &subLayer, 0)
   161  		if err != nil {
   162  			return os.NewSyscallError("FwpmSubLayerAdd0", err)
   163  		}
   164  
   165  		processAppID, err := winsys.GetCurrentProcessAppID()
   166  		if err != nil {
   167  			return err
   168  		}
   169  		defer winsys.FwpmFreeMemory0(unsafe.Pointer(&processAppID))
   170  
   171  		var filterId uint64
   172  		permitCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1)
   173  		permitCondition[0].FieldKey = winsys.FWPM_CONDITION_ALE_APP_ID
   174  		permitCondition[0].MatchType = winsys.FWP_MATCH_EQUAL
   175  		permitCondition[0].ConditionValue.Type = winsys.FWP_BYTE_BLOB_TYPE
   176  		permitCondition[0].ConditionValue.Value = uintptr(unsafe.Pointer(processAppID))
   177  
   178  		permitFilter4 := winsys.FWPM_FILTER0{}
   179  		permitFilter4.FilterCondition = &permitCondition[0]
   180  		permitFilter4.NumFilterConditions = 1
   181  		permitFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv4")
   182  		permitFilter4.SubLayerKey = subLayerKey
   183  		permitFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4
   184  		permitFilter4.Action.Type = winsys.FWP_ACTION_PERMIT
   185  		permitFilter4.Weight.Type = winsys.FWP_UINT8
   186  		permitFilter4.Weight.Value = uintptr(13)
   187  		permitFilter4.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT
   188  		err = winsys.FwpmFilterAdd0(engine, &permitFilter4, 0, &filterId)
   189  		if err != nil {
   190  			return os.NewSyscallError("FwpmFilterAdd0", err)
   191  		}
   192  
   193  		permitFilter6 := winsys.FWPM_FILTER0{}
   194  		permitFilter6.FilterCondition = &permitCondition[0]
   195  		permitFilter6.NumFilterConditions = 1
   196  		permitFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "protect ipv6")
   197  		permitFilter6.SubLayerKey = subLayerKey
   198  		permitFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6
   199  		permitFilter6.Action.Type = winsys.FWP_ACTION_PERMIT
   200  		permitFilter6.Weight.Type = winsys.FWP_UINT8
   201  		permitFilter6.Weight.Value = uintptr(13)
   202  		permitFilter6.Flags = winsys.FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT
   203  		err = winsys.FwpmFilterAdd0(engine, &permitFilter6, 0, &filterId)
   204  		if err != nil {
   205  			return os.NewSyscallError("FwpmFilterAdd0", err)
   206  		}
   207  
   208  		/*if len(t.options.Inet4Address) == 0 {
   209  			blockFilter := winsys.FWPM_FILTER0{}
   210  			blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4")
   211  			blockFilter.SubLayerKey = subLayerKey
   212  			blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4
   213  			blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK
   214  			blockFilter.Weight.Type = winsys.FWP_UINT8
   215  			blockFilter.Weight.Value = uintptr(12)
   216  			err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId)
   217  			if err != nil {
   218  				return os.NewSyscallError("FwpmFilterAdd0", err)
   219  			}
   220  		}*/
   221  
   222  		if len(t.options.Inet6Address) == 0 {
   223  			blockFilter := winsys.FWPM_FILTER0{}
   224  			blockFilter.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6")
   225  			blockFilter.SubLayerKey = subLayerKey
   226  			blockFilter.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6
   227  			blockFilter.Action.Type = winsys.FWP_ACTION_BLOCK
   228  			blockFilter.Weight.Type = winsys.FWP_UINT8
   229  			blockFilter.Weight.Value = uintptr(12)
   230  			err = winsys.FwpmFilterAdd0(engine, &blockFilter, 0, &filterId)
   231  			if err != nil {
   232  				return os.NewSyscallError("FwpmFilterAdd0", err)
   233  			}
   234  		}
   235  
   236  		netInterface, err := net.InterfaceByName(t.options.Name)
   237  		if err != nil {
   238  			return err
   239  		}
   240  
   241  		tunCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1)
   242  		tunCondition[0].FieldKey = winsys.FWPM_CONDITION_LOCAL_INTERFACE_INDEX
   243  		tunCondition[0].MatchType = winsys.FWP_MATCH_EQUAL
   244  		tunCondition[0].ConditionValue.Type = winsys.FWP_UINT32
   245  		tunCondition[0].ConditionValue.Value = uintptr(uint32(netInterface.Index))
   246  
   247  		if len(t.options.Inet4Address) > 0 {
   248  			tunFilter4 := winsys.FWPM_FILTER0{}
   249  			tunFilter4.FilterCondition = &tunCondition[0]
   250  			tunFilter4.NumFilterConditions = 1
   251  			tunFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv4")
   252  			tunFilter4.SubLayerKey = subLayerKey
   253  			tunFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4
   254  			tunFilter4.Action.Type = winsys.FWP_ACTION_PERMIT
   255  			tunFilter4.Weight.Type = winsys.FWP_UINT8
   256  			tunFilter4.Weight.Value = uintptr(11)
   257  			err = winsys.FwpmFilterAdd0(engine, &tunFilter4, 0, &filterId)
   258  			if err != nil {
   259  				return os.NewSyscallError("FwpmFilterAdd0", err)
   260  			}
   261  		}
   262  
   263  		if len(t.options.Inet6Address) > 0 {
   264  			tunFilter6 := winsys.FWPM_FILTER0{}
   265  			tunFilter6.FilterCondition = &tunCondition[0]
   266  			tunFilter6.NumFilterConditions = 1
   267  			tunFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "allow ipv6")
   268  			tunFilter6.SubLayerKey = subLayerKey
   269  			tunFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6
   270  			tunFilter6.Action.Type = winsys.FWP_ACTION_PERMIT
   271  			tunFilter6.Weight.Type = winsys.FWP_UINT8
   272  			tunFilter6.Weight.Value = uintptr(11)
   273  			err = winsys.FwpmFilterAdd0(engine, &tunFilter6, 0, &filterId)
   274  			if err != nil {
   275  				return os.NewSyscallError("FwpmFilterAdd0", err)
   276  			}
   277  		}
   278  
   279  		blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 2)
   280  		blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_PROTOCOL
   281  		blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL
   282  		blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT8
   283  		blockDNSCondition[0].ConditionValue.Value = uintptr(uint8(winsys.IPPROTO_UDP))
   284  		blockDNSCondition[1].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT
   285  		blockDNSCondition[1].MatchType = winsys.FWP_MATCH_EQUAL
   286  		blockDNSCondition[1].ConditionValue.Type = winsys.FWP_UINT16
   287  		blockDNSCondition[1].ConditionValue.Value = uintptr(uint16(53))
   288  
   289  		blockDNSFilter4 := winsys.FWPM_FILTER0{}
   290  		blockDNSFilter4.FilterCondition = &blockDNSCondition[0]
   291  		blockDNSFilter4.NumFilterConditions = 2
   292  		blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns")
   293  		blockDNSFilter4.SubLayerKey = subLayerKey
   294  		blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4
   295  		blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK
   296  		blockDNSFilter4.Weight.Type = winsys.FWP_UINT8
   297  		blockDNSFilter4.Weight.Value = uintptr(10)
   298  		err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId)
   299  		if err != nil {
   300  			return os.NewSyscallError("FwpmFilterAdd0", err)
   301  		}
   302  
   303  		blockDNSFilter6 := winsys.FWPM_FILTER0{}
   304  		blockDNSFilter6.FilterCondition = &blockDNSCondition[0]
   305  		blockDNSFilter6.NumFilterConditions = 2
   306  		blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns")
   307  		blockDNSFilter6.SubLayerKey = subLayerKey
   308  		blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6
   309  		blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK
   310  		blockDNSFilter6.Weight.Type = winsys.FWP_UINT8
   311  		blockDNSFilter6.Weight.Value = uintptr(10)
   312  		err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId)
   313  		if err != nil {
   314  			return os.NewSyscallError("FwpmFilterAdd0", err)
   315  		}
   316  	}
   317  
   318  	return nil
   319  }
   320  
   321  func (t *NativeTun) Read(p []byte) (n int, err error) {
   322  	return 0, os.ErrInvalid
   323  }
   324  
   325  func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
   326  	t.running.Add(1)
   327  	defer t.running.Done()
   328  retry:
   329  	if t.close.Load() == 1 {
   330  		return nil, nil, os.ErrClosed
   331  	}
   332  	start := nanotime()
   333  	shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
   334  	for {
   335  		if t.close.Load() == 1 {
   336  			return nil, nil, os.ErrClosed
   337  		}
   338  		packet, err := t.session.ReceivePacket()
   339  		switch err {
   340  		case nil:
   341  			packetSize := len(packet)
   342  			t.rate.update(uint64(packetSize))
   343  			return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil
   344  		case windows.ERROR_NO_MORE_ITEMS:
   345  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   346  				windows.WaitForSingleObject(t.readWait, windows.INFINITE)
   347  				goto retry
   348  			}
   349  			procyield(1)
   350  			continue
   351  		case windows.ERROR_HANDLE_EOF:
   352  			return nil, nil, os.ErrClosed
   353  		case windows.ERROR_INVALID_DATA:
   354  			return nil, nil, errors.New("send ring corrupt")
   355  		}
   356  		return nil, nil, fmt.Errorf("read failed: %w", err)
   357  	}
   358  }
   359  
   360  func (t *NativeTun) ReadFunc(block func(b []byte)) error {
   361  	t.running.Add(1)
   362  	defer t.running.Done()
   363  retry:
   364  	if t.close.Load() == 1 {
   365  		return os.ErrClosed
   366  	}
   367  	start := nanotime()
   368  	shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
   369  	for {
   370  		if t.close.Load() == 1 {
   371  			return os.ErrClosed
   372  		}
   373  		packet, err := t.session.ReceivePacket()
   374  		switch err {
   375  		case nil:
   376  			packetSize := len(packet)
   377  			block(packet)
   378  			t.session.ReleaseReceivePacket(packet)
   379  			t.rate.update(uint64(packetSize))
   380  			return nil
   381  		case windows.ERROR_NO_MORE_ITEMS:
   382  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   383  				windows.WaitForSingleObject(t.readWait, windows.INFINITE)
   384  				goto retry
   385  			}
   386  			procyield(1)
   387  			continue
   388  		case windows.ERROR_HANDLE_EOF:
   389  			return os.ErrClosed
   390  		case windows.ERROR_INVALID_DATA:
   391  			return errors.New("send ring corrupt")
   392  		}
   393  		return fmt.Errorf("read failed: %w", err)
   394  	}
   395  }
   396  
   397  func (t *NativeTun) Write(p []byte) (n int, err error) {
   398  	t.running.Add(1)
   399  	defer t.running.Done()
   400  	if t.close.Load() == 1 {
   401  		return 0, os.ErrClosed
   402  	}
   403  	t.rate.update(uint64(len(p)))
   404  	packet, err := t.session.AllocateSendPacket(len(p))
   405  	copy(packet, p)
   406  	if err == nil {
   407  		t.session.SendPacket(packet)
   408  		return len(p), nil
   409  	}
   410  	switch err {
   411  	case windows.ERROR_HANDLE_EOF:
   412  		return 0, os.ErrClosed
   413  	case windows.ERROR_BUFFER_OVERFLOW:
   414  		return 0, nil // Dropping when ring is full.
   415  	}
   416  	return 0, fmt.Errorf("write failed: %w", err)
   417  }
   418  
   419  func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
   420  	t.running.Add(1)
   421  	defer t.running.Done()
   422  	if t.close.Load() == 1 {
   423  		return 0, os.ErrClosed
   424  	}
   425  	var packetSize int
   426  	for _, packetElement := range packetElementList {
   427  		packetSize += len(packetElement)
   428  	}
   429  	t.rate.update(uint64(packetSize))
   430  	packet, err := t.session.AllocateSendPacket(packetSize)
   431  	if err == nil {
   432  		var index int
   433  		for _, packetElement := range packetElementList {
   434  			index += copy(packet[index:], packetElement)
   435  		}
   436  		t.session.SendPacket(packet)
   437  		return
   438  	}
   439  	switch err {
   440  	case windows.ERROR_HANDLE_EOF:
   441  		return 0, os.ErrClosed
   442  	case windows.ERROR_BUFFER_OVERFLOW:
   443  		return 0, nil // Dropping when ring is full.
   444  	}
   445  	return 0, fmt.Errorf("write failed: %w", err)
   446  }
   447  
   448  func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
   449  	defer buf.ReleaseMulti(buffers)
   450  	return common.Error(t.write(buf.ToSliceMulti(buffers)))
   451  }
   452  
   453  func (t *NativeTun) Close() error {
   454  	var err error
   455  	t.closeOnce.Do(func() {
   456  		t.close.Store(1)
   457  		windows.SetEvent(t.readWait)
   458  		t.running.Wait()
   459  		t.session.End()
   460  		t.adapter.Close()
   461  		if t.fwpmSession != 0 {
   462  			winsys.FwpmEngineClose0(t.fwpmSession)
   463  		}
   464  		if t.options.AutoRoute {
   465  			windnsapi.FlushResolverCache()
   466  		}
   467  	})
   468  	return err
   469  }
   470  
   471  func generateGUIDByDeviceName(name string) *windows.GUID {
   472  	hash := md5.New()
   473  	hash.Write([]byte("wintun"))
   474  	hash.Write([]byte(name))
   475  	sum := hash.Sum(nil)
   476  	return (*windows.GUID)(unsafe.Pointer(&sum[0]))
   477  }
   478  
   479  //go:linkname procyield runtime.procyield
   480  func procyield(cycles uint32)
   481  
   482  //go:linkname nanotime runtime.nanotime
   483  func nanotime() int64
   484  
   485  type rateJuggler struct {
   486  	current       atomic.Uint64
   487  	nextByteCount atomic.Uint64
   488  	nextStartTime atomic.Int64
   489  	changing      atomic.Int32
   490  }
   491  
   492  func (rate *rateJuggler) update(packetLen uint64) {
   493  	now := nanotime()
   494  	total := rate.nextByteCount.Add(packetLen)
   495  	period := uint64(now - rate.nextStartTime.Load())
   496  	if period >= rateMeasurementGranularity {
   497  		if !rate.changing.CompareAndSwap(0, 1) {
   498  			return
   499  		}
   500  		rate.nextStartTime.Store(now)
   501  		rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
   502  		rate.nextByteCount.Store(0)
   503  		rate.changing.Store(0)
   504  	}
   505  }
   506  
   507  const (
   508  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
   509  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
   510  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
   511  )