github.com/kelleygo/clashcore@v1.0.2/transport/tuic/v5/client.go (about)

     1  package v5
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"crypto/tls"
     8  	"errors"
     9  	"net"
    10  	"runtime"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	atomic2 "github.com/kelleygo/clashcore/common/atomic"
    16  	N "github.com/kelleygo/clashcore/common/net"
    17  	"github.com/kelleygo/clashcore/common/pool"
    18  	C "github.com/kelleygo/clashcore/constant"
    19  	"github.com/kelleygo/clashcore/log"
    20  	"github.com/kelleygo/clashcore/transport/tuic/common"
    21  
    22  	"github.com/metacubex/quic-go"
    23  	"github.com/puzpuzpuz/xsync/v3"
    24  	"github.com/zhangyunhao116/fastrand"
    25  )
    26  
    27  type ClientOption struct {
    28  	TlsConfig             *tls.Config
    29  	QuicConfig            *quic.Config
    30  	Uuid                  [16]byte
    31  	Password              string
    32  	UdpRelayMode          common.UdpRelayMode
    33  	CongestionController  string
    34  	ReduceRtt             bool
    35  	MaxUdpRelayPacketSize int
    36  	MaxOpenStreams        int64
    37  	CWND                  int
    38  }
    39  
    40  type clientImpl struct {
    41  	*ClientOption
    42  	udp bool
    43  
    44  	quicConn  quic.Connection
    45  	connMutex sync.Mutex
    46  
    47  	openStreams atomic.Int64
    48  	closed      atomic.Bool
    49  
    50  	udpInputMap *xsync.MapOf[uint16, net.Conn]
    51  
    52  	// only ready for PoolClient
    53  	dialerRef   C.Dialer
    54  	lastVisited atomic2.TypedValue[time.Time]
    55  }
    56  
    57  func (t *clientImpl) OpenStreams() int64 {
    58  	return t.openStreams.Load()
    59  }
    60  
    61  func (t *clientImpl) DialerRef() C.Dialer {
    62  	return t.dialerRef
    63  }
    64  
    65  func (t *clientImpl) LastVisited() time.Time {
    66  	return t.lastVisited.Load()
    67  }
    68  
    69  func (t *clientImpl) SetLastVisited(last time.Time) {
    70  	t.lastVisited.Store(last)
    71  }
    72  
    73  func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (quic.Connection, error) {
    74  	t.connMutex.Lock()
    75  	defer t.connMutex.Unlock()
    76  	if t.quicConn != nil {
    77  		return t.quicConn, nil
    78  	}
    79  	transport, addr, err := dialFn(ctx, dialer)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	var quicConn quic.Connection
    84  	if t.ReduceRtt {
    85  		quicConn, err = transport.DialEarly(ctx, addr, t.TlsConfig, t.QuicConfig)
    86  	} else {
    87  		quicConn, err = transport.Dial(ctx, addr, t.TlsConfig, t.QuicConfig)
    88  	}
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	common.SetCongestionController(quicConn, t.CongestionController, t.CWND)
    94  
    95  	go func() {
    96  		_ = t.sendAuthentication(quicConn)
    97  	}()
    98  
    99  	if t.udp && t.UdpRelayMode == common.QUIC {
   100  		go func() {
   101  			_ = t.handleUniStream(quicConn)
   102  		}()
   103  	}
   104  	go func() {
   105  		_ = t.handleMessage(quicConn) // always handleMessage because tuicV5 using datagram to send the Heartbeat
   106  	}()
   107  
   108  	t.quicConn = quicConn
   109  	t.openStreams.Store(0)
   110  	return quicConn, nil
   111  }
   112  
   113  func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) {
   114  	defer func() {
   115  		t.deferQuicConn(quicConn, err)
   116  	}()
   117  	stream, err := quicConn.OpenUniStream()
   118  	if err != nil {
   119  		return err
   120  	}
   121  	buf := pool.GetBuffer()
   122  	defer pool.PutBuffer(buf)
   123  	token, err := GenToken(quicConn.ConnectionState(), t.Uuid, t.Password)
   124  	if err != nil {
   125  		return err
   126  	}
   127  	err = NewAuthenticate(t.Uuid, token).WriteTo(buf)
   128  	if err != nil {
   129  		return err
   130  	}
   131  	_, err = buf.WriteTo(stream)
   132  	if err != nil {
   133  		return err
   134  	}
   135  	err = stream.Close()
   136  	if err != nil {
   137  		return
   138  	}
   139  	return nil
   140  }
   141  
   142  func (t *clientImpl) handleUniStream(quicConn quic.Connection) (err error) {
   143  	defer func() {
   144  		t.deferQuicConn(quicConn, err)
   145  	}()
   146  	for {
   147  		var stream quic.ReceiveStream
   148  		stream, err = quicConn.AcceptUniStream(context.Background())
   149  		if err != nil {
   150  			return err
   151  		}
   152  		go func() (err error) {
   153  			var assocId uint16
   154  			defer func() {
   155  				t.deferQuicConn(quicConn, err)
   156  				if err != nil && assocId != 0 {
   157  					if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok {
   158  						if conn, ok := val.(net.Conn); ok {
   159  							_ = conn.Close()
   160  						}
   161  					}
   162  				}
   163  				stream.CancelRead(0)
   164  			}()
   165  			reader := bufio.NewReader(stream)
   166  			commandHead, err := ReadCommandHead(reader)
   167  			if err != nil {
   168  				return
   169  			}
   170  			switch commandHead.TYPE {
   171  			case PacketType:
   172  				var packet Packet
   173  				packet, err = ReadPacketWithHead(commandHead, reader)
   174  				if err != nil {
   175  					return
   176  				}
   177  				if t.udp && t.UdpRelayMode == common.QUIC {
   178  					assocId = packet.ASSOC_ID
   179  					if val, ok := t.udpInputMap.Load(assocId); ok {
   180  						if conn, ok := val.(net.Conn); ok {
   181  							writer := bufio.NewWriterSize(conn, packet.BytesLen())
   182  							_ = packet.WriteTo(writer)
   183  							_ = writer.Flush()
   184  						}
   185  					}
   186  				}
   187  			}
   188  			return
   189  		}()
   190  	}
   191  }
   192  
   193  func (t *clientImpl) handleMessage(quicConn quic.Connection) (err error) {
   194  	defer func() {
   195  		t.deferQuicConn(quicConn, err)
   196  	}()
   197  	for {
   198  		var message []byte
   199  		message, err = quicConn.ReceiveDatagram(context.Background())
   200  		if err != nil {
   201  			return err
   202  		}
   203  		go func() (err error) {
   204  			var assocId uint16
   205  			defer func() {
   206  				t.deferQuicConn(quicConn, err)
   207  				if err != nil && assocId != 0 {
   208  					if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok {
   209  						if conn, ok := val.(net.Conn); ok {
   210  							_ = conn.Close()
   211  						}
   212  					}
   213  				}
   214  			}()
   215  			reader := bytes.NewBuffer(message)
   216  			commandHead, err := ReadCommandHead(reader)
   217  			if err != nil {
   218  				return
   219  			}
   220  			switch commandHead.TYPE {
   221  			case PacketType:
   222  				var packet Packet
   223  				packet, err = ReadPacketWithHead(commandHead, reader)
   224  				if err != nil {
   225  					return
   226  				}
   227  				if t.udp && t.UdpRelayMode == common.NATIVE {
   228  					assocId = packet.ASSOC_ID
   229  					if val, ok := t.udpInputMap.Load(assocId); ok {
   230  						if conn, ok := val.(net.Conn); ok {
   231  							_, _ = conn.Write(message)
   232  						}
   233  					}
   234  				}
   235  			case HeartbeatType:
   236  				var heartbeat Heartbeat
   237  				heartbeat, err = ReadHeartbeatWithHead(commandHead, reader)
   238  				if err != nil {
   239  					return
   240  				}
   241  				heartbeat.BytesLen()
   242  			}
   243  			return
   244  		}()
   245  	}
   246  }
   247  
   248  func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) {
   249  	var netError net.Error
   250  	if err != nil && errors.As(err, &netError) {
   251  		t.forceClose(quicConn, err)
   252  	}
   253  }
   254  
   255  func (t *clientImpl) forceClose(quicConn quic.Connection, err error) {
   256  	t.connMutex.Lock()
   257  	defer t.connMutex.Unlock()
   258  	if quicConn == nil {
   259  		quicConn = t.quicConn
   260  	}
   261  	if quicConn != nil {
   262  		if quicConn == t.quicConn {
   263  			t.quicConn = nil
   264  		}
   265  	}
   266  	errStr := ""
   267  	if err != nil {
   268  		errStr = err.Error()
   269  	}
   270  	if quicConn != nil {
   271  		_ = quicConn.CloseWithError(ProtocolError, errStr)
   272  	}
   273  	udpInputMap := t.udpInputMap
   274  	udpInputMap.Range(func(key uint16, value net.Conn) bool {
   275  		conn := value
   276  		_ = conn.Close()
   277  		udpInputMap.Delete(key)
   278  		return true
   279  	})
   280  }
   281  
   282  func (t *clientImpl) Close() {
   283  	t.closed.Store(true)
   284  	if t.openStreams.Load() == 0 {
   285  		t.forceClose(nil, common.ClientClosed)
   286  	}
   287  }
   288  
   289  func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) {
   290  	quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  	openStreams := t.openStreams.Add(1)
   295  	if openStreams >= t.MaxOpenStreams {
   296  		t.openStreams.Add(-1)
   297  		return nil, common.TooManyOpenStreams
   298  	}
   299  	stream, err := func() (stream net.Conn, err error) {
   300  		defer func() {
   301  			t.deferQuicConn(quicConn, err)
   302  		}()
   303  		buf := pool.GetBuffer()
   304  		defer pool.PutBuffer(buf)
   305  		err = NewConnect(NewAddress(metadata)).WriteTo(buf)
   306  		if err != nil {
   307  			return nil, err
   308  		}
   309  		quicStream, err := quicConn.OpenStream()
   310  		if err != nil {
   311  			return nil, err
   312  		}
   313  		stream = common.NewQuicStreamConn(
   314  			quicStream,
   315  			quicConn.LocalAddr(),
   316  			quicConn.RemoteAddr(),
   317  			func() {
   318  				time.AfterFunc(C.DefaultTCPTimeout, func() {
   319  					openStreams := t.openStreams.Add(-1)
   320  					if openStreams == 0 && t.closed.Load() {
   321  						t.forceClose(quicConn, common.ClientClosed)
   322  					}
   323  				})
   324  			},
   325  		)
   326  		_, err = buf.WriteTo(stream)
   327  		if err != nil {
   328  			_ = stream.Close()
   329  			return nil, err
   330  		}
   331  		return stream, err
   332  	}()
   333  	if err != nil {
   334  		return nil, err
   335  	}
   336  
   337  	return stream, nil
   338  }
   339  
   340  func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {
   341  	quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  	openStreams := t.openStreams.Add(1)
   346  	if openStreams >= t.MaxOpenStreams {
   347  		t.openStreams.Add(-1)
   348  		return nil, common.TooManyOpenStreams
   349  	}
   350  
   351  	pipe1, pipe2 := N.Pipe()
   352  	var connId uint16
   353  	for {
   354  		connId = uint16(fastrand.Intn(0xFFFF))
   355  		_, loaded := t.udpInputMap.LoadOrStore(connId, pipe1)
   356  		if !loaded {
   357  			break
   358  		}
   359  	}
   360  	pc := &quicStreamPacketConn{
   361  		connId:                connId,
   362  		quicConn:              quicConn,
   363  		inputConn:             N.NewBufferedConn(pipe2),
   364  		udpRelayMode:          t.UdpRelayMode,
   365  		maxUdpRelayPacketSize: t.MaxUdpRelayPacketSize,
   366  		deferQuicConnFn:       t.deferQuicConn,
   367  		closeDeferFn: func() {
   368  			t.udpInputMap.Delete(connId)
   369  			time.AfterFunc(C.DefaultUDPTimeout, func() {
   370  				openStreams := t.openStreams.Add(-1)
   371  				if openStreams == 0 && t.closed.Load() {
   372  					t.forceClose(quicConn, common.ClientClosed)
   373  				}
   374  			})
   375  		},
   376  	}
   377  	return pc, nil
   378  }
   379  
   380  type Client struct {
   381  	*clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner
   382  }
   383  
   384  func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) {
   385  	conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn)
   386  	if err != nil {
   387  		return nil, err
   388  	}
   389  	return N.NewRefConn(conn, t), err
   390  }
   391  
   392  func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {
   393  	pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn)
   394  	if err != nil {
   395  		return nil, err
   396  	}
   397  	return N.NewRefPacketConn(pc, t), nil
   398  }
   399  
   400  func (t *Client) forceClose() {
   401  	t.clientImpl.forceClose(nil, common.ClientClosed)
   402  }
   403  
   404  func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client {
   405  	ci := &clientImpl{
   406  		ClientOption: clientOption,
   407  		udp:          udp,
   408  		dialerRef:    dialerRef,
   409  		udpInputMap:  xsync.NewMapOf[uint16, net.Conn](),
   410  	}
   411  	c := &Client{ci}
   412  	runtime.SetFinalizer(c, closeClient)
   413  	log.Debugln("New TuicV5 Client at %p", c)
   414  	return c
   415  }
   416  
   417  func closeClient(client *Client) {
   418  	log.Debugln("Close TuicV5 Client at %p", client)
   419  	client.forceClose()
   420  }