github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria2/client.go (about)

     1  package hysteria2
     2  
     3  import (
     4  	"context"
     5  	"github.com/sagernet/quic-go"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"net/url"
    10  	"os"
    11  	"runtime"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/inazumav/sing-box/common/qtls"
    16  	"github.com/inazumav/sing-box/common/tls"
    17  	"github.com/inazumav/sing-box/transport/hysteria2/congestion"
    18  	"github.com/inazumav/sing-box/transport/hysteria2/internal/protocol"
    19  	tuicCongestion "github.com/inazumav/sing-box/transport/tuic/congestion"
    20  	"github.com/sagernet/sing/common/bufio"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	M "github.com/sagernet/sing/common/metadata"
    23  	N "github.com/sagernet/sing/common/network"
    24  )
    25  
    26  const (
    27  	defaultStreamReceiveWindow = 8388608                            // 8MB
    28  	defaultConnReceiveWindow   = defaultStreamReceiveWindow * 5 / 2 // 20MB
    29  	defaultMaxIdleTimeout      = 30 * time.Second
    30  	defaultKeepAlivePeriod     = 10 * time.Second
    31  )
    32  
    33  type ClientOptions struct {
    34  	Context            context.Context
    35  	Dialer             N.Dialer
    36  	ServerAddress      M.Socksaddr
    37  	SendBPS            uint64
    38  	ReceiveBPS         uint64
    39  	SalamanderPassword string
    40  	Password           string
    41  	TLSConfig          tls.Config
    42  	UDPDisabled        bool
    43  }
    44  
    45  type Client struct {
    46  	ctx                context.Context
    47  	dialer             N.Dialer
    48  	serverAddr         M.Socksaddr
    49  	sendBPS            uint64
    50  	receiveBPS         uint64
    51  	salamanderPassword string
    52  	password           string
    53  	tlsConfig          tls.Config
    54  	quicConfig         *quic.Config
    55  	udpDisabled        bool
    56  
    57  	connAccess sync.RWMutex
    58  	conn       *clientQUICConnection
    59  }
    60  
    61  func NewClient(options ClientOptions) (*Client, error) {
    62  	quicConfig := &quic.Config{
    63  		DisablePathMTUDiscovery:        !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
    64  		MaxDatagramFrameSize:           1400,
    65  		EnableDatagrams:                true,
    66  		InitialStreamReceiveWindow:     defaultStreamReceiveWindow,
    67  		MaxStreamReceiveWindow:         defaultStreamReceiveWindow,
    68  		InitialConnectionReceiveWindow: defaultConnReceiveWindow,
    69  		MaxConnectionReceiveWindow:     defaultConnReceiveWindow,
    70  		MaxIdleTimeout:                 defaultMaxIdleTimeout,
    71  		KeepAlivePeriod:                defaultKeepAlivePeriod,
    72  	}
    73  	return &Client{
    74  		ctx:                options.Context,
    75  		dialer:             options.Dialer,
    76  		serverAddr:         options.ServerAddress,
    77  		sendBPS:            options.SendBPS,
    78  		receiveBPS:         options.ReceiveBPS,
    79  		salamanderPassword: options.SalamanderPassword,
    80  		password:           options.Password,
    81  		tlsConfig:          options.TLSConfig,
    82  		quicConfig:         quicConfig,
    83  		udpDisabled:        options.UDPDisabled,
    84  	}, nil
    85  }
    86  
    87  func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
    88  	conn := c.conn
    89  	if conn != nil && conn.active() {
    90  		return conn, nil
    91  	}
    92  	c.connAccess.Lock()
    93  	defer c.connAccess.Unlock()
    94  	conn = c.conn
    95  	if conn != nil && conn.active() {
    96  		return conn, nil
    97  	}
    98  	conn, err := c.offerNew(ctx)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	return conn, nil
   103  }
   104  
   105  func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
   106  	udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	var packetConn net.PacketConn
   111  	packetConn = bufio.NewUnbindPacketConn(udpConn)
   112  	if c.salamanderPassword != "" {
   113  		packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword))
   114  	}
   115  	var quicConn quic.EarlyConnection
   116  	http3Transport, err := qtls.CreateTransport(packetConn, &quicConn, c.serverAddr, c.tlsConfig, c.quicConfig, true)
   117  	if err != nil {
   118  		udpConn.Close()
   119  		return nil, err
   120  	}
   121  	request := &http.Request{
   122  		Method: http.MethodPost,
   123  		URL: &url.URL{
   124  			Scheme: "https",
   125  			Host:   protocol.URLHost,
   126  			Path:   protocol.URLPath,
   127  		},
   128  		Header: make(http.Header),
   129  	}
   130  	protocol.AuthRequestToHeader(request.Header, protocol.AuthRequest{Auth: c.password, Rx: c.receiveBPS})
   131  	response, err := http3Transport.RoundTrip(request)
   132  	if err != nil {
   133  		if quicConn != nil {
   134  			quicConn.CloseWithError(0, "")
   135  		}
   136  		udpConn.Close()
   137  		return nil, err
   138  	}
   139  	if response.StatusCode != protocol.StatusAuthOK {
   140  		if quicConn != nil {
   141  			quicConn.CloseWithError(0, "")
   142  		}
   143  		udpConn.Close()
   144  		return nil, E.New("authentication failed, status code: ", response.StatusCode)
   145  	}
   146  	response.Body.Close()
   147  	authResponse := protocol.AuthResponseFromHeader(response.Header)
   148  	actualTx := authResponse.Rx
   149  	if actualTx == 0 || actualTx > c.sendBPS {
   150  		actualTx = c.sendBPS
   151  	}
   152  	if !authResponse.RxAuto && actualTx > 0 {
   153  		quicConn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
   154  	} else {
   155  		quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
   156  			tuicCongestion.DefaultClock{},
   157  			tuicCongestion.GetInitialPacketSize(quicConn.RemoteAddr()),
   158  			tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
   159  			tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
   160  		))
   161  	}
   162  	conn := &clientQUICConnection{
   163  		quicConn:    quicConn,
   164  		rawConn:     udpConn,
   165  		connDone:    make(chan struct{}),
   166  		udpDisabled: c.udpDisabled || !authResponse.UDPEnabled,
   167  		udpConnMap:  make(map[uint32]*udpPacketConn),
   168  	}
   169  	if !c.udpDisabled {
   170  		go c.loopMessages(conn)
   171  	}
   172  	c.conn = conn
   173  	return conn, nil
   174  }
   175  
   176  func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
   177  	conn, err := c.offer(ctx)
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	stream, err := conn.quicConn.OpenStream()
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  	return &clientConn{
   186  		Stream:      stream,
   187  		destination: destination,
   188  	}, nil
   189  }
   190  
   191  func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
   192  	if c.udpDisabled {
   193  		return nil, os.ErrInvalid
   194  	}
   195  	conn, err := c.offer(ctx)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	if conn.udpDisabled {
   200  		return nil, E.New("UDP disabled by server")
   201  	}
   202  	var sessionID uint32
   203  	clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, func() {
   204  		conn.udpAccess.Lock()
   205  		delete(conn.udpConnMap, sessionID)
   206  		conn.udpAccess.Unlock()
   207  	})
   208  	conn.udpAccess.Lock()
   209  	sessionID = conn.udpSessionID
   210  	conn.udpSessionID++
   211  	conn.udpConnMap[sessionID] = clientPacketConn
   212  	conn.udpAccess.Unlock()
   213  	clientPacketConn.sessionID = sessionID
   214  	return clientPacketConn, nil
   215  }
   216  
   217  func (c *Client) CloseWithError(err error) error {
   218  	conn := c.conn
   219  	if conn != nil {
   220  		conn.closeWithError(err)
   221  	}
   222  	return nil
   223  }
   224  
   225  type clientQUICConnection struct {
   226  	quicConn     quic.Connection
   227  	rawConn      io.Closer
   228  	closeOnce    sync.Once
   229  	connDone     chan struct{}
   230  	connErr      error
   231  	udpDisabled  bool
   232  	udpAccess    sync.RWMutex
   233  	udpConnMap   map[uint32]*udpPacketConn
   234  	udpSessionID uint32
   235  }
   236  
   237  func (c *clientQUICConnection) active() bool {
   238  	select {
   239  	case <-c.quicConn.Context().Done():
   240  		return false
   241  	default:
   242  	}
   243  	select {
   244  	case <-c.connDone:
   245  		return false
   246  	default:
   247  	}
   248  	return true
   249  }
   250  
   251  func (c *clientQUICConnection) closeWithError(err error) {
   252  	c.closeOnce.Do(func() {
   253  		c.connErr = err
   254  		close(c.connDone)
   255  		c.quicConn.CloseWithError(0, "")
   256  	})
   257  }
   258  
   259  type clientConn struct {
   260  	quic.Stream
   261  	destination    M.Socksaddr
   262  	requestWritten bool
   263  	responseRead   bool
   264  }
   265  
   266  func (c *clientConn) NeedHandshake() bool {
   267  	return !c.requestWritten
   268  }
   269  
   270  func (c *clientConn) Read(p []byte) (n int, err error) {
   271  	if c.responseRead {
   272  		return c.Stream.Read(p)
   273  	}
   274  	status, errorMessage, err := protocol.ReadTCPResponse(c.Stream)
   275  	if err != nil {
   276  		return
   277  	}
   278  	if !status {
   279  		err = E.New("remote error: ", errorMessage)
   280  		return
   281  	}
   282  	c.responseRead = true
   283  	return c.Stream.Read(p)
   284  }
   285  
   286  func (c *clientConn) Write(p []byte) (n int, err error) {
   287  	if !c.requestWritten {
   288  		buffer := protocol.WriteTCPRequest(c.destination.String(), p)
   289  		defer buffer.Release()
   290  		_, err = c.Stream.Write(buffer.Bytes())
   291  		if err != nil {
   292  			return
   293  		}
   294  		c.requestWritten = true
   295  		return len(p), nil
   296  	}
   297  	return c.Stream.Write(p)
   298  }
   299  
   300  func (c *clientConn) LocalAddr() net.Addr {
   301  	return M.Socksaddr{}
   302  }
   303  
   304  func (c *clientConn) RemoteAddr() net.Addr {
   305  	return M.Socksaddr{}
   306  }