github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/sniff/quic.go (about) 1 package sniff 2 3 import ( 4 "bytes" 5 "context" 6 "crypto" 7 "crypto/aes" 8 "encoding/binary" 9 "io" 10 "os" 11 12 "github.com/inazumav/sing-box/adapter" 13 "github.com/inazumav/sing-box/common/sniff/internal/qtls" 14 C "github.com/inazumav/sing-box/constant" 15 E "github.com/sagernet/sing/common/exceptions" 16 17 "golang.org/x/crypto/hkdf" 18 ) 19 20 func QUICClientHello(ctx context.Context, packet []byte) (*adapter.InboundContext, error) { 21 reader := bytes.NewReader(packet) 22 23 typeByte, err := reader.ReadByte() 24 if err != nil { 25 return nil, err 26 } 27 if typeByte&0x40 == 0 { 28 return nil, E.New("bad type byte") 29 } 30 var versionNumber uint32 31 err = binary.Read(reader, binary.BigEndian, &versionNumber) 32 if err != nil { 33 return nil, err 34 } 35 if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 { 36 return nil, E.New("bad version") 37 } 38 packetType := (typeByte & 0x30) >> 4 39 if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 { 40 return nil, E.New("bad packet type") 41 } 42 43 destConnIDLen, err := reader.ReadByte() 44 if err != nil { 45 return nil, err 46 } 47 48 if destConnIDLen == 0 || destConnIDLen > 20 { 49 return nil, E.New("bad destination connection id length") 50 } 51 52 destConnID := make([]byte, destConnIDLen) 53 _, err = io.ReadFull(reader, destConnID) 54 if err != nil { 55 return nil, err 56 } 57 58 srcConnIDLen, err := reader.ReadByte() 59 if err != nil { 60 return nil, err 61 } 62 63 _, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen)) 64 if err != nil { 65 return nil, err 66 } 67 68 tokenLen, err := qtls.ReadUvarint(reader) 69 if err != nil { 70 return nil, err 71 } 72 73 _, err = io.CopyN(io.Discard, reader, int64(tokenLen)) 74 if err != nil { 75 return nil, err 76 } 77 78 packetLen, err := qtls.ReadUvarint(reader) 79 if err != nil { 80 return nil, err 81 } 82 83 hdrLen := int(reader.Size()) - reader.Len() 84 if hdrLen+int(packetLen) > len(packet) { 85 return nil, os.ErrInvalid 86 } 87 88 _, err = io.CopyN(io.Discard, reader, 4) 89 if err != nil { 90 return nil, err 91 } 92 93 pnBytes := make([]byte, aes.BlockSize) 94 _, err = io.ReadFull(reader, pnBytes) 95 if err != nil { 96 return nil, err 97 } 98 99 var salt []byte 100 switch versionNumber { 101 case qtls.Version1: 102 salt = qtls.SaltV1 103 case qtls.Version2: 104 salt = qtls.SaltV2 105 default: 106 salt = qtls.SaltOld 107 } 108 var hkdfHeaderProtectionLabel string 109 switch versionNumber { 110 case qtls.Version2: 111 hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2 112 default: 113 hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1 114 } 115 initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt) 116 secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) 117 hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16) 118 block, err := aes.NewCipher(hpKey) 119 if err != nil { 120 return nil, err 121 } 122 mask := make([]byte, aes.BlockSize) 123 block.Encrypt(mask, pnBytes) 124 newPacket := make([]byte, len(packet)) 125 copy(newPacket, packet) 126 newPacket[0] ^= mask[0] & 0xf 127 for i := range newPacket[hdrLen : hdrLen+4] { 128 newPacket[hdrLen+i] ^= mask[i+1] 129 } 130 packetNumberLength := newPacket[0]&0x3 + 1 131 if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen { 132 return nil, os.ErrInvalid 133 } 134 var packetNumber uint32 135 switch packetNumberLength { 136 case 1: 137 packetNumber = uint32(newPacket[hdrLen]) 138 case 2: 139 packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:])) 140 case 3: 141 packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16 142 case 4: 143 packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:]) 144 default: 145 return nil, E.New("bad packet number length") 146 } 147 extHdrLen := hdrLen + int(packetNumberLength) 148 copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:]) 149 data := newPacket[extHdrLen : int(packetLen)+hdrLen] 150 151 var keyLabel string 152 var ivLabel string 153 switch versionNumber { 154 case qtls.Version2: 155 keyLabel = qtls.HKDFLabelKeyV2 156 ivLabel = qtls.HKDFLabelIVV2 157 default: 158 keyLabel = qtls.HKDFLabelKeyV1 159 ivLabel = qtls.HKDFLabelIVV1 160 } 161 162 key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) 163 iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) 164 cipher := qtls.AEADAESGCMTLS13(key, iv) 165 nonce := make([]byte, int32(cipher.NonceSize())) 166 binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber)) 167 decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen]) 168 if err != nil { 169 return nil, err 170 } 171 var frameType byte 172 var frameLen uint64 173 var fragments []struct { 174 offset uint64 175 length uint64 176 payload []byte 177 } 178 decryptedReader := bytes.NewReader(decrypted) 179 for { 180 frameType, err = decryptedReader.ReadByte() 181 if err == io.EOF { 182 break 183 } 184 switch frameType { 185 case 0x0: 186 continue 187 case 0x1: 188 continue 189 case 0x6: 190 var offset uint64 191 offset, err = qtls.ReadUvarint(decryptedReader) 192 if err != nil { 193 return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err 194 } 195 var length uint64 196 length, err = qtls.ReadUvarint(decryptedReader) 197 if err != nil { 198 return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err 199 } 200 index := len(decrypted) - decryptedReader.Len() 201 fragments = append(fragments, struct { 202 offset uint64 203 length uint64 204 payload []byte 205 }{offset, length, decrypted[index : index+int(length)]}) 206 frameLen += length 207 _, err = decryptedReader.Seek(int64(length), io.SeekCurrent) 208 if err != nil { 209 return nil, err 210 } 211 default: 212 // ignore unknown frame type 213 } 214 } 215 tlsHdr := make([]byte, 5) 216 tlsHdr[0] = 0x16 217 binary.BigEndian.PutUint16(tlsHdr[1:], uint16(0x0303)) 218 binary.BigEndian.PutUint16(tlsHdr[3:], uint16(frameLen)) 219 var index uint64 220 var length int 221 var readers []io.Reader 222 readers = append(readers, bytes.NewReader(tlsHdr)) 223 find: 224 for { 225 for _, fragment := range fragments { 226 if fragment.offset == index { 227 readers = append(readers, bytes.NewReader(fragment.payload)) 228 index = fragment.offset + fragment.length 229 length++ 230 continue find 231 } 232 } 233 if length == len(fragments) { 234 break 235 } 236 return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, E.New("bad fragments") 237 } 238 metadata, err := TLSClientHello(ctx, io.MultiReader(readers...)) 239 if err != nil { 240 return &adapter.InboundContext{Protocol: C.ProtocolQUIC}, err 241 } 242 metadata.Protocol = C.ProtocolQUIC 243 return metadata, nil 244 }