github.com/xmplusdev/xray-core@v1.8.10/proxy/trojan/server.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"strconv"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/xmplusdev/xray-core/common"
    11  	"github.com/xmplusdev/xray-core/common/buf"
    12  	"github.com/xmplusdev/xray-core/common/errors"
    13  	"github.com/xmplusdev/xray-core/common/log"
    14  	"github.com/xmplusdev/xray-core/common/net"
    15  	"github.com/xmplusdev/xray-core/common/protocol"
    16  	udp_proto "github.com/xmplusdev/xray-core/common/protocol/udp"
    17  	"github.com/xmplusdev/xray-core/common/retry"
    18  	"github.com/xmplusdev/xray-core/common/session"
    19  	"github.com/xmplusdev/xray-core/common/signal"
    20  	"github.com/xmplusdev/xray-core/common/task"
    21  	"github.com/xmplusdev/xray-core/core"
    22  	"github.com/xmplusdev/xray-core/features/policy"
    23  	"github.com/xmplusdev/xray-core/features/routing"
    24  	"github.com/xmplusdev/xray-core/transport/internet/reality"
    25  	"github.com/xmplusdev/xray-core/transport/internet/stat"
    26  	"github.com/xmplusdev/xray-core/transport/internet/tls"
    27  	"github.com/xmplusdev/xray-core/transport/internet/udp"
    28  )
    29  
    30  func init() {
    31  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    32  		return NewServer(ctx, config.(*ServerConfig))
    33  	}))
    34  }
    35  
    36  // Server is an inbound connection handler that handles messages in trojan protocol.
    37  type Server struct {
    38  	policyManager policy.Manager
    39  	validator     *Validator
    40  	fallbacks     map[string]map[string]map[string]*Fallback // or nil
    41  	cone          bool
    42  }
    43  
    44  // NewServer creates a new trojan inbound handler.
    45  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    46  	validator := new(Validator)
    47  	for _, user := range config.Users {
    48  		u, err := user.ToMemoryUser()
    49  		if err != nil {
    50  			return nil, newError("failed to get trojan user").Base(err).AtError()
    51  		}
    52  
    53  		if err := validator.Add(u); err != nil {
    54  			return nil, newError("failed to add user").Base(err).AtError()
    55  		}
    56  	}
    57  
    58  	v := core.MustFromContext(ctx)
    59  	server := &Server{
    60  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    61  		validator:     validator,
    62  		cone:          ctx.Value("cone").(bool),
    63  	}
    64  
    65  	if config.Fallbacks != nil {
    66  		server.fallbacks = make(map[string]map[string]map[string]*Fallback)
    67  		for _, fb := range config.Fallbacks {
    68  			if server.fallbacks[fb.Name] == nil {
    69  				server.fallbacks[fb.Name] = make(map[string]map[string]*Fallback)
    70  			}
    71  			if server.fallbacks[fb.Name][fb.Alpn] == nil {
    72  				server.fallbacks[fb.Name][fb.Alpn] = make(map[string]*Fallback)
    73  			}
    74  			server.fallbacks[fb.Name][fb.Alpn][fb.Path] = fb
    75  		}
    76  		if server.fallbacks[""] != nil {
    77  			for name, apfb := range server.fallbacks {
    78  				if name != "" {
    79  					for alpn := range server.fallbacks[""] {
    80  						if apfb[alpn] == nil {
    81  							apfb[alpn] = make(map[string]*Fallback)
    82  						}
    83  					}
    84  				}
    85  			}
    86  		}
    87  		for _, apfb := range server.fallbacks {
    88  			if apfb[""] != nil {
    89  				for alpn, pfb := range apfb {
    90  					if alpn != "" { // && alpn != "h2" {
    91  						for path, fb := range apfb[""] {
    92  							if pfb[path] == nil {
    93  								pfb[path] = fb
    94  							}
    95  						}
    96  					}
    97  				}
    98  			}
    99  		}
   100  		if server.fallbacks[""] != nil {
   101  			for name, apfb := range server.fallbacks {
   102  				if name != "" {
   103  					for alpn, pfb := range server.fallbacks[""] {
   104  						for path, fb := range pfb {
   105  							if apfb[alpn][path] == nil {
   106  								apfb[alpn][path] = fb
   107  							}
   108  						}
   109  					}
   110  				}
   111  			}
   112  		}
   113  	}
   114  
   115  	return server, nil
   116  }
   117  
   118  // AddUser implements proxy.UserManager.AddUser().
   119  func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
   120  	return s.validator.Add(u)
   121  }
   122  
   123  // RemoveUser implements proxy.UserManager.RemoveUser().
   124  func (s *Server) RemoveUser(ctx context.Context, e string) error {
   125  	return s.validator.Del(e)
   126  }
   127  
   128  // Network implements proxy.Inbound.Network().
   129  func (s *Server) Network() []net.Network {
   130  	return []net.Network{net.Network_TCP, net.Network_UNIX}
   131  }
   132  
   133  // Process implements proxy.Inbound.Process().
   134  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
   135  	sid := session.ExportIDToError(ctx)
   136  
   137  	iConn := conn
   138  	statConn, ok := iConn.(*stat.CounterConnection)
   139  	if ok {
   140  		iConn = statConn.Connection
   141  	}
   142  
   143  	sessionPolicy := s.policyManager.ForLevel(0)
   144  	if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   145  		return newError("unable to set read deadline").Base(err).AtWarning()
   146  	}
   147  
   148  	first := buf.FromBytes(make([]byte, buf.Size))
   149  	first.Clear()
   150  	firstLen, err := first.ReadFrom(conn)
   151  	if err != nil {
   152  		return newError("failed to read first request").Base(err)
   153  	}
   154  	newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid)
   155  
   156  	bufferedReader := &buf.BufferedReader{
   157  		Reader: buf.NewReader(conn),
   158  		Buffer: buf.MultiBuffer{first},
   159  	}
   160  
   161  	var user *protocol.MemoryUser
   162  
   163  	napfb := s.fallbacks
   164  	isfb := napfb != nil
   165  
   166  	shouldFallback := false
   167  	if firstLen < 58 || first.Byte(56) != '\r' {
   168  		// invalid protocol
   169  		err = newError("not trojan protocol")
   170  		log.Record(&log.AccessMessage{
   171  			From:   conn.RemoteAddr(),
   172  			To:     "",
   173  			Status: log.AccessRejected,
   174  			Reason: err,
   175  		})
   176  
   177  		shouldFallback = true
   178  	} else {
   179  		user = s.validator.Get(hexString(first.BytesTo(56)))
   180  		if user == nil {
   181  			// invalid user, let's fallback
   182  			err = newError("not a valid user")
   183  			log.Record(&log.AccessMessage{
   184  				From:   conn.RemoteAddr(),
   185  				To:     "",
   186  				Status: log.AccessRejected,
   187  				Reason: err,
   188  			})
   189  
   190  			shouldFallback = true
   191  		}
   192  	}
   193  
   194  	if isfb && shouldFallback {
   195  		return s.fallback(ctx, sid, err, sessionPolicy, conn, iConn, napfb, first, firstLen, bufferedReader)
   196  	} else if shouldFallback {
   197  		return newError("invalid protocol or invalid user")
   198  	}
   199  
   200  	clientReader := &ConnReader{Reader: bufferedReader}
   201  	if err := clientReader.ParseHeader(); err != nil {
   202  		log.Record(&log.AccessMessage{
   203  			From:   conn.RemoteAddr(),
   204  			To:     "",
   205  			Status: log.AccessRejected,
   206  			Reason: err,
   207  		})
   208  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   209  	}
   210  
   211  	destination := clientReader.Target
   212  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   213  		return newError("unable to set read deadline").Base(err).AtWarning()
   214  	}
   215  
   216  	inbound := session.InboundFromContext(ctx)
   217  	inbound.Name = "trojan"
   218  	inbound.SetCanSpliceCopy(3)
   219  	inbound.User = user
   220  	sessionPolicy = s.policyManager.ForLevel(user.Level)
   221  
   222  	if destination.Network == net.Network_UDP { // handle udp request
   223  		return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
   224  	}
   225  
   226  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   227  		From:   conn.RemoteAddr(),
   228  		To:     destination,
   229  		Status: log.AccessAccepted,
   230  		Reason: "",
   231  		Email:  user.Email,
   232  	})
   233  
   234  	newError("received request for ", destination).WriteToLog(sid)
   235  	return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
   236  }
   237  
   238  func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
   239  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   240  		udpPayload := packet.Payload
   241  		if udpPayload.UDP == nil {
   242  			udpPayload.UDP = &packet.Source
   243  		}
   244  
   245  		if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
   246  			newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   247  		}
   248  	})
   249  
   250  	inbound := session.InboundFromContext(ctx)
   251  	user := inbound.User
   252  
   253  	var dest *net.Destination
   254  
   255  	for {
   256  		select {
   257  		case <-ctx.Done():
   258  			return nil
   259  		default:
   260  			mb, err := clientReader.ReadMultiBuffer()
   261  			if err != nil {
   262  				if errors.Cause(err) != io.EOF {
   263  					return newError("unexpected EOF").Base(err)
   264  				}
   265  				return nil
   266  			}
   267  
   268  			mb2, b := buf.SplitFirst(mb)
   269  			if b == nil {
   270  				continue
   271  			}
   272  			destination := *b.UDP
   273  
   274  			currentPacketCtx := ctx
   275  			if inbound.Source.IsValid() {
   276  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   277  					From:   inbound.Source,
   278  					To:     destination,
   279  					Status: log.AccessAccepted,
   280  					Reason: "",
   281  					Email:  user.Email,
   282  				})
   283  			}
   284  			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
   285  
   286  			if !s.cone || dest == nil {
   287  				dest = &destination
   288  			}
   289  
   290  			udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
   291  			for _, payload := range mb2 {
   292  				udpServer.Dispatch(currentPacketCtx, *dest, payload)
   293  			}
   294  		}
   295  	}
   296  }
   297  
   298  func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,
   299  	destination net.Destination,
   300  	clientReader buf.Reader,
   301  	clientWriter buf.Writer, dispatcher routing.Dispatcher,
   302  ) error {
   303  	ctx, cancel := context.WithCancel(ctx)
   304  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   305  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   306  
   307  	link, err := dispatcher.Dispatch(ctx, destination)
   308  	if err != nil {
   309  		return newError("failed to dispatch request to ", destination).Base(err)
   310  	}
   311  
   312  	requestDone := func() error {
   313  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   314  		if buf.Copy(clientReader, link.Writer, buf.UpdateActivity(timer)) != nil {
   315  			return newError("failed to transfer request").Base(err)
   316  		}
   317  		return nil
   318  	}
   319  
   320  	responseDone := func() error {
   321  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   322  
   323  		if err := buf.Copy(link.Reader, clientWriter, buf.UpdateActivity(timer)); err != nil {
   324  			return newError("failed to write response").Base(err)
   325  		}
   326  		return nil
   327  	}
   328  
   329  	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
   330  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   331  		common.Must(common.Interrupt(link.Reader))
   332  		common.Must(common.Interrupt(link.Writer))
   333  		return newError("connection ends").Base(err)
   334  	}
   335  
   336  	return nil
   337  }
   338  
   339  func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection stat.Connection, iConn stat.Connection, napfb map[string]map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error {
   340  	if err := connection.SetReadDeadline(time.Time{}); err != nil {
   341  		newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
   342  	}
   343  	newError("fallback starts").Base(err).AtInfo().WriteToLog(sid)
   344  
   345  	name := ""
   346  	alpn := ""
   347  	if tlsConn, ok := iConn.(*tls.Conn); ok {
   348  		cs := tlsConn.ConnectionState()
   349  		name = cs.ServerName
   350  		alpn = cs.NegotiatedProtocol
   351  		newError("realName = " + name).AtInfo().WriteToLog(sid)
   352  		newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   353  	} else if realityConn, ok := iConn.(*reality.Conn); ok {
   354  		cs := realityConn.ConnectionState()
   355  		name = cs.ServerName
   356  		alpn = cs.NegotiatedProtocol
   357  		newError("realName = " + name).AtInfo().WriteToLog(sid)
   358  		newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   359  	}
   360  	name = strings.ToLower(name)
   361  	alpn = strings.ToLower(alpn)
   362  
   363  	if len(napfb) > 1 || napfb[""] == nil {
   364  		if name != "" && napfb[name] == nil {
   365  			match := ""
   366  			for n := range napfb {
   367  				if n != "" && strings.Contains(name, n) && len(n) > len(match) {
   368  					match = n
   369  				}
   370  			}
   371  			name = match
   372  		}
   373  	}
   374  
   375  	if napfb[name] == nil {
   376  		name = ""
   377  	}
   378  	apfb := napfb[name]
   379  	if apfb == nil {
   380  		return newError(`failed to find the default "name" config`).AtWarning()
   381  	}
   382  
   383  	if apfb[alpn] == nil {
   384  		alpn = ""
   385  	}
   386  	pfb := apfb[alpn]
   387  	if pfb == nil {
   388  		return newError(`failed to find the default "alpn" config`).AtWarning()
   389  	}
   390  
   391  	path := ""
   392  	if len(pfb) > 1 || pfb[""] == nil {
   393  		if firstLen >= 18 && first.Byte(4) != '*' { // not h2c
   394  			firstBytes := first.Bytes()
   395  			for i := 4; i <= 8; i++ { // 5 -> 9
   396  				if firstBytes[i] == '/' && firstBytes[i-1] == ' ' {
   397  					search := len(firstBytes)
   398  					if search > 64 {
   399  						search = 64 // up to about 60
   400  					}
   401  					for j := i + 1; j < search; j++ {
   402  						k := firstBytes[j]
   403  						if k == '\r' || k == '\n' { // avoid logging \r or \n
   404  							break
   405  						}
   406  						if k == '?' || k == ' ' {
   407  							path = string(firstBytes[i:j])
   408  							newError("realPath = " + path).AtInfo().WriteToLog(sid)
   409  							if pfb[path] == nil {
   410  								path = ""
   411  							}
   412  							break
   413  						}
   414  					}
   415  					break
   416  				}
   417  			}
   418  		}
   419  	}
   420  	fb := pfb[path]
   421  	if fb == nil {
   422  		return newError(`failed to find the default "path" config`).AtWarning()
   423  	}
   424  
   425  	ctx, cancel := context.WithCancel(ctx)
   426  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   427  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   428  
   429  	var conn net.Conn
   430  	if err := retry.ExponentialBackoff(5, 100).On(func() error {
   431  		var dialer net.Dialer
   432  		conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest)
   433  		if err != nil {
   434  			return err
   435  		}
   436  		return nil
   437  	}); err != nil {
   438  		return newError("failed to dial to " + fb.Dest).Base(err).AtWarning()
   439  	}
   440  	defer conn.Close()
   441  
   442  	serverReader := buf.NewReader(conn)
   443  	serverWriter := buf.NewWriter(conn)
   444  
   445  	postRequest := func() error {
   446  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   447  		if fb.Xver != 0 {
   448  			ipType := 4
   449  			remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String())
   450  			if err != nil {
   451  				ipType = 0
   452  			}
   453  			localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String())
   454  			if err != nil {
   455  				ipType = 0
   456  			}
   457  			if ipType == 4 {
   458  				for i := 0; i < len(remoteAddr); i++ {
   459  					if remoteAddr[i] == ':' {
   460  						ipType = 6
   461  						break
   462  					}
   463  				}
   464  			}
   465  			pro := buf.New()
   466  			defer pro.Release()
   467  			switch fb.Xver {
   468  			case 1:
   469  				if ipType == 0 {
   470  					common.Must2(pro.Write([]byte("PROXY UNKNOWN\r\n")))
   471  					break
   472  				}
   473  				if ipType == 4 {
   474  					common.Must2(pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
   475  				} else {
   476  					common.Must2(pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
   477  				}
   478  			case 2:
   479  				common.Must2(pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"))) // signature
   480  				if ipType == 0 {
   481  					common.Must2(pro.Write([]byte("\x20\x00\x00\x00"))) // v2 + LOCAL + UNSPEC + UNSPEC + 0 bytes
   482  					break
   483  				}
   484  				if ipType == 4 {
   485  					common.Must2(pro.Write([]byte("\x21\x11\x00\x0C"))) // v2 + PROXY + AF_INET + STREAM + 12 bytes
   486  					common.Must2(pro.Write(net.ParseIP(remoteAddr).To4()))
   487  					common.Must2(pro.Write(net.ParseIP(localAddr).To4()))
   488  				} else {
   489  					common.Must2(pro.Write([]byte("\x21\x21\x00\x24"))) // v2 + PROXY + AF_INET6 + STREAM + 36 bytes
   490  					common.Must2(pro.Write(net.ParseIP(remoteAddr).To16()))
   491  					common.Must2(pro.Write(net.ParseIP(localAddr).To16()))
   492  				}
   493  				p1, _ := strconv.ParseUint(remotePort, 10, 16)
   494  				p2, _ := strconv.ParseUint(localPort, 10, 16)
   495  				common.Must2(pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)}))
   496  			}
   497  			if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil {
   498  				return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning()
   499  			}
   500  		}
   501  		if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil {
   502  			return newError("failed to fallback request payload").Base(err).AtInfo()
   503  		}
   504  		return nil
   505  	}
   506  
   507  	writer := buf.NewWriter(connection)
   508  
   509  	getResponse := func() error {
   510  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   511  		if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil {
   512  			return newError("failed to deliver response payload").Base(err).AtInfo()
   513  		}
   514  		return nil
   515  	}
   516  
   517  	if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil {
   518  		common.Must(common.Interrupt(serverReader))
   519  		common.Must(common.Interrupt(serverWriter))
   520  		return newError("fallback ends").Base(err).AtInfo()
   521  	}
   522  
   523  	return nil
   524  }