github.com/kelleygo/clashcore@v1.0.2/transport/hysteria/core/client.go (about)

     1  package core
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"strconv"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/kelleygo/clashcore/transport/hysteria/obfs"
    15  	"github.com/kelleygo/clashcore/transport/hysteria/pmtud_fix"
    16  	"github.com/kelleygo/clashcore/transport/hysteria/transport"
    17  	"github.com/kelleygo/clashcore/transport/hysteria/utils"
    18  
    19  	"github.com/lunixbochs/struc"
    20  	"github.com/metacubex/quic-go"
    21  	"github.com/metacubex/quic-go/congestion"
    22  	"github.com/zhangyunhao116/fastrand"
    23  )
    24  
    25  var (
    26  	ErrClosed = errors.New("closed")
    27  )
    28  
    29  type CongestionFactory func(refBPS uint64) congestion.CongestionControl
    30  
    31  type Client struct {
    32  	transport         *transport.ClientTransport
    33  	serverAddr        string
    34  	serverPorts       string
    35  	protocol          string
    36  	sendBPS, recvBPS  uint64
    37  	auth              []byte
    38  	congestionFactory CongestionFactory
    39  	obfuscator        obfs.Obfuscator
    40  
    41  	tlsConfig  *tls.Config
    42  	quicConfig *quic.Config
    43  
    44  	quicSession    quic.Connection
    45  	reconnectMutex sync.Mutex
    46  	closed         bool
    47  
    48  	udpSessionMutex sync.RWMutex
    49  	udpSessionMap   map[uint32]chan *udpMessage
    50  	udpDefragger    defragger
    51  	hopInterval     time.Duration
    52  	fastOpen        bool
    53  }
    54  
    55  func NewClient(serverAddr string, serverPorts string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
    56  	transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
    57  	obfuscator obfs.Obfuscator, hopInterval time.Duration, fastOpen bool) (*Client, error) {
    58  	quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery
    59  	c := &Client{
    60  		transport:         transport,
    61  		serverAddr:        serverAddr,
    62  		serverPorts:       serverPorts,
    63  		protocol:          protocol,
    64  		sendBPS:           sendBPS,
    65  		recvBPS:           recvBPS,
    66  		auth:              auth,
    67  		congestionFactory: congestionFactory,
    68  		obfuscator:        obfuscator,
    69  		tlsConfig:         tlsConfig,
    70  		quicConfig:        quicConfig,
    71  		hopInterval:       hopInterval,
    72  		fastOpen:          fastOpen,
    73  	}
    74  	return c, nil
    75  }
    76  
    77  func (c *Client) connectToServer(dialer utils.PacketDialer) error {
    78  	qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.serverPorts, c.tlsConfig, c.quicConfig, c.obfuscator, c.hopInterval, dialer)
    79  	if err != nil {
    80  		return err
    81  	}
    82  	// Control stream
    83  	ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
    84  	stream, err := qs.OpenStreamSync(ctx)
    85  	ctxCancel()
    86  	if err != nil {
    87  		_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
    88  		return err
    89  	}
    90  	ok, msg, err := c.handleControlStream(qs, stream)
    91  	if err != nil {
    92  		_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
    93  		return err
    94  	}
    95  	if !ok {
    96  		_ = qs.CloseWithError(closeErrorCodeAuth, "auth error")
    97  		return fmt.Errorf("auth error: %s", msg)
    98  	}
    99  	// All good
   100  	c.udpSessionMap = make(map[uint32]chan *udpMessage)
   101  	go c.handleMessage(qs)
   102  	c.quicSession = qs
   103  	return nil
   104  }
   105  
   106  func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bool, string, error) {
   107  	// Send protocol version
   108  	_, err := stream.Write([]byte{protocolVersion})
   109  	if err != nil {
   110  		return false, "", err
   111  	}
   112  	// Send client hello
   113  	err = struc.Pack(stream, &clientHello{
   114  		Rate: transmissionRate{
   115  			SendBPS: c.sendBPS,
   116  			RecvBPS: c.recvBPS,
   117  		},
   118  		Auth: c.auth,
   119  	})
   120  	if err != nil {
   121  		return false, "", err
   122  	}
   123  	// Receive server hello
   124  	var sh serverHello
   125  	err = struc.Unpack(stream, &sh)
   126  	if err != nil {
   127  		return false, "", err
   128  	}
   129  	// Set the congestion accordingly
   130  	if sh.OK && c.congestionFactory != nil {
   131  		qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS))
   132  	}
   133  	return sh.OK, sh.Message, nil
   134  }
   135  
   136  func (c *Client) handleMessage(qs quic.Connection) {
   137  	for {
   138  		msg, err := qs.ReceiveDatagram(context.Background())
   139  		if err != nil {
   140  			break
   141  		}
   142  		var udpMsg udpMessage
   143  		err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
   144  		if err != nil {
   145  			continue
   146  		}
   147  		dfMsg := c.udpDefragger.Feed(udpMsg)
   148  		if dfMsg == nil {
   149  			continue
   150  		}
   151  		c.udpSessionMutex.RLock()
   152  		ch, ok := c.udpSessionMap[dfMsg.SessionID]
   153  		if ok {
   154  			select {
   155  			case ch <- dfMsg:
   156  				// OK
   157  			default:
   158  				// Silently drop the message when the channel is full
   159  			}
   160  		}
   161  		c.udpSessionMutex.RUnlock()
   162  	}
   163  }
   164  
   165  func (c *Client) openStreamWithReconnect(dialer utils.PacketDialer) (quic.Connection, quic.Stream, error) {
   166  	c.reconnectMutex.Lock()
   167  	defer c.reconnectMutex.Unlock()
   168  	if c.closed {
   169  		return nil, nil, ErrClosed
   170  	}
   171  	if c.quicSession == nil {
   172  		if err := c.connectToServer(dialer); err != nil {
   173  			// Still error, oops
   174  			return nil, nil, err
   175  		}
   176  	}
   177  	stream, err := c.quicSession.OpenStream()
   178  	if err == nil {
   179  		// All good
   180  		return c.quicSession, &wrappedQUICStream{stream}, nil
   181  	}
   182  	// Something is wrong
   183  	if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
   184  		// Temporary error, just return
   185  		return nil, nil, err
   186  	}
   187  	// Permanent error, need to reconnect
   188  	if err := c.connectToServer(dialer); err != nil {
   189  		// Still error, oops
   190  		return nil, nil, err
   191  	}
   192  	// We are not going to try again even if it still fails the second time
   193  	stream, err = c.quicSession.OpenStream()
   194  	return c.quicSession, &wrappedQUICStream{stream}, err
   195  }
   196  
   197  func (c *Client) DialTCP(host string, port uint16, dialer utils.PacketDialer) (net.Conn, error) {
   198  	session, stream, err := c.openStreamWithReconnect(dialer)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  	// Send request
   203  	err = struc.Pack(stream, &clientRequest{
   204  		UDP:  false,
   205  		Host: host,
   206  		Port: port,
   207  	})
   208  	if err != nil {
   209  		_ = stream.Close()
   210  		return nil, err
   211  	}
   212  	// If fast open is enabled, we return the stream immediately
   213  	// and defer the response handling to the first Read() call
   214  	if !c.fastOpen {
   215  		// Read response
   216  		var sr serverResponse
   217  		err = struc.Unpack(stream, &sr)
   218  		if err != nil {
   219  			_ = stream.Close()
   220  			return nil, err
   221  		}
   222  		if !sr.OK {
   223  			_ = stream.Close()
   224  			return nil, fmt.Errorf("connection rejected: %s", sr.Message)
   225  		}
   226  	}
   227  
   228  	return &quicConn{
   229  		Orig:             stream,
   230  		PseudoLocalAddr:  session.LocalAddr(),
   231  		PseudoRemoteAddr: session.RemoteAddr(),
   232  		Established:      !c.fastOpen,
   233  	}, nil
   234  }
   235  
   236  func (c *Client) DialUDP(dialer utils.PacketDialer) (UDPConn, error) {
   237  	session, stream, err := c.openStreamWithReconnect(dialer)
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  	// Send request
   242  	err = struc.Pack(stream, &clientRequest{
   243  		UDP: true,
   244  	})
   245  	if err != nil {
   246  		_ = stream.Close()
   247  		return nil, err
   248  	}
   249  	// Read response
   250  	var sr serverResponse
   251  	err = struc.Unpack(stream, &sr)
   252  	if err != nil {
   253  		_ = stream.Close()
   254  		return nil, err
   255  	}
   256  	if !sr.OK {
   257  		_ = stream.Close()
   258  		return nil, fmt.Errorf("connection rejected: %s", sr.Message)
   259  	}
   260  
   261  	// Create a session in the map
   262  	c.udpSessionMutex.Lock()
   263  	nCh := make(chan *udpMessage, 1024)
   264  	// Store the current session map for CloseFunc below
   265  	// to ensures that we are adding and removing sessions on the same map,
   266  	// as reconnecting will reassign the map
   267  	sessionMap := c.udpSessionMap
   268  	sessionMap[sr.UDPSessionID] = nCh
   269  	c.udpSessionMutex.Unlock()
   270  
   271  	pktConn := &quicPktConn{
   272  		Session: session,
   273  		Stream:  stream,
   274  		CloseFunc: func() {
   275  			c.udpSessionMutex.Lock()
   276  			if ch, ok := sessionMap[sr.UDPSessionID]; ok {
   277  				close(ch)
   278  				delete(sessionMap, sr.UDPSessionID)
   279  			}
   280  			c.udpSessionMutex.Unlock()
   281  		},
   282  		UDPSessionID: sr.UDPSessionID,
   283  		MsgCh:        nCh,
   284  	}
   285  	go pktConn.Hold()
   286  	return pktConn, nil
   287  }
   288  
   289  func (c *Client) Close() error {
   290  	c.reconnectMutex.Lock()
   291  	defer c.reconnectMutex.Unlock()
   292  	err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "")
   293  	c.closed = true
   294  	return err
   295  }
   296  
   297  type quicConn struct {
   298  	Orig             quic.Stream
   299  	PseudoLocalAddr  net.Addr
   300  	PseudoRemoteAddr net.Addr
   301  	Established      bool
   302  }
   303  
   304  func (w *quicConn) Read(b []byte) (n int, err error) {
   305  	if !w.Established {
   306  		var sr serverResponse
   307  		err := struc.Unpack(w.Orig, &sr)
   308  		if err != nil {
   309  			_ = w.Close()
   310  			return 0, err
   311  		}
   312  		if !sr.OK {
   313  			_ = w.Close()
   314  			return 0, fmt.Errorf("connection rejected: %s", sr.Message)
   315  		}
   316  		w.Established = true
   317  	}
   318  	return w.Orig.Read(b)
   319  }
   320  
   321  func (w *quicConn) Write(b []byte) (n int, err error) {
   322  	return w.Orig.Write(b)
   323  }
   324  
   325  func (w *quicConn) Close() error {
   326  	return w.Orig.Close()
   327  }
   328  
   329  func (w *quicConn) LocalAddr() net.Addr {
   330  	return w.PseudoLocalAddr
   331  }
   332  
   333  func (w *quicConn) RemoteAddr() net.Addr {
   334  	return w.PseudoRemoteAddr
   335  }
   336  
   337  func (w *quicConn) SetDeadline(t time.Time) error {
   338  	return w.Orig.SetDeadline(t)
   339  }
   340  
   341  func (w *quicConn) SetReadDeadline(t time.Time) error {
   342  	return w.Orig.SetReadDeadline(t)
   343  }
   344  
   345  func (w *quicConn) SetWriteDeadline(t time.Time) error {
   346  	return w.Orig.SetWriteDeadline(t)
   347  }
   348  
   349  type UDPConn interface {
   350  	ReadFrom() ([]byte, string, error)
   351  	WriteTo([]byte, string) error
   352  	Close() error
   353  	LocalAddr() net.Addr
   354  	SetDeadline(t time.Time) error
   355  	SetReadDeadline(t time.Time) error
   356  	SetWriteDeadline(t time.Time) error
   357  }
   358  
   359  type quicPktConn struct {
   360  	Session      quic.Connection
   361  	Stream       quic.Stream
   362  	CloseFunc    func()
   363  	UDPSessionID uint32
   364  	MsgCh        <-chan *udpMessage
   365  }
   366  
   367  func (c *quicPktConn) Hold() {
   368  	// Hold the stream until it's closed
   369  	buf := make([]byte, 1024)
   370  	for {
   371  		_, err := c.Stream.Read(buf)
   372  		if err != nil {
   373  			break
   374  		}
   375  	}
   376  	_ = c.Close()
   377  }
   378  
   379  func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
   380  	msg := <-c.MsgCh
   381  	if msg == nil {
   382  		// Closed
   383  		return nil, "", ErrClosed
   384  	}
   385  	return msg.Data, net.JoinHostPort(msg.Host, strconv.Itoa(int(msg.Port))), nil
   386  }
   387  
   388  func (c *quicPktConn) WriteTo(p []byte, addr string) error {
   389  	host, port, err := utils.SplitHostPort(addr)
   390  	if err != nil {
   391  		return err
   392  	}
   393  	msg := udpMessage{
   394  		SessionID: c.UDPSessionID,
   395  		Host:      host,
   396  		Port:      port,
   397  		FragCount: 1,
   398  		Data:      p,
   399  	}
   400  	// try no frag first
   401  	var msgBuf bytes.Buffer
   402  	_ = struc.Pack(&msgBuf, &msg)
   403  	err = c.Session.SendDatagram(msgBuf.Bytes())
   404  	if err != nil {
   405  		var errSize *quic.DatagramTooLargeError
   406  		if errors.As(err, &errSize) {
   407  			// need to frag
   408  			msg.MsgID = uint16(fastrand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
   409  			fragMsgs := fragUDPMessage(msg, int(errSize.PeerMaxDatagramFrameSize))
   410  			for _, fragMsg := range fragMsgs {
   411  				msgBuf.Reset()
   412  				_ = struc.Pack(&msgBuf, &fragMsg)
   413  				err = c.Session.SendDatagram(msgBuf.Bytes())
   414  				if err != nil {
   415  					return err
   416  				}
   417  			}
   418  			return nil
   419  		} else {
   420  			// some other error
   421  			return err
   422  		}
   423  	} else {
   424  		return nil
   425  	}
   426  }
   427  
   428  func (c *quicPktConn) Close() error {
   429  	c.CloseFunc()
   430  	return c.Stream.Close()
   431  }
   432  
   433  func (c *quicPktConn) LocalAddr() net.Addr {
   434  	return c.Session.LocalAddr()
   435  }
   436  
   437  func (c *quicPktConn) SetDeadline(t time.Time) error {
   438  	return c.Stream.SetDeadline(t)
   439  }
   440  
   441  func (c *quicPktConn) SetReadDeadline(t time.Time) error {
   442  	return c.Stream.SetReadDeadline(t)
   443  }
   444  
   445  func (c *quicPktConn) SetWriteDeadline(t time.Time) error {
   446  	return c.Stream.SetWriteDeadline(t)
   447  }