github.com/imannamdari/v2ray-core/v5@v5.0.5/proxy/vless/inbound/inbound.go (about)

     1  package inbound
     2  
     3  //go:generate go run github.com/imannamdari/v2ray-core/v5/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"io"
     8  	"strconv"
     9  	"time"
    10  
    11  	core "github.com/imannamdari/v2ray-core/v5"
    12  	"github.com/imannamdari/v2ray-core/v5/common"
    13  	"github.com/imannamdari/v2ray-core/v5/common/buf"
    14  	"github.com/imannamdari/v2ray-core/v5/common/errors"
    15  	"github.com/imannamdari/v2ray-core/v5/common/log"
    16  	"github.com/imannamdari/v2ray-core/v5/common/net"
    17  	"github.com/imannamdari/v2ray-core/v5/common/protocol"
    18  	"github.com/imannamdari/v2ray-core/v5/common/retry"
    19  	"github.com/imannamdari/v2ray-core/v5/common/serial"
    20  	"github.com/imannamdari/v2ray-core/v5/common/session"
    21  	"github.com/imannamdari/v2ray-core/v5/common/signal"
    22  	"github.com/imannamdari/v2ray-core/v5/common/task"
    23  	"github.com/imannamdari/v2ray-core/v5/features/dns"
    24  	feature_inbound "github.com/imannamdari/v2ray-core/v5/features/inbound"
    25  	"github.com/imannamdari/v2ray-core/v5/features/policy"
    26  	"github.com/imannamdari/v2ray-core/v5/features/routing"
    27  	"github.com/imannamdari/v2ray-core/v5/proxy/vless"
    28  	"github.com/imannamdari/v2ray-core/v5/proxy/vless/encoding"
    29  	"github.com/imannamdari/v2ray-core/v5/transport/internet"
    30  	"github.com/imannamdari/v2ray-core/v5/transport/internet/tls"
    31  )
    32  
    33  func init() {
    34  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    35  		var dc dns.Client
    36  		if err := core.RequireFeatures(ctx, func(d dns.Client) error {
    37  			dc = d
    38  			return nil
    39  		}); err != nil {
    40  			return nil, err
    41  		}
    42  		return New(ctx, config.(*Config), dc)
    43  	}))
    44  
    45  	common.Must(common.RegisterConfig((*SimplifiedConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    46  		simplifiedServer := config.(*SimplifiedConfig)
    47  		fullConfig := &Config{
    48  			Clients: func() (users []*protocol.User) {
    49  				for _, v := range simplifiedServer.Users {
    50  					account := &vless.Account{Id: v}
    51  					users = append(users, &protocol.User{
    52  						Account: serial.ToTypedMessage(account),
    53  					})
    54  				}
    55  				return
    56  			}(),
    57  			Decryption: "none",
    58  		}
    59  
    60  		return common.CreateObject(ctx, fullConfig)
    61  	}))
    62  }
    63  
    64  // Handler is an inbound connection handler that handles messages in VLess protocol.
    65  type Handler struct {
    66  	inboundHandlerManager feature_inbound.Manager
    67  	policyManager         policy.Manager
    68  	validator             *vless.Validator
    69  	dns                   dns.Client
    70  	fallbacks             map[string]map[string]*Fallback // or nil
    71  	// regexps               map[string]*regexp.Regexp       // or nil
    72  }
    73  
    74  // New creates a new VLess inbound handler.
    75  func New(ctx context.Context, config *Config, dc dns.Client) (*Handler, error) {
    76  	v := core.MustFromContext(ctx)
    77  	handler := &Handler{
    78  		inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
    79  		policyManager:         v.GetFeature(policy.ManagerType()).(policy.Manager),
    80  		validator:             new(vless.Validator),
    81  		dns:                   dc,
    82  	}
    83  
    84  	for _, user := range config.Clients {
    85  		u, err := user.ToMemoryUser()
    86  		if err != nil {
    87  			return nil, newError("failed to get VLESS user").Base(err).AtError()
    88  		}
    89  		if err := handler.AddUser(ctx, u); err != nil {
    90  			return nil, newError("failed to initiate user").Base(err).AtError()
    91  		}
    92  	}
    93  
    94  	if config.Fallbacks != nil {
    95  		handler.fallbacks = make(map[string]map[string]*Fallback)
    96  		// handler.regexps = make(map[string]*regexp.Regexp)
    97  		for _, fb := range config.Fallbacks {
    98  			if handler.fallbacks[fb.Alpn] == nil {
    99  				handler.fallbacks[fb.Alpn] = make(map[string]*Fallback)
   100  			}
   101  			handler.fallbacks[fb.Alpn][fb.Path] = fb
   102  			/*
   103  				if fb.Path != "" {
   104  					if r, err := regexp.Compile(fb.Path); err != nil {
   105  						return nil, newError("invalid path regexp").Base(err).AtError()
   106  					} else {
   107  						handler.regexps[fb.Path] = r
   108  					}
   109  				}
   110  			*/
   111  		}
   112  		if handler.fallbacks[""] != nil {
   113  			for alpn, pfb := range handler.fallbacks {
   114  				if alpn != "" { // && alpn != "h2" {
   115  					for path, fb := range handler.fallbacks[""] {
   116  						if pfb[path] == nil {
   117  							pfb[path] = fb
   118  						}
   119  					}
   120  				}
   121  			}
   122  		}
   123  	}
   124  
   125  	return handler, nil
   126  }
   127  
   128  // Close implements common.Closable.Close().
   129  func (h *Handler) Close() error {
   130  	return errors.Combine(common.Close(h.validator))
   131  }
   132  
   133  // AddUser implements proxy.UserManager.AddUser().
   134  func (h *Handler) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
   135  	return h.validator.Add(u)
   136  }
   137  
   138  // RemoveUser implements proxy.UserManager.RemoveUser().
   139  func (h *Handler) RemoveUser(ctx context.Context, e string) error {
   140  	return h.validator.Del(e)
   141  }
   142  
   143  // Network implements proxy.Inbound.Network().
   144  func (*Handler) Network() []net.Network {
   145  	return []net.Network{net.Network_TCP, net.Network_UNIX}
   146  }
   147  
   148  // Process implements proxy.Inbound.Process().
   149  func (h *Handler) Process(ctx context.Context, network net.Network, connection internet.Connection, dispatcher routing.Dispatcher) error {
   150  	sid := session.ExportIDToError(ctx)
   151  
   152  	iConn := connection
   153  	statConn, ok := iConn.(*internet.StatCouterConnection)
   154  	if ok {
   155  		iConn = statConn.Connection
   156  	}
   157  
   158  	sessionPolicy := h.policyManager.ForLevel(0)
   159  	if err := connection.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   160  		return newError("unable to set read deadline").Base(err).AtWarning()
   161  	}
   162  
   163  	first := buf.New()
   164  	defer first.Release()
   165  
   166  	firstLen, _ := first.ReadFrom(connection)
   167  	newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid)
   168  
   169  	reader := &buf.BufferedReader{
   170  		Reader: buf.NewReader(connection),
   171  		Buffer: buf.MultiBuffer{first},
   172  	}
   173  
   174  	var request *protocol.RequestHeader
   175  	var requestAddons *encoding.Addons
   176  	var err error
   177  
   178  	apfb := h.fallbacks
   179  	isfb := apfb != nil
   180  
   181  	if isfb && firstLen < 18 {
   182  		err = newError("fallback directly")
   183  	} else {
   184  		request, requestAddons, isfb, err = encoding.DecodeRequestHeader(isfb, first, reader, h.validator)
   185  	}
   186  
   187  	if err != nil {
   188  		if isfb {
   189  			if err := connection.SetReadDeadline(time.Time{}); err != nil {
   190  				newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
   191  			}
   192  			newError("fallback starts").Base(err).AtInfo().WriteToLog(sid)
   193  
   194  			alpn := ""
   195  			if len(apfb) > 1 || apfb[""] == nil {
   196  				if tlsConn, ok := iConn.(*tls.Conn); ok {
   197  					alpn = tlsConn.ConnectionState().NegotiatedProtocol
   198  					newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   199  				}
   200  				if apfb[alpn] == nil {
   201  					alpn = ""
   202  				}
   203  			}
   204  			pfb := apfb[alpn]
   205  			if pfb == nil {
   206  				return newError(`failed to find the default "alpn" config`).AtWarning()
   207  			}
   208  
   209  			path := ""
   210  			if len(pfb) > 1 || pfb[""] == nil {
   211  				/*
   212  					if lines := bytes.Split(firstBytes, []byte{'\r', '\n'}); len(lines) > 1 {
   213  						if s := bytes.Split(lines[0], []byte{' '}); len(s) == 3 {
   214  							if len(s[0]) < 8 && len(s[1]) > 0 && len(s[2]) == 8 {
   215  								newError("realPath = " + string(s[1])).AtInfo().WriteToLog(sid)
   216  								for _, fb := range pfb {
   217  									if fb.Path != "" && h.regexps[fb.Path].Match(s[1]) {
   218  										path = fb.Path
   219  										break
   220  									}
   221  								}
   222  							}
   223  						}
   224  					}
   225  				*/
   226  				if firstLen >= 18 && first.Byte(4) != '*' { // not h2c
   227  					firstBytes := first.Bytes()
   228  					for i := 4; i <= 8; i++ { // 5 -> 9
   229  						if firstBytes[i] == '/' && firstBytes[i-1] == ' ' {
   230  							search := len(firstBytes)
   231  							if search > 64 {
   232  								search = 64 // up to about 60
   233  							}
   234  							for j := i + 1; j < search; j++ {
   235  								k := firstBytes[j]
   236  								if k == '\r' || k == '\n' { // avoid logging \r or \n
   237  									break
   238  								}
   239  								if k == ' ' {
   240  									path = string(firstBytes[i:j])
   241  									newError("realPath = " + path).AtInfo().WriteToLog(sid)
   242  									if pfb[path] == nil {
   243  										path = ""
   244  									}
   245  									break
   246  								}
   247  							}
   248  							break
   249  						}
   250  					}
   251  				}
   252  			}
   253  			fb := pfb[path]
   254  			if fb == nil {
   255  				return newError(`failed to find the default "path" config`).AtWarning()
   256  			}
   257  
   258  			ctx, cancel := context.WithCancel(ctx)
   259  			timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   260  			ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   261  
   262  			var conn net.Conn
   263  			if err := retry.ExponentialBackoff(5, 100).On(func() error {
   264  				var dialer net.Dialer
   265  				conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest)
   266  				if err != nil {
   267  					return err
   268  				}
   269  				return nil
   270  			}); err != nil {
   271  				return newError("failed to dial to " + fb.Dest).Base(err).AtWarning()
   272  			}
   273  			defer conn.Close()
   274  
   275  			serverReader := buf.NewReader(conn)
   276  			serverWriter := buf.NewWriter(conn)
   277  
   278  			postRequest := func() error {
   279  				defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   280  				if fb.Xver != 0 {
   281  					remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String())
   282  					if err != nil {
   283  						return err
   284  					}
   285  					localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String())
   286  					if err != nil {
   287  						return err
   288  					}
   289  					ipv4 := true
   290  					for i := 0; i < len(remoteAddr); i++ {
   291  						if remoteAddr[i] == ':' {
   292  							ipv4 = false
   293  							break
   294  						}
   295  					}
   296  					pro := buf.New()
   297  					defer pro.Release()
   298  					switch fb.Xver {
   299  					case 1:
   300  						if ipv4 {
   301  							pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))
   302  						} else {
   303  							pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n"))
   304  						}
   305  
   306  					case 2:
   307  						pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21")) // signature + v2 + PROXY
   308  						if ipv4 {
   309  							pro.Write([]byte("\x11\x00\x0C")) // AF_INET + STREAM + 12 bytes
   310  							pro.Write(net.ParseIP(remoteAddr).To4())
   311  							pro.Write(net.ParseIP(localAddr).To4())
   312  						} else {
   313  							pro.Write([]byte("\x21\x00\x24")) // AF_INET6 + STREAM + 36 bytes
   314  							pro.Write(net.ParseIP(remoteAddr).To16())
   315  							pro.Write(net.ParseIP(localAddr).To16())
   316  						}
   317  						p1, _ := strconv.ParseUint(remotePort, 10, 16)
   318  						p2, _ := strconv.ParseUint(localPort, 10, 16)
   319  						pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)})
   320  					}
   321  					if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil {
   322  						return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning()
   323  					}
   324  				}
   325  				if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil {
   326  					return newError("failed to fallback request payload").Base(err).AtInfo()
   327  				}
   328  				return nil
   329  			}
   330  
   331  			writer := buf.NewWriter(connection)
   332  
   333  			getResponse := func() error {
   334  				defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   335  				if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil {
   336  					return newError("failed to deliver response payload").Base(err).AtInfo()
   337  				}
   338  				return nil
   339  			}
   340  
   341  			if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil {
   342  				common.Interrupt(serverReader)
   343  				common.Interrupt(serverWriter)
   344  				return newError("fallback ends").Base(err).AtInfo()
   345  			}
   346  			return nil
   347  		}
   348  
   349  		if errors.Cause(err) != io.EOF {
   350  			log.Record(&log.AccessMessage{
   351  				From:   connection.RemoteAddr(),
   352  				To:     "",
   353  				Status: log.AccessRejected,
   354  				Reason: err,
   355  			})
   356  			err = newError("invalid request from ", connection.RemoteAddr()).Base(err).AtInfo()
   357  		}
   358  		return err
   359  	}
   360  
   361  	if err := connection.SetReadDeadline(time.Time{}); err != nil {
   362  		newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
   363  	}
   364  	newError("received request for ", request.Destination()).AtInfo().WriteToLog(sid)
   365  
   366  	inbound := session.InboundFromContext(ctx)
   367  	if inbound == nil {
   368  		panic("no inbound metadata")
   369  	}
   370  	inbound.User = request.User
   371  
   372  	responseAddons := &encoding.Addons{}
   373  
   374  	if request.Command != protocol.RequestCommandMux {
   375  		ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   376  			From:   connection.RemoteAddr(),
   377  			To:     request.Destination(),
   378  			Status: log.AccessAccepted,
   379  			Reason: "",
   380  			Email:  request.User.Email,
   381  		})
   382  	}
   383  
   384  	sessionPolicy = h.policyManager.ForLevel(request.User.Level)
   385  	ctx, cancel := context.WithCancel(ctx)
   386  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   387  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   388  
   389  	link, err := dispatcher.Dispatch(ctx, request.Destination())
   390  	if err != nil {
   391  		return newError("failed to dispatch request to ", request.Destination()).Base(err).AtWarning()
   392  	}
   393  
   394  	serverReader := link.Reader // .(*pipe.Reader)
   395  	serverWriter := link.Writer // .(*pipe.Writer)
   396  
   397  	postRequest := func() error {
   398  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   399  
   400  		// default: clientReader := reader
   401  		clientReader := encoding.DecodeBodyAddons(reader, request, requestAddons)
   402  
   403  		// from clientReader.ReadMultiBuffer to serverWriter.WriteMultiBuffer
   404  		if err := buf.Copy(clientReader, serverWriter, buf.UpdateActivity(timer)); err != nil {
   405  			return newError("failed to transfer request payload").Base(err).AtInfo()
   406  		}
   407  
   408  		return nil
   409  	}
   410  
   411  	getResponse := func() error {
   412  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   413  
   414  		bufferWriter := buf.NewBufferedWriter(buf.NewWriter(connection))
   415  		if err := encoding.EncodeResponseHeader(bufferWriter, request, responseAddons); err != nil {
   416  			return newError("failed to encode response header").Base(err).AtWarning()
   417  		}
   418  
   419  		// default: clientWriter := bufferWriter
   420  		clientWriter := encoding.EncodeBodyAddons(bufferWriter, request, responseAddons)
   421  		{
   422  			multiBuffer, err := serverReader.ReadMultiBuffer()
   423  			if err != nil {
   424  				return err // ...
   425  			}
   426  			if err := clientWriter.WriteMultiBuffer(multiBuffer); err != nil {
   427  				return err // ...
   428  			}
   429  		}
   430  
   431  		// Flush; bufferWriter.WriteMultiBuffer now is bufferWriter.writer.WriteMultiBuffer
   432  		if err := bufferWriter.SetBuffered(false); err != nil {
   433  			return newError("failed to write A response payload").Base(err).AtWarning()
   434  		}
   435  
   436  		// from serverReader.ReadMultiBuffer to clientWriter.WriteMultiBuffer
   437  		if err := buf.Copy(serverReader, clientWriter, buf.UpdateActivity(timer)); err != nil {
   438  			return newError("failed to transfer response payload").Base(err).AtInfo()
   439  		}
   440  
   441  		return nil
   442  	}
   443  
   444  	if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), getResponse); err != nil {
   445  		common.Interrupt(serverReader)
   446  		common.Interrupt(serverWriter)
   447  		return newError("connection ends").Base(err).AtInfo()
   448  	}
   449  
   450  	return nil
   451  }