github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/app/proxyman/inbound/worker.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/xtls/xray-core/app/proxyman"
    10  	"github.com/xtls/xray-core/common"
    11  	"github.com/xtls/xray-core/common/buf"
    12  	"github.com/xtls/xray-core/common/net"
    13  	"github.com/xtls/xray-core/common/serial"
    14  	"github.com/xtls/xray-core/common/session"
    15  	"github.com/xtls/xray-core/common/signal/done"
    16  	"github.com/xtls/xray-core/common/task"
    17  	"github.com/xtls/xray-core/features/routing"
    18  	"github.com/xtls/xray-core/features/stats"
    19  	"github.com/xtls/xray-core/proxy"
    20  	"github.com/xtls/xray-core/transport/internet"
    21  	"github.com/xtls/xray-core/transport/internet/stat"
    22  	"github.com/xtls/xray-core/transport/internet/tcp"
    23  	"github.com/xtls/xray-core/transport/internet/udp"
    24  	"github.com/xtls/xray-core/transport/pipe"
    25  )
    26  
    27  type worker interface {
    28  	Start() error
    29  	Close() error
    30  	Port() net.Port
    31  	Proxy() proxy.Inbound
    32  }
    33  
    34  type tcpWorker struct {
    35  	address         net.Address
    36  	port            net.Port
    37  	proxy           proxy.Inbound
    38  	stream          *internet.MemoryStreamConfig
    39  	recvOrigDest    bool
    40  	tag             string
    41  	dispatcher      routing.Dispatcher
    42  	sniffingConfig  *proxyman.SniffingConfig
    43  	uplinkCounter   stats.Counter
    44  	downlinkCounter stats.Counter
    45  
    46  	hub internet.Listener
    47  
    48  	ctx context.Context
    49  }
    50  
    51  func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode {
    52  	if s == nil || s.SocketSettings == nil {
    53  		return internet.SocketConfig_Off
    54  	}
    55  	return s.SocketSettings.Tproxy
    56  }
    57  
    58  func (w *tcpWorker) callback(conn stat.Connection) {
    59  	ctx, cancel := context.WithCancel(w.ctx)
    60  	sid := session.NewID()
    61  	ctx = session.ContextWithID(ctx, sid)
    62  
    63  	outbounds := []*session.Outbound{{}}
    64  	if w.recvOrigDest {
    65  		var dest net.Destination
    66  		switch getTProxyType(w.stream) {
    67  		case internet.SocketConfig_Redirect:
    68  			d, err := tcp.GetOriginalDestination(conn)
    69  			if err != nil {
    70  				newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx))
    71  			} else {
    72  				dest = d
    73  			}
    74  		case internet.SocketConfig_TProxy:
    75  			dest = net.DestinationFromAddr(conn.LocalAddr())
    76  		}
    77  		if dest.IsValid() {
    78  			outbounds[0].Target = dest
    79  		}
    80  	}
    81  	ctx = session.ContextWithOutbounds(ctx, outbounds)
    82  
    83  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
    84  		conn = &stat.CounterConnection{
    85  			Connection:   conn,
    86  			ReadCounter:  w.uplinkCounter,
    87  			WriteCounter: w.downlinkCounter,
    88  		}
    89  	}
    90  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
    91  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
    92  		Gateway: net.TCPDestination(w.address, w.port),
    93  		Tag:     w.tag,
    94  		Conn:    conn,
    95  	})
    96  
    97  	content := new(session.Content)
    98  	if w.sniffingConfig != nil {
    99  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   100  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   101  		content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
   102  		content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   103  		content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
   104  	}
   105  	ctx = session.ContextWithContent(ctx, content)
   106  
   107  	if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil {
   108  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   109  	}
   110  	cancel()
   111  	conn.Close()
   112  }
   113  
   114  func (w *tcpWorker) Proxy() proxy.Inbound {
   115  	return w.proxy
   116  }
   117  
   118  func (w *tcpWorker) Start() error {
   119  	ctx := context.Background()
   120  	hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) {
   121  		go w.callback(conn)
   122  	})
   123  	if err != nil {
   124  		return newError("failed to listen TCP on ", w.port).AtWarning().Base(err)
   125  	}
   126  	w.hub = hub
   127  	return nil
   128  }
   129  
   130  func (w *tcpWorker) Close() error {
   131  	var errors []interface{}
   132  	if w.hub != nil {
   133  		if err := common.Close(w.hub); err != nil {
   134  			errors = append(errors, err)
   135  		}
   136  		if err := common.Close(w.proxy); err != nil {
   137  			errors = append(errors, err)
   138  		}
   139  	}
   140  	if len(errors) > 0 {
   141  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   142  	}
   143  
   144  	return nil
   145  }
   146  
   147  func (w *tcpWorker) Port() net.Port {
   148  	return w.port
   149  }
   150  
   151  type udpConn struct {
   152  	lastActivityTime int64 // in seconds
   153  	reader           buf.Reader
   154  	writer           buf.Writer
   155  	output           func([]byte) (int, error)
   156  	remote           net.Addr
   157  	local            net.Addr
   158  	done             *done.Instance
   159  	uplink           stats.Counter
   160  	downlink         stats.Counter
   161  	inactive         bool
   162  }
   163  
   164  func (c *udpConn) setInactive() {
   165  	c.inactive = true
   166  }
   167  
   168  func (c *udpConn) updateActivity() {
   169  	atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
   170  }
   171  
   172  // ReadMultiBuffer implements buf.Reader
   173  func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
   174  	mb, err := c.reader.ReadMultiBuffer()
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	c.updateActivity()
   179  
   180  	if c.uplink != nil {
   181  		c.uplink.Add(int64(mb.Len()))
   182  	}
   183  
   184  	return mb, nil
   185  }
   186  
   187  func (c *udpConn) Read(buf []byte) (int, error) {
   188  	panic("not implemented")
   189  }
   190  
   191  // Write implements io.Writer.
   192  func (c *udpConn) Write(buf []byte) (int, error) {
   193  	n, err := c.output(buf)
   194  	if c.downlink != nil {
   195  		c.downlink.Add(int64(n))
   196  	}
   197  	if err == nil {
   198  		c.updateActivity()
   199  	}
   200  	return n, err
   201  }
   202  
   203  func (c *udpConn) Close() error {
   204  	common.Must(c.done.Close())
   205  	common.Must(common.Close(c.writer))
   206  	return nil
   207  }
   208  
   209  func (c *udpConn) RemoteAddr() net.Addr {
   210  	return c.remote
   211  }
   212  
   213  func (c *udpConn) LocalAddr() net.Addr {
   214  	return c.local
   215  }
   216  
   217  func (*udpConn) SetDeadline(time.Time) error {
   218  	return nil
   219  }
   220  
   221  func (*udpConn) SetReadDeadline(time.Time) error {
   222  	return nil
   223  }
   224  
   225  func (*udpConn) SetWriteDeadline(time.Time) error {
   226  	return nil
   227  }
   228  
   229  type connID struct {
   230  	src  net.Destination
   231  	dest net.Destination
   232  }
   233  
   234  type udpWorker struct {
   235  	sync.RWMutex
   236  
   237  	proxy           proxy.Inbound
   238  	hub             *udp.Hub
   239  	address         net.Address
   240  	port            net.Port
   241  	tag             string
   242  	stream          *internet.MemoryStreamConfig
   243  	dispatcher      routing.Dispatcher
   244  	sniffingConfig  *proxyman.SniffingConfig
   245  	uplinkCounter   stats.Counter
   246  	downlinkCounter stats.Counter
   247  
   248  	checker    *task.Periodic
   249  	activeConn map[connID]*udpConn
   250  
   251  	ctx  context.Context
   252  	cone bool
   253  }
   254  
   255  func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
   256  	w.Lock()
   257  	defer w.Unlock()
   258  
   259  	if conn, found := w.activeConn[id]; found && !conn.done.Done() {
   260  		return conn, true
   261  	}
   262  
   263  	pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
   264  	conn := &udpConn{
   265  		reader: pReader,
   266  		writer: pWriter,
   267  		output: func(b []byte) (int, error) {
   268  			return w.hub.WriteTo(b, id.src)
   269  		},
   270  		remote: &net.UDPAddr{
   271  			IP:   id.src.Address.IP(),
   272  			Port: int(id.src.Port),
   273  		},
   274  		local: &net.UDPAddr{
   275  			IP:   w.address.IP(),
   276  			Port: int(w.port),
   277  		},
   278  		done:     done.New(),
   279  		uplink:   w.uplinkCounter,
   280  		downlink: w.downlinkCounter,
   281  	}
   282  	w.activeConn[id] = conn
   283  
   284  	conn.updateActivity()
   285  	return conn, false
   286  }
   287  
   288  func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
   289  	id := connID{
   290  		src: source,
   291  	}
   292  	if originalDest.IsValid() {
   293  		if !w.cone {
   294  			id.dest = originalDest
   295  		}
   296  		b.UDP = &originalDest
   297  	}
   298  	conn, existing := w.getConnection(id)
   299  
   300  	// payload will be discarded in pipe is full.
   301  	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
   302  
   303  	if !existing {
   304  		common.Must(w.checker.Start())
   305  
   306  		go func() {
   307  			ctx := w.ctx
   308  			sid := session.NewID()
   309  			ctx = session.ContextWithID(ctx, sid)
   310  
   311  			if originalDest.IsValid() {
   312  				outbounds := []*session.Outbound{{
   313  					Target: originalDest,
   314  				}}
   315  				ctx = session.ContextWithOutbounds(ctx, outbounds)
   316  			}
   317  			ctx = session.ContextWithInbound(ctx, &session.Inbound{
   318  				Source:  source,
   319  				Gateway: net.UDPDestination(w.address, w.port),
   320  				Tag:     w.tag,
   321  			})
   322  			content := new(session.Content)
   323  			if w.sniffingConfig != nil {
   324  				content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   325  				content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   326  				content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   327  				content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
   328  			}
   329  			ctx = session.ContextWithContent(ctx, content)
   330  			if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
   331  				newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   332  			}
   333  			conn.Close()
   334  			// conn not removed by checker TODO may be lock worker here is better
   335  			if !conn.inactive {
   336  				conn.setInactive()
   337  				w.removeConn(id)
   338  			}
   339  		}()
   340  	}
   341  }
   342  
   343  func (w *udpWorker) removeConn(id connID) {
   344  	w.Lock()
   345  	delete(w.activeConn, id)
   346  	w.Unlock()
   347  }
   348  
   349  func (w *udpWorker) handlePackets() {
   350  	receive := w.hub.Receive()
   351  	for payload := range receive {
   352  		w.callback(payload.Payload, payload.Source, payload.Target)
   353  	}
   354  }
   355  
   356  func (w *udpWorker) clean() error {
   357  	nowSec := time.Now().Unix()
   358  	w.Lock()
   359  	defer w.Unlock()
   360  
   361  	if len(w.activeConn) == 0 {
   362  		return newError("no more connections. stopping...")
   363  	}
   364  
   365  	for addr, conn := range w.activeConn {
   366  		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 2*60 {
   367  			if !conn.inactive {
   368  				conn.setInactive()
   369  				delete(w.activeConn, addr)
   370  			}
   371  			conn.Close()
   372  		}
   373  	}
   374  
   375  	if len(w.activeConn) == 0 {
   376  		w.activeConn = make(map[connID]*udpConn, 16)
   377  	}
   378  
   379  	return nil
   380  }
   381  
   382  func (w *udpWorker) Start() error {
   383  	w.activeConn = make(map[connID]*udpConn, 16)
   384  	ctx := context.Background()
   385  	h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256))
   386  	if err != nil {
   387  		return err
   388  	}
   389  
   390  	w.cone = w.ctx.Value("cone").(bool)
   391  
   392  	w.checker = &task.Periodic{
   393  		Interval: time.Minute,
   394  		Execute:  w.clean,
   395  	}
   396  
   397  	w.hub = h
   398  	go w.handlePackets()
   399  	return nil
   400  }
   401  
   402  func (w *udpWorker) Close() error {
   403  	w.Lock()
   404  	defer w.Unlock()
   405  
   406  	var errors []interface{}
   407  
   408  	if w.hub != nil {
   409  		if err := w.hub.Close(); err != nil {
   410  			errors = append(errors, err)
   411  		}
   412  	}
   413  
   414  	if w.checker != nil {
   415  		if err := w.checker.Close(); err != nil {
   416  			errors = append(errors, err)
   417  		}
   418  	}
   419  
   420  	if err := common.Close(w.proxy); err != nil {
   421  		errors = append(errors, err)
   422  	}
   423  
   424  	if len(errors) > 0 {
   425  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   426  	}
   427  	return nil
   428  }
   429  
   430  func (w *udpWorker) Port() net.Port {
   431  	return w.port
   432  }
   433  
   434  func (w *udpWorker) Proxy() proxy.Inbound {
   435  	return w.proxy
   436  }
   437  
   438  type dsWorker struct {
   439  	address         net.Address
   440  	proxy           proxy.Inbound
   441  	stream          *internet.MemoryStreamConfig
   442  	tag             string
   443  	dispatcher      routing.Dispatcher
   444  	sniffingConfig  *proxyman.SniffingConfig
   445  	uplinkCounter   stats.Counter
   446  	downlinkCounter stats.Counter
   447  
   448  	hub internet.Listener
   449  
   450  	ctx context.Context
   451  }
   452  
   453  func (w *dsWorker) callback(conn stat.Connection) {
   454  	ctx, cancel := context.WithCancel(w.ctx)
   455  	sid := session.NewID()
   456  	ctx = session.ContextWithID(ctx, sid)
   457  
   458  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
   459  		conn = &stat.CounterConnection{
   460  			Connection:   conn,
   461  			ReadCounter:  w.uplinkCounter,
   462  			WriteCounter: w.downlinkCounter,
   463  		}
   464  	}
   465  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
   466  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
   467  		Gateway: net.UnixDestination(w.address),
   468  		Tag:     w.tag,
   469  		Conn:    conn,
   470  	})
   471  
   472  	content := new(session.Content)
   473  	if w.sniffingConfig != nil {
   474  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   475  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   476  		content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
   477  		content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   478  		content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
   479  	}
   480  	ctx = session.ContextWithContent(ctx, content)
   481  
   482  	if err := w.proxy.Process(ctx, net.Network_UNIX, conn, w.dispatcher); err != nil {
   483  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   484  	}
   485  	cancel()
   486  	if err := conn.Close(); err != nil {
   487  		newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   488  	}
   489  }
   490  
   491  func (w *dsWorker) Proxy() proxy.Inbound {
   492  	return w.proxy
   493  }
   494  
   495  func (w *dsWorker) Port() net.Port {
   496  	return net.Port(0)
   497  }
   498  
   499  func (w *dsWorker) Start() error {
   500  	ctx := context.Background()
   501  	hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) {
   502  		go w.callback(conn)
   503  	})
   504  	if err != nil {
   505  		return newError("failed to listen Unix Domain Socket on ", w.address).AtWarning().Base(err)
   506  	}
   507  	w.hub = hub
   508  	return nil
   509  }
   510  
   511  func (w *dsWorker) Close() error {
   512  	var errors []interface{}
   513  	if w.hub != nil {
   514  		if err := common.Close(w.hub); err != nil {
   515  			errors = append(errors, err)
   516  		}
   517  		if err := common.Close(w.proxy); err != nil {
   518  			errors = append(errors, err)
   519  		}
   520  	}
   521  	if len(errors) > 0 {
   522  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   523  	}
   524  
   525  	return nil
   526  }