github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/client.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"sync"
     7  
     8  	"github.com/sagernet/sing/common"
     9  	"github.com/sagernet/sing/common/bufio"
    10  	E "github.com/sagernet/sing/common/exceptions"
    11  	"github.com/sagernet/sing/common/logger"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	N "github.com/sagernet/sing/common/network"
    14  	"github.com/sagernet/sing/common/x/list"
    15  )
    16  
    17  type Client struct {
    18  	dialer         N.Dialer
    19  	logger         logger.Logger
    20  	protocol       byte
    21  	maxConnections int
    22  	minStreams     int
    23  	maxStreams     int
    24  	padding        bool
    25  	access         sync.Mutex
    26  	connections    list.List[abstractSession]
    27  	brutal         BrutalOptions
    28  }
    29  
    30  type Options struct {
    31  	Dialer         N.Dialer
    32  	Logger         logger.Logger
    33  	Protocol       string
    34  	MaxConnections int
    35  	MinStreams     int
    36  	MaxStreams     int
    37  	Padding        bool
    38  	Brutal         BrutalOptions
    39  }
    40  
    41  type BrutalOptions struct {
    42  	Enabled    bool
    43  	SendBPS    uint64
    44  	ReceiveBPS uint64
    45  }
    46  
    47  func NewClient(options Options) (*Client, error) {
    48  	client := &Client{
    49  		dialer:         options.Dialer,
    50  		logger:         options.Logger,
    51  		maxConnections: options.MaxConnections,
    52  		minStreams:     options.MinStreams,
    53  		maxStreams:     options.MaxStreams,
    54  		padding:        options.Padding,
    55  		brutal:         options.Brutal,
    56  	}
    57  	if client.dialer == nil {
    58  		client.dialer = N.SystemDialer
    59  	}
    60  	if client.maxStreams == 0 && client.maxConnections == 0 {
    61  		client.minStreams = 8
    62  	}
    63  	switch options.Protocol {
    64  	case "", "h2mux":
    65  		client.protocol = ProtocolH2Mux
    66  	case "smux":
    67  		client.protocol = ProtocolSmux
    68  	case "yamux":
    69  		client.protocol = ProtocolYAMux
    70  	default:
    71  		return nil, E.New("unknown protocol: " + options.Protocol)
    72  	}
    73  	return client, nil
    74  }
    75  
    76  func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    77  	switch N.NetworkName(network) {
    78  	case N.NetworkTCP:
    79  		stream, err := c.openStream(ctx)
    80  		if err != nil {
    81  			return nil, err
    82  		}
    83  		return &clientConn{Conn: stream, destination: destination}, nil
    84  	case N.NetworkUDP:
    85  		stream, err := c.openStream(ctx)
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  		extendedConn := bufio.NewExtendedConn(stream)
    90  		return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
    91  	default:
    92  		return nil, E.Extend(N.ErrUnknownNetwork, network)
    93  	}
    94  }
    95  
    96  func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
    97  	stream, err := c.openStream(ctx)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	extendedConn := bufio.NewExtendedConn(stream)
   102  	return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil
   103  }
   104  
   105  func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
   106  	var (
   107  		session abstractSession
   108  		stream  net.Conn
   109  		err     error
   110  	)
   111  	for attempts := 0; attempts < 2; attempts++ {
   112  		session, err = c.offer(ctx)
   113  		if err != nil {
   114  			continue
   115  		}
   116  		stream, err = session.Open()
   117  		if err != nil {
   118  			continue
   119  		}
   120  		break
   121  	}
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	return &wrapStream{stream}, nil
   126  }
   127  
   128  func (c *Client) offer(ctx context.Context) (abstractSession, error) {
   129  	c.access.Lock()
   130  	defer c.access.Unlock()
   131  
   132  	var sessions []abstractSession
   133  	for element := c.connections.Front(); element != nil; {
   134  		if element.Value.IsClosed() {
   135  			element.Value.Close()
   136  			nextElement := element.Next()
   137  			c.connections.Remove(element)
   138  			element = nextElement
   139  			continue
   140  		}
   141  		sessions = append(sessions, element.Value)
   142  		element = element.Next()
   143  	}
   144  	if c.brutal.Enabled {
   145  		if len(sessions) > 0 {
   146  			return sessions[0], nil
   147  		}
   148  		return c.offerNew(ctx)
   149  	}
   150  	session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams)
   151  	if session == nil {
   152  		return c.offerNew(ctx)
   153  	}
   154  	numStreams := session.NumStreams()
   155  	if numStreams == 0 {
   156  		return session, nil
   157  	}
   158  	if c.maxConnections > 0 {
   159  		if len(sessions) >= c.maxConnections || numStreams < c.minStreams {
   160  			return session, nil
   161  		}
   162  	} else {
   163  		if c.maxStreams > 0 && numStreams < c.maxStreams {
   164  			return session, nil
   165  		}
   166  	}
   167  	return c.offerNew(ctx)
   168  }
   169  
   170  func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
   171  	ctx, cancel := context.WithTimeout(ctx, TCPTimeout)
   172  	defer cancel()
   173  	conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  	var version byte
   178  	if c.padding {
   179  		version = Version1
   180  	} else {
   181  		version = Version0
   182  	}
   183  	conn = newProtocolConn(conn, Request{
   184  		Version:  version,
   185  		Protocol: c.protocol,
   186  		Padding:  c.padding,
   187  	})
   188  	if c.padding {
   189  		conn = newPaddingConn(conn)
   190  	}
   191  	session, err := newClientSession(conn, c.protocol)
   192  	if err != nil {
   193  		conn.Close()
   194  		return nil, err
   195  	}
   196  	if c.brutal.Enabled {
   197  		err = c.brutalExchange(ctx, conn, session)
   198  		if err != nil {
   199  			conn.Close()
   200  			session.Close()
   201  			return nil, E.Cause(err, "brutal exchange")
   202  		}
   203  	}
   204  	c.connections.PushBack(session)
   205  	return session, nil
   206  }
   207  
   208  func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error {
   209  	stream, err := session.Open()
   210  	if err != nil {
   211  		return err
   212  	}
   213  	conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}}
   214  	err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS)
   215  	if err != nil {
   216  		return err
   217  	}
   218  	serverReceiveBPS, err := ReadBrutalResponse(conn)
   219  	if err != nil {
   220  		return err
   221  	}
   222  	conn.Close()
   223  	sendBPS := c.brutal.SendBPS
   224  	if serverReceiveBPS < sendBPS {
   225  		sendBPS = serverReceiveBPS
   226  	}
   227  	clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS)
   228  	if clientBrutalErr != nil {
   229  		c.logger.Debug(E.Cause(clientBrutalErr, "failed to enable TCP Brutal at client"))
   230  	}
   231  	return nil
   232  }
   233  
   234  func (c *Client) Reset() {
   235  	c.access.Lock()
   236  	defer c.access.Unlock()
   237  	for _, session := range c.connections.Array() {
   238  		session.Close()
   239  	}
   240  	c.connections.Init()
   241  }
   242  
   243  func (c *Client) Close() error {
   244  	c.Reset()
   245  	return nil
   246  }