github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/netlink/netlink_windows.go (about)

     1  package netlink
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"net/netip"
     8  	"unsafe"
     9  
    10  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    11  	"golang.org/x/sys/windows"
    12  )
    13  
    14  var (
    15  	modIphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
    16  
    17  	procGetExtendedTcpTable = modIphlpapi.NewProc("GetExtendedTcpTable")
    18  	procGetExtendedUdpTable = modIphlpapi.NewProc("GetExtendedUdpTable")
    19  )
    20  
    21  func FindProcessName(network string, ip net.IP, srcPort uint16, to net.IP, toPort uint16) (string, error) {
    22  	family := uint32(windows.AF_INET)
    23  	if ip.To4() == nil {
    24  		family = windows.AF_INET6
    25  	}
    26  
    27  	var protocol uint32
    28  	switch network {
    29  	case "tcp":
    30  		protocol = windows.IPPROTO_TCP
    31  	case "udp":
    32  		protocol = windows.IPPROTO_UDP
    33  	default:
    34  		return "", errors.New("ErrInvalidNetwork")
    35  	}
    36  
    37  	saddr, _ := netip.AddrFromSlice(ip)
    38  	daddr, _ := netip.AddrFromSlice(to)
    39  
    40  	pid, err := findPidByConnectionEndpoint(family,
    41  		protocol,
    42  		netip.AddrPortFrom(saddr, srcPort),
    43  		netip.AddrPortFrom(daddr, toPort),
    44  	)
    45  	if err != nil {
    46  		return "", err
    47  	}
    48  
    49  	return getExecPathFromPID(pid)
    50  }
    51  
    52  func findPidByConnectionEndpoint(family uint32, protocol uint32, from netip.AddrPort, to netip.AddrPort) (uint32, error) {
    53  	buf := pool.GetBytes(0)
    54  	defer pool.PutBytes(buf)
    55  
    56  	bufSize := uint32(len(buf))
    57  
    58  loop:
    59  	for {
    60  		var ret uintptr
    61  
    62  		switch protocol {
    63  		case windows.IPPROTO_TCP:
    64  			ret, _, _ = procGetExtendedTcpTable.Call(
    65  				uintptr(unsafe.Pointer(unsafe.SliceData(buf))),
    66  				uintptr(unsafe.Pointer(&bufSize)),
    67  				0,
    68  				uintptr(family),
    69  				4, // TCP_TABLE_OWNER_PID_CONNECTIONS
    70  				0,
    71  			)
    72  		case windows.IPPROTO_UDP:
    73  			ret, _, _ = procGetExtendedUdpTable.Call(
    74  				uintptr(unsafe.Pointer(unsafe.SliceData(buf))),
    75  				uintptr(unsafe.Pointer(&bufSize)),
    76  				0,
    77  				uintptr(family),
    78  				1, // UDP_TABLE_OWNER_PID
    79  				0,
    80  			)
    81  		default:
    82  			return 0, errors.New("unsupported network")
    83  		}
    84  
    85  		switch ret {
    86  		case 0:
    87  			buf = buf[:bufSize]
    88  
    89  			break loop
    90  		case uintptr(windows.ERROR_INSUFFICIENT_BUFFER):
    91  			pool.PutBytes(buf)
    92  			buf = pool.GetBytes(int(bufSize))
    93  
    94  			continue loop
    95  		default:
    96  			return 0, fmt.Errorf("syscall error: %d", ret)
    97  		}
    98  	}
    99  
   100  	if len(buf) < int(unsafe.Sizeof(uint32(0))) {
   101  		return 0, fmt.Errorf("invalid table size: %d", len(buf))
   102  	}
   103  
   104  	entriesSize := *(*uint32)(unsafe.Pointer(&buf[0]))
   105  
   106  	switch protocol {
   107  	case windows.IPPROTO_TCP:
   108  		if family == windows.AF_INET {
   109  			type MibTcpRowOwnerPid struct {
   110  				State      uint32
   111  				LocalAddr  [4]byte
   112  				LocalPort  uint32
   113  				RemoteAddr [4]byte
   114  				RemotePort uint32
   115  				OwningPid  uint32
   116  			}
   117  
   118  			if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibTcpRowOwnerPid{})) {
   119  				return 0, fmt.Errorf("invalid tables size: %d", len(buf))
   120  			}
   121  
   122  			entries := unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize)
   123  			for _, entry := range entries {
   124  				localAddr := netip.AddrFrom4(entry.LocalAddr)
   125  				localPort := windows.Ntohs(uint16(entry.LocalPort))
   126  				remoteAddr := netip.AddrFrom4(entry.RemoteAddr)
   127  				remotePort := windows.Ntohs(uint16(entry.RemotePort))
   128  
   129  				if localAddr == from.Addr() && remoteAddr == to.Addr() && localPort == from.Port() && remotePort == to.Port() {
   130  					return entry.OwningPid, nil
   131  				}
   132  			}
   133  		} else {
   134  			type MibTcp6RowOwnerPid struct {
   135  				LocalAddr     [16]byte
   136  				LocalScopeID  uint32
   137  				LocalPort     uint32
   138  				RemoteAddr    [16]byte
   139  				RemoteScopeID uint32
   140  				RemotePort    uint32
   141  				State         uint32
   142  				OwningPid     uint32
   143  			}
   144  
   145  			if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibTcp6RowOwnerPid{})) {
   146  				return 0, fmt.Errorf("invalid tables size: %d", len(buf))
   147  			}
   148  
   149  			entries := unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize)
   150  			for _, entry := range entries {
   151  				localAddr := netip.AddrFrom16(entry.LocalAddr)
   152  				localPort := windows.Ntohs(uint16(entry.LocalPort))
   153  				remoteAddr := netip.AddrFrom16(entry.RemoteAddr)
   154  				remotePort := windows.Ntohs(uint16(entry.RemotePort))
   155  
   156  				if localAddr == from.Addr() && remoteAddr == to.Addr() && localPort == from.Port() && remotePort == to.Port() {
   157  					return entry.OwningPid, nil
   158  				}
   159  			}
   160  		}
   161  	case windows.IPPROTO_UDP:
   162  		if family == windows.AF_INET {
   163  			type MibUdpRowOwnerPid struct {
   164  				LocalAddr [4]byte
   165  				LocalPort uint32
   166  				OwningPid uint32
   167  			}
   168  
   169  			if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibUdpRowOwnerPid{})) {
   170  				return 0, fmt.Errorf("invalid tables size: %d", len(buf))
   171  			}
   172  
   173  			entries := unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize)
   174  			for _, entry := range entries {
   175  				localAddr := netip.AddrFrom4(entry.LocalAddr)
   176  				localPort := windows.Ntohs(uint16(entry.LocalPort))
   177  
   178  				if (localAddr == from.Addr() || localAddr.IsUnspecified()) && localPort == from.Port() {
   179  					return entry.OwningPid, nil
   180  				}
   181  			}
   182  		} else {
   183  			type MibUdp6RowOwnerPid struct {
   184  				LocalAddr    [16]byte
   185  				LocalScopeId uint32
   186  				LocalPort    uint32
   187  				OwningPid    uint32
   188  			}
   189  
   190  			if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibUdp6RowOwnerPid{})) {
   191  				return 0, fmt.Errorf("invalid tables size: %d", len(buf))
   192  			}
   193  
   194  			entries := unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize)
   195  			for _, entry := range entries {
   196  				localAddr := netip.AddrFrom16(entry.LocalAddr)
   197  				localPort := windows.Ntohs(uint16(entry.LocalPort))
   198  
   199  				if (localAddr == from.Addr() || localAddr.IsUnspecified()) && localPort == from.Port() {
   200  					return entry.OwningPid, nil
   201  				}
   202  			}
   203  		}
   204  	default:
   205  		return 0, errors.New("ErrInvalidNetwork")
   206  	}
   207  
   208  	return 0, errors.New("ErrNotFound")
   209  }
   210  
   211  func getExecPathFromPID(pid uint32) (string, error) {
   212  	// kernel process starts with a colon in order to distinguish with normal processes
   213  	switch pid {
   214  	case 0:
   215  		// reserved pid for system idle process
   216  		return ":System Idle Process", nil
   217  	case 4:
   218  		// reserved pid for windows kernel image
   219  		return ":System", nil
   220  	}
   221  	h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid)
   222  	if err != nil {
   223  		return "", err
   224  	}
   225  	defer windows.CloseHandle(h)
   226  
   227  	buf := make([]uint16, windows.MAX_LONG_PATH)
   228  	size := uint32(len(buf))
   229  
   230  	err = windows.QueryFullProcessImageName(h, 0, &buf[0], &size)
   231  	if err != nil {
   232  		return "", err
   233  	}
   234  
   235  	return windows.UTF16ToString(buf[:size]), nil
   236  }