github.com/ipfans/trojan-go@v0.11.0/tunnel/tproxy/udp.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package tproxy
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/binary"
     9  	"fmt"
    10  	"net"
    11  	"os"
    12  	"strconv"
    13  	"syscall"
    14  	"unsafe"
    15  )
    16  
    17  // ListenUDP will construct a new UDP listener
    18  // socket with the Linux IP_TRANSPARENT option
    19  // set on the underlying socket
    20  func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
    21  	listener, err := net.ListenUDP(network, laddr)
    22  	if err != nil {
    23  		return nil, err
    24  	}
    25  
    26  	fileDescriptorSource, err := listener.File()
    27  	if err != nil {
    28  		return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: laddr, Err: fmt.Errorf("get file descriptor: %s", err)}
    29  	}
    30  	defer fileDescriptorSource.Close()
    31  
    32  	fileDescriptor := int(fileDescriptorSource.Fd())
    33  	if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil {
    34  		return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: laddr, Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %s", err)}
    35  	}
    36  
    37  	if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1); err != nil {
    38  		return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: laddr, Err: fmt.Errorf("set socket option: IP_RECVORIGDSTADDR: %s", err)}
    39  	}
    40  
    41  	return listener, nil
    42  }
    43  
    44  // ReadFromUDP reads a UDP packet from c, copying the payload into b.
    45  // It returns the number of bytes copied into b and the return address
    46  // that was on the packet.
    47  //
    48  // Out-of-band data is also read in so that the original destination
    49  // address can be identified and parsed.
    50  func ReadFromUDP(conn *net.UDPConn, b []byte) (int, *net.UDPAddr, *net.UDPAddr, error) {
    51  	oob := make([]byte, 1024)
    52  	n, oobn, _, addr, err := conn.ReadMsgUDP(b, oob)
    53  	if err != nil {
    54  		return 0, nil, nil, err
    55  	}
    56  
    57  	msgs, err := syscall.ParseSocketControlMessage(oob[:oobn])
    58  	if err != nil {
    59  		return 0, nil, nil, fmt.Errorf("parsing socket control message: %s", err)
    60  	}
    61  
    62  	var originalDst *net.UDPAddr
    63  	for _, msg := range msgs {
    64  		if (msg.Header.Level == syscall.SOL_IP || msg.Header.Level == syscall.SOL_IPV6) && msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
    65  			originalDstRaw := &syscall.RawSockaddrInet4{}
    66  			if err = binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originalDstRaw); err != nil {
    67  				return 0, nil, nil, fmt.Errorf("reading original destination address: %s", err)
    68  			}
    69  
    70  			switch originalDstRaw.Family {
    71  			case syscall.AF_INET:
    72  				pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw))
    73  				p := (*[2]byte)(unsafe.Pointer(&pp.Port))
    74  				originalDst = &net.UDPAddr{
    75  					IP:   net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
    76  					Port: int(p[0])<<8 + int(p[1]),
    77  				}
    78  
    79  			case syscall.AF_INET6:
    80  				pp := (*syscall.RawSockaddrInet6)(unsafe.Pointer(originalDstRaw))
    81  				p := (*[2]byte)(unsafe.Pointer(&pp.Port))
    82  				originalDst = &net.UDPAddr{
    83  					IP:   net.IP(pp.Addr[:]),
    84  					Port: int(p[0])<<8 + int(p[1]),
    85  					Zone: strconv.Itoa(int(pp.Scope_id)),
    86  				}
    87  
    88  			default:
    89  				return 0, nil, nil, fmt.Errorf("original destination is an unsupported network family")
    90  			}
    91  		}
    92  	}
    93  
    94  	if originalDst == nil {
    95  		return 0, nil, nil, fmt.Errorf("unable to obtain original destination: %s", err)
    96  	}
    97  
    98  	return n, addr, originalDst, nil
    99  }
   100  
   101  // DialUDP connects to the remote address raddr on the network net,
   102  // which must be "udp", "udp4", or "udp6".  If laddr is not nil, it is
   103  // used as the local address for the connection.
   104  func DialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
   105  	remoteSocketAddress, err := udpAddrToSocketAddr(raddr)
   106  	if err != nil {
   107  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build destination socket address: %s", err)}
   108  	}
   109  
   110  	localSocketAddress, err := udpAddrToSocketAddr(laddr)
   111  	if err != nil {
   112  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build local socket address: %s", err)}
   113  	}
   114  
   115  	fileDescriptor, err := syscall.Socket(udpAddrFamily(network, laddr, raddr), syscall.SOCK_DGRAM, 0)
   116  	if err != nil {
   117  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %s", err)}
   118  	}
   119  
   120  	if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
   121  		syscall.Close(fileDescriptor)
   122  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: SO_REUSEADDR: %s", err)}
   123  	}
   124  
   125  	if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil {
   126  		syscall.Close(fileDescriptor)
   127  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %s", err)}
   128  	}
   129  
   130  	if err = syscall.Bind(fileDescriptor, localSocketAddress); err != nil {
   131  		syscall.Close(fileDescriptor)
   132  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket bind: %s", err)}
   133  	}
   134  
   135  	if err = syscall.Connect(fileDescriptor, remoteSocketAddress); err != nil {
   136  		syscall.Close(fileDescriptor)
   137  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket connect: %s", err)}
   138  	}
   139  
   140  	fdFile := os.NewFile(uintptr(fileDescriptor), fmt.Sprintf("net-udp-dial-%s", raddr.String()))
   141  	defer fdFile.Close()
   142  
   143  	remoteConn, err := net.FileConn(fdFile)
   144  	if err != nil {
   145  		syscall.Close(fileDescriptor)
   146  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("convert file descriptor to connection: %s", err)}
   147  	}
   148  
   149  	return remoteConn.(*net.UDPConn), nil
   150  }
   151  
   152  // udpAddToSockerAddr will convert a UDPAddr
   153  // into a Sockaddr that may be used when
   154  // connecting and binding sockets
   155  func udpAddrToSocketAddr(addr *net.UDPAddr) (syscall.Sockaddr, error) {
   156  	switch {
   157  	case addr.IP.To4() != nil:
   158  		ip := [4]byte{}
   159  		copy(ip[:], addr.IP.To4())
   160  
   161  		return &syscall.SockaddrInet4{Addr: ip, Port: addr.Port}, nil
   162  
   163  	default:
   164  		ip := [16]byte{}
   165  		copy(ip[:], addr.IP.To16())
   166  
   167  		zoneID, err := strconv.ParseUint(addr.Zone, 10, 32)
   168  		if err != nil {
   169  			return nil, err
   170  		}
   171  
   172  		return &syscall.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil
   173  	}
   174  }
   175  
   176  // udpAddrFamily will attempt to work
   177  // out the address family based on the
   178  // network and UDP addresses
   179  func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int {
   180  	switch net[len(net)-1] {
   181  	case '4':
   182  		return syscall.AF_INET
   183  	case '6':
   184  		return syscall.AF_INET6
   185  	}
   186  
   187  	if (laddr == nil || laddr.IP.To4() != nil) &&
   188  		(raddr == nil || laddr.IP.To4() != nil) {
   189  		return syscall.AF_INET
   190  	}
   191  	return syscall.AF_INET6
   192  }