github.com/eagleql/xray-core@v1.4.4/proxy/vmess/inbound/inbound.go (about)

     1  package inbound
     2  
     3  //go:generate go run github.com/eagleql/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"io"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/eagleql/xray-core/common"
    13  	"github.com/eagleql/xray-core/common/buf"
    14  	"github.com/eagleql/xray-core/common/errors"
    15  	"github.com/eagleql/xray-core/common/log"
    16  	"github.com/eagleql/xray-core/common/net"
    17  	"github.com/eagleql/xray-core/common/platform"
    18  	"github.com/eagleql/xray-core/common/protocol"
    19  	"github.com/eagleql/xray-core/common/session"
    20  	"github.com/eagleql/xray-core/common/signal"
    21  	"github.com/eagleql/xray-core/common/task"
    22  	"github.com/eagleql/xray-core/common/uuid"
    23  	"github.com/eagleql/xray-core/core"
    24  	feature_inbound "github.com/eagleql/xray-core/features/inbound"
    25  	"github.com/eagleql/xray-core/features/policy"
    26  	"github.com/eagleql/xray-core/features/routing"
    27  	"github.com/eagleql/xray-core/proxy/vmess"
    28  	"github.com/eagleql/xray-core/proxy/vmess/encoding"
    29  	"github.com/eagleql/xray-core/transport/internet"
    30  )
    31  
    32  var (
    33  	aeadForced = false
    34  )
    35  
    36  type userByEmail struct {
    37  	sync.Mutex
    38  	cache           map[string]*protocol.MemoryUser
    39  	defaultLevel    uint32
    40  	defaultAlterIDs uint16
    41  }
    42  
    43  func newUserByEmail(config *DefaultConfig) *userByEmail {
    44  	return &userByEmail{
    45  		cache:           make(map[string]*protocol.MemoryUser),
    46  		defaultLevel:    config.Level,
    47  		defaultAlterIDs: uint16(config.AlterId),
    48  	}
    49  }
    50  
    51  func (v *userByEmail) addNoLock(u *protocol.MemoryUser) bool {
    52  	email := strings.ToLower(u.Email)
    53  	_, found := v.cache[email]
    54  	if found {
    55  		return false
    56  	}
    57  	v.cache[email] = u
    58  	return true
    59  }
    60  
    61  func (v *userByEmail) Add(u *protocol.MemoryUser) bool {
    62  	v.Lock()
    63  	defer v.Unlock()
    64  
    65  	return v.addNoLock(u)
    66  }
    67  
    68  func (v *userByEmail) Get(email string) (*protocol.MemoryUser, bool) {
    69  	email = strings.ToLower(email)
    70  
    71  	v.Lock()
    72  	defer v.Unlock()
    73  
    74  	user, found := v.cache[email]
    75  	if !found {
    76  		id := uuid.New()
    77  		rawAccount := &vmess.Account{
    78  			Id:      id.String(),
    79  			AlterId: uint32(v.defaultAlterIDs),
    80  		}
    81  		account, err := rawAccount.AsAccount()
    82  		common.Must(err)
    83  		user = &protocol.MemoryUser{
    84  			Level:   v.defaultLevel,
    85  			Email:   email,
    86  			Account: account,
    87  		}
    88  		v.cache[email] = user
    89  	}
    90  	return user, found
    91  }
    92  
    93  func (v *userByEmail) Remove(email string) bool {
    94  	email = strings.ToLower(email)
    95  
    96  	v.Lock()
    97  	defer v.Unlock()
    98  
    99  	if _, found := v.cache[email]; !found {
   100  		return false
   101  	}
   102  	delete(v.cache, email)
   103  	return true
   104  }
   105  
   106  // Handler is an inbound connection handler that handles messages in VMess protocol.
   107  type Handler struct {
   108  	policyManager         policy.Manager
   109  	inboundHandlerManager feature_inbound.Manager
   110  	clients               *vmess.TimedUserValidator
   111  	usersByEmail          *userByEmail
   112  	detours               *DetourConfig
   113  	sessionHistory        *encoding.SessionHistory
   114  	secure                bool
   115  }
   116  
   117  // New creates a new VMess inbound handler.
   118  func New(ctx context.Context, config *Config) (*Handler, error) {
   119  	v := core.MustFromContext(ctx)
   120  	handler := &Handler{
   121  		policyManager:         v.GetFeature(policy.ManagerType()).(policy.Manager),
   122  		inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager),
   123  		clients:               vmess.NewTimedUserValidator(protocol.DefaultIDHash),
   124  		detours:               config.Detour,
   125  		usersByEmail:          newUserByEmail(config.GetDefaultValue()),
   126  		sessionHistory:        encoding.NewSessionHistory(),
   127  		secure:                config.SecureEncryptionOnly,
   128  	}
   129  
   130  	for _, user := range config.User {
   131  		mUser, err := user.ToMemoryUser()
   132  		if err != nil {
   133  			return nil, newError("failed to get VMess user").Base(err)
   134  		}
   135  
   136  		if err := handler.AddUser(ctx, mUser); err != nil {
   137  			return nil, newError("failed to initiate user").Base(err)
   138  		}
   139  	}
   140  
   141  	return handler, nil
   142  }
   143  
   144  // Close implements common.Closable.
   145  func (h *Handler) Close() error {
   146  	return errors.Combine(
   147  		h.clients.Close(),
   148  		h.sessionHistory.Close(),
   149  		common.Close(h.usersByEmail))
   150  }
   151  
   152  // Network implements proxy.Inbound.Network().
   153  func (*Handler) Network() []net.Network {
   154  	return []net.Network{net.Network_TCP, net.Network_UNIX}
   155  }
   156  
   157  func (h *Handler) GetUser(email string) *protocol.MemoryUser {
   158  	user, existing := h.usersByEmail.Get(email)
   159  	if !existing {
   160  		h.clients.Add(user)
   161  	}
   162  	return user
   163  }
   164  
   165  func (h *Handler) AddUser(ctx context.Context, user *protocol.MemoryUser) error {
   166  	if len(user.Email) > 0 && !h.usersByEmail.Add(user) {
   167  		return newError("User ", user.Email, " already exists.")
   168  	}
   169  	return h.clients.Add(user)
   170  }
   171  
   172  func (h *Handler) RemoveUser(ctx context.Context, email string) error {
   173  	if email == "" {
   174  		return newError("Email must not be empty.")
   175  	}
   176  	if !h.usersByEmail.Remove(email) {
   177  		return newError("User ", email, " not found.")
   178  	}
   179  	h.clients.Remove(email)
   180  	return nil
   181  }
   182  
   183  func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output *buf.BufferedWriter) error {
   184  	session.EncodeResponseHeader(response, output)
   185  
   186  	bodyWriter := session.EncodeResponseBody(request, output)
   187  
   188  	{
   189  		// Optimize for small response packet
   190  		data, err := input.ReadMultiBuffer()
   191  		if err != nil {
   192  			return err
   193  		}
   194  
   195  		if err := bodyWriter.WriteMultiBuffer(data); err != nil {
   196  			return err
   197  		}
   198  	}
   199  
   200  	if err := output.SetBuffered(false); err != nil {
   201  		return err
   202  	}
   203  
   204  	if err := buf.Copy(input, bodyWriter, buf.UpdateActivity(timer)); err != nil {
   205  		return err
   206  	}
   207  
   208  	if request.Option.Has(protocol.RequestOptionChunkStream) {
   209  		if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil {
   210  			return err
   211  		}
   212  	}
   213  
   214  	return nil
   215  }
   216  
   217  func isInsecureEncryption(s protocol.SecurityType) bool {
   218  	return s == protocol.SecurityType_NONE || s == protocol.SecurityType_LEGACY || s == protocol.SecurityType_UNKNOWN
   219  }
   220  
   221  // Process implements proxy.Inbound.Process().
   222  func (h *Handler) Process(ctx context.Context, network net.Network, connection internet.Connection, dispatcher routing.Dispatcher) error {
   223  	sessionPolicy := h.policyManager.ForLevel(0)
   224  	if err := connection.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   225  		return newError("unable to set read deadline").Base(err).AtWarning()
   226  	}
   227  
   228  	iConn := connection
   229  	if statConn, ok := iConn.(*internet.StatCouterConnection); ok {
   230  		iConn = statConn.Connection
   231  	}
   232  	_, isDrain := iConn.(*net.TCPConn)
   233  	if !isDrain {
   234  		_, isDrain = iConn.(*net.UnixConn)
   235  	}
   236  
   237  	reader := &buf.BufferedReader{Reader: buf.NewReader(connection)}
   238  	svrSession := encoding.NewServerSession(h.clients, h.sessionHistory)
   239  	svrSession.SetAEADForced(aeadForced)
   240  	request, err := svrSession.DecodeRequestHeader(reader, isDrain)
   241  	if err != nil {
   242  		if errors.Cause(err) != io.EOF {
   243  			log.Record(&log.AccessMessage{
   244  				From:   connection.RemoteAddr(),
   245  				To:     "",
   246  				Status: log.AccessRejected,
   247  				Reason: err,
   248  			})
   249  			err = newError("invalid request from ", connection.RemoteAddr()).Base(err).AtInfo()
   250  		}
   251  		return err
   252  	}
   253  
   254  	if h.secure && isInsecureEncryption(request.Security) {
   255  		log.Record(&log.AccessMessage{
   256  			From:   connection.RemoteAddr(),
   257  			To:     "",
   258  			Status: log.AccessRejected,
   259  			Reason: "Insecure encryption",
   260  			Email:  request.User.Email,
   261  		})
   262  		return newError("client is using insecure encryption: ", request.Security)
   263  	}
   264  
   265  	if request.Command != protocol.RequestCommandMux {
   266  		ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   267  			From:   connection.RemoteAddr(),
   268  			To:     request.Destination(),
   269  			Status: log.AccessAccepted,
   270  			Reason: "",
   271  			Email:  request.User.Email,
   272  		})
   273  	}
   274  
   275  	newError("received request for ", request.Destination()).WriteToLog(session.ExportIDToError(ctx))
   276  
   277  	if err := connection.SetReadDeadline(time.Time{}); err != nil {
   278  		newError("unable to set back read deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
   279  	}
   280  
   281  	inbound := session.InboundFromContext(ctx)
   282  	if inbound == nil {
   283  		panic("no inbound metadata")
   284  	}
   285  	inbound.User = request.User
   286  
   287  	sessionPolicy = h.policyManager.ForLevel(request.User.Level)
   288  
   289  	ctx, cancel := context.WithCancel(ctx)
   290  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   291  
   292  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   293  	link, err := dispatcher.Dispatch(ctx, request.Destination())
   294  	if err != nil {
   295  		return newError("failed to dispatch request to ", request.Destination()).Base(err)
   296  	}
   297  
   298  	requestDone := func() error {
   299  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   300  
   301  		bodyReader := svrSession.DecodeRequestBody(request, reader)
   302  		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   303  			return newError("failed to transfer request").Base(err)
   304  		}
   305  		return nil
   306  	}
   307  
   308  	responseDone := func() error {
   309  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   310  
   311  		writer := buf.NewBufferedWriter(buf.NewWriter(connection))
   312  		defer writer.Flush()
   313  
   314  		response := &protocol.ResponseHeader{
   315  			Command: h.generateCommand(ctx, request),
   316  		}
   317  		return transferResponse(timer, svrSession, request, response, link.Reader, writer)
   318  	}
   319  
   320  	var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
   321  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   322  		common.Interrupt(link.Reader)
   323  		common.Interrupt(link.Writer)
   324  		return newError("connection ends").Base(err)
   325  	}
   326  
   327  	return nil
   328  }
   329  
   330  func (h *Handler) generateCommand(ctx context.Context, request *protocol.RequestHeader) protocol.ResponseCommand {
   331  	if h.detours != nil {
   332  		tag := h.detours.To
   333  		if h.inboundHandlerManager != nil {
   334  			handler, err := h.inboundHandlerManager.GetHandler(ctx, tag)
   335  			if err != nil {
   336  				newError("failed to get detour handler: ", tag).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   337  				return nil
   338  			}
   339  			proxyHandler, port, availableMin := handler.GetRandomInboundProxy()
   340  			inboundHandler, ok := proxyHandler.(*Handler)
   341  			if ok && inboundHandler != nil {
   342  				if availableMin > 255 {
   343  					availableMin = 255
   344  				}
   345  
   346  				newError("pick detour handler for port ", port, " for ", availableMin, " minutes.").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   347  				user := inboundHandler.GetUser(request.User.Email)
   348  				if user == nil {
   349  					return nil
   350  				}
   351  				account := user.Account.(*vmess.MemoryAccount)
   352  				return &protocol.CommandSwitchAccount{
   353  					Port:     port,
   354  					ID:       account.ID.UUID(),
   355  					AlterIds: uint16(len(account.AlterIDs)),
   356  					Level:    user.Level,
   357  					ValidMin: byte(availableMin),
   358  				}
   359  			}
   360  		}
   361  	}
   362  
   363  	return nil
   364  }
   365  
   366  func init() {
   367  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   368  		return New(ctx, config.(*Config))
   369  	}))
   370  
   371  	const defaultFlagValue = "NOT_DEFINED_AT_ALL"
   372  
   373  	isAeadForced := platform.NewEnvFlag("xray.vmess.aead.forced").GetValue(func() string { return defaultFlagValue })
   374  	aeadForced = (isAeadForced == "true")
   375  }