github.com/kelleygo/clashcore@v1.0.2/transport/hysteria/conns/faketcp/tcp_linux.go (about)

     1  //go:build linux && !no_fake_tcp
     2  // +build linux,!no_fake_tcp
     3  
     4  package faketcp
     5  
     6  import (
     7  	"crypto/rand"
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"net"
    14  	"sync"
    15  	"sync/atomic"
    16  	"syscall"
    17  	"time"
    18  
    19  	"github.com/coreos/go-iptables/iptables"
    20  	"github.com/metacubex/gopacket"
    21  	"github.com/metacubex/gopacket/layers"
    22  
    23  	"github.com/kelleygo/clashcore/component/dialer"
    24  )
    25  
    26  var (
    27  	errOpNotImplemented = errors.New("operation not implemented")
    28  	errTimeout          = errors.New("timeout")
    29  	expire              = time.Minute
    30  )
    31  
    32  // a message from NIC
    33  type message struct {
    34  	bts  []byte
    35  	addr net.Addr
    36  }
    37  
    38  // a tcp flow information of a connection pair
    39  type tcpFlow struct {
    40  	conn         *net.TCPConn               // the related system TCP connection of this flow
    41  	handle       *net.IPConn                // the handle to send packets
    42  	seq          uint32                     // TCP sequence number
    43  	ack          uint32                     // TCP acknowledge number
    44  	networkLayer gopacket.SerializableLayer // network layer header for tx
    45  	ts           time.Time                  // last packet incoming time
    46  	buf          gopacket.SerializeBuffer   // a buffer for write
    47  	tcpHeader    layers.TCP
    48  }
    49  
    50  // TCPConn defines a TCP-packet oriented connection
    51  type TCPConn struct {
    52  	die     chan struct{}
    53  	dieOnce sync.Once
    54  
    55  	// the main golang sockets
    56  	tcpconn  *net.TCPConn     // from net.Dial
    57  	listener *net.TCPListener // from net.Listen
    58  
    59  	// handles
    60  	handles []*net.IPConn
    61  
    62  	// packets captured from all related NICs will be delivered to this channel
    63  	chMessage chan message
    64  
    65  	// all TCP flows
    66  	flowTable map[string]*tcpFlow
    67  	flowsLock sync.Mutex
    68  
    69  	// iptables
    70  	iptables *iptables.IPTables
    71  	iprule   []string
    72  
    73  	ip6tables *iptables.IPTables
    74  	ip6rule   []string
    75  
    76  	// deadlines
    77  	readDeadline  atomic.Value
    78  	writeDeadline atomic.Value
    79  
    80  	// serialization
    81  	opts gopacket.SerializeOptions
    82  }
    83  
    84  // lockflow locks the flow table and apply function `f` to the entry, and create one if not exist
    85  func (conn *TCPConn) lockflow(addr net.Addr, f func(e *tcpFlow)) {
    86  	key := addr.String()
    87  	conn.flowsLock.Lock()
    88  	e := conn.flowTable[key]
    89  	if e == nil { // entry first visit
    90  		e = new(tcpFlow)
    91  		e.ts = time.Now()
    92  		e.buf = gopacket.NewSerializeBuffer()
    93  	}
    94  	f(e)
    95  	conn.flowTable[key] = e
    96  	conn.flowsLock.Unlock()
    97  }
    98  
    99  // clean expired flows
   100  func (conn *TCPConn) cleaner() {
   101  	ticker := time.NewTicker(time.Minute)
   102  	select {
   103  	case <-conn.die:
   104  		return
   105  	case <-ticker.C:
   106  		conn.flowsLock.Lock()
   107  		for k, v := range conn.flowTable {
   108  			if time.Now().Sub(v.ts) > expire {
   109  				if v.conn != nil {
   110  					setTTL(v.conn, 64)
   111  					v.conn.Close()
   112  				}
   113  				delete(conn.flowTable, k)
   114  			}
   115  		}
   116  		conn.flowsLock.Unlock()
   117  	}
   118  }
   119  
   120  // captureFlow capture every inbound packets based on rules of BPF
   121  func (conn *TCPConn) captureFlow(handle *net.IPConn, port int) {
   122  	buf := make([]byte, 2048)
   123  	opt := gopacket.DecodeOptions{NoCopy: true, Lazy: true}
   124  	for {
   125  		n, addr, err := handle.ReadFromIP(buf)
   126  		if err != nil {
   127  			return
   128  		}
   129  
   130  		// try decoding TCP frame from buf[:n]
   131  		packet := gopacket.NewPacket(buf[:n], layers.LayerTypeTCP, opt)
   132  		transport := packet.TransportLayer()
   133  		tcp, ok := transport.(*layers.TCP)
   134  		if !ok {
   135  			continue
   136  		}
   137  
   138  		// port filtering
   139  		if int(tcp.DstPort) != port {
   140  			continue
   141  		}
   142  
   143  		// address building
   144  		var src net.TCPAddr
   145  		src.IP = addr.IP
   146  		src.Port = int(tcp.SrcPort)
   147  
   148  		var orphan bool
   149  		// flow maintaince
   150  		conn.lockflow(&src, func(e *tcpFlow) {
   151  			if e.conn == nil { // make sure it's related to net.TCPConn
   152  				orphan = true // mark as orphan if it's not related net.TCPConn
   153  			}
   154  
   155  			// to keep track of TCP header related to this source
   156  			e.ts = time.Now()
   157  			if tcp.ACK {
   158  				e.seq = tcp.Ack
   159  			}
   160  			if tcp.SYN {
   161  				e.ack = tcp.Seq + 1
   162  			}
   163  			if tcp.PSH {
   164  				if e.ack == tcp.Seq {
   165  					e.ack = tcp.Seq + uint32(len(tcp.Payload))
   166  				}
   167  			}
   168  			e.handle = handle
   169  		})
   170  
   171  		// push data if it's not orphan
   172  		if !orphan && tcp.PSH {
   173  			payload := make([]byte, len(tcp.Payload))
   174  			copy(payload, tcp.Payload)
   175  			select {
   176  			case conn.chMessage <- message{payload, &src}:
   177  			case <-conn.die:
   178  				return
   179  			}
   180  		}
   181  	}
   182  }
   183  
   184  // ReadFrom implements the PacketConn ReadFrom method.
   185  func (conn *TCPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   186  	var timer *time.Timer
   187  	var deadline <-chan time.Time
   188  	if d, ok := conn.readDeadline.Load().(time.Time); ok && !d.IsZero() {
   189  		timer = time.NewTimer(time.Until(d))
   190  		defer timer.Stop()
   191  		deadline = timer.C
   192  	}
   193  
   194  	select {
   195  	case <-deadline:
   196  		return 0, nil, errTimeout
   197  	case <-conn.die:
   198  		return 0, nil, io.EOF
   199  	case packet := <-conn.chMessage:
   200  		n = copy(p, packet.bts)
   201  		return n, packet.addr, nil
   202  	}
   203  }
   204  
   205  // WriteTo implements the PacketConn WriteTo method.
   206  func (conn *TCPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   207  	var deadline <-chan time.Time
   208  	if d, ok := conn.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
   209  		timer := time.NewTimer(time.Until(d))
   210  		defer timer.Stop()
   211  		deadline = timer.C
   212  	}
   213  
   214  	select {
   215  	case <-deadline:
   216  		return 0, errTimeout
   217  	case <-conn.die:
   218  		return 0, io.EOF
   219  	default:
   220  		raddr, err := net.ResolveTCPAddr("tcp", addr.String())
   221  		if err != nil {
   222  			return 0, err
   223  		}
   224  
   225  		var lport int
   226  		if conn.tcpconn != nil {
   227  			lport = conn.tcpconn.LocalAddr().(*net.TCPAddr).Port
   228  		} else {
   229  			lport = conn.listener.Addr().(*net.TCPAddr).Port
   230  		}
   231  
   232  		conn.lockflow(addr, func(e *tcpFlow) {
   233  			// if the flow doesn't have handle , assume this packet has lost, without notification
   234  			if e.handle == nil {
   235  				n = len(p)
   236  				return
   237  			}
   238  
   239  			// build tcp header with local and remote port
   240  			e.tcpHeader.SrcPort = layers.TCPPort(lport)
   241  			e.tcpHeader.DstPort = layers.TCPPort(raddr.Port)
   242  			binary.Read(rand.Reader, binary.LittleEndian, &e.tcpHeader.Window)
   243  			e.tcpHeader.Window |= 0x8000 // make sure it's larger than 32768
   244  			e.tcpHeader.Ack = e.ack
   245  			e.tcpHeader.Seq = e.seq
   246  			e.tcpHeader.PSH = true
   247  			e.tcpHeader.ACK = true
   248  
   249  			// build IP header with src & dst ip for TCP checksum
   250  			if raddr.IP.To4() != nil {
   251  				ip := &layers.IPv4{
   252  					Protocol: layers.IPProtocolTCP,
   253  					SrcIP:    e.handle.LocalAddr().(*net.IPAddr).IP.To4(),
   254  					DstIP:    raddr.IP.To4(),
   255  				}
   256  				e.tcpHeader.SetNetworkLayerForChecksum(ip)
   257  			} else {
   258  				ip := &layers.IPv6{
   259  					NextHeader: layers.IPProtocolTCP,
   260  					SrcIP:      e.handle.LocalAddr().(*net.IPAddr).IP.To16(),
   261  					DstIP:      raddr.IP.To16(),
   262  				}
   263  				e.tcpHeader.SetNetworkLayerForChecksum(ip)
   264  			}
   265  
   266  			e.buf.Clear()
   267  			gopacket.SerializeLayers(e.buf, conn.opts, &e.tcpHeader, gopacket.Payload(p))
   268  			if conn.tcpconn != nil {
   269  				_, err = e.handle.Write(e.buf.Bytes())
   270  			} else {
   271  				_, err = e.handle.WriteToIP(e.buf.Bytes(), &net.IPAddr{IP: raddr.IP})
   272  			}
   273  			// increase seq in flow
   274  			e.seq += uint32(len(p))
   275  			n = len(p)
   276  		})
   277  	}
   278  	return
   279  }
   280  
   281  // Close closes the connection.
   282  func (conn *TCPConn) Close() error {
   283  	var err error
   284  	conn.dieOnce.Do(func() {
   285  		// signal closing
   286  		close(conn.die)
   287  
   288  		// close all established tcp connections
   289  		if conn.tcpconn != nil { // client
   290  			setTTL(conn.tcpconn, 64)
   291  			err = conn.tcpconn.Close()
   292  		} else if conn.listener != nil {
   293  			err = conn.listener.Close() // server
   294  			conn.flowsLock.Lock()
   295  			for k, v := range conn.flowTable {
   296  				if v.conn != nil {
   297  					setTTL(v.conn, 64)
   298  					v.conn.Close()
   299  				}
   300  				delete(conn.flowTable, k)
   301  			}
   302  			conn.flowsLock.Unlock()
   303  		}
   304  
   305  		// close handles
   306  		for k := range conn.handles {
   307  			conn.handles[k].Close()
   308  		}
   309  
   310  		// delete iptable
   311  		if conn.iptables != nil {
   312  			conn.iptables.Delete("filter", "OUTPUT", conn.iprule...)
   313  		}
   314  		if conn.ip6tables != nil {
   315  			conn.ip6tables.Delete("filter", "OUTPUT", conn.ip6rule...)
   316  		}
   317  	})
   318  	return err
   319  }
   320  
   321  // LocalAddr returns the local network address.
   322  func (conn *TCPConn) LocalAddr() net.Addr {
   323  	if conn.tcpconn != nil {
   324  		return conn.tcpconn.LocalAddr()
   325  	} else if conn.listener != nil {
   326  		return conn.listener.Addr()
   327  	}
   328  	return nil
   329  }
   330  
   331  // SetDeadline implements the Conn SetDeadline method.
   332  func (conn *TCPConn) SetDeadline(t time.Time) error {
   333  	if err := conn.SetReadDeadline(t); err != nil {
   334  		return err
   335  	}
   336  	if err := conn.SetWriteDeadline(t); err != nil {
   337  		return err
   338  	}
   339  	return nil
   340  }
   341  
   342  // SetReadDeadline implements the Conn SetReadDeadline method.
   343  func (conn *TCPConn) SetReadDeadline(t time.Time) error {
   344  	conn.readDeadline.Store(t)
   345  	return nil
   346  }
   347  
   348  // SetWriteDeadline implements the Conn SetWriteDeadline method.
   349  func (conn *TCPConn) SetWriteDeadline(t time.Time) error {
   350  	conn.writeDeadline.Store(t)
   351  	return nil
   352  }
   353  
   354  // SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header.
   355  func (conn *TCPConn) SetDSCP(dscp int) error {
   356  	for k := range conn.handles {
   357  		if err := setDSCP(conn.handles[k], dscp); err != nil {
   358  			return err
   359  		}
   360  	}
   361  	return nil
   362  }
   363  
   364  // SetReadBuffer sets the size of the operating system's receive buffer associated with the connection.
   365  func (conn *TCPConn) SetReadBuffer(bytes int) error {
   366  	var err error
   367  	for k := range conn.handles {
   368  		if err := conn.handles[k].SetReadBuffer(bytes); err != nil {
   369  			return err
   370  		}
   371  	}
   372  	return err
   373  }
   374  
   375  // SetWriteBuffer sets the size of the operating system's transmit buffer associated with the connection.
   376  func (conn *TCPConn) SetWriteBuffer(bytes int) error {
   377  	var err error
   378  	for k := range conn.handles {
   379  		if err := conn.handles[k].SetWriteBuffer(bytes); err != nil {
   380  			return err
   381  		}
   382  	}
   383  	return err
   384  }
   385  
   386  func (conn *TCPConn) SyscallConn() (syscall.RawConn, error) {
   387  	if len(conn.handles) == 0 {
   388  		return nil, errors.New("no handles")
   389  		// How is it possible?
   390  	}
   391  	return conn.handles[0].SyscallConn()
   392  }
   393  
   394  // Dial connects to the remote TCP port,
   395  // and returns a single packet-oriented connection
   396  func Dial(network, address string) (*TCPConn, error) {
   397  	// init gopacket.layers
   398  	layers.Init()
   399  	// remote address resolve
   400  	raddr, err := net.ResolveTCPAddr(network, address)
   401  	if err != nil {
   402  		return nil, err
   403  	}
   404  
   405  	var lTcpAddr *net.TCPAddr
   406  	var lIpAddr *net.IPAddr
   407  	if ifaceName := dialer.DefaultInterface.Load(); len(ifaceName) > 0 {
   408  		rAddrPort := raddr.AddrPort()
   409  		addr, err := dialer.LookupLocalAddrFromIfaceName(ifaceName, network, rAddrPort.Addr(), int(rAddrPort.Port()))
   410  		if err != nil {
   411  			return nil, err
   412  		}
   413  		lTcpAddr = addr.(*net.TCPAddr)
   414  		lIpAddr = &net.IPAddr{IP: lTcpAddr.IP}
   415  	}
   416  
   417  	// AF_INET
   418  	handle, err := net.DialIP("ip:tcp", lIpAddr, &net.IPAddr{IP: raddr.IP})
   419  	if err != nil {
   420  		return nil, err
   421  	}
   422  
   423  	// create an established tcp connection
   424  	// will hack this tcp connection for packet transmission
   425  	tcpconn, err := net.DialTCP(network, lTcpAddr, raddr)
   426  	if err != nil {
   427  		return nil, err
   428  	}
   429  
   430  	// fields
   431  	conn := new(TCPConn)
   432  	conn.die = make(chan struct{})
   433  	conn.flowTable = make(map[string]*tcpFlow)
   434  	conn.tcpconn = tcpconn
   435  	conn.chMessage = make(chan message)
   436  	conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn })
   437  	conn.handles = append(conn.handles, handle)
   438  	conn.opts = gopacket.SerializeOptions{
   439  		FixLengths:       true,
   440  		ComputeChecksums: true,
   441  	}
   442  	go conn.captureFlow(handle, tcpconn.LocalAddr().(*net.TCPAddr).Port)
   443  	go conn.cleaner()
   444  
   445  	// iptables
   446  	err = setTTL(tcpconn, 1)
   447  	if err != nil {
   448  		return nil, err
   449  	}
   450  
   451  	if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil {
   452  		rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"}
   453  		if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
   454  			if !exists {
   455  				if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
   456  					conn.iprule = rule
   457  					conn.iptables = ipt
   458  				}
   459  			}
   460  		}
   461  	}
   462  	if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil {
   463  		rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"}
   464  		if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
   465  			if !exists {
   466  				if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
   467  					conn.ip6rule = rule
   468  					conn.ip6tables = ipt
   469  				}
   470  			}
   471  		}
   472  	}
   473  
   474  	// discard everything
   475  	go io.Copy(ioutil.Discard, tcpconn)
   476  
   477  	return conn, nil
   478  }
   479  
   480  // Listen acts like net.ListenTCP,
   481  // and returns a single packet-oriented connection
   482  func Listen(network, address string) (*TCPConn, error) {
   483  	// init gopacket.layers
   484  	layers.Init()
   485  	// fields
   486  	conn := new(TCPConn)
   487  	conn.flowTable = make(map[string]*tcpFlow)
   488  	conn.die = make(chan struct{})
   489  	conn.chMessage = make(chan message)
   490  	conn.opts = gopacket.SerializeOptions{
   491  		FixLengths:       true,
   492  		ComputeChecksums: true,
   493  	}
   494  
   495  	// resolve address
   496  	laddr, err := net.ResolveTCPAddr(network, address)
   497  	if err != nil {
   498  		return nil, err
   499  	}
   500  
   501  	// AF_INET
   502  	ifaces, err := net.Interfaces()
   503  	if err != nil {
   504  		return nil, err
   505  	}
   506  
   507  	if laddr.IP == nil || laddr.IP.IsUnspecified() { // if address is not specified, capture on all ifaces
   508  		var lasterr error
   509  		for _, iface := range ifaces {
   510  			if addrs, err := iface.Addrs(); err == nil {
   511  				for _, addr := range addrs {
   512  					if ipaddr, ok := addr.(*net.IPNet); ok {
   513  						if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: ipaddr.IP}); err == nil {
   514  							conn.handles = append(conn.handles, handle)
   515  							go conn.captureFlow(handle, laddr.Port)
   516  						} else {
   517  							lasterr = err
   518  						}
   519  					}
   520  				}
   521  			}
   522  		}
   523  		if len(conn.handles) == 0 {
   524  			return nil, lasterr
   525  		}
   526  	} else {
   527  		if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: laddr.IP}); err == nil {
   528  			conn.handles = append(conn.handles, handle)
   529  			go conn.captureFlow(handle, laddr.Port)
   530  		} else {
   531  			return nil, err
   532  		}
   533  	}
   534  
   535  	// start listening
   536  	l, err := net.ListenTCP(network, laddr)
   537  	if err != nil {
   538  		return nil, err
   539  	}
   540  
   541  	conn.listener = l
   542  
   543  	// start cleaner
   544  	go conn.cleaner()
   545  
   546  	// iptables drop packets marked with TTL = 1
   547  	// TODO: what if iptables is not available, the next hop will send back ICMP Time Exceeded,
   548  	// is this still an acceptable behavior?
   549  	if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil {
   550  		rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"}
   551  		if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
   552  			if !exists {
   553  				if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
   554  					conn.iprule = rule
   555  					conn.iptables = ipt
   556  				}
   557  			}
   558  		}
   559  	}
   560  	if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil {
   561  		rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"}
   562  		if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil {
   563  			if !exists {
   564  				if err = ipt.Append("filter", "OUTPUT", rule...); err == nil {
   565  					conn.ip6rule = rule
   566  					conn.ip6tables = ipt
   567  				}
   568  			}
   569  		}
   570  	}
   571  
   572  	// discard everything in original connection
   573  	go func() {
   574  		for {
   575  			tcpconn, err := l.AcceptTCP()
   576  			if err != nil {
   577  				return
   578  			}
   579  
   580  			// if we cannot set TTL = 1, the only thing reasonable is panic
   581  			if err := setTTL(tcpconn, 1); err != nil {
   582  				panic(err)
   583  			}
   584  
   585  			// record net.Conn
   586  			conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn })
   587  
   588  			// discard everything
   589  			go io.Copy(ioutil.Discard, tcpconn)
   590  		}
   591  	}()
   592  
   593  	return conn, nil
   594  }
   595  
   596  // setTTL sets the Time-To-Live field on a given connection
   597  func setTTL(c *net.TCPConn, ttl int) error {
   598  	raw, err := c.SyscallConn()
   599  	if err != nil {
   600  		return err
   601  	}
   602  	addr := c.LocalAddr().(*net.TCPAddr)
   603  
   604  	if addr.IP.To4() == nil {
   605  		raw.Control(func(fd uintptr) {
   606  			err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, ttl)
   607  		})
   608  	} else {
   609  		raw.Control(func(fd uintptr) {
   610  			err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TTL, ttl)
   611  		})
   612  	}
   613  	return err
   614  }
   615  
   616  // setDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header.
   617  func setDSCP(c *net.IPConn, dscp int) error {
   618  	raw, err := c.SyscallConn()
   619  	if err != nil {
   620  		return err
   621  	}
   622  	addr := c.LocalAddr().(*net.IPAddr)
   623  
   624  	if addr.IP.To4() == nil {
   625  		raw.Control(func(fd uintptr) {
   626  			err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, dscp)
   627  		})
   628  	} else {
   629  		raw.Control(func(fd uintptr) {
   630  			err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, dscp<<2)
   631  		})
   632  	}
   633  	return err
   634  }