github.com/xraypb/xray-core@v1.6.6/app/proxyman/inbound/worker.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/xraypb/xray-core/app/proxyman"
    10  	"github.com/xraypb/xray-core/common"
    11  	"github.com/xraypb/xray-core/common/buf"
    12  	"github.com/xraypb/xray-core/common/net"
    13  	"github.com/xraypb/xray-core/common/serial"
    14  	"github.com/xraypb/xray-core/common/session"
    15  	"github.com/xraypb/xray-core/common/signal/done"
    16  	"github.com/xraypb/xray-core/common/task"
    17  	"github.com/xraypb/xray-core/features/routing"
    18  	"github.com/xraypb/xray-core/features/stats"
    19  	"github.com/xraypb/xray-core/proxy"
    20  	"github.com/xraypb/xray-core/transport/internet"
    21  	"github.com/xraypb/xray-core/transport/internet/stat"
    22  	"github.com/xraypb/xray-core/transport/internet/tcp"
    23  	"github.com/xraypb/xray-core/transport/internet/udp"
    24  	"github.com/xraypb/xray-core/transport/pipe"
    25  )
    26  
    27  type worker interface {
    28  	Start() error
    29  	Close() error
    30  	Port() net.Port
    31  	Proxy() proxy.Inbound
    32  }
    33  
    34  type tcpWorker struct {
    35  	address         net.Address
    36  	port            net.Port
    37  	proxy           proxy.Inbound
    38  	stream          *internet.MemoryStreamConfig
    39  	recvOrigDest    bool
    40  	tag             string
    41  	dispatcher      routing.Dispatcher
    42  	sniffingConfig  *proxyman.SniffingConfig
    43  	uplinkCounter   stats.Counter
    44  	downlinkCounter stats.Counter
    45  
    46  	hub internet.Listener
    47  
    48  	ctx context.Context
    49  }
    50  
    51  func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode {
    52  	if s == nil || s.SocketSettings == nil {
    53  		return internet.SocketConfig_Off
    54  	}
    55  	return s.SocketSettings.Tproxy
    56  }
    57  
    58  func (w *tcpWorker) callback(conn stat.Connection) {
    59  	ctx, cancel := context.WithCancel(w.ctx)
    60  	sid := session.NewID()
    61  	ctx = session.ContextWithID(ctx, sid)
    62  
    63  	if w.recvOrigDest {
    64  		var dest net.Destination
    65  		switch getTProxyType(w.stream) {
    66  		case internet.SocketConfig_Redirect:
    67  			d, err := tcp.GetOriginalDestination(conn)
    68  			if err != nil {
    69  				newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx))
    70  			} else {
    71  				dest = d
    72  			}
    73  		case internet.SocketConfig_TProxy:
    74  			dest = net.DestinationFromAddr(conn.LocalAddr())
    75  		}
    76  		if dest.IsValid() {
    77  			ctx = session.ContextWithOutbound(ctx, &session.Outbound{
    78  				Target: dest,
    79  			})
    80  		}
    81  	}
    82  
    83  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
    84  		conn = &stat.CounterConnection{
    85  			Connection:   conn,
    86  			ReadCounter:  w.uplinkCounter,
    87  			WriteCounter: w.downlinkCounter,
    88  		}
    89  	}
    90  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
    91  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
    92  		Gateway: net.TCPDestination(w.address, w.port),
    93  		Tag:     w.tag,
    94  		Conn:    conn,
    95  	})
    96  
    97  	content := new(session.Content)
    98  	if w.sniffingConfig != nil {
    99  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   100  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   101  		content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
   102  		content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   103  		content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
   104  	}
   105  	ctx = session.ContextWithContent(ctx, content)
   106  
   107  	if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil {
   108  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   109  	}
   110  	cancel()
   111  	conn.Close()
   112  }
   113  
   114  func (w *tcpWorker) Proxy() proxy.Inbound {
   115  	return w.proxy
   116  }
   117  
   118  func (w *tcpWorker) Start() error {
   119  	ctx := context.Background()
   120  	hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) {
   121  		go w.callback(conn)
   122  	})
   123  	if err != nil {
   124  		return newError("failed to listen TCP on ", w.port).AtWarning().Base(err)
   125  	}
   126  	w.hub = hub
   127  	return nil
   128  }
   129  
   130  func (w *tcpWorker) Close() error {
   131  	var errors []interface{}
   132  	if w.hub != nil {
   133  		if err := common.Close(w.hub); err != nil {
   134  			errors = append(errors, err)
   135  		}
   136  		if err := common.Close(w.proxy); err != nil {
   137  			errors = append(errors, err)
   138  		}
   139  	}
   140  	if len(errors) > 0 {
   141  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   142  	}
   143  
   144  	return nil
   145  }
   146  
   147  func (w *tcpWorker) Port() net.Port {
   148  	return w.port
   149  }
   150  
   151  type udpConn struct {
   152  	lastActivityTime int64 // in seconds
   153  	reader           buf.Reader
   154  	writer           buf.Writer
   155  	output           func([]byte) (int, error)
   156  	remote           net.Addr
   157  	local            net.Addr
   158  	done             *done.Instance
   159  	uplink           stats.Counter
   160  	downlink         stats.Counter
   161  	inactive         bool
   162  }
   163  
   164  func (c *udpConn) setInactive() {
   165  	c.inactive = true
   166  }
   167  
   168  func (c *udpConn) updateActivity() {
   169  	atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
   170  }
   171  
   172  // ReadMultiBuffer implements buf.Reader
   173  func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
   174  	mb, err := c.reader.ReadMultiBuffer()
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	c.updateActivity()
   179  
   180  	if c.uplink != nil {
   181  		c.uplink.Add(int64(mb.Len()))
   182  	}
   183  
   184  	return mb, nil
   185  }
   186  
   187  func (c *udpConn) Read(buf []byte) (int, error) {
   188  	panic("not implemented")
   189  }
   190  
   191  // Write implements io.Writer.
   192  func (c *udpConn) Write(buf []byte) (int, error) {
   193  	n, err := c.output(buf)
   194  	if c.downlink != nil {
   195  		c.downlink.Add(int64(n))
   196  	}
   197  	if err == nil {
   198  		c.updateActivity()
   199  	}
   200  	return n, err
   201  }
   202  
   203  func (c *udpConn) Close() error {
   204  	common.Must(c.done.Close())
   205  	common.Must(common.Close(c.writer))
   206  	return nil
   207  }
   208  
   209  func (c *udpConn) RemoteAddr() net.Addr {
   210  	return c.remote
   211  }
   212  
   213  func (c *udpConn) LocalAddr() net.Addr {
   214  	return c.local
   215  }
   216  
   217  func (*udpConn) SetDeadline(time.Time) error {
   218  	return nil
   219  }
   220  
   221  func (*udpConn) SetReadDeadline(time.Time) error {
   222  	return nil
   223  }
   224  
   225  func (*udpConn) SetWriteDeadline(time.Time) error {
   226  	return nil
   227  }
   228  
   229  type connID struct {
   230  	src  net.Destination
   231  	dest net.Destination
   232  }
   233  
   234  type udpWorker struct {
   235  	sync.RWMutex
   236  
   237  	proxy           proxy.Inbound
   238  	hub             *udp.Hub
   239  	address         net.Address
   240  	port            net.Port
   241  	tag             string
   242  	stream          *internet.MemoryStreamConfig
   243  	dispatcher      routing.Dispatcher
   244  	sniffingConfig  *proxyman.SniffingConfig
   245  	uplinkCounter   stats.Counter
   246  	downlinkCounter stats.Counter
   247  
   248  	checker    *task.Periodic
   249  	activeConn map[connID]*udpConn
   250  
   251  	ctx  context.Context
   252  	cone bool
   253  }
   254  
   255  func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
   256  	w.Lock()
   257  	defer w.Unlock()
   258  
   259  	if conn, found := w.activeConn[id]; found && !conn.done.Done() {
   260  		return conn, true
   261  	}
   262  
   263  	pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
   264  	conn := &udpConn{
   265  		reader: pReader,
   266  		writer: pWriter,
   267  		output: func(b []byte) (int, error) {
   268  			return w.hub.WriteTo(b, id.src)
   269  		},
   270  		remote: &net.UDPAddr{
   271  			IP:   id.src.Address.IP(),
   272  			Port: int(id.src.Port),
   273  		},
   274  		local: &net.UDPAddr{
   275  			IP:   w.address.IP(),
   276  			Port: int(w.port),
   277  		},
   278  		done:     done.New(),
   279  		uplink:   w.uplinkCounter,
   280  		downlink: w.downlinkCounter,
   281  	}
   282  	w.activeConn[id] = conn
   283  
   284  	conn.updateActivity()
   285  	return conn, false
   286  }
   287  
   288  func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
   289  	id := connID{
   290  		src: source,
   291  	}
   292  	if originalDest.IsValid() {
   293  		if !w.cone {
   294  			id.dest = originalDest
   295  		}
   296  		b.UDP = &originalDest
   297  	}
   298  	conn, existing := w.getConnection(id)
   299  
   300  	// payload will be discarded in pipe is full.
   301  	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
   302  
   303  	if !existing {
   304  		common.Must(w.checker.Start())
   305  
   306  		go func() {
   307  			ctx := w.ctx
   308  			sid := session.NewID()
   309  			ctx = session.ContextWithID(ctx, sid)
   310  
   311  			if originalDest.IsValid() {
   312  				ctx = session.ContextWithOutbound(ctx, &session.Outbound{
   313  					Target: originalDest,
   314  				})
   315  			}
   316  			ctx = session.ContextWithInbound(ctx, &session.Inbound{
   317  				Source:  source,
   318  				Gateway: net.UDPDestination(w.address, w.port),
   319  				Tag:     w.tag,
   320  			})
   321  			content := new(session.Content)
   322  			if w.sniffingConfig != nil {
   323  				content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   324  				content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   325  				content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   326  				content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
   327  			}
   328  			ctx = session.ContextWithContent(ctx, content)
   329  			if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
   330  				newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   331  			}
   332  			conn.Close()
   333  			// conn not removed by checker TODO may be lock worker here is better
   334  			if !conn.inactive {
   335  				conn.setInactive()
   336  				w.removeConn(id)
   337  			}
   338  		}()
   339  	}
   340  }
   341  
   342  func (w *udpWorker) removeConn(id connID) {
   343  	w.Lock()
   344  	delete(w.activeConn, id)
   345  	w.Unlock()
   346  }
   347  
   348  func (w *udpWorker) handlePackets() {
   349  	receive := w.hub.Receive()
   350  	for payload := range receive {
   351  		w.callback(payload.Payload, payload.Source, payload.Target)
   352  	}
   353  }
   354  
   355  func (w *udpWorker) clean() error {
   356  	nowSec := time.Now().Unix()
   357  	w.Lock()
   358  	defer w.Unlock()
   359  
   360  	if len(w.activeConn) == 0 {
   361  		return newError("no more connections. stopping...")
   362  	}
   363  
   364  	for addr, conn := range w.activeConn {
   365  		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 5*60 { // TODO Timeout too small
   366  			if !conn.inactive {
   367  				conn.setInactive()
   368  				delete(w.activeConn, addr)
   369  			}
   370  			conn.Close()
   371  		}
   372  	}
   373  
   374  	if len(w.activeConn) == 0 {
   375  		w.activeConn = make(map[connID]*udpConn, 16)
   376  	}
   377  
   378  	return nil
   379  }
   380  
   381  func (w *udpWorker) Start() error {
   382  	w.activeConn = make(map[connID]*udpConn, 16)
   383  	ctx := context.Background()
   384  	h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256))
   385  	if err != nil {
   386  		return err
   387  	}
   388  
   389  	w.cone = w.ctx.Value("cone").(bool)
   390  
   391  	w.checker = &task.Periodic{
   392  		Interval: time.Minute,
   393  		Execute:  w.clean,
   394  	}
   395  
   396  	w.hub = h
   397  	go w.handlePackets()
   398  	return nil
   399  }
   400  
   401  func (w *udpWorker) Close() error {
   402  	w.Lock()
   403  	defer w.Unlock()
   404  
   405  	var errors []interface{}
   406  
   407  	if w.hub != nil {
   408  		if err := w.hub.Close(); err != nil {
   409  			errors = append(errors, err)
   410  		}
   411  	}
   412  
   413  	if w.checker != nil {
   414  		if err := w.checker.Close(); err != nil {
   415  			errors = append(errors, err)
   416  		}
   417  	}
   418  
   419  	if err := common.Close(w.proxy); err != nil {
   420  		errors = append(errors, err)
   421  	}
   422  
   423  	if len(errors) > 0 {
   424  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   425  	}
   426  	return nil
   427  }
   428  
   429  func (w *udpWorker) Port() net.Port {
   430  	return w.port
   431  }
   432  
   433  func (w *udpWorker) Proxy() proxy.Inbound {
   434  	return w.proxy
   435  }
   436  
   437  type dsWorker struct {
   438  	address         net.Address
   439  	proxy           proxy.Inbound
   440  	stream          *internet.MemoryStreamConfig
   441  	tag             string
   442  	dispatcher      routing.Dispatcher
   443  	sniffingConfig  *proxyman.SniffingConfig
   444  	uplinkCounter   stats.Counter
   445  	downlinkCounter stats.Counter
   446  
   447  	hub internet.Listener
   448  
   449  	ctx context.Context
   450  }
   451  
   452  func (w *dsWorker) callback(conn stat.Connection) {
   453  	ctx, cancel := context.WithCancel(w.ctx)
   454  	sid := session.NewID()
   455  	ctx = session.ContextWithID(ctx, sid)
   456  
   457  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
   458  		conn = &stat.CounterConnection{
   459  			Connection:   conn,
   460  			ReadCounter:  w.uplinkCounter,
   461  			WriteCounter: w.downlinkCounter,
   462  		}
   463  	}
   464  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
   465  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
   466  		Gateway: net.UnixDestination(w.address),
   467  		Tag:     w.tag,
   468  		Conn:    conn,
   469  	})
   470  
   471  	content := new(session.Content)
   472  	if w.sniffingConfig != nil {
   473  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   474  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   475  		content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
   476  		content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   477  		content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
   478  	}
   479  	ctx = session.ContextWithContent(ctx, content)
   480  
   481  	if err := w.proxy.Process(ctx, net.Network_UNIX, conn, w.dispatcher); err != nil {
   482  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   483  	}
   484  	cancel()
   485  	if err := conn.Close(); err != nil {
   486  		newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   487  	}
   488  }
   489  
   490  func (w *dsWorker) Proxy() proxy.Inbound {
   491  	return w.proxy
   492  }
   493  
   494  func (w *dsWorker) Port() net.Port {
   495  	return net.Port(0)
   496  }
   497  
   498  func (w *dsWorker) Start() error {
   499  	ctx := context.Background()
   500  	hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) {
   501  		go w.callback(conn)
   502  	})
   503  	if err != nil {
   504  		return newError("failed to listen Unix Domain Socket on ", w.address).AtWarning().Base(err)
   505  	}
   506  	w.hub = hub
   507  	return nil
   508  }
   509  
   510  func (w *dsWorker) Close() error {
   511  	var errors []interface{}
   512  	if w.hub != nil {
   513  		if err := common.Close(w.hub); err != nil {
   514  			errors = append(errors, err)
   515  		}
   516  		if err := common.Close(w.proxy); err != nil {
   517  			errors = append(errors, err)
   518  		}
   519  	}
   520  	if len(errors) > 0 {
   521  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   522  	}
   523  
   524  	return nil
   525  }