github.com/database64128/shadowsocks-go@v1.7.0/service/tcp.go (about)

     1  package service
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"os"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/database64128/shadowsocks-go/conn"
    13  	"github.com/database64128/shadowsocks-go/direct"
    14  	"github.com/database64128/shadowsocks-go/router"
    15  	"github.com/database64128/shadowsocks-go/stats"
    16  	"github.com/database64128/shadowsocks-go/zerocopy"
    17  	"github.com/database64128/tfo-go/v2"
    18  	"go.uber.org/zap"
    19  )
    20  
    21  const (
    22  	initialPayloadWaitBufferSize = 1280
    23  	initialPayloadWaitTimeout    = 250 * time.Millisecond
    24  )
    25  
    26  // TCPRelay is a relay service for TCP traffic.
    27  //
    28  // When started, the relay service accepts incoming TCP connections on the server,
    29  // and dispatches them to a client selected by the router.
    30  //
    31  // TCPRelay implements the Service interface.
    32  type TCPRelay struct {
    33  	serverIndex           int
    34  	serverName            string
    35  	listenAddress         string
    36  	wg                    sync.WaitGroup
    37  	listenConfig          tfo.ListenConfig
    38  	waitForInitialPayload bool
    39  	server                zerocopy.TCPServer
    40  	connCloser            zerocopy.TCPConnCloser
    41  	fallbackAddress       *conn.Addr
    42  	collector             stats.Collector
    43  	router                *router.Router
    44  	logger                *zap.Logger
    45  	listener              *net.TCPListener
    46  }
    47  
    48  func NewTCPRelay(
    49  	serverIndex int,
    50  	serverName, listenAddress string, waitForInitialPayload bool,
    51  	listenConfig tfo.ListenConfig,
    52  	server zerocopy.TCPServer,
    53  	connCloser zerocopy.TCPConnCloser,
    54  	fallbackAddress *conn.Addr,
    55  	collector stats.Collector,
    56  	router *router.Router,
    57  	logger *zap.Logger,
    58  ) *TCPRelay {
    59  	return &TCPRelay{
    60  		serverIndex:           serverIndex,
    61  		serverName:            serverName,
    62  		listenAddress:         listenAddress,
    63  		listenConfig:          listenConfig,
    64  		waitForInitialPayload: waitForInitialPayload,
    65  		server:                server,
    66  		connCloser:            connCloser,
    67  		fallbackAddress:       fallbackAddress,
    68  		collector:             collector,
    69  		router:                router,
    70  		logger:                logger,
    71  	}
    72  }
    73  
    74  // String implements the Service String method.
    75  func (s *TCPRelay) String() string {
    76  	return fmt.Sprintf("TCP relay service for %s", s.serverName)
    77  }
    78  
    79  // Start implements the Service Start method.
    80  func (s *TCPRelay) Start() error {
    81  	l, err := s.listenConfig.Listen(context.Background(), "tcp", s.listenAddress)
    82  	if err != nil {
    83  		return err
    84  	}
    85  	s.listener = l.(*net.TCPListener)
    86  
    87  	s.wg.Add(1)
    88  
    89  	go func() {
    90  		for {
    91  			clientConn, err := s.listener.AcceptTCP()
    92  			if err != nil {
    93  				if errors.Is(err, os.ErrDeadlineExceeded) {
    94  					break
    95  				}
    96  				s.logger.Warn("Failed to accept TCP connection",
    97  					zap.String("server", s.serverName),
    98  					zap.String("listenAddress", s.listenAddress),
    99  					zap.Error(err),
   100  				)
   101  				continue
   102  			}
   103  
   104  			go s.handleConn(clientConn)
   105  		}
   106  
   107  		s.wg.Done()
   108  	}()
   109  
   110  	s.logger.Info("Started TCP relay service",
   111  		zap.String("server", s.serverName),
   112  		zap.String("listenAddress", s.listenAddress),
   113  	)
   114  
   115  	return nil
   116  }
   117  
   118  // handleConn handles an accepted TCP connection.
   119  func (s *TCPRelay) handleConn(clientConn *net.TCPConn) {
   120  	defer clientConn.Close()
   121  
   122  	// Get client address.
   123  	clientAddrPort := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort()
   124  	clientAddress := clientAddrPort.String()
   125  
   126  	// Handshake.
   127  	clientRW, targetAddr, payload, username, err := s.server.Accept(clientConn)
   128  	if err != nil {
   129  		if err == zerocopy.ErrAcceptDoneNoRelay {
   130  			s.logger.Debug("The accepted connection has been handled without relaying",
   131  				zap.String("server", s.serverName),
   132  				zap.String("listenAddress", s.listenAddress),
   133  				zap.String("clientAddress", clientAddress),
   134  			)
   135  			return
   136  		}
   137  
   138  		s.logger.Warn("Failed to complete handshake with client",
   139  			zap.String("server", s.serverName),
   140  			zap.String("listenAddress", s.listenAddress),
   141  			zap.String("clientAddress", clientAddress),
   142  			zap.Error(err),
   143  		)
   144  
   145  		if s.fallbackAddress == nil || len(payload) == 0 {
   146  			s.connCloser(clientConn, s.serverName, s.listenAddress, clientAddress, s.logger)
   147  			return
   148  		}
   149  
   150  		clientRW = direct.NewDirectStreamReadWriter(clientConn)
   151  		targetAddr = *s.fallbackAddress
   152  	}
   153  
   154  	// Convert target address to string once for log messages.
   155  	targetAddress := targetAddr.String()
   156  
   157  	// Route.
   158  	c, err := s.router.GetTCPClient(router.RequestInfo{
   159  		ServerIndex:    s.serverIndex,
   160  		Username:       username,
   161  		SourceAddrPort: clientAddrPort,
   162  		TargetAddr:     targetAddr,
   163  	})
   164  	if err != nil {
   165  		s.logger.Warn("Failed to get TCP client for client connection",
   166  			zap.String("server", s.serverName),
   167  			zap.String("listenAddress", s.listenAddress),
   168  			zap.String("clientAddress", clientAddress),
   169  			zap.String("targetAddress", targetAddress),
   170  			zap.String("username", username),
   171  			zap.Error(err),
   172  		)
   173  		return
   174  	}
   175  
   176  	// Get client info.
   177  	clientInfo := c.Info()
   178  
   179  	// Wait for initial payload if all of the following are true:
   180  	// 1. not disabled
   181  	// 2. server does not have native support
   182  	// 3. client has native support
   183  	if s.waitForInitialPayload && clientInfo.NativeInitialPayload {
   184  		clientReaderInfo := clientRW.ReaderInfo()
   185  		payloadBufSize := clientReaderInfo.MinPayloadBufferSizePerRead
   186  		if payloadBufSize == 0 {
   187  			payloadBufSize = initialPayloadWaitBufferSize
   188  		}
   189  
   190  		payload = make([]byte, clientReaderInfo.Headroom.Front+payloadBufSize+clientReaderInfo.Headroom.Rear)
   191  
   192  		err = clientConn.SetReadDeadline(time.Now().Add(initialPayloadWaitTimeout))
   193  		if err != nil {
   194  			s.logger.Warn("Failed to set read deadline to initial payload wait timeout",
   195  				zap.String("server", s.serverName),
   196  				zap.String("client", clientInfo.Name),
   197  				zap.String("listenAddress", s.listenAddress),
   198  				zap.String("clientAddress", clientAddress),
   199  				zap.String("targetAddress", targetAddress),
   200  				zap.String("username", username),
   201  				zap.Error(err),
   202  			)
   203  			return
   204  		}
   205  
   206  		payloadLength, err := clientRW.ReadZeroCopy(payload, clientReaderInfo.Headroom.Front, payloadBufSize)
   207  		switch {
   208  		case err == nil:
   209  			payload = payload[clientReaderInfo.Headroom.Front : clientReaderInfo.Headroom.Front+payloadLength]
   210  			s.logger.Debug("Got initial payload",
   211  				zap.String("server", s.serverName),
   212  				zap.String("client", clientInfo.Name),
   213  				zap.String("listenAddress", s.listenAddress),
   214  				zap.String("clientAddress", clientAddress),
   215  				zap.String("targetAddress", targetAddress),
   216  				zap.String("username", username),
   217  				zap.Int("payloadLength", payloadLength),
   218  			)
   219  
   220  		case errors.Is(err, os.ErrDeadlineExceeded):
   221  			s.logger.Debug("Initial payload wait timed out",
   222  				zap.String("server", s.serverName),
   223  				zap.String("client", clientInfo.Name),
   224  				zap.String("listenAddress", s.listenAddress),
   225  				zap.String("clientAddress", clientAddress),
   226  				zap.String("targetAddress", targetAddress),
   227  				zap.String("username", username),
   228  			)
   229  
   230  		default:
   231  			s.logger.Warn("Failed to read initial payload",
   232  				zap.String("server", s.serverName),
   233  				zap.String("client", clientInfo.Name),
   234  				zap.String("listenAddress", s.listenAddress),
   235  				zap.String("clientAddress", clientAddress),
   236  				zap.String("targetAddress", targetAddress),
   237  				zap.String("username", username),
   238  				zap.Error(err),
   239  			)
   240  			return
   241  		}
   242  
   243  		err = clientConn.SetReadDeadline(time.Time{})
   244  		if err != nil {
   245  			s.logger.Warn("Failed to reset read deadline",
   246  				zap.String("server", s.serverName),
   247  				zap.String("client", clientInfo.Name),
   248  				zap.String("listenAddress", s.listenAddress),
   249  				zap.String("clientAddress", clientAddress),
   250  				zap.String("targetAddress", targetAddress),
   251  				zap.String("username", username),
   252  				zap.Error(err),
   253  			)
   254  			return
   255  		}
   256  	}
   257  
   258  	// Create remote connection.
   259  	remoteConn, remoteRW, err := c.Dial(targetAddr, payload)
   260  	if err != nil {
   261  		s.logger.Warn("Failed to create remote connection",
   262  			zap.String("server", s.serverName),
   263  			zap.String("client", clientInfo.Name),
   264  			zap.String("listenAddress", s.listenAddress),
   265  			zap.String("clientAddress", clientAddress),
   266  			zap.String("targetAddress", targetAddress),
   267  			zap.String("username", username),
   268  			zap.Int("initialPayloadLength", len(payload)),
   269  			zap.Error(err),
   270  		)
   271  		return
   272  	}
   273  	defer remoteConn.Close()
   274  
   275  	s.logger.Info("Two-way relay started",
   276  		zap.String("server", s.serverName),
   277  		zap.String("client", clientInfo.Name),
   278  		zap.String("listenAddress", s.listenAddress),
   279  		zap.String("clientAddress", clientAddress),
   280  		zap.String("targetAddress", targetAddress),
   281  		zap.String("username", username),
   282  		zap.Int("initialPayloadLength", len(payload)),
   283  	)
   284  
   285  	// Two-way relay.
   286  	nl2r, nr2l, err := zerocopy.TwoWayRelay(clientRW, remoteRW)
   287  	nl2r += int64(len(payload))
   288  	s.collector.CollectTCPSession(username, uint64(nr2l), uint64(nl2r))
   289  	if err != nil {
   290  		s.logger.Warn("Two-way relay failed",
   291  			zap.String("server", s.serverName),
   292  			zap.String("client", clientInfo.Name),
   293  			zap.String("listenAddress", s.listenAddress),
   294  			zap.String("clientAddress", clientAddress),
   295  			zap.String("targetAddress", targetAddress),
   296  			zap.String("username", username),
   297  			zap.Int64("nl2r", nl2r),
   298  			zap.Int64("nr2l", nr2l),
   299  			zap.Error(err),
   300  		)
   301  		return
   302  	}
   303  
   304  	s.logger.Info("Two-way relay completed",
   305  		zap.String("server", s.serverName),
   306  		zap.String("client", clientInfo.Name),
   307  		zap.String("listenAddress", s.listenAddress),
   308  		zap.String("clientAddress", clientAddress),
   309  		zap.String("targetAddress", targetAddress),
   310  		zap.String("username", username),
   311  		zap.Int64("nl2r", nl2r),
   312  		zap.Int64("nr2l", nr2l),
   313  	)
   314  }
   315  
   316  // Stop implements the Service Stop method.
   317  func (s *TCPRelay) Stop() error {
   318  	if s.listener == nil {
   319  		return nil
   320  	}
   321  	if err := s.listener.SetDeadline(time.Now()); err != nil {
   322  		return err
   323  	}
   324  	s.wg.Wait()
   325  	return s.listener.Close()
   326  }