github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/vision/vision.go (about)

     1  package vision
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/rand"
     7  	"crypto/tls"
     8  	"fmt"
     9  	"io"
    10  	"math/big"
    11  	"net"
    12  	"reflect"
    13  	"time"
    14  	"unsafe"
    15  
    16  	"github.com/Asutorufa/yuhaiin/pkg/log"
    17  	utls "github.com/refraction-networking/utls"
    18  )
    19  
    20  var (
    21  	tls13SupportedVersions  = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}
    22  	tlsClientHandShakeStart = []byte{0x16, 0x03}
    23  	tlsServerHandShakeStart = []byte{0x16, 0x03, 0x03}
    24  	tlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
    25  )
    26  
    27  const (
    28  	commandPaddingContinue byte = iota
    29  	commandPaddingEnd
    30  	commandPaddingDirect
    31  )
    32  
    33  var tls13CipherSuiteDic = map[uint16]string{
    34  	0x1301: "TLS_AES_128_GCM_SHA256",
    35  	0x1302: "TLS_AES_256_GCM_SHA384",
    36  	0x1303: "TLS_CHACHA20_POLY1305_SHA256",
    37  	0x1304: "TLS_AES_128_CCM_SHA256",
    38  	0x1305: "TLS_AES_128_CCM_8_SHA256",
    39  }
    40  
    41  func reshapeBuffer(b []byte) [][]byte {
    42  	const bufferLimit = 8192 - 21
    43  	if len(b) < bufferLimit {
    44  		return [][]byte{b}
    45  	}
    46  	index := int32(bytes.LastIndex(b, tlsApplicationDataStart))
    47  	if index <= 0 {
    48  		index = 8192 / 2
    49  	}
    50  
    51  	return [][]byte{b[:index], b[index:]}
    52  }
    53  
    54  const xrayChunkSize = 8192
    55  
    56  type VisionConn struct {
    57  	net.Conn
    58  	reader   *bufio.Reader
    59  	writer   net.Conn
    60  	input    *bytes.Reader
    61  	rawInput *bytes.Buffer
    62  	netConn  net.Conn
    63  
    64  	userUUID               [16]byte
    65  	isTLS                  bool
    66  	numberOfPacketToFilter int
    67  	isTLS12orAbove         bool
    68  	remainingServerHello   int32
    69  	cipher                 uint16
    70  	enableXTLS             bool
    71  	isPadding              bool
    72  	directWrite            bool
    73  	writeUUID              bool
    74  	withinPaddingBuffers   bool
    75  	remainingContent       int
    76  	remainingPadding       int
    77  	currentCommand         byte
    78  	directRead             bool
    79  	remainingReader        io.Reader
    80  }
    81  
    82  func NewVisionConn(conn net.Conn, tlsConn net.Conn, userUUID [16]byte) (*VisionConn, error) {
    83  	var (
    84  		reflectType    reflect.Type
    85  		reflectPointer unsafe.Pointer
    86  		netConn        net.Conn
    87  	)
    88  
    89  	switch underlying := tlsConn.(type) {
    90  	case *tls.Conn:
    91  		netConn = underlying.NetConn()
    92  		reflectType = reflect.TypeOf(underlying).Elem()
    93  		reflectPointer = unsafe.Pointer(underlying)
    94  	case *utls.UConn:
    95  		netConn = underlying.NetConn()
    96  		reflectType = reflect.TypeOf(underlying.Conn).Elem()
    97  		reflectPointer = unsafe.Pointer(underlying.Conn)
    98  	default:
    99  		return nil, fmt.Errorf(`failed to use vision, maybe "security" is not "tls" or "utls"`)
   100  	}
   101  
   102  	input, _ := reflectType.FieldByName("input")
   103  	rawInput, _ := reflectType.FieldByName("rawInput")
   104  
   105  	return &VisionConn{
   106  		Conn:     conn,
   107  		reader:   bufio.NewReaderSize(conn, xrayChunkSize),
   108  		writer:   conn,
   109  		input:    (*bytes.Reader)(unsafe.Add(reflectPointer, input.Offset)),
   110  		rawInput: (*bytes.Buffer)(unsafe.Add(reflectPointer, rawInput.Offset)),
   111  		netConn:  netConn,
   112  
   113  		userUUID:               userUUID,
   114  		numberOfPacketToFilter: 8,
   115  		remainingServerHello:   -1,
   116  		isPadding:              true,
   117  		writeUUID:              true,
   118  		withinPaddingBuffers:   true,
   119  		remainingContent:       -1,
   120  		remainingPadding:       -1,
   121  	}, nil
   122  }
   123  
   124  func (c *VisionConn) Read(p []byte) (n int, err error) {
   125  	if c.remainingReader != nil {
   126  		n, err = c.remainingReader.Read(p)
   127  		if err == io.EOF {
   128  			err = nil
   129  			c.remainingReader = nil
   130  		}
   131  		if n > 0 {
   132  			return
   133  		}
   134  	}
   135  
   136  	if c.directRead {
   137  		return c.netConn.Read(p)
   138  	}
   139  
   140  	var bufferBytes []byte
   141  	var chunkBuffer []byte
   142  	if len(p) > xrayChunkSize {
   143  		n, err = c.Conn.Read(p)
   144  		if err != nil {
   145  			return
   146  		}
   147  		bufferBytes = p[:n]
   148  	} else {
   149  		buf := make([]byte, xrayChunkSize)
   150  		n, err = c.reader.Read(buf)
   151  		if err != nil {
   152  			return 0, err
   153  		}
   154  		chunkBuffer = buf[:n]
   155  		bufferBytes = chunkBuffer
   156  	}
   157  	if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
   158  		buffers := c.unPadding(bufferBytes)
   159  
   160  		if c.remainingContent == 0 && c.remainingPadding == 0 {
   161  			if c.currentCommand == commandPaddingEnd {
   162  				c.withinPaddingBuffers = false
   163  				c.remainingContent = -1
   164  				c.remainingPadding = -1
   165  			} else if c.currentCommand == commandPaddingDirect {
   166  				c.withinPaddingBuffers = false
   167  				c.directRead = true
   168  
   169  				inputBuffer, err := io.ReadAll(c.input)
   170  				if err != nil {
   171  					return 0, err
   172  				}
   173  				buffers = append(buffers, inputBuffer)
   174  
   175  				rawInputBuffer, err := io.ReadAll(c.rawInput)
   176  				if err != nil {
   177  					return 0, err
   178  				}
   179  
   180  				buffers = append(buffers, rawInputBuffer)
   181  
   182  				log.Debug("XtlsRead readV")
   183  			} else if c.currentCommand == commandPaddingContinue {
   184  				c.withinPaddingBuffers = true
   185  			} else {
   186  				return 0, fmt.Errorf("unknown command %v", c.currentCommand)
   187  			}
   188  		} else if c.remainingContent > 0 || c.remainingPadding > 0 {
   189  			c.withinPaddingBuffers = true
   190  		} else {
   191  			c.withinPaddingBuffers = false
   192  		}
   193  		if c.numberOfPacketToFilter > 0 {
   194  			c.filterTLS(buffers)
   195  		}
   196  		nBuffers := net.Buffers(buffers)
   197  		c.remainingReader = &nBuffers
   198  		return c.Read(p)
   199  	} else {
   200  		if c.numberOfPacketToFilter > 0 {
   201  			c.filterTLS([][]byte{bufferBytes})
   202  		}
   203  		if chunkBuffer != nil {
   204  			n = copy(p, bufferBytes)
   205  		}
   206  		return
   207  	}
   208  }
   209  
   210  func (c *VisionConn) Write(p []byte) (n int, err error) {
   211  	if c.numberOfPacketToFilter > 0 {
   212  		c.filterTLS([][]byte{p})
   213  	}
   214  	if c.isPadding {
   215  		inputLen := len(p)
   216  		buffers := reshapeBuffer(p)
   217  		var specIndex int
   218  		for i, buffer := range buffers {
   219  			if c.isTLS && len(buffer) > 6 && bytes.Equal(tlsApplicationDataStart, buffer[:3]) {
   220  				var command byte = commandPaddingEnd
   221  				if c.enableXTLS {
   222  					c.directWrite = true
   223  					specIndex = i
   224  					command = commandPaddingDirect
   225  				}
   226  				c.isPadding = false
   227  				buffers[i] = c.padding(buffer, command)
   228  				break
   229  			} else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
   230  				c.isPadding = false
   231  				buffers[i] = c.padding(buffer, commandPaddingEnd)
   232  				break
   233  			}
   234  			buffers[i] = c.padding(buffer, commandPaddingContinue)
   235  		}
   236  
   237  		if c.directWrite {
   238  			encryptedBuffer := buffers[:specIndex+1]
   239  
   240  			for _, v := range encryptedBuffer {
   241  				_, err = c.writer.Write(v)
   242  				if err != nil {
   243  					return
   244  				}
   245  			}
   246  			buffers = buffers[specIndex+1:]
   247  			c.writer = c.netConn
   248  			time.Sleep(5 * time.Millisecond) // wtf
   249  		}
   250  
   251  		for _, v := range buffers {
   252  			_, err = c.writer.Write(v)
   253  			if err != nil {
   254  				return
   255  			}
   256  		}
   257  		n = inputLen
   258  		return
   259  	}
   260  
   261  	if c.directWrite {
   262  		return c.netConn.Write(p)
   263  	} else {
   264  		return c.Conn.Write(p)
   265  	}
   266  }
   267  
   268  func (c *VisionConn) filterTLS(buffers [][]byte) {
   269  	for _, buffer := range buffers {
   270  		c.numberOfPacketToFilter--
   271  		if len(buffer) > 6 {
   272  			if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
   273  				c.isTLS = true
   274  				if buffer[5] == 2 {
   275  					c.isTLS12orAbove = true
   276  					c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5
   277  					if len(buffer) >= 79 && c.remainingServerHello >= 79 {
   278  						sessionIdLen := int32(buffer[43])
   279  						cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
   280  						c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
   281  					} else {
   282  						log.Info("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
   283  					}
   284  				}
   285  			} else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
   286  				c.isTLS = true
   287  				log.Debug("XtlsFilterTls found tls client hello! ", len(buffer))
   288  			}
   289  		}
   290  		if c.remainingServerHello > 0 {
   291  			end := int(c.remainingServerHello)
   292  			if end > len(buffer) {
   293  				end = len(buffer)
   294  			}
   295  			c.remainingServerHello -= int32(end)
   296  			if bytes.Contains(buffer[:end], tls13SupportedVersions) {
   297  				cipher, ok := tls13CipherSuiteDic[c.cipher]
   298  				if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
   299  					c.enableXTLS = true
   300  				}
   301  				log.Debug("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
   302  				c.numberOfPacketToFilter = 0
   303  				return
   304  			} else if c.remainingServerHello == 0 {
   305  				log.Debug("XtlsFilterTls found tls 1.2! ", len(buffer))
   306  				c.numberOfPacketToFilter = 0
   307  				return
   308  			}
   309  		}
   310  		if c.numberOfPacketToFilter == 0 {
   311  			log.Debug("XtlsFilterTls stop filtering ", len(buffer))
   312  		}
   313  	}
   314  }
   315  
   316  func (c *VisionConn) padding(buffer []byte, command byte) []byte {
   317  	contentLen := 0
   318  	paddingLen := 0
   319  	if buffer != nil {
   320  		contentLen = len(buffer)
   321  	}
   322  	if contentLen < 900 && c.isTLS {
   323  		l, _ := rand.Int(rand.Reader, big.NewInt(500))
   324  		paddingLen = int(l.Int64()) + 900 - contentLen
   325  	} else {
   326  		l, _ := rand.Int(rand.Reader, big.NewInt(256))
   327  		paddingLen = int(l.Int64())
   328  	}
   329  	var bufferLen int
   330  	if c.writeUUID {
   331  		bufferLen += 16
   332  	}
   333  	bufferLen += 5
   334  	if buffer != nil {
   335  		bufferLen += len(buffer)
   336  	}
   337  	bufferLen += paddingLen
   338  
   339  	newBuffer := bytes.NewBuffer(nil)
   340  	if c.writeUUID {
   341  		newBuffer.Write(c.userUUID[:])
   342  		c.writeUUID = false
   343  	}
   344  	newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)})
   345  	if buffer != nil {
   346  		newBuffer.Write(buffer)
   347  	}
   348  	newBuffer.Write(make([]byte, paddingLen))
   349  	// newBuffer.Extend(paddingLen)
   350  	log.Debug("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
   351  	return newBuffer.Bytes()
   352  }
   353  
   354  func (c *VisionConn) unPadding(buffer []byte) [][]byte {
   355  	var bufferIndex int
   356  	if c.remainingContent == -1 && c.remainingPadding == -1 {
   357  		if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) {
   358  			bufferIndex = 16
   359  			c.remainingContent = 0
   360  			c.remainingPadding = 0
   361  			c.currentCommand = 0
   362  		}
   363  	}
   364  	if c.remainingContent == -1 && c.remainingPadding == -1 {
   365  		return [][]byte{buffer}
   366  	}
   367  	var buffers [][]byte
   368  	for bufferIndex < len(buffer) {
   369  		if c.remainingContent <= 0 && c.remainingPadding <= 0 {
   370  			if c.currentCommand == 1 {
   371  				buffers = append(buffers, buffer[bufferIndex:])
   372  				break
   373  			} else {
   374  				paddingInfo := buffer[bufferIndex : bufferIndex+5]
   375  				c.currentCommand = paddingInfo[0]
   376  				c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
   377  				c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
   378  				bufferIndex += 5
   379  				log.Debug("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
   380  			}
   381  		} else if c.remainingContent > 0 {
   382  			end := c.remainingContent
   383  			if end > len(buffer)-bufferIndex {
   384  				end = len(buffer) - bufferIndex
   385  			}
   386  			buffers = append(buffers, buffer[bufferIndex:bufferIndex+end])
   387  			c.remainingContent -= end
   388  			bufferIndex += end
   389  		} else {
   390  			end := c.remainingPadding
   391  			if end > len(buffer)-bufferIndex {
   392  				end = len(buffer) - bufferIndex
   393  			}
   394  			c.remainingPadding -= end
   395  			bufferIndex += end
   396  		}
   397  		if bufferIndex == len(buffer) {
   398  			break
   399  		}
   400  	}
   401  	return buffers
   402  }
   403  
   404  func (c *VisionConn) NeedAdditionalReadDeadline() bool {
   405  	return true
   406  }
   407  
   408  func (c *VisionConn) Upstream() any {
   409  	return c.Conn
   410  }