github.com/ipfans/trojan-go@v0.11.0/tunnel/tproxy/tcp.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package tproxy
     5  
     6  import (
     7  	"fmt"
     8  	"net"
     9  	"os"
    10  	"syscall"
    11  	"unsafe"
    12  )
    13  
    14  // Listener describes a TCP Listener
    15  // with the Linux IP_TRANSPARENT option defined
    16  // on the listening socket
    17  type Listener struct {
    18  	base net.Listener
    19  }
    20  
    21  // Accept waits for and returns
    22  // the next connection to the listener.
    23  //
    24  // This command wraps the AcceptTProxy
    25  // method of the Listener
    26  func (listener *Listener) Accept() (net.Conn, error) {
    27  	tcpConn, err := listener.base.(*net.TCPListener).AcceptTCP()
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	return tcpConn, nil
    33  }
    34  
    35  // Addr returns the network address
    36  // the listener is accepting connections
    37  // from
    38  func (listener *Listener) Addr() net.Addr {
    39  	return listener.base.Addr()
    40  }
    41  
    42  // Close will close the listener from accepting
    43  // any more connections. Any blocked connections
    44  // will unblock and close
    45  func (listener *Listener) Close() error {
    46  	return listener.base.Close()
    47  }
    48  
    49  // ListenTCP will construct a new TCP listener
    50  // socket with the Linux IP_TRANSPARENT option
    51  // set on the underlying socket
    52  func ListenTCP(network string, laddr *net.TCPAddr) (net.Listener, error) {
    53  	listener, err := net.ListenTCP(network, laddr)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	fileDescriptorSource, err := listener.File()
    59  	if err != nil {
    60  		return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: laddr, Err: fmt.Errorf("get file descriptor: %s", err)}
    61  	}
    62  	defer fileDescriptorSource.Close()
    63  
    64  	if err = syscall.SetsockoptInt(int(fileDescriptorSource.Fd()), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil {
    65  		return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: laddr, Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %s", err)}
    66  	}
    67  
    68  	return &Listener{listener}, nil
    69  }
    70  
    71  const (
    72  	IP6T_SO_ORIGINAL_DST = 80
    73  	SO_ORIGINAL_DST      = 80
    74  )
    75  
    76  // getOriginalTCPDest retrieves the original destination address from
    77  // NATed connection.  Currently, only Linux iptables using DNAT/REDIRECT
    78  // is supported.  For other operating systems, this will just return
    79  // conn.LocalAddr().
    80  //
    81  // Note that this function only works when nf_conntrack_ipv4 and/or
    82  // nf_conntrack_ipv6 is loaded in the kernel.
    83  func getOriginalTCPDest(conn *net.TCPConn) (*net.TCPAddr, error) {
    84  	f, err := conn.File()
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	defer f.Close()
    89  
    90  	fd := int(f.Fd())
    91  	// revert to non-blocking mode.
    92  	// see http://stackoverflow.com/a/28968431/1493661
    93  	if err = syscall.SetNonblock(fd, true); err != nil {
    94  		return nil, os.NewSyscallError("setnonblock", err)
    95  	}
    96  
    97  	v6 := conn.LocalAddr().(*net.TCPAddr).IP.To4() == nil
    98  	if v6 {
    99  		var addr syscall.RawSockaddrInet6
   100  		var len uint32
   101  		len = uint32(unsafe.Sizeof(addr))
   102  		err = getsockopt(fd, syscall.IPPROTO_IPV6, IP6T_SO_ORIGINAL_DST,
   103  			unsafe.Pointer(&addr), &len)
   104  		if err != nil {
   105  			return nil, os.NewSyscallError("getsockopt", err)
   106  		}
   107  		ip := make([]byte, 16)
   108  		for i, b := range addr.Addr {
   109  			ip[i] = b
   110  		}
   111  		pb := *(*[2]byte)(unsafe.Pointer(&addr.Port))
   112  		return &net.TCPAddr{
   113  			IP:   ip,
   114  			Port: int(pb[0])*256 + int(pb[1]),
   115  		}, nil
   116  	}
   117  
   118  	// IPv4
   119  	var addr syscall.RawSockaddrInet4
   120  	var len uint32
   121  	len = uint32(unsafe.Sizeof(addr))
   122  	err = getsockopt(fd, syscall.IPPROTO_IP, SO_ORIGINAL_DST,
   123  		unsafe.Pointer(&addr), &len)
   124  	if err != nil {
   125  		return nil, os.NewSyscallError("getsockopt", err)
   126  	}
   127  	ip := make([]byte, 4)
   128  	for i, b := range addr.Addr {
   129  		ip[i] = b
   130  	}
   131  	pb := *(*[2]byte)(unsafe.Pointer(&addr.Port))
   132  	return &net.TCPAddr{
   133  		IP:   ip,
   134  		Port: int(pb[0])*256 + int(pb[1]),
   135  	}, nil
   136  }