github.com/sagernet/sing-box@v1.2.7/transport/vless/vision.go (about)

     1  package vless
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"io"
     8  	"math/big"
     9  	"net"
    10  	"reflect"
    11  	"time"
    12  	"unsafe"
    13  
    14  	C "github.com/sagernet/sing-box/constant"
    15  	"github.com/sagernet/sing/common"
    16  	"github.com/sagernet/sing/common/buf"
    17  	"github.com/sagernet/sing/common/bufio"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	"github.com/sagernet/sing/common/logger"
    20  	N "github.com/sagernet/sing/common/network"
    21  )
    22  
    23  var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr)
    24  
    25  func init() {
    26  	tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) {
    27  		tlsConn, loaded := conn.(*tls.Conn)
    28  		if !loaded {
    29  			return
    30  		}
    31  		return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn))
    32  	})
    33  }
    34  
    35  const xrayChunkSize = 8192
    36  
    37  type VisionConn struct {
    38  	net.Conn
    39  	reader   *bufio.ChunkReader
    40  	writer   N.VectorisedWriter
    41  	input    *bytes.Reader
    42  	rawInput *bytes.Buffer
    43  	netConn  net.Conn
    44  	logger   logger.Logger
    45  
    46  	userUUID               [16]byte
    47  	isTLS                  bool
    48  	numberOfPacketToFilter int
    49  	isTLS12orAbove         bool
    50  	remainingServerHello   int32
    51  	cipher                 uint16
    52  	enableXTLS             bool
    53  	isPadding              bool
    54  	directWrite            bool
    55  	writeUUID              bool
    56  	withinPaddingBuffers   bool
    57  	remainingContent       int
    58  	remainingPadding       int
    59  	currentCommand         int
    60  	directRead             bool
    61  	remainingReader        io.Reader
    62  }
    63  
    64  func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
    65  	var (
    66  		loaded         bool
    67  		reflectType    reflect.Type
    68  		reflectPointer uintptr
    69  		netConn        net.Conn
    70  	)
    71  	for _, tlsCreator := range tlsRegistry {
    72  		loaded, netConn, reflectType, reflectPointer = tlsCreator(conn)
    73  		if loaded {
    74  			break
    75  		}
    76  	}
    77  	if !loaded {
    78  		return nil, C.ErrTLSRequired
    79  	}
    80  	input, _ := reflectType.FieldByName("input")
    81  	rawInput, _ := reflectType.FieldByName("rawInput")
    82  	return &VisionConn{
    83  		Conn:     conn,
    84  		reader:   bufio.NewChunkReader(conn, xrayChunkSize),
    85  		writer:   bufio.NewVectorisedWriter(conn),
    86  		input:    (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
    87  		rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
    88  		netConn:  netConn,
    89  		logger:   logger,
    90  
    91  		userUUID:               userUUID,
    92  		numberOfPacketToFilter: 8,
    93  		remainingServerHello:   -1,
    94  		isPadding:              true,
    95  		writeUUID:              true,
    96  		withinPaddingBuffers:   true,
    97  		remainingContent:       -1,
    98  		remainingPadding:       -1,
    99  	}, nil
   100  }
   101  
   102  func (c *VisionConn) Read(p []byte) (n int, err error) {
   103  	if c.remainingReader != nil {
   104  		n, err = c.remainingReader.Read(p)
   105  		if err == io.EOF {
   106  			c.remainingReader = nil
   107  		}
   108  		if n > 0 {
   109  			return
   110  		}
   111  	}
   112  	if c.directRead {
   113  		return c.netConn.Read(p)
   114  	}
   115  	var bufferBytes []byte
   116  	if len(p) > xrayChunkSize {
   117  		n, err = c.Conn.Read(p)
   118  		if err != nil {
   119  			return
   120  		}
   121  		bufferBytes = p[:n]
   122  	} else {
   123  		buffer, err := c.reader.ReadChunk()
   124  		if err != nil {
   125  			return 0, err
   126  		}
   127  		defer buffer.FullReset()
   128  		bufferBytes = buffer.Bytes()
   129  	}
   130  	if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
   131  		buffers := c.unPadding(bufferBytes)
   132  		if c.remainingContent == 0 && c.remainingPadding == 0 {
   133  			if c.currentCommand == 1 {
   134  				c.withinPaddingBuffers = false
   135  				c.remainingContent = -1
   136  				c.remainingPadding = -1
   137  			} else if c.currentCommand == 2 {
   138  				c.withinPaddingBuffers = false
   139  				c.directRead = true
   140  
   141  				inputBuffer, err := io.ReadAll(c.input)
   142  				if err != nil {
   143  					return 0, err
   144  				}
   145  				buffers = append(buffers, inputBuffer)
   146  
   147  				rawInputBuffer, err := io.ReadAll(c.rawInput)
   148  				if err != nil {
   149  					return 0, err
   150  				}
   151  
   152  				buffers = append(buffers, rawInputBuffer)
   153  
   154  				c.logger.Trace("XtlsRead readV")
   155  			} else if c.currentCommand == 0 {
   156  				c.withinPaddingBuffers = true
   157  			} else {
   158  				return 0, E.New("unknown command ", c.currentCommand)
   159  			}
   160  		} else if c.remainingContent > 0 || c.remainingPadding > 0 {
   161  			c.withinPaddingBuffers = true
   162  		} else {
   163  			c.withinPaddingBuffers = false
   164  		}
   165  		if c.numberOfPacketToFilter > 0 {
   166  			c.filterTLS(buffers)
   167  		}
   168  		c.remainingReader = io.MultiReader(common.Map(buffers, func(it []byte) io.Reader { return bytes.NewReader(it) })...)
   169  		return c.Read(p)
   170  	} else {
   171  		if c.numberOfPacketToFilter > 0 {
   172  			c.filterTLS([][]byte{bufferBytes})
   173  		}
   174  		return
   175  	}
   176  }
   177  
   178  func (c *VisionConn) Write(p []byte) (n int, err error) {
   179  	if c.numberOfPacketToFilter > 0 {
   180  		c.filterTLS([][]byte{p})
   181  	}
   182  	if c.isPadding {
   183  		inputLen := len(p)
   184  		buffers := reshapeBuffer(p)
   185  		var specIndex int
   186  		for i, buffer := range buffers {
   187  			if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
   188  				var command byte = commandPaddingEnd
   189  				if c.enableXTLS {
   190  					c.directWrite = true
   191  					specIndex = i
   192  					command = commandPaddingDirect
   193  				}
   194  				c.isPadding = false
   195  				buffers[i] = c.padding(buffer, command)
   196  				break
   197  			} else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
   198  				c.isPadding = false
   199  				buffers[i] = c.padding(buffer, commandPaddingEnd)
   200  				break
   201  			}
   202  			buffers[i] = c.padding(buffer, commandPaddingContinue)
   203  		}
   204  		if c.directWrite {
   205  			encryptedBuffer := buffers[:specIndex+1]
   206  			err = c.writer.WriteVectorised(encryptedBuffer)
   207  			if err != nil {
   208  				return
   209  			}
   210  			buffers = buffers[specIndex+1:]
   211  			c.writer = bufio.NewVectorisedWriter(c.netConn)
   212  			c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
   213  			time.Sleep(5 * time.Millisecond) // wtf
   214  		}
   215  		err = c.writer.WriteVectorised(buffers)
   216  		if err == nil {
   217  			n = inputLen
   218  		}
   219  		return
   220  	}
   221  	if c.directWrite {
   222  		return c.netConn.Write(p)
   223  	} else {
   224  		return c.Conn.Write(p)
   225  	}
   226  }
   227  
   228  func (c *VisionConn) filterTLS(buffers [][]byte) {
   229  	for _, buffer := range buffers {
   230  		c.numberOfPacketToFilter--
   231  		if len(buffer) > 6 {
   232  			if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
   233  				c.isTLS = true
   234  				if buffer[5] == 2 {
   235  					c.isTLS12orAbove = true
   236  					c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5
   237  					if len(buffer) >= 79 && c.remainingServerHello >= 79 {
   238  						sessionIdLen := int32(buffer[43])
   239  						cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
   240  						c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
   241  					} else {
   242  						c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
   243  					}
   244  				}
   245  			} else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
   246  				c.isTLS = true
   247  				c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
   248  			}
   249  		}
   250  		if c.remainingServerHello > 0 {
   251  			end := int(c.remainingServerHello)
   252  			if end > len(buffer) {
   253  				end = len(buffer)
   254  			}
   255  			c.remainingServerHello -= int32(end)
   256  			if bytes.Contains(buffer[:end], tls13SupportedVersions) {
   257  				cipher, ok := tls13CipherSuiteDic[c.cipher]
   258  				if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
   259  					c.enableXTLS = true
   260  				}
   261  				c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
   262  				c.numberOfPacketToFilter = 0
   263  				return
   264  			} else if c.remainingServerHello == 0 {
   265  				c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
   266  				c.numberOfPacketToFilter = 0
   267  				return
   268  			}
   269  		}
   270  		if c.numberOfPacketToFilter == 0 {
   271  			c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
   272  		}
   273  	}
   274  }
   275  
   276  func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
   277  	contentLen := 0
   278  	paddingLen := 0
   279  	if buffer != nil {
   280  		contentLen = buffer.Len()
   281  	}
   282  	if contentLen < 900 && c.isTLS {
   283  		l, _ := rand.Int(rand.Reader, big.NewInt(500))
   284  		paddingLen = int(l.Int64()) + 900 - contentLen
   285  	} else {
   286  		l, _ := rand.Int(rand.Reader, big.NewInt(256))
   287  		paddingLen = int(l.Int64())
   288  	}
   289  	var bufferLen int
   290  	if c.writeUUID {
   291  		bufferLen += 16
   292  	}
   293  	bufferLen += 5
   294  	if buffer != nil {
   295  		bufferLen += buffer.Len()
   296  	}
   297  	bufferLen += paddingLen
   298  	newBuffer := buf.NewSize(bufferLen)
   299  	if c.writeUUID {
   300  		common.Must1(newBuffer.Write(c.userUUID[:]))
   301  		c.writeUUID = false
   302  	}
   303  	common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}))
   304  	if buffer != nil {
   305  		common.Must1(newBuffer.Write(buffer.Bytes()))
   306  		buffer.Release()
   307  	}
   308  	newBuffer.Extend(paddingLen)
   309  	c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
   310  	return newBuffer
   311  }
   312  
   313  func (c *VisionConn) unPadding(buffer []byte) [][]byte {
   314  	var bufferIndex int
   315  	if c.remainingContent == -1 && c.remainingPadding == -1 {
   316  		if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) {
   317  			bufferIndex = 16
   318  			c.remainingContent = 0
   319  			c.remainingPadding = 0
   320  			c.currentCommand = 0
   321  		}
   322  	}
   323  	if c.remainingContent == -1 && c.remainingPadding == -1 {
   324  		return [][]byte{buffer}
   325  	}
   326  	var buffers [][]byte
   327  	for bufferIndex < len(buffer) {
   328  		if c.remainingContent <= 0 && c.remainingPadding <= 0 {
   329  			if c.currentCommand == 1 {
   330  				buffers = append(buffers, buffer[bufferIndex:])
   331  				break
   332  			} else {
   333  				paddingInfo := buffer[bufferIndex : bufferIndex+5]
   334  				c.currentCommand = int(paddingInfo[0])
   335  				c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
   336  				c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
   337  				bufferIndex += 5
   338  				c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
   339  			}
   340  		} else if c.remainingContent > 0 {
   341  			end := c.remainingContent
   342  			if end > len(buffer)-bufferIndex {
   343  				end = len(buffer) - bufferIndex
   344  			}
   345  			buffers = append(buffers, buffer[bufferIndex:bufferIndex+end])
   346  			c.remainingContent -= end
   347  			bufferIndex += end
   348  		} else {
   349  			end := c.remainingPadding
   350  			if end > len(buffer)-bufferIndex {
   351  				end = len(buffer) - bufferIndex
   352  			}
   353  			c.remainingPadding -= end
   354  			bufferIndex += end
   355  		}
   356  		if bufferIndex == len(buffer) {
   357  			break
   358  		}
   359  	}
   360  	return buffers
   361  }
   362  
   363  func (c *VisionConn) NeedAdditionalReadDeadline() bool {
   364  	return true
   365  }
   366  
   367  func (c *VisionConn) Upstream() any {
   368  	return c.Conn
   369  }