github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/conn/conn_windows.go (about)

     1  package conn
     2  
     3  import (
     4  	"fmt"
     5  	"net/netip"
     6  	"unsafe"
     7  
     8  	"golang.org/x/sys/windows"
     9  )
    10  
    11  const (
    12  	IP_MTU_DISCOVER   = 71
    13  	IPV6_MTU_DISCOVER = 71
    14  )
    15  
    16  // enum PMTUD_STATE from ws2ipdef.h
    17  const (
    18  	IP_PMTUDISC_NOT_SET = iota
    19  	IP_PMTUDISC_DO
    20  	IP_PMTUDISC_DONT
    21  	IP_PMTUDISC_PROBE
    22  	IP_PMTUDISC_MAX
    23  )
    24  
    25  func setPMTUD(fd int, network string) error {
    26  	// Set IP_MTU_DISCOVER for both v4 and v6.
    27  	if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
    28  		return fmt.Errorf("failed to set socket option IP_MTU_DISCOVER: %w", err)
    29  	}
    30  
    31  	switch network {
    32  	case "tcp4", "udp4":
    33  	case "tcp6", "udp6":
    34  		if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_MTU_DISCOVER, IP_PMTUDISC_DO); err != nil {
    35  			return fmt.Errorf("failed to set socket option IPV6_MTU_DISCOVER: %w", err)
    36  		}
    37  	default:
    38  		return fmt.Errorf("unsupported network: %s", network)
    39  	}
    40  
    41  	return nil
    42  }
    43  
    44  func setRecvPktinfo(fd int, network string) error {
    45  	// Set IP_PKTINFO for both v4 and v6.
    46  	if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, windows.IP_PKTINFO, 1); err != nil {
    47  		return fmt.Errorf("failed to set socket option IP_PKTINFO: %w", err)
    48  	}
    49  
    50  	switch network {
    51  	case "udp4":
    52  	case "udp6":
    53  		if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, windows.IPV6_PKTINFO, 1); err != nil {
    54  			return fmt.Errorf("failed to set socket option IPV6_PKTINFO: %w", err)
    55  		}
    56  	default:
    57  		return fmt.Errorf("unsupported network: %s", network)
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (lso ListenerSocketOptions) buildSetFns() setFuncSlice {
    64  	return setFuncSlice{}.
    65  		appendSetPMTUDFunc(lso.PathMTUDiscovery).
    66  		appendSetRecvPktinfoFunc(lso.ReceivePacketInfo)
    67  }
    68  
    69  // Structure CMSGHDR from ws2def.h
    70  type Cmsghdr struct {
    71  	Len   uint
    72  	Level int32
    73  	Type  int32
    74  }
    75  
    76  // Structure IN_PKTINFO from ws2ipdef.h
    77  type Inet4Pktinfo struct {
    78  	Addr    [4]byte
    79  	Ifindex uint32
    80  }
    81  
    82  // Structure IN6_PKTINFO from ws2ipdef.h
    83  type Inet6Pktinfo struct {
    84  	Addr    [16]byte
    85  	Ifindex uint32
    86  }
    87  
    88  const (
    89  	SizeofCmsghdr      = unsafe.Sizeof(Cmsghdr{})
    90  	SizeofInet4Pktinfo = unsafe.Sizeof(Inet4Pktinfo{})
    91  	SizeofInet6Pktinfo = unsafe.Sizeof(Inet6Pktinfo{})
    92  )
    93  
    94  const SizeofPtr = unsafe.Sizeof(uintptr(0))
    95  
    96  // SocketControlMessageBufferSize specifies the buffer size for receiving socket control messages.
    97  const SocketControlMessageBufferSize = SizeofCmsghdr + (SizeofInet6Pktinfo+SizeofPtr-1) & ^(SizeofPtr-1)
    98  
    99  // ParsePktinfoCmsg parses a single socket control message of type IP_PKTINFO or IPV6_PKTINFO,
   100  // and returns the IP address and index of the network interface the packet was received from,
   101  // or an error.
   102  //
   103  // This function is only implemented for Linux, macOS and Windows. On other platforms, this is a no-op.
   104  func ParsePktinfoCmsg(cmsg []byte) (netip.Addr, uint32, error) {
   105  	if len(cmsg) < int(SizeofCmsghdr) {
   106  		return netip.Addr{}, 0, fmt.Errorf("control message length %d is shorter than cmsghdr length", len(cmsg))
   107  	}
   108  
   109  	cmsghdr := (*Cmsghdr)(unsafe.Pointer(&cmsg[0]))
   110  
   111  	switch {
   112  	case cmsghdr.Level == windows.IPPROTO_IP && cmsghdr.Type == windows.IP_PKTINFO && len(cmsg) >= int(SizeofCmsghdr+SizeofInet4Pktinfo):
   113  		pktinfo := (*Inet4Pktinfo)(unsafe.Pointer(&cmsg[SizeofCmsghdr]))
   114  		return netip.AddrFrom4(pktinfo.Addr), pktinfo.Ifindex, nil
   115  
   116  	case cmsghdr.Level == windows.IPPROTO_IPV6 && cmsghdr.Type == windows.IPV6_PKTINFO && len(cmsg) >= int(SizeofCmsghdr+SizeofInet6Pktinfo):
   117  		pktinfo := (*Inet6Pktinfo)(unsafe.Pointer(&cmsg[SizeofCmsghdr]))
   118  		return netip.AddrFrom16(pktinfo.Addr), pktinfo.Ifindex, nil
   119  
   120  	default:
   121  		return netip.Addr{}, 0, fmt.Errorf("unknown control message level %d type %d", cmsghdr.Level, cmsghdr.Type)
   122  	}
   123  }