github.com/sagernet/sing-box@v1.9.0-rc.20/transport/vless/vision.go (about)

     1  package vless
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/tls"
     7  	"io"
     8  	"math/big"
     9  	"net"
    10  	"reflect"
    11  	"time"
    12  	"unsafe"
    13  
    14  	C "github.com/sagernet/sing-box/constant"
    15  	"github.com/sagernet/sing/common"
    16  	"github.com/sagernet/sing/common/buf"
    17  	"github.com/sagernet/sing/common/bufio"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	"github.com/sagernet/sing/common/logger"
    20  	N "github.com/sagernet/sing/common/network"
    21  )
    22  
    23  var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr)
    24  
    25  func init() {
    26  	tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) {
    27  		tlsConn, loaded := common.Cast[*tls.Conn](conn)
    28  		if !loaded {
    29  			return
    30  		}
    31  		return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn))
    32  	})
    33  }
    34  
    35  const xrayChunkSize = 8192
    36  
    37  type VisionConn struct {
    38  	net.Conn
    39  	reader   *bufio.ChunkReader
    40  	writer   N.VectorisedWriter
    41  	input    *bytes.Reader
    42  	rawInput *bytes.Buffer
    43  	netConn  net.Conn
    44  	logger   logger.Logger
    45  
    46  	userUUID               [16]byte
    47  	isTLS                  bool
    48  	numberOfPacketToFilter int
    49  	isTLS12orAbove         bool
    50  	remainingServerHello   int32
    51  	cipher                 uint16
    52  	enableXTLS             bool
    53  	isPadding              bool
    54  	directWrite            bool
    55  	writeUUID              bool
    56  	withinPaddingBuffers   bool
    57  	remainingContent       int
    58  	remainingPadding       int
    59  	currentCommand         byte
    60  	directRead             bool
    61  	remainingReader        io.Reader
    62  }
    63  
    64  func NewVisionConn(conn net.Conn, tlsConn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
    65  	var (
    66  		loaded         bool
    67  		reflectType    reflect.Type
    68  		reflectPointer uintptr
    69  		netConn        net.Conn
    70  	)
    71  	for _, tlsCreator := range tlsRegistry {
    72  		loaded, netConn, reflectType, reflectPointer = tlsCreator(tlsConn)
    73  		if loaded {
    74  			break
    75  		}
    76  	}
    77  	if !loaded {
    78  		return nil, C.ErrTLSRequired
    79  	}
    80  	input, _ := reflectType.FieldByName("input")
    81  	rawInput, _ := reflectType.FieldByName("rawInput")
    82  	return &VisionConn{
    83  		Conn:     conn,
    84  		reader:   bufio.NewChunkReader(conn, xrayChunkSize),
    85  		writer:   bufio.NewVectorisedWriter(conn),
    86  		input:    (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
    87  		rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
    88  		netConn:  netConn,
    89  		logger:   logger,
    90  
    91  		userUUID:               userUUID,
    92  		numberOfPacketToFilter: 8,
    93  		remainingServerHello:   -1,
    94  		isPadding:              true,
    95  		writeUUID:              true,
    96  		withinPaddingBuffers:   true,
    97  		remainingContent:       -1,
    98  		remainingPadding:       -1,
    99  	}, nil
   100  }
   101  
   102  func (c *VisionConn) Read(p []byte) (n int, err error) {
   103  	if c.remainingReader != nil {
   104  		n, err = c.remainingReader.Read(p)
   105  		if err == io.EOF {
   106  			err = nil
   107  			c.remainingReader = nil
   108  		}
   109  		if n > 0 {
   110  			return
   111  		}
   112  	}
   113  	if c.directRead {
   114  		return c.netConn.Read(p)
   115  	}
   116  	var bufferBytes []byte
   117  	var chunkBuffer *buf.Buffer
   118  	if len(p) > xrayChunkSize {
   119  		n, err = c.Conn.Read(p)
   120  		if err != nil {
   121  			return
   122  		}
   123  		bufferBytes = p[:n]
   124  	} else {
   125  		chunkBuffer, err = c.reader.ReadChunk()
   126  		if err != nil {
   127  			return 0, err
   128  		}
   129  		bufferBytes = chunkBuffer.Bytes()
   130  	}
   131  	if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
   132  		buffers := c.unPadding(bufferBytes)
   133  		if chunkBuffer != nil {
   134  			buffers = common.Map(buffers, func(it *buf.Buffer) *buf.Buffer {
   135  				return it.ToOwned()
   136  			})
   137  			chunkBuffer.Reset()
   138  		}
   139  		if c.remainingContent == 0 && c.remainingPadding == 0 {
   140  			if c.currentCommand == commandPaddingEnd {
   141  				c.withinPaddingBuffers = false
   142  				c.remainingContent = -1
   143  				c.remainingPadding = -1
   144  			} else if c.currentCommand == commandPaddingDirect {
   145  				c.withinPaddingBuffers = false
   146  				c.directRead = true
   147  
   148  				inputBuffer, err := io.ReadAll(c.input)
   149  				if err != nil {
   150  					return 0, err
   151  				}
   152  				buffers = append(buffers, buf.As(inputBuffer))
   153  
   154  				rawInputBuffer, err := io.ReadAll(c.rawInput)
   155  				if err != nil {
   156  					return 0, err
   157  				}
   158  
   159  				buffers = append(buffers, buf.As(rawInputBuffer))
   160  
   161  				c.logger.Trace("XtlsRead readV")
   162  			} else if c.currentCommand == commandPaddingContinue {
   163  				c.withinPaddingBuffers = true
   164  			} else {
   165  				return 0, E.New("unknown command ", c.currentCommand)
   166  			}
   167  		} else if c.remainingContent > 0 || c.remainingPadding > 0 {
   168  			c.withinPaddingBuffers = true
   169  		} else {
   170  			c.withinPaddingBuffers = false
   171  		}
   172  		if c.numberOfPacketToFilter > 0 {
   173  			c.filterTLS(buf.ToSliceMulti(buffers))
   174  		}
   175  		c.remainingReader = io.MultiReader(common.Map(buffers, func(it *buf.Buffer) io.Reader { return it })...)
   176  		return c.Read(p)
   177  	} else {
   178  		if c.numberOfPacketToFilter > 0 {
   179  			c.filterTLS([][]byte{bufferBytes})
   180  		}
   181  		if chunkBuffer != nil {
   182  			n = copy(p, bufferBytes)
   183  			chunkBuffer.Advance(n)
   184  		}
   185  		return
   186  	}
   187  }
   188  
   189  func (c *VisionConn) Write(p []byte) (n int, err error) {
   190  	if c.numberOfPacketToFilter > 0 {
   191  		c.filterTLS([][]byte{p})
   192  	}
   193  	if c.isPadding {
   194  		inputLen := len(p)
   195  		buffers := reshapeBuffer(p)
   196  		var specIndex int
   197  		for i, buffer := range buffers {
   198  			if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
   199  				var command byte = commandPaddingEnd
   200  				if c.enableXTLS {
   201  					c.directWrite = true
   202  					specIndex = i
   203  					command = commandPaddingDirect
   204  				}
   205  				c.isPadding = false
   206  				buffers[i] = c.padding(buffer, command)
   207  				break
   208  			} else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
   209  				c.isPadding = false
   210  				buffers[i] = c.padding(buffer, commandPaddingEnd)
   211  				break
   212  			}
   213  			buffers[i] = c.padding(buffer, commandPaddingContinue)
   214  		}
   215  		if c.directWrite {
   216  			encryptedBuffer := buffers[:specIndex+1]
   217  			err = c.writer.WriteVectorised(encryptedBuffer)
   218  			if err != nil {
   219  				return
   220  			}
   221  			buffers = buffers[specIndex+1:]
   222  			c.writer = bufio.NewVectorisedWriter(c.netConn)
   223  			c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
   224  			time.Sleep(5 * time.Millisecond) // wtf
   225  		}
   226  		err = c.writer.WriteVectorised(buffers)
   227  		if err == nil {
   228  			n = inputLen
   229  		}
   230  		return
   231  	}
   232  	if c.directWrite {
   233  		return c.netConn.Write(p)
   234  	} else {
   235  		return c.Conn.Write(p)
   236  	}
   237  }
   238  
   239  func (c *VisionConn) filterTLS(buffers [][]byte) {
   240  	for _, buffer := range buffers {
   241  		c.numberOfPacketToFilter--
   242  		if len(buffer) > 6 {
   243  			if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
   244  				c.isTLS = true
   245  				if buffer[5] == 2 {
   246  					c.isTLS12orAbove = true
   247  					c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5
   248  					if len(buffer) >= 79 && c.remainingServerHello >= 79 {
   249  						sessionIdLen := int32(buffer[43])
   250  						cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
   251  						c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
   252  					} else {
   253  						c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
   254  					}
   255  				}
   256  			} else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
   257  				c.isTLS = true
   258  				c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
   259  			}
   260  		}
   261  		if c.remainingServerHello > 0 {
   262  			end := int(c.remainingServerHello)
   263  			if end > len(buffer) {
   264  				end = len(buffer)
   265  			}
   266  			c.remainingServerHello -= int32(end)
   267  			if bytes.Contains(buffer[:end], tls13SupportedVersions) {
   268  				cipher, ok := tls13CipherSuiteDic[c.cipher]
   269  				if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
   270  					c.enableXTLS = true
   271  				}
   272  				c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
   273  				c.numberOfPacketToFilter = 0
   274  				return
   275  			} else if c.remainingServerHello == 0 {
   276  				c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
   277  				c.numberOfPacketToFilter = 0
   278  				return
   279  			}
   280  		}
   281  		if c.numberOfPacketToFilter == 0 {
   282  			c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
   283  		}
   284  	}
   285  }
   286  
   287  func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
   288  	contentLen := 0
   289  	paddingLen := 0
   290  	if buffer != nil {
   291  		contentLen = buffer.Len()
   292  	}
   293  	if contentLen < 900 && c.isTLS {
   294  		l, _ := rand.Int(rand.Reader, big.NewInt(500))
   295  		paddingLen = int(l.Int64()) + 900 - contentLen
   296  	} else {
   297  		l, _ := rand.Int(rand.Reader, big.NewInt(256))
   298  		paddingLen = int(l.Int64())
   299  	}
   300  	var bufferLen int
   301  	if c.writeUUID {
   302  		bufferLen += 16
   303  	}
   304  	bufferLen += 5
   305  	if buffer != nil {
   306  		bufferLen += buffer.Len()
   307  	}
   308  	bufferLen += paddingLen
   309  	newBuffer := buf.NewSize(bufferLen)
   310  	if c.writeUUID {
   311  		common.Must1(newBuffer.Write(c.userUUID[:]))
   312  		c.writeUUID = false
   313  	}
   314  	common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}))
   315  	if buffer != nil {
   316  		common.Must1(newBuffer.Write(buffer.Bytes()))
   317  		buffer.Release()
   318  	}
   319  	newBuffer.Extend(paddingLen)
   320  	c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
   321  	return newBuffer
   322  }
   323  
   324  func (c *VisionConn) unPadding(buffer []byte) []*buf.Buffer {
   325  	var bufferIndex int
   326  	if c.remainingContent == -1 && c.remainingPadding == -1 {
   327  		if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) {
   328  			bufferIndex = 16
   329  			c.remainingContent = 0
   330  			c.remainingPadding = 0
   331  			c.currentCommand = 0
   332  		}
   333  	}
   334  	if c.remainingContent == -1 && c.remainingPadding == -1 {
   335  		return []*buf.Buffer{buf.As(buffer)}
   336  	}
   337  	var buffers []*buf.Buffer
   338  	for bufferIndex < len(buffer) {
   339  		if c.remainingContent <= 0 && c.remainingPadding <= 0 {
   340  			if c.currentCommand == 1 {
   341  				buffers = append(buffers, buf.As(buffer[bufferIndex:]))
   342  				break
   343  			} else {
   344  				paddingInfo := buffer[bufferIndex : bufferIndex+5]
   345  				c.currentCommand = paddingInfo[0]
   346  				c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
   347  				c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
   348  				bufferIndex += 5
   349  				c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
   350  			}
   351  		} else if c.remainingContent > 0 {
   352  			end := c.remainingContent
   353  			if end > len(buffer)-bufferIndex {
   354  				end = len(buffer) - bufferIndex
   355  			}
   356  			buffers = append(buffers, buf.As(buffer[bufferIndex:bufferIndex+end]))
   357  			c.remainingContent -= end
   358  			bufferIndex += end
   359  		} else {
   360  			end := c.remainingPadding
   361  			if end > len(buffer)-bufferIndex {
   362  				end = len(buffer) - bufferIndex
   363  			}
   364  			c.remainingPadding -= end
   365  			bufferIndex += end
   366  		}
   367  		if bufferIndex == len(buffer) {
   368  			break
   369  		}
   370  	}
   371  	return buffers
   372  }
   373  
   374  func (c *VisionConn) NeedAdditionalReadDeadline() bool {
   375  	return true
   376  }
   377  
   378  func (c *VisionConn) Upstream() any {
   379  	return c.Conn
   380  }