github.com/sagernet/sing-box@v1.9.0-rc.20/common/sniff/quic.go (about)

     1  package sniff
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto"
     7  	"crypto/aes"
     8  	"encoding/binary"
     9  	"io"
    10  	"os"
    11  
    12  	"github.com/sagernet/sing-box/adapter"
    13  	"github.com/sagernet/sing-box/common/sniff/internal/qtls"
    14  	C "github.com/sagernet/sing-box/constant"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  
    17  	"golang.org/x/crypto/hkdf"
    18  )
    19  
    20  func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
    21  	reader := bytes.NewReader(packet)
    22  
    23  	typeByte, err := reader.ReadByte()
    24  	if err != nil {
    25  		return nil, err
    26  	}
    27  	if typeByte&0x40 == 0 {
    28  		return nil, E.New("bad type byte")
    29  	}
    30  	var versionNumber uint32
    31  	err = binary.Read(reader, binary.BigEndian, &versionNumber)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  	if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 {
    36  		return nil, E.New("bad version")
    37  	}
    38  	packetType := (typeByte & 0x30) >> 4
    39  	if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 {
    40  		return nil, E.New("bad packet type")
    41  	}
    42  
    43  	destConnIDLen, err := reader.ReadByte()
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	if destConnIDLen == 0 || destConnIDLen > 20 {
    49  		return nil, E.New("bad destination connection id length")
    50  	}
    51  
    52  	destConnID := make([]byte, destConnIDLen)
    53  	_, err = io.ReadFull(reader, destConnID)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	srcConnIDLen, err := reader.ReadByte()
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	_, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen))
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	tokenLen, err := qtls.ReadUvarint(reader)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	_, err = io.CopyN(io.Discard, reader, int64(tokenLen))
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	packetLen, err := qtls.ReadUvarint(reader)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	hdrLen := int(reader.Size()) - reader.Len()
    84  	if hdrLen+int(packetLen) > len(packet) {
    85  		return nil, os.ErrInvalid
    86  	}
    87  
    88  	_, err = io.CopyN(io.Discard, reader, 4)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	pnBytes := make([]byte, aes.BlockSize)
    94  	_, err = io.ReadFull(reader, pnBytes)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	var salt []byte
   100  	switch versionNumber {
   101  	case qtls.Version1:
   102  		salt = qtls.SaltV1
   103  	case qtls.Version2:
   104  		salt = qtls.SaltV2
   105  	default:
   106  		salt = qtls.SaltOld
   107  	}
   108  	var hkdfHeaderProtectionLabel string
   109  	switch versionNumber {
   110  	case qtls.Version2:
   111  		hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2
   112  	default:
   113  		hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1
   114  	}
   115  	initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
   116  	secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
   117  	hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16)
   118  	block, err := aes.NewCipher(hpKey)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	mask := make([]byte, aes.BlockSize)
   123  	block.Encrypt(mask, pnBytes)
   124  	newPacket := make([]byte, len(packet))
   125  	copy(newPacket, packet)
   126  	newPacket[0] ^= mask[0] & 0xf
   127  	for i := range newPacket[hdrLen : hdrLen+4] {
   128  		newPacket[hdrLen+i] ^= mask[i+1]
   129  	}
   130  	packetNumberLength := newPacket[0]&0x3 + 1
   131  	if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen {
   132  		return nil, os.ErrInvalid
   133  	}
   134  	var packetNumber uint32
   135  	switch packetNumberLength {
   136  	case 1:
   137  		packetNumber = uint32(newPacket[hdrLen])
   138  	case 2:
   139  		packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:]))
   140  	case 3:
   141  		packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16
   142  	case 4:
   143  		packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:])
   144  	default:
   145  		return nil, E.New("bad packet number length")
   146  	}
   147  	extHdrLen := hdrLen + int(packetNumberLength)
   148  	copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:])
   149  	data := newPacket[extHdrLen : int(packetLen)+hdrLen]
   150  
   151  	var keyLabel string
   152  	var ivLabel string
   153  	switch versionNumber {
   154  	case qtls.Version2:
   155  		keyLabel = qtls.HKDFLabelKeyV2
   156  		ivLabel = qtls.HKDFLabelIVV2
   157  	default:
   158  		keyLabel = qtls.HKDFLabelKeyV1
   159  		ivLabel = qtls.HKDFLabelIVV1
   160  	}
   161  
   162  	key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
   163  	iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
   164  	cipher := qtls.AEADAESGCMTLS13(key, iv)
   165  	nonce := make([]byte, int32(cipher.NonceSize()))
   166  	binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
   167  	decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen])
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  	var frameType byte
   172  	var frameLen uint64
   173  	var fragments []struct {
   174  		offset  uint64
   175  		length  uint64
   176  		payload []byte
   177  	}
   178  	decryptedReader := bytes.NewReader(decrypted)
   179  	for {
   180  		frameType, err = decryptedReader.ReadByte()
   181  		if err == io.EOF {
   182  			break
   183  		}
   184  		switch frameType {
   185  		case 0x00: // PADDING
   186  			continue
   187  		case 0x01: // PING
   188  			continue
   189  		case 0x02, 0x03: // ACK
   190  			_, err = qtls.ReadUvarint(decryptedReader) // Largest Acknowledged
   191  			if err != nil {
   192  				return nil, err
   193  			}
   194  			_, err = qtls.ReadUvarint(decryptedReader) // ACK Delay
   195  			if err != nil {
   196  				return nil, err
   197  			}
   198  			ackRangeCount, err := qtls.ReadUvarint(decryptedReader) // ACK Range Count
   199  			if err != nil {
   200  				return nil, err
   201  			}
   202  			_, err = qtls.ReadUvarint(decryptedReader) // First ACK Range
   203  			if err != nil {
   204  				return nil, err
   205  			}
   206  			for i := 0; i < int(ackRangeCount); i++ {
   207  				_, err = qtls.ReadUvarint(decryptedReader) // Gap
   208  				if err != nil {
   209  					return nil, err
   210  				}
   211  				_, err = qtls.ReadUvarint(decryptedReader) // ACK Range Length
   212  				if err != nil {
   213  					return nil, err
   214  				}
   215  			}
   216  			if frameType == 0x03 {
   217  				_, err = qtls.ReadUvarint(decryptedReader) // ECT0 Count
   218  				if err != nil {
   219  					return nil, err
   220  				}
   221  				_, err = qtls.ReadUvarint(decryptedReader) // ECT1 Count
   222  				if err != nil {
   223  					return nil, err
   224  				}
   225  				_, err = qtls.ReadUvarint(decryptedReader) // ECN-CE Count
   226  				if err != nil {
   227  					return nil, err
   228  				}
   229  			}
   230  		case 0x06: // CRYPTO
   231  			var offset uint64
   232  			offset, err = qtls.ReadUvarint(decryptedReader)
   233  			if err != nil {
   234  				return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
   235  			}
   236  			var length uint64
   237  			length, err = qtls.ReadUvarint(decryptedReader)
   238  			if err != nil {
   239  				return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
   240  			}
   241  			index := len(decrypted) - decryptedReader.Len()
   242  			fragments = append(fragments, struct {
   243  				offset  uint64
   244  				length  uint64
   245  				payload []byte
   246  			}{offset, length, decrypted[index : index+int(length)]})
   247  			frameLen += length
   248  			_, err = decryptedReader.Seek(int64(length), io.SeekCurrent)
   249  			if err != nil {
   250  				return nil, err
   251  			}
   252  		case 0x1c: // CONNECTION_CLOSE
   253  			_, err = qtls.ReadUvarint(decryptedReader) // Error Code
   254  			if err != nil {
   255  				return nil, err
   256  			}
   257  			_, err = qtls.ReadUvarint(decryptedReader) // Frame Type
   258  			if err != nil {
   259  				return nil, err
   260  			}
   261  			var length uint64
   262  			length, err = qtls.ReadUvarint(decryptedReader) // Reason Phrase Length
   263  			if err != nil {
   264  				return nil, err
   265  			}
   266  			_, err = decryptedReader.Seek(int64(length), io.SeekCurrent) // Reason Phrase
   267  			if err != nil {
   268  				return nil, err
   269  			}
   270  		default:
   271  			return nil, os.ErrInvalid
   272  		}
   273  	}
   274  	tlsHdr := make([]byte, 5)
   275  	tlsHdr[0] = 0x16
   276  	binary.BigEndian.PutUint16(tlsHdr[1:], uint16(0x0303))
   277  	binary.BigEndian.PutUint16(tlsHdr[3:], uint16(frameLen))
   278  	var index uint64
   279  	var length int
   280  	var readers []io.Reader
   281  	readers = append(readers, bytes.NewReader(tlsHdr))
   282  find:
   283  	for {
   284  		for _, fragment := range fragments {
   285  			if fragment.offset == index {
   286  				readers = append(readers, bytes.NewReader(fragment.payload))
   287  				index = fragment.offset + fragment.length
   288  				length++
   289  				continue find
   290  			}
   291  		}
   292  		if length == len(fragments) {
   293  			break
   294  		}
   295  		return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, E.New("bad fragments")
   296  	}
   297  	metadata, err := TLSClientHello(ctx, io.MultiReader(readers...))
   298  	if err != nil {
   299  		return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err
   300  	}
   301  	metadata.Protocol = C.ProtocolQUIC
   302  	return metadata, nil
   303  }