github.com/xmplusdev/xray-core@v1.8.10/proxy/proxy.go (about)

     1  // Package proxy contains all proxies used by Xray.
     2  //
     3  // To implement an inbound or outbound proxy, one needs to do the following:
     4  // 1. Implement the interface(s) below.
     5  // 2. Register a config creator through common.RegisterConfig.
     6  package proxy
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"crypto/rand"
    12  	"io"
    13  	"math/big"
    14  	"runtime"
    15  	"strconv"
    16  	"time"
    17  
    18  	"github.com/pires/go-proxyproto"
    19  	"github.com/xmplusdev/xray-core/app/dispatcher"
    20  	"github.com/xmplusdev/xray-core/common/buf"
    21  	"github.com/xmplusdev/xray-core/common/errors"
    22  	"github.com/xmplusdev/xray-core/common/net"
    23  	"github.com/xmplusdev/xray-core/common/protocol"
    24  	"github.com/xmplusdev/xray-core/common/session"
    25  	"github.com/xmplusdev/xray-core/common/signal"
    26  	"github.com/xmplusdev/xray-core/features/routing"
    27  	"github.com/xmplusdev/xray-core/features/stats"
    28  	"github.com/xmplusdev/xray-core/transport"
    29  	"github.com/xmplusdev/xray-core/transport/internet"
    30  	"github.com/xmplusdev/xray-core/transport/internet/reality"
    31  	"github.com/xmplusdev/xray-core/transport/internet/stat"
    32  	"github.com/xmplusdev/xray-core/transport/internet/tls"
    33  )
    34  
    35  var (
    36  	Tls13SupportedVersions  = []byte{0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}
    37  	TlsClientHandShakeStart = []byte{0x16, 0x03}
    38  	TlsServerHandShakeStart = []byte{0x16, 0x03, 0x03}
    39  	TlsApplicationDataStart = []byte{0x17, 0x03, 0x03}
    40  
    41  	Tls13CipherSuiteDic = map[uint16]string{
    42  		0x1301: "TLS_AES_128_GCM_SHA256",
    43  		0x1302: "TLS_AES_256_GCM_SHA384",
    44  		0x1303: "TLS_CHACHA20_POLY1305_SHA256",
    45  		0x1304: "TLS_AES_128_CCM_SHA256",
    46  		0x1305: "TLS_AES_128_CCM_8_SHA256",
    47  	}
    48  )
    49  
    50  const (
    51  	TlsHandshakeTypeClientHello byte = 0x01
    52  	TlsHandshakeTypeServerHello byte = 0x02
    53  
    54  	CommandPaddingContinue byte = 0x00
    55  	CommandPaddingEnd      byte = 0x01
    56  	CommandPaddingDirect   byte = 0x02
    57  )
    58  
    59  // An Inbound processes inbound connections.
    60  type Inbound interface {
    61  	// Network returns a list of networks that this inbound supports. Connections with not-supported networks will not be passed into Process().
    62  	Network() []net.Network
    63  
    64  	// Process processes a connection of given network. If necessary, the Inbound can dispatch the connection to an Outbound.
    65  	Process(context.Context, net.Network, stat.Connection, routing.Dispatcher) error
    66  }
    67  
    68  // An Outbound process outbound connections.
    69  type Outbound interface {
    70  	// Process processes the given connection. The given dialer may be used to dial a system outbound connection.
    71  	Process(context.Context, *transport.Link, internet.Dialer) error
    72  }
    73  
    74  // UserManager is the interface for Inbounds and Outbounds that can manage their users.
    75  type UserManager interface {
    76  	// AddUser adds a new user.
    77  	AddUser(context.Context, *protocol.MemoryUser) error
    78  
    79  	// RemoveUser removes a user by email.
    80  	RemoveUser(context.Context, string) error
    81  }
    82  
    83  type GetInbound interface {
    84  	GetInbound() Inbound
    85  }
    86  
    87  type GetOutbound interface {
    88  	GetOutbound() Outbound
    89  }
    90  
    91  // TrafficState is used to track uplink and downlink of one connection
    92  // It is used by XTLS to determine if switch to raw copy mode, It is used by Vision to calculate padding
    93  type TrafficState struct {
    94  	UserUUID               []byte
    95  	NumberOfPacketToFilter int
    96  	EnableXtls             bool
    97  	IsTLS12orAbove         bool
    98  	IsTLS                  bool
    99  	Cipher                 uint16
   100  	RemainingServerHello   int32
   101  
   102  	// reader link state
   103  	WithinPaddingBuffers     bool
   104  	ReaderSwitchToDirectCopy bool
   105  	RemainingCommand		 int32
   106  	RemainingContent         int32
   107  	RemainingPadding         int32
   108  	CurrentCommand           int
   109  
   110  	// write link state
   111  	IsPadding                bool
   112  	WriterSwitchToDirectCopy bool
   113  }
   114  
   115  func NewTrafficState(userUUID []byte) *TrafficState {
   116  	return &TrafficState{
   117  		UserUUID:                 userUUID,
   118  		NumberOfPacketToFilter:   8,
   119  		EnableXtls:               false,
   120  		IsTLS12orAbove:           false,
   121  		IsTLS:                    false,
   122  		Cipher:                   0,
   123  		RemainingServerHello:     -1,
   124  		WithinPaddingBuffers:     true,
   125  		ReaderSwitchToDirectCopy: false,
   126  		RemainingCommand:         -1,
   127  		RemainingContent:         -1,
   128  		RemainingPadding:         -1,
   129  		CurrentCommand:           0,
   130  		IsPadding:                true,
   131  		WriterSwitchToDirectCopy: false,
   132  	}
   133  }
   134  
   135  // VisionReader is used to read xtls vision protocol
   136  // Note Vision probably only make sense as the inner most layer of reader, since it need assess traffic state from origin proxy traffic
   137  type VisionReader struct {
   138  	buf.Reader
   139  	trafficState *TrafficState
   140  	ctx          context.Context
   141  }
   142  
   143  func NewVisionReader(reader buf.Reader, state *TrafficState, context context.Context) *VisionReader {
   144  	return &VisionReader{
   145  		Reader:       reader,
   146  		trafficState: state,
   147  		ctx:          context,
   148  	}
   149  }
   150  
   151  func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   152  	buffer, err := w.Reader.ReadMultiBuffer()
   153  	if !buffer.IsEmpty() {
   154  		if w.trafficState.WithinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 {
   155  			mb2 := make(buf.MultiBuffer, 0, len(buffer))
   156  			for _, b := range buffer {
   157  				newbuffer := XtlsUnpadding(b, w.trafficState, w.ctx)
   158  				if newbuffer.Len() > 0 {
   159  					mb2 = append(mb2, newbuffer)
   160  				}
   161  			}
   162  			buffer = mb2
   163  			if w.trafficState.RemainingContent > 0 || w.trafficState.RemainingPadding > 0 || w.trafficState.CurrentCommand == 0 {
   164  				w.trafficState.WithinPaddingBuffers = true
   165  			} else if w.trafficState.CurrentCommand == 1 {
   166  				w.trafficState.WithinPaddingBuffers = false
   167  			} else if w.trafficState.CurrentCommand == 2 {
   168  				w.trafficState.WithinPaddingBuffers = false
   169  				w.trafficState.ReaderSwitchToDirectCopy = true
   170  			} else {
   171  				newError("XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len()).WriteToLog(session.ExportIDToError(w.ctx))
   172  			}
   173  		}
   174  		if w.trafficState.NumberOfPacketToFilter > 0 {
   175  			XtlsFilterTls(buffer, w.trafficState, w.ctx)
   176  		}
   177  	}
   178  	return buffer, err
   179  }
   180  
   181  // VisionWriter is used to write xtls vision protocol
   182  // Note Vision probably only make sense as the inner most layer of writer, since it need assess traffic state from origin proxy traffic
   183  type VisionWriter struct {
   184  	buf.Writer
   185  	trafficState      *TrafficState
   186  	ctx               context.Context
   187  	writeOnceUserUUID []byte
   188  }
   189  
   190  func NewVisionWriter(writer buf.Writer, state *TrafficState, context context.Context) *VisionWriter {
   191  	w := make([]byte, len(state.UserUUID))
   192  	copy(w, state.UserUUID)
   193  	return &VisionWriter{
   194  		Writer:            writer,
   195  		trafficState:      state,
   196  		ctx:               context,
   197  		writeOnceUserUUID: w,
   198  	}
   199  }
   200  
   201  func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   202  	if w.trafficState.NumberOfPacketToFilter > 0 {
   203  		XtlsFilterTls(mb, w.trafficState, w.ctx)
   204  	}
   205  	if w.trafficState.IsPadding {
   206  		if len(mb) == 1 && mb[0] == nil {
   207  			mb[0] = XtlsPadding(nil, CommandPaddingContinue, &w.writeOnceUserUUID, true, w.ctx) // we do a long padding to hide vless header
   208  			return w.Writer.WriteMultiBuffer(mb)
   209  		}
   210  		mb = ReshapeMultiBuffer(w.ctx, mb)
   211  		longPadding := w.trafficState.IsTLS
   212  		for i, b := range mb {
   213  			if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) {
   214  				if w.trafficState.EnableXtls {
   215  					w.trafficState.WriterSwitchToDirectCopy = true
   216  				}
   217  				var command byte = CommandPaddingContinue
   218  				if i == len(mb) - 1 {
   219  					command = CommandPaddingEnd
   220  					if w.trafficState.EnableXtls {
   221  						command = CommandPaddingDirect
   222  					}
   223  				}
   224  				mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, true, w.ctx)
   225  				w.trafficState.IsPadding = false // padding going to end
   226  				longPadding = false
   227  				continue
   228  			} else if !w.trafficState.IsTLS12orAbove && w.trafficState.NumberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early
   229  				w.trafficState.IsPadding = false
   230  				mb[i] = XtlsPadding(b, CommandPaddingEnd, &w.writeOnceUserUUID, longPadding, w.ctx)
   231  				break
   232  			}
   233  			var command byte = CommandPaddingContinue
   234  			if i == len(mb) - 1 && !w.trafficState.IsPadding {
   235  				command = CommandPaddingEnd
   236  				if w.trafficState.EnableXtls {
   237  					command = CommandPaddingDirect
   238  				}
   239  			}
   240  			mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, longPadding, w.ctx)
   241  		}
   242  	}
   243  	return w.Writer.WriteMultiBuffer(mb)
   244  }
   245  
   246  // ReshapeMultiBuffer prepare multi buffer for padding stucture (max 21 bytes)
   247  func ReshapeMultiBuffer(ctx context.Context, buffer buf.MultiBuffer) buf.MultiBuffer {
   248  	needReshape := 0
   249  	for _, b := range buffer {
   250  		if b.Len() >= buf.Size-21 {
   251  			needReshape += 1
   252  		}
   253  	}
   254  	if needReshape == 0 {
   255  		return buffer
   256  	}
   257  	mb2 := make(buf.MultiBuffer, 0, len(buffer)+needReshape)
   258  	toPrint := ""
   259  	for i, buffer1 := range buffer {
   260  		if buffer1.Len() >= buf.Size-21 {
   261  			index := int32(bytes.LastIndex(buffer1.Bytes(), TlsApplicationDataStart))
   262  			if index < 21 || index > buf.Size-21 {
   263  				index = buf.Size / 2
   264  			}
   265  			buffer2 := buf.New()
   266  			buffer2.Write(buffer1.BytesFrom(index))
   267  			buffer1.Resize(0, index)
   268  			mb2 = append(mb2, buffer1, buffer2)
   269  			toPrint += " " + strconv.Itoa(int(buffer1.Len())) + " " + strconv.Itoa(int(buffer2.Len()))
   270  		} else {
   271  			mb2 = append(mb2, buffer1)
   272  			toPrint += " " + strconv.Itoa(int(buffer1.Len()))
   273  		}
   274  		buffer[i] = nil
   275  	}
   276  	buffer = buffer[:0]
   277  	newError("ReshapeMultiBuffer ", toPrint).WriteToLog(session.ExportIDToError(ctx))
   278  	return mb2
   279  }
   280  
   281  // XtlsPadding add padding to eliminate length siganature during tls handshake
   282  func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool, ctx context.Context) *buf.Buffer {
   283  	var contentLen int32 = 0
   284  	var paddingLen int32 = 0
   285  	if b != nil {
   286  		contentLen = b.Len()
   287  	}
   288  	if contentLen < 900 && longPadding {
   289  		l, err := rand.Int(rand.Reader, big.NewInt(500))
   290  		if err != nil {
   291  			newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx))
   292  		}
   293  		paddingLen = int32(l.Int64()) + 900 - contentLen
   294  	} else {
   295  		l, err := rand.Int(rand.Reader, big.NewInt(256))
   296  		if err != nil {
   297  			newError("failed to generate padding").Base(err).WriteToLog(session.ExportIDToError(ctx))
   298  		}
   299  		paddingLen = int32(l.Int64())
   300  	}
   301  	if paddingLen > buf.Size-21-contentLen {
   302  		paddingLen = buf.Size - 21 - contentLen
   303  	}
   304  	newbuffer := buf.New()
   305  	if userUUID != nil {
   306  		newbuffer.Write(*userUUID)
   307  		*userUUID = nil
   308  	}
   309  	newbuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)})
   310  	if b != nil {
   311  		newbuffer.Write(b.Bytes())
   312  		b.Release()
   313  		b = nil
   314  	}
   315  	newbuffer.Extend(paddingLen)
   316  	newError("XtlsPadding ", contentLen, " ", paddingLen, " ", command).WriteToLog(session.ExportIDToError(ctx))
   317  	return newbuffer
   318  }
   319  
   320  // XtlsUnpadding remove padding and parse command
   321  func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buffer {
   322  	if s.RemainingCommand == -1 && s.RemainingContent == -1 && s.RemainingPadding == -1 { // inital state
   323  		if b.Len() >= 21 && bytes.Equal(s.UserUUID, b.BytesTo(16)) {
   324  			b.Advance(16)
   325  			s.RemainingCommand = 5
   326  		} else {
   327  			return b
   328  		}
   329  	}
   330  	newbuffer := buf.New()
   331  	for b.Len() > 0 {
   332  		if s.RemainingCommand > 0 {
   333  			data, err := b.ReadByte()
   334  			if err != nil {
   335  				return newbuffer
   336  			}
   337  			switch s.RemainingCommand {
   338  			case 5:
   339  				s.CurrentCommand = int(data)
   340  			case 4:
   341  				s.RemainingContent = int32(data)<<8
   342  			case 3:
   343  				s.RemainingContent = s.RemainingContent | int32(data)
   344  			case 2:
   345  				s.RemainingPadding = int32(data)<<8
   346  			case 1:
   347  				s.RemainingPadding = s.RemainingPadding | int32(data)
   348  				newError("Xtls Unpadding new block, content ", s.RemainingContent, " padding ", s.RemainingPadding, " command ", s.CurrentCommand).WriteToLog(session.ExportIDToError(ctx))
   349  			}
   350  			s.RemainingCommand--
   351  		} else if s.RemainingContent > 0 {
   352  			len := s.RemainingContent
   353  			if b.Len() < len {
   354  				len = b.Len()
   355  			}
   356  			data, err := b.ReadBytes(len)
   357  			if err != nil {
   358  				return newbuffer
   359  			}
   360  			newbuffer.Write(data)
   361  			s.RemainingContent -= len
   362  		} else { // remainingPadding > 0
   363  			len := s.RemainingPadding
   364  			if b.Len() < len {
   365  				len = b.Len()
   366  			}
   367  			b.Advance(len)
   368  			s.RemainingPadding -= len
   369  		}
   370  		if s.RemainingCommand <= 0 && s.RemainingContent <= 0 && s.RemainingPadding <= 0 { // this block done
   371  			if s.CurrentCommand == 0 {
   372  				s.RemainingCommand = 5
   373  			} else {
   374  				s.RemainingCommand = -1 // set to initial state
   375  				s.RemainingContent = -1
   376  				s.RemainingPadding = -1
   377  				if b.Len() > 0 { // shouldn't happen
   378  					newbuffer.Write(b.Bytes())
   379  				}
   380  				break
   381  			}
   382  		}
   383  	}
   384  	b.Release()
   385  	b = nil
   386  	return newbuffer
   387  }
   388  
   389  // XtlsFilterTls filter and recognize tls 1.3 and other info
   390  func XtlsFilterTls(buffer buf.MultiBuffer, trafficState *TrafficState, ctx context.Context) {
   391  	for _, b := range buffer {
   392  		if b == nil {
   393  			continue
   394  		}
   395  		trafficState.NumberOfPacketToFilter--
   396  		if b.Len() >= 6 {
   397  			startsBytes := b.BytesTo(6)
   398  			if bytes.Equal(TlsServerHandShakeStart, startsBytes[:3]) && startsBytes[5] == TlsHandshakeTypeServerHello {
   399  				trafficState.RemainingServerHello = (int32(startsBytes[3])<<8 | int32(startsBytes[4])) + 5
   400  				trafficState.IsTLS12orAbove = true
   401  				trafficState.IsTLS = true
   402  				if b.Len() >= 79 && trafficState.RemainingServerHello >= 79 {
   403  					sessionIdLen := int32(b.Byte(43))
   404  					cipherSuite := b.BytesRange(43+sessionIdLen+1, 43+sessionIdLen+3)
   405  					trafficState.Cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
   406  				} else {
   407  					newError("XtlsFilterTls short server hello, tls 1.2 or older? ", b.Len(), " ", trafficState.RemainingServerHello).WriteToLog(session.ExportIDToError(ctx))
   408  				}
   409  			} else if bytes.Equal(TlsClientHandShakeStart, startsBytes[:2]) && startsBytes[5] == TlsHandshakeTypeClientHello {
   410  				trafficState.IsTLS = true
   411  				newError("XtlsFilterTls found tls client hello! ", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
   412  			}
   413  		}
   414  		if trafficState.RemainingServerHello > 0 {
   415  			end := trafficState.RemainingServerHello
   416  			if end > b.Len() {
   417  				end = b.Len()
   418  			}
   419  			trafficState.RemainingServerHello -= b.Len()
   420  			if bytes.Contains(b.BytesTo(end), Tls13SupportedVersions) {
   421  				v, ok := Tls13CipherSuiteDic[trafficState.Cipher]
   422  				if !ok {
   423  					v = "Old cipher: " + strconv.FormatUint(uint64(trafficState.Cipher), 16)
   424  				} else if v != "TLS_AES_128_CCM_8_SHA256" {
   425  					trafficState.EnableXtls = true
   426  				}
   427  				newError("XtlsFilterTls found tls 1.3! ", b.Len(), " ", v).WriteToLog(session.ExportIDToError(ctx))
   428  				trafficState.NumberOfPacketToFilter = 0
   429  				return
   430  			} else if trafficState.RemainingServerHello <= 0 {
   431  				newError("XtlsFilterTls found tls 1.2! ", b.Len()).WriteToLog(session.ExportIDToError(ctx))
   432  				trafficState.NumberOfPacketToFilter = 0
   433  				return
   434  			}
   435  			newError("XtlsFilterTls inconclusive server hello ", b.Len(), " ", trafficState.RemainingServerHello).WriteToLog(session.ExportIDToError(ctx))
   436  		}
   437  		if trafficState.NumberOfPacketToFilter <= 0 {
   438  			newError("XtlsFilterTls stop filtering", buffer.Len()).WriteToLog(session.ExportIDToError(ctx))
   439  		}
   440  	}
   441  }
   442  
   443  // UnwrapRawConn support unwrap stats, tls, utls, reality and proxyproto conn and get raw tcp conn from it
   444  func UnwrapRawConn(conn net.Conn) (net.Conn, stats.Counter, stats.Counter) {
   445  	var readCounter, writerCounter stats.Counter
   446  	if conn != nil {
   447  		statConn, ok := conn.(*stat.CounterConnection)
   448  		if ok {
   449  			conn = statConn.Connection
   450  			readCounter = statConn.ReadCounter
   451  			writerCounter = statConn.WriteCounter
   452  		}
   453  		if xc, ok := conn.(*tls.Conn); ok {
   454  			conn = xc.NetConn()
   455  		} else if utlsConn, ok := conn.(*tls.UConn); ok {
   456  			conn = utlsConn.NetConn()
   457  		} else if realityConn, ok := conn.(*reality.Conn); ok {
   458  			conn = realityConn.NetConn()
   459  		} else if realityUConn, ok := conn.(*reality.UConn); ok {
   460  			conn = realityUConn.NetConn()
   461  		}
   462  		if pc, ok := conn.(*proxyproto.Conn); ok {
   463  			conn = pc.Raw()
   464  			// 8192 > 4096, there is no need to process pc's bufReader
   465  		}
   466  	}
   467  	return conn, readCounter, writerCounter
   468  }
   469  
   470  // CopyRawConnIfExist use the most efficient copy method.
   471  // - If caller don't want to turn on splice, do not pass in both reader conn and writer conn
   472  // - writer are from *transport.Link
   473  func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net.Conn, writer buf.Writer, timer signal.ActivityUpdater) error {
   474  	readerConn, readCounter, _ := UnwrapRawConn(readerConn)
   475  	writerConn, _, writeCounter := UnwrapRawConn(writerConn)
   476  	reader := buf.NewReader(readerConn)
   477  	if inbound := session.InboundFromContext(ctx); inbound != nil {
   478  		if tc, ok := writerConn.(*net.TCPConn); ok && readerConn != nil && writerConn != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
   479  			for inbound.CanSpliceCopy != 3 {
   480  				if inbound.CanSpliceCopy == 1 {
   481  					newError("CopyRawConn splice").WriteToLog(session.ExportIDToError(ctx))
   482  					statWriter, _ := writer.(*dispatcher.SizeStatWriter)
   483  					//runtime.Gosched() // necessary
   484  					time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice
   485  					w, err := tc.ReadFrom(readerConn)
   486  					if readCounter != nil {
   487  						readCounter.Add(w) // outbound stats
   488  					}
   489  					if writeCounter != nil {
   490  						writeCounter.Add(w) // inbound stats
   491  					}
   492  					if statWriter != nil {
   493  						statWriter.Counter.Add(w) // user stats
   494  					}
   495  					if err != nil && errors.Cause(err) != io.EOF {
   496  						return err
   497  					}
   498  					return nil
   499  				}
   500  				buffer, err := reader.ReadMultiBuffer()
   501  				if !buffer.IsEmpty() {
   502  					if readCounter != nil {
   503  						readCounter.Add(int64(buffer.Len()))
   504  					}
   505  					timer.Update()
   506  					if werr := writer.WriteMultiBuffer(buffer); werr != nil {
   507  						return werr
   508  					}
   509  				}
   510  				if err != nil {
   511  					return err
   512  				}
   513  			}
   514  		}
   515  	}
   516  	newError("CopyRawConn readv").WriteToLog(session.ExportIDToError(ctx))
   517  	if err := buf.Copy(reader, writer, buf.UpdateActivity(timer), buf.AddToStatCounter(readCounter)); err != nil {
   518  		return newError("failed to process response").Base(err)
   519  	}
   520  	return nil
   521  }