github.com/kelleygo/clashcore@v1.0.2/transport/tuic/v4/client.go (about)

     1  package v4
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"crypto/tls"
     8  	"errors"
     9  	"net"
    10  	"runtime"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	atomic2 "github.com/kelleygo/clashcore/common/atomic"
    16  	N "github.com/kelleygo/clashcore/common/net"
    17  	"github.com/kelleygo/clashcore/common/pool"
    18  	C "github.com/kelleygo/clashcore/constant"
    19  	"github.com/kelleygo/clashcore/log"
    20  	"github.com/kelleygo/clashcore/transport/tuic/common"
    21  
    22  	"github.com/metacubex/quic-go"
    23  	"github.com/puzpuzpuz/xsync/v3"
    24  	"github.com/zhangyunhao116/fastrand"
    25  )
    26  
    27  type ClientOption struct {
    28  	TlsConfig             *tls.Config
    29  	QuicConfig            *quic.Config
    30  	Token                 [32]byte
    31  	UdpRelayMode          common.UdpRelayMode
    32  	CongestionController  string
    33  	ReduceRtt             bool
    34  	RequestTimeout        time.Duration
    35  	MaxUdpRelayPacketSize int
    36  	FastOpen              bool
    37  	MaxOpenStreams        int64
    38  	CWND                  int
    39  }
    40  
    41  type clientImpl struct {
    42  	*ClientOption
    43  	udp bool
    44  
    45  	quicConn  quic.Connection
    46  	connMutex sync.Mutex
    47  
    48  	openStreams atomic.Int64
    49  	closed      atomic.Bool
    50  
    51  	udpInputMap *xsync.MapOf[uint32, net.Conn]
    52  
    53  	// only ready for PoolClient
    54  	dialerRef   C.Dialer
    55  	lastVisited atomic2.TypedValue[time.Time]
    56  }
    57  
    58  func (t *clientImpl) OpenStreams() int64 {
    59  	return t.openStreams.Load()
    60  }
    61  
    62  func (t *clientImpl) DialerRef() C.Dialer {
    63  	return t.dialerRef
    64  }
    65  
    66  func (t *clientImpl) LastVisited() time.Time {
    67  	return t.lastVisited.Load()
    68  }
    69  
    70  func (t *clientImpl) SetLastVisited(last time.Time) {
    71  	t.lastVisited.Store(last)
    72  }
    73  
    74  func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (quic.Connection, error) {
    75  	t.connMutex.Lock()
    76  	defer t.connMutex.Unlock()
    77  	if t.quicConn != nil {
    78  		return t.quicConn, nil
    79  	}
    80  	transport, addr, err := dialFn(ctx, dialer)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  	var quicConn quic.Connection
    85  	if t.ReduceRtt {
    86  		quicConn, err = transport.DialEarly(ctx, addr, t.TlsConfig, t.QuicConfig)
    87  	} else {
    88  		quicConn, err = transport.Dial(ctx, addr, t.TlsConfig, t.QuicConfig)
    89  	}
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	common.SetCongestionController(quicConn, t.CongestionController, t.CWND)
    95  
    96  	go func() {
    97  		_ = t.sendAuthentication(quicConn)
    98  	}()
    99  
   100  	if t.udp {
   101  		go func() {
   102  			switch t.UdpRelayMode {
   103  			case common.QUIC:
   104  				_ = t.handleUniStream(quicConn)
   105  			default: // native
   106  				_ = t.handleMessage(quicConn)
   107  			}
   108  		}()
   109  	}
   110  
   111  	t.quicConn = quicConn
   112  	t.openStreams.Store(0)
   113  	return quicConn, nil
   114  }
   115  
   116  func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) {
   117  	defer func() {
   118  		t.deferQuicConn(quicConn, err)
   119  	}()
   120  	stream, err := quicConn.OpenUniStream()
   121  	if err != nil {
   122  		return err
   123  	}
   124  	buf := pool.GetBuffer()
   125  	defer pool.PutBuffer(buf)
   126  	err = NewAuthenticate(t.Token).WriteTo(buf)
   127  	if err != nil {
   128  		return err
   129  	}
   130  	_, err = buf.WriteTo(stream)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	err = stream.Close()
   135  	if err != nil {
   136  		return
   137  	}
   138  	return nil
   139  }
   140  
   141  func (t *clientImpl) handleUniStream(quicConn quic.Connection) (err error) {
   142  	defer func() {
   143  		t.deferQuicConn(quicConn, err)
   144  	}()
   145  	for {
   146  		var stream quic.ReceiveStream
   147  		stream, err = quicConn.AcceptUniStream(context.Background())
   148  		if err != nil {
   149  			return err
   150  		}
   151  		go func() (err error) {
   152  			var assocId uint32
   153  			defer func() {
   154  				t.deferQuicConn(quicConn, err)
   155  				if err != nil && assocId != 0 {
   156  					if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok {
   157  						if conn, ok := val.(net.Conn); ok {
   158  							_ = conn.Close()
   159  						}
   160  					}
   161  				}
   162  				stream.CancelRead(0)
   163  			}()
   164  			reader := bufio.NewReader(stream)
   165  			commandHead, err := ReadCommandHead(reader)
   166  			if err != nil {
   167  				return
   168  			}
   169  			switch commandHead.TYPE {
   170  			case PacketType:
   171  				var packet Packet
   172  				packet, err = ReadPacketWithHead(commandHead, reader)
   173  				if err != nil {
   174  					return
   175  				}
   176  				if t.udp && t.UdpRelayMode == common.QUIC {
   177  					assocId = packet.ASSOC_ID
   178  					if val, ok := t.udpInputMap.Load(assocId); ok {
   179  						if conn, ok := val.(net.Conn); ok {
   180  							writer := bufio.NewWriterSize(conn, packet.BytesLen())
   181  							_ = packet.WriteTo(writer)
   182  							_ = writer.Flush()
   183  						}
   184  					}
   185  				}
   186  			}
   187  			return
   188  		}()
   189  	}
   190  }
   191  
   192  func (t *clientImpl) handleMessage(quicConn quic.Connection) (err error) {
   193  	defer func() {
   194  		t.deferQuicConn(quicConn, err)
   195  	}()
   196  	for {
   197  		var message []byte
   198  		message, err = quicConn.ReceiveDatagram(context.Background())
   199  		if err != nil {
   200  			return err
   201  		}
   202  		go func() (err error) {
   203  			var assocId uint32
   204  			defer func() {
   205  				t.deferQuicConn(quicConn, err)
   206  				if err != nil && assocId != 0 {
   207  					if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok {
   208  						if conn, ok := val.(net.Conn); ok {
   209  							_ = conn.Close()
   210  						}
   211  					}
   212  				}
   213  			}()
   214  			reader := bytes.NewBuffer(message)
   215  			commandHead, err := ReadCommandHead(reader)
   216  			if err != nil {
   217  				return
   218  			}
   219  			switch commandHead.TYPE {
   220  			case PacketType:
   221  				var packet Packet
   222  				packet, err = ReadPacketWithHead(commandHead, reader)
   223  				if err != nil {
   224  					return
   225  				}
   226  				if t.udp && t.UdpRelayMode == common.NATIVE {
   227  					assocId = packet.ASSOC_ID
   228  					if val, ok := t.udpInputMap.Load(assocId); ok {
   229  						if conn, ok := val.(net.Conn); ok {
   230  							_, _ = conn.Write(message)
   231  						}
   232  					}
   233  				}
   234  			}
   235  			return
   236  		}()
   237  	}
   238  }
   239  
   240  func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) {
   241  	var netError net.Error
   242  	if err != nil && errors.As(err, &netError) {
   243  		t.forceClose(quicConn, err)
   244  	}
   245  }
   246  
   247  func (t *clientImpl) forceClose(quicConn quic.Connection, err error) {
   248  	t.connMutex.Lock()
   249  	defer t.connMutex.Unlock()
   250  	if quicConn == nil {
   251  		quicConn = t.quicConn
   252  	}
   253  	if quicConn != nil {
   254  		if quicConn == t.quicConn {
   255  			t.quicConn = nil
   256  		}
   257  	}
   258  	errStr := ""
   259  	if err != nil {
   260  		errStr = err.Error()
   261  	}
   262  	if quicConn != nil {
   263  		_ = quicConn.CloseWithError(ProtocolError, errStr)
   264  	}
   265  	udpInputMap := t.udpInputMap
   266  	udpInputMap.Range(func(key uint32, value net.Conn) bool {
   267  		conn := value
   268  		_ = conn.Close()
   269  		udpInputMap.Delete(key)
   270  		return true
   271  	})
   272  }
   273  
   274  func (t *clientImpl) Close() {
   275  	t.closed.Store(true)
   276  	if t.openStreams.Load() == 0 {
   277  		t.forceClose(nil, common.ClientClosed)
   278  	}
   279  }
   280  
   281  func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) {
   282  	quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  	openStreams := t.openStreams.Add(1)
   287  	if openStreams >= t.MaxOpenStreams {
   288  		t.openStreams.Add(-1)
   289  		return nil, common.TooManyOpenStreams
   290  	}
   291  	stream, err := func() (stream net.Conn, err error) {
   292  		defer func() {
   293  			t.deferQuicConn(quicConn, err)
   294  		}()
   295  		buf := pool.GetBuffer()
   296  		defer pool.PutBuffer(buf)
   297  		err = NewConnect(NewAddress(metadata)).WriteTo(buf)
   298  		if err != nil {
   299  			return nil, err
   300  		}
   301  		quicStream, err := quicConn.OpenStream()
   302  		if err != nil {
   303  			return nil, err
   304  		}
   305  		stream = common.NewQuicStreamConn(
   306  			quicStream,
   307  			quicConn.LocalAddr(),
   308  			quicConn.RemoteAddr(),
   309  			func() {
   310  				time.AfterFunc(C.DefaultTCPTimeout, func() {
   311  					openStreams := t.openStreams.Add(-1)
   312  					if openStreams == 0 && t.closed.Load() {
   313  						t.forceClose(quicConn, common.ClientClosed)
   314  					}
   315  				})
   316  			},
   317  		)
   318  		_, err = buf.WriteTo(stream)
   319  		if err != nil {
   320  			_ = stream.Close()
   321  			return nil, err
   322  		}
   323  		return stream, err
   324  	}()
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  
   329  	bufConn := N.NewBufferedConn(stream)
   330  	response := func() error {
   331  		if t.RequestTimeout > 0 {
   332  			_ = bufConn.SetReadDeadline(time.Now().Add(t.RequestTimeout))
   333  		}
   334  		response, err := ReadResponse(bufConn)
   335  		if err != nil {
   336  			_ = bufConn.Close()
   337  			return err
   338  		}
   339  		if response.IsFailed() {
   340  			_ = bufConn.Close()
   341  			return errors.New("connect failed")
   342  		}
   343  		_ = bufConn.SetReadDeadline(time.Time{})
   344  		return nil
   345  	}
   346  	if t.FastOpen {
   347  		return N.NewEarlyConn(bufConn, response), nil
   348  	}
   349  	err = response()
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  	return bufConn, nil
   354  }
   355  
   356  func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {
   357  	quicConn, err := t.getQuicConn(ctx, dialer, dialFn)
   358  	if err != nil {
   359  		return nil, err
   360  	}
   361  	openStreams := t.openStreams.Add(1)
   362  	if openStreams >= t.MaxOpenStreams {
   363  		t.openStreams.Add(-1)
   364  		return nil, common.TooManyOpenStreams
   365  	}
   366  
   367  	pipe1, pipe2 := N.Pipe()
   368  	var connId uint32
   369  	for {
   370  		connId = fastrand.Uint32()
   371  		_, loaded := t.udpInputMap.LoadOrStore(connId, pipe1)
   372  		if !loaded {
   373  			break
   374  		}
   375  	}
   376  	pc := &quicStreamPacketConn{
   377  		connId:                connId,
   378  		quicConn:              quicConn,
   379  		inputConn:             N.NewBufferedConn(pipe2),
   380  		udpRelayMode:          t.UdpRelayMode,
   381  		maxUdpRelayPacketSize: t.MaxUdpRelayPacketSize,
   382  		deferQuicConnFn:       t.deferQuicConn,
   383  		closeDeferFn: func() {
   384  			t.udpInputMap.Delete(connId)
   385  			time.AfterFunc(C.DefaultUDPTimeout, func() {
   386  				openStreams := t.openStreams.Add(-1)
   387  				if openStreams == 0 && t.closed.Load() {
   388  					t.forceClose(quicConn, common.ClientClosed)
   389  				}
   390  			})
   391  		},
   392  	}
   393  	return pc, nil
   394  }
   395  
   396  type Client struct {
   397  	*clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner
   398  }
   399  
   400  func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) {
   401  	conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn)
   402  	if err != nil {
   403  		return nil, err
   404  	}
   405  	return N.NewRefConn(conn, t), err
   406  }
   407  
   408  func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) {
   409  	pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn)
   410  	if err != nil {
   411  		return nil, err
   412  	}
   413  	return N.NewRefPacketConn(pc, t), nil
   414  }
   415  
   416  func (t *Client) forceClose() {
   417  	t.clientImpl.forceClose(nil, common.ClientClosed)
   418  }
   419  
   420  func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client {
   421  	ci := &clientImpl{
   422  		ClientOption: clientOption,
   423  		udp:          udp,
   424  		dialerRef:    dialerRef,
   425  		udpInputMap:  xsync.NewMapOf[uint32, net.Conn](),
   426  	}
   427  	c := &Client{ci}
   428  	runtime.SetFinalizer(c, closeClient)
   429  	log.Debugln("New TuicV4 Client at %p", c)
   430  	return c
   431  }
   432  
   433  func closeClient(client *Client) {
   434  	log.Debugln("Close TuicV4 Client at %p", client)
   435  	client.forceClose()
   436  }