github.com/xmplusdev/xray-core@v1.8.10/app/proxyman/inbound/worker.go (about) 1 package inbound 2 3 import ( 4 "context" 5 "sync" 6 "sync/atomic" 7 "time" 8 9 "github.com/xmplusdev/xray-core/app/proxyman" 10 "github.com/xmplusdev/xray-core/common" 11 "github.com/xmplusdev/xray-core/common/buf" 12 "github.com/xmplusdev/xray-core/common/net" 13 "github.com/xmplusdev/xray-core/common/serial" 14 "github.com/xmplusdev/xray-core/common/session" 15 "github.com/xmplusdev/xray-core/common/signal/done" 16 "github.com/xmplusdev/xray-core/common/task" 17 "github.com/xmplusdev/xray-core/features/routing" 18 "github.com/xmplusdev/xray-core/features/stats" 19 "github.com/xmplusdev/xray-core/proxy" 20 "github.com/xmplusdev/xray-core/transport/internet" 21 "github.com/xmplusdev/xray-core/transport/internet/stat" 22 "github.com/xmplusdev/xray-core/transport/internet/tcp" 23 "github.com/xmplusdev/xray-core/transport/internet/udp" 24 "github.com/xmplusdev/xray-core/transport/pipe" 25 ) 26 27 type worker interface { 28 Start() error 29 Close() error 30 Port() net.Port 31 Proxy() proxy.Inbound 32 } 33 34 type tcpWorker struct { 35 address net.Address 36 port net.Port 37 proxy proxy.Inbound 38 stream *internet.MemoryStreamConfig 39 recvOrigDest bool 40 tag string 41 dispatcher routing.Dispatcher 42 sniffingConfig *proxyman.SniffingConfig 43 uplinkCounter stats.Counter 44 downlinkCounter stats.Counter 45 46 hub internet.Listener 47 48 ctx context.Context 49 } 50 51 func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode { 52 if s == nil || s.SocketSettings == nil { 53 return internet.SocketConfig_Off 54 } 55 return s.SocketSettings.Tproxy 56 } 57 58 func (w *tcpWorker) callback(conn stat.Connection) { 59 ctx, cancel := context.WithCancel(w.ctx) 60 sid := session.NewID() 61 ctx = session.ContextWithID(ctx, sid) 62 63 var outbound = &session.Outbound{} 64 if w.recvOrigDest { 65 var dest net.Destination 66 switch getTProxyType(w.stream) { 67 case internet.SocketConfig_Redirect: 68 d, err := tcp.GetOriginalDestination(conn) 69 if err != nil { 70 newError("failed to get original destination").Base(err).WriteToLog(session.ExportIDToError(ctx)) 71 } else { 72 dest = d 73 } 74 case internet.SocketConfig_TProxy: 75 dest = net.DestinationFromAddr(conn.LocalAddr()) 76 } 77 if dest.IsValid() { 78 outbound.Target = dest 79 } 80 } 81 ctx = session.ContextWithOutbound(ctx, outbound) 82 83 if w.uplinkCounter != nil || w.downlinkCounter != nil { 84 conn = &stat.CounterConnection{ 85 Connection: conn, 86 ReadCounter: w.uplinkCounter, 87 WriteCounter: w.downlinkCounter, 88 } 89 } 90 ctx = session.ContextWithInbound(ctx, &session.Inbound{ 91 Source: net.DestinationFromAddr(conn.RemoteAddr()), 92 Gateway: net.TCPDestination(w.address, w.port), 93 Tag: w.tag, 94 Conn: conn, 95 }) 96 97 content := new(session.Content) 98 if w.sniffingConfig != nil { 99 content.SniffingRequest.Enabled = w.sniffingConfig.Enabled 100 content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride 101 content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded 102 content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly 103 content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly 104 } 105 ctx = session.ContextWithContent(ctx, content) 106 107 if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil { 108 newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx)) 109 } 110 cancel() 111 conn.Close() 112 } 113 114 func (w *tcpWorker) Proxy() proxy.Inbound { 115 return w.proxy 116 } 117 118 func (w *tcpWorker) Start() error { 119 ctx := context.Background() 120 hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) { 121 go w.callback(conn) 122 }) 123 if err != nil { 124 return newError("failed to listen TCP on ", w.port).AtWarning().Base(err) 125 } 126 w.hub = hub 127 return nil 128 } 129 130 func (w *tcpWorker) Close() error { 131 var errors []interface{} 132 if w.hub != nil { 133 if err := common.Close(w.hub); err != nil { 134 errors = append(errors, err) 135 } 136 if err := common.Close(w.proxy); err != nil { 137 errors = append(errors, err) 138 } 139 } 140 if len(errors) > 0 { 141 return newError("failed to close all resources").Base(newError(serial.Concat(errors...))) 142 } 143 144 return nil 145 } 146 147 func (w *tcpWorker) Port() net.Port { 148 return w.port 149 } 150 151 type udpConn struct { 152 lastActivityTime int64 // in seconds 153 reader buf.Reader 154 writer buf.Writer 155 output func([]byte) (int, error) 156 remote net.Addr 157 local net.Addr 158 done *done.Instance 159 uplink stats.Counter 160 downlink stats.Counter 161 inactive bool 162 } 163 164 func (c *udpConn) setInactive() { 165 c.inactive = true 166 } 167 168 func (c *udpConn) updateActivity() { 169 atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix()) 170 } 171 172 // ReadMultiBuffer implements buf.Reader 173 func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) { 174 mb, err := c.reader.ReadMultiBuffer() 175 if err != nil { 176 return nil, err 177 } 178 c.updateActivity() 179 180 if c.uplink != nil { 181 c.uplink.Add(int64(mb.Len())) 182 } 183 184 return mb, nil 185 } 186 187 func (c *udpConn) Read(buf []byte) (int, error) { 188 panic("not implemented") 189 } 190 191 // Write implements io.Writer. 192 func (c *udpConn) Write(buf []byte) (int, error) { 193 n, err := c.output(buf) 194 if c.downlink != nil { 195 c.downlink.Add(int64(n)) 196 } 197 if err == nil { 198 c.updateActivity() 199 } 200 return n, err 201 } 202 203 func (c *udpConn) Close() error { 204 common.Must(c.done.Close()) 205 common.Must(common.Close(c.writer)) 206 return nil 207 } 208 209 func (c *udpConn) RemoteAddr() net.Addr { 210 return c.remote 211 } 212 213 func (c *udpConn) LocalAddr() net.Addr { 214 return c.local 215 } 216 217 func (*udpConn) SetDeadline(time.Time) error { 218 return nil 219 } 220 221 func (*udpConn) SetReadDeadline(time.Time) error { 222 return nil 223 } 224 225 func (*udpConn) SetWriteDeadline(time.Time) error { 226 return nil 227 } 228 229 type connID struct { 230 src net.Destination 231 dest net.Destination 232 } 233 234 type udpWorker struct { 235 sync.RWMutex 236 237 proxy proxy.Inbound 238 hub *udp.Hub 239 address net.Address 240 port net.Port 241 tag string 242 stream *internet.MemoryStreamConfig 243 dispatcher routing.Dispatcher 244 sniffingConfig *proxyman.SniffingConfig 245 uplinkCounter stats.Counter 246 downlinkCounter stats.Counter 247 248 checker *task.Periodic 249 activeConn map[connID]*udpConn 250 251 ctx context.Context 252 cone bool 253 } 254 255 func (w *udpWorker) getConnection(id connID) (*udpConn, bool) { 256 w.Lock() 257 defer w.Unlock() 258 259 if conn, found := w.activeConn[id]; found && !conn.done.Done() { 260 return conn, true 261 } 262 263 pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024)) 264 conn := &udpConn{ 265 reader: pReader, 266 writer: pWriter, 267 output: func(b []byte) (int, error) { 268 return w.hub.WriteTo(b, id.src) 269 }, 270 remote: &net.UDPAddr{ 271 IP: id.src.Address.IP(), 272 Port: int(id.src.Port), 273 }, 274 local: &net.UDPAddr{ 275 IP: w.address.IP(), 276 Port: int(w.port), 277 }, 278 done: done.New(), 279 uplink: w.uplinkCounter, 280 downlink: w.downlinkCounter, 281 } 282 w.activeConn[id] = conn 283 284 conn.updateActivity() 285 return conn, false 286 } 287 288 func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) { 289 id := connID{ 290 src: source, 291 } 292 if originalDest.IsValid() { 293 if !w.cone { 294 id.dest = originalDest 295 } 296 b.UDP = &originalDest 297 } 298 conn, existing := w.getConnection(id) 299 300 // payload will be discarded in pipe is full. 301 conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) 302 303 if !existing { 304 common.Must(w.checker.Start()) 305 306 go func() { 307 ctx := w.ctx 308 sid := session.NewID() 309 ctx = session.ContextWithID(ctx, sid) 310 311 if originalDest.IsValid() { 312 ctx = session.ContextWithOutbound(ctx, &session.Outbound{ 313 Target: originalDest, 314 }) 315 } 316 ctx = session.ContextWithInbound(ctx, &session.Inbound{ 317 Source: source, 318 Gateway: net.UDPDestination(w.address, w.port), 319 Tag: w.tag, 320 }) 321 content := new(session.Content) 322 if w.sniffingConfig != nil { 323 content.SniffingRequest.Enabled = w.sniffingConfig.Enabled 324 content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride 325 content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly 326 content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly 327 } 328 ctx = session.ContextWithContent(ctx, content) 329 if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil { 330 newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx)) 331 } 332 conn.Close() 333 // conn not removed by checker TODO may be lock worker here is better 334 if !conn.inactive { 335 conn.setInactive() 336 w.removeConn(id) 337 } 338 }() 339 } 340 } 341 342 func (w *udpWorker) removeConn(id connID) { 343 w.Lock() 344 delete(w.activeConn, id) 345 w.Unlock() 346 } 347 348 func (w *udpWorker) handlePackets() { 349 receive := w.hub.Receive() 350 for payload := range receive { 351 w.callback(payload.Payload, payload.Source, payload.Target) 352 } 353 } 354 355 func (w *udpWorker) clean() error { 356 nowSec := time.Now().Unix() 357 w.Lock() 358 defer w.Unlock() 359 360 if len(w.activeConn) == 0 { 361 return newError("no more connections. stopping...") 362 } 363 364 for addr, conn := range w.activeConn { 365 if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 2*60 { 366 if !conn.inactive { 367 conn.setInactive() 368 delete(w.activeConn, addr) 369 } 370 conn.Close() 371 } 372 } 373 374 if len(w.activeConn) == 0 { 375 w.activeConn = make(map[connID]*udpConn, 16) 376 } 377 378 return nil 379 } 380 381 func (w *udpWorker) Start() error { 382 w.activeConn = make(map[connID]*udpConn, 16) 383 ctx := context.Background() 384 h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256)) 385 if err != nil { 386 return err 387 } 388 389 w.cone = w.ctx.Value("cone").(bool) 390 391 w.checker = &task.Periodic{ 392 Interval: time.Minute, 393 Execute: w.clean, 394 } 395 396 w.hub = h 397 go w.handlePackets() 398 return nil 399 } 400 401 func (w *udpWorker) Close() error { 402 w.Lock() 403 defer w.Unlock() 404 405 var errors []interface{} 406 407 if w.hub != nil { 408 if err := w.hub.Close(); err != nil { 409 errors = append(errors, err) 410 } 411 } 412 413 if w.checker != nil { 414 if err := w.checker.Close(); err != nil { 415 errors = append(errors, err) 416 } 417 } 418 419 if err := common.Close(w.proxy); err != nil { 420 errors = append(errors, err) 421 } 422 423 if len(errors) > 0 { 424 return newError("failed to close all resources").Base(newError(serial.Concat(errors...))) 425 } 426 return nil 427 } 428 429 func (w *udpWorker) Port() net.Port { 430 return w.port 431 } 432 433 func (w *udpWorker) Proxy() proxy.Inbound { 434 return w.proxy 435 } 436 437 type dsWorker struct { 438 address net.Address 439 proxy proxy.Inbound 440 stream *internet.MemoryStreamConfig 441 tag string 442 dispatcher routing.Dispatcher 443 sniffingConfig *proxyman.SniffingConfig 444 uplinkCounter stats.Counter 445 downlinkCounter stats.Counter 446 447 hub internet.Listener 448 449 ctx context.Context 450 } 451 452 func (w *dsWorker) callback(conn stat.Connection) { 453 ctx, cancel := context.WithCancel(w.ctx) 454 sid := session.NewID() 455 ctx = session.ContextWithID(ctx, sid) 456 457 if w.uplinkCounter != nil || w.downlinkCounter != nil { 458 conn = &stat.CounterConnection{ 459 Connection: conn, 460 ReadCounter: w.uplinkCounter, 461 WriteCounter: w.downlinkCounter, 462 } 463 } 464 ctx = session.ContextWithInbound(ctx, &session.Inbound{ 465 Source: net.DestinationFromAddr(conn.RemoteAddr()), 466 Gateway: net.UnixDestination(w.address), 467 Tag: w.tag, 468 Conn: conn, 469 }) 470 471 content := new(session.Content) 472 if w.sniffingConfig != nil { 473 content.SniffingRequest.Enabled = w.sniffingConfig.Enabled 474 content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride 475 content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded 476 content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly 477 content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly 478 } 479 ctx = session.ContextWithContent(ctx, content) 480 481 if err := w.proxy.Process(ctx, net.Network_UNIX, conn, w.dispatcher); err != nil { 482 newError("connection ends").Base(err).WriteToLog(session.ExportIDToError(ctx)) 483 } 484 cancel() 485 if err := conn.Close(); err != nil { 486 newError("failed to close connection").Base(err).WriteToLog(session.ExportIDToError(ctx)) 487 } 488 } 489 490 func (w *dsWorker) Proxy() proxy.Inbound { 491 return w.proxy 492 } 493 494 func (w *dsWorker) Port() net.Port { 495 return net.Port(0) 496 } 497 498 func (w *dsWorker) Start() error { 499 ctx := context.Background() 500 hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) { 501 go w.callback(conn) 502 }) 503 if err != nil { 504 return newError("failed to listen Unix Domain Socket on ", w.address).AtWarning().Base(err) 505 } 506 w.hub = hub 507 return nil 508 } 509 510 func (w *dsWorker) Close() error { 511 var errors []interface{} 512 if w.hub != nil { 513 if err := common.Close(w.hub); err != nil { 514 errors = append(errors, err) 515 } 516 if err := common.Close(w.proxy); err != nil { 517 errors = append(errors, err) 518 } 519 } 520 if len(errors) > 0 { 521 return newError("failed to close all resources").Base(newError(serial.Concat(errors...))) 522 } 523 524 return nil 525 }