github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/tcp.go (about)

     1  package service
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"os"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/database64128/shadowsocks-go/conn"
    12  	"github.com/database64128/shadowsocks-go/direct"
    13  	"github.com/database64128/shadowsocks-go/router"
    14  	"github.com/database64128/shadowsocks-go/stats"
    15  	"github.com/database64128/shadowsocks-go/zerocopy"
    16  	"go.uber.org/zap"
    17  )
    18  
    19  const (
    20  	defaultInitialPayloadWaitBufferSize = 1440
    21  	defaultInitialPayloadWaitTimeout    = 250 * time.Millisecond
    22  )
    23  
    24  // tcpRelayListener configures the TCP listener for a relay service.
    25  type tcpRelayListener struct {
    26  	logger                       *zap.Logger
    27  	listener                     *net.TCPListener
    28  	listenConfig                 conn.ListenConfig
    29  	waitForInitialPayload        bool
    30  	initialPayloadWaitTimeout    time.Duration
    31  	initialPayloadWaitBufferSize int
    32  	network                      string
    33  	address                      string
    34  }
    35  
    36  // TCPRelay is a relay service for TCP traffic.
    37  //
    38  // When started, the relay service accepts incoming TCP connections on the server,
    39  // and dispatches them to a client selected by the router.
    40  //
    41  // TCPRelay implements the Service interface.
    42  type TCPRelay struct {
    43  	serverIndex     int
    44  	serverName      string
    45  	listeners       []tcpRelayListener
    46  	acceptWg        sync.WaitGroup
    47  	server          zerocopy.TCPServer
    48  	connCloser      zerocopy.TCPConnCloser
    49  	fallbackAddress conn.Addr
    50  	collector       stats.Collector
    51  	router          *router.Router
    52  	logger          *zap.Logger
    53  }
    54  
    55  func NewTCPRelay(
    56  	serverIndex int,
    57  	serverName string,
    58  	listeners []tcpRelayListener,
    59  	server zerocopy.TCPServer,
    60  	connCloser zerocopy.TCPConnCloser,
    61  	fallbackAddress conn.Addr,
    62  	collector stats.Collector,
    63  	router *router.Router,
    64  	logger *zap.Logger,
    65  ) *TCPRelay {
    66  	return &TCPRelay{
    67  		serverIndex:     serverIndex,
    68  		serverName:      serverName,
    69  		listeners:       listeners,
    70  		server:          server,
    71  		connCloser:      connCloser,
    72  		fallbackAddress: fallbackAddress,
    73  		collector:       collector,
    74  		router:          router,
    75  		logger:          logger,
    76  	}
    77  }
    78  
    79  // String implements the Service String method.
    80  func (s *TCPRelay) String() string {
    81  	return "TCP relay service for " + s.serverName
    82  }
    83  
    84  // Start implements the Service Start method.
    85  func (s *TCPRelay) Start(ctx context.Context) error {
    86  	for i := range s.listeners {
    87  		index := i
    88  		lnc := &s.listeners[index]
    89  
    90  		l, err := lnc.listenConfig.ListenTCP(ctx, lnc.network, lnc.address)
    91  		if err != nil {
    92  			return err
    93  		}
    94  		lnc.listener = l
    95  		lnc.address = l.Addr().String()
    96  		lnc.logger = s.logger.With(
    97  			zap.String("server", s.serverName),
    98  			zap.Int("listener", index),
    99  			zap.String("listenAddress", lnc.address),
   100  		)
   101  
   102  		s.acceptWg.Add(1)
   103  
   104  		go func() {
   105  			for {
   106  				clientConn, err := lnc.listener.AcceptTCP()
   107  				if err != nil {
   108  					if errors.Is(err, os.ErrDeadlineExceeded) {
   109  						break
   110  					}
   111  					lnc.logger.Warn("Failed to accept TCP connection", zap.Error(err))
   112  					continue
   113  				}
   114  
   115  				go s.handleConn(ctx, lnc, clientConn)
   116  			}
   117  
   118  			s.acceptWg.Done()
   119  		}()
   120  
   121  		lnc.logger.Info("Started TCP relay service listener")
   122  	}
   123  	return nil
   124  }
   125  
   126  // handleConn handles an accepted TCP connection.
   127  func (s *TCPRelay) handleConn(ctx context.Context, lnc *tcpRelayListener, clientConn *net.TCPConn) {
   128  	// Get client address.
   129  	clientAddrPort := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort()
   130  	clientAddress := clientAddrPort.String()
   131  
   132  	// Handshake.
   133  	clientRW, targetAddr, payload, username, err := s.server.Accept(clientConn)
   134  	if err != nil {
   135  		if err == zerocopy.ErrAcceptDoneNoRelay {
   136  			if ce := lnc.logger.Check(zap.DebugLevel, "The accepted connection has been handled without relaying"); ce != nil {
   137  				ce.Write(
   138  					zap.String("clientAddress", clientAddress),
   139  				)
   140  			}
   141  			clientConn.Close()
   142  			return
   143  		}
   144  
   145  		logger := lnc.logger.With(
   146  			zap.String("clientAddress", clientAddress),
   147  		)
   148  
   149  		logger.Warn("Failed to complete handshake with client", zap.Error(err))
   150  
   151  		if !s.fallbackAddress.IsValid() || len(payload) == 0 {
   152  			s.connCloser(clientConn, logger)
   153  			clientConn.Close()
   154  			return
   155  		}
   156  
   157  		clientRW = direct.NewDirectStreamReadWriter(clientConn)
   158  		targetAddr = s.fallbackAddress
   159  	}
   160  	defer clientRW.Close()
   161  
   162  	// Convert target address to string once for log messages.
   163  	targetAddress := targetAddr.String()
   164  
   165  	// Route.
   166  	c, err := s.router.GetTCPClient(ctx, router.RequestInfo{
   167  		ServerIndex:    s.serverIndex,
   168  		Username:       username,
   169  		SourceAddrPort: clientAddrPort,
   170  		TargetAddr:     targetAddr,
   171  	})
   172  	if err != nil {
   173  		lnc.logger.Warn("Failed to get TCP client for client connection",
   174  			zap.String("clientAddress", clientAddress),
   175  			zap.String("username", username),
   176  			zap.String("targetAddress", targetAddress),
   177  			zap.Error(err),
   178  		)
   179  		return
   180  	}
   181  
   182  	// Get client info.
   183  	clientInfo := c.Info()
   184  
   185  	// Create logger with new fields.
   186  	logger := lnc.logger.With(
   187  		zap.String("clientAddress", clientAddress),
   188  		zap.String("username", username),
   189  		zap.String("targetAddress", targetAddress),
   190  		zap.String("client", clientInfo.Name),
   191  	)
   192  
   193  	// Wait for initial payload if all of the following are true:
   194  	// 1. not disabled
   195  	// 2. server does not have native support
   196  	// 3. client has native support
   197  	if lnc.waitForInitialPayload && clientInfo.NativeInitialPayload {
   198  		clientReaderInfo := clientRW.ReaderInfo()
   199  		payloadBufSize := max(clientReaderInfo.MinPayloadBufferSizePerRead, lnc.initialPayloadWaitBufferSize)
   200  		payload = make([]byte, clientReaderInfo.Headroom.Front+payloadBufSize+clientReaderInfo.Headroom.Rear)
   201  
   202  		err = clientConn.SetReadDeadline(time.Now().Add(lnc.initialPayloadWaitTimeout))
   203  		if err != nil {
   204  			logger.Warn("Failed to set read deadline to initial payload wait timeout", zap.Error(err))
   205  			return
   206  		}
   207  
   208  		payloadLength, err := clientRW.ReadZeroCopy(payload, clientReaderInfo.Headroom.Front, payloadBufSize)
   209  		switch {
   210  		case err == nil:
   211  			payload = payload[clientReaderInfo.Headroom.Front : clientReaderInfo.Headroom.Front+payloadLength]
   212  			if ce := logger.Check(zap.DebugLevel, "Got initial payload"); ce != nil {
   213  				ce.Write(
   214  					zap.Int("payloadLength", payloadLength),
   215  				)
   216  			}
   217  
   218  		case errors.Is(err, os.ErrDeadlineExceeded):
   219  			if ce := logger.Check(zap.DebugLevel, "Initial payload wait timed out"); ce != nil {
   220  				ce.Write()
   221  			}
   222  
   223  		default:
   224  			logger.Warn("Failed to read initial payload", zap.Error(err))
   225  			return
   226  		}
   227  
   228  		err = clientConn.SetReadDeadline(time.Time{})
   229  		if err != nil {
   230  			logger.Warn("Failed to reset read deadline", zap.Error(err))
   231  			return
   232  		}
   233  	}
   234  
   235  	// Create remote connection.
   236  	remoteRawRW, remoteRW, err := c.Dial(ctx, targetAddr, payload)
   237  	if err != nil {
   238  		logger.Warn("Failed to create remote connection",
   239  			zap.Int("initialPayloadLength", len(payload)),
   240  			zap.Error(err),
   241  		)
   242  		return
   243  	}
   244  	defer remoteRawRW.Close()
   245  
   246  	logger.Info("Two-way relay started",
   247  		zap.Int("initialPayloadLength", len(payload)),
   248  	)
   249  
   250  	// Two-way relay.
   251  	nl2r, nr2l, err := zerocopy.TwoWayRelay(clientRW, remoteRW)
   252  	nl2r += int64(len(payload))
   253  	s.collector.CollectTCPSession(username, uint64(nr2l), uint64(nl2r))
   254  	if err != nil {
   255  		logger.Warn("Two-way relay failed",
   256  			zap.Int64("nl2r", nl2r),
   257  			zap.Int64("nr2l", nr2l),
   258  			zap.Error(err),
   259  		)
   260  		return
   261  	}
   262  
   263  	logger.Info("Two-way relay completed",
   264  		zap.Int64("nl2r", nl2r),
   265  		zap.Int64("nr2l", nr2l),
   266  	)
   267  }
   268  
   269  // Stop implements the Service Stop method.
   270  func (s *TCPRelay) Stop() error {
   271  	for i := range s.listeners {
   272  		lnc := &s.listeners[i]
   273  		if err := lnc.listener.SetDeadline(conn.ALongTimeAgo); err != nil {
   274  			lnc.logger.Warn("Failed to set deadline on listener", zap.Error(err))
   275  		}
   276  	}
   277  
   278  	s.acceptWg.Wait()
   279  
   280  	for i := range s.listeners {
   281  		lnc := &s.listeners[i]
   282  		if err := lnc.listener.Close(); err != nil {
   283  			lnc.logger.Warn("Failed to close listener", zap.Error(err))
   284  		}
   285  	}
   286  
   287  	return nil
   288  }