github.com/chwjbn/xclash@v0.2.0/component/process/process_windows.go (about)

     1  package process
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"path/filepath"
     7  	"sync"
     8  	"syscall"
     9  	"unsafe"
    10  
    11  	"github.com/chwjbn/xclash/log"
    12  
    13  	"golang.org/x/sys/windows"
    14  )
    15  
    16  const (
    17  	tcpTableFunc      = "GetExtendedTcpTable"
    18  	tcpTablePidConn   = 4
    19  	udpTableFunc      = "GetExtendedUdpTable"
    20  	udpTablePid       = 1
    21  	queryProcNameFunc = "QueryFullProcessImageNameW"
    22  )
    23  
    24  var (
    25  	getExTCPTable uintptr
    26  	getExUDPTable uintptr
    27  	queryProcName uintptr
    28  
    29  	once sync.Once
    30  )
    31  
    32  func initWin32API() error {
    33  	h, err := windows.LoadLibrary("iphlpapi.dll")
    34  	if err != nil {
    35  		return fmt.Errorf("LoadLibrary iphlpapi.dll failed: %s", err.Error())
    36  	}
    37  
    38  	getExTCPTable, err = windows.GetProcAddress(h, tcpTableFunc)
    39  	if err != nil {
    40  		return fmt.Errorf("GetProcAddress of %s failed: %s", tcpTableFunc, err.Error())
    41  	}
    42  
    43  	getExUDPTable, err = windows.GetProcAddress(h, udpTableFunc)
    44  	if err != nil {
    45  		return fmt.Errorf("GetProcAddress of %s failed: %s", udpTableFunc, err.Error())
    46  	}
    47  
    48  	h, err = windows.LoadLibrary("kernel32.dll")
    49  	if err != nil {
    50  		return fmt.Errorf("LoadLibrary kernel32.dll failed: %s", err.Error())
    51  	}
    52  
    53  	queryProcName, err = windows.GetProcAddress(h, queryProcNameFunc)
    54  	if err != nil {
    55  		return fmt.Errorf("GetProcAddress of %s failed: %s", queryProcNameFunc, err.Error())
    56  	}
    57  
    58  	return nil
    59  }
    60  
    61  func findProcessName(network string, ip net.IP, srcPort int) (string, error) {
    62  	once.Do(func() {
    63  		err := initWin32API()
    64  		if err != nil {
    65  			log.Errorln("Initialize PROCESS-NAME failed: %s", err.Error())
    66  			log.Warnln("All PROCESS-NAMES rules will be skiped")
    67  			return
    68  		}
    69  	})
    70  	family := windows.AF_INET
    71  	if ip.To4() == nil {
    72  		family = windows.AF_INET6
    73  	}
    74  
    75  	var class int
    76  	var fn uintptr
    77  	switch network {
    78  	case TCP:
    79  		fn = getExTCPTable
    80  		class = tcpTablePidConn
    81  	case UDP:
    82  		fn = getExUDPTable
    83  		class = udpTablePid
    84  	default:
    85  		return "", ErrInvalidNetwork
    86  	}
    87  
    88  	buf, err := getTransportTable(fn, family, class)
    89  	if err != nil {
    90  		return "", err
    91  	}
    92  
    93  	s := newSearcher(family == windows.AF_INET, network == TCP)
    94  
    95  	pid, err := s.Search(buf, ip, uint16(srcPort))
    96  	if err != nil {
    97  		return "", err
    98  	}
    99  	return getExecPathFromPID(pid)
   100  }
   101  
   102  type searcher struct {
   103  	itemSize int
   104  	port     int
   105  	ip       int
   106  	ipSize   int
   107  	pid      int
   108  	tcpState int
   109  }
   110  
   111  func (s *searcher) Search(b []byte, ip net.IP, port uint16) (uint32, error) {
   112  	n := int(readNativeUint32(b[:4]))
   113  	itemSize := s.itemSize
   114  	for i := 0; i < n; i++ {
   115  		row := b[4+itemSize*i : 4+itemSize*(i+1)]
   116  
   117  		if s.tcpState >= 0 {
   118  			tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4])
   119  			// MIB_TCP_STATE_ESTAB, only check established connections for TCP
   120  			if tcpState != 5 {
   121  				continue
   122  			}
   123  		}
   124  
   125  		// according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian.
   126  		// this field can be illustrated as follows depends on different machine endianess:
   127  		//     little endian: [ MSB LSB  0   0  ]   interpret as native uint32 is ((LSB<<8)|MSB)
   128  		//       big  endian: [  0   0  MSB LSB ]   interpret as native uint32 is ((MSB<<8)|LSB)
   129  		// so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32
   130  		srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4])))
   131  		if srcPort != port {
   132  			continue
   133  		}
   134  
   135  		srcIP := net.IP(row[s.ip : s.ip+s.ipSize])
   136  		// windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto
   137  		if !ip.Equal(srcIP) && (!srcIP.IsUnspecified() || s.tcpState != -1) {
   138  			continue
   139  		}
   140  
   141  		pid := readNativeUint32(row[s.pid : s.pid+4])
   142  		return pid, nil
   143  	}
   144  	return 0, ErrNotFound
   145  }
   146  
   147  func newSearcher(isV4, isTCP bool) *searcher {
   148  	var itemSize, port, ip, ipSize, pid int
   149  	tcpState := -1
   150  	switch {
   151  	case isV4 && isTCP:
   152  		// struct MIB_TCPROW_OWNER_PID
   153  		itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0
   154  	case isV4 && !isTCP:
   155  		// struct MIB_UDPROW_OWNER_PID
   156  		itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8
   157  	case !isV4 && isTCP:
   158  		// struct MIB_TCP6ROW_OWNER_PID
   159  		itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48
   160  	case !isV4 && !isTCP:
   161  		// struct MIB_UDP6ROW_OWNER_PID
   162  		itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24
   163  	}
   164  
   165  	return &searcher{
   166  		itemSize: itemSize,
   167  		port:     port,
   168  		ip:       ip,
   169  		ipSize:   ipSize,
   170  		pid:      pid,
   171  		tcpState: tcpState,
   172  	}
   173  }
   174  
   175  func getTransportTable(fn uintptr, family int, class int) ([]byte, error) {
   176  	for size, buf := uint32(8), make([]byte, 8); ; {
   177  		ptr := unsafe.Pointer(&buf[0])
   178  		err, _, _ := syscall.Syscall6(fn, 6, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0)
   179  
   180  		switch err {
   181  		case 0:
   182  			return buf, nil
   183  		case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER):
   184  			buf = make([]byte, size)
   185  		default:
   186  			return nil, fmt.Errorf("syscall error: %d", err)
   187  		}
   188  	}
   189  }
   190  
   191  func readNativeUint32(b []byte) uint32 {
   192  	return *(*uint32)(unsafe.Pointer(&b[0]))
   193  }
   194  
   195  func getExecPathFromPID(pid uint32) (string, error) {
   196  	// kernel process starts with a colon in order to distinguish with normal processes
   197  	switch pid {
   198  	case 0:
   199  		// reserved pid for system idle process
   200  		return ":System Idle Process", nil
   201  	case 4:
   202  		// reserved pid for windows kernel image
   203  		return ":System", nil
   204  	}
   205  	h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
   206  	if err != nil {
   207  		return "", err
   208  	}
   209  	defer windows.CloseHandle(h)
   210  
   211  	buf := make([]uint16, syscall.MAX_LONG_PATH)
   212  	size := uint32(len(buf))
   213  	r1, _, err := syscall.Syscall6(
   214  		queryProcName, 4,
   215  		uintptr(h),
   216  		uintptr(1),
   217  		uintptr(unsafe.Pointer(&buf[0])),
   218  		uintptr(unsafe.Pointer(&size)),
   219  		0, 0)
   220  	if r1 == 0 {
   221  		return "", err
   222  	}
   223  	return filepath.Base(syscall.UTF16ToString(buf[:size])), nil
   224  }