github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/common/protocol/quic/sniff.go (about)

     1  package quic
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/aes"
     6  	"crypto/tls"
     7  	"encoding/binary"
     8  	"io"
     9  
    10  	"github.com/quic-go/quic-go/quicvarint"
    11  	"github.com/xtls/xray-core/common"
    12  	"github.com/xtls/xray-core/common/buf"
    13  	"github.com/xtls/xray-core/common/bytespool"
    14  	"github.com/xtls/xray-core/common/errors"
    15  	ptls "github.com/xtls/xray-core/common/protocol/tls"
    16  	"golang.org/x/crypto/hkdf"
    17  )
    18  
    19  type SniffHeader struct {
    20  	domain string
    21  }
    22  
    23  func (s SniffHeader) Protocol() string {
    24  	return "quic"
    25  }
    26  
    27  func (s SniffHeader) Domain() string {
    28  	return s.domain
    29  }
    30  
    31  const (
    32  	versionDraft29 uint32 = 0xff00001d
    33  	version1       uint32 = 0x1
    34  )
    35  
    36  var (
    37  	quicSaltOld  = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
    38  	quicSalt     = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
    39  	initialSuite = &CipherSuiteTLS13{
    40  		ID:     tls.TLS_AES_128_GCM_SHA256,
    41  		KeyLen: 16,
    42  		AEAD:   AEADAESGCMTLS13,
    43  		Hash:   crypto.SHA256,
    44  	}
    45  	errNotQuic        = errors.New("not quic")
    46  	errNotQuicInitial = errors.New("not initial packet")
    47  )
    48  
    49  func SniffQUIC(b []byte) (*SniffHeader, error) {
    50  	buffer := buf.FromBytes(b)
    51  	typeByte, err := buffer.ReadByte()
    52  	if err != nil {
    53  		return nil, errNotQuic
    54  	}
    55  	isLongHeader := typeByte&0x80 > 0
    56  	if !isLongHeader || typeByte&0x40 == 0 {
    57  		return nil, errNotQuicInitial
    58  	}
    59  
    60  	vb, err := buffer.ReadBytes(4)
    61  	if err != nil {
    62  		return nil, errNotQuic
    63  	}
    64  
    65  	versionNumber := binary.BigEndian.Uint32(vb)
    66  
    67  	if versionNumber != 0 && typeByte&0x40 == 0 {
    68  		return nil, errNotQuic
    69  	} else if versionNumber != versionDraft29 && versionNumber != version1 {
    70  		return nil, errNotQuic
    71  	}
    72  
    73  	if (typeByte&0x30)>>4 != 0x0 {
    74  		return nil, errNotQuicInitial
    75  	}
    76  
    77  	var destConnID []byte
    78  	if l, err := buffer.ReadByte(); err != nil {
    79  		return nil, errNotQuic
    80  	} else if destConnID, err = buffer.ReadBytes(int32(l)); err != nil {
    81  		return nil, errNotQuic
    82  	}
    83  
    84  	if l, err := buffer.ReadByte(); err != nil {
    85  		return nil, errNotQuic
    86  	} else if common.Error2(buffer.ReadBytes(int32(l))) != nil {
    87  		return nil, errNotQuic
    88  	}
    89  
    90  	tokenLen, err := quicvarint.Read(buffer)
    91  	if err != nil || tokenLen > uint64(len(b)) {
    92  		return nil, errNotQuic
    93  	}
    94  
    95  	if _, err = buffer.ReadBytes(int32(tokenLen)); err != nil {
    96  		return nil, errNotQuic
    97  	}
    98  
    99  	packetLen, err := quicvarint.Read(buffer)
   100  	if err != nil {
   101  		return nil, errNotQuic
   102  	}
   103  
   104  	hdrLen := len(b) - int(buffer.Len())
   105  
   106  	origPNBytes := make([]byte, 4)
   107  	copy(origPNBytes, b[hdrLen:hdrLen+4])
   108  
   109  	var salt []byte
   110  	if versionNumber == version1 {
   111  		salt = quicSalt
   112  	} else {
   113  		salt = quicSaltOld
   114  	}
   115  	initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
   116  	secret := hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
   117  	hpKey := hkdfExpandLabel(initialSuite.Hash, secret, []byte{}, "quic hp", initialSuite.KeyLen)
   118  	block, err := aes.NewCipher(hpKey)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	cache := buf.New()
   124  	defer cache.Release()
   125  
   126  	mask := cache.Extend(int32(block.BlockSize()))
   127  	block.Encrypt(mask, b[hdrLen+4:hdrLen+4+16])
   128  	b[0] ^= mask[0] & 0xf
   129  	for i := range b[hdrLen : hdrLen+4] {
   130  		b[hdrLen+i] ^= mask[i+1]
   131  	}
   132  	packetNumberLength := b[0]&0x3 + 1
   133  	if packetNumberLength != 1 {
   134  		return nil, errNotQuicInitial
   135  	}
   136  	var packetNumber uint32
   137  	{
   138  		n, err := buffer.ReadByte()
   139  		if err != nil {
   140  			return nil, err
   141  		}
   142  		packetNumber = uint32(n)
   143  	}
   144  
   145  	if packetNumber != 0 && packetNumber != 1 {
   146  		return nil, errNotQuicInitial
   147  	}
   148  
   149  	extHdrLen := hdrLen + int(packetNumberLength)
   150  	copy(b[extHdrLen:hdrLen+4], origPNBytes[packetNumberLength:])
   151  	data := b[extHdrLen : int(packetLen)+hdrLen]
   152  
   153  	key := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
   154  	iv := hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
   155  	cipher := AEADAESGCMTLS13(key, iv)
   156  	nonce := cache.Extend(int32(cipher.NonceSize()))
   157  	binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
   158  	decrypted, err := cipher.Open(b[extHdrLen:extHdrLen], nonce, data, b[:extHdrLen])
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	buffer = buf.FromBytes(decrypted)
   163  
   164  	cryptoLen := uint(0)
   165  	cryptoData := bytespool.Alloc(buffer.Len())
   166  	defer bytespool.Free(cryptoData)
   167  	for i := 0; !buffer.IsEmpty(); i++ {
   168  		frameType := byte(0x0) // Default to PADDING frame
   169  		for frameType == 0x0 && !buffer.IsEmpty() {
   170  			frameType, _ = buffer.ReadByte()
   171  		}
   172  		switch frameType {
   173  		case 0x00: // PADDING frame
   174  		case 0x01: // PING frame
   175  		case 0x02, 0x03: // ACK frame
   176  			if _, err = quicvarint.Read(buffer); err != nil { // Field: Largest Acknowledged
   177  				return nil, io.ErrUnexpectedEOF
   178  			}
   179  			if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Delay
   180  				return nil, io.ErrUnexpectedEOF
   181  			}
   182  			ackRangeCount, err := quicvarint.Read(buffer) // Field: ACK Range Count
   183  			if err != nil {
   184  				return nil, io.ErrUnexpectedEOF
   185  			}
   186  			if _, err = quicvarint.Read(buffer); err != nil { // Field: First ACK Range
   187  				return nil, io.ErrUnexpectedEOF
   188  			}
   189  			for i := 0; i < int(ackRangeCount); i++ { // Field: ACK Range
   190  				if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Range -> Gap
   191  					return nil, io.ErrUnexpectedEOF
   192  				}
   193  				if _, err = quicvarint.Read(buffer); err != nil { // Field: ACK Range -> ACK Range Length
   194  					return nil, io.ErrUnexpectedEOF
   195  				}
   196  			}
   197  			if frameType == 0x03 {
   198  				if _, err = quicvarint.Read(buffer); err != nil { // Field: ECN Counts -> ECT0 Count
   199  					return nil, io.ErrUnexpectedEOF
   200  				}
   201  				if _, err = quicvarint.Read(buffer); err != nil { // Field: ECN Counts -> ECT1 Count
   202  					return nil, io.ErrUnexpectedEOF
   203  				}
   204  				if _, err = quicvarint.Read(buffer); err != nil { //nolint:misspell // Field: ECN Counts -> ECT-CE Count
   205  					return nil, io.ErrUnexpectedEOF
   206  				}
   207  			}
   208  		case 0x06: // CRYPTO frame, we will use this frame
   209  			offset, err := quicvarint.Read(buffer) // Field: Offset
   210  			if err != nil {
   211  				return nil, io.ErrUnexpectedEOF
   212  			}
   213  			length, err := quicvarint.Read(buffer) // Field: Length
   214  			if err != nil || length > uint64(buffer.Len()) {
   215  				return nil, io.ErrUnexpectedEOF
   216  			}
   217  			if cryptoLen < uint(offset+length) {
   218  				cryptoLen = uint(offset + length)
   219  			}
   220  			if _, err := buffer.Read(cryptoData[offset : offset+length]); err != nil { // Field: Crypto Data
   221  				return nil, io.ErrUnexpectedEOF
   222  			}
   223  		case 0x1c: // CONNECTION_CLOSE frame, only 0x1c is permitted in initial packet
   224  			if _, err = quicvarint.Read(buffer); err != nil { // Field: Error Code
   225  				return nil, io.ErrUnexpectedEOF
   226  			}
   227  			if _, err = quicvarint.Read(buffer); err != nil { // Field: Frame Type
   228  				return nil, io.ErrUnexpectedEOF
   229  			}
   230  			length, err := quicvarint.Read(buffer) // Field: Reason Phrase Length
   231  			if err != nil {
   232  				return nil, io.ErrUnexpectedEOF
   233  			}
   234  			if _, err := buffer.ReadBytes(int32(length)); err != nil { // Field: Reason Phrase
   235  				return nil, io.ErrUnexpectedEOF
   236  			}
   237  		default:
   238  			// Only above frame types are permitted in initial packet.
   239  			// See https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2.2-8
   240  			return nil, errNotQuicInitial
   241  		}
   242  	}
   243  
   244  	tlsHdr := &ptls.SniffHeader{}
   245  	err = ptls.ReadClientHello(cryptoData[:cryptoLen], tlsHdr)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	return &SniffHeader{domain: tlsHdr.Domain()}, nil
   250  }
   251  
   252  func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
   253  	b := make([]byte, 3, 3+6+len(label)+1+len(context))
   254  	binary.BigEndian.PutUint16(b, uint16(length))
   255  	b[2] = uint8(6 + len(label))
   256  	b = append(b, []byte("tls13 ")...)
   257  	b = append(b, []byte(label)...)
   258  	b = b[:3+6+len(label)+1]
   259  	b[3+6+len(label)] = uint8(len(context))
   260  	b = append(b, context...)
   261  
   262  	out := make([]byte, length)
   263  	n, err := hkdf.Expand(hash.New, secret, b).Read(out)
   264  	if err != nil || n != length {
   265  		panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
   266  	}
   267  	return out
   268  }