github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/app/proxyman/outbound/handler.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"errors"
     7  	"github.com/xtls/xray-core/app/proxyman"
     8  	"github.com/xtls/xray-core/common"
     9  	"github.com/xtls/xray-core/common/buf"
    10  	"github.com/xtls/xray-core/common/mux"
    11  	"github.com/xtls/xray-core/common/net"
    12  	"github.com/xtls/xray-core/common/net/cnc"
    13  	"github.com/xtls/xray-core/common/session"
    14  	"github.com/xtls/xray-core/core"
    15  	"github.com/xtls/xray-core/features/outbound"
    16  	"github.com/xtls/xray-core/features/policy"
    17  	"github.com/xtls/xray-core/features/stats"
    18  	"github.com/xtls/xray-core/proxy"
    19  	"github.com/xtls/xray-core/transport"
    20  	"github.com/xtls/xray-core/transport/internet"
    21  	"github.com/xtls/xray-core/transport/internet/stat"
    22  	"github.com/xtls/xray-core/transport/internet/tls"
    23  	"github.com/xtls/xray-core/transport/pipe"
    24  	"io"
    25  	"math/big"
    26  	gonet "net"
    27  	"os"
    28  )
    29  
    30  func getStatCounter(v *core.Instance, tag string) (stats.Counter, stats.Counter) {
    31  	var uplinkCounter stats.Counter
    32  	var downlinkCounter stats.Counter
    33  
    34  	policy := v.GetFeature(policy.ManagerType()).(policy.Manager)
    35  	if len(tag) > 0 && policy.ForSystem().Stats.OutboundUplink {
    36  		statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
    37  		name := "outbound>>>" + tag + ">>>traffic>>>uplink"
    38  		c, _ := stats.GetOrRegisterCounter(statsManager, name)
    39  		if c != nil {
    40  			uplinkCounter = c
    41  		}
    42  	}
    43  	if len(tag) > 0 && policy.ForSystem().Stats.OutboundDownlink {
    44  		statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
    45  		name := "outbound>>>" + tag + ">>>traffic>>>downlink"
    46  		c, _ := stats.GetOrRegisterCounter(statsManager, name)
    47  		if c != nil {
    48  			downlinkCounter = c
    49  		}
    50  	}
    51  
    52  	return uplinkCounter, downlinkCounter
    53  }
    54  
    55  // Handler is an implements of outbound.Handler.
    56  type Handler struct {
    57  	tag             string
    58  	senderSettings  *proxyman.SenderConfig
    59  	streamSettings  *internet.MemoryStreamConfig
    60  	proxy           proxy.Outbound
    61  	outboundManager outbound.Manager
    62  	mux             *mux.ClientManager
    63  	xudp            *mux.ClientManager
    64  	udp443          string
    65  	uplinkCounter   stats.Counter
    66  	downlinkCounter stats.Counter
    67  }
    68  
    69  // NewHandler creates a new Handler based on the given configuration.
    70  func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbound.Handler, error) {
    71  	v := core.MustFromContext(ctx)
    72  	uplinkCounter, downlinkCounter := getStatCounter(v, config.Tag)
    73  	h := &Handler{
    74  		tag:             config.Tag,
    75  		outboundManager: v.GetFeature(outbound.ManagerType()).(outbound.Manager),
    76  		uplinkCounter:   uplinkCounter,
    77  		downlinkCounter: downlinkCounter,
    78  	}
    79  
    80  	if config.SenderSettings != nil {
    81  		senderSettings, err := config.SenderSettings.GetInstance()
    82  		if err != nil {
    83  			return nil, err
    84  		}
    85  		switch s := senderSettings.(type) {
    86  		case *proxyman.SenderConfig:
    87  			h.senderSettings = s
    88  			mss, err := internet.ToMemoryStreamConfig(s.StreamSettings)
    89  			if err != nil {
    90  				return nil, newError("failed to parse stream settings").Base(err).AtWarning()
    91  			}
    92  			h.streamSettings = mss
    93  		default:
    94  			return nil, newError("settings is not SenderConfig")
    95  		}
    96  	}
    97  
    98  	proxyConfig, err := config.ProxySettings.GetInstance()
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	rawProxyHandler, err := common.CreateObject(ctx, proxyConfig)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	proxyHandler, ok := rawProxyHandler.(proxy.Outbound)
   109  	if !ok {
   110  		return nil, newError("not an outbound handler")
   111  	}
   112  
   113  	if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil {
   114  		if config := h.senderSettings.MultiplexSettings; config.Enabled {
   115  			if config.Concurrency < 0 {
   116  				h.mux = &mux.ClientManager{Enabled: false}
   117  			}
   118  			if config.Concurrency == 0 {
   119  				config.Concurrency = 8 // same as before
   120  			}
   121  			if config.Concurrency > 0 {
   122  				h.mux = &mux.ClientManager{
   123  					Enabled: true,
   124  					Picker: &mux.IncrementalWorkerPicker{
   125  						Factory: &mux.DialingWorkerFactory{
   126  							Proxy:  proxyHandler,
   127  							Dialer: h,
   128  							Strategy: mux.ClientStrategy{
   129  								MaxConcurrency: uint32(config.Concurrency),
   130  								MaxConnection:  128,
   131  							},
   132  						},
   133  					},
   134  				}
   135  			}
   136  			if config.XudpConcurrency < 0 {
   137  				h.xudp = &mux.ClientManager{Enabled: false}
   138  			}
   139  			if config.XudpConcurrency == 0 {
   140  				h.xudp = nil // same as before
   141  			}
   142  			if config.XudpConcurrency > 0 {
   143  				h.xudp = &mux.ClientManager{
   144  					Enabled: true,
   145  					Picker: &mux.IncrementalWorkerPicker{
   146  						Factory: &mux.DialingWorkerFactory{
   147  							Proxy:  proxyHandler,
   148  							Dialer: h,
   149  							Strategy: mux.ClientStrategy{
   150  								MaxConcurrency: uint32(config.XudpConcurrency),
   151  								MaxConnection:  128,
   152  							},
   153  						},
   154  					},
   155  				}
   156  			}
   157  			h.udp443 = config.XudpProxyUDP443
   158  		}
   159  	}
   160  
   161  	h.proxy = proxyHandler
   162  	return h, nil
   163  }
   164  
   165  // Tag implements outbound.Handler.
   166  func (h *Handler) Tag() string {
   167  	return h.tag
   168  }
   169  
   170  // Dispatch implements proxy.Outbound.Dispatch.
   171  func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
   172  	outbounds := session.OutboundsFromContext(ctx)
   173  	ob := outbounds[len(outbounds) - 1]
   174  	if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address {
   175  		link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
   176  		link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
   177  	}
   178  	if h.mux != nil {
   179  		test := func(err error) {
   180  			if err != nil {
   181  				err := newError("failed to process mux outbound traffic").Base(err)
   182  				session.SubmitOutboundErrorToOriginator(ctx, err)
   183  				err.WriteToLog(session.ExportIDToError(ctx))
   184  				common.Interrupt(link.Writer)
   185  			}
   186  		}
   187  		if ob.Target.Network == net.Network_UDP && ob.Target.Port == 443 {
   188  			switch h.udp443 {
   189  			case "reject":
   190  				test(newError("XUDP rejected UDP/443 traffic").AtInfo())
   191  				return
   192  			case "skip":
   193  				goto out
   194  			}
   195  		}
   196  		if h.xudp != nil && ob.Target.Network == net.Network_UDP {
   197  			if !h.xudp.Enabled {
   198  				goto out
   199  			}
   200  			test(h.xudp.Dispatch(ctx, link))
   201  			return
   202  		}
   203  		if h.mux.Enabled {
   204  			test(h.mux.Dispatch(ctx, link))
   205  			return
   206  		}
   207  	}
   208  out:
   209  	err := h.proxy.Process(ctx, link, h)
   210  	if err != nil {
   211  		if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, context.Canceled) {
   212  			err = nil
   213  		}
   214  	}
   215  	if err != nil {
   216  		// Ensure outbound ray is properly closed.
   217  		err := newError("failed to process outbound traffic").Base(err)
   218  		session.SubmitOutboundErrorToOriginator(ctx, err)
   219  		err.WriteToLog(session.ExportIDToError(ctx))
   220  		common.Interrupt(link.Writer)
   221  	} else {
   222  		common.Close(link.Writer)
   223  	}
   224  	common.Interrupt(link.Reader)
   225  }
   226  
   227  // Address implements internet.Dialer.
   228  func (h *Handler) Address() net.Address {
   229  	if h.senderSettings == nil || h.senderSettings.Via == nil {
   230  		return nil
   231  	}
   232  	return h.senderSettings.Via.AsAddress()
   233  }
   234  
   235  func (h *Handler) DestIpAddress() net.IP {
   236  	return internet.DestIpAddress()
   237  }
   238  
   239  // Dial implements internet.Dialer.
   240  func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) {
   241  	if h.senderSettings != nil {
   242  		if h.senderSettings.ProxySettings.HasTag() {
   243  			tag := h.senderSettings.ProxySettings.Tag
   244  			handler := h.outboundManager.GetHandler(tag)
   245  			if handler != nil {
   246  				newError("proxying to ", tag, " for dest ", dest).AtDebug().WriteToLog(session.ExportIDToError(ctx))
   247  				outbounds := session.OutboundsFromContext(ctx)
   248  				ctx = session.ContextWithOutbounds(ctx, append(outbounds, &session.Outbound{
   249  					Target: dest,
   250  					Tag: tag,
   251  				})) // add another outbound in session ctx
   252  				opts := pipe.OptionsFromContext(ctx)
   253  				uplinkReader, uplinkWriter := pipe.New(opts...)
   254  				downlinkReader, downlinkWriter := pipe.New(opts...)
   255  
   256  				go handler.Dispatch(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter})
   257  				conn := cnc.NewConnection(cnc.ConnectionInputMulti(uplinkWriter), cnc.ConnectionOutputMulti(downlinkReader))
   258  
   259  				if config := tls.ConfigFromStreamSettings(h.streamSettings); config != nil {
   260  					tlsConfig := config.GetTLSConfig(tls.WithDestination(dest))
   261  					conn = tls.Client(conn, tlsConfig)
   262  				}
   263  
   264  				return h.getStatCouterConnection(conn), nil
   265  			}
   266  
   267  			newError("failed to get outbound handler with tag: ", tag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   268  		}
   269  
   270  		if h.senderSettings.Via != nil {
   271  			outbounds := session.OutboundsFromContext(ctx)
   272  			ob := outbounds[len(outbounds) - 1]
   273  			if h.senderSettings.ViaCidr == "" {
   274  				ob.Gateway = h.senderSettings.Via.AsAddress()
   275  			} else { //Get a random address.
   276  				ob.Gateway = ParseRandomIPv6(h.senderSettings.Via.AsAddress(), h.senderSettings.ViaCidr)
   277  			}
   278  		}
   279  	}
   280  
   281  	if conn, err := h.getUoTConnection(ctx, dest); err != os.ErrInvalid {
   282  		return conn, err
   283  	}
   284  
   285  	conn, err := internet.Dial(ctx, dest, h.streamSettings)
   286  	conn = h.getStatCouterConnection(conn)
   287  	outbounds := session.OutboundsFromContext(ctx)
   288  	ob := outbounds[len(outbounds) - 1]
   289  	ob.Conn = conn
   290  	return conn, err
   291  }
   292  
   293  func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection {
   294  	if h.uplinkCounter != nil || h.downlinkCounter != nil {
   295  		return &stat.CounterConnection{
   296  			Connection:   conn,
   297  			ReadCounter:  h.downlinkCounter,
   298  			WriteCounter: h.uplinkCounter,
   299  		}
   300  	}
   301  	return conn
   302  }
   303  
   304  // GetOutbound implements proxy.GetOutbound.
   305  func (h *Handler) GetOutbound() proxy.Outbound {
   306  	return h.proxy
   307  }
   308  
   309  // Start implements common.Runnable.
   310  func (h *Handler) Start() error {
   311  	return nil
   312  }
   313  
   314  // Close implements common.Closable.
   315  func (h *Handler) Close() error {
   316  	common.Close(h.mux)
   317  	return nil
   318  }
   319  
   320  
   321  func ParseRandomIPv6(address net.Address, prefix string) net.Address {
   322  	_, network, _ := gonet.ParseCIDR(address.IP().String() + "/" + prefix)
   323  
   324  	maskSize, totalBits := network.Mask.Size()
   325  	subnetSize := big.NewInt(1).Lsh(big.NewInt(1), uint(totalBits-maskSize))
   326  
   327  	// random
   328  	randomBigInt, _ := rand.Int(rand.Reader, subnetSize)
   329  
   330  	startIPBigInt := big.NewInt(0).SetBytes(network.IP.To16())
   331  	randomIPBigInt := big.NewInt(0).Add(startIPBigInt, randomBigInt)
   332  
   333  	randomIPBytes := randomIPBigInt.Bytes()
   334  	randomIPBytes = append(make([]byte, 16-len(randomIPBytes)), randomIPBytes...)
   335  
   336  	return net.ParseAddress(gonet.IP(randomIPBytes).String())
   337  }