github.com/sagernet/sing-box@v1.9.0-rc.20/transport/vless/vision.go (about) 1 package vless 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "crypto/tls" 7 "io" 8 "math/big" 9 "net" 10 "reflect" 11 "time" 12 "unsafe" 13 14 C "github.com/sagernet/sing-box/constant" 15 "github.com/sagernet/sing/common" 16 "github.com/sagernet/sing/common/buf" 17 "github.com/sagernet/sing/common/bufio" 18 E "github.com/sagernet/sing/common/exceptions" 19 "github.com/sagernet/sing/common/logger" 20 N "github.com/sagernet/sing/common/network" 21 ) 22 23 var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) 24 25 func init() { 26 tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) { 27 tlsConn, loaded := common.Cast[*tls.Conn](conn) 28 if !loaded { 29 return 30 } 31 return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn)) 32 }) 33 } 34 35 const xrayChunkSize = 8192 36 37 type VisionConn struct { 38 net.Conn 39 reader *bufio.ChunkReader 40 writer N.VectorisedWriter 41 input *bytes.Reader 42 rawInput *bytes.Buffer 43 netConn net.Conn 44 logger logger.Logger 45 46 userUUID [16]byte 47 isTLS bool 48 numberOfPacketToFilter int 49 isTLS12orAbove bool 50 remainingServerHello int32 51 cipher uint16 52 enableXTLS bool 53 isPadding bool 54 directWrite bool 55 writeUUID bool 56 withinPaddingBuffers bool 57 remainingContent int 58 remainingPadding int 59 currentCommand byte 60 directRead bool 61 remainingReader io.Reader 62 } 63 64 func NewVisionConn(conn net.Conn, tlsConn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) { 65 var ( 66 loaded bool 67 reflectType reflect.Type 68 reflectPointer uintptr 69 netConn net.Conn 70 ) 71 for _, tlsCreator := range tlsRegistry { 72 loaded, netConn, reflectType, reflectPointer = tlsCreator(tlsConn) 73 if loaded { 74 break 75 } 76 } 77 if !loaded { 78 return nil, C.ErrTLSRequired 79 } 80 input, _ := reflectType.FieldByName("input") 81 rawInput, _ := reflectType.FieldByName("rawInput") 82 return &VisionConn{ 83 Conn: conn, 84 reader: bufio.NewChunkReader(conn, xrayChunkSize), 85 writer: bufio.NewVectorisedWriter(conn), 86 input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)), 87 rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)), 88 netConn: netConn, 89 logger: logger, 90 91 userUUID: userUUID, 92 numberOfPacketToFilter: 8, 93 remainingServerHello: -1, 94 isPadding: true, 95 writeUUID: true, 96 withinPaddingBuffers: true, 97 remainingContent: -1, 98 remainingPadding: -1, 99 }, nil 100 } 101 102 func (c *VisionConn) Read(p []byte) (n int, err error) { 103 if c.remainingReader != nil { 104 n, err = c.remainingReader.Read(p) 105 if err == io.EOF { 106 err = nil 107 c.remainingReader = nil 108 } 109 if n > 0 { 110 return 111 } 112 } 113 if c.directRead { 114 return c.netConn.Read(p) 115 } 116 var bufferBytes []byte 117 var chunkBuffer *buf.Buffer 118 if len(p) > xrayChunkSize { 119 n, err = c.Conn.Read(p) 120 if err != nil { 121 return 122 } 123 bufferBytes = p[:n] 124 } else { 125 chunkBuffer, err = c.reader.ReadChunk() 126 if err != nil { 127 return 0, err 128 } 129 bufferBytes = chunkBuffer.Bytes() 130 } 131 if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 { 132 buffers := c.unPadding(bufferBytes) 133 if chunkBuffer != nil { 134 buffers = common.Map(buffers, func(it *buf.Buffer) *buf.Buffer { 135 return it.ToOwned() 136 }) 137 chunkBuffer.Reset() 138 } 139 if c.remainingContent == 0 && c.remainingPadding == 0 { 140 if c.currentCommand == commandPaddingEnd { 141 c.withinPaddingBuffers = false 142 c.remainingContent = -1 143 c.remainingPadding = -1 144 } else if c.currentCommand == commandPaddingDirect { 145 c.withinPaddingBuffers = false 146 c.directRead = true 147 148 inputBuffer, err := io.ReadAll(c.input) 149 if err != nil { 150 return 0, err 151 } 152 buffers = append(buffers, buf.As(inputBuffer)) 153 154 rawInputBuffer, err := io.ReadAll(c.rawInput) 155 if err != nil { 156 return 0, err 157 } 158 159 buffers = append(buffers, buf.As(rawInputBuffer)) 160 161 c.logger.Trace("XtlsRead readV") 162 } else if c.currentCommand == commandPaddingContinue { 163 c.withinPaddingBuffers = true 164 } else { 165 return 0, E.New("unknown command ", c.currentCommand) 166 } 167 } else if c.remainingContent > 0 || c.remainingPadding > 0 { 168 c.withinPaddingBuffers = true 169 } else { 170 c.withinPaddingBuffers = false 171 } 172 if c.numberOfPacketToFilter > 0 { 173 c.filterTLS(buf.ToSliceMulti(buffers)) 174 } 175 c.remainingReader = io.MultiReader(common.Map(buffers, func(it *buf.Buffer) io.Reader { return it })...) 176 return c.Read(p) 177 } else { 178 if c.numberOfPacketToFilter > 0 { 179 c.filterTLS([][]byte{bufferBytes}) 180 } 181 if chunkBuffer != nil { 182 n = copy(p, bufferBytes) 183 chunkBuffer.Advance(n) 184 } 185 return 186 } 187 } 188 189 func (c *VisionConn) Write(p []byte) (n int, err error) { 190 if c.numberOfPacketToFilter > 0 { 191 c.filterTLS([][]byte{p}) 192 } 193 if c.isPadding { 194 inputLen := len(p) 195 buffers := reshapeBuffer(p) 196 var specIndex int 197 for i, buffer := range buffers { 198 if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) { 199 var command byte = commandPaddingEnd 200 if c.enableXTLS { 201 c.directWrite = true 202 specIndex = i 203 command = commandPaddingDirect 204 } 205 c.isPadding = false 206 buffers[i] = c.padding(buffer, command) 207 break 208 } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 { 209 c.isPadding = false 210 buffers[i] = c.padding(buffer, commandPaddingEnd) 211 break 212 } 213 buffers[i] = c.padding(buffer, commandPaddingContinue) 214 } 215 if c.directWrite { 216 encryptedBuffer := buffers[:specIndex+1] 217 err = c.writer.WriteVectorised(encryptedBuffer) 218 if err != nil { 219 return 220 } 221 buffers = buffers[specIndex+1:] 222 c.writer = bufio.NewVectorisedWriter(c.netConn) 223 c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers)) 224 time.Sleep(5 * time.Millisecond) // wtf 225 } 226 err = c.writer.WriteVectorised(buffers) 227 if err == nil { 228 n = inputLen 229 } 230 return 231 } 232 if c.directWrite { 233 return c.netConn.Write(p) 234 } else { 235 return c.Conn.Write(p) 236 } 237 } 238 239 func (c *VisionConn) filterTLS(buffers [][]byte) { 240 for _, buffer := range buffers { 241 c.numberOfPacketToFilter-- 242 if len(buffer) > 6 { 243 if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 { 244 c.isTLS = true 245 if buffer[5] == 2 { 246 c.isTLS12orAbove = true 247 c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5 248 if len(buffer) >= 79 && c.remainingServerHello >= 79 { 249 sessionIdLen := int32(buffer[43]) 250 cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3] 251 c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1]) 252 } else { 253 c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello) 254 } 255 } 256 } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 { 257 c.isTLS = true 258 c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer)) 259 } 260 } 261 if c.remainingServerHello > 0 { 262 end := int(c.remainingServerHello) 263 if end > len(buffer) { 264 end = len(buffer) 265 } 266 c.remainingServerHello -= int32(end) 267 if bytes.Contains(buffer[:end], tls13SupportedVersions) { 268 cipher, ok := tls13CipherSuiteDic[c.cipher] 269 if ok && cipher != "TLS_AES_128_CCM_8_SHA256" { 270 c.enableXTLS = true 271 } 272 c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS) 273 c.numberOfPacketToFilter = 0 274 return 275 } else if c.remainingServerHello == 0 { 276 c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer)) 277 c.numberOfPacketToFilter = 0 278 return 279 } 280 } 281 if c.numberOfPacketToFilter == 0 { 282 c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer)) 283 } 284 } 285 } 286 287 func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer { 288 contentLen := 0 289 paddingLen := 0 290 if buffer != nil { 291 contentLen = buffer.Len() 292 } 293 if contentLen < 900 && c.isTLS { 294 l, _ := rand.Int(rand.Reader, big.NewInt(500)) 295 paddingLen = int(l.Int64()) + 900 - contentLen 296 } else { 297 l, _ := rand.Int(rand.Reader, big.NewInt(256)) 298 paddingLen = int(l.Int64()) 299 } 300 var bufferLen int 301 if c.writeUUID { 302 bufferLen += 16 303 } 304 bufferLen += 5 305 if buffer != nil { 306 bufferLen += buffer.Len() 307 } 308 bufferLen += paddingLen 309 newBuffer := buf.NewSize(bufferLen) 310 if c.writeUUID { 311 common.Must1(newBuffer.Write(c.userUUID[:])) 312 c.writeUUID = false 313 } 314 common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)})) 315 if buffer != nil { 316 common.Must1(newBuffer.Write(buffer.Bytes())) 317 buffer.Release() 318 } 319 newBuffer.Extend(paddingLen) 320 c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command) 321 return newBuffer 322 } 323 324 func (c *VisionConn) unPadding(buffer []byte) []*buf.Buffer { 325 var bufferIndex int 326 if c.remainingContent == -1 && c.remainingPadding == -1 { 327 if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) { 328 bufferIndex = 16 329 c.remainingContent = 0 330 c.remainingPadding = 0 331 c.currentCommand = 0 332 } 333 } 334 if c.remainingContent == -1 && c.remainingPadding == -1 { 335 return []*buf.Buffer{buf.As(buffer)} 336 } 337 var buffers []*buf.Buffer 338 for bufferIndex < len(buffer) { 339 if c.remainingContent <= 0 && c.remainingPadding <= 0 { 340 if c.currentCommand == 1 { 341 buffers = append(buffers, buf.As(buffer[bufferIndex:])) 342 break 343 } else { 344 paddingInfo := buffer[bufferIndex : bufferIndex+5] 345 c.currentCommand = paddingInfo[0] 346 c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2]) 347 c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4]) 348 bufferIndex += 5 349 c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand) 350 } 351 } else if c.remainingContent > 0 { 352 end := c.remainingContent 353 if end > len(buffer)-bufferIndex { 354 end = len(buffer) - bufferIndex 355 } 356 buffers = append(buffers, buf.As(buffer[bufferIndex:bufferIndex+end])) 357 c.remainingContent -= end 358 bufferIndex += end 359 } else { 360 end := c.remainingPadding 361 if end > len(buffer)-bufferIndex { 362 end = len(buffer) - bufferIndex 363 } 364 c.remainingPadding -= end 365 bufferIndex += end 366 } 367 if bufferIndex == len(buffer) { 368 break 369 } 370 } 371 return buffers 372 } 373 374 func (c *VisionConn) NeedAdditionalReadDeadline() bool { 375 return true 376 } 377 378 func (c *VisionConn) Upstream() any { 379 return c.Conn 380 }