github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/tuic/client.go (about)

     1  //go:build with_quic
     2  
     3  package tuic
     4  
     5  import (
     6  	"context"
     7  	"github.com/sagernet/quic-go"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"runtime"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/inazumav/sing-box/common/baderror"
    16  	"github.com/inazumav/sing-box/common/qtls"
    17  	"github.com/inazumav/sing-box/common/tls"
    18  	"github.com/sagernet/sing/common"
    19  	"github.com/sagernet/sing/common/buf"
    20  	"github.com/sagernet/sing/common/bufio"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	M "github.com/sagernet/sing/common/metadata"
    23  	N "github.com/sagernet/sing/common/network"
    24  
    25  	"github.com/gofrs/uuid/v5"
    26  )
    27  
    28  type ClientOptions struct {
    29  	Context           context.Context
    30  	Dialer            N.Dialer
    31  	ServerAddress     M.Socksaddr
    32  	TLSConfig         tls.Config
    33  	UUID              uuid.UUID
    34  	Password          string
    35  	CongestionControl string
    36  	UDPStream         bool
    37  	ZeroRTTHandshake  bool
    38  	Heartbeat         time.Duration
    39  }
    40  
    41  type Client struct {
    42  	ctx               context.Context
    43  	dialer            N.Dialer
    44  	serverAddr        M.Socksaddr
    45  	tlsConfig         tls.Config
    46  	quicConfig        *quic.Config
    47  	uuid              uuid.UUID
    48  	password          string
    49  	congestionControl string
    50  	udpStream         bool
    51  	zeroRTTHandshake  bool
    52  	heartbeat         time.Duration
    53  
    54  	connAccess sync.RWMutex
    55  	conn       *clientQUICConnection
    56  }
    57  
    58  func NewClient(options ClientOptions) (*Client, error) {
    59  	if options.Heartbeat == 0 {
    60  		options.Heartbeat = 10 * time.Second
    61  	}
    62  	quicConfig := &quic.Config{
    63  		DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
    64  		MaxDatagramFrameSize:    1400,
    65  		EnableDatagrams:         true,
    66  		MaxIncomingUniStreams:   1 << 60,
    67  	}
    68  	switch options.CongestionControl {
    69  	case "":
    70  		options.CongestionControl = "cubic"
    71  	case "cubic", "new_reno", "bbr":
    72  	default:
    73  		return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
    74  	}
    75  	return &Client{
    76  		ctx:               options.Context,
    77  		dialer:            options.Dialer,
    78  		serverAddr:        options.ServerAddress,
    79  		tlsConfig:         options.TLSConfig,
    80  		quicConfig:        quicConfig,
    81  		uuid:              options.UUID,
    82  		password:          options.Password,
    83  		congestionControl: options.CongestionControl,
    84  		udpStream:         options.UDPStream,
    85  		zeroRTTHandshake:  options.ZeroRTTHandshake,
    86  		heartbeat:         options.Heartbeat,
    87  	}, nil
    88  }
    89  
    90  func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) {
    91  	conn := c.conn
    92  	if conn != nil && conn.active() {
    93  		return conn, nil
    94  	}
    95  	c.connAccess.Lock()
    96  	defer c.connAccess.Unlock()
    97  	conn = c.conn
    98  	if conn != nil && conn.active() {
    99  		return conn, nil
   100  	}
   101  	conn, err := c.offerNew(ctx)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	return conn, nil
   106  }
   107  
   108  func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) {
   109  	udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	var quicConn quic.Connection
   114  	if c.zeroRTTHandshake {
   115  		quicConn, err = qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
   116  	} else {
   117  		quicConn, err = qtls.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig)
   118  	}
   119  	if err != nil {
   120  		udpConn.Close()
   121  		return nil, E.Cause(err, "open connection")
   122  	}
   123  	setCongestion(c.ctx, quicConn, c.congestionControl)
   124  	conn := &clientQUICConnection{
   125  		quicConn:   quicConn,
   126  		rawConn:    udpConn,
   127  		connDone:   make(chan struct{}),
   128  		udpConnMap: make(map[uint16]*udpPacketConn),
   129  	}
   130  	go func() {
   131  		hErr := c.clientHandshake(quicConn)
   132  		if hErr != nil {
   133  			conn.closeWithError(hErr)
   134  		}
   135  	}()
   136  	if c.udpStream {
   137  		go c.loopUniStreams(conn)
   138  	}
   139  	go c.loopMessages(conn)
   140  	go c.loopHeartbeats(conn)
   141  	c.conn = conn
   142  	return conn, nil
   143  }
   144  
   145  func (c *Client) clientHandshake(conn quic.Connection) error {
   146  	authStream, err := conn.OpenUniStream()
   147  	if err != nil {
   148  		return E.Cause(err, "open handshake stream")
   149  	}
   150  	defer authStream.Close()
   151  	handshakeState := conn.ConnectionState()
   152  	tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32)
   153  	if err != nil {
   154  		return E.Cause(err, "export keying material")
   155  	}
   156  	authRequest := buf.NewSize(AuthenticateLen)
   157  	authRequest.WriteByte(Version)
   158  	authRequest.WriteByte(CommandAuthenticate)
   159  	authRequest.Write(c.uuid[:])
   160  	authRequest.Write(tuicAuthToken)
   161  	return common.Error(authStream.Write(authRequest.Bytes()))
   162  }
   163  
   164  func (c *Client) loopHeartbeats(conn *clientQUICConnection) {
   165  	ticker := time.NewTicker(c.heartbeat)
   166  	defer ticker.Stop()
   167  	for {
   168  		select {
   169  		case <-conn.connDone:
   170  			return
   171  		case <-ticker.C:
   172  			err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
   173  			if err != nil {
   174  				conn.closeWithError(E.Cause(err, "send heartbeat"))
   175  			}
   176  		}
   177  	}
   178  }
   179  
   180  func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
   181  	conn, err := c.offer(ctx)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  	stream, err := conn.quicConn.OpenStream()
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  	return &clientConn{
   190  		parent:      conn,
   191  		stream:      stream,
   192  		destination: destination,
   193  	}, nil
   194  }
   195  
   196  func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) {
   197  	conn, err := c.offer(ctx)
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  	var sessionID uint16
   202  	clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() {
   203  		conn.udpAccess.Lock()
   204  		delete(conn.udpConnMap, sessionID)
   205  		conn.udpAccess.Unlock()
   206  	})
   207  	conn.udpAccess.Lock()
   208  	sessionID = conn.udpSessionID
   209  	conn.udpSessionID++
   210  	conn.udpConnMap[sessionID] = clientPacketConn
   211  	conn.udpAccess.Unlock()
   212  	clientPacketConn.sessionID = sessionID
   213  	return clientPacketConn, nil
   214  }
   215  
   216  func (c *Client) CloseWithError(err error) error {
   217  	conn := c.conn
   218  	if conn != nil {
   219  		conn.closeWithError(err)
   220  	}
   221  	return nil
   222  }
   223  
   224  type clientQUICConnection struct {
   225  	quicConn     quic.Connection
   226  	rawConn      io.Closer
   227  	closeOnce    sync.Once
   228  	connDone     chan struct{}
   229  	connErr      error
   230  	udpAccess    sync.RWMutex
   231  	udpConnMap   map[uint16]*udpPacketConn
   232  	udpSessionID uint16
   233  }
   234  
   235  func (c *clientQUICConnection) active() bool {
   236  	select {
   237  	case <-c.quicConn.Context().Done():
   238  		return false
   239  	default:
   240  	}
   241  	select {
   242  	case <-c.connDone:
   243  		return false
   244  	default:
   245  	}
   246  	return true
   247  }
   248  
   249  func (c *clientQUICConnection) closeWithError(err error) {
   250  	c.closeOnce.Do(func() {
   251  		c.connErr = err
   252  		close(c.connDone)
   253  		_ = c.quicConn.CloseWithError(0, "")
   254  		_ = c.rawConn.Close()
   255  	})
   256  }
   257  
   258  type clientConn struct {
   259  	parent         *clientQUICConnection
   260  	stream         quic.Stream
   261  	destination    M.Socksaddr
   262  	requestWritten bool
   263  }
   264  
   265  func (c *clientConn) NeedHandshake() bool {
   266  	return !c.requestWritten
   267  }
   268  
   269  func (c *clientConn) Read(b []byte) (n int, err error) {
   270  	n, err = c.stream.Read(b)
   271  	return n, baderror.WrapQUIC(err)
   272  }
   273  
   274  func (c *clientConn) Write(b []byte) (n int, err error) {
   275  	if !c.requestWritten {
   276  		request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b))
   277  		defer request.Release()
   278  		request.WriteByte(Version)
   279  		request.WriteByte(CommandConnect)
   280  		err = addressSerializer.WriteAddrPort(request, c.destination)
   281  		if err != nil {
   282  			return
   283  		}
   284  		request.Write(b)
   285  		_, err = c.stream.Write(request.Bytes())
   286  		if err != nil {
   287  			c.parent.closeWithError(E.Cause(err, "create new connection"))
   288  			return 0, baderror.WrapQUIC(err)
   289  		}
   290  		c.requestWritten = true
   291  		return len(b), nil
   292  	}
   293  	n, err = c.stream.Write(b)
   294  	return n, baderror.WrapQUIC(err)
   295  }
   296  
   297  func (c *clientConn) Close() error {
   298  	stream := c.stream
   299  	if stream == nil {
   300  		return nil
   301  	}
   302  	stream.CancelRead(0)
   303  	return stream.Close()
   304  }
   305  
   306  func (c *clientConn) LocalAddr() net.Addr {
   307  	return M.Socksaddr{}
   308  }
   309  
   310  func (c *clientConn) RemoteAddr() net.Addr {
   311  	return c.destination
   312  }
   313  
   314  func (c *clientConn) SetDeadline(t time.Time) error {
   315  	if c.stream == nil {
   316  		return os.ErrInvalid
   317  	}
   318  	return c.stream.SetDeadline(t)
   319  }
   320  
   321  func (c *clientConn) SetReadDeadline(t time.Time) error {
   322  	if c.stream == nil {
   323  		return os.ErrInvalid
   324  	}
   325  	return c.stream.SetReadDeadline(t)
   326  }
   327  
   328  func (c *clientConn) SetWriteDeadline(t time.Time) error {
   329  	if c.stream == nil {
   330  		return os.ErrInvalid
   331  	}
   332  	return c.stream.SetWriteDeadline(t)
   333  }