github.com/EagleQL/Xray-core@v1.4.3/app/proxyman/inbound/worker.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/xtls/xray-core/app/proxyman"
    10  	"github.com/xtls/xray-core/common"
    11  	"github.com/xtls/xray-core/common/buf"
    12  	"github.com/xtls/xray-core/common/net"
    13  	"github.com/xtls/xray-core/common/serial"
    14  	"github.com/xtls/xray-core/common/session"
    15  	"github.com/xtls/xray-core/common/signal/done"
    16  	"github.com/xtls/xray-core/common/task"
    17  	"github.com/xtls/xray-core/features/routing"
    18  	"github.com/xtls/xray-core/features/stats"
    19  	"github.com/xtls/xray-core/proxy"
    20  	"github.com/xtls/xray-core/transport/internet"
    21  	"github.com/xtls/xray-core/transport/internet/tcp"
    22  	"github.com/xtls/xray-core/transport/internet/udp"
    23  	"github.com/xtls/xray-core/transport/pipe"
    24  )
    25  
    26  type worker interface {
    27  	Start() error
    28  	Close() error
    29  	Port() net.Port
    30  	Proxy() proxy.Inbound
    31  }
    32  
    33  type tcpWorker struct {
    34  	address         net.Address
    35  	port            net.Port
    36  	proxy           proxy.Inbound
    37  	stream          *internet.MemoryStreamConfig
    38  	recvOrigDest    bool
    39  	tag             string
    40  	dispatcher      routing.Dispatcher
    41  	sniffingConfig  *proxyman.SniffingConfig
    42  	uplinkCounter   stats.Counter
    43  	downlinkCounter stats.Counter
    44  
    45  	hub internet.Listener
    46  
    47  	ctx context.Context
    48  }
    49  
    50  func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode {
    51  	if s == nil || s.SocketSettings == nil {
    52  		return internet.SocketConfig_Off
    53  	}
    54  	return s.SocketSettings.Tproxy
    55  }
    56  
    57  func (w *tcpWorker) callback(conn internet.Connection) {
    58  	ctx, cancel := context.WithCancel(w.ctx)
    59  	sid := session.NewID()
    60  	ctx = session.ContextWithID(ctx, sid)
    61  
    62  	if w.recvOrigDest {
    63  		var dest net.Destination
    64  		switch getTProxyType(w.stream) {
    65  		case internet.SocketConfig_Redirect:
    66  			d, err := tcp.GetOriginalDestination(conn)
    67  			if err != nil {
    68  				newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx))
    69  			} else {
    70  				dest = d
    71  			}
    72  		case internet.SocketConfig_TProxy:
    73  			dest = net.DestinationFromAddr(conn.LocalAddr())
    74  		}
    75  		if dest.IsValid() {
    76  			ctx = session.ContextWithOutbound(ctx, &session.Outbound{
    77  				Target: dest,
    78  			})
    79  		}
    80  	}
    81  
    82  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
    83  		conn = &internet.StatCouterConnection{
    84  			Connection:   conn,
    85  			ReadCounter:  w.uplinkCounter,
    86  			WriteCounter: w.downlinkCounter,
    87  		}
    88  	}
    89  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
    90  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
    91  		Gateway: net.TCPDestination(w.address, w.port),
    92  		Tag:     w.tag,
    93  		Conn:    conn,
    94  	})
    95  
    96  	content := new(session.Content)
    97  	if w.sniffingConfig != nil {
    98  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
    99  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   100  		content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
   101  		content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   102  	}
   103  	ctx = session.ContextWithContent(ctx, content)
   104  
   105  	if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil {
   106  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   107  	}
   108  	cancel()
   109  	if err := conn.Close(); err != nil {
   110  		newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   111  	}
   112  }
   113  
   114  func (w *tcpWorker) Proxy() proxy.Inbound {
   115  	return w.proxy
   116  }
   117  
   118  func (w *tcpWorker) Start() error {
   119  	ctx := context.Background()
   120  	hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn internet.Connection) {
   121  		go w.callback(conn)
   122  	})
   123  	if err != nil {
   124  		return newError("failed to listen TCP on ", w.port).AtWarning().Base(err)
   125  	}
   126  	w.hub = hub
   127  	return nil
   128  }
   129  
   130  func (w *tcpWorker) Close() error {
   131  	var errors []interface{}
   132  	if w.hub != nil {
   133  		if err := common.Close(w.hub); err != nil {
   134  			errors = append(errors, err)
   135  		}
   136  		if err := common.Close(w.proxy); err != nil {
   137  			errors = append(errors, err)
   138  		}
   139  	}
   140  	if len(errors) > 0 {
   141  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   142  	}
   143  
   144  	return nil
   145  }
   146  
   147  func (w *tcpWorker) Port() net.Port {
   148  	return w.port
   149  }
   150  
   151  type udpConn struct {
   152  	lastActivityTime int64 // in seconds
   153  	reader           buf.Reader
   154  	writer           buf.Writer
   155  	output           func([]byte) (int, error)
   156  	remote           net.Addr
   157  	local            net.Addr
   158  	done             *done.Instance
   159  	uplink           stats.Counter
   160  	downlink         stats.Counter
   161  }
   162  
   163  func (c *udpConn) updateActivity() {
   164  	atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
   165  }
   166  
   167  // ReadMultiBuffer implements buf.Reader
   168  func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
   169  	mb, err := c.reader.ReadMultiBuffer()
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	c.updateActivity()
   174  
   175  	if c.uplink != nil {
   176  		c.uplink.Add(int64(mb.Len()))
   177  	}
   178  
   179  	return mb, nil
   180  }
   181  
   182  func (c *udpConn) Read(buf []byte) (int, error) {
   183  	panic("not implemented")
   184  }
   185  
   186  // Write implements io.Writer.
   187  func (c *udpConn) Write(buf []byte) (int, error) {
   188  	n, err := c.output(buf)
   189  	if c.downlink != nil {
   190  		c.downlink.Add(int64(n))
   191  	}
   192  	if err == nil {
   193  		c.updateActivity()
   194  	}
   195  	return n, err
   196  }
   197  
   198  func (c *udpConn) Close() error {
   199  	common.Must(c.done.Close())
   200  	common.Must(common.Close(c.writer))
   201  	return nil
   202  }
   203  
   204  func (c *udpConn) RemoteAddr() net.Addr {
   205  	return c.remote
   206  }
   207  
   208  func (c *udpConn) LocalAddr() net.Addr {
   209  	return c.local
   210  }
   211  
   212  func (*udpConn) SetDeadline(time.Time) error {
   213  	return nil
   214  }
   215  
   216  func (*udpConn) SetReadDeadline(time.Time) error {
   217  	return nil
   218  }
   219  
   220  func (*udpConn) SetWriteDeadline(time.Time) error {
   221  	return nil
   222  }
   223  
   224  type connID struct {
   225  	src  net.Destination
   226  	dest net.Destination
   227  }
   228  
   229  type udpWorker struct {
   230  	sync.RWMutex
   231  
   232  	proxy           proxy.Inbound
   233  	hub             *udp.Hub
   234  	address         net.Address
   235  	port            net.Port
   236  	tag             string
   237  	stream          *internet.MemoryStreamConfig
   238  	dispatcher      routing.Dispatcher
   239  	sniffingConfig  *proxyman.SniffingConfig
   240  	uplinkCounter   stats.Counter
   241  	downlinkCounter stats.Counter
   242  
   243  	checker    *task.Periodic
   244  	activeConn map[connID]*udpConn
   245  
   246  	ctx  context.Context
   247  	cone bool
   248  }
   249  
   250  func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
   251  	w.Lock()
   252  	defer w.Unlock()
   253  
   254  	if conn, found := w.activeConn[id]; found && !conn.done.Done() {
   255  		return conn, true
   256  	}
   257  
   258  	pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
   259  	conn := &udpConn{
   260  		reader: pReader,
   261  		writer: pWriter,
   262  		output: func(b []byte) (int, error) {
   263  			return w.hub.WriteTo(b, id.src)
   264  		},
   265  		remote: &net.UDPAddr{
   266  			IP:   id.src.Address.IP(),
   267  			Port: int(id.src.Port),
   268  		},
   269  		local: &net.UDPAddr{
   270  			IP:   w.address.IP(),
   271  			Port: int(w.port),
   272  		},
   273  		done:     done.New(),
   274  		uplink:   w.uplinkCounter,
   275  		downlink: w.downlinkCounter,
   276  	}
   277  	w.activeConn[id] = conn
   278  
   279  	conn.updateActivity()
   280  	return conn, false
   281  }
   282  
   283  func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
   284  	id := connID{
   285  		src: source,
   286  	}
   287  	if originalDest.IsValid() {
   288  		if !w.cone {
   289  			id.dest = originalDest
   290  		}
   291  		b.UDP = &originalDest
   292  	}
   293  	conn, existing := w.getConnection(id)
   294  
   295  	// payload will be discarded in pipe is full.
   296  	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
   297  
   298  	if !existing {
   299  		common.Must(w.checker.Start())
   300  
   301  		go func() {
   302  			ctx := w.ctx
   303  			sid := session.NewID()
   304  			ctx = session.ContextWithID(ctx, sid)
   305  
   306  			if originalDest.IsValid() {
   307  				ctx = session.ContextWithOutbound(ctx, &session.Outbound{
   308  					Target: originalDest,
   309  				})
   310  			}
   311  			ctx = session.ContextWithInbound(ctx, &session.Inbound{
   312  				Source:  source,
   313  				Gateway: net.UDPDestination(w.address, w.port),
   314  				Tag:     w.tag,
   315  			})
   316  			content := new(session.Content)
   317  			if w.sniffingConfig != nil {
   318  				content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   319  				content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   320  				content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   321  			}
   322  			ctx = session.ContextWithContent(ctx, content)
   323  			if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
   324  				newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   325  			}
   326  			conn.Close()
   327  			w.removeConn(id)
   328  		}()
   329  	}
   330  }
   331  
   332  func (w *udpWorker) removeConn(id connID) {
   333  	w.Lock()
   334  	delete(w.activeConn, id)
   335  	w.Unlock()
   336  }
   337  
   338  func (w *udpWorker) handlePackets() {
   339  	receive := w.hub.Receive()
   340  	for payload := range receive {
   341  		w.callback(payload.Payload, payload.Source, payload.Target)
   342  	}
   343  }
   344  
   345  func (w *udpWorker) clean() error {
   346  	nowSec := time.Now().Unix()
   347  	w.Lock()
   348  	defer w.Unlock()
   349  
   350  	if len(w.activeConn) == 0 {
   351  		return newError("no more connections. stopping...")
   352  	}
   353  
   354  	for addr, conn := range w.activeConn {
   355  		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 300 {
   356  			delete(w.activeConn, addr)
   357  			conn.Close()
   358  		}
   359  	}
   360  
   361  	if len(w.activeConn) == 0 {
   362  		w.activeConn = make(map[connID]*udpConn, 16)
   363  	}
   364  
   365  	return nil
   366  }
   367  
   368  func (w *udpWorker) Start() error {
   369  	w.activeConn = make(map[connID]*udpConn, 16)
   370  	ctx := context.Background()
   371  	h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256))
   372  	if err != nil {
   373  		return err
   374  	}
   375  
   376  	w.cone = w.ctx.Value("cone").(bool)
   377  
   378  	w.checker = &task.Periodic{
   379  		Interval: time.Minute,
   380  		Execute:  w.clean,
   381  	}
   382  
   383  	w.hub = h
   384  	go w.handlePackets()
   385  	return nil
   386  }
   387  
   388  func (w *udpWorker) Close() error {
   389  	w.Lock()
   390  	defer w.Unlock()
   391  
   392  	var errors []interface{}
   393  
   394  	if w.hub != nil {
   395  		if err := w.hub.Close(); err != nil {
   396  			errors = append(errors, err)
   397  		}
   398  	}
   399  
   400  	if w.checker != nil {
   401  		if err := w.checker.Close(); err != nil {
   402  			errors = append(errors, err)
   403  		}
   404  	}
   405  
   406  	if err := common.Close(w.proxy); err != nil {
   407  		errors = append(errors, err)
   408  	}
   409  
   410  	if len(errors) > 0 {
   411  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   412  	}
   413  	return nil
   414  }
   415  
   416  func (w *udpWorker) Port() net.Port {
   417  	return w.port
   418  }
   419  
   420  func (w *udpWorker) Proxy() proxy.Inbound {
   421  	return w.proxy
   422  }
   423  
   424  type dsWorker struct {
   425  	address         net.Address
   426  	proxy           proxy.Inbound
   427  	stream          *internet.MemoryStreamConfig
   428  	tag             string
   429  	dispatcher      routing.Dispatcher
   430  	sniffingConfig  *proxyman.SniffingConfig
   431  	uplinkCounter   stats.Counter
   432  	downlinkCounter stats.Counter
   433  
   434  	hub internet.Listener
   435  
   436  	ctx context.Context
   437  }
   438  
   439  func (w *dsWorker) callback(conn internet.Connection) {
   440  	ctx, cancel := context.WithCancel(w.ctx)
   441  	sid := session.NewID()
   442  	ctx = session.ContextWithID(ctx, sid)
   443  
   444  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
   445  		conn = &internet.StatCouterConnection{
   446  			Connection:   conn,
   447  			ReadCounter:  w.uplinkCounter,
   448  			WriteCounter: w.downlinkCounter,
   449  		}
   450  	}
   451  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
   452  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
   453  		Gateway: net.UnixDestination(w.address),
   454  		Tag:     w.tag,
   455  		Conn:    conn,
   456  	})
   457  
   458  	content := new(session.Content)
   459  	if w.sniffingConfig != nil {
   460  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   461  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   462  		content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
   463  		content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   464  	}
   465  	ctx = session.ContextWithContent(ctx, content)
   466  
   467  	if err := w.proxy.Process(ctx, net.Network_UNIX, conn, w.dispatcher); err != nil {
   468  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   469  	}
   470  	cancel()
   471  	if err := conn.Close(); err != nil {
   472  		newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   473  	}
   474  }
   475  
   476  func (w *dsWorker) Proxy() proxy.Inbound {
   477  	return w.proxy
   478  }
   479  
   480  func (w *dsWorker) Port() net.Port {
   481  	return net.Port(0)
   482  }
   483  func (w *dsWorker) Start() error {
   484  	ctx := context.Background()
   485  	hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn internet.Connection) {
   486  		go w.callback(conn)
   487  	})
   488  	if err != nil {
   489  		return newError("failed to listen Unix Domain Socket on ", w.address).AtWarning().Base(err)
   490  	}
   491  	w.hub = hub
   492  	return nil
   493  }
   494  
   495  func (w *dsWorker) Close() error {
   496  	var errors []interface{}
   497  	if w.hub != nil {
   498  		if err := common.Close(w.hub); err != nil {
   499  			errors = append(errors, err)
   500  		}
   501  		if err := common.Close(w.proxy); err != nil {
   502  			errors = append(errors, err)
   503  		}
   504  	}
   505  	if len(errors) > 0 {
   506  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   507  	}
   508  
   509  	return nil
   510  }