github.com/xraypb/xray-core@v1.6.6/proxy/trojan/server.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"strconv"
     7  	"strings"
     8  	"syscall"
     9  	"time"
    10  
    11  	"github.com/xraypb/xray-core/common"
    12  	"github.com/xraypb/xray-core/common/buf"
    13  	"github.com/xraypb/xray-core/common/errors"
    14  	"github.com/xraypb/xray-core/common/log"
    15  	"github.com/xraypb/xray-core/common/net"
    16  	"github.com/xraypb/xray-core/common/platform"
    17  	"github.com/xraypb/xray-core/common/protocol"
    18  	udp_proto "github.com/xraypb/xray-core/common/protocol/udp"
    19  	"github.com/xraypb/xray-core/common/retry"
    20  	"github.com/xraypb/xray-core/common/session"
    21  	"github.com/xraypb/xray-core/common/signal"
    22  	"github.com/xraypb/xray-core/common/task"
    23  	"github.com/xraypb/xray-core/core"
    24  	"github.com/xraypb/xray-core/features/policy"
    25  	"github.com/xraypb/xray-core/features/routing"
    26  	"github.com/xraypb/xray-core/features/stats"
    27  	"github.com/xraypb/xray-core/transport/internet/stat"
    28  	"github.com/xraypb/xray-core/transport/internet/tls"
    29  	"github.com/xraypb/xray-core/transport/internet/udp"
    30  	"github.com/xraypb/xray-core/transport/internet/xtls"
    31  )
    32  
    33  func init() {
    34  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    35  		return NewServer(ctx, config.(*ServerConfig))
    36  	}))
    37  
    38  	const defaultFlagValue = "NOT_DEFINED_AT_ALL"
    39  
    40  	xtlsShow := platform.NewEnvFlag("xray.trojan.xtls.show").GetValue(func() string { return defaultFlagValue })
    41  	if xtlsShow == "true" {
    42  		xtls_show = true
    43  	}
    44  }
    45  
    46  // Server is an inbound connection handler that handles messages in trojan protocol.
    47  type Server struct {
    48  	policyManager policy.Manager
    49  	validator     *Validator
    50  	fallbacks     map[string]map[string]map[string]*Fallback // or nil
    51  	cone          bool
    52  }
    53  
    54  // NewServer creates a new trojan inbound handler.
    55  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    56  	validator := new(Validator)
    57  	for _, user := range config.Users {
    58  		u, err := user.ToMemoryUser()
    59  		if err != nil {
    60  			return nil, newError("failed to get trojan user").Base(err).AtError()
    61  		}
    62  
    63  		if err := validator.Add(u); err != nil {
    64  			return nil, newError("failed to add user").Base(err).AtError()
    65  		}
    66  	}
    67  
    68  	v := core.MustFromContext(ctx)
    69  	server := &Server{
    70  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    71  		validator:     validator,
    72  		cone:          ctx.Value("cone").(bool),
    73  	}
    74  
    75  	if config.Fallbacks != nil {
    76  		server.fallbacks = make(map[string]map[string]map[string]*Fallback)
    77  		for _, fb := range config.Fallbacks {
    78  			if server.fallbacks[fb.Name] == nil {
    79  				server.fallbacks[fb.Name] = make(map[string]map[string]*Fallback)
    80  			}
    81  			if server.fallbacks[fb.Name][fb.Alpn] == nil {
    82  				server.fallbacks[fb.Name][fb.Alpn] = make(map[string]*Fallback)
    83  			}
    84  			server.fallbacks[fb.Name][fb.Alpn][fb.Path] = fb
    85  		}
    86  		if server.fallbacks[""] != nil {
    87  			for name, apfb := range server.fallbacks {
    88  				if name != "" {
    89  					for alpn := range server.fallbacks[""] {
    90  						if apfb[alpn] == nil {
    91  							apfb[alpn] = make(map[string]*Fallback)
    92  						}
    93  					}
    94  				}
    95  			}
    96  		}
    97  		for _, apfb := range server.fallbacks {
    98  			if apfb[""] != nil {
    99  				for alpn, pfb := range apfb {
   100  					if alpn != "" { // && alpn != "h2" {
   101  						for path, fb := range apfb[""] {
   102  							if pfb[path] == nil {
   103  								pfb[path] = fb
   104  							}
   105  						}
   106  					}
   107  				}
   108  			}
   109  		}
   110  		if server.fallbacks[""] != nil {
   111  			for name, apfb := range server.fallbacks {
   112  				if name != "" {
   113  					for alpn, pfb := range server.fallbacks[""] {
   114  						for path, fb := range pfb {
   115  							if apfb[alpn][path] == nil {
   116  								apfb[alpn][path] = fb
   117  							}
   118  						}
   119  					}
   120  				}
   121  			}
   122  		}
   123  	}
   124  
   125  	return server, nil
   126  }
   127  
   128  // AddUser implements proxy.UserManager.AddUser().
   129  func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
   130  	return s.validator.Add(u)
   131  }
   132  
   133  // RemoveUser implements proxy.UserManager.RemoveUser().
   134  func (s *Server) RemoveUser(ctx context.Context, e string) error {
   135  	return s.validator.Del(e)
   136  }
   137  
   138  // Network implements proxy.Inbound.Network().
   139  func (s *Server) Network() []net.Network {
   140  	return []net.Network{net.Network_TCP, net.Network_UNIX}
   141  }
   142  
   143  // Process implements proxy.Inbound.Process().
   144  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
   145  	sid := session.ExportIDToError(ctx)
   146  
   147  	iConn := conn
   148  	statConn, ok := iConn.(*stat.CounterConnection)
   149  	if ok {
   150  		iConn = statConn.Connection
   151  	}
   152  
   153  	sessionPolicy := s.policyManager.ForLevel(0)
   154  	if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   155  		return newError("unable to set read deadline").Base(err).AtWarning()
   156  	}
   157  
   158  	first := buf.New()
   159  	defer first.Release()
   160  
   161  	firstLen, err := first.ReadFrom(conn)
   162  	if err != nil {
   163  		return newError("failed to read first request").Base(err)
   164  	}
   165  	newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid)
   166  
   167  	bufferedReader := &buf.BufferedReader{
   168  		Reader: buf.NewReader(conn),
   169  		Buffer: buf.MultiBuffer{first},
   170  	}
   171  
   172  	var user *protocol.MemoryUser
   173  
   174  	napfb := s.fallbacks
   175  	isfb := napfb != nil
   176  
   177  	shouldFallback := false
   178  	if firstLen < 58 || first.Byte(56) != '\r' {
   179  		// invalid protocol
   180  		err = newError("not trojan protocol")
   181  		log.Record(&log.AccessMessage{
   182  			From:   conn.RemoteAddr(),
   183  			To:     "",
   184  			Status: log.AccessRejected,
   185  			Reason: err,
   186  		})
   187  
   188  		shouldFallback = true
   189  	} else {
   190  		user = s.validator.Get(hexString(first.BytesTo(56)))
   191  		if user == nil {
   192  			// invalid user, let's fallback
   193  			err = newError("not a valid user")
   194  			log.Record(&log.AccessMessage{
   195  				From:   conn.RemoteAddr(),
   196  				To:     "",
   197  				Status: log.AccessRejected,
   198  				Reason: err,
   199  			})
   200  
   201  			shouldFallback = true
   202  		}
   203  	}
   204  
   205  	if isfb && shouldFallback {
   206  		return s.fallback(ctx, sid, err, sessionPolicy, conn, iConn, napfb, first, firstLen, bufferedReader)
   207  	} else if shouldFallback {
   208  		return newError("invalid protocol or invalid user")
   209  	}
   210  
   211  	clientReader := &ConnReader{Reader: bufferedReader}
   212  	if err := clientReader.ParseHeader(); err != nil {
   213  		log.Record(&log.AccessMessage{
   214  			From:   conn.RemoteAddr(),
   215  			To:     "",
   216  			Status: log.AccessRejected,
   217  			Reason: err,
   218  		})
   219  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   220  	}
   221  
   222  	destination := clientReader.Target
   223  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   224  		return newError("unable to set read deadline").Base(err).AtWarning()
   225  	}
   226  
   227  	inbound := session.InboundFromContext(ctx)
   228  	if inbound == nil {
   229  		panic("no inbound metadata")
   230  	}
   231  	inbound.User = user
   232  	sessionPolicy = s.policyManager.ForLevel(user.Level)
   233  
   234  	if destination.Network == net.Network_UDP { // handle udp request
   235  		return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
   236  	}
   237  
   238  	// handle tcp request
   239  	account, ok := user.Account.(*MemoryAccount)
   240  	if !ok {
   241  		return newError("user account is not valid")
   242  	}
   243  
   244  	var rawConn syscall.RawConn
   245  
   246  	switch clientReader.Flow {
   247  	case XRO, XRD:
   248  		if account.Flow == clientReader.Flow {
   249  			if destination.Address.Family().IsDomain() && destination.Address.Domain() == muxCoolAddress {
   250  				return newError(clientReader.Flow + " doesn't support Mux").AtWarning()
   251  			}
   252  			if xtlsConn, ok := iConn.(*xtls.Conn); ok {
   253  				xtlsConn.RPRX = true
   254  				xtlsConn.SHOW = xtls_show
   255  				xtlsConn.MARK = "XTLS"
   256  				if clientReader.Flow == XRD {
   257  					xtlsConn.DirectMode = true
   258  					if sc, ok := xtlsConn.NetConn().(syscall.Conn); ok {
   259  						rawConn, _ = sc.SyscallConn()
   260  					}
   261  				}
   262  			} else {
   263  				return newError(`failed to use ` + clientReader.Flow + `, maybe "security" is not "xtls"`).AtWarning()
   264  			}
   265  		} else {
   266  			return newError(account.Password + " is not able to use " + clientReader.Flow).AtWarning()
   267  		}
   268  	case "":
   269  	}
   270  
   271  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   272  		From:   conn.RemoteAddr(),
   273  		To:     destination,
   274  		Status: log.AccessAccepted,
   275  		Reason: "",
   276  		Email:  user.Email,
   277  	})
   278  
   279  	newError("received request for ", destination).WriteToLog(sid)
   280  	return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher, iConn, rawConn, statConn)
   281  }
   282  
   283  func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
   284  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   285  		udpPayload := packet.Payload
   286  		if udpPayload.UDP == nil {
   287  			udpPayload.UDP = &packet.Source
   288  		}
   289  
   290  		if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
   291  			newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   292  		}
   293  	})
   294  
   295  	inbound := session.InboundFromContext(ctx)
   296  	user := inbound.User
   297  
   298  	var dest *net.Destination
   299  
   300  	for {
   301  		select {
   302  		case <-ctx.Done():
   303  			return nil
   304  		default:
   305  			mb, err := clientReader.ReadMultiBuffer()
   306  			if err != nil {
   307  				if errors.Cause(err) != io.EOF {
   308  					return newError("unexpected EOF").Base(err)
   309  				}
   310  				return nil
   311  			}
   312  
   313  			mb2, b := buf.SplitFirst(mb)
   314  			if b == nil {
   315  				continue
   316  			}
   317  			destination := *b.UDP
   318  
   319  			currentPacketCtx := ctx
   320  			if inbound.Source.IsValid() {
   321  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   322  					From:   inbound.Source,
   323  					To:     destination,
   324  					Status: log.AccessAccepted,
   325  					Reason: "",
   326  					Email:  user.Email,
   327  				})
   328  			}
   329  			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
   330  
   331  			if !s.cone || dest == nil {
   332  				dest = &destination
   333  			}
   334  
   335  			udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
   336  			for _, payload := range mb2 {
   337  				udpServer.Dispatch(currentPacketCtx, *dest, payload)
   338  			}
   339  		}
   340  	}
   341  }
   342  
   343  func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,
   344  	destination net.Destination,
   345  	clientReader buf.Reader,
   346  	clientWriter buf.Writer, dispatcher routing.Dispatcher, iConn stat.Connection, rawConn syscall.RawConn, statConn *stat.CounterConnection,
   347  ) error {
   348  	ctx, cancel := context.WithCancel(ctx)
   349  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   350  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   351  
   352  	link, err := dispatcher.Dispatch(ctx, destination)
   353  	if err != nil {
   354  		return newError("failed to dispatch request to ", destination).Base(err)
   355  	}
   356  
   357  	requestDone := func() error {
   358  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   359  
   360  		var err error
   361  		if rawConn != nil {
   362  			var counter stats.Counter
   363  			if statConn != nil {
   364  				counter = statConn.ReadCounter
   365  			}
   366  			err = ReadV(clientReader, link.Writer, timer, iConn.(*xtls.Conn), rawConn, counter, nil)
   367  		} else {
   368  			err = buf.Copy(clientReader, link.Writer, buf.UpdateActivity(timer))
   369  		}
   370  		if err != nil {
   371  			return newError("failed to transfer request").Base(err)
   372  		}
   373  		return nil
   374  	}
   375  
   376  	responseDone := func() error {
   377  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   378  
   379  		if err := buf.Copy(link.Reader, clientWriter, buf.UpdateActivity(timer)); err != nil {
   380  			return newError("failed to write response").Base(err)
   381  		}
   382  		return nil
   383  	}
   384  
   385  	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
   386  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   387  		common.Must(common.Interrupt(link.Reader))
   388  		common.Must(common.Interrupt(link.Writer))
   389  		return newError("connection ends").Base(err)
   390  	}
   391  
   392  	return nil
   393  }
   394  
   395  func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection stat.Connection, iConn stat.Connection, napfb map[string]map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error {
   396  	if err := connection.SetReadDeadline(time.Time{}); err != nil {
   397  		newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
   398  	}
   399  	newError("fallback starts").Base(err).AtInfo().WriteToLog(sid)
   400  
   401  	name := ""
   402  	alpn := ""
   403  	if tlsConn, ok := iConn.(*tls.Conn); ok {
   404  		cs := tlsConn.ConnectionState()
   405  		name = cs.ServerName
   406  		alpn = cs.NegotiatedProtocol
   407  		newError("realName = " + name).AtInfo().WriteToLog(sid)
   408  		newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   409  	} else if xtlsConn, ok := iConn.(*xtls.Conn); ok {
   410  		cs := xtlsConn.ConnectionState()
   411  		name = cs.ServerName
   412  		alpn = cs.NegotiatedProtocol
   413  		newError("realName = " + name).AtInfo().WriteToLog(sid)
   414  		newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   415  	}
   416  	name = strings.ToLower(name)
   417  	alpn = strings.ToLower(alpn)
   418  
   419  	if len(napfb) > 1 || napfb[""] == nil {
   420  		if name != "" && napfb[name] == nil {
   421  			match := ""
   422  			for n := range napfb {
   423  				if n != "" && strings.Contains(name, n) && len(n) > len(match) {
   424  					match = n
   425  				}
   426  			}
   427  			name = match
   428  		}
   429  	}
   430  
   431  	if napfb[name] == nil {
   432  		name = ""
   433  	}
   434  	apfb := napfb[name]
   435  	if apfb == nil {
   436  		return newError(`failed to find the default "name" config`).AtWarning()
   437  	}
   438  
   439  	if apfb[alpn] == nil {
   440  		alpn = ""
   441  	}
   442  	pfb := apfb[alpn]
   443  	if pfb == nil {
   444  		return newError(`failed to find the default "alpn" config`).AtWarning()
   445  	}
   446  
   447  	path := ""
   448  	if len(pfb) > 1 || pfb[""] == nil {
   449  		if firstLen >= 18 && first.Byte(4) != '*' { // not h2c
   450  			firstBytes := first.Bytes()
   451  			for i := 4; i <= 8; i++ { // 5 -> 9
   452  				if firstBytes[i] == '/' && firstBytes[i-1] == ' ' {
   453  					search := len(firstBytes)
   454  					if search > 64 {
   455  						search = 64 // up to about 60
   456  					}
   457  					for j := i + 1; j < search; j++ {
   458  						k := firstBytes[j]
   459  						if k == '\r' || k == '\n' { // avoid logging \r or \n
   460  							break
   461  						}
   462  						if k == '?' || k == ' ' {
   463  							path = string(firstBytes[i:j])
   464  							newError("realPath = " + path).AtInfo().WriteToLog(sid)
   465  							if pfb[path] == nil {
   466  								path = ""
   467  							}
   468  							break
   469  						}
   470  					}
   471  					break
   472  				}
   473  			}
   474  		}
   475  	}
   476  	fb := pfb[path]
   477  	if fb == nil {
   478  		return newError(`failed to find the default "path" config`).AtWarning()
   479  	}
   480  
   481  	ctx, cancel := context.WithCancel(ctx)
   482  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   483  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   484  
   485  	var conn net.Conn
   486  	if err := retry.ExponentialBackoff(5, 100).On(func() error {
   487  		var dialer net.Dialer
   488  		conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest)
   489  		if err != nil {
   490  			return err
   491  		}
   492  		return nil
   493  	}); err != nil {
   494  		return newError("failed to dial to " + fb.Dest).Base(err).AtWarning()
   495  	}
   496  	defer conn.Close()
   497  
   498  	serverReader := buf.NewReader(conn)
   499  	serverWriter := buf.NewWriter(conn)
   500  
   501  	postRequest := func() error {
   502  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   503  		if fb.Xver != 0 {
   504  			ipType := 4
   505  			remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String())
   506  			if err != nil {
   507  				ipType = 0
   508  			}
   509  			localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String())
   510  			if err != nil {
   511  				ipType = 0
   512  			}
   513  			if ipType == 4 {
   514  				for i := 0; i < len(remoteAddr); i++ {
   515  					if remoteAddr[i] == ':' {
   516  						ipType = 6
   517  						break
   518  					}
   519  				}
   520  			}
   521  			pro := buf.New()
   522  			defer pro.Release()
   523  			switch fb.Xver {
   524  			case 1:
   525  				if ipType == 0 {
   526  					common.Must2(pro.Write([]byte("PROXY UNKNOWN\r\n")))
   527  					break
   528  				}
   529  				if ipType == 4 {
   530  					common.Must2(pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
   531  				} else {
   532  					common.Must2(pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
   533  				}
   534  			case 2:
   535  				common.Must2(pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"))) // signature
   536  				if ipType == 0 {
   537  					common.Must2(pro.Write([]byte("\x20\x00\x00\x00"))) // v2 + LOCAL + UNSPEC + UNSPEC + 0 bytes
   538  					break
   539  				}
   540  				if ipType == 4 {
   541  					common.Must2(pro.Write([]byte("\x21\x11\x00\x0C"))) // v2 + PROXY + AF_INET + STREAM + 12 bytes
   542  					common.Must2(pro.Write(net.ParseIP(remoteAddr).To4()))
   543  					common.Must2(pro.Write(net.ParseIP(localAddr).To4()))
   544  				} else {
   545  					common.Must2(pro.Write([]byte("\x21\x21\x00\x24"))) // v2 + PROXY + AF_INET6 + STREAM + 36 bytes
   546  					common.Must2(pro.Write(net.ParseIP(remoteAddr).To16()))
   547  					common.Must2(pro.Write(net.ParseIP(localAddr).To16()))
   548  				}
   549  				p1, _ := strconv.ParseUint(remotePort, 10, 16)
   550  				p2, _ := strconv.ParseUint(localPort, 10, 16)
   551  				common.Must2(pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)}))
   552  			}
   553  			if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil {
   554  				return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning()
   555  			}
   556  		}
   557  		if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil {
   558  			return newError("failed to fallback request payload").Base(err).AtInfo()
   559  		}
   560  		return nil
   561  	}
   562  
   563  	writer := buf.NewWriter(connection)
   564  
   565  	getResponse := func() error {
   566  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   567  		if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil {
   568  			return newError("failed to deliver response payload").Base(err).AtInfo()
   569  		}
   570  		return nil
   571  	}
   572  
   573  	if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil {
   574  		common.Must(common.Interrupt(serverReader))
   575  		common.Must(common.Interrupt(serverWriter))
   576  		return newError("fallback ends").Base(err).AtInfo()
   577  	}
   578  
   579  	return nil
   580  }