github.com/moqsien/xraycore@v1.8.5/proxy/trojan/server.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"strconv"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/moqsien/xraycore/common"
    11  	"github.com/moqsien/xraycore/common/buf"
    12  	"github.com/moqsien/xraycore/common/errors"
    13  	"github.com/moqsien/xraycore/common/log"
    14  	"github.com/moqsien/xraycore/common/net"
    15  	"github.com/moqsien/xraycore/common/protocol"
    16  	udp_proto "github.com/moqsien/xraycore/common/protocol/udp"
    17  	"github.com/moqsien/xraycore/common/retry"
    18  	"github.com/moqsien/xraycore/common/session"
    19  	"github.com/moqsien/xraycore/common/signal"
    20  	"github.com/moqsien/xraycore/common/task"
    21  	"github.com/moqsien/xraycore/core"
    22  	"github.com/moqsien/xraycore/features/policy"
    23  	"github.com/moqsien/xraycore/features/routing"
    24  	"github.com/moqsien/xraycore/transport/internet/reality"
    25  	"github.com/moqsien/xraycore/transport/internet/stat"
    26  	"github.com/moqsien/xraycore/transport/internet/tls"
    27  	"github.com/moqsien/xraycore/transport/internet/udp"
    28  )
    29  
    30  func init() {
    31  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    32  		return NewServer(ctx, config.(*ServerConfig))
    33  	}))
    34  }
    35  
    36  // Server is an inbound connection handler that handles messages in trojan protocol.
    37  type Server struct {
    38  	policyManager policy.Manager
    39  	validator     *Validator
    40  	fallbacks     map[string]map[string]map[string]*Fallback // or nil
    41  	cone          bool
    42  }
    43  
    44  // NewServer creates a new trojan inbound handler.
    45  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    46  	validator := new(Validator)
    47  	for _, user := range config.Users {
    48  		u, err := user.ToMemoryUser()
    49  		if err != nil {
    50  			return nil, newError("failed to get trojan user").Base(err).AtError()
    51  		}
    52  
    53  		if err := validator.Add(u); err != nil {
    54  			return nil, newError("failed to add user").Base(err).AtError()
    55  		}
    56  	}
    57  
    58  	v := core.MustFromContext(ctx)
    59  	server := &Server{
    60  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    61  		validator:     validator,
    62  		cone:          ctx.Value("cone").(bool),
    63  	}
    64  
    65  	if config.Fallbacks != nil {
    66  		server.fallbacks = make(map[string]map[string]map[string]*Fallback)
    67  		for _, fb := range config.Fallbacks {
    68  			if server.fallbacks[fb.Name] == nil {
    69  				server.fallbacks[fb.Name] = make(map[string]map[string]*Fallback)
    70  			}
    71  			if server.fallbacks[fb.Name][fb.Alpn] == nil {
    72  				server.fallbacks[fb.Name][fb.Alpn] = make(map[string]*Fallback)
    73  			}
    74  			server.fallbacks[fb.Name][fb.Alpn][fb.Path] = fb
    75  		}
    76  		if server.fallbacks[""] != nil {
    77  			for name, apfb := range server.fallbacks {
    78  				if name != "" {
    79  					for alpn := range server.fallbacks[""] {
    80  						if apfb[alpn] == nil {
    81  							apfb[alpn] = make(map[string]*Fallback)
    82  						}
    83  					}
    84  				}
    85  			}
    86  		}
    87  		for _, apfb := range server.fallbacks {
    88  			if apfb[""] != nil {
    89  				for alpn, pfb := range apfb {
    90  					if alpn != "" { // && alpn != "h2" {
    91  						for path, fb := range apfb[""] {
    92  							if pfb[path] == nil {
    93  								pfb[path] = fb
    94  							}
    95  						}
    96  					}
    97  				}
    98  			}
    99  		}
   100  		if server.fallbacks[""] != nil {
   101  			for name, apfb := range server.fallbacks {
   102  				if name != "" {
   103  					for alpn, pfb := range server.fallbacks[""] {
   104  						for path, fb := range pfb {
   105  							if apfb[alpn][path] == nil {
   106  								apfb[alpn][path] = fb
   107  							}
   108  						}
   109  					}
   110  				}
   111  			}
   112  		}
   113  	}
   114  
   115  	return server, nil
   116  }
   117  
   118  // AddUser implements proxy.UserManager.AddUser().
   119  func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
   120  	return s.validator.Add(u)
   121  }
   122  
   123  // RemoveUser implements proxy.UserManager.RemoveUser().
   124  func (s *Server) RemoveUser(ctx context.Context, e string) error {
   125  	return s.validator.Del(e)
   126  }
   127  
   128  // Network implements proxy.Inbound.Network().
   129  func (s *Server) Network() []net.Network {
   130  	return []net.Network{net.Network_TCP, net.Network_UNIX}
   131  }
   132  
   133  // Process implements proxy.Inbound.Process().
   134  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
   135  	sid := session.ExportIDToError(ctx)
   136  
   137  	iConn := conn
   138  	statConn, ok := iConn.(*stat.CounterConnection)
   139  	if ok {
   140  		iConn = statConn.Connection
   141  	}
   142  
   143  	sessionPolicy := s.policyManager.ForLevel(0)
   144  	if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   145  		return newError("unable to set read deadline").Base(err).AtWarning()
   146  	}
   147  
   148  	first := buf.FromBytes(make([]byte, buf.Size))
   149  	first.Clear()
   150  	firstLen, err := first.ReadFrom(conn)
   151  	if err != nil {
   152  		return newError("failed to read first request").Base(err)
   153  	}
   154  	newError("firstLen = ", firstLen).AtInfo().WriteToLog(sid)
   155  
   156  	bufferedReader := &buf.BufferedReader{
   157  		Reader: buf.NewReader(conn),
   158  		Buffer: buf.MultiBuffer{first},
   159  	}
   160  
   161  	var user *protocol.MemoryUser
   162  
   163  	napfb := s.fallbacks
   164  	isfb := napfb != nil
   165  
   166  	shouldFallback := false
   167  	if firstLen < 58 || first.Byte(56) != '\r' {
   168  		// invalid protocol
   169  		err = newError("not trojan protocol")
   170  		log.Record(&log.AccessMessage{
   171  			From:   conn.RemoteAddr(),
   172  			To:     "",
   173  			Status: log.AccessRejected,
   174  			Reason: err,
   175  		})
   176  
   177  		shouldFallback = true
   178  	} else {
   179  		user = s.validator.Get(hexString(first.BytesTo(56)))
   180  		if user == nil {
   181  			// invalid user, let's fallback
   182  			err = newError("not a valid user")
   183  			log.Record(&log.AccessMessage{
   184  				From:   conn.RemoteAddr(),
   185  				To:     "",
   186  				Status: log.AccessRejected,
   187  				Reason: err,
   188  			})
   189  
   190  			shouldFallback = true
   191  		}
   192  	}
   193  
   194  	if isfb && shouldFallback {
   195  		return s.fallback(ctx, sid, err, sessionPolicy, conn, iConn, napfb, first, firstLen, bufferedReader)
   196  	} else if shouldFallback {
   197  		return newError("invalid protocol or invalid user")
   198  	}
   199  
   200  	clientReader := &ConnReader{Reader: bufferedReader}
   201  	if err := clientReader.ParseHeader(); err != nil {
   202  		log.Record(&log.AccessMessage{
   203  			From:   conn.RemoteAddr(),
   204  			To:     "",
   205  			Status: log.AccessRejected,
   206  			Reason: err,
   207  		})
   208  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   209  	}
   210  
   211  	destination := clientReader.Target
   212  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   213  		return newError("unable to set read deadline").Base(err).AtWarning()
   214  	}
   215  
   216  	inbound := session.InboundFromContext(ctx)
   217  	if inbound == nil {
   218  		panic("no inbound metadata")
   219  	}
   220  	inbound.Name = "trojan"
   221  	inbound.User = user
   222  	sessionPolicy = s.policyManager.ForLevel(user.Level)
   223  
   224  	if destination.Network == net.Network_UDP { // handle udp request
   225  		return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
   226  	}
   227  
   228  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   229  		From:   conn.RemoteAddr(),
   230  		To:     destination,
   231  		Status: log.AccessAccepted,
   232  		Reason: "",
   233  		Email:  user.Email,
   234  	})
   235  
   236  	newError("received request for ", destination).WriteToLog(sid)
   237  	return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
   238  }
   239  
   240  func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
   241  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   242  		udpPayload := packet.Payload
   243  		if udpPayload.UDP == nil {
   244  			udpPayload.UDP = &packet.Source
   245  		}
   246  
   247  		if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
   248  			newError("failed to write response").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   249  		}
   250  	})
   251  
   252  	inbound := session.InboundFromContext(ctx)
   253  	user := inbound.User
   254  
   255  	var dest *net.Destination
   256  
   257  	for {
   258  		select {
   259  		case <-ctx.Done():
   260  			return nil
   261  		default:
   262  			mb, err := clientReader.ReadMultiBuffer()
   263  			if err != nil {
   264  				if errors.Cause(err) != io.EOF {
   265  					return newError("unexpected EOF").Base(err)
   266  				}
   267  				return nil
   268  			}
   269  
   270  			mb2, b := buf.SplitFirst(mb)
   271  			if b == nil {
   272  				continue
   273  			}
   274  			destination := *b.UDP
   275  
   276  			currentPacketCtx := ctx
   277  			if inbound.Source.IsValid() {
   278  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   279  					From:   inbound.Source,
   280  					To:     destination,
   281  					Status: log.AccessAccepted,
   282  					Reason: "",
   283  					Email:  user.Email,
   284  				})
   285  			}
   286  			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(ctx))
   287  
   288  			if !s.cone || dest == nil {
   289  				dest = &destination
   290  			}
   291  
   292  			udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
   293  			for _, payload := range mb2 {
   294  				udpServer.Dispatch(currentPacketCtx, *dest, payload)
   295  			}
   296  		}
   297  	}
   298  }
   299  
   300  func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,
   301  	destination net.Destination,
   302  	clientReader buf.Reader,
   303  	clientWriter buf.Writer, dispatcher routing.Dispatcher,
   304  ) error {
   305  	ctx, cancel := context.WithCancel(ctx)
   306  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   307  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   308  
   309  	link, err := dispatcher.Dispatch(ctx, destination)
   310  	if err != nil {
   311  		return newError("failed to dispatch request to ", destination).Base(err)
   312  	}
   313  
   314  	requestDone := func() error {
   315  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   316  		if buf.Copy(clientReader, link.Writer, buf.UpdateActivity(timer)) != nil {
   317  			return newError("failed to transfer request").Base(err)
   318  		}
   319  		return nil
   320  	}
   321  
   322  	responseDone := func() error {
   323  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   324  
   325  		if err := buf.Copy(link.Reader, clientWriter, buf.UpdateActivity(timer)); err != nil {
   326  			return newError("failed to write response").Base(err)
   327  		}
   328  		return nil
   329  	}
   330  
   331  	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
   332  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   333  		common.Must(common.Interrupt(link.Reader))
   334  		common.Must(common.Interrupt(link.Writer))
   335  		return newError("connection ends").Base(err)
   336  	}
   337  
   338  	return nil
   339  }
   340  
   341  func (s *Server) fallback(ctx context.Context, sid errors.ExportOption, err error, sessionPolicy policy.Session, connection stat.Connection, iConn stat.Connection, napfb map[string]map[string]map[string]*Fallback, first *buf.Buffer, firstLen int64, reader buf.Reader) error {
   342  	if err := connection.SetReadDeadline(time.Time{}); err != nil {
   343  		newError("unable to set back read deadline").Base(err).AtWarning().WriteToLog(sid)
   344  	}
   345  	newError("fallback starts").Base(err).AtInfo().WriteToLog(sid)
   346  
   347  	name := ""
   348  	alpn := ""
   349  	if tlsConn, ok := iConn.(*tls.Conn); ok {
   350  		cs := tlsConn.ConnectionState()
   351  		name = cs.ServerName
   352  		alpn = cs.NegotiatedProtocol
   353  		newError("realName = " + name).AtInfo().WriteToLog(sid)
   354  		newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   355  	} else if realityConn, ok := iConn.(*reality.Conn); ok {
   356  		cs := realityConn.ConnectionState()
   357  		name = cs.ServerName
   358  		alpn = cs.NegotiatedProtocol
   359  		newError("realName = " + name).AtInfo().WriteToLog(sid)
   360  		newError("realAlpn = " + alpn).AtInfo().WriteToLog(sid)
   361  	}
   362  	name = strings.ToLower(name)
   363  	alpn = strings.ToLower(alpn)
   364  
   365  	if len(napfb) > 1 || napfb[""] == nil {
   366  		if name != "" && napfb[name] == nil {
   367  			match := ""
   368  			for n := range napfb {
   369  				if n != "" && strings.Contains(name, n) && len(n) > len(match) {
   370  					match = n
   371  				}
   372  			}
   373  			name = match
   374  		}
   375  	}
   376  
   377  	if napfb[name] == nil {
   378  		name = ""
   379  	}
   380  	apfb := napfb[name]
   381  	if apfb == nil {
   382  		return newError(`failed to find the default "name" config`).AtWarning()
   383  	}
   384  
   385  	if apfb[alpn] == nil {
   386  		alpn = ""
   387  	}
   388  	pfb := apfb[alpn]
   389  	if pfb == nil {
   390  		return newError(`failed to find the default "alpn" config`).AtWarning()
   391  	}
   392  
   393  	path := ""
   394  	if len(pfb) > 1 || pfb[""] == nil {
   395  		if firstLen >= 18 && first.Byte(4) != '*' { // not h2c
   396  			firstBytes := first.Bytes()
   397  			for i := 4; i <= 8; i++ { // 5 -> 9
   398  				if firstBytes[i] == '/' && firstBytes[i-1] == ' ' {
   399  					search := len(firstBytes)
   400  					if search > 64 {
   401  						search = 64 // up to about 60
   402  					}
   403  					for j := i + 1; j < search; j++ {
   404  						k := firstBytes[j]
   405  						if k == '\r' || k == '\n' { // avoid logging \r or \n
   406  							break
   407  						}
   408  						if k == '?' || k == ' ' {
   409  							path = string(firstBytes[i:j])
   410  							newError("realPath = " + path).AtInfo().WriteToLog(sid)
   411  							if pfb[path] == nil {
   412  								path = ""
   413  							}
   414  							break
   415  						}
   416  					}
   417  					break
   418  				}
   419  			}
   420  		}
   421  	}
   422  	fb := pfb[path]
   423  	if fb == nil {
   424  		return newError(`failed to find the default "path" config`).AtWarning()
   425  	}
   426  
   427  	ctx, cancel := context.WithCancel(ctx)
   428  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   429  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   430  
   431  	var conn net.Conn
   432  	if err := retry.ExponentialBackoff(5, 100).On(func() error {
   433  		var dialer net.Dialer
   434  		conn, err = dialer.DialContext(ctx, fb.Type, fb.Dest)
   435  		if err != nil {
   436  			return err
   437  		}
   438  		return nil
   439  	}); err != nil {
   440  		return newError("failed to dial to " + fb.Dest).Base(err).AtWarning()
   441  	}
   442  	defer conn.Close()
   443  
   444  	serverReader := buf.NewReader(conn)
   445  	serverWriter := buf.NewWriter(conn)
   446  
   447  	postRequest := func() error {
   448  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   449  		if fb.Xver != 0 {
   450  			ipType := 4
   451  			remoteAddr, remotePort, err := net.SplitHostPort(connection.RemoteAddr().String())
   452  			if err != nil {
   453  				ipType = 0
   454  			}
   455  			localAddr, localPort, err := net.SplitHostPort(connection.LocalAddr().String())
   456  			if err != nil {
   457  				ipType = 0
   458  			}
   459  			if ipType == 4 {
   460  				for i := 0; i < len(remoteAddr); i++ {
   461  					if remoteAddr[i] == ':' {
   462  						ipType = 6
   463  						break
   464  					}
   465  				}
   466  			}
   467  			pro := buf.New()
   468  			defer pro.Release()
   469  			switch fb.Xver {
   470  			case 1:
   471  				if ipType == 0 {
   472  					common.Must2(pro.Write([]byte("PROXY UNKNOWN\r\n")))
   473  					break
   474  				}
   475  				if ipType == 4 {
   476  					common.Must2(pro.Write([]byte("PROXY TCP4 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
   477  				} else {
   478  					common.Must2(pro.Write([]byte("PROXY TCP6 " + remoteAddr + " " + localAddr + " " + remotePort + " " + localPort + "\r\n")))
   479  				}
   480  			case 2:
   481  				common.Must2(pro.Write([]byte("\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"))) // signature
   482  				if ipType == 0 {
   483  					common.Must2(pro.Write([]byte("\x20\x00\x00\x00"))) // v2 + LOCAL + UNSPEC + UNSPEC + 0 bytes
   484  					break
   485  				}
   486  				if ipType == 4 {
   487  					common.Must2(pro.Write([]byte("\x21\x11\x00\x0C"))) // v2 + PROXY + AF_INET + STREAM + 12 bytes
   488  					common.Must2(pro.Write(net.ParseIP(remoteAddr).To4()))
   489  					common.Must2(pro.Write(net.ParseIP(localAddr).To4()))
   490  				} else {
   491  					common.Must2(pro.Write([]byte("\x21\x21\x00\x24"))) // v2 + PROXY + AF_INET6 + STREAM + 36 bytes
   492  					common.Must2(pro.Write(net.ParseIP(remoteAddr).To16()))
   493  					common.Must2(pro.Write(net.ParseIP(localAddr).To16()))
   494  				}
   495  				p1, _ := strconv.ParseUint(remotePort, 10, 16)
   496  				p2, _ := strconv.ParseUint(localPort, 10, 16)
   497  				common.Must2(pro.Write([]byte{byte(p1 >> 8), byte(p1), byte(p2 >> 8), byte(p2)}))
   498  			}
   499  			if err := serverWriter.WriteMultiBuffer(buf.MultiBuffer{pro}); err != nil {
   500  				return newError("failed to set PROXY protocol v", fb.Xver).Base(err).AtWarning()
   501  			}
   502  		}
   503  		if err := buf.Copy(reader, serverWriter, buf.UpdateActivity(timer)); err != nil {
   504  			return newError("failed to fallback request payload").Base(err).AtInfo()
   505  		}
   506  		return nil
   507  	}
   508  
   509  	writer := buf.NewWriter(connection)
   510  
   511  	getResponse := func() error {
   512  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   513  		if err := buf.Copy(serverReader, writer, buf.UpdateActivity(timer)); err != nil {
   514  			return newError("failed to deliver response payload").Base(err).AtInfo()
   515  		}
   516  		return nil
   517  	}
   518  
   519  	if err := task.Run(ctx, task.OnSuccess(postRequest, task.Close(serverWriter)), task.OnSuccess(getResponse, task.Close(writer))); err != nil {
   520  		common.Must(common.Interrupt(serverReader))
   521  		common.Must(common.Interrupt(serverWriter))
   522  		return newError("fallback ends").Base(err).AtInfo()
   523  	}
   524  
   525  	return nil
   526  }