github.com/v2fly/v2ray-core/v4@v4.45.2/app/proxyman/inbound/worker.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/v2fly/v2ray-core/v4/app/proxyman"
    10  	"github.com/v2fly/v2ray-core/v4/common"
    11  	"github.com/v2fly/v2ray-core/v4/common/buf"
    12  	"github.com/v2fly/v2ray-core/v4/common/net"
    13  	"github.com/v2fly/v2ray-core/v4/common/serial"
    14  	"github.com/v2fly/v2ray-core/v4/common/session"
    15  	"github.com/v2fly/v2ray-core/v4/common/signal/done"
    16  	"github.com/v2fly/v2ray-core/v4/common/task"
    17  	"github.com/v2fly/v2ray-core/v4/features/routing"
    18  	"github.com/v2fly/v2ray-core/v4/features/stats"
    19  	"github.com/v2fly/v2ray-core/v4/proxy"
    20  	"github.com/v2fly/v2ray-core/v4/transport/internet"
    21  	"github.com/v2fly/v2ray-core/v4/transport/internet/tcp"
    22  	"github.com/v2fly/v2ray-core/v4/transport/internet/udp"
    23  	"github.com/v2fly/v2ray-core/v4/transport/pipe"
    24  )
    25  
    26  type worker interface {
    27  	Start() error
    28  	Close() error
    29  	Port() net.Port
    30  	Proxy() proxy.Inbound
    31  }
    32  
    33  type tcpWorker struct {
    34  	address         net.Address
    35  	port            net.Port
    36  	proxy           proxy.Inbound
    37  	stream          *internet.MemoryStreamConfig
    38  	recvOrigDest    bool
    39  	tag             string
    40  	dispatcher      routing.Dispatcher
    41  	sniffingConfig  *proxyman.SniffingConfig
    42  	uplinkCounter   stats.Counter
    43  	downlinkCounter stats.Counter
    44  
    45  	hub internet.Listener
    46  
    47  	ctx context.Context
    48  }
    49  
    50  func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode {
    51  	if s == nil || s.SocketSettings == nil {
    52  		return internet.SocketConfig_Off
    53  	}
    54  	return s.SocketSettings.Tproxy
    55  }
    56  
    57  func (w *tcpWorker) callback(conn internet.Connection) {
    58  	ctx, cancel := context.WithCancel(w.ctx)
    59  	sid := session.NewID()
    60  	ctx = session.ContextWithID(ctx, sid)
    61  
    62  	if w.recvOrigDest {
    63  		var dest net.Destination
    64  		switch getTProxyType(w.stream) {
    65  		case internet.SocketConfig_Redirect:
    66  			d, err := tcp.GetOriginalDestination(conn)
    67  			if err != nil {
    68  				newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx))
    69  			} else {
    70  				dest = d
    71  			}
    72  		case internet.SocketConfig_TProxy:
    73  			dest = net.DestinationFromAddr(conn.LocalAddr())
    74  		}
    75  		if dest.IsValid() {
    76  			ctx = session.ContextWithOutbound(ctx, &session.Outbound{
    77  				Target: dest,
    78  			})
    79  		}
    80  	}
    81  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
    82  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
    83  		Gateway: net.TCPDestination(w.address, w.port),
    84  		Tag:     w.tag,
    85  	})
    86  	content := new(session.Content)
    87  	if w.sniffingConfig != nil {
    88  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
    89  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
    90  		content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
    91  	}
    92  	ctx = session.ContextWithContent(ctx, content)
    93  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
    94  		conn = &internet.StatCouterConnection{
    95  			Connection:   conn,
    96  			ReadCounter:  w.uplinkCounter,
    97  			WriteCounter: w.downlinkCounter,
    98  		}
    99  	}
   100  	if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil {
   101  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   102  	}
   103  	cancel()
   104  	if err := conn.Close(); err != nil {
   105  		newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   106  	}
   107  }
   108  
   109  func (w *tcpWorker) Proxy() proxy.Inbound {
   110  	return w.proxy
   111  }
   112  
   113  func (w *tcpWorker) Start() error {
   114  	ctx := context.Background()
   115  	hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn internet.Connection) {
   116  		go w.callback(conn)
   117  	})
   118  	if err != nil {
   119  		return newError("failed to listen TCP on ", w.port).AtWarning().Base(err)
   120  	}
   121  	w.hub = hub
   122  	return nil
   123  }
   124  
   125  func (w *tcpWorker) Close() error {
   126  	var errors []interface{}
   127  	if w.hub != nil {
   128  		if err := common.Close(w.hub); err != nil {
   129  			errors = append(errors, err)
   130  		}
   131  		if err := common.Close(w.proxy); err != nil {
   132  			errors = append(errors, err)
   133  		}
   134  	}
   135  	if len(errors) > 0 {
   136  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  func (w *tcpWorker) Port() net.Port {
   143  	return w.port
   144  }
   145  
   146  type udpConn struct {
   147  	lastActivityTime int64 // in seconds
   148  	reader           buf.Reader
   149  	writer           buf.Writer
   150  	output           func([]byte) (int, error)
   151  	remote           net.Addr
   152  	local            net.Addr
   153  	done             *done.Instance
   154  	uplink           stats.Counter
   155  	downlink         stats.Counter
   156  	inactive         bool
   157  }
   158  
   159  func (c *udpConn) setInactive() {
   160  	c.inactive = true
   161  }
   162  
   163  func (c *udpConn) updateActivity() {
   164  	atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
   165  }
   166  
   167  // ReadMultiBuffer implements buf.Reader
   168  func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
   169  	mb, err := c.reader.ReadMultiBuffer()
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	c.updateActivity()
   174  
   175  	if c.uplink != nil {
   176  		c.uplink.Add(int64(mb.Len()))
   177  	}
   178  
   179  	return mb, nil
   180  }
   181  
   182  func (c *udpConn) Read(buf []byte) (int, error) {
   183  	panic("not implemented")
   184  }
   185  
   186  // Write implements io.Writer.
   187  func (c *udpConn) Write(buf []byte) (int, error) {
   188  	n, err := c.output(buf)
   189  	if c.downlink != nil {
   190  		c.downlink.Add(int64(n))
   191  	}
   192  	if err == nil {
   193  		c.updateActivity()
   194  	}
   195  	return n, err
   196  }
   197  
   198  func (c *udpConn) Close() error {
   199  	common.Must(c.done.Close())
   200  	common.Must(common.Close(c.writer))
   201  	return nil
   202  }
   203  
   204  func (c *udpConn) RemoteAddr() net.Addr {
   205  	return c.remote
   206  }
   207  
   208  func (c *udpConn) LocalAddr() net.Addr {
   209  	return c.local
   210  }
   211  
   212  func (*udpConn) SetDeadline(time.Time) error {
   213  	return nil
   214  }
   215  
   216  func (*udpConn) SetReadDeadline(time.Time) error {
   217  	return nil
   218  }
   219  
   220  func (*udpConn) SetWriteDeadline(time.Time) error {
   221  	return nil
   222  }
   223  
   224  type connID struct {
   225  	src  net.Destination
   226  	dest net.Destination
   227  }
   228  
   229  type udpWorker struct {
   230  	sync.RWMutex
   231  
   232  	proxy           proxy.Inbound
   233  	hub             *udp.Hub
   234  	address         net.Address
   235  	port            net.Port
   236  	tag             string
   237  	stream          *internet.MemoryStreamConfig
   238  	dispatcher      routing.Dispatcher
   239  	sniffingConfig  *proxyman.SniffingConfig
   240  	uplinkCounter   stats.Counter
   241  	downlinkCounter stats.Counter
   242  
   243  	checker    *task.Periodic
   244  	activeConn map[connID]*udpConn
   245  
   246  	ctx context.Context
   247  }
   248  
   249  func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
   250  	w.Lock()
   251  	defer w.Unlock()
   252  
   253  	if conn, found := w.activeConn[id]; found && !conn.done.Done() {
   254  		return conn, true
   255  	}
   256  
   257  	pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
   258  	conn := &udpConn{
   259  		reader: pReader,
   260  		writer: pWriter,
   261  		output: func(b []byte) (int, error) {
   262  			return w.hub.WriteTo(b, id.src)
   263  		},
   264  		remote: &net.UDPAddr{
   265  			IP:   id.src.Address.IP(),
   266  			Port: int(id.src.Port),
   267  		},
   268  		local: &net.UDPAddr{
   269  			IP:   w.address.IP(),
   270  			Port: int(w.port),
   271  		},
   272  		done:     done.New(),
   273  		uplink:   w.uplinkCounter,
   274  		downlink: w.downlinkCounter,
   275  	}
   276  	w.activeConn[id] = conn
   277  
   278  	conn.updateActivity()
   279  	return conn, false
   280  }
   281  
   282  func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
   283  	id := connID{
   284  		src: source,
   285  	}
   286  	if originalDest.IsValid() {
   287  		id.dest = originalDest
   288  	}
   289  	conn, existing := w.getConnection(id)
   290  
   291  	// payload will be discarded in pipe is full.
   292  	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
   293  
   294  	if !existing {
   295  		common.Must(w.checker.Start())
   296  
   297  		go func() {
   298  			ctx := w.ctx
   299  			sid := session.NewID()
   300  			ctx = session.ContextWithID(ctx, sid)
   301  
   302  			if originalDest.IsValid() {
   303  				ctx = session.ContextWithOutbound(ctx, &session.Outbound{
   304  					Target: originalDest,
   305  				})
   306  			}
   307  			ctx = session.ContextWithInbound(ctx, &session.Inbound{
   308  				Source:  source,
   309  				Gateway: net.UDPDestination(w.address, w.port),
   310  				Tag:     w.tag,
   311  			})
   312  			content := new(session.Content)
   313  			if w.sniffingConfig != nil {
   314  				content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   315  				content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   316  				content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   317  			}
   318  			ctx = session.ContextWithContent(ctx, content)
   319  			if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
   320  				newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   321  			}
   322  			conn.Close()
   323  			// conn not removed by checker TODO may be lock worker here is better
   324  			if !conn.inactive {
   325  				conn.setInactive()
   326  				w.removeConn(id)
   327  			}
   328  		}()
   329  	}
   330  }
   331  
   332  func (w *udpWorker) removeConn(id connID) {
   333  	w.Lock()
   334  	delete(w.activeConn, id)
   335  	w.Unlock()
   336  }
   337  
   338  func (w *udpWorker) handlePackets() {
   339  	receive := w.hub.Receive()
   340  	for payload := range receive {
   341  		w.callback(payload.Payload, payload.Source, payload.Target)
   342  	}
   343  }
   344  
   345  func (w *udpWorker) clean() error {
   346  	nowSec := time.Now().Unix()
   347  	w.Lock()
   348  	defer w.Unlock()
   349  
   350  	if len(w.activeConn) == 0 {
   351  		return newError("no more connections. stopping...")
   352  	}
   353  
   354  	for addr, conn := range w.activeConn {
   355  		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 8 { // TODO Timeout too small
   356  			if !conn.inactive {
   357  				conn.setInactive()
   358  				delete(w.activeConn, addr)
   359  			}
   360  			conn.Close()
   361  		}
   362  	}
   363  
   364  	if len(w.activeConn) == 0 {
   365  		w.activeConn = make(map[connID]*udpConn, 16)
   366  	}
   367  
   368  	return nil
   369  }
   370  
   371  func (w *udpWorker) Start() error {
   372  	w.activeConn = make(map[connID]*udpConn, 16)
   373  	ctx := context.Background()
   374  	h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256))
   375  	if err != nil {
   376  		return err
   377  	}
   378  
   379  	w.checker = &task.Periodic{
   380  		Interval: time.Second * 16,
   381  		Execute:  w.clean,
   382  	}
   383  
   384  	w.hub = h
   385  	go w.handlePackets()
   386  	return nil
   387  }
   388  
   389  func (w *udpWorker) Close() error {
   390  	w.Lock()
   391  	defer w.Unlock()
   392  
   393  	var errors []interface{}
   394  
   395  	if w.hub != nil {
   396  		if err := w.hub.Close(); err != nil {
   397  			errors = append(errors, err)
   398  		}
   399  	}
   400  
   401  	if w.checker != nil {
   402  		if err := w.checker.Close(); err != nil {
   403  			errors = append(errors, err)
   404  		}
   405  	}
   406  
   407  	if err := common.Close(w.proxy); err != nil {
   408  		errors = append(errors, err)
   409  	}
   410  
   411  	if len(errors) > 0 {
   412  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   413  	}
   414  	return nil
   415  }
   416  
   417  func (w *udpWorker) Port() net.Port {
   418  	return w.port
   419  }
   420  
   421  func (w *udpWorker) Proxy() proxy.Inbound {
   422  	return w.proxy
   423  }
   424  
   425  type dsWorker struct {
   426  	address         net.Address
   427  	proxy           proxy.Inbound
   428  	stream          *internet.MemoryStreamConfig
   429  	tag             string
   430  	dispatcher      routing.Dispatcher
   431  	sniffingConfig  *proxyman.SniffingConfig
   432  	uplinkCounter   stats.Counter
   433  	downlinkCounter stats.Counter
   434  
   435  	hub internet.Listener
   436  
   437  	ctx context.Context
   438  }
   439  
   440  func (w *dsWorker) callback(conn internet.Connection) {
   441  	ctx, cancel := context.WithCancel(w.ctx)
   442  	sid := session.NewID()
   443  	ctx = session.ContextWithID(ctx, sid)
   444  
   445  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
   446  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
   447  		Gateway: net.UnixDestination(w.address),
   448  		Tag:     w.tag,
   449  	})
   450  	content := new(session.Content)
   451  	if w.sniffingConfig != nil {
   452  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
   453  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
   454  		content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
   455  	}
   456  	ctx = session.ContextWithContent(ctx, content)
   457  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
   458  		conn = &internet.StatCouterConnection{
   459  			Connection:   conn,
   460  			ReadCounter:  w.uplinkCounter,
   461  			WriteCounter: w.downlinkCounter,
   462  		}
   463  	}
   464  	if err := w.proxy.Process(ctx, net.Network_UNIX, conn, w.dispatcher); err != nil {
   465  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   466  	}
   467  	cancel()
   468  	if err := conn.Close(); err != nil {
   469  		newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   470  	}
   471  }
   472  
   473  func (w *dsWorker) Proxy() proxy.Inbound {
   474  	return w.proxy
   475  }
   476  
   477  func (w *dsWorker) Port() net.Port {
   478  	return net.Port(0)
   479  }
   480  
   481  func (w *dsWorker) Start() error {
   482  	ctx := context.Background()
   483  	hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn internet.Connection) {
   484  		go w.callback(conn)
   485  	})
   486  	if err != nil {
   487  		return newError("failed to listen Unix Domain Socket on ", w.address).AtWarning().Base(err)
   488  	}
   489  	w.hub = hub
   490  	return nil
   491  }
   492  
   493  func (w *dsWorker) Close() error {
   494  	var errors []interface{}
   495  	if w.hub != nil {
   496  		if err := common.Close(w.hub); err != nil {
   497  			errors = append(errors, err)
   498  		}
   499  		if err := common.Close(w.proxy); err != nil {
   500  			errors = append(errors, err)
   501  		}
   502  	}
   503  	if len(errors) > 0 {
   504  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   505  	}
   506  
   507  	return nil
   508  }