github.com/sagernet/sing-box@v1.2.7/transport/vless/vision.go (about) 1 package vless 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "crypto/tls" 7 "io" 8 "math/big" 9 "net" 10 "reflect" 11 "time" 12 "unsafe" 13 14 C "github.com/sagernet/sing-box/constant" 15 "github.com/sagernet/sing/common" 16 "github.com/sagernet/sing/common/buf" 17 "github.com/sagernet/sing/common/bufio" 18 E "github.com/sagernet/sing/common/exceptions" 19 "github.com/sagernet/sing/common/logger" 20 N "github.com/sagernet/sing/common/network" 21 ) 22 23 var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) 24 25 func init() { 26 tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) { 27 tlsConn, loaded := conn.(*tls.Conn) 28 if !loaded { 29 return 30 } 31 return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn)) 32 }) 33 } 34 35 const xrayChunkSize = 8192 36 37 type VisionConn struct { 38 net.Conn 39 reader *bufio.ChunkReader 40 writer N.VectorisedWriter 41 input *bytes.Reader 42 rawInput *bytes.Buffer 43 netConn net.Conn 44 logger logger.Logger 45 46 userUUID [16]byte 47 isTLS bool 48 numberOfPacketToFilter int 49 isTLS12orAbove bool 50 remainingServerHello int32 51 cipher uint16 52 enableXTLS bool 53 isPadding bool 54 directWrite bool 55 writeUUID bool 56 withinPaddingBuffers bool 57 remainingContent int 58 remainingPadding int 59 currentCommand int 60 directRead bool 61 remainingReader io.Reader 62 } 63 64 func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) { 65 var ( 66 loaded bool 67 reflectType reflect.Type 68 reflectPointer uintptr 69 netConn net.Conn 70 ) 71 for _, tlsCreator := range tlsRegistry { 72 loaded, netConn, reflectType, reflectPointer = tlsCreator(conn) 73 if loaded { 74 break 75 } 76 } 77 if !loaded { 78 return nil, C.ErrTLSRequired 79 } 80 input, _ := reflectType.FieldByName("input") 81 rawInput, _ := reflectType.FieldByName("rawInput") 82 return &VisionConn{ 83 Conn: conn, 84 reader: bufio.NewChunkReader(conn, xrayChunkSize), 85 writer: bufio.NewVectorisedWriter(conn), 86 input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)), 87 rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)), 88 netConn: netConn, 89 logger: logger, 90 91 userUUID: userUUID, 92 numberOfPacketToFilter: 8, 93 remainingServerHello: -1, 94 isPadding: true, 95 writeUUID: true, 96 withinPaddingBuffers: true, 97 remainingContent: -1, 98 remainingPadding: -1, 99 }, nil 100 } 101 102 func (c *VisionConn) Read(p []byte) (n int, err error) { 103 if c.remainingReader != nil { 104 n, err = c.remainingReader.Read(p) 105 if err == io.EOF { 106 c.remainingReader = nil 107 } 108 if n > 0 { 109 return 110 } 111 } 112 if c.directRead { 113 return c.netConn.Read(p) 114 } 115 var bufferBytes []byte 116 if len(p) > xrayChunkSize { 117 n, err = c.Conn.Read(p) 118 if err != nil { 119 return 120 } 121 bufferBytes = p[:n] 122 } else { 123 buffer, err := c.reader.ReadChunk() 124 if err != nil { 125 return 0, err 126 } 127 defer buffer.FullReset() 128 bufferBytes = buffer.Bytes() 129 } 130 if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 { 131 buffers := c.unPadding(bufferBytes) 132 if c.remainingContent == 0 && c.remainingPadding == 0 { 133 if c.currentCommand == 1 { 134 c.withinPaddingBuffers = false 135 c.remainingContent = -1 136 c.remainingPadding = -1 137 } else if c.currentCommand == 2 { 138 c.withinPaddingBuffers = false 139 c.directRead = true 140 141 inputBuffer, err := io.ReadAll(c.input) 142 if err != nil { 143 return 0, err 144 } 145 buffers = append(buffers, inputBuffer) 146 147 rawInputBuffer, err := io.ReadAll(c.rawInput) 148 if err != nil { 149 return 0, err 150 } 151 152 buffers = append(buffers, rawInputBuffer) 153 154 c.logger.Trace("XtlsRead readV") 155 } else if c.currentCommand == 0 { 156 c.withinPaddingBuffers = true 157 } else { 158 return 0, E.New("unknown command ", c.currentCommand) 159 } 160 } else if c.remainingContent > 0 || c.remainingPadding > 0 { 161 c.withinPaddingBuffers = true 162 } else { 163 c.withinPaddingBuffers = false 164 } 165 if c.numberOfPacketToFilter > 0 { 166 c.filterTLS(buffers) 167 } 168 c.remainingReader = io.MultiReader(common.Map(buffers, func(it []byte) io.Reader { return bytes.NewReader(it) })...) 169 return c.Read(p) 170 } else { 171 if c.numberOfPacketToFilter > 0 { 172 c.filterTLS([][]byte{bufferBytes}) 173 } 174 return 175 } 176 } 177 178 func (c *VisionConn) Write(p []byte) (n int, err error) { 179 if c.numberOfPacketToFilter > 0 { 180 c.filterTLS([][]byte{p}) 181 } 182 if c.isPadding { 183 inputLen := len(p) 184 buffers := reshapeBuffer(p) 185 var specIndex int 186 for i, buffer := range buffers { 187 if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) { 188 var command byte = commandPaddingEnd 189 if c.enableXTLS { 190 c.directWrite = true 191 specIndex = i 192 command = commandPaddingDirect 193 } 194 c.isPadding = false 195 buffers[i] = c.padding(buffer, command) 196 break 197 } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 { 198 c.isPadding = false 199 buffers[i] = c.padding(buffer, commandPaddingEnd) 200 break 201 } 202 buffers[i] = c.padding(buffer, commandPaddingContinue) 203 } 204 if c.directWrite { 205 encryptedBuffer := buffers[:specIndex+1] 206 err = c.writer.WriteVectorised(encryptedBuffer) 207 if err != nil { 208 return 209 } 210 buffers = buffers[specIndex+1:] 211 c.writer = bufio.NewVectorisedWriter(c.netConn) 212 c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers)) 213 time.Sleep(5 * time.Millisecond) // wtf 214 } 215 err = c.writer.WriteVectorised(buffers) 216 if err == nil { 217 n = inputLen 218 } 219 return 220 } 221 if c.directWrite { 222 return c.netConn.Write(p) 223 } else { 224 return c.Conn.Write(p) 225 } 226 } 227 228 func (c *VisionConn) filterTLS(buffers [][]byte) { 229 for _, buffer := range buffers { 230 c.numberOfPacketToFilter-- 231 if len(buffer) > 6 { 232 if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 { 233 c.isTLS = true 234 if buffer[5] == 2 { 235 c.isTLS12orAbove = true 236 c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5 237 if len(buffer) >= 79 && c.remainingServerHello >= 79 { 238 sessionIdLen := int32(buffer[43]) 239 cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3] 240 c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1]) 241 } else { 242 c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello) 243 } 244 } 245 } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 { 246 c.isTLS = true 247 c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer)) 248 } 249 } 250 if c.remainingServerHello > 0 { 251 end := int(c.remainingServerHello) 252 if end > len(buffer) { 253 end = len(buffer) 254 } 255 c.remainingServerHello -= int32(end) 256 if bytes.Contains(buffer[:end], tls13SupportedVersions) { 257 cipher, ok := tls13CipherSuiteDic[c.cipher] 258 if ok && cipher != "TLS_AES_128_CCM_8_SHA256" { 259 c.enableXTLS = true 260 } 261 c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS) 262 c.numberOfPacketToFilter = 0 263 return 264 } else if c.remainingServerHello == 0 { 265 c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer)) 266 c.numberOfPacketToFilter = 0 267 return 268 } 269 } 270 if c.numberOfPacketToFilter == 0 { 271 c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer)) 272 } 273 } 274 } 275 276 func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer { 277 contentLen := 0 278 paddingLen := 0 279 if buffer != nil { 280 contentLen = buffer.Len() 281 } 282 if contentLen < 900 && c.isTLS { 283 l, _ := rand.Int(rand.Reader, big.NewInt(500)) 284 paddingLen = int(l.Int64()) + 900 - contentLen 285 } else { 286 l, _ := rand.Int(rand.Reader, big.NewInt(256)) 287 paddingLen = int(l.Int64()) 288 } 289 var bufferLen int 290 if c.writeUUID { 291 bufferLen += 16 292 } 293 bufferLen += 5 294 if buffer != nil { 295 bufferLen += buffer.Len() 296 } 297 bufferLen += paddingLen 298 newBuffer := buf.NewSize(bufferLen) 299 if c.writeUUID { 300 common.Must1(newBuffer.Write(c.userUUID[:])) 301 c.writeUUID = false 302 } 303 common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)})) 304 if buffer != nil { 305 common.Must1(newBuffer.Write(buffer.Bytes())) 306 buffer.Release() 307 } 308 newBuffer.Extend(paddingLen) 309 c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command) 310 return newBuffer 311 } 312 313 func (c *VisionConn) unPadding(buffer []byte) [][]byte { 314 var bufferIndex int 315 if c.remainingContent == -1 && c.remainingPadding == -1 { 316 if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) { 317 bufferIndex = 16 318 c.remainingContent = 0 319 c.remainingPadding = 0 320 c.currentCommand = 0 321 } 322 } 323 if c.remainingContent == -1 && c.remainingPadding == -1 { 324 return [][]byte{buffer} 325 } 326 var buffers [][]byte 327 for bufferIndex < len(buffer) { 328 if c.remainingContent <= 0 && c.remainingPadding <= 0 { 329 if c.currentCommand == 1 { 330 buffers = append(buffers, buffer[bufferIndex:]) 331 break 332 } else { 333 paddingInfo := buffer[bufferIndex : bufferIndex+5] 334 c.currentCommand = int(paddingInfo[0]) 335 c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2]) 336 c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4]) 337 bufferIndex += 5 338 c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand) 339 } 340 } else if c.remainingContent > 0 { 341 end := c.remainingContent 342 if end > len(buffer)-bufferIndex { 343 end = len(buffer) - bufferIndex 344 } 345 buffers = append(buffers, buffer[bufferIndex:bufferIndex+end]) 346 c.remainingContent -= end 347 bufferIndex += end 348 } else { 349 end := c.remainingPadding 350 if end > len(buffer)-bufferIndex { 351 end = len(buffer) - bufferIndex 352 } 353 c.remainingPadding -= end 354 bufferIndex += end 355 } 356 if bufferIndex == len(buffer) { 357 break 358 } 359 } 360 return buffers 361 } 362 363 func (c *VisionConn) NeedAdditionalReadDeadline() bool { 364 return true 365 } 366 367 func (c *VisionConn) Upstream() any { 368 return c.Conn 369 }