github.com/yaling888/clash@v1.53.0/transport/quic/quic.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"syscall"
    11  	"time"
    12  
    13  	"github.com/quic-go/quic-go"
    14  
    15  	"github.com/yaling888/clash/common/pool"
    16  	"github.com/yaling888/clash/component/resolver"
    17  	C "github.com/yaling888/clash/constant"
    18  	"github.com/yaling888/clash/transport/crypto"
    19  	"github.com/yaling888/clash/transport/header"
    20  )
    21  
    22  var defaultALPN = []string{"h3", "h3-29", "h3-Q050", "h3-Q046", "h3-Q043", "hq-interop", "quic"}
    23  
    24  type Config struct {
    25  	Header         string
    26  	AEAD           *crypto.AEAD
    27  	Host           string
    28  	Port           int
    29  	ALPN           []string
    30  	ServerName     string
    31  	SkipCertVerify bool
    32  }
    33  
    34  var _ net.PacketConn = (*rawConn)(nil)
    35  
    36  type rawConn struct {
    37  	net.PacketConn
    38  	header header.Header
    39  	cipher *crypto.AEAD
    40  }
    41  
    42  func (rc *rawConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    43  	obfs := rc.header
    44  	cipher := rc.cipher
    45  	if obfs == nil && cipher == nil {
    46  		return rc.PacketConn.ReadFrom(p)
    47  	}
    48  
    49  	bufP := pool.GetBufferWriter()
    50  	defer pool.PutBufferWriter(bufP)
    51  
    52  	offset := 0
    53  	if obfs != nil {
    54  		offset = obfs.Size()
    55  	}
    56  
    57  	bufP.Grow(offset + len(p))
    58  	if cipher != nil {
    59  		bufP.Grow(cipher.NonceSize() + cipher.Overhead())
    60  	}
    61  
    62  	for {
    63  		n, addr, err = rc.PacketConn.ReadFrom(*bufP)
    64  		if n <= offset {
    65  			if err != nil {
    66  				return
    67  			}
    68  			continue
    69  		}
    70  
    71  		if cipher == nil {
    72  			nr := n - offset
    73  			n = copy(p, bufP.Bytes()[offset:n])
    74  			if n < nr && err == nil {
    75  				err = io.ErrShortBuffer
    76  			}
    77  			return
    78  		}
    79  
    80  		b, er := cipher.Decrypt(bufP.Bytes()[offset:n])
    81  		if er != nil {
    82  			if err != nil {
    83  				return
    84  			}
    85  			continue
    86  		}
    87  
    88  		n = copy(p, b)
    89  		if n < len(b) {
    90  			err = io.ErrShortBuffer
    91  		}
    92  		return
    93  	}
    94  }
    95  
    96  func (rc *rawConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    97  	obfs := rc.header
    98  	cipher := rc.cipher
    99  	if obfs == nil && cipher == nil {
   100  		return rc.PacketConn.WriteTo(p, addr)
   101  	}
   102  
   103  	bufP := pool.GetBufferWriter()
   104  	defer pool.PutBufferWriter(bufP)
   105  
   106  	if obfs != nil {
   107  		bufP.Grow(obfs.Size())
   108  		obfs.Fill(bufP.Bytes())
   109  	}
   110  
   111  	if cipher != nil {
   112  		_, err = cipher.Encrypt(bufP, p)
   113  		if err != nil {
   114  			return
   115  		}
   116  	} else {
   117  		bufP.PutSlice(p)
   118  	}
   119  
   120  	lenP := len(p)
   121  	delta := bufP.Len() - lenP
   122  	nw, err := rc.PacketConn.WriteTo(bufP.Bytes(), addr)
   123  	n = max(nw-delta, 0)
   124  	if n < lenP && err == nil {
   125  		err = io.ErrShortWrite
   126  	}
   127  	return
   128  }
   129  
   130  func (rc *rawConn) Close() error {
   131  	rc.header = nil
   132  	rc.cipher = nil
   133  	return rc.PacketConn.Close()
   134  }
   135  
   136  func (rc *rawConn) SyscallConn() (syscall.RawConn, error) {
   137  	if c, ok := rc.PacketConn.(*net.UDPConn); ok {
   138  		return c.SyscallConn()
   139  	}
   140  	return nil, syscall.EINVAL
   141  }
   142  
   143  var _ net.Conn = (*quicConn)(nil)
   144  
   145  type quicConn struct {
   146  	conn      quic.Connection
   147  	stream    quic.Stream
   148  	transport *quic.Transport
   149  }
   150  
   151  func (qc *quicConn) Read(b []byte) (n int, err error) {
   152  	return qc.stream.Read(b)
   153  }
   154  
   155  func (qc *quicConn) Write(b []byte) (n int, err error) {
   156  	return qc.stream.Write(b)
   157  }
   158  
   159  func (qc *quicConn) Close() error {
   160  	_ = qc.stream.Close()
   161  	_ = qc.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
   162  
   163  	if err := qc.transport.Close(); err != nil {
   164  		return err
   165  	}
   166  	if err := qc.transport.Conn.Close(); err != nil {
   167  		return err
   168  	}
   169  	return nil
   170  }
   171  
   172  func (qc *quicConn) LocalAddr() net.Addr {
   173  	return qc.conn.LocalAddr()
   174  }
   175  
   176  func (qc *quicConn) RemoteAddr() net.Addr {
   177  	return qc.conn.RemoteAddr()
   178  }
   179  
   180  func (qc *quicConn) SetDeadline(t time.Time) error {
   181  	return qc.stream.SetDeadline(t)
   182  }
   183  
   184  func (qc *quicConn) SetReadDeadline(t time.Time) error {
   185  	return qc.stream.SetReadDeadline(t)
   186  }
   187  
   188  func (qc *quicConn) SetWriteDeadline(t time.Time) error {
   189  	return qc.stream.SetWriteDeadline(t)
   190  }
   191  
   192  func StreamQUICConn(conn net.Conn, cfg *Config) (net.Conn, error) {
   193  	pc, ok := conn.(net.PacketConn)
   194  	if !ok {
   195  		return nil, errors.New("conn is not a net.PacketConn")
   196  	}
   197  
   198  	hd, err := header.New(cfg.Header)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  
   203  	ip, err := resolver.ResolveProxyServerHost(cfg.Host)
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  
   208  	alpn := defaultALPN
   209  	if len(cfg.ALPN) != 0 {
   210  		alpn = cfg.ALPN
   211  	}
   212  
   213  	serverName := cfg.Host
   214  	if cfg.ServerName != "" {
   215  		serverName = cfg.ServerName
   216  	}
   217  
   218  	tlsConfig := &tls.Config{
   219  		NextProtos:         alpn,
   220  		ServerName:         serverName,
   221  		InsecureSkipVerify: cfg.SkipCertVerify,
   222  		MinVersion:         tls.VersionTLS13,
   223  	}
   224  
   225  	quicConfig := &quic.Config{
   226  		// Allow0RTT:               true,
   227  		// EnableDatagrams:         true,
   228  		// DisablePathMTUDiscovery: true,
   229  		MaxIdleTimeout:       60 * time.Second,
   230  		KeepAlivePeriod:      15 * time.Second,
   231  		HandshakeIdleTimeout: C.DefaultTLSTimeout,
   232  	}
   233  
   234  	var rConn net.PacketConn
   235  	if cfg.AEAD == nil && hd == nil {
   236  		rConn = pc
   237  	} else {
   238  		rConn = &rawConn{
   239  			PacketConn: pc,
   240  			header:     hd,
   241  			cipher:     cfg.AEAD,
   242  		}
   243  	}
   244  
   245  	transport := &quic.Transport{
   246  		Conn:               rConn,
   247  		ConnectionIDLength: 12,
   248  	}
   249  
   250  	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
   251  	defer cancel()
   252  
   253  	qConn, err := transport.Dial(ctx, &net.UDPAddr{IP: ip.AsSlice(), Port: cfg.Port}, tlsConfig, quicConfig)
   254  	if err != nil {
   255  		return nil, fmt.Errorf("quic dial -> %s:%d error: %w", ip, cfg.Port, err)
   256  	}
   257  
   258  	stream, err := qConn.OpenStream()
   259  	if err != nil {
   260  		return nil, fmt.Errorf("quic open stream -> %s:%d error: %w", ip, cfg.Port, err)
   261  	}
   262  
   263  	return &quicConn{
   264  		conn:      qConn,
   265  		stream:    stream,
   266  		transport: transport,
   267  	}, nil
   268  }