github.com/sagernet/sing-box@v1.2.7/common/mux/client.go (about) 1 package mux 2 3 import ( 4 "context" 5 "encoding/binary" 6 "io" 7 "net" 8 "sync" 9 10 "github.com/sagernet/sing-box/option" 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 "github.com/sagernet/sing/common/bufio" 14 E "github.com/sagernet/sing/common/exceptions" 15 M "github.com/sagernet/sing/common/metadata" 16 N "github.com/sagernet/sing/common/network" 17 "github.com/sagernet/sing/common/x/list" 18 ) 19 20 var _ N.Dialer = (*Client)(nil) 21 22 type Client struct { 23 access sync.Mutex 24 connections list.List[abstractSession] 25 ctx context.Context 26 dialer N.Dialer 27 protocol Protocol 28 maxConnections int 29 minStreams int 30 maxStreams int 31 } 32 33 func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client { 34 return &Client{ 35 ctx: ctx, 36 dialer: dialer, 37 protocol: protocol, 38 maxConnections: maxConnections, 39 minStreams: minStreams, 40 maxStreams: maxStreams, 41 } 42 } 43 44 func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (N.Dialer, error) { 45 if !options.Enabled { 46 return nil, nil 47 } 48 if options.MaxConnections == 0 && options.MaxStreams == 0 { 49 options.MinStreams = 8 50 } 51 protocol, err := ParseProtocol(options.Protocol) 52 if err != nil { 53 return nil, err 54 } 55 return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil 56 } 57 58 func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 59 switch N.NetworkName(network) { 60 case N.NetworkTCP: 61 stream, err := c.openStream() 62 if err != nil { 63 return nil, err 64 } 65 return &ClientConn{Conn: stream, destination: destination}, nil 66 case N.NetworkUDP: 67 stream, err := c.openStream() 68 if err != nil { 69 return nil, err 70 } 71 return bufio.NewUnbindPacketConn(&ClientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil 72 default: 73 return nil, E.Extend(N.ErrUnknownNetwork, network) 74 } 75 } 76 77 func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 78 stream, err := c.openStream() 79 if err != nil { 80 return nil, err 81 } 82 return &ClientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil 83 } 84 85 func (c *Client) openStream() (net.Conn, error) { 86 var ( 87 session abstractSession 88 stream net.Conn 89 err error 90 ) 91 for attempts := 0; attempts < 2; attempts++ { 92 session, err = c.offer() 93 if err != nil { 94 continue 95 } 96 stream, err = session.Open() 97 if err != nil { 98 continue 99 } 100 break 101 } 102 if err != nil { 103 return nil, err 104 } 105 return &wrapStream{stream}, nil 106 } 107 108 func (c *Client) offer() (abstractSession, error) { 109 c.access.Lock() 110 defer c.access.Unlock() 111 112 sessions := make([]abstractSession, 0, c.maxConnections) 113 for element := c.connections.Front(); element != nil; { 114 if element.Value.IsClosed() { 115 nextElement := element.Next() 116 c.connections.Remove(element) 117 element = nextElement 118 continue 119 } 120 sessions = append(sessions, element.Value) 121 element = element.Next() 122 } 123 sLen := len(sessions) 124 if sLen == 0 { 125 return c.offerNew() 126 } 127 session := common.MinBy(sessions, abstractSession.NumStreams) 128 numStreams := session.NumStreams() 129 if numStreams == 0 { 130 return session, nil 131 } 132 if c.maxConnections > 0 { 133 if sLen >= c.maxConnections || numStreams < c.minStreams { 134 return session, nil 135 } 136 } else { 137 if c.maxStreams > 0 && numStreams < c.maxStreams { 138 return session, nil 139 } 140 } 141 return c.offerNew() 142 } 143 144 func (c *Client) offerNew() (abstractSession, error) { 145 conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination) 146 if err != nil { 147 return nil, err 148 } 149 if vectorisedWriter, isVectorised := bufio.CreateVectorisedWriter(conn); isVectorised { 150 conn = &vectorisedProtocolConn{protocolConn{Conn: conn, protocol: c.protocol}, vectorisedWriter} 151 } else { 152 conn = &protocolConn{Conn: conn, protocol: c.protocol} 153 } 154 session, err := c.protocol.newClient(conn) 155 if err != nil { 156 return nil, err 157 } 158 c.connections.PushBack(session) 159 return session, nil 160 } 161 162 func (c *Client) Close() error { 163 c.access.Lock() 164 defer c.access.Unlock() 165 for _, session := range c.connections.Array() { 166 session.Close() 167 } 168 return nil 169 } 170 171 type ClientConn struct { 172 net.Conn 173 destination M.Socksaddr 174 requestWrite bool 175 responseRead bool 176 } 177 178 func (c *ClientConn) readResponse() error { 179 response, err := ReadStreamResponse(c.Conn) 180 if err != nil { 181 return err 182 } 183 if response.Status == statusError { 184 return E.New("remote error: ", response.Message) 185 } 186 return nil 187 } 188 189 func (c *ClientConn) Read(b []byte) (n int, err error) { 190 if !c.responseRead { 191 err = c.readResponse() 192 if err != nil { 193 return 194 } 195 c.responseRead = true 196 } 197 return c.Conn.Read(b) 198 } 199 200 func (c *ClientConn) Write(b []byte) (n int, err error) { 201 if c.requestWrite { 202 return c.Conn.Write(b) 203 } 204 request := StreamRequest{ 205 Network: N.NetworkTCP, 206 Destination: c.destination, 207 } 208 _buffer := buf.StackNewSize(requestLen(request) + len(b)) 209 defer common.KeepAlive(_buffer) 210 buffer := common.Dup(_buffer) 211 defer buffer.Release() 212 EncodeStreamRequest(request, buffer) 213 buffer.Write(b) 214 _, err = c.Conn.Write(buffer.Bytes()) 215 if err != nil { 216 return 217 } 218 c.requestWrite = true 219 return len(b), nil 220 } 221 222 func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) { 223 if !c.requestWrite { 224 return bufio.ReadFrom0(c, r) 225 } 226 return bufio.Copy(c.Conn, r) 227 } 228 229 func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) { 230 if !c.responseRead { 231 return bufio.WriteTo0(c, w) 232 } 233 return bufio.Copy(w, c.Conn) 234 } 235 236 func (c *ClientConn) LocalAddr() net.Addr { 237 return c.Conn.LocalAddr() 238 } 239 240 func (c *ClientConn) RemoteAddr() net.Addr { 241 return c.destination.TCPAddr() 242 } 243 244 func (c *ClientConn) ReaderReplaceable() bool { 245 return c.responseRead 246 } 247 248 func (c *ClientConn) WriterReplaceable() bool { 249 return c.requestWrite 250 } 251 252 func (c *ClientConn) NeedAdditionalReadDeadline() bool { 253 return true 254 } 255 256 func (c *ClientConn) Upstream() any { 257 return c.Conn 258 } 259 260 type ClientPacketConn struct { 261 N.ExtendedConn 262 destination M.Socksaddr 263 requestWrite bool 264 responseRead bool 265 } 266 267 func (c *ClientPacketConn) readResponse() error { 268 response, err := ReadStreamResponse(c.ExtendedConn) 269 if err != nil { 270 return err 271 } 272 if response.Status == statusError { 273 return E.New("remote error: ", response.Message) 274 } 275 return nil 276 } 277 278 func (c *ClientPacketConn) Read(b []byte) (n int, err error) { 279 if !c.responseRead { 280 err = c.readResponse() 281 if err != nil { 282 return 283 } 284 c.responseRead = true 285 } 286 var length uint16 287 err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 288 if err != nil { 289 return 290 } 291 if cap(b) < int(length) { 292 return 0, io.ErrShortBuffer 293 } 294 return io.ReadFull(c.ExtendedConn, b[:length]) 295 } 296 297 func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) { 298 request := StreamRequest{ 299 Network: N.NetworkUDP, 300 Destination: c.destination, 301 } 302 rLen := requestLen(request) 303 if len(payload) > 0 { 304 rLen += 2 + len(payload) 305 } 306 _buffer := buf.StackNewSize(rLen) 307 defer common.KeepAlive(_buffer) 308 buffer := common.Dup(_buffer) 309 defer buffer.Release() 310 EncodeStreamRequest(request, buffer) 311 if len(payload) > 0 { 312 common.Must( 313 binary.Write(buffer, binary.BigEndian, uint16(len(payload))), 314 common.Error(buffer.Write(payload)), 315 ) 316 } 317 _, err = c.ExtendedConn.Write(buffer.Bytes()) 318 if err != nil { 319 return 320 } 321 c.requestWrite = true 322 return len(payload), nil 323 } 324 325 func (c *ClientPacketConn) Write(b []byte) (n int, err error) { 326 if !c.requestWrite { 327 return c.writeRequest(b) 328 } 329 err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b))) 330 if err != nil { 331 return 332 } 333 return c.ExtendedConn.Write(b) 334 } 335 336 func (c *ClientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) { 337 if !c.responseRead { 338 err = c.readResponse() 339 if err != nil { 340 return 341 } 342 c.responseRead = true 343 } 344 var length uint16 345 err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 346 if err != nil { 347 return 348 } 349 _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) 350 return 351 } 352 353 func (c *ClientPacketConn) WriteBuffer(buffer *buf.Buffer) error { 354 if !c.requestWrite { 355 defer buffer.Release() 356 return common.Error(c.writeRequest(buffer.Bytes())) 357 } 358 bLen := buffer.Len() 359 binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen)) 360 return c.ExtendedConn.WriteBuffer(buffer) 361 } 362 363 func (c *ClientPacketConn) FrontHeadroom() int { 364 return 2 365 } 366 367 func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 368 err = c.ReadBuffer(buffer) 369 return 370 } 371 372 func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 373 return c.WriteBuffer(buffer) 374 } 375 376 func (c *ClientPacketConn) LocalAddr() net.Addr { 377 return c.ExtendedConn.LocalAddr() 378 } 379 380 func (c *ClientPacketConn) RemoteAddr() net.Addr { 381 return c.destination.UDPAddr() 382 } 383 384 func (c *ClientPacketConn) NeedAdditionalReadDeadline() bool { 385 return true 386 } 387 388 func (c *ClientPacketConn) Upstream() any { 389 return c.ExtendedConn 390 } 391 392 var _ N.NetPacketConn = (*ClientPacketAddrConn)(nil) 393 394 type ClientPacketAddrConn struct { 395 N.ExtendedConn 396 destination M.Socksaddr 397 requestWrite bool 398 responseRead bool 399 } 400 401 func (c *ClientPacketAddrConn) readResponse() error { 402 response, err := ReadStreamResponse(c.ExtendedConn) 403 if err != nil { 404 return err 405 } 406 if response.Status == statusError { 407 return E.New("remote error: ", response.Message) 408 } 409 return nil 410 } 411 412 func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 413 if !c.responseRead { 414 err = c.readResponse() 415 if err != nil { 416 return 417 } 418 c.responseRead = true 419 } 420 destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) 421 if err != nil { 422 return 423 } 424 if destination.IsFqdn() { 425 addr = destination 426 } else { 427 addr = destination.UDPAddr() 428 } 429 var length uint16 430 err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 431 if err != nil { 432 return 433 } 434 if cap(p) < int(length) { 435 return 0, nil, io.ErrShortBuffer 436 } 437 n, err = io.ReadFull(c.ExtendedConn, p[:length]) 438 return 439 } 440 441 func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) { 442 request := StreamRequest{ 443 Network: N.NetworkUDP, 444 Destination: c.destination, 445 PacketAddr: true, 446 } 447 rLen := requestLen(request) 448 if len(payload) > 0 { 449 rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload) 450 } 451 _buffer := buf.StackNewSize(rLen) 452 defer common.KeepAlive(_buffer) 453 buffer := common.Dup(_buffer) 454 defer buffer.Release() 455 EncodeStreamRequest(request, buffer) 456 if len(payload) > 0 { 457 common.Must( 458 M.SocksaddrSerializer.WriteAddrPort(buffer, destination), 459 binary.Write(buffer, binary.BigEndian, uint16(len(payload))), 460 common.Error(buffer.Write(payload)), 461 ) 462 } 463 _, err = c.ExtendedConn.Write(buffer.Bytes()) 464 if err != nil { 465 return 466 } 467 c.requestWrite = true 468 return len(payload), nil 469 } 470 471 func (c *ClientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 472 if !c.requestWrite { 473 return c.writeRequest(p, M.SocksaddrFromNet(addr)) 474 } 475 err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr)) 476 if err != nil { 477 return 478 } 479 err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p))) 480 if err != nil { 481 return 482 } 483 return c.ExtendedConn.Write(p) 484 } 485 486 func (c *ClientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 487 if !c.responseRead { 488 err = c.readResponse() 489 if err != nil { 490 return 491 } 492 c.responseRead = true 493 } 494 destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) 495 if err != nil { 496 return 497 } 498 var length uint16 499 err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 500 if err != nil { 501 return 502 } 503 _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) 504 return 505 } 506 507 func (c *ClientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 508 if !c.requestWrite { 509 defer buffer.Release() 510 return common.Error(c.writeRequest(buffer.Bytes(), destination)) 511 } 512 bLen := buffer.Len() 513 header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2)) 514 common.Must( 515 M.SocksaddrSerializer.WriteAddrPort(header, destination), 516 binary.Write(header, binary.BigEndian, uint16(bLen)), 517 ) 518 return c.ExtendedConn.WriteBuffer(buffer) 519 } 520 521 func (c *ClientPacketAddrConn) LocalAddr() net.Addr { 522 return c.ExtendedConn.LocalAddr() 523 } 524 525 func (c *ClientPacketAddrConn) FrontHeadroom() int { 526 return 2 + M.MaxSocksaddrLength 527 } 528 529 func (c *ClientPacketAddrConn) NeedAdditionalReadDeadline() bool { 530 return true 531 } 532 533 func (c *ClientPacketAddrConn) Upstream() any { 534 return c.ExtendedConn 535 }