github.com/xraypb/xray-core@v1.6.6/proxy/trojan/client.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"syscall"
     6  	"time"
     7  
     8  	"github.com/xraypb/xray-core/common"
     9  	"github.com/xraypb/xray-core/common/buf"
    10  	"github.com/xraypb/xray-core/common/errors"
    11  	"github.com/xraypb/xray-core/common/net"
    12  	"github.com/xraypb/xray-core/common/platform"
    13  	"github.com/xraypb/xray-core/common/protocol"
    14  	"github.com/xraypb/xray-core/common/retry"
    15  	"github.com/xraypb/xray-core/common/session"
    16  	"github.com/xraypb/xray-core/common/signal"
    17  	"github.com/xraypb/xray-core/common/task"
    18  	core "github.com/xraypb/xray-core/core"
    19  	"github.com/xraypb/xray-core/features/policy"
    20  	"github.com/xraypb/xray-core/features/stats"
    21  	"github.com/xraypb/xray-core/transport"
    22  	"github.com/xraypb/xray-core/transport/internet"
    23  	"github.com/xraypb/xray-core/transport/internet/stat"
    24  	"github.com/xraypb/xray-core/transport/internet/xtls"
    25  )
    26  
    27  // Client is a inbound handler for trojan protocol
    28  type Client struct {
    29  	serverPicker  protocol.ServerPicker
    30  	policyManager policy.Manager
    31  }
    32  
    33  // NewClient create a new trojan client.
    34  func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
    35  	serverList := protocol.NewServerList()
    36  	for _, rec := range config.Server {
    37  		s, err := protocol.NewServerSpecFromPB(rec)
    38  		if err != nil {
    39  			return nil, newError("failed to parse server spec").Base(err)
    40  		}
    41  		serverList.AddServer(s)
    42  	}
    43  	if serverList.Size() == 0 {
    44  		return nil, newError("0 server")
    45  	}
    46  
    47  	v := core.MustFromContext(ctx)
    48  	client := &Client{
    49  		serverPicker:  protocol.NewRoundRobinServerPicker(serverList),
    50  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    51  	}
    52  	return client, nil
    53  }
    54  
    55  // Process implements OutboundHandler.Process().
    56  func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
    57  	outbound := session.OutboundFromContext(ctx)
    58  	if outbound == nil || !outbound.Target.IsValid() {
    59  		return newError("target not specified")
    60  	}
    61  	destination := outbound.Target
    62  	network := destination.Network
    63  
    64  	var server *protocol.ServerSpec
    65  	var conn stat.Connection
    66  
    67  	err := retry.ExponentialBackoff(5, 100).On(func() error {
    68  		server = c.serverPicker.PickServer()
    69  		rawConn, err := dialer.Dial(ctx, server.Destination())
    70  		if err != nil {
    71  			return err
    72  		}
    73  
    74  		conn = rawConn
    75  		return nil
    76  	})
    77  	if err != nil {
    78  		return newError("failed to find an available destination").AtWarning().Base(err)
    79  	}
    80  	newError("tunneling request to ", destination, " via ", server.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx))
    81  
    82  	defer conn.Close()
    83  
    84  	iConn := conn
    85  	statConn, ok := iConn.(*stat.CounterConnection)
    86  	if ok {
    87  		iConn = statConn.Connection
    88  	}
    89  
    90  	user := server.PickUser()
    91  	account, ok := user.Account.(*MemoryAccount)
    92  	if !ok {
    93  		return newError("user account is not valid")
    94  	}
    95  
    96  	connWriter := &ConnWriter{
    97  		Flow: account.Flow,
    98  	}
    99  
   100  	var rawConn syscall.RawConn
   101  	var sctx context.Context
   102  
   103  	allowUDP443 := false
   104  	switch connWriter.Flow {
   105  	case XRO + "-udp443", XRD + "-udp443", XRS + "-udp443":
   106  		allowUDP443 = true
   107  		connWriter.Flow = connWriter.Flow[:16]
   108  		fallthrough
   109  	case XRO, XRD, XRS:
   110  		if destination.Address.Family().IsDomain() && destination.Address.Domain() == muxCoolAddress {
   111  			return newError(connWriter.Flow + " doesn't support Mux").AtWarning()
   112  		}
   113  		if destination.Network == net.Network_UDP {
   114  			if !allowUDP443 && destination.Port == 443 {
   115  				return newError(connWriter.Flow + " stopped UDP/443").AtInfo()
   116  			}
   117  			connWriter.Flow = ""
   118  		} else { // enable XTLS only if making TCP request
   119  			if xtlsConn, ok := iConn.(*xtls.Conn); ok {
   120  				xtlsConn.RPRX = true
   121  				xtlsConn.SHOW = xtls_show
   122  				xtlsConn.MARK = "XTLS"
   123  				if connWriter.Flow == XRS {
   124  					sctx = ctx
   125  					connWriter.Flow = XRD
   126  				}
   127  				if connWriter.Flow == XRD {
   128  					xtlsConn.DirectMode = true
   129  					if sc, ok := xtlsConn.NetConn().(syscall.Conn); ok {
   130  						rawConn, _ = sc.SyscallConn()
   131  					}
   132  				}
   133  			} else {
   134  				return newError(`failed to use ` + connWriter.Flow + `, maybe "security" is not "xtls"`).AtWarning()
   135  			}
   136  		}
   137  	default:
   138  		if _, ok := iConn.(*xtls.Conn); ok {
   139  			panic(`To avoid misunderstanding, you must fill in Trojan "flow" when using XTLS.`)
   140  		}
   141  	}
   142  
   143  	sessionPolicy := c.policyManager.ForLevel(user.Level)
   144  	ctx, cancel := context.WithCancel(ctx)
   145  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   146  
   147  	postRequest := func() error {
   148  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   149  
   150  		bufferWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
   151  
   152  		connWriter.Writer = bufferWriter
   153  		connWriter.Target = destination
   154  		connWriter.Account = account
   155  
   156  		var bodyWriter buf.Writer
   157  		if destination.Network == net.Network_UDP {
   158  			bodyWriter = &PacketWriter{Writer: connWriter, Target: destination}
   159  		} else {
   160  			bodyWriter = connWriter
   161  		}
   162  
   163  		// write some request payload to buffer
   164  		if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
   165  			return newError("failed to write A request payload").Base(err).AtWarning()
   166  		}
   167  
   168  		// Flush; bufferWriter.WriteMultiBufer now is bufferWriter.writer.WriteMultiBuffer
   169  		if err = bufferWriter.SetBuffered(false); err != nil {
   170  			return newError("failed to flush payload").Base(err).AtWarning()
   171  		}
   172  
   173  		// Send header if not sent yet
   174  		if _, err = connWriter.Write([]byte{}); err != nil {
   175  			return err.(*errors.Error).AtWarning()
   176  		}
   177  
   178  		if err = buf.Copy(link.Reader, bodyWriter, buf.UpdateActivity(timer)); err != nil {
   179  			return newError("failed to transfer request payload").Base(err).AtInfo()
   180  		}
   181  
   182  		return nil
   183  	}
   184  
   185  	getResponse := func() error {
   186  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   187  
   188  		var reader buf.Reader
   189  		if network == net.Network_UDP {
   190  			reader = &PacketReader{
   191  				Reader: conn,
   192  			}
   193  		} else {
   194  			reader = buf.NewReader(conn)
   195  		}
   196  		if rawConn != nil {
   197  			var counter stats.Counter
   198  			if statConn != nil {
   199  				counter = statConn.ReadCounter
   200  			}
   201  			return ReadV(reader, link.Writer, timer, iConn.(*xtls.Conn), rawConn, counter, sctx)
   202  		}
   203  		return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))
   204  	}
   205  
   206  	responseDoneAndCloseWriter := task.OnSuccess(getResponse, task.Close(link.Writer))
   207  	if err := task.Run(ctx, postRequest, responseDoneAndCloseWriter); err != nil {
   208  		return newError("connection ends").Base(err)
   209  	}
   210  
   211  	return nil
   212  }
   213  
   214  func init() {
   215  	common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   216  		return NewClient(ctx, config.(*ClientConfig))
   217  	}))
   218  
   219  	const defaultFlagValue = "NOT_DEFINED_AT_ALL"
   220  
   221  	xtlsShow := platform.NewEnvFlag("xray.trojan.xtls.show").GetValue(func() string { return defaultFlagValue })
   222  	if xtlsShow == "true" {
   223  		xtls_show = true
   224  	}
   225  }