github.com/EagleQL/Xray-core@v1.4.3/proxy/vless/inbound/inbound.go (about)

     1  package inbound
     2  
     3  //go:generate go run github.com/xtls/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"io"
     8  	"strconv"
     9  	"strings"
    10  	"syscall"
    11  	"time"
    12  
    13  	"github.com/xtls/xray-core/common"
    14  	"github.com/xtls/xray-core/common/buf"
    15  	"github.com/xtls/xray-core/common/errors"
    16  	"github.com/xtls/xray-core/common/log"
    17  	"github.com/xtls/xray-core/common/net"
    18  	"github.com/xtls/xray-core/common/platform"
    19  	"github.com/xtls/xray-core/common/protocol"
    20  	"github.com/xtls/xray-core/common/retry"
    21  	"github.com/xtls/xray-core/common/session"
    22  	"github.com/xtls/xray-core/common/signal"
    23  	"github.com/xtls/xray-core/common/task"
    24  	core "github.com/xtls/xray-core/core"
    25  	"github.com/xtls/xray-core/features/dns"
    26  	feature_inbound "github.com/xtls/xray-core/features/inbound"
    27  	"github.com/xtls/xray-core/features/policy"
    28  	"github.com/xtls/xray-core/features/routing"
    29  	"github.com/xtls/xray-core/features/stats"
    30  	"github.com/xtls/xray-core/proxy/vless"
    31  	"github.com/xtls/xray-core/proxy/vless/encoding"
    32  	"github.com/xtls/xray-core/transport/internet"
    33  	"github.com/xtls/xray-core/transport/internet/tls"
    34  	"github.com/xtls/xray-core/transport/internet/xtls"
    35  )
    36  
    37  var (
    38  	xtls_show = false
    39  )
    40  
    41  func init() {
    42  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    43  		var dc dns.Client
    44  		if err := core.RequireFeatures(ctx, func(d dns.Client) error {
    45  			dc = d
    46  			return nil
    47  		}); err != nil {
    48  			return nil, err
    49  		}
    50  		return New(ctx, config.(*Config), dc)
    51  	}))
    52  
    53  	const defaultFlagValue = "NOT_DEFINED_AT_ALL"
    54  
    55  	xtlsShow := platform.NewEnvFlag("xray.vless.xtls.show").GetValue(func() string { return defaultFlagValue })
    56  	if xtlsShow == "true" {
    57  		xtls_show = true
    58  	}
    59  }
    60  
    61  // Handler is an inbound connection handler that handles messages in VLess protocol.
    62  type Handler struct {
    63  	inboundHandlerManager feature_inbound.Manager
    64  	policyManager         policy.Manager
    65  	validator             *vless.Validator
    66  	dns                   dns.Client
    67  	fallbacks             map[string]map[string]map[string]*Fallback // or nil
    68  	// regexps               map[string]*regexp.Regexp       // or nil
    69  }
    70  
    71  // New creates a new VLess inbound handler.
    72  func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) {
    73  	v := core.MustFromContext(ctx)
    74  	handler := &Handler{
    75  		inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
    76  		policyManager:         v.GetFeature(policy.ManagerType()).(policy.Manager),
    77  		validator:             new(vless.Validator),
    78  		dns:                   dc,
    79  	}
    80  
    81  	for _, user := range config.Clients {
    82  		u, err := user.ToMemoryUser()
    83  		if err != nil {
    84  			return nil, newError("failed to get VLESS user").Base(err).AtError()
    85  		}
    86  		if err := handler.AddUser(ctx, u); err != nil {
    87  			return nil, newError("failed to initiate user").Base(err).AtError()
    88  		}
    89  	}
    90  
    91  	if config.Fallbacks != nil {
    92  		handler.fallbacks = make(map[string]map[string]map[string]*Fallback)
    93  		// handler.regexps = make(map[string]*regexp.Regexp)
    94  		for _, fb := range config.Fallbacks {
    95  			if handler.fallbacks[fb.Name] == nil {
    96  				handler.fallbacks[fb.Name] = make(map[string]map[string]*Fallback)
    97  			}
    98  			if handler.fallbacks[fb.Name][fb.Alpn] == nil {
    99  				handler.fallbacks[fb.Name][fb.Alpn] = make(map[string]*Fallback)
   100  			}
   101  			handler.fallbacks[fb.Name][fb.Alpn][fb.Path] = fb
   102  			/*
   103  				if fb.Path != "" {
   104  					if r, err := regexp.Compile(fb.Path); err != nil {
   105  						return nil, newError("invalid path regexp").Base(err).AtError()
   106  					} else {
   107  						handler.regexps[fb.Path] = r
   108  					}
   109  				}
   110  			*/
   111  		}
   112  		if handler.fallbacks[""] != nil {
   113  			for name, apfb := range handler.fallbacks {
   114  				if name != "" {
   115  					for alpn := range handler.fallbacks[""] {
   116  						if apfb[alpn] == nil {
   117  							apfb[alpn] = make(map[string]*Fallback)
   118  						}
   119  					}
   120  				}
   121  			}
   122  		}
   123  		for _, apfb := range handler.fallbacks {
   124  			if apfb[""] != nil {
   125  				for alpn, pfb := range apfb {
   126  					if alpn != "" { // && alpn != "h2" {
   127  						for path, fb := range apfb[""] {
   128  							if pfb[path] == nil {
   129  								pfb[path] = fb
   130  							}
   131  						}
   132  					}
   133  				}
   134  			}
   135  		}
   136  		if handler.fallbacks[""] != nil {
   137  			for name, apfb := range handler.fallbacks {
   138  				if name != "" {
   139  					for alpn, pfb := range handler.fallbacks[""] {
   140  						for path, fb := range pfb {
   141  							if apfb[alpn][path] == nil {
   142  								apfb[alpn][path] = fb
   143  							}
   144  						}
   145  					}
   146  				}
   147  			}
   148  		}
   149  	}
   150  
   151  	return handler, nil
   152  }
   153  
   154  // Close implements common.Closable.Close().
   155  func (h *Handler) Close() error {
   156  	return errors.Combine(common.Close(h.validator))
   157  }
   158  
   159  // AddUser implements proxy.UserManager.AddUser().
   160  func (h *Handler) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
   161  	return h.validator.Add(u)
   162  }
   163  
   164  // RemoveUser implements proxy.UserManager.RemoveUser().
   165  func (h *Handler) RemoveUser(ctx context.Context, e string) error {
   166  	return h.validator.Del(e)
   167  }
   168  
   169  // Network implements proxy.Inbound.Network().
   170  func (*Handler) Network() []net.Network {
   171  	return []net.Network{net.Network_TCP, net.Network_UNIX}
   172  }
   173  
   174  // Process implements proxy.Inbound.Process().
   175  func (h *Handler) Process(ctx context.Context, network net.Network, connection internet.Connection, dispatcher routing.Dispatcher) error {
   176  	sid := session.ExportIDToError(ctx)
   177  
   178  	iConn := connection
   179  	statConn, ok := iConn.(*internet.StatCouterConnection)
   180  	if ok {
   181  		iConn = statConn.Connection
   182  	}
   183  
   184  	sessionPolicy := h.policyManager.ForLevel(0)
   185  	if err := connection.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   186  		return newError("unable to set read deadline").Base(err).AtWarning()
   187  	}
   188  
   189  	first := buf.New()
   190  	defer first.Release()
   191  
   192  	firstLen, _ := first.ReadFrom(connection)
   193  	newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid)
   194  
   195  	reader := &buf.BufferedReader{
   196  		Reader: buf.NewReader(connection),
   197  		Buffer: buf.MultiBuffer{first},
   198  	}
   199  
   200  	var request *protocol.RequestHeader
   201  	var requestAddons *encoding.Addons
   202  	var err error
   203  
   204  	napfb := h.fallbacks
   205  	isfb := napfb != nil
   206  
   207  	if isfb && firstLen < 18 {
   208  		err = newError("fallback directly")
   209  	} else {
   210  		request, requestAddons, isfb, err = encoding.DecodeRequestHeader(isfb, first, reader, h.validator)
   211  	}
   212  
   213  	if err != nil {
   214  		if isfb {
   215  			if err := connection.SetReadDeadline(time.Time{}); err != nil {
   216  				newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
   217  			}
   218  			newError("fallback starts").Base(err).AtInfo().WriteToLog(sid)
   219  
   220  			name := ""
   221  			alpn := ""
   222  			if tlsConn, ok := iConn.(*tls.Conn); ok {
   223  				cs := tlsConn.ConnectionState()
   224  				name = cs.ServerName
   225  				alpn = cs.NegotiatedProtocol
   226  				newError("realName = " + name).AtInfo().WriteToLog(sid)
   227  				newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   228  			} else if xtlsConn, ok := iConn.(*xtls.Conn); ok {
   229  				cs := xtlsConn.ConnectionState()
   230  				name = cs.ServerName
   231  				alpn = cs.NegotiatedProtocol
   232  				newError("realName = " + name).AtInfo().WriteToLog(sid)
   233  				newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   234  			}
   235  			name = strings.ToLower(name)
   236  			alpn = strings.ToLower(alpn)
   237  
   238  			if len(napfb) > 1 || napfb[""] == nil {
   239  				if name != "" && napfb[name] == nil {
   240  					match := ""
   241  					for n := range napfb {
   242  						if n != "" && strings.Contains(name, n) && len(n) > len(match) {
   243  							match = n
   244  						}
   245  					}
   246  					name = match
   247  				}
   248  			}
   249  
   250  			if napfb[name] == nil {
   251  				name = ""
   252  			}
   253  			apfb := napfb[name]
   254  			if apfb == nil {
   255  				return newError(`failed to find the default "name" config`).AtWarning()
   256  			}
   257  
   258  			if apfb[alpn] == nil {
   259  				alpn = ""
   260  			}
   261  			pfb := apfb[alpn]
   262  			if pfb == nil {
   263  				return newError(`failed to find the default "alpn" config`).AtWarning()
   264  			}
   265  
   266  			path := ""
   267  			if len(pfb) > 1 || pfb[""] == nil {
   268  				/*
   269  					if lines := bytes.Split(firstBytes, []byte{'\r', '\n'}); len(lines) > 1 {
   270  						if s := bytes.Split(lines[0], []byte{' '}); len(s) == 3 {
   271  							if len(s[0]) < 8 && len(s[1]) > 0 && len(s[2]) == 8 {
   272  								newError("realPath = " + string(s[1])).AtInfo().WriteToLog(sid)
   273  								for _, fb := range pfb {
   274  									if fb.Path != "" && h.regexps[fb.Path].Match(s[1]) {
   275  										path = fb.Path
   276  										break
   277  									}
   278  								}
   279  							}
   280  						}
   281  					}
   282  				*/
   283  				if firstLen >= 18 && first.Byte(4) != '*' { // not h2c
   284  					firstBytes := first.Bytes()
   285  					for i := 4; i <= 8; i++ { // 5 -> 9
   286  						if firstBytes[i] == '/' && firstBytes[i-1] == ' ' {
   287  							search := len(firstBytes)
   288  							if search > 64 {
   289  								search = 64 // up to about 60
   290  							}
   291  							for j := i + 1; j < search; j++ {
   292  								k := firstBytes[j]
   293  								if k == '\r' || k == '\n' { // avoid logging \r or \n
   294  									break
   295  								}
   296  								if k == '?' || k == ' ' {
   297  									path = string(firstBytes[i:j])
   298  									newError("realPath = " + path).AtInfo().WriteToLog(sid)
   299  									if pfb[path] == nil {
   300  										path = ""
   301  									}
   302  									break
   303  								}
   304  							}
   305  							break
   306  						}
   307  					}
   308  				}
   309  			}
   310  			fb := pfb[path]
   311  			if fb == nil {
   312  				return newError(`failed to find the default "path" config`).AtWarning()
   313  			}
   314  
   315  			ctx, cancel := context.WithCancel(ctx)
   316  			timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   317  			ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   318  
   319  			var conn net.Conn
   320  			if err := retry.ExponentialBackoff(5, 100).On(func() error {
   321  				var dialer net.Dialer
   322  				conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest)
   323  				if err != nil {
   324  					return err
   325  				}
   326  				return nil
   327  			}); err != nil {
   328  				return newError("failed to dial to " + fb.Dest).Base(err).AtWarning()
   329  			}
   330  			defer conn.Close()
   331  
   332  			serverReader := buf.NewReader(conn)
   333  			serverWriter := buf.NewWriter(conn)
   334  
   335  			postRequest := func() error {
   336  				defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   337  				if fb.Xver != 0 {
   338  					ipType := 4
   339  					remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String())
   340  					if err != nil {
   341  						ipType = 0
   342  					}
   343  					localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String())
   344  					if err != nil {
   345  						ipType = 0
   346  					}
   347  					if ipType == 4 {
   348  						for i := 0; i < len(remoteAddr); i++ {
   349  							if remoteAddr[i] == ':' {
   350  								ipType = 6
   351  								break
   352  							}
   353  						}
   354  					}
   355  					pro := buf.New()
   356  					defer pro.Release()
   357  					switch fb.Xver {
   358  					case 1:
   359  						if ipType == 0 {
   360  							pro.Write([]byte("PROXY UNKNOWN\r\n"))
   361  							break
   362  						}
   363  						if ipType == 4 {
   364  							pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))
   365  						} else {
   366  							pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))
   367  						}
   368  					case 2:
   369  						pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A")) // signature
   370  						if ipType == 0 {
   371  							pro.Write([]byte("\x20\x00\x00\x00")) // v2 + LOCAL + UNSPEC + UNSPEC + 0 bytes
   372  							break
   373  						}
   374  						if ipType == 4 {
   375  							pro.Write([]byte("\x21\x11\x00\x0C")) // v2 + PROXY + AF_INET + STREAM + 12 bytes
   376  							pro.Write(net.ParseIP(remoteAddr).To4())
   377  							pro.Write(net.ParseIP(localAddr).To4())
   378  						} else {
   379  							pro.Write([]byte("\x21\x21\x00\x24")) // v2 + PROXY + AF_INET6 + STREAM + 36 bytes
   380  							pro.Write(net.ParseIP(remoteAddr).To16())
   381  							pro.Write(net.ParseIP(localAddr).To16())
   382  						}
   383  						p1, _ := strconv.ParseUint(remotePort, 10, 16)
   384  						p2, _ := strconv.ParseUint(localPort, 10, 16)
   385  						pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)})
   386  					}
   387  					if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil {
   388  						return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning()
   389  					}
   390  				}
   391  				if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil {
   392  					return newError("failed to fallback request payload").Base(err).AtInfo()
   393  				}
   394  				return nil
   395  			}
   396  
   397  			writer := buf.NewWriter(connection)
   398  
   399  			getResponse := func() error {
   400  				defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   401  				if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil {
   402  					return newError("failed to deliver response payload").Base(err).AtInfo()
   403  				}
   404  				return nil
   405  			}
   406  
   407  			if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil {
   408  				common.Interrupt(serverReader)
   409  				common.Interrupt(serverWriter)
   410  				return newError("fallback ends").Base(err).AtInfo()
   411  			}
   412  			return nil
   413  		}
   414  
   415  		if errors.Cause(err) != io.EOF {
   416  			log.Record(&log.AccessMessage{
   417  				From:   connection.RemoteAddr(),
   418  				To:     "",
   419  				Status: log.AccessRejected,
   420  				Reason: err,
   421  			})
   422  			err = newError("invalid request from ", connection.RemoteAddr()).Base(err).AtInfo()
   423  		}
   424  		return err
   425  	}
   426  
   427  	if err := connection.SetReadDeadline(time.Time{}); err != nil {
   428  		newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
   429  	}
   430  	newError("received request for ", request.Destination()).AtInfo().WriteToLog(sid)
   431  
   432  	inbound := session.InboundFromContext(ctx)
   433  	if inbound == nil {
   434  		panic("no inbound metadata")
   435  	}
   436  	inbound.User = request.User
   437  
   438  	account := request.User.Account.(*vless.MemoryAccount)
   439  
   440  	responseAddons := &encoding.Addons{
   441  		// Flow: requestAddons.Flow,
   442  	}
   443  
   444  	var rawConn syscall.RawConn
   445  
   446  	switch requestAddons.Flow {
   447  	case vless.XRO, vless.XRD:
   448  		if account.Flow == requestAddons.Flow {
   449  			switch request.Command {
   450  			case protocol.RequestCommandMux:
   451  				return newError(requestAddons.Flow + " doesn't support Mux").AtWarning()
   452  			case protocol.RequestCommandUDP:
   453  				return newError(requestAddons.Flow + " doesn't support UDP").AtWarning()
   454  			case protocol.RequestCommandTCP:
   455  				if xtlsConn, ok := iConn.(*xtls.Conn); ok {
   456  					xtlsConn.RPRX = true
   457  					xtlsConn.SHOW = xtls_show
   458  					xtlsConn.MARK = "XTLS"
   459  					if requestAddons.Flow == vless.XRD {
   460  						xtlsConn.DirectMode = true
   461  						if sc, ok := xtlsConn.Connection.(syscall.Conn); ok {
   462  							rawConn, _ = sc.SyscallConn()
   463  						}
   464  					}
   465  				} else {
   466  					return newError(`failed to use ` + requestAddons.Flow + `, maybe "security" is not "xtls"`).AtWarning()
   467  				}
   468  			}
   469  		} else {
   470  			return newError(account.ID.String() + " is not able to use " + requestAddons.Flow).AtWarning()
   471  		}
   472  	case "":
   473  	default:
   474  		return newError("unknown request flow " + requestAddons.Flow).AtWarning()
   475  	}
   476  
   477  	if request.Command != protocol.RequestCommandMux {
   478  		ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   479  			From:   connection.RemoteAddr(),
   480  			To:     request.Destination(),
   481  			Status: log.AccessAccepted,
   482  			Reason: "",
   483  			Email:  request.User.Email,
   484  		})
   485  	}
   486  
   487  	sessionPolicy = h.policyManager.ForLevel(request.User.Level)
   488  	ctx, cancel := context.WithCancel(ctx)
   489  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   490  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   491  
   492  	link, err := dispatcher.Dispatch(ctx, request.Destination())
   493  	if err != nil {
   494  		return newError("failed to dispatch request to ", request.Destination()).Base(err).AtWarning()
   495  	}
   496  
   497  	serverReader := link.Reader // .(*pipe.Reader)
   498  	serverWriter := link.Writer // .(*pipe.Writer)
   499  
   500  	postRequest := func() error {
   501  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   502  
   503  		// default: clientReader := reader
   504  		clientReader := encoding.DecodeBodyAddons(reader, request, requestAddons)
   505  
   506  		var err error
   507  
   508  		if rawConn != nil {
   509  			var counter stats.Counter
   510  			if statConn != nil {
   511  				counter = statConn.ReadCounter
   512  			}
   513  			err = encoding.ReadV(clientReader, serverWriter, timer, iConn.(*xtls.Conn), rawConn, counter, nil)
   514  		} else {
   515  			// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBufer
   516  			err = buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer))
   517  		}
   518  
   519  		if err != nil {
   520  			return newError("failed to transfer request payload").Base(err).AtInfo()
   521  		}
   522  
   523  		return nil
   524  	}
   525  
   526  	getResponse := func() error {
   527  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   528  
   529  		bufferWriter := buf.NewBufferedWriter(buf.NewWriter(connection))
   530  		if err := encoding.EncodeResponseHeader(bufferWriter, request, responseAddons); err != nil {
   531  			return newError("failed to encode response header").Base(err).AtWarning()
   532  		}
   533  
   534  		// default: clientWriter := bufferWriter
   535  		clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, responseAddons)
   536  		{
   537  			multiBuffer, err := serverReader.ReadMultiBuffer()
   538  			if err != nil {
   539  				return err // ...
   540  			}
   541  			if err := clientWriter.WriteMultiBuffer(multiBuffer); err != nil {
   542  				return err // ...
   543  			}
   544  		}
   545  
   546  		// Flush; bufferWriter.WriteMultiBufer now is bufferWriter.writer.WriteMultiBuffer
   547  		if err := bufferWriter.SetBuffered(false); err != nil {
   548  			return newError("failed to write A response payload").Base(err).AtWarning()
   549  		}
   550  
   551  		// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBufer
   552  		if err := buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)); err != nil {
   553  			return newError("failed to transfer response payload").Base(err).AtInfo()
   554  		}
   555  
   556  		// Indicates the end of response payload.
   557  		switch responseAddons.Flow {
   558  		default:
   559  		}
   560  
   561  		return nil
   562  	}
   563  
   564  	if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), getResponse); err != nil {
   565  		common.Interrupt(serverReader)
   566  		common.Interrupt(serverWriter)
   567  		return newError("connection ends").Base(err).AtInfo()
   568  	}
   569  
   570  	return nil
   571  }