github.com/fumiama/water@v0.0.0-20211231134027-da391938d6ac/syscalls_windows.go (about)

     1  package water
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"os"
     9  	"sync"
    10  	"sync/atomic"
    11  	"syscall"
    12  	"time"
    13  	"unsafe"
    14  
    15  	"golang.org/x/sys/windows"
    16  	"golang.org/x/sys/windows/registry"
    17  
    18  	"github.com/fumiama/wintun"
    19  )
    20  
    21  // To use it with windows, you need a tap driver installed on windows.
    22  // https://github.com/OpenVPN/tap-windows6
    23  // or just install OpenVPN
    24  // https://github.com/OpenVPN/openvpn
    25  
    26  const (
    27  	// tapDriverKey is the location of the TAP driver key.
    28  	tapDriverKey = `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`
    29  	// netConfigKey is the location of the TAP adapter's network config.
    30  	netConfigKey = `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}`
    31  )
    32  
    33  var (
    34  	errIfceNameNotFound = errors.New("failed to find the name of interface")
    35  	// Device Control Codes
    36  	tap_win_ioctl_get_mac = tap_control_code(1, 0)
    37  	// tap_win_ioctl_get_version         = tap_control_code(2, 0)
    38  	// tap_win_ioctl_get_mtu             = tap_control_code(3, 0)
    39  	// tap_win_ioctl_get_info            = tap_control_code(4, 0)
    40  	// tap_ioctl_config_point_to_point   = tap_control_code(5, 0)
    41  	tap_ioctl_set_media_status = tap_control_code(6, 0)
    42  	// tap_win_ioctl_config_dhcp_masq    = tap_control_code(7, 0)
    43  	// tap_win_ioctl_get_log_line        = tap_control_code(8, 0)
    44  	// tap_win_ioctl_config_dhcp_set_opt = tap_control_code(9, 0)
    45  	// tap_ioctl_config_tun              = tap_control_code(10, 0)
    46  	// w32 api
    47  	file_device_unknown = uint32(0x00000022)
    48  	nCreateEvent,
    49  	nResetEvent,
    50  	nGetOverlappedResult uintptr
    51  )
    52  
    53  func init() {
    54  	k32, err := syscall.LoadLibrary("kernel32.dll")
    55  	if err != nil {
    56  		panic("LoadLibrary " + err.Error())
    57  	}
    58  	defer syscall.FreeLibrary(k32)
    59  
    60  	nCreateEvent = getProcAddr(k32, "CreateEventW")
    61  	nResetEvent = getProcAddr(k32, "ResetEvent")
    62  	nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult")
    63  }
    64  
    65  func getProcAddr(lib syscall.Handle, name string) uintptr {
    66  	addr, err := syscall.GetProcAddress(lib, name)
    67  	if err != nil {
    68  		panic(name + " " + err.Error())
    69  	}
    70  	return addr
    71  }
    72  
    73  func resetEvent(h syscall.Handle) error {
    74  	r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0)
    75  	if r == 0 {
    76  		return err
    77  	}
    78  	return nil
    79  }
    80  
    81  func getOverlappedResult(h syscall.Handle, overlapped *syscall.Overlapped) (int, error) {
    82  	var n int
    83  	r, _, err := syscall.Syscall6(nGetOverlappedResult, 4,
    84  		uintptr(h),
    85  		uintptr(unsafe.Pointer(overlapped)),
    86  		uintptr(unsafe.Pointer(&n)), 1, 0, 0)
    87  	if r == 0 {
    88  		return n, err
    89  	}
    90  
    91  	return n, nil
    92  }
    93  
    94  func newOverlapped() (*syscall.Overlapped, error) {
    95  	var overlapped syscall.Overlapped
    96  	r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0)
    97  	if r == 0 {
    98  		return nil, err
    99  	}
   100  	overlapped.HEvent = syscall.Handle(r)
   101  	return &overlapped, nil
   102  }
   103  
   104  type wfile struct {
   105  	fd syscall.Handle
   106  	rl sync.Mutex
   107  	wl sync.Mutex
   108  	ro *syscall.Overlapped
   109  	wo *syscall.Overlapped
   110  }
   111  
   112  func (f *wfile) Close() error {
   113  	return syscall.Close(f.fd)
   114  }
   115  
   116  func (f *wfile) Write(b []byte) (int, error) {
   117  	f.wl.Lock()
   118  	defer f.wl.Unlock()
   119  
   120  	if err := resetEvent(f.wo.HEvent); err != nil {
   121  		return 0, err
   122  	}
   123  	var n uint32
   124  	err := syscall.WriteFile(f.fd, b, &n, f.wo)
   125  	if err != nil && err != syscall.ERROR_IO_PENDING {
   126  		return int(n), err
   127  	}
   128  	return getOverlappedResult(f.fd, f.wo)
   129  }
   130  
   131  func (f *wfile) Read(b []byte) (int, error) {
   132  	f.rl.Lock()
   133  	defer f.rl.Unlock()
   134  
   135  	if err := resetEvent(f.ro.HEvent); err != nil {
   136  		return 0, err
   137  	}
   138  	var done uint32
   139  	err := syscall.ReadFile(f.fd, b, &done, f.ro)
   140  	if err != nil && err != syscall.ERROR_IO_PENDING {
   141  		return int(done), err
   142  	}
   143  	return getOverlappedResult(f.fd, f.ro)
   144  }
   145  
   146  func ctl_code(device_type, function, method, access uint32) uint32 {
   147  	return (device_type << 16) | (access << 14) | (function << 2) | method
   148  }
   149  
   150  func tap_control_code(request, method uint32) uint32 {
   151  	return ctl_code(file_device_unknown, request, method, 0)
   152  }
   153  
   154  // getdeviceid finds out a TAP device from registry, it *may* requires privileged right to prevent some weird issue.
   155  func getdeviceid(componentID string, interfaceName string) (deviceid string, err error) {
   156  	k, err := registry.OpenKey(registry.LOCAL_MACHINE, tapDriverKey, registry.READ)
   157  	if err != nil {
   158  		return "", errors.New("Failed to open the adapter registry, TAP driver may be not installed" + err.Error())
   159  	}
   160  	defer k.Close()
   161  	// read all subkeys, it should not return an err here
   162  	keys, err := k.ReadSubKeyNames(-1)
   163  	if err != nil {
   164  		return "", err
   165  	}
   166  	// find the one matched ComponentId
   167  	for _, v := range keys {
   168  		key, err := registry.OpenKey(registry.LOCAL_MACHINE, tapDriverKey+"\\"+v, registry.READ)
   169  		if err != nil {
   170  			continue
   171  		}
   172  		val, _, err := key.GetStringValue("ComponentId")
   173  		if err != nil {
   174  			key.Close()
   175  			continue
   176  		}
   177  		if val == componentID {
   178  			val, _, err = key.GetStringValue("NetCfgInstanceId")
   179  			if err != nil {
   180  				key.Close()
   181  				continue
   182  			}
   183  			if len(interfaceName) > 0 {
   184  				key2 := fmt.Sprintf("%s\\%s\\Connection", netConfigKey, val)
   185  				k2, err := registry.OpenKey(registry.LOCAL_MACHINE, key2, registry.READ)
   186  				if err != nil {
   187  					continue
   188  				}
   189  				defer k2.Close()
   190  				val, _, err := k2.GetStringValue("Name")
   191  				if err != nil || val != interfaceName {
   192  					continue
   193  				}
   194  			}
   195  			key.Close()
   196  			return val, nil
   197  		}
   198  		key.Close()
   199  	}
   200  	if len(interfaceName) > 0 {
   201  		return "", errors.New("Failed to find the tap device in registry with specified ComponentId '" + componentID + "' and InterfaceName '" + interfaceName + "', TAP driver may be not installed or you may have specified an interface name that doesn't exist")
   202  	}
   203  
   204  	return "", errors.New("Failed to find the tap device in registry with specified ComponentId '" + componentID + "', TAP driver may be not installed")
   205  }
   206  
   207  // setStatus is used to bring up or bring down the interface
   208  func setStatus(fd syscall.Handle, status bool) error {
   209  	var bytesReturned uint32
   210  	rdbbuf := make([]byte, syscall.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
   211  	code := []byte{0x00, 0x00, 0x00, 0x00}
   212  	if status {
   213  		code[0] = 0x01
   214  	}
   215  	return syscall.DeviceIoControl(fd, tap_ioctl_set_media_status, &code[0], uint32(4), &rdbbuf[0], uint32(len(rdbbuf)), &bytesReturned, nil)
   216  }
   217  
   218  func openTap(config Config) (ifce *Interface, err error) {
   219  	if config.ComponentID == "" {
   220  		config.ComponentID = "root\\tap0901"
   221  	}
   222  	// find the device in registry.
   223  	deviceid, err := getdeviceid(config.PlatformSpecificParams.ComponentID, config.PlatformSpecificParams.InterfaceName)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	path := "\\\\.\\Global\\" + deviceid + ".tap"
   228  	pathp, err := syscall.UTF16PtrFromString(path)
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  	// type Handle uintptr
   233  	file, err := syscall.CreateFile(pathp, syscall.GENERIC_READ|syscall.GENERIC_WRITE, uint32(syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE), nil, syscall.OPEN_EXISTING, syscall.FILE_ATTRIBUTE_SYSTEM|syscall.FILE_FLAG_OVERLAPPED, 0)
   234  	// if err hanppens, close the interface.
   235  	defer func() {
   236  		if err != nil {
   237  			syscall.Close(file)
   238  		}
   239  		if err := recover(); err != nil {
   240  			syscall.Close(file)
   241  		}
   242  	}()
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  	var bytesReturned uint32
   247  
   248  	// find the mac address of tap device, use this to find the name of interface
   249  	mac := make([]byte, 6)
   250  	err = syscall.DeviceIoControl(file, tap_win_ioctl_get_mac, &mac[0], uint32(len(mac)), &mac[0], uint32(len(mac)), &bytesReturned, nil)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	// fd := os.NewFile(uintptr(file), path)
   256  	ro, err := newOverlapped()
   257  	if err != nil {
   258  		return
   259  	}
   260  	wo, err := newOverlapped()
   261  	if err != nil {
   262  		return
   263  	}
   264  	fd := &wfile{fd: file, ro: ro, wo: wo}
   265  	ifce = &Interface{isTAP: (config.DeviceType == TAP), ReadWriteCloser: fd}
   266  
   267  	// bring up device.
   268  	if err := setStatus(file, true); err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	// find the name of tap interface(u need it to set the ip or other command)
   273  	ifces, err := net.Interfaces()
   274  	if err != nil {
   275  		return
   276  	}
   277  
   278  	for _, v := range ifces {
   279  		if len(v.HardwareAddr) < 6 {
   280  			continue
   281  		}
   282  		if bytes.Equal(v.HardwareAddr[:6], mac[:6]) {
   283  			ifce.name = v.Name
   284  			return
   285  		}
   286  	}
   287  
   288  	return nil, errIfceNameNotFound
   289  }
   290  
   291  // https://github.com/WireGuard/wireguard-go/blob/master/tun/tun_windows.go
   292  const (
   293  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
   294  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
   295  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
   296  )
   297  
   298  //go:linkname procyield runtime.procyield
   299  func procyield(cycles uint32)
   300  
   301  //go:linkname nanotime runtime.nanotime
   302  func nanotime() int64
   303  
   304  type rateJuggler struct {
   305  	current       uint64
   306  	nextByteCount uint64
   307  	nextStartTime int64
   308  	changing      int32
   309  }
   310  
   311  func (rate *rateJuggler) update(packetLen uint64) {
   312  	now := nanotime()
   313  	total := atomic.AddUint64(&rate.nextByteCount, packetLen)
   314  	period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
   315  	if period >= rateMeasurementGranularity {
   316  		if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
   317  			return
   318  		}
   319  		atomic.StoreInt64(&rate.nextStartTime, now)
   320  		atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
   321  		atomic.StoreUint64(&rate.nextByteCount, 0)
   322  		atomic.StoreInt32(&rate.changing, 0)
   323  	}
   324  }
   325  
   326  type wintunRWC struct {
   327  	ad       *wintun.Adapter
   328  	s        wintun.Session
   329  	rate     rateJuggler
   330  	readwait windows.Handle
   331  	mu       sync.Mutex
   332  	readbuf  []byte
   333  	isclosed bool
   334  }
   335  
   336  func (w *wintunRWC) Close() error {
   337  	w.isclosed = true
   338  	w.s.End()
   339  	return w.ad.Close()
   340  }
   341  
   342  func (w *wintunRWC) Write(b []byte) (int, error) {
   343  	w.rate.update(uint64(len(b)))
   344  	w.mu.Lock()
   345  	defer w.mu.Unlock()
   346  ALLOC:
   347  	packet, err := w.s.AllocateSendPacket(len(b))
   348  	switch err {
   349  	case nil:
   350  		copy(packet, b)
   351  		w.s.SendPacket(packet)
   352  		return len(b), nil
   353  	case windows.ERROR_HANDLE_EOF:
   354  		w.s.End()
   355  		w.s, err = w.ad.StartSession(0x800000) // Ring capacity, 8 MiB
   356  		if err == nil {
   357  			goto ALLOC
   358  		}
   359  		return 0, os.ErrClosed
   360  	case windows.ERROR_BUFFER_OVERFLOW:
   361  		return 0, nil // Dropping when ring is full.
   362  	default:
   363  		return 0, err
   364  	}
   365  }
   366  
   367  func (w *wintunRWC) Read(b []byte) (int, error) {
   368  	w.mu.Lock()
   369  	defer w.mu.Unlock()
   370  
   371  	n := 0
   372  
   373  	if w.readbuf != nil {
   374  		n = copy(b, w.readbuf)
   375  		if len(w.readbuf) >= len(b) {
   376  			w.readbuf = w.readbuf[len(b):]
   377  			if len(w.readbuf) == 0 {
   378  				w.readbuf = nil
   379  			}
   380  			return n, nil
   381  		}
   382  		b = b[len(w.readbuf):]
   383  		w.readbuf = nil
   384  	}
   385  
   386  RETRY:
   387  	if w.isclosed {
   388  		return 0, errors.New("wintun is closed")
   389  	}
   390  	start := nanotime()
   391  	shouldSpin := atomic.LoadUint64(&w.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&w.rate.nextStartTime)) <= rateMeasurementGranularity*2
   392  	for {
   393  		packet, err := w.s.ReceivePacket()
   394  		switch err {
   395  		case nil:
   396  			packetSize := len(packet)
   397  			n += copy(b, packet)
   398  			if len(packet) > len(b) {
   399  				w.readbuf = make([]byte, len(packet)-len(b))
   400  				copy(w.readbuf, packet[len(b):])
   401  			}
   402  			w.s.ReleaseReceivePacket(packet)
   403  			w.rate.update(uint64(packetSize))
   404  			return n, nil
   405  		case windows.ERROR_NO_MORE_ITEMS:
   406  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
   407  				w.mu.Unlock()
   408  				windows.WaitForSingleObject(w.readwait, windows.INFINITE)
   409  				w.mu.Lock()
   410  				goto RETRY
   411  			}
   412  			w.mu.Unlock()
   413  			procyield(1)
   414  			w.mu.Lock()
   415  			continue
   416  		}
   417  		w.s.End()
   418  		w.s, err = w.ad.StartSession(0x800000) // Ring capacity, 8 MiB
   419  		if err == nil {
   420  			continue
   421  		}
   422  		return n, err
   423  	}
   424  }
   425  
   426  // openDev find and open an interface.
   427  func openDev(config Config) (ifce *Interface, err error) {
   428  	// TAP
   429  	if config.DeviceType == TAP {
   430  		return openTap(config)
   431  	}
   432  	// TUN
   433  	var ad *wintun.Adapter
   434  	if config.InterfaceName == "" {
   435  		config.InterfaceName = "WaterWinTunInterface"
   436  	}
   437  	if config.ComponentID == "" {
   438  		config.ComponentID = "WaterWintun"
   439  	}
   440  	ad, err = wintun.OpenAdapter(config.InterfaceName)
   441  	if err != nil {
   442  		ad, err = wintun.CreateAdapter(config.InterfaceName, config.ComponentID, nil)
   443  	}
   444  
   445  	if err != nil {
   446  		return
   447  	}
   448  	s, err := ad.StartSession(0x800000) // Ring capacity, 8 MiB
   449  	if err != nil {
   450  		ad.Close()
   451  		return
   452  	}
   453  	return &Interface{ReadWriteCloser: &wintunRWC{s: s, ad: ad, readwait: s.ReadWaitEvent()}, name: config.InterfaceName}, nil
   454  }