github.com/metacubex/mihomo@v1.18.5/listener/tproxy/udp_linux.go (about)

     1  //go:build linux
     2  
     3  package tproxy
     4  
     5  import (
     6  	"fmt"
     7  	"net"
     8  	"net/netip"
     9  	"os"
    10  	"strconv"
    11  	"syscall"
    12  
    13  	"golang.org/x/sys/unix"
    14  )
    15  
    16  const (
    17  	IPV6_TRANSPARENT     = 0x4b
    18  	IPV6_RECVORIGDSTADDR = 0x4a
    19  )
    20  
    21  // dialUDP acts like net.DialUDP for transparent proxy.
    22  // It binds to a non-local address(`lAddr`).
    23  func dialUDP(network string, lAddr, rAddr netip.AddrPort) (uc *net.UDPConn, err error) {
    24  	rSockAddr, err := udpAddrToSockAddr(rAddr)
    25  	if err != nil {
    26  		return nil, err
    27  	}
    28  
    29  	lSockAddr, err := udpAddrToSockAddr(lAddr)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  
    34  	fd, err := syscall.Socket(udpAddrFamily(network, lAddr, rAddr), syscall.SOCK_DGRAM, 0)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	defer func() {
    40  		if err != nil {
    41  			syscall.Close(fd)
    42  		}
    43  	}()
    44  
    45  	if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	if err = syscall.SetsockoptInt(fd, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	if err = syscall.Bind(fd, lSockAddr); err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	if err = syscall.Connect(fd, rSockAddr); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	fdFile := os.NewFile(uintptr(fd), fmt.Sprintf("net-udp-dial-%s", rAddr.String()))
    62  	defer fdFile.Close()
    63  
    64  	c, err := net.FileConn(fdFile)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	return c.(*net.UDPConn), nil
    70  }
    71  
    72  func udpAddrToSockAddr(addr netip.AddrPort) (syscall.Sockaddr, error) {
    73  	if addr.Addr().Is4() {
    74  		return &syscall.SockaddrInet4{Addr: addr.Addr().As4(), Port: int(addr.Port())}, nil
    75  	}
    76  
    77  	zoneID, err := strconv.ParseUint(addr.Addr().Zone(), 10, 32)
    78  	if err != nil {
    79  		zoneID = 0
    80  	}
    81  
    82  	return &syscall.SockaddrInet6{Addr: addr.Addr().As16(), Port: int(addr.Port()), ZoneId: uint32(zoneID)}, nil
    83  }
    84  
    85  func udpAddrFamily(net string, lAddr, rAddr netip.AddrPort) int {
    86  	switch net[len(net)-1] {
    87  	case '4':
    88  		return syscall.AF_INET
    89  	case '6':
    90  		return syscall.AF_INET6
    91  	}
    92  
    93  	if lAddr.Addr().Is4() && rAddr.Addr().Is4() {
    94  		return syscall.AF_INET
    95  	}
    96  	return syscall.AF_INET6
    97  }
    98  
    99  func getOrigDst(oob []byte) (netip.AddrPort, error) {
   100  	// oob contains socket control messages which we need to parse.
   101  	scms, err := unix.ParseSocketControlMessage(oob)
   102  	if err != nil {
   103  		return netip.AddrPort{}, fmt.Errorf("parse control message: %w", err)
   104  	}
   105  
   106  	// retrieve the destination address from the SCM.
   107  	var sa unix.Sockaddr
   108  	for i := range scms {
   109  		sa, err = unix.ParseOrigDstAddr(&scms[i])
   110  		if err == nil {
   111  			break
   112  		}
   113  	}
   114  
   115  	if err != nil {
   116  		return netip.AddrPort{}, fmt.Errorf("retrieve destination: %w", err)
   117  	}
   118  
   119  	// encode the destination address into a cmsg.
   120  	var rAddr netip.AddrPort
   121  	switch v := sa.(type) {
   122  	case *unix.SockaddrInet4:
   123  		rAddr = netip.AddrPortFrom(netip.AddrFrom4(v.Addr), uint16(v.Port))
   124  	case *unix.SockaddrInet6:
   125  		rAddr = netip.AddrPortFrom(netip.AddrFrom16(v.Addr), uint16(v.Port))
   126  	default:
   127  		return netip.AddrPort{}, fmt.Errorf("unsupported address type: %T", v)
   128  	}
   129  
   130  	return rAddr, nil
   131  }
   132  
   133  func getDSCP(oob []byte) (uint8, error) {
   134  	scms, err := unix.ParseSocketControlMessage(oob)
   135  	if err != nil {
   136  		return 0, fmt.Errorf("parse control message: %w", err)
   137  	}
   138  	var dscp uint8
   139  	for i := range scms {
   140  		dscp, err = parseDSCP(&scms[i])
   141  		if err == nil {
   142  			break
   143  		}
   144  	}
   145  
   146  	if err != nil {
   147  		return 0, fmt.Errorf("retrieve DSCP: %w", err)
   148  	}
   149  	return dscp, nil
   150  }
   151  
   152  func parseDSCP(m *unix.SocketControlMessage) (uint8, error) {
   153  	switch {
   154  	case m.Header.Level == unix.SOL_IP && m.Header.Type == unix.IP_TOS:
   155  		dscp := uint8(m.Data[0] >> 2)
   156  		return dscp, nil
   157  
   158  	case m.Header.Level == unix.SOL_IPV6 && m.Header.Type == unix.IPV6_TCLASS:
   159  		dscp := uint8(m.Data[0] >> 2)
   160  		return dscp, nil
   161  
   162  	default:
   163  		return 0, nil
   164  	}
   165  }