github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/app/proxyman/inbound/worker.go (about) 1 package inbound 2 3 import ( 4 "context" 5 "sync" 6 "sync/atomic" 7 "time" 8 9 "github.com/xtls/xray-core/app/proxyman" 10 "github.com/xtls/xray-core/common" 11 "github.com/xtls/xray-core/common/buf" 12 "github.com/xtls/xray-core/common/net" 13 "github.com/xtls/xray-core/common/serial" 14 "github.com/xtls/xray-core/common/session" 15 "github.com/xtls/xray-core/common/signal/done" 16 "github.com/xtls/xray-core/common/task" 17 "github.com/xtls/xray-core/features/routing" 18 "github.com/xtls/xray-core/features/stats" 19 "github.com/xtls/xray-core/proxy" 20 "github.com/xtls/xray-core/transport/internet" 21 "github.com/xtls/xray-core/transport/internet/stat" 22 "github.com/xtls/xray-core/transport/internet/tcp" 23 "github.com/xtls/xray-core/transport/internet/udp" 24 "github.com/xtls/xray-core/transport/pipe" 25 ) 26 27 type worker interface { 28 Start() error 29 Close() error 30 Port() net.Port 31 Proxy() proxy.Inbound 32 } 33 34 type tcpWorker struct { 35 address net.Address 36 port net.Port 37 proxy proxy.Inbound 38 stream *internet.MemoryStreamConfig 39 recvOrigDest bool 40 tag string 41 dispatcher routing.Dispatcher 42 sniffingConfig *proxyman.SniffingConfig 43 uplinkCounter stats.Counter 44 downlinkCounter stats.Counter 45 46 hub internet.Listener 47 48 ctx context.Context 49 } 50 51 func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode { 52 if s == nil || s.SocketSettings == nil { 53 return internet.SocketConfig_Off 54 } 55 return s.SocketSettings.Tproxy 56 } 57 58 func (w *tcpWorker) callback(conn stat.Connection) { 59 ctx, cancel := context.WithCancel(w.ctx) 60 sid := session.NewID() 61 ctx = session.ContextWithID(ctx, sid) 62 63 outbounds := []*session.Outbound{{}} 64 if w.recvOrigDest { 65 var dest net.Destination 66 switch getTProxyType(w.stream) { 67 case internet.SocketConfig_Redirect: 68 d, err := tcp.GetOriginalDestination(conn) 69 if err != nil { 70 newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx)) 71 } else { 72 dest = d 73 } 74 case internet.SocketConfig_TProxy: 75 dest = net.DestinationFromAddr(conn.LocalAddr()) 76 } 77 if dest.IsValid() { 78 outbounds[0].Target = dest 79 } 80 } 81 ctx = session.ContextWithOutbounds(ctx, outbounds) 82 83 if w.uplinkCounter != nil || w.downlinkCounter != nil { 84 conn = &stat.CounterConnection{ 85 Connection: conn, 86 ReadCounter: w.uplinkCounter, 87 WriteCounter: w.downlinkCounter, 88 } 89 } 90 ctx = session.ContextWithInbound(ctx, &session.Inbound{ 91 Source: net.DestinationFromAddr(conn.RemoteAddr()), 92 Gateway: net.TCPDestination(w.address, w.port), 93 Tag: w.tag, 94 Conn: conn, 95 }) 96 97 content := new(session.Content) 98 if w.sniffingConfig != nil { 99 content.SniffingRequest.Enabled = w.sniffingConfig.Enabled 100 content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride 101 content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded 102 content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly 103 content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly 104 } 105 ctx = session.ContextWithContent(ctx, content) 106 107 if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil { 108 newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx)) 109 } 110 cancel() 111 conn.Close() 112 } 113 114 func (w *tcpWorker) Proxy() proxy.Inbound { 115 return w.proxy 116 } 117 118 func (w *tcpWorker) Start() error { 119 ctx := context.Background() 120 hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) { 121 go w.callback(conn) 122 }) 123 if err != nil { 124 return newError("failed to listen TCP on ", w.port).AtWarning().Base(err) 125 } 126 w.hub = hub 127 return nil 128 } 129 130 func (w *tcpWorker) Close() error { 131 var errors []interface{} 132 if w.hub != nil { 133 if err := common.Close(w.hub); err != nil { 134 errors = append(errors, err) 135 } 136 if err := common.Close(w.proxy); err != nil { 137 errors = append(errors, err) 138 } 139 } 140 if len(errors) > 0 { 141 return newError("failed to close all resources").Base(newError(serial.Concat(errors...))) 142 } 143 144 return nil 145 } 146 147 func (w *tcpWorker) Port() net.Port { 148 return w.port 149 } 150 151 type udpConn struct { 152 lastActivityTime int64 // in seconds 153 reader buf.Reader 154 writer buf.Writer 155 output func([]byte) (int, error) 156 remote net.Addr 157 local net.Addr 158 done *done.Instance 159 uplink stats.Counter 160 downlink stats.Counter 161 inactive bool 162 } 163 164 func (c *udpConn) setInactive() { 165 c.inactive = true 166 } 167 168 func (c *udpConn) updateActivity() { 169 atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix()) 170 } 171 172 // ReadMultiBuffer implements buf.Reader 173 func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) { 174 mb, err := c.reader.ReadMultiBuffer() 175 if err != nil { 176 return nil, err 177 } 178 c.updateActivity() 179 180 if c.uplink != nil { 181 c.uplink.Add(int64(mb.Len())) 182 } 183 184 return mb, nil 185 } 186 187 func (c *udpConn) Read(buf []byte) (int, error) { 188 panic("not implemented") 189 } 190 191 // Write implements io.Writer. 192 func (c *udpConn) Write(buf []byte) (int, error) { 193 n, err := c.output(buf) 194 if c.downlink != nil { 195 c.downlink.Add(int64(n)) 196 } 197 if err == nil { 198 c.updateActivity() 199 } 200 return n, err 201 } 202 203 func (c *udpConn) Close() error { 204 common.Must(c.done.Close()) 205 common.Must(common.Close(c.writer)) 206 return nil 207 } 208 209 func (c *udpConn) RemoteAddr() net.Addr { 210 return c.remote 211 } 212 213 func (c *udpConn) LocalAddr() net.Addr { 214 return c.local 215 } 216 217 func (*udpConn) SetDeadline(time.Time) error { 218 return nil 219 } 220 221 func (*udpConn) SetReadDeadline(time.Time) error { 222 return nil 223 } 224 225 func (*udpConn) SetWriteDeadline(time.Time) error { 226 return nil 227 } 228 229 type connID struct { 230 src net.Destination 231 dest net.Destination 232 } 233 234 type udpWorker struct { 235 sync.RWMutex 236 237 proxy proxy.Inbound 238 hub *udp.Hub 239 address net.Address 240 port net.Port 241 tag string 242 stream *internet.MemoryStreamConfig 243 dispatcher routing.Dispatcher 244 sniffingConfig *proxyman.SniffingConfig 245 uplinkCounter stats.Counter 246 downlinkCounter stats.Counter 247 248 checker *task.Periodic 249 activeConn map[connID]*udpConn 250 251 ctx context.Context 252 cone bool 253 } 254 255 func (w *udpWorker) getConnection(id connID) (*udpConn, bool) { 256 w.Lock() 257 defer w.Unlock() 258 259 if conn, found := w.activeConn[id]; found && !conn.done.Done() { 260 return conn, true 261 } 262 263 pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024)) 264 conn := &udpConn{ 265 reader: pReader, 266 writer: pWriter, 267 output: func(b []byte) (int, error) { 268 return w.hub.WriteTo(b, id.src) 269 }, 270 remote: &net.UDPAddr{ 271 IP: id.src.Address.IP(), 272 Port: int(id.src.Port), 273 }, 274 local: &net.UDPAddr{ 275 IP: w.address.IP(), 276 Port: int(w.port), 277 }, 278 done: done.New(), 279 uplink: w.uplinkCounter, 280 downlink: w.downlinkCounter, 281 } 282 w.activeConn[id] = conn 283 284 conn.updateActivity() 285 return conn, false 286 } 287 288 func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) { 289 id := connID{ 290 src: source, 291 } 292 if originalDest.IsValid() { 293 if !w.cone { 294 id.dest = originalDest 295 } 296 b.UDP = &originalDest 297 } 298 conn, existing := w.getConnection(id) 299 300 // payload will be discarded in pipe is full. 301 conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) 302 303 if !existing { 304 common.Must(w.checker.Start()) 305 306 go func() { 307 ctx := w.ctx 308 sid := session.NewID() 309 ctx = session.ContextWithID(ctx, sid) 310 311 if originalDest.IsValid() { 312 outbounds := []*session.Outbound{{ 313 Target: originalDest, 314 }} 315 ctx = session.ContextWithOutbounds(ctx, outbounds) 316 } 317 ctx = session.ContextWithInbound(ctx, &session.Inbound{ 318 Source: source, 319 Gateway: net.UDPDestination(w.address, w.port), 320 Tag: w.tag, 321 }) 322 content := new(session.Content) 323 if w.sniffingConfig != nil { 324 content.SniffingRequest.Enabled = w.sniffingConfig.Enabled 325 content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride 326 content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly 327 content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly 328 } 329 ctx = session.ContextWithContent(ctx, content) 330 if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil { 331 newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx)) 332 } 333 conn.Close() 334 // conn not removed by checker TODO may be lock worker here is better 335 if !conn.inactive { 336 conn.setInactive() 337 w.removeConn(id) 338 } 339 }() 340 } 341 } 342 343 func (w *udpWorker) removeConn(id connID) { 344 w.Lock() 345 delete(w.activeConn, id) 346 w.Unlock() 347 } 348 349 func (w *udpWorker) handlePackets() { 350 receive := w.hub.Receive() 351 for payload := range receive { 352 w.callback(payload.Payload, payload.Source, payload.Target) 353 } 354 } 355 356 func (w *udpWorker) clean() error { 357 nowSec := time.Now().Unix() 358 w.Lock() 359 defer w.Unlock() 360 361 if len(w.activeConn) == 0 { 362 return newError("no more connections. stopping...") 363 } 364 365 for addr, conn := range w.activeConn { 366 if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 2*60 { 367 if !conn.inactive { 368 conn.setInactive() 369 delete(w.activeConn, addr) 370 } 371 conn.Close() 372 } 373 } 374 375 if len(w.activeConn) == 0 { 376 w.activeConn = make(map[connID]*udpConn, 16) 377 } 378 379 return nil 380 } 381 382 func (w *udpWorker) Start() error { 383 w.activeConn = make(map[connID]*udpConn, 16) 384 ctx := context.Background() 385 h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256)) 386 if err != nil { 387 return err 388 } 389 390 w.cone = w.ctx.Value("cone").(bool) 391 392 w.checker = &task.Periodic{ 393 Interval: time.Minute, 394 Execute: w.clean, 395 } 396 397 w.hub = h 398 go w.handlePackets() 399 return nil 400 } 401 402 func (w *udpWorker) Close() error { 403 w.Lock() 404 defer w.Unlock() 405 406 var errors []interface{} 407 408 if w.hub != nil { 409 if err := w.hub.Close(); err != nil { 410 errors = append(errors, err) 411 } 412 } 413 414 if w.checker != nil { 415 if err := w.checker.Close(); err != nil { 416 errors = append(errors, err) 417 } 418 } 419 420 if err := common.Close(w.proxy); err != nil { 421 errors = append(errors, err) 422 } 423 424 if len(errors) > 0 { 425 return newError("failed to close all resources").Base(newError(serial.Concat(errors...))) 426 } 427 return nil 428 } 429 430 func (w *udpWorker) Port() net.Port { 431 return w.port 432 } 433 434 func (w *udpWorker) Proxy() proxy.Inbound { 435 return w.proxy 436 } 437 438 type dsWorker struct { 439 address net.Address 440 proxy proxy.Inbound 441 stream *internet.MemoryStreamConfig 442 tag string 443 dispatcher routing.Dispatcher 444 sniffingConfig *proxyman.SniffingConfig 445 uplinkCounter stats.Counter 446 downlinkCounter stats.Counter 447 448 hub internet.Listener 449 450 ctx context.Context 451 } 452 453 func (w *dsWorker) callback(conn stat.Connection) { 454 ctx, cancel := context.WithCancel(w.ctx) 455 sid := session.NewID() 456 ctx = session.ContextWithID(ctx, sid) 457 458 if w.uplinkCounter != nil || w.downlinkCounter != nil { 459 conn = &stat.CounterConnection{ 460 Connection: conn, 461 ReadCounter: w.uplinkCounter, 462 WriteCounter: w.downlinkCounter, 463 } 464 } 465 ctx = session.ContextWithInbound(ctx, &session.Inbound{ 466 Source: net.DestinationFromAddr(conn.RemoteAddr()), 467 Gateway: net.UnixDestination(w.address), 468 Tag: w.tag, 469 Conn: conn, 470 }) 471 472 content := new(session.Content) 473 if w.sniffingConfig != nil { 474 content.SniffingRequest.Enabled = w.sniffingConfig.Enabled 475 content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride 476 content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded 477 content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly 478 content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly 479 } 480 ctx = session.ContextWithContent(ctx, content) 481 482 if err := w.proxy.Process(ctx, net.Network_UNIX, conn, w.dispatcher); err != nil { 483 newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx)) 484 } 485 cancel() 486 if err := conn.Close(); err != nil { 487 newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx)) 488 } 489 } 490 491 func (w *dsWorker) Proxy() proxy.Inbound { 492 return w.proxy 493 } 494 495 func (w *dsWorker) Port() net.Port { 496 return net.Port(0) 497 } 498 499 func (w *dsWorker) Start() error { 500 ctx := context.Background() 501 hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) { 502 go w.callback(conn) 503 }) 504 if err != nil { 505 return newError("failed to listen Unix Domain Socket on ", w.address).AtWarning().Base(err) 506 } 507 w.hub = hub 508 return nil 509 } 510 511 func (w *dsWorker) Close() error { 512 var errors []interface{} 513 if w.hub != nil { 514 if err := common.Close(w.hub); err != nil { 515 errors = append(errors, err) 516 } 517 if err := common.Close(w.proxy); err != nil { 518 errors = append(errors, err) 519 } 520 } 521 if len(errors) > 0 { 522 return newError("failed to close all resources").Base(newError(serial.Concat(errors...))) 523 } 524 525 return nil 526 }