github.com/eagleql/xray-core@v1.4.4/proxy/trojan/server.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"io"
     7  	"strconv"
     8  	"strings"
     9  	"syscall"
    10  	"time"
    11  
    12  	"github.com/eagleql/xray-core/common"
    13  	"github.com/eagleql/xray-core/common/buf"
    14  	"github.com/eagleql/xray-core/common/errors"
    15  	"github.com/eagleql/xray-core/common/log"
    16  	"github.com/eagleql/xray-core/common/net"
    17  	"github.com/eagleql/xray-core/common/platform"
    18  	"github.com/eagleql/xray-core/common/protocol"
    19  	udp_proto "github.com/eagleql/xray-core/common/protocol/udp"
    20  	"github.com/eagleql/xray-core/common/retry"
    21  	"github.com/eagleql/xray-core/common/session"
    22  	"github.com/eagleql/xray-core/common/signal"
    23  	"github.com/eagleql/xray-core/common/task"
    24  	core "github.com/eagleql/xray-core/core"
    25  	"github.com/eagleql/xray-core/features/policy"
    26  	"github.com/eagleql/xray-core/features/routing"
    27  	"github.com/eagleql/xray-core/features/stats"
    28  	"github.com/eagleql/xray-core/transport/internet"
    29  	"github.com/eagleql/xray-core/transport/internet/udp"
    30  	"github.com/eagleql/xray-core/transport/internet/xtls"
    31  )
    32  
    33  func init() {
    34  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    35  		return NewServer(ctx, config.(*ServerConfig))
    36  	}))
    37  
    38  	const defaultFlagValue = "NOT_DEFINED_AT_ALL"
    39  
    40  	xtlsShow := platform.NewEnvFlag("xray.trojan.xtls.show").GetValue(func() string { return defaultFlagValue })
    41  	if xtlsShow == "true" {
    42  		xtls_show = true
    43  	}
    44  }
    45  
    46  // Server is an inbound connection handler that handles messages in trojan protocol.
    47  type Server struct {
    48  	policyManager policy.Manager
    49  	validator     *Validator
    50  	fallbacks     map[string]map[string]map[string]*Fallback // or nil
    51  	cone          bool
    52  }
    53  
    54  // NewServer creates a new trojan inbound handler.
    55  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    56  	validator := new(Validator)
    57  	for _, user := range config.Users {
    58  		u, err := user.ToMemoryUser()
    59  		if err != nil {
    60  			return nil, newError("failed to get trojan user").Base(err).AtError()
    61  		}
    62  
    63  		if err := validator.Add(u); err != nil {
    64  			return nil, newError("failed to add user").Base(err).AtError()
    65  		}
    66  	}
    67  
    68  	v := core.MustFromContext(ctx)
    69  	server := &Server{
    70  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    71  		validator:     validator,
    72  		cone:          ctx.Value("cone").(bool),
    73  	}
    74  
    75  	if config.Fallbacks != nil {
    76  		server.fallbacks = make(map[string]map[string]map[string]*Fallback)
    77  		for _, fb := range config.Fallbacks {
    78  			if server.fallbacks[fb.Name] == nil {
    79  				server.fallbacks[fb.Name] = make(map[string]map[string]*Fallback)
    80  			}
    81  			if server.fallbacks[fb.Name][fb.Alpn] == nil {
    82  				server.fallbacks[fb.Name][fb.Alpn] = make(map[string]*Fallback)
    83  			}
    84  			server.fallbacks[fb.Name][fb.Alpn][fb.Path] = fb
    85  		}
    86  		if server.fallbacks[""] != nil {
    87  			for name, apfb := range server.fallbacks {
    88  				if name != "" {
    89  					for alpn := range server.fallbacks[""] {
    90  						if apfb[alpn] == nil {
    91  							apfb[alpn] = make(map[string]*Fallback)
    92  						}
    93  					}
    94  				}
    95  			}
    96  		}
    97  		for _, apfb := range server.fallbacks {
    98  			if apfb[""] != nil {
    99  				for alpn, pfb := range apfb {
   100  					if alpn != "" { // && alpn != "h2" {
   101  						for path, fb := range apfb[""] {
   102  							if pfb[path] == nil {
   103  								pfb[path] = fb
   104  							}
   105  						}
   106  					}
   107  				}
   108  			}
   109  		}
   110  		if server.fallbacks[""] != nil {
   111  			for name, apfb := range server.fallbacks {
   112  				if name != "" {
   113  					for alpn, pfb := range server.fallbacks[""] {
   114  						for path, fb := range pfb {
   115  							if apfb[alpn][path] == nil {
   116  								apfb[alpn][path] = fb
   117  							}
   118  						}
   119  					}
   120  				}
   121  			}
   122  		}
   123  	}
   124  
   125  	return server, nil
   126  }
   127  
   128  // AddUser implements proxy.UserManager.AddUser().
   129  func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
   130  	return s.validator.Add(u)
   131  }
   132  
   133  // RemoveUser implements proxy.UserManager.RemoveUser().
   134  func (s *Server) RemoveUser(ctx context.Context, e string) error {
   135  	return s.validator.Del(e)
   136  }
   137  
   138  // Network implements proxy.Inbound.Network().
   139  func (s *Server) Network() []net.Network {
   140  	return []net.Network{net.Network_TCP, net.Network_UNIX}
   141  }
   142  
   143  // Process implements proxy.Inbound.Process().
   144  func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
   145  	sid := session.ExportIDToError(ctx)
   146  
   147  	iConn := conn
   148  	statConn, ok := iConn.(*internet.StatCouterConnection)
   149  	if ok {
   150  		iConn = statConn.Connection
   151  	}
   152  
   153  	sessionPolicy := s.policyManager.ForLevel(0)
   154  	if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   155  		return newError("unable to set read deadline").Base(err).AtWarning()
   156  	}
   157  
   158  	first := buf.New()
   159  	defer first.Release()
   160  
   161  	firstLen, err := first.ReadFrom(conn)
   162  	if err != nil {
   163  		return newError("failed to read first request").Base(err)
   164  	}
   165  	newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid)
   166  
   167  	bufferedReader := &buf.BufferedReader{
   168  		Reader: buf.NewReader(conn),
   169  		Buffer: buf.MultiBuffer{first},
   170  	}
   171  
   172  	var user *protocol.MemoryUser
   173  
   174  	napfb := s.fallbacks
   175  	isfb := napfb != nil
   176  
   177  	shouldFallback := false
   178  	if firstLen < 58 || first.Byte(56) != '\r' {
   179  		// invalid protocol
   180  		err = newError("not trojan protocol")
   181  		log.Record(&log.AccessMessage{
   182  			From:   conn.RemoteAddr(),
   183  			To:     "",
   184  			Status: log.AccessRejected,
   185  			Reason: err,
   186  		})
   187  
   188  		shouldFallback = true
   189  	} else {
   190  		user = s.validator.Get(hexString(first.BytesTo(56)))
   191  		if user == nil {
   192  			// invalid user, let's fallback
   193  			err = newError("not a valid user")
   194  			log.Record(&log.AccessMessage{
   195  				From:   conn.RemoteAddr(),
   196  				To:     "",
   197  				Status: log.AccessRejected,
   198  				Reason: err,
   199  			})
   200  
   201  			shouldFallback = true
   202  		}
   203  	}
   204  
   205  	if isfb && shouldFallback {
   206  		return s.fallback(ctx, sid, err, sessionPolicy, conn, iConn, napfb, first, firstLen, bufferedReader)
   207  	} else if shouldFallback {
   208  		return newError("invalid protocol or invalid user")
   209  	}
   210  
   211  	clientReader := &ConnReader{Reader: bufferedReader}
   212  	if err := clientReader.ParseHeader(); err != nil {
   213  		log.Record(&log.AccessMessage{
   214  			From:   conn.RemoteAddr(),
   215  			To:     "",
   216  			Status: log.AccessRejected,
   217  			Reason: err,
   218  		})
   219  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   220  	}
   221  
   222  	destination := clientReader.Target
   223  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   224  		return newError("unable to set read deadline").Base(err).AtWarning()
   225  	}
   226  
   227  	inbound := session.InboundFromContext(ctx)
   228  	if inbound == nil {
   229  		panic("no inbound metadata")
   230  	}
   231  	inbound.User = user
   232  	sessionPolicy = s.policyManager.ForLevel(user.Level)
   233  
   234  	if destination.Network == net.Network_UDP { // handle udp request
   235  		return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
   236  	}
   237  
   238  	// handle tcp request
   239  	account, ok := user.Account.(*MemoryAccount)
   240  	if !ok {
   241  		return newError("user account is not valid")
   242  	}
   243  
   244  	var rawConn syscall.RawConn
   245  
   246  	switch clientReader.Flow {
   247  	case XRO, XRD:
   248  		if account.Flow == clientReader.Flow {
   249  			if destination.Address.Family().IsDomain() && destination.Address.Domain() == muxCoolAddress {
   250  				return newError(clientReader.Flow + " doesn't support Mux").AtWarning()
   251  			}
   252  			if xtlsConn, ok := iConn.(*xtls.Conn); ok {
   253  				xtlsConn.RPRX = true
   254  				xtlsConn.SHOW = xtls_show
   255  				xtlsConn.MARK = "XTLS"
   256  				if clientReader.Flow == XRD {
   257  					xtlsConn.DirectMode = true
   258  					if sc, ok := xtlsConn.Connection.(syscall.Conn); ok {
   259  						rawConn, _ = sc.SyscallConn()
   260  					}
   261  				}
   262  			} else {
   263  				return newError(`failed to use ` + clientReader.Flow + `, maybe "security" is not "xtls"`).AtWarning()
   264  			}
   265  		} else {
   266  			return newError(account.Password + " is not able to use " + clientReader.Flow).AtWarning()
   267  		}
   268  	case "":
   269  	}
   270  
   271  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   272  		From:   conn.RemoteAddr(),
   273  		To:     destination,
   274  		Status: log.AccessAccepted,
   275  		Reason: "",
   276  		Email:  user.Email,
   277  	})
   278  
   279  	newError("received request for ", destination).WriteToLog(sid)
   280  	return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher, iConn, rawConn, statConn)
   281  }
   282  
   283  func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
   284  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   285  		udpPayload := packet.Payload
   286  		if udpPayload.UDP == nil {
   287  			udpPayload.UDP = &packet.Source
   288  		}
   289  
   290  		if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
   291  			newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   292  		}
   293  	})
   294  
   295  	inbound := session.InboundFromContext(ctx)
   296  	user := inbound.User
   297  
   298  	var dest *net.Destination
   299  
   300  	for {
   301  		select {
   302  		case <-ctx.Done():
   303  			return nil
   304  		default:
   305  			mb, err := clientReader.ReadMultiBuffer()
   306  			if err != nil {
   307  				if errors.Cause(err) != io.EOF {
   308  					return newError("unexpected EOF").Base(err)
   309  				}
   310  				return nil
   311  			}
   312  
   313  			mb2, b := buf.SplitFirst(mb)
   314  			if b == nil {
   315  				continue
   316  			}
   317  			destination := *b.UDP
   318  
   319  			currentPacketCtx := ctx
   320  			if inbound.Source.IsValid() {
   321  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   322  					From:   inbound.Source,
   323  					To:     destination,
   324  					Status: log.AccessAccepted,
   325  					Reason: "",
   326  					Email:  user.Email,
   327  				})
   328  			}
   329  			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
   330  
   331  			if !s.cone || dest == nil {
   332  				dest = &destination
   333  			}
   334  
   335  			udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
   336  			for _, payload := range mb2 {
   337  				udpServer.Dispatch(currentPacketCtx, *dest, payload)
   338  			}
   339  		}
   340  	}
   341  }
   342  
   343  func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,
   344  	destination net.Destination,
   345  	clientReader buf.Reader,
   346  	clientWriter buf.Writer, dispatcher routing.Dispatcher, iConn internet.Connection, rawConn syscall.RawConn, statConn *internet.StatCouterConnection) error {
   347  	ctx, cancel := context.WithCancel(ctx)
   348  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   349  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   350  
   351  	link, err := dispatcher.Dispatch(ctx, destination)
   352  	if err != nil {
   353  		return newError("failed to dispatch request to ", destination).Base(err)
   354  	}
   355  
   356  	requestDone := func() error {
   357  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   358  
   359  		var err error
   360  		if rawConn != nil {
   361  			var counter stats.Counter
   362  			if statConn != nil {
   363  				counter = statConn.ReadCounter
   364  			}
   365  			err = ReadV(clientReader, link.Writer, timer, iConn.(*xtls.Conn), rawConn, counter, nil)
   366  		} else {
   367  			err = buf.Copy(clientReader, link.Writer, buf.UpdateActivity(timer))
   368  		}
   369  		if err != nil {
   370  			return newError("failed to transfer request").Base(err)
   371  		}
   372  		return nil
   373  	}
   374  
   375  	responseDone := func() error {
   376  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   377  
   378  		if err := buf.Copy(link.Reader, clientWriter, buf.UpdateActivity(timer)); err != nil {
   379  			return newError("failed to write response").Base(err)
   380  		}
   381  		return nil
   382  	}
   383  
   384  	var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
   385  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   386  		common.Must(common.Interrupt(link.Reader))
   387  		common.Must(common.Interrupt(link.Writer))
   388  		return newError("connection ends").Base(err)
   389  	}
   390  
   391  	return nil
   392  }
   393  
   394  func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection internet.Connection, iConn internet.Connection, napfb map[string]map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error {
   395  	if err := connection.SetReadDeadline(time.Time{}); err != nil {
   396  		newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
   397  	}
   398  	newError("fallback starts").Base(err).AtInfo().WriteToLog(sid)
   399  
   400  	name := ""
   401  	alpn := ""
   402  	if tlsConn, ok := iConn.(*tls.Conn); ok {
   403  		cs := tlsConn.ConnectionState()
   404  		name = cs.ServerName
   405  		alpn = cs.NegotiatedProtocol
   406  		newError("realName = " + name).AtInfo().WriteToLog(sid)
   407  		newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   408  	} else if xtlsConn, ok := iConn.(*xtls.Conn); ok {
   409  		cs := xtlsConn.ConnectionState()
   410  		name = cs.ServerName
   411  		alpn = cs.NegotiatedProtocol
   412  		newError("realName = " + name).AtInfo().WriteToLog(sid)
   413  		newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   414  	}
   415  	name = strings.ToLower(name)
   416  	alpn = strings.ToLower(alpn)
   417  
   418  	if len(napfb) > 1 || napfb[""] == nil {
   419  		if name != "" && napfb[name] == nil {
   420  			match := ""
   421  			for n := range napfb {
   422  				if n != "" && strings.Contains(name, n) && len(n) > len(match) {
   423  					match = n
   424  				}
   425  			}
   426  			name = match
   427  		}
   428  	}
   429  
   430  	if napfb[name] == nil {
   431  		name = ""
   432  	}
   433  	apfb := napfb[name]
   434  	if apfb == nil {
   435  		return newError(`failed to find the default "name" config`).AtWarning()
   436  	}
   437  
   438  	if apfb[alpn] == nil {
   439  		alpn = ""
   440  	}
   441  	pfb := apfb[alpn]
   442  	if pfb == nil {
   443  		return newError(`failed to find the default "alpn" config`).AtWarning()
   444  	}
   445  
   446  	path := ""
   447  	if len(pfb) > 1 || pfb[""] == nil {
   448  		if firstLen >= 18 && first.Byte(4) != '*' { // not h2c
   449  			firstBytes := first.Bytes()
   450  			for i := 4; i <= 8; i++ { // 5 -> 9
   451  				if firstBytes[i] == '/' && firstBytes[i-1] == ' ' {
   452  					search := len(firstBytes)
   453  					if search > 64 {
   454  						search = 64 // up to about 60
   455  					}
   456  					for j := i + 1; j < search; j++ {
   457  						k := firstBytes[j]
   458  						if k == '\r' || k == '\n' { // avoid logging \r or \n
   459  							break
   460  						}
   461  						if k == '?' || k == ' ' {
   462  							path = string(firstBytes[i:j])
   463  							newError("realPath = " + path).AtInfo().WriteToLog(sid)
   464  							if pfb[path] == nil {
   465  								path = ""
   466  							}
   467  							break
   468  						}
   469  					}
   470  					break
   471  				}
   472  			}
   473  		}
   474  	}
   475  	fb := pfb[path]
   476  	if fb == nil {
   477  		return newError(`failed to find the default "path" config`).AtWarning()
   478  	}
   479  
   480  	ctx, cancel := context.WithCancel(ctx)
   481  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   482  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   483  
   484  	var conn net.Conn
   485  	if err := retry.ExponentialBackoff(5, 100).On(func() error {
   486  		var dialer net.Dialer
   487  		conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest)
   488  		if err != nil {
   489  			return err
   490  		}
   491  		return nil
   492  	}); err != nil {
   493  		return newError("failed to dial to " + fb.Dest).Base(err).AtWarning()
   494  	}
   495  	defer conn.Close()
   496  
   497  	serverReader := buf.NewReader(conn)
   498  	serverWriter := buf.NewWriter(conn)
   499  
   500  	postRequest := func() error {
   501  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   502  		if fb.Xver != 0 {
   503  			ipType := 4
   504  			remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String())
   505  			if err != nil {
   506  				ipType = 0
   507  			}
   508  			localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String())
   509  			if err != nil {
   510  				ipType = 0
   511  			}
   512  			if ipType == 4 {
   513  				for i := 0; i < len(remoteAddr); i++ {
   514  					if remoteAddr[i] == ':' {
   515  						ipType = 6
   516  						break
   517  					}
   518  				}
   519  			}
   520  			pro := buf.New()
   521  			defer pro.Release()
   522  			switch fb.Xver {
   523  			case 1:
   524  				if ipType == 0 {
   525  					common.Must2(pro.Write([]byte("PROXY UNKNOWN\r\n")))
   526  					break
   527  				}
   528  				if ipType == 4 {
   529  					common.Must2(pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
   530  				} else {
   531  					common.Must2(pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
   532  				}
   533  			case 2:
   534  				common.Must2(pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"))) // signature
   535  				if ipType == 0 {
   536  					common.Must2(pro.Write([]byte("\x20\x00\x00\x00"))) // v2 + LOCAL + UNSPEC + UNSPEC + 0 bytes
   537  					break
   538  				}
   539  				if ipType == 4 {
   540  					common.Must2(pro.Write([]byte("\x21\x11\x00\x0C"))) // v2 + PROXY + AF_INET + STREAM + 12 bytes
   541  					common.Must2(pro.Write(net.ParseIP(remoteAddr).To4()))
   542  					common.Must2(pro.Write(net.ParseIP(localAddr).To4()))
   543  				} else {
   544  					common.Must2(pro.Write([]byte("\x21\x21\x00\x24"))) // v2 + PROXY + AF_INET6 + STREAM + 36 bytes
   545  					common.Must2(pro.Write(net.ParseIP(remoteAddr).To16()))
   546  					common.Must2(pro.Write(net.ParseIP(localAddr).To16()))
   547  				}
   548  				p1, _ := strconv.ParseUint(remotePort, 10, 16)
   549  				p2, _ := strconv.ParseUint(localPort, 10, 16)
   550  				common.Must2(pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)}))
   551  			}
   552  			if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil {
   553  				return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning()
   554  			}
   555  		}
   556  		if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil {
   557  			return newError("failed to fallback request payload").Base(err).AtInfo()
   558  		}
   559  		return nil
   560  	}
   561  
   562  	writer := buf.NewWriter(connection)
   563  
   564  	getResponse := func() error {
   565  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   566  		if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil {
   567  			return newError("failed to deliver response payload").Base(err).AtInfo()
   568  		}
   569  		return nil
   570  	}
   571  
   572  	if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil {
   573  		common.Must(common.Interrupt(serverReader))
   574  		common.Must(common.Interrupt(serverWriter))
   575  		return newError("fallback ends").Base(err).AtInfo()
   576  	}
   577  
   578  	return nil
   579  }