github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/app/proxyman/inbound/worker.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"v2ray.com/core/app/proxyman"
    10  	"v2ray.com/core/common"
    11  	"v2ray.com/core/common/buf"
    12  	"v2ray.com/core/common/net"
    13  	"v2ray.com/core/common/serial"
    14  	"v2ray.com/core/common/session"
    15  	"v2ray.com/core/common/signal/done"
    16  	"v2ray.com/core/common/task"
    17  	"v2ray.com/core/features/routing"
    18  	"v2ray.com/core/features/stats"
    19  	"v2ray.com/core/proxy"
    20  	"v2ray.com/core/transport/internet"
    21  	"v2ray.com/core/transport/internet/tcp"
    22  	"v2ray.com/core/transport/internet/udp"
    23  	"v2ray.com/core/transport/pipe"
    24  )
    25  
    26  type worker interface {
    27  	Start() error
    28  	Close() error
    29  	Port() net.Port
    30  	Proxy() proxy.Inbound
    31  }
    32  
    33  type tcpWorker struct {
    34  	address         net.Address
    35  	port            net.Port
    36  	proxy           proxy.Inbound
    37  	stream          *internet.MemoryStreamConfig
    38  	recvOrigDest    bool
    39  	tag             string
    40  	dispatcher      routing.Dispatcher
    41  	sniffingConfig  *proxyman.SniffingConfig
    42  	uplinkCounter   stats.Counter
    43  	downlinkCounter stats.Counter
    44  
    45  	hub internet.Listener
    46  
    47  	ctx context.Context
    48  }
    49  
    50  func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode {
    51  	if s == nil || s.SocketSettings == nil {
    52  		return internet.SocketConfig_Off
    53  	}
    54  	return s.SocketSettings.Tproxy
    55  }
    56  
    57  func (w *tcpWorker) callback(conn internet.Connection) {
    58  	ctx, cancel := context.WithCancel(w.ctx)
    59  	sid := session.NewID()
    60  	ctx = session.ContextWithID(ctx, sid)
    61  
    62  	if w.recvOrigDest {
    63  		var dest net.Destination
    64  		switch getTProxyType(w.stream) {
    65  		case internet.SocketConfig_Redirect:
    66  			d, err := tcp.GetOriginalDestination(conn)
    67  			if err != nil {
    68  				newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx))
    69  			} else {
    70  				dest = d
    71  			}
    72  		case internet.SocketConfig_TProxy:
    73  			dest = net.DestinationFromAddr(conn.LocalAddr())
    74  		}
    75  		if dest.IsValid() {
    76  			ctx = session.ContextWithOutbound(ctx, &session.Outbound{
    77  				Target: dest,
    78  			})
    79  		}
    80  	}
    81  	ctx = session.ContextWithInbound(ctx, &session.Inbound{
    82  		Source:  net.DestinationFromAddr(conn.RemoteAddr()),
    83  		Gateway: net.TCPDestination(w.address, w.port),
    84  		Tag:     w.tag,
    85  	})
    86  	content := new(session.Content)
    87  	if w.sniffingConfig != nil {
    88  		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
    89  		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
    90  	}
    91  	ctx = session.ContextWithContent(ctx, content)
    92  	if w.uplinkCounter != nil || w.downlinkCounter != nil {
    93  		conn = &internet.StatCouterConnection{
    94  			Connection:   conn,
    95  			ReadCounter:  w.uplinkCounter,
    96  			WriteCounter: w.downlinkCounter,
    97  		}
    98  	}
    99  	if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil {
   100  		newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   101  	}
   102  	cancel()
   103  	if err := conn.Close(); err != nil {
   104  		newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   105  	}
   106  }
   107  
   108  func (w *tcpWorker) Proxy() proxy.Inbound {
   109  	return w.proxy
   110  }
   111  
   112  func (w *tcpWorker) Start() error {
   113  	ctx := context.Background()
   114  	hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn internet.Connection) {
   115  		go w.callback(conn)
   116  	})
   117  	if err != nil {
   118  		return newError("failed to listen TCP on ", w.port).AtWarning().Base(err)
   119  	}
   120  	w.hub = hub
   121  	return nil
   122  }
   123  
   124  func (w *tcpWorker) Close() error {
   125  	var errors []interface{}
   126  	if w.hub != nil {
   127  		if err := common.Close(w.hub); err != nil {
   128  			errors = append(errors, err)
   129  		}
   130  		if err := common.Close(w.proxy); err != nil {
   131  			errors = append(errors, err)
   132  		}
   133  	}
   134  	if len(errors) > 0 {
   135  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  func (w *tcpWorker) Port() net.Port {
   142  	return w.port
   143  }
   144  
   145  type udpConn struct {
   146  	lastActivityTime int64 // in seconds
   147  	reader           buf.Reader
   148  	writer           buf.Writer
   149  	output           func([]byte) (int, error)
   150  	remote           net.Addr
   151  	local            net.Addr
   152  	done             *done.Instance
   153  	uplink           stats.Counter
   154  	downlink         stats.Counter
   155  }
   156  
   157  func (c *udpConn) updateActivity() {
   158  	atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
   159  }
   160  
   161  // ReadMultiBuffer implements buf.Reader
   162  func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
   163  	mb, err := c.reader.ReadMultiBuffer()
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  	c.updateActivity()
   168  
   169  	if c.uplink != nil {
   170  		c.uplink.Add(int64(mb.Len()))
   171  	}
   172  
   173  	return mb, nil
   174  }
   175  
   176  func (c *udpConn) Read(buf []byte) (int, error) {
   177  	panic("not implemented")
   178  }
   179  
   180  // Write implements io.Writer.
   181  func (c *udpConn) Write(buf []byte) (int, error) {
   182  	n, err := c.output(buf)
   183  	if c.downlink != nil {
   184  		c.downlink.Add(int64(n))
   185  	}
   186  	if err == nil {
   187  		c.updateActivity()
   188  	}
   189  	return n, err
   190  }
   191  
   192  func (c *udpConn) Close() error {
   193  	common.Must(c.done.Close())
   194  	common.Must(common.Close(c.writer))
   195  	return nil
   196  }
   197  
   198  func (c *udpConn) RemoteAddr() net.Addr {
   199  	return c.remote
   200  }
   201  
   202  func (c *udpConn) LocalAddr() net.Addr {
   203  	return c.local
   204  }
   205  
   206  func (*udpConn) SetDeadline(time.Time) error {
   207  	return nil
   208  }
   209  
   210  func (*udpConn) SetReadDeadline(time.Time) error {
   211  	return nil
   212  }
   213  
   214  func (*udpConn) SetWriteDeadline(time.Time) error {
   215  	return nil
   216  }
   217  
   218  type connID struct {
   219  	src  net.Destination
   220  	dest net.Destination
   221  }
   222  
   223  type udpWorker struct {
   224  	sync.RWMutex
   225  
   226  	proxy           proxy.Inbound
   227  	hub             *udp.Hub
   228  	address         net.Address
   229  	port            net.Port
   230  	tag             string
   231  	stream          *internet.MemoryStreamConfig
   232  	dispatcher      routing.Dispatcher
   233  	uplinkCounter   stats.Counter
   234  	downlinkCounter stats.Counter
   235  
   236  	checker    *task.Periodic
   237  	activeConn map[connID]*udpConn
   238  }
   239  
   240  func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
   241  	w.Lock()
   242  	defer w.Unlock()
   243  
   244  	if conn, found := w.activeConn[id]; found && !conn.done.Done() {
   245  		return conn, true
   246  	}
   247  
   248  	pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
   249  	conn := &udpConn{
   250  		reader: pReader,
   251  		writer: pWriter,
   252  		output: func(b []byte) (int, error) {
   253  			return w.hub.WriteTo(b, id.src)
   254  		},
   255  		remote: &net.UDPAddr{
   256  			IP:   id.src.Address.IP(),
   257  			Port: int(id.src.Port),
   258  		},
   259  		local: &net.UDPAddr{
   260  			IP:   w.address.IP(),
   261  			Port: int(w.port),
   262  		},
   263  		done:     done.New(),
   264  		uplink:   w.uplinkCounter,
   265  		downlink: w.downlinkCounter,
   266  	}
   267  	w.activeConn[id] = conn
   268  
   269  	conn.updateActivity()
   270  	return conn, false
   271  }
   272  
   273  func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
   274  	id := connID{
   275  		src: source,
   276  	}
   277  	if originalDest.IsValid() {
   278  		id.dest = originalDest
   279  	}
   280  	conn, existing := w.getConnection(id)
   281  
   282  	// payload will be discarded in pipe is full.
   283  	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) // nolint: errcheck
   284  
   285  	if !existing {
   286  		common.Must(w.checker.Start())
   287  
   288  		go func() {
   289  			ctx := context.Background()
   290  			sid := session.NewID()
   291  			ctx = session.ContextWithID(ctx, sid)
   292  
   293  			if originalDest.IsValid() {
   294  				ctx = session.ContextWithOutbound(ctx, &session.Outbound{
   295  					Target: originalDest,
   296  				})
   297  			}
   298  			ctx = session.ContextWithInbound(ctx, &session.Inbound{
   299  				Source:  source,
   300  				Gateway: net.UDPDestination(w.address, w.port),
   301  				Tag:     w.tag,
   302  			})
   303  			if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
   304  				newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx))
   305  			}
   306  			conn.Close() // nolint: errcheck
   307  			w.removeConn(id)
   308  		}()
   309  	}
   310  }
   311  
   312  func (w *udpWorker) removeConn(id connID) {
   313  	w.Lock()
   314  	delete(w.activeConn, id)
   315  	w.Unlock()
   316  }
   317  
   318  func (w *udpWorker) handlePackets() {
   319  	receive := w.hub.Receive()
   320  	for payload := range receive {
   321  		w.callback(payload.Payload, payload.Source, payload.Target)
   322  	}
   323  }
   324  
   325  func (w *udpWorker) clean() error {
   326  	nowSec := time.Now().Unix()
   327  	w.Lock()
   328  	defer w.Unlock()
   329  
   330  	if len(w.activeConn) == 0 {
   331  		return newError("no more connections. stopping...")
   332  	}
   333  
   334  	for addr, conn := range w.activeConn {
   335  		if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 8 { //TODO Timeout too small
   336  			delete(w.activeConn, addr)
   337  			conn.Close() // nolint: errcheck
   338  		}
   339  	}
   340  
   341  	if len(w.activeConn) == 0 {
   342  		w.activeConn = make(map[connID]*udpConn, 16)
   343  	}
   344  
   345  	return nil
   346  }
   347  
   348  func (w *udpWorker) Start() error {
   349  	w.activeConn = make(map[connID]*udpConn, 16)
   350  	ctx := context.Background()
   351  	h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256))
   352  	if err != nil {
   353  		return err
   354  	}
   355  
   356  	w.checker = &task.Periodic{
   357  		Interval: time.Second * 16,
   358  		Execute:  w.clean,
   359  	}
   360  
   361  	w.hub = h
   362  	go w.handlePackets()
   363  	return nil
   364  }
   365  
   366  func (w *udpWorker) Close() error {
   367  	w.Lock()
   368  	defer w.Unlock()
   369  
   370  	var errors []interface{}
   371  
   372  	if w.hub != nil {
   373  		if err := w.hub.Close(); err != nil {
   374  			errors = append(errors, err)
   375  		}
   376  	}
   377  
   378  	if w.checker != nil {
   379  		if err := w.checker.Close(); err != nil {
   380  			errors = append(errors, err)
   381  		}
   382  	}
   383  
   384  	if err := common.Close(w.proxy); err != nil {
   385  		errors = append(errors, err)
   386  	}
   387  
   388  	if len(errors) > 0 {
   389  		return newError("failed to close all resources").Base(newError(serial.Concat(errors...)))
   390  	}
   391  	return nil
   392  }
   393  
   394  func (w *udpWorker) Port() net.Port {
   395  	return w.port
   396  }
   397  
   398  func (w *udpWorker) Proxy() proxy.Inbound {
   399  	return w.proxy
   400  }