github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/tun/handler_tcp.go (about)

     1  package tun
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"gvisor.dev/gvisor/pkg/tcpip"
     8  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
     9  	"gvisor.dev/gvisor/pkg/tcpip/header"
    10  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    11  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    12  	"gvisor.dev/gvisor/pkg/waiter"
    13  
    14  	tun_net "github.com/v2fly/v2ray-core/v5/app/tun/net"
    15  	"github.com/v2fly/v2ray-core/v5/common"
    16  	"github.com/v2fly/v2ray-core/v5/common/buf"
    17  	"github.com/v2fly/v2ray-core/v5/common/log"
    18  	"github.com/v2fly/v2ray-core/v5/common/net"
    19  	"github.com/v2fly/v2ray-core/v5/common/session"
    20  	"github.com/v2fly/v2ray-core/v5/common/signal"
    21  	"github.com/v2fly/v2ray-core/v5/common/task"
    22  	"github.com/v2fly/v2ray-core/v5/features/policy"
    23  	"github.com/v2fly/v2ray-core/v5/features/routing"
    24  	internet "github.com/v2fly/v2ray-core/v5/transport/internet"
    25  )
    26  
    27  const (
    28  	rcvWnd      = 0 // default settings
    29  	maxInFlight = 2 << 10
    30  )
    31  
    32  type tcpConn struct {
    33  	*gonet.TCPConn
    34  	id stack.TransportEndpointID
    35  }
    36  
    37  func (c *tcpConn) ID() *stack.TransportEndpointID {
    38  	return &c.id
    39  }
    40  
    41  type TCPHandler struct {
    42  	ctx           context.Context
    43  	dispatcher    routing.Dispatcher
    44  	policyManager policy.Manager
    45  	config        *Config
    46  }
    47  
    48  func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) StackOption {
    49  	return func(s *stack.Stack) error {
    50  		tcpForwarder := tcp.NewForwarder(s, rcvWnd, maxInFlight, func(r *tcp.ForwarderRequest) {
    51  			wg := new(waiter.Queue)
    52  			linkedEndpoint, err := r.CreateEndpoint(wg)
    53  			if err != nil {
    54  				r.Complete(true)
    55  				return
    56  			}
    57  			defer r.Complete(false)
    58  
    59  			if config.SocketSettings != nil {
    60  				if err := applySocketOptions(s, linkedEndpoint, config.SocketSettings); err != nil {
    61  					newError("failed to apply socket options: ", err).WriteToLog(session.ExportIDToError(ctx))
    62  				}
    63  			}
    64  
    65  			conn := &tcpConn{
    66  				TCPConn: gonet.NewTCPConn(wg, linkedEndpoint),
    67  				id:      r.ID(),
    68  			}
    69  
    70  			handler := &TCPHandler{
    71  				ctx:           ctx,
    72  				dispatcher:    dispatcher,
    73  				policyManager: policyManager,
    74  				config:        config,
    75  			}
    76  
    77  			go handler.Handle(conn)
    78  		})
    79  
    80  		s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
    81  
    82  		return nil
    83  	}
    84  }
    85  
    86  func (h *TCPHandler) Handle(conn tun_net.TCPConn) error {
    87  	defer conn.Close()
    88  	id := conn.ID()
    89  	ctx := session.ContextWithInbound(h.ctx, &session.Inbound{Tag: h.config.Tag})
    90  	sessionPolicy := h.policyManager.ForLevel(h.config.UserLevel)
    91  
    92  	dest := net.TCPDestination(tun_net.AddressFromTCPIPAddr(id.LocalAddress), net.Port(id.LocalPort))
    93  	src := net.TCPDestination(tun_net.AddressFromTCPIPAddr(id.RemoteAddress), net.Port(id.RemotePort))
    94  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
    95  		From:   src,
    96  		To:     dest,
    97  		Status: log.AccessAccepted,
    98  		Reason: "",
    99  	})
   100  	content := new(session.Content)
   101  	if h.config.SniffingSettings != nil {
   102  		content.SniffingRequest.Enabled = h.config.SniffingSettings.Enabled
   103  		content.SniffingRequest.OverrideDestinationForProtocol = h.config.SniffingSettings.DestinationOverride
   104  		content.SniffingRequest.MetadataOnly = h.config.SniffingSettings.MetadataOnly
   105  	}
   106  	ctx = session.ContextWithContent(ctx, content)
   107  	ctx, cancel := context.WithCancel(ctx)
   108  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   109  	link, err := h.dispatcher.Dispatch(ctx, dest)
   110  	if err != nil {
   111  		return newError("failed to dispatch").Base(err)
   112  	}
   113  
   114  	responseDone := func() error {
   115  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   116  
   117  		if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
   118  			return newError("failed to transport all TCP response").Base(err)
   119  		}
   120  
   121  		return nil
   122  	}
   123  
   124  	requestDone := func() error {
   125  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   126  
   127  		if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
   128  			return newError("failed to transport all TCP request").Base(err)
   129  		}
   130  
   131  		return nil
   132  	}
   133  
   134  	requestDoneAndCloseWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
   135  	if err := task.Run(h.ctx, requestDoneAndCloseWriter, responseDone); err != nil {
   136  		common.Interrupt(link.Reader)
   137  		common.Interrupt(link.Writer)
   138  		return newError("connection ends").Base(err)
   139  	}
   140  
   141  	return nil
   142  }
   143  
   144  func applySocketOptions(s *stack.Stack, endpoint tcpip.Endpoint, config *internet.SocketConfig) tcpip.Error {
   145  	if config.TcpKeepAliveInterval > 0 {
   146  		interval := tcpip.KeepaliveIntervalOption(time.Duration(config.TcpKeepAliveInterval) * time.Second)
   147  		if err := endpoint.SetSockOpt(&interval); err != nil {
   148  			return err
   149  		}
   150  	}
   151  
   152  	if config.TcpKeepAliveIdle > 0 {
   153  		idle := tcpip.KeepaliveIdleOption(time.Duration(config.TcpKeepAliveIdle) * time.Second)
   154  		if err := endpoint.SetSockOpt(&idle); err != nil {
   155  			return err
   156  		}
   157  	}
   158  
   159  	if config.TcpKeepAliveInterval > 0 || config.TcpKeepAliveIdle > 0 {
   160  		endpoint.SocketOptions().SetKeepAlive(true)
   161  	}
   162  	{
   163  		var sendBufferSizeRangeOption tcpip.TCPSendBufferSizeRangeOption
   164  		if err := s.TransportProtocolOption(header.TCPProtocolNumber, &sendBufferSizeRangeOption); err == nil {
   165  			endpoint.SocketOptions().SetReceiveBufferSize(int64(sendBufferSizeRangeOption.Default), false)
   166  		}
   167  
   168  		var receiveBufferSizeRangeOption tcpip.TCPReceiveBufferSizeRangeOption
   169  		if err := s.TransportProtocolOption(header.TCPProtocolNumber, &receiveBufferSizeRangeOption); err == nil {
   170  			endpoint.SocketOptions().SetSendBufferSize(int64(receiveBufferSizeRangeOption.Default), false)
   171  		}
   172  	}
   173  
   174  	return nil
   175  }