github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/tcp.go (about) 1 package service 2 3 import ( 4 "context" 5 "errors" 6 "net" 7 "os" 8 "sync" 9 "time" 10 11 "github.com/database64128/shadowsocks-go/conn" 12 "github.com/database64128/shadowsocks-go/direct" 13 "github.com/database64128/shadowsocks-go/router" 14 "github.com/database64128/shadowsocks-go/stats" 15 "github.com/database64128/shadowsocks-go/zerocopy" 16 "go.uber.org/zap" 17 ) 18 19 const ( 20 defaultInitialPayloadWaitBufferSize = 1440 21 defaultInitialPayloadWaitTimeout = 250 * time.Millisecond 22 ) 23 24 // tcpRelayListener configures the TCP listener for a relay service. 25 type tcpRelayListener struct { 26 logger *zap.Logger 27 listener *net.TCPListener 28 listenConfig conn.ListenConfig 29 waitForInitialPayload bool 30 initialPayloadWaitTimeout time.Duration 31 initialPayloadWaitBufferSize int 32 network string 33 address string 34 } 35 36 // TCPRelay is a relay service for TCP traffic. 37 // 38 // When started, the relay service accepts incoming TCP connections on the server, 39 // and dispatches them to a client selected by the router. 40 // 41 // TCPRelay implements the Service interface. 42 type TCPRelay struct { 43 serverIndex int 44 serverName string 45 listeners []tcpRelayListener 46 acceptWg sync.WaitGroup 47 server zerocopy.TCPServer 48 connCloser zerocopy.TCPConnCloser 49 fallbackAddress conn.Addr 50 collector stats.Collector 51 router *router.Router 52 logger *zap.Logger 53 } 54 55 func NewTCPRelay( 56 serverIndex int, 57 serverName string, 58 listeners []tcpRelayListener, 59 server zerocopy.TCPServer, 60 connCloser zerocopy.TCPConnCloser, 61 fallbackAddress conn.Addr, 62 collector stats.Collector, 63 router *router.Router, 64 logger *zap.Logger, 65 ) *TCPRelay { 66 return &TCPRelay{ 67 serverIndex: serverIndex, 68 serverName: serverName, 69 listeners: listeners, 70 server: server, 71 connCloser: connCloser, 72 fallbackAddress: fallbackAddress, 73 collector: collector, 74 router: router, 75 logger: logger, 76 } 77 } 78 79 // String implements the Service String method. 80 func (s *TCPRelay) String() string { 81 return "TCP relay service for " + s.serverName 82 } 83 84 // Start implements the Service Start method. 85 func (s *TCPRelay) Start(ctx context.Context) error { 86 for i := range s.listeners { 87 index := i 88 lnc := &s.listeners[index] 89 90 l, err := lnc.listenConfig.ListenTCP(ctx, lnc.network, lnc.address) 91 if err != nil { 92 return err 93 } 94 lnc.listener = l 95 lnc.address = l.Addr().String() 96 lnc.logger = s.logger.With( 97 zap.String("server", s.serverName), 98 zap.Int("listener", index), 99 zap.String("listenAddress", lnc.address), 100 ) 101 102 s.acceptWg.Add(1) 103 104 go func() { 105 for { 106 clientConn, err := lnc.listener.AcceptTCP() 107 if err != nil { 108 if errors.Is(err, os.ErrDeadlineExceeded) { 109 break 110 } 111 lnc.logger.Warn("Failed to accept TCP connection", zap.Error(err)) 112 continue 113 } 114 115 go s.handleConn(ctx, lnc, clientConn) 116 } 117 118 s.acceptWg.Done() 119 }() 120 121 lnc.logger.Info("Started TCP relay service listener") 122 } 123 return nil 124 } 125 126 // handleConn handles an accepted TCP connection. 127 func (s *TCPRelay) handleConn(ctx context.Context, lnc *tcpRelayListener, clientConn *net.TCPConn) { 128 // Get client address. 129 clientAddrPort := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort() 130 clientAddress := clientAddrPort.String() 131 132 // Handshake. 133 clientRW, targetAddr, payload, username, err := s.server.Accept(clientConn) 134 if err != nil { 135 if err == zerocopy.ErrAcceptDoneNoRelay { 136 if ce := lnc.logger.Check(zap.DebugLevel, "The accepted connection has been handled without relaying"); ce != nil { 137 ce.Write( 138 zap.String("clientAddress", clientAddress), 139 ) 140 } 141 clientConn.Close() 142 return 143 } 144 145 logger := lnc.logger.With( 146 zap.String("clientAddress", clientAddress), 147 ) 148 149 logger.Warn("Failed to complete handshake with client", zap.Error(err)) 150 151 if !s.fallbackAddress.IsValid() || len(payload) == 0 { 152 s.connCloser(clientConn, logger) 153 clientConn.Close() 154 return 155 } 156 157 clientRW = direct.NewDirectStreamReadWriter(clientConn) 158 targetAddr = s.fallbackAddress 159 } 160 defer clientRW.Close() 161 162 // Convert target address to string once for log messages. 163 targetAddress := targetAddr.String() 164 165 // Route. 166 c, err := s.router.GetTCPClient(ctx, router.RequestInfo{ 167 ServerIndex: s.serverIndex, 168 Username: username, 169 SourceAddrPort: clientAddrPort, 170 TargetAddr: targetAddr, 171 }) 172 if err != nil { 173 lnc.logger.Warn("Failed to get TCP client for client connection", 174 zap.String("clientAddress", clientAddress), 175 zap.String("username", username), 176 zap.String("targetAddress", targetAddress), 177 zap.Error(err), 178 ) 179 return 180 } 181 182 // Get client info. 183 clientInfo := c.Info() 184 185 // Create logger with new fields. 186 logger := lnc.logger.With( 187 zap.String("clientAddress", clientAddress), 188 zap.String("username", username), 189 zap.String("targetAddress", targetAddress), 190 zap.String("client", clientInfo.Name), 191 ) 192 193 // Wait for initial payload if all of the following are true: 194 // 1. not disabled 195 // 2. server does not have native support 196 // 3. client has native support 197 if lnc.waitForInitialPayload && clientInfo.NativeInitialPayload { 198 clientReaderInfo := clientRW.ReaderInfo() 199 payloadBufSize := max(clientReaderInfo.MinPayloadBufferSizePerRead, lnc.initialPayloadWaitBufferSize) 200 payload = make([]byte, clientReaderInfo.Headroom.Front+payloadBufSize+clientReaderInfo.Headroom.Rear) 201 202 err = clientConn.SetReadDeadline(time.Now().Add(lnc.initialPayloadWaitTimeout)) 203 if err != nil { 204 logger.Warn("Failed to set read deadline to initial payload wait timeout", zap.Error(err)) 205 return 206 } 207 208 payloadLength, err := clientRW.ReadZeroCopy(payload, clientReaderInfo.Headroom.Front, payloadBufSize) 209 switch { 210 case err == nil: 211 payload = payload[clientReaderInfo.Headroom.Front : clientReaderInfo.Headroom.Front+payloadLength] 212 if ce := logger.Check(zap.DebugLevel, "Got initial payload"); ce != nil { 213 ce.Write( 214 zap.Int("payloadLength", payloadLength), 215 ) 216 } 217 218 case errors.Is(err, os.ErrDeadlineExceeded): 219 if ce := logger.Check(zap.DebugLevel, "Initial payload wait timed out"); ce != nil { 220 ce.Write() 221 } 222 223 default: 224 logger.Warn("Failed to read initial payload", zap.Error(err)) 225 return 226 } 227 228 err = clientConn.SetReadDeadline(time.Time{}) 229 if err != nil { 230 logger.Warn("Failed to reset read deadline", zap.Error(err)) 231 return 232 } 233 } 234 235 // Create remote connection. 236 remoteRawRW, remoteRW, err := c.Dial(ctx, targetAddr, payload) 237 if err != nil { 238 logger.Warn("Failed to create remote connection", 239 zap.Int("initialPayloadLength", len(payload)), 240 zap.Error(err), 241 ) 242 return 243 } 244 defer remoteRawRW.Close() 245 246 logger.Info("Two-way relay started", 247 zap.Int("initialPayloadLength", len(payload)), 248 ) 249 250 // Two-way relay. 251 nl2r, nr2l, err := zerocopy.TwoWayRelay(clientRW, remoteRW) 252 nl2r += int64(len(payload)) 253 s.collector.CollectTCPSession(username, uint64(nr2l), uint64(nl2r)) 254 if err != nil { 255 logger.Warn("Two-way relay failed", 256 zap.Int64("nl2r", nl2r), 257 zap.Int64("nr2l", nr2l), 258 zap.Error(err), 259 ) 260 return 261 } 262 263 logger.Info("Two-way relay completed", 264 zap.Int64("nl2r", nl2r), 265 zap.Int64("nr2l", nr2l), 266 ) 267 } 268 269 // Stop implements the Service Stop method. 270 func (s *TCPRelay) Stop() error { 271 for i := range s.listeners { 272 lnc := &s.listeners[i] 273 if err := lnc.listener.SetDeadline(conn.ALongTimeAgo); err != nil { 274 lnc.logger.Warn("Failed to set deadline on listener", zap.Error(err)) 275 } 276 } 277 278 s.acceptWg.Wait() 279 280 for i := range s.listeners { 281 lnc := &s.listeners[i] 282 if err := lnc.listener.Close(); err != nil { 283 lnc.logger.Warn("Failed to close listener", zap.Error(err)) 284 } 285 } 286 287 return nil 288 }