github.com/eagleql/xray-core@v1.4.4/proxy/trojan/client.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"syscall"
     6  	"time"
     7  
     8  	"github.com/eagleql/xray-core/common"
     9  	"github.com/eagleql/xray-core/common/buf"
    10  	"github.com/eagleql/xray-core/common/errors"
    11  	"github.com/eagleql/xray-core/common/net"
    12  	"github.com/eagleql/xray-core/common/platform"
    13  	"github.com/eagleql/xray-core/common/protocol"
    14  	"github.com/eagleql/xray-core/common/retry"
    15  	"github.com/eagleql/xray-core/common/session"
    16  	"github.com/eagleql/xray-core/common/signal"
    17  	"github.com/eagleql/xray-core/common/task"
    18  	core "github.com/eagleql/xray-core/core"
    19  	"github.com/eagleql/xray-core/features/policy"
    20  	"github.com/eagleql/xray-core/features/stats"
    21  	"github.com/eagleql/xray-core/transport"
    22  	"github.com/eagleql/xray-core/transport/internet"
    23  	"github.com/eagleql/xray-core/transport/internet/xtls"
    24  )
    25  
    26  // Client is a inbound handler for trojan protocol
    27  type Client struct {
    28  	serverPicker  protocol.ServerPicker
    29  	policyManager policy.Manager
    30  }
    31  
    32  // NewClient create a new trojan client.
    33  func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
    34  	serverList := protocol.NewServerList()
    35  	for _, rec := range config.Server {
    36  		s, err := protocol.NewServerSpecFromPB(rec)
    37  		if err != nil {
    38  			return nil, newError("failed to parse server spec").Base(err)
    39  		}
    40  		serverList.AddServer(s)
    41  	}
    42  	if serverList.Size() == 0 {
    43  		return nil, newError("0 server")
    44  	}
    45  
    46  	v := core.MustFromContext(ctx)
    47  	client := &Client{
    48  		serverPicker:  protocol.NewRoundRobinServerPicker(serverList),
    49  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    50  	}
    51  	return client, nil
    52  }
    53  
    54  // Process implements OutboundHandler.Process().
    55  func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    56  	outbound := session.OutboundFromContext(ctx)
    57  	if outbound == nil || !outbound.Target.IsValid() {
    58  		return newError("target not specified")
    59  	}
    60  	destination := outbound.Target
    61  	network := destination.Network
    62  
    63  	var server *protocol.ServerSpec
    64  	var conn internet.Connection
    65  
    66  	err := retry.ExponentialBackoff(5, 100).On(func() error {
    67  		server = c.serverPicker.PickServer()
    68  		rawConn, err := dialer.Dial(ctx, server.Destination())
    69  		if err != nil {
    70  			return err
    71  		}
    72  
    73  		conn = rawConn
    74  		return nil
    75  	})
    76  	if err != nil {
    77  		return newError("failed to find an available destination").AtWarning().Base(err)
    78  	}
    79  	newError("tunneling request to ", destination, " via ", server.Destination()).WriteToLog(session.ExportIDToError(ctx))
    80  
    81  	defer conn.Close()
    82  
    83  	iConn := conn
    84  	statConn, ok := iConn.(*internet.StatCouterConnection)
    85  	if ok {
    86  		iConn = statConn.Connection
    87  	}
    88  
    89  	user := server.PickUser()
    90  	account, ok := user.Account.(*MemoryAccount)
    91  	if !ok {
    92  		return newError("user account is not valid")
    93  	}
    94  
    95  	connWriter := &ConnWriter{
    96  		Flow: account.Flow,
    97  	}
    98  
    99  	var rawConn syscall.RawConn
   100  	var sctx context.Context
   101  
   102  	allowUDP443 := false
   103  	switch connWriter.Flow {
   104  	case XRO + "-udp443", XRD + "-udp443", XRS + "-udp443":
   105  		allowUDP443 = true
   106  		connWriter.Flow = connWriter.Flow[:16]
   107  		fallthrough
   108  	case XRO, XRD, XRS:
   109  		if destination.Address.Family().IsDomain() && destination.Address.Domain() == muxCoolAddress {
   110  			return newError(connWriter.Flow + " doesn't support Mux").AtWarning()
   111  		}
   112  		if destination.Network == net.Network_UDP {
   113  			if !allowUDP443 && destination.Port == 443 {
   114  				return newError(connWriter.Flow + " stopped UDP/443").AtInfo()
   115  			}
   116  			connWriter.Flow = ""
   117  		} else { // enable XTLS only if making TCP request
   118  			if xtlsConn, ok := iConn.(*xtls.Conn); ok {
   119  				xtlsConn.RPRX = true
   120  				xtlsConn.SHOW = xtls_show
   121  				xtlsConn.MARK = "XTLS"
   122  				if connWriter.Flow == XRS {
   123  					sctx = ctx
   124  					connWriter.Flow = XRD
   125  				}
   126  				if connWriter.Flow == XRD {
   127  					xtlsConn.DirectMode = true
   128  					if sc, ok := xtlsConn.Connection.(syscall.Conn); ok {
   129  						rawConn, _ = sc.SyscallConn()
   130  					}
   131  				}
   132  			} else {
   133  				return newError(`failed to use ` + connWriter.Flow + `, maybe "security" is not "xtls"`).AtWarning()
   134  			}
   135  		}
   136  	default:
   137  		if _, ok := iConn.(*xtls.Conn); ok {
   138  			panic(`To avoid misunderstanding, you must fill in Trojan "flow" when using XTLS.`)
   139  		}
   140  	}
   141  
   142  	sessionPolicy := c.policyManager.ForLevel(user.Level)
   143  	ctx, cancel := context.WithCancel(ctx)
   144  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   145  
   146  	postRequest := func() error {
   147  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   148  
   149  		bufferWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
   150  
   151  		connWriter.Writer = bufferWriter
   152  		connWriter.Target = destination
   153  		connWriter.Account = account
   154  
   155  		var bodyWriter buf.Writer
   156  		if destination.Network == net.Network_UDP {
   157  			bodyWriter = &PacketWriter{Writer: connWriter, Target: destination}
   158  		} else {
   159  			bodyWriter = connWriter
   160  		}
   161  
   162  		// write some request payload to buffer
   163  		if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
   164  			return newError("failed to write A request payload").Base(err).AtWarning()
   165  		}
   166  
   167  		// Flush; bufferWriter.WriteMultiBufer now is bufferWriter.writer.WriteMultiBuffer
   168  		if err = bufferWriter.SetBuffered(false); err != nil {
   169  			return newError("failed to flush payload").Base(err).AtWarning()
   170  		}
   171  
   172  		// Send header if not sent yet
   173  		if _, err = connWriter.Write([]byte{}); err != nil {
   174  			return err.(*errors.Error).AtWarning()
   175  		}
   176  
   177  		if err = buf.Copy(link.Reader, bodyWriter, buf.UpdateActivity(timer)); err != nil {
   178  			return newError("failed to transfer request payload").Base(err).AtInfo()
   179  		}
   180  
   181  		return nil
   182  	}
   183  
   184  	getResponse := func() error {
   185  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   186  
   187  		var reader buf.Reader
   188  		if network == net.Network_UDP {
   189  			reader = &PacketReader{
   190  				Reader: conn,
   191  			}
   192  		} else {
   193  			reader = buf.NewReader(conn)
   194  		}
   195  		if rawConn != nil {
   196  			var counter stats.Counter
   197  			if statConn != nil {
   198  				counter = statConn.ReadCounter
   199  			}
   200  			return ReadV(reader, link.Writer, timer, iConn.(*xtls.Conn), rawConn, counter, sctx)
   201  		}
   202  		return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))
   203  	}
   204  
   205  	var responseDoneAndCloseWriter = task.OnSuccess(getResponse, task.Close(link.Writer))
   206  	if err := task.Run(ctx, postRequest, responseDoneAndCloseWriter); err != nil {
   207  		return newError("connection ends").Base(err)
   208  	}
   209  
   210  	return nil
   211  }
   212  
   213  func init() {
   214  	common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   215  		return NewClient(ctx, config.(*ClientConfig))
   216  	}))
   217  
   218  	const defaultFlagValue = "NOT_DEFINED_AT_ALL"
   219  
   220  	xtlsShow := platform.NewEnvFlag("xray.trojan.xtls.show").GetValue(func() string { return defaultFlagValue })
   221  	if xtlsShow == "true" {
   222  		xtls_show = true
   223  	}
   224  }