github.com/sagernet/sing-box@v1.2.7/common/mux/client.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/sagernet/sing-box/option"
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/buf"
    13  	"github.com/sagernet/sing/common/bufio"
    14  	E "github.com/sagernet/sing/common/exceptions"
    15  	M "github.com/sagernet/sing/common/metadata"
    16  	N "github.com/sagernet/sing/common/network"
    17  	"github.com/sagernet/sing/common/x/list"
    18  )
    19  
    20  var _ N.Dialer = (*Client)(nil)
    21  
    22  type Client struct {
    23  	access         sync.Mutex
    24  	connections    list.List[abstractSession]
    25  	ctx            context.Context
    26  	dialer         N.Dialer
    27  	protocol       Protocol
    28  	maxConnections int
    29  	minStreams     int
    30  	maxStreams     int
    31  }
    32  
    33  func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client {
    34  	return &Client{
    35  		ctx:            ctx,
    36  		dialer:         dialer,
    37  		protocol:       protocol,
    38  		maxConnections: maxConnections,
    39  		minStreams:     minStreams,
    40  		maxStreams:     maxStreams,
    41  	}
    42  }
    43  
    44  func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (N.Dialer, error) {
    45  	if !options.Enabled {
    46  		return nil, nil
    47  	}
    48  	if options.MaxConnections == 0 && options.MaxStreams == 0 {
    49  		options.MinStreams = 8
    50  	}
    51  	protocol, err := ParseProtocol(options.Protocol)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil
    56  }
    57  
    58  func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    59  	switch N.NetworkName(network) {
    60  	case N.NetworkTCP:
    61  		stream, err := c.openStream()
    62  		if err != nil {
    63  			return nil, err
    64  		}
    65  		return &ClientConn{Conn: stream, destination: destination}, nil
    66  	case N.NetworkUDP:
    67  		stream, err := c.openStream()
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  		return bufio.NewUnbindPacketConn(&ClientPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}), nil
    72  	default:
    73  		return nil, E.Extend(N.ErrUnknownNetwork, network)
    74  	}
    75  }
    76  
    77  func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
    78  	stream, err := c.openStream()
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  	return &ClientPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: destination}, nil
    83  }
    84  
    85  func (c *Client) openStream() (net.Conn, error) {
    86  	var (
    87  		session abstractSession
    88  		stream  net.Conn
    89  		err     error
    90  	)
    91  	for attempts := 0; attempts < 2; attempts++ {
    92  		session, err = c.offer()
    93  		if err != nil {
    94  			continue
    95  		}
    96  		stream, err = session.Open()
    97  		if err != nil {
    98  			continue
    99  		}
   100  		break
   101  	}
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	return &wrapStream{stream}, nil
   106  }
   107  
   108  func (c *Client) offer() (abstractSession, error) {
   109  	c.access.Lock()
   110  	defer c.access.Unlock()
   111  
   112  	sessions := make([]abstractSession, 0, c.maxConnections)
   113  	for element := c.connections.Front(); element != nil; {
   114  		if element.Value.IsClosed() {
   115  			nextElement := element.Next()
   116  			c.connections.Remove(element)
   117  			element = nextElement
   118  			continue
   119  		}
   120  		sessions = append(sessions, element.Value)
   121  		element = element.Next()
   122  	}
   123  	sLen := len(sessions)
   124  	if sLen == 0 {
   125  		return c.offerNew()
   126  	}
   127  	session := common.MinBy(sessions, abstractSession.NumStreams)
   128  	numStreams := session.NumStreams()
   129  	if numStreams == 0 {
   130  		return session, nil
   131  	}
   132  	if c.maxConnections > 0 {
   133  		if sLen >= c.maxConnections || numStreams < c.minStreams {
   134  			return session, nil
   135  		}
   136  	} else {
   137  		if c.maxStreams > 0 && numStreams < c.maxStreams {
   138  			return session, nil
   139  		}
   140  	}
   141  	return c.offerNew()
   142  }
   143  
   144  func (c *Client) offerNew() (abstractSession, error) {
   145  	conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  	if vectorisedWriter, isVectorised := bufio.CreateVectorisedWriter(conn); isVectorised {
   150  		conn = &vectorisedProtocolConn{protocolConn{Conn: conn, protocol: c.protocol}, vectorisedWriter}
   151  	} else {
   152  		conn = &protocolConn{Conn: conn, protocol: c.protocol}
   153  	}
   154  	session, err := c.protocol.newClient(conn)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  	c.connections.PushBack(session)
   159  	return session, nil
   160  }
   161  
   162  func (c *Client) Close() error {
   163  	c.access.Lock()
   164  	defer c.access.Unlock()
   165  	for _, session := range c.connections.Array() {
   166  		session.Close()
   167  	}
   168  	return nil
   169  }
   170  
   171  type ClientConn struct {
   172  	net.Conn
   173  	destination  M.Socksaddr
   174  	requestWrite bool
   175  	responseRead bool
   176  }
   177  
   178  func (c *ClientConn) readResponse() error {
   179  	response, err := ReadStreamResponse(c.Conn)
   180  	if err != nil {
   181  		return err
   182  	}
   183  	if response.Status == statusError {
   184  		return E.New("remote error: ", response.Message)
   185  	}
   186  	return nil
   187  }
   188  
   189  func (c *ClientConn) Read(b []byte) (n int, err error) {
   190  	if !c.responseRead {
   191  		err = c.readResponse()
   192  		if err != nil {
   193  			return
   194  		}
   195  		c.responseRead = true
   196  	}
   197  	return c.Conn.Read(b)
   198  }
   199  
   200  func (c *ClientConn) Write(b []byte) (n int, err error) {
   201  	if c.requestWrite {
   202  		return c.Conn.Write(b)
   203  	}
   204  	request := StreamRequest{
   205  		Network:     N.NetworkTCP,
   206  		Destination: c.destination,
   207  	}
   208  	_buffer := buf.StackNewSize(requestLen(request) + len(b))
   209  	defer common.KeepAlive(_buffer)
   210  	buffer := common.Dup(_buffer)
   211  	defer buffer.Release()
   212  	EncodeStreamRequest(request, buffer)
   213  	buffer.Write(b)
   214  	_, err = c.Conn.Write(buffer.Bytes())
   215  	if err != nil {
   216  		return
   217  	}
   218  	c.requestWrite = true
   219  	return len(b), nil
   220  }
   221  
   222  func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) {
   223  	if !c.requestWrite {
   224  		return bufio.ReadFrom0(c, r)
   225  	}
   226  	return bufio.Copy(c.Conn, r)
   227  }
   228  
   229  func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
   230  	if !c.responseRead {
   231  		return bufio.WriteTo0(c, w)
   232  	}
   233  	return bufio.Copy(w, c.Conn)
   234  }
   235  
   236  func (c *ClientConn) LocalAddr() net.Addr {
   237  	return c.Conn.LocalAddr()
   238  }
   239  
   240  func (c *ClientConn) RemoteAddr() net.Addr {
   241  	return c.destination.TCPAddr()
   242  }
   243  
   244  func (c *ClientConn) ReaderReplaceable() bool {
   245  	return c.responseRead
   246  }
   247  
   248  func (c *ClientConn) WriterReplaceable() bool {
   249  	return c.requestWrite
   250  }
   251  
   252  func (c *ClientConn) NeedAdditionalReadDeadline() bool {
   253  	return true
   254  }
   255  
   256  func (c *ClientConn) Upstream() any {
   257  	return c.Conn
   258  }
   259  
   260  type ClientPacketConn struct {
   261  	N.ExtendedConn
   262  	destination  M.Socksaddr
   263  	requestWrite bool
   264  	responseRead bool
   265  }
   266  
   267  func (c *ClientPacketConn) readResponse() error {
   268  	response, err := ReadStreamResponse(c.ExtendedConn)
   269  	if err != nil {
   270  		return err
   271  	}
   272  	if response.Status == statusError {
   273  		return E.New("remote error: ", response.Message)
   274  	}
   275  	return nil
   276  }
   277  
   278  func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
   279  	if !c.responseRead {
   280  		err = c.readResponse()
   281  		if err != nil {
   282  			return
   283  		}
   284  		c.responseRead = true
   285  	}
   286  	var length uint16
   287  	err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
   288  	if err != nil {
   289  		return
   290  	}
   291  	if cap(b) < int(length) {
   292  		return 0, io.ErrShortBuffer
   293  	}
   294  	return io.ReadFull(c.ExtendedConn, b[:length])
   295  }
   296  
   297  func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) {
   298  	request := StreamRequest{
   299  		Network:     N.NetworkUDP,
   300  		Destination: c.destination,
   301  	}
   302  	rLen := requestLen(request)
   303  	if len(payload) > 0 {
   304  		rLen += 2 + len(payload)
   305  	}
   306  	_buffer := buf.StackNewSize(rLen)
   307  	defer common.KeepAlive(_buffer)
   308  	buffer := common.Dup(_buffer)
   309  	defer buffer.Release()
   310  	EncodeStreamRequest(request, buffer)
   311  	if len(payload) > 0 {
   312  		common.Must(
   313  			binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
   314  			common.Error(buffer.Write(payload)),
   315  		)
   316  	}
   317  	_, err = c.ExtendedConn.Write(buffer.Bytes())
   318  	if err != nil {
   319  		return
   320  	}
   321  	c.requestWrite = true
   322  	return len(payload), nil
   323  }
   324  
   325  func (c *ClientPacketConn) Write(b []byte) (n int, err error) {
   326  	if !c.requestWrite {
   327  		return c.writeRequest(b)
   328  	}
   329  	err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(b)))
   330  	if err != nil {
   331  		return
   332  	}
   333  	return c.ExtendedConn.Write(b)
   334  }
   335  
   336  func (c *ClientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
   337  	if !c.responseRead {
   338  		err = c.readResponse()
   339  		if err != nil {
   340  			return
   341  		}
   342  		c.responseRead = true
   343  	}
   344  	var length uint16
   345  	err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
   346  	if err != nil {
   347  		return
   348  	}
   349  	_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
   350  	return
   351  }
   352  
   353  func (c *ClientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
   354  	if !c.requestWrite {
   355  		defer buffer.Release()
   356  		return common.Error(c.writeRequest(buffer.Bytes()))
   357  	}
   358  	bLen := buffer.Len()
   359  	binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
   360  	return c.ExtendedConn.WriteBuffer(buffer)
   361  }
   362  
   363  func (c *ClientPacketConn) FrontHeadroom() int {
   364  	return 2
   365  }
   366  
   367  func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   368  	err = c.ReadBuffer(buffer)
   369  	return
   370  }
   371  
   372  func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   373  	return c.WriteBuffer(buffer)
   374  }
   375  
   376  func (c *ClientPacketConn) LocalAddr() net.Addr {
   377  	return c.ExtendedConn.LocalAddr()
   378  }
   379  
   380  func (c *ClientPacketConn) RemoteAddr() net.Addr {
   381  	return c.destination.UDPAddr()
   382  }
   383  
   384  func (c *ClientPacketConn) NeedAdditionalReadDeadline() bool {
   385  	return true
   386  }
   387  
   388  func (c *ClientPacketConn) Upstream() any {
   389  	return c.ExtendedConn
   390  }
   391  
   392  var _ N.NetPacketConn = (*ClientPacketAddrConn)(nil)
   393  
   394  type ClientPacketAddrConn struct {
   395  	N.ExtendedConn
   396  	destination  M.Socksaddr
   397  	requestWrite bool
   398  	responseRead bool
   399  }
   400  
   401  func (c *ClientPacketAddrConn) readResponse() error {
   402  	response, err := ReadStreamResponse(c.ExtendedConn)
   403  	if err != nil {
   404  		return err
   405  	}
   406  	if response.Status == statusError {
   407  		return E.New("remote error: ", response.Message)
   408  	}
   409  	return nil
   410  }
   411  
   412  func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   413  	if !c.responseRead {
   414  		err = c.readResponse()
   415  		if err != nil {
   416  			return
   417  		}
   418  		c.responseRead = true
   419  	}
   420  	destination, err := M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
   421  	if err != nil {
   422  		return
   423  	}
   424  	if destination.IsFqdn() {
   425  		addr = destination
   426  	} else {
   427  		addr = destination.UDPAddr()
   428  	}
   429  	var length uint16
   430  	err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
   431  	if err != nil {
   432  		return
   433  	}
   434  	if cap(p) < int(length) {
   435  		return 0, nil, io.ErrShortBuffer
   436  	}
   437  	n, err = io.ReadFull(c.ExtendedConn, p[:length])
   438  	return
   439  }
   440  
   441  func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
   442  	request := StreamRequest{
   443  		Network:     N.NetworkUDP,
   444  		Destination: c.destination,
   445  		PacketAddr:  true,
   446  	}
   447  	rLen := requestLen(request)
   448  	if len(payload) > 0 {
   449  		rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload)
   450  	}
   451  	_buffer := buf.StackNewSize(rLen)
   452  	defer common.KeepAlive(_buffer)
   453  	buffer := common.Dup(_buffer)
   454  	defer buffer.Release()
   455  	EncodeStreamRequest(request, buffer)
   456  	if len(payload) > 0 {
   457  		common.Must(
   458  			M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
   459  			binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
   460  			common.Error(buffer.Write(payload)),
   461  		)
   462  	}
   463  	_, err = c.ExtendedConn.Write(buffer.Bytes())
   464  	if err != nil {
   465  		return
   466  	}
   467  	c.requestWrite = true
   468  	return len(payload), nil
   469  }
   470  
   471  func (c *ClientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   472  	if !c.requestWrite {
   473  		return c.writeRequest(p, M.SocksaddrFromNet(addr))
   474  	}
   475  	err = M.SocksaddrSerializer.WriteAddrPort(c.ExtendedConn, M.SocksaddrFromNet(addr))
   476  	if err != nil {
   477  		return
   478  	}
   479  	err = binary.Write(c.ExtendedConn, binary.BigEndian, uint16(len(p)))
   480  	if err != nil {
   481  		return
   482  	}
   483  	return c.ExtendedConn.Write(p)
   484  }
   485  
   486  func (c *ClientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   487  	if !c.responseRead {
   488  		err = c.readResponse()
   489  		if err != nil {
   490  			return
   491  		}
   492  		c.responseRead = true
   493  	}
   494  	destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
   495  	if err != nil {
   496  		return
   497  	}
   498  	var length uint16
   499  	err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
   500  	if err != nil {
   501  		return
   502  	}
   503  	_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
   504  	return
   505  }
   506  
   507  func (c *ClientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   508  	if !c.requestWrite {
   509  		defer buffer.Release()
   510  		return common.Error(c.writeRequest(buffer.Bytes(), destination))
   511  	}
   512  	bLen := buffer.Len()
   513  	header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2))
   514  	common.Must(
   515  		M.SocksaddrSerializer.WriteAddrPort(header, destination),
   516  		binary.Write(header, binary.BigEndian, uint16(bLen)),
   517  	)
   518  	return c.ExtendedConn.WriteBuffer(buffer)
   519  }
   520  
   521  func (c *ClientPacketAddrConn) LocalAddr() net.Addr {
   522  	return c.ExtendedConn.LocalAddr()
   523  }
   524  
   525  func (c *ClientPacketAddrConn) FrontHeadroom() int {
   526  	return 2 + M.MaxSocksaddrLength
   527  }
   528  
   529  func (c *ClientPacketAddrConn) NeedAdditionalReadDeadline() bool {
   530  	return true
   531  }
   532  
   533  func (c *ClientPacketAddrConn) Upstream() any {
   534  	return c.ExtendedConn
   535  }