github.com/sagernet/sing-box@v1.9.0-rc.20/common/sniff/quic.go (about) 1 package sniff 2 3 import ( 4 "bytes" 5 "context" 6 "crypto" 7 "crypto/aes" 8 "encoding/binary" 9 "io" 10 "os" 11 12 "github.com/sagernet/sing-box/adapter" 13 "github.com/sagernet/sing-box/common/sniff/internal/qtls" 14 C "github.com/sagernet/sing-box/constant" 15 E "github.com/sagernet/sing/common/exceptions" 16 17 "golang.org/x/crypto/hkdf" 18 ) 19 20 func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContext, error) { 21 reader := bytes.NewReader(packet) 22 23 typeByte, err := reader.ReadByte() 24 if err != nil { 25 return nil, err 26 } 27 if typeByte&0x40 == 0 { 28 return nil, E.New("bad type byte") 29 } 30 var versionNumber uint32 31 err = binary.Read(reader, binary.BigEndian, &versionNumber) 32 if err != nil { 33 return nil, err 34 } 35 if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 { 36 return nil, E.New("bad version") 37 } 38 packetType := (typeByte & 0x30) >> 4 39 if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 { 40 return nil, E.New("bad packet type") 41 } 42 43 destConnIDLen, err := reader.ReadByte() 44 if err != nil { 45 return nil, err 46 } 47 48 if destConnIDLen == 0 || destConnIDLen > 20 { 49 return nil, E.New("bad destination connection id length") 50 } 51 52 destConnID := make([]byte, destConnIDLen) 53 _, err = io.ReadFull(reader, destConnID) 54 if err != nil { 55 return nil, err 56 } 57 58 srcConnIDLen, err := reader.ReadByte() 59 if err != nil { 60 return nil, err 61 } 62 63 _, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen)) 64 if err != nil { 65 return nil, err 66 } 67 68 tokenLen, err := qtls.ReadUvarint(reader) 69 if err != nil { 70 return nil, err 71 } 72 73 _, err = io.CopyN(io.Discard, reader, int64(tokenLen)) 74 if err != nil { 75 return nil, err 76 } 77 78 packetLen, err := qtls.ReadUvarint(reader) 79 if err != nil { 80 return nil, err 81 } 82 83 hdrLen := int(reader.Size()) - reader.Len() 84 if hdrLen+int(packetLen) > len(packet) { 85 return nil, os.ErrInvalid 86 } 87 88 _, err = io.CopyN(io.Discard, reader, 4) 89 if err != nil { 90 return nil, err 91 } 92 93 pnBytes := make([]byte, aes.BlockSize) 94 _, err = io.ReadFull(reader, pnBytes) 95 if err != nil { 96 return nil, err 97 } 98 99 var salt []byte 100 switch versionNumber { 101 case qtls.Version1: 102 salt = qtls.SaltV1 103 case qtls.Version2: 104 salt = qtls.SaltV2 105 default: 106 salt = qtls.SaltOld 107 } 108 var hkdfHeaderProtectionLabel string 109 switch versionNumber { 110 case qtls.Version2: 111 hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2 112 default: 113 hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1 114 } 115 initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt) 116 secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) 117 hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16) 118 block, err := aes.NewCipher(hpKey) 119 if err != nil { 120 return nil, err 121 } 122 mask := make([]byte, aes.BlockSize) 123 block.Encrypt(mask, pnBytes) 124 newPacket := make([]byte, len(packet)) 125 copy(newPacket, packet) 126 newPacket[0] ^= mask[0] & 0xf 127 for i := range newPacket[hdrLen : hdrLen+4] { 128 newPacket[hdrLen+i] ^= mask[i+1] 129 } 130 packetNumberLength := newPacket[0]&0x3 + 1 131 if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen { 132 return nil, os.ErrInvalid 133 } 134 var packetNumber uint32 135 switch packetNumberLength { 136 case 1: 137 packetNumber = uint32(newPacket[hdrLen]) 138 case 2: 139 packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:])) 140 case 3: 141 packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16 142 case 4: 143 packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:]) 144 default: 145 return nil, E.New("bad packet number length") 146 } 147 extHdrLen := hdrLen + int(packetNumberLength) 148 copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:]) 149 data := newPacket[extHdrLen : int(packetLen)+hdrLen] 150 151 var keyLabel string 152 var ivLabel string 153 switch versionNumber { 154 case qtls.Version2: 155 keyLabel = qtls.HKDFLabelKeyV2 156 ivLabel = qtls.HKDFLabelIVV2 157 default: 158 keyLabel = qtls.HKDFLabelKeyV1 159 ivLabel = qtls.HKDFLabelIVV1 160 } 161 162 key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) 163 iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) 164 cipher := qtls.AEADAESGCMTLS13(key, iv) 165 nonce := make([]byte, int32(cipher.NonceSize())) 166 binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber)) 167 decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen]) 168 if err != nil { 169 return nil, err 170 } 171 var frameType byte 172 var frameLen uint64 173 var fragments []struct { 174 offset uint64 175 length uint64 176 payload []byte 177 } 178 decryptedReader := bytes.NewReader(decrypted) 179 for { 180 frameType, err = decryptedReader.ReadByte() 181 if err == io.EOF { 182 break 183 } 184 switch frameType { 185 case 0x00: // PADDING 186 continue 187 case 0x01: // PING 188 continue 189 case 0x02, 0x03: // ACK 190 _, err = qtls.ReadUvarint(decryptedReader) // Largest Acknowledged 191 if err != nil { 192 return nil, err 193 } 194 _, err = qtls.ReadUvarint(decryptedReader) // ACK Delay 195 if err != nil { 196 return nil, err 197 } 198 ackRangeCount, err := qtls.ReadUvarint(decryptedReader) // ACK Range Count 199 if err != nil { 200 return nil, err 201 } 202 _, err = qtls.ReadUvarint(decryptedReader) // First ACK Range 203 if err != nil { 204 return nil, err 205 } 206 for i := 0; i < int(ackRangeCount); i++ { 207 _, err = qtls.ReadUvarint(decryptedReader) // Gap 208 if err != nil { 209 return nil, err 210 } 211 _, err = qtls.ReadUvarint(decryptedReader) // ACK Range Length 212 if err != nil { 213 return nil, err 214 } 215 } 216 if frameType == 0x03 { 217 _, err = qtls.ReadUvarint(decryptedReader) // ECT0 Count 218 if err != nil { 219 return nil, err 220 } 221 _, err = qtls.ReadUvarint(decryptedReader) // ECT1 Count 222 if err != nil { 223 return nil, err 224 } 225 _, err = qtls.ReadUvarint(decryptedReader) // ECN-CE Count 226 if err != nil { 227 return nil, err 228 } 229 } 230 case 0x06: // CRYPTO 231 var offset uint64 232 offset, err = qtls.ReadUvarint(decryptedReader) 233 if err != nil { 234 return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err 235 } 236 var length uint64 237 length, err = qtls.ReadUvarint(decryptedReader) 238 if err != nil { 239 return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err 240 } 241 index := len(decrypted) - decryptedReader.Len() 242 fragments = append(fragments, struct { 243 offset uint64 244 length uint64 245 payload []byte 246 }{offset, length, decrypted[index : index+int(length)]}) 247 frameLen += length 248 _, err = decryptedReader.Seek(int64(length), io.SeekCurrent) 249 if err != nil { 250 return nil, err 251 } 252 case 0x1c: // CONNECTION_CLOSE 253 _, err = qtls.ReadUvarint(decryptedReader) // Error Code 254 if err != nil { 255 return nil, err 256 } 257 _, err = qtls.ReadUvarint(decryptedReader) // Frame Type 258 if err != nil { 259 return nil, err 260 } 261 var length uint64 262 length, err = qtls.ReadUvarint(decryptedReader) // Reason Phrase Length 263 if err != nil { 264 return nil, err 265 } 266 _, err = decryptedReader.Seek(int64(length), io.SeekCurrent) // Reason Phrase 267 if err != nil { 268 return nil, err 269 } 270 default: 271 return nil, os.ErrInvalid 272 } 273 } 274 tlsHdr := make([]byte, 5) 275 tlsHdr[0] = 0x16 276 binary.BigEndian.PutUint16(tlsHdr[1:], uint16(0x0303)) 277 binary.BigEndian.PutUint16(tlsHdr[3:], uint16(frameLen)) 278 var index uint64 279 var length int 280 var readers []io.Reader 281 readers = append(readers, bytes.NewReader(tlsHdr)) 282 find: 283 for { 284 for _, fragment := range fragments { 285 if fragment.offset == index { 286 readers = append(readers, bytes.NewReader(fragment.payload)) 287 index = fragment.offset + fragment.length 288 length++ 289 continue find 290 } 291 } 292 if length == len(fragments) { 293 break 294 } 295 return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, E.New("bad fragments") 296 } 297 metadata, err := TLSClientHello(ctx, io.MultiReader(readers...)) 298 if err != nil { 299 return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err 300 } 301 metadata.Protocol = C.ProtocolQUIC 302 return metadata, nil 303 }