github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/vision/vision.go (about) 1 package vision 2 3 import ( 4 "bufio" 5 "bytes" 6 "crypto/rand" 7 "crypto/tls" 8 "fmt" 9 "io" 10 "math/big" 11 "net" 12 "reflect" 13 "time" 14 "unsafe" 15 16 "github.com/Asutorufa/yuhaiin/pkg/log" 17 utls "github.com/refraction-networking/utls" 18 ) 19 20 var ( 21 tls13SupportedVersions = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04} 22 tlsClientHandShakeStart = []byte{0x16, 0x03} 23 tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03} 24 tlsApplicationDataStart = []byte{0x17, 0x03, 0x03} 25 ) 26 27 const ( 28 commandPaddingContinue byte = iota 29 commandPaddingEnd 30 commandPaddingDirect 31 ) 32 33 var tls13CipherSuiteDic = map[uint16]string{ 34 0x1301: "TLS_AES_128_GCM_SHA256", 35 0x1302: "TLS_AES_256_GCM_SHA384", 36 0x1303: "TLS_CHACHA20_POLY1305_SHA256", 37 0x1304: "TLS_AES_128_CCM_SHA256", 38 0x1305: "TLS_AES_128_CCM_8_SHA256", 39 } 40 41 func reshapeBuffer(b []byte) [][]byte { 42 const bufferLimit = 8192 - 21 43 if len(b) < bufferLimit { 44 return [][]byte{b} 45 } 46 index := int32(bytes.LastIndex(b, tlsApplicationDataStart)) 47 if index <= 0 { 48 index = 8192 / 2 49 } 50 51 return [][]byte{b[:index], b[index:]} 52 } 53 54 const xrayChunkSize = 8192 55 56 type VisionConn struct { 57 net.Conn 58 reader *bufio.Reader 59 writer net.Conn 60 input *bytes.Reader 61 rawInput *bytes.Buffer 62 netConn net.Conn 63 64 userUUID [16]byte 65 isTLS bool 66 numberOfPacketToFilter int 67 isTLS12orAbove bool 68 remainingServerHello int32 69 cipher uint16 70 enableXTLS bool 71 isPadding bool 72 directWrite bool 73 writeUUID bool 74 withinPaddingBuffers bool 75 remainingContent int 76 remainingPadding int 77 currentCommand byte 78 directRead bool 79 remainingReader io.Reader 80 } 81 82 func NewVisionConn(conn net.Conn, tlsConn net.Conn, userUUID [16]byte) (*VisionConn, error) { 83 var ( 84 reflectType reflect.Type 85 reflectPointer unsafe.Pointer 86 netConn net.Conn 87 ) 88 89 switch underlying := tlsConn.(type) { 90 case *tls.Conn: 91 netConn = underlying.NetConn() 92 reflectType = reflect.TypeOf(underlying).Elem() 93 reflectPointer = unsafe.Pointer(underlying) 94 case *utls.UConn: 95 netConn = underlying.NetConn() 96 reflectType = reflect.TypeOf(underlying.Conn).Elem() 97 reflectPointer = unsafe.Pointer(underlying.Conn) 98 default: 99 return nil, fmt.Errorf(`failed to use vision, maybe "security" is not "tls" or "utls"`) 100 } 101 102 input, _ := reflectType.FieldByName("input") 103 rawInput, _ := reflectType.FieldByName("rawInput") 104 105 return &VisionConn{ 106 Conn: conn, 107 reader: bufio.NewReaderSize(conn, xrayChunkSize), 108 writer: conn, 109 input: (*bytes.Reader)(unsafe.Add(reflectPointer, input.Offset)), 110 rawInput: (*bytes.Buffer)(unsafe.Add(reflectPointer, rawInput.Offset)), 111 netConn: netConn, 112 113 userUUID: userUUID, 114 numberOfPacketToFilter: 8, 115 remainingServerHello: -1, 116 isPadding: true, 117 writeUUID: true, 118 withinPaddingBuffers: true, 119 remainingContent: -1, 120 remainingPadding: -1, 121 }, nil 122 } 123 124 func (c *VisionConn) Read(p []byte) (n int, err error) { 125 if c.remainingReader != nil { 126 n, err = c.remainingReader.Read(p) 127 if err == io.EOF { 128 err = nil 129 c.remainingReader = nil 130 } 131 if n > 0 { 132 return 133 } 134 } 135 136 if c.directRead { 137 return c.netConn.Read(p) 138 } 139 140 var bufferBytes []byte 141 var chunkBuffer []byte 142 if len(p) > xrayChunkSize { 143 n, err = c.Conn.Read(p) 144 if err != nil { 145 return 146 } 147 bufferBytes = p[:n] 148 } else { 149 buf := make([]byte, xrayChunkSize) 150 n, err = c.reader.Read(buf) 151 if err != nil { 152 return 0, err 153 } 154 chunkBuffer = buf[:n] 155 bufferBytes = chunkBuffer 156 } 157 if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 { 158 buffers := c.unPadding(bufferBytes) 159 160 if c.remainingContent == 0 && c.remainingPadding == 0 { 161 if c.currentCommand == commandPaddingEnd { 162 c.withinPaddingBuffers = false 163 c.remainingContent = -1 164 c.remainingPadding = -1 165 } else if c.currentCommand == commandPaddingDirect { 166 c.withinPaddingBuffers = false 167 c.directRead = true 168 169 inputBuffer, err := io.ReadAll(c.input) 170 if err != nil { 171 return 0, err 172 } 173 buffers = append(buffers, inputBuffer) 174 175 rawInputBuffer, err := io.ReadAll(c.rawInput) 176 if err != nil { 177 return 0, err 178 } 179 180 buffers = append(buffers, rawInputBuffer) 181 182 log.Debug("XtlsRead readV") 183 } else if c.currentCommand == commandPaddingContinue { 184 c.withinPaddingBuffers = true 185 } else { 186 return 0, fmt.Errorf("unknown command %v", c.currentCommand) 187 } 188 } else if c.remainingContent > 0 || c.remainingPadding > 0 { 189 c.withinPaddingBuffers = true 190 } else { 191 c.withinPaddingBuffers = false 192 } 193 if c.numberOfPacketToFilter > 0 { 194 c.filterTLS(buffers) 195 } 196 nBuffers := net.Buffers(buffers) 197 c.remainingReader = &nBuffers 198 return c.Read(p) 199 } else { 200 if c.numberOfPacketToFilter > 0 { 201 c.filterTLS([][]byte{bufferBytes}) 202 } 203 if chunkBuffer != nil { 204 n = copy(p, bufferBytes) 205 } 206 return 207 } 208 } 209 210 func (c *VisionConn) Write(p []byte) (n int, err error) { 211 if c.numberOfPacketToFilter > 0 { 212 c.filterTLS([][]byte{p}) 213 } 214 if c.isPadding { 215 inputLen := len(p) 216 buffers := reshapeBuffer(p) 217 var specIndex int 218 for i, buffer := range buffers { 219 if c.isTLS && len(buffer) > 6 && bytes.Equal(tlsApplicationDataStart, buffer[:3]) { 220 var command byte = commandPaddingEnd 221 if c.enableXTLS { 222 c.directWrite = true 223 specIndex = i 224 command = commandPaddingDirect 225 } 226 c.isPadding = false 227 buffers[i] = c.padding(buffer, command) 228 break 229 } else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 { 230 c.isPadding = false 231 buffers[i] = c.padding(buffer, commandPaddingEnd) 232 break 233 } 234 buffers[i] = c.padding(buffer, commandPaddingContinue) 235 } 236 237 if c.directWrite { 238 encryptedBuffer := buffers[:specIndex+1] 239 240 for _, v := range encryptedBuffer { 241 _, err = c.writer.Write(v) 242 if err != nil { 243 return 244 } 245 } 246 buffers = buffers[specIndex+1:] 247 c.writer = c.netConn 248 time.Sleep(5 * time.Millisecond) // wtf 249 } 250 251 for _, v := range buffers { 252 _, err = c.writer.Write(v) 253 if err != nil { 254 return 255 } 256 } 257 n = inputLen 258 return 259 } 260 261 if c.directWrite { 262 return c.netConn.Write(p) 263 } else { 264 return c.Conn.Write(p) 265 } 266 } 267 268 func (c *VisionConn) filterTLS(buffers [][]byte) { 269 for _, buffer := range buffers { 270 c.numberOfPacketToFilter-- 271 if len(buffer) > 6 { 272 if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 { 273 c.isTLS = true 274 if buffer[5] == 2 { 275 c.isTLS12orAbove = true 276 c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5 277 if len(buffer) >= 79 && c.remainingServerHello >= 79 { 278 sessionIdLen := int32(buffer[43]) 279 cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3] 280 c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1]) 281 } else { 282 log.Info("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello) 283 } 284 } 285 } else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 { 286 c.isTLS = true 287 log.Debug("XtlsFilterTls found tls client hello! ", len(buffer)) 288 } 289 } 290 if c.remainingServerHello > 0 { 291 end := int(c.remainingServerHello) 292 if end > len(buffer) { 293 end = len(buffer) 294 } 295 c.remainingServerHello -= int32(end) 296 if bytes.Contains(buffer[:end], tls13SupportedVersions) { 297 cipher, ok := tls13CipherSuiteDic[c.cipher] 298 if ok && cipher != "TLS_AES_128_CCM_8_SHA256" { 299 c.enableXTLS = true 300 } 301 log.Debug("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS) 302 c.numberOfPacketToFilter = 0 303 return 304 } else if c.remainingServerHello == 0 { 305 log.Debug("XtlsFilterTls found tls 1.2! ", len(buffer)) 306 c.numberOfPacketToFilter = 0 307 return 308 } 309 } 310 if c.numberOfPacketToFilter == 0 { 311 log.Debug("XtlsFilterTls stop filtering ", len(buffer)) 312 } 313 } 314 } 315 316 func (c *VisionConn) padding(buffer []byte, command byte) []byte { 317 contentLen := 0 318 paddingLen := 0 319 if buffer != nil { 320 contentLen = len(buffer) 321 } 322 if contentLen < 900 && c.isTLS { 323 l, _ := rand.Int(rand.Reader, big.NewInt(500)) 324 paddingLen = int(l.Int64()) + 900 - contentLen 325 } else { 326 l, _ := rand.Int(rand.Reader, big.NewInt(256)) 327 paddingLen = int(l.Int64()) 328 } 329 var bufferLen int 330 if c.writeUUID { 331 bufferLen += 16 332 } 333 bufferLen += 5 334 if buffer != nil { 335 bufferLen += len(buffer) 336 } 337 bufferLen += paddingLen 338 339 newBuffer := bytes.NewBuffer(nil) 340 if c.writeUUID { 341 newBuffer.Write(c.userUUID[:]) 342 c.writeUUID = false 343 } 344 newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}) 345 if buffer != nil { 346 newBuffer.Write(buffer) 347 } 348 newBuffer.Write(make([]byte, paddingLen)) 349 // newBuffer.Extend(paddingLen) 350 log.Debug("XtlsPadding ", contentLen, " ", paddingLen, " ", command) 351 return newBuffer.Bytes() 352 } 353 354 func (c *VisionConn) unPadding(buffer []byte) [][]byte { 355 var bufferIndex int 356 if c.remainingContent == -1 && c.remainingPadding == -1 { 357 if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) { 358 bufferIndex = 16 359 c.remainingContent = 0 360 c.remainingPadding = 0 361 c.currentCommand = 0 362 } 363 } 364 if c.remainingContent == -1 && c.remainingPadding == -1 { 365 return [][]byte{buffer} 366 } 367 var buffers [][]byte 368 for bufferIndex < len(buffer) { 369 if c.remainingContent <= 0 && c.remainingPadding <= 0 { 370 if c.currentCommand == 1 { 371 buffers = append(buffers, buffer[bufferIndex:]) 372 break 373 } else { 374 paddingInfo := buffer[bufferIndex : bufferIndex+5] 375 c.currentCommand = paddingInfo[0] 376 c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2]) 377 c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4]) 378 bufferIndex += 5 379 log.Debug("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand) 380 } 381 } else if c.remainingContent > 0 { 382 end := c.remainingContent 383 if end > len(buffer)-bufferIndex { 384 end = len(buffer) - bufferIndex 385 } 386 buffers = append(buffers, buffer[bufferIndex:bufferIndex+end]) 387 c.remainingContent -= end 388 bufferIndex += end 389 } else { 390 end := c.remainingPadding 391 if end > len(buffer)-bufferIndex { 392 end = len(buffer) - bufferIndex 393 } 394 c.remainingPadding -= end 395 bufferIndex += end 396 } 397 if bufferIndex == len(buffer) { 398 break 399 } 400 } 401 return buffers 402 } 403 404 func (c *VisionConn) NeedAdditionalReadDeadline() bool { 405 return true 406 } 407 408 func (c *VisionConn) Upstream() any { 409 return c.Conn 410 }