github.com/eagleql/xray-core@v1.4.4/app/proxyman/inbound/always.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/eagleql/xray-core/app/proxyman"
     7  	"github.com/eagleql/xray-core/common"
     8  	"github.com/eagleql/xray-core/common/dice"
     9  	"github.com/eagleql/xray-core/common/errors"
    10  	"github.com/eagleql/xray-core/common/mux"
    11  	"github.com/eagleql/xray-core/common/net"
    12  	"github.com/eagleql/xray-core/core"
    13  	"github.com/eagleql/xray-core/features/policy"
    14  	"github.com/eagleql/xray-core/features/stats"
    15  	"github.com/eagleql/xray-core/proxy"
    16  	"github.com/eagleql/xray-core/transport/internet"
    17  )
    18  
    19  func getStatCounter(v *core.Instance, tag string) (stats.Counter, stats.Counter) {
    20  	var uplinkCounter stats.Counter
    21  	var downlinkCounter stats.Counter
    22  
    23  	policy := v.GetFeature(policy.ManagerType()).(policy.Manager)
    24  	if len(tag) > 0 && policy.ForSystem().Stats.InboundUplink {
    25  		statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
    26  		name := "inbound>>>" + tag + ">>>traffic>>>uplink"
    27  		c, _ := stats.GetOrRegisterCounter(statsManager, name)
    28  		if c != nil {
    29  			uplinkCounter = c
    30  		}
    31  	}
    32  	if len(tag) > 0 && policy.ForSystem().Stats.InboundDownlink {
    33  		statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
    34  		name := "inbound>>>" + tag + ">>>traffic>>>downlink"
    35  		c, _ := stats.GetOrRegisterCounter(statsManager, name)
    36  		if c != nil {
    37  			downlinkCounter = c
    38  		}
    39  	}
    40  
    41  	return uplinkCounter, downlinkCounter
    42  }
    43  
    44  type AlwaysOnInboundHandler struct {
    45  	proxy   proxy.Inbound
    46  	workers []worker
    47  	mux     *mux.Server
    48  	tag     string
    49  }
    50  
    51  func NewAlwaysOnInboundHandler(ctx context.Context, tag string, receiverConfig *proxyman.ReceiverConfig, proxyConfig interface{}) (*AlwaysOnInboundHandler, error) {
    52  	rawProxy, err := common.CreateObject(ctx, proxyConfig)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	p, ok := rawProxy.(proxy.Inbound)
    57  	if !ok {
    58  		return nil, newError("not an inbound proxy.")
    59  	}
    60  
    61  	h := &AlwaysOnInboundHandler{
    62  		proxy: p,
    63  		mux:   mux.NewServer(ctx),
    64  		tag:   tag,
    65  	}
    66  
    67  	uplinkCounter, downlinkCounter := getStatCounter(core.MustFromContext(ctx), tag)
    68  
    69  	nl := p.Network()
    70  	pr := receiverConfig.PortRange
    71  	address := receiverConfig.Listen.AsAddress()
    72  	if address == nil {
    73  		address = net.AnyIP
    74  	}
    75  
    76  	mss, err := internet.ToMemoryStreamConfig(receiverConfig.StreamSettings)
    77  	if err != nil {
    78  		return nil, newError("failed to parse stream config").Base(err).AtWarning()
    79  	}
    80  
    81  	if receiverConfig.ReceiveOriginalDestination {
    82  		if mss.SocketSettings == nil {
    83  			mss.SocketSettings = &internet.SocketConfig{}
    84  		}
    85  		if mss.SocketSettings.Tproxy == internet.SocketConfig_Off {
    86  			mss.SocketSettings.Tproxy = internet.SocketConfig_Redirect
    87  		}
    88  		mss.SocketSettings.ReceiveOriginalDestAddress = true
    89  	}
    90  	if pr == nil {
    91  		if net.HasNetwork(nl, net.Network_UNIX) {
    92  			newError("creating unix domain socket worker on ", address).AtDebug().WriteToLog()
    93  
    94  			worker := &dsWorker{
    95  				address:         address,
    96  				proxy:           p,
    97  				stream:          mss,
    98  				tag:             tag,
    99  				dispatcher:      h.mux,
   100  				sniffingConfig:  receiverConfig.GetEffectiveSniffingSettings(),
   101  				uplinkCounter:   uplinkCounter,
   102  				downlinkCounter: downlinkCounter,
   103  				ctx:             ctx,
   104  			}
   105  			h.workers = append(h.workers, worker)
   106  		}
   107  	}
   108  	if pr != nil {
   109  		for port := pr.From; port <= pr.To; port++ {
   110  			if net.HasNetwork(nl, net.Network_TCP) {
   111  				newError("creating stream worker on ", address, ":", port).AtDebug().WriteToLog()
   112  
   113  				worker := &tcpWorker{
   114  					address:         address,
   115  					port:            net.Port(port),
   116  					proxy:           p,
   117  					stream:          mss,
   118  					recvOrigDest:    receiverConfig.ReceiveOriginalDestination,
   119  					tag:             tag,
   120  					dispatcher:      h.mux,
   121  					sniffingConfig:  receiverConfig.GetEffectiveSniffingSettings(),
   122  					uplinkCounter:   uplinkCounter,
   123  					downlinkCounter: downlinkCounter,
   124  					ctx:             ctx,
   125  				}
   126  				h.workers = append(h.workers, worker)
   127  			}
   128  
   129  			if net.HasNetwork(nl, net.Network_UDP) {
   130  				worker := &udpWorker{
   131  					tag:             tag,
   132  					proxy:           p,
   133  					address:         address,
   134  					port:            net.Port(port),
   135  					dispatcher:      h.mux,
   136  					sniffingConfig:  receiverConfig.GetEffectiveSniffingSettings(),
   137  					uplinkCounter:   uplinkCounter,
   138  					downlinkCounter: downlinkCounter,
   139  					stream:          mss,
   140  					ctx:             ctx,
   141  				}
   142  				h.workers = append(h.workers, worker)
   143  			}
   144  		}
   145  	}
   146  
   147  	return h, nil
   148  }
   149  
   150  // Start implements common.Runnable.
   151  func (h *AlwaysOnInboundHandler) Start() error {
   152  	for _, worker := range h.workers {
   153  		if err := worker.Start(); err != nil {
   154  			return err
   155  		}
   156  	}
   157  	return nil
   158  }
   159  
   160  // Close implements common.Closable.
   161  func (h *AlwaysOnInboundHandler) Close() error {
   162  	var errs []error
   163  	for _, worker := range h.workers {
   164  		errs = append(errs, worker.Close())
   165  	}
   166  	errs = append(errs, h.mux.Close())
   167  	if err := errors.Combine(errs...); err != nil {
   168  		return newError("failed to close all resources").Base(err)
   169  	}
   170  	return nil
   171  }
   172  
   173  func (h *AlwaysOnInboundHandler) GetRandomInboundProxy() (interface{}, net.Port, int) {
   174  	if len(h.workers) == 0 {
   175  		return nil, 0, 0
   176  	}
   177  	w := h.workers[dice.Roll(len(h.workers))]
   178  	return w.Proxy(), w.Port(), 9999
   179  }
   180  
   181  func (h *AlwaysOnInboundHandler) Tag() string {
   182  	return h.tag
   183  }
   184  
   185  func (h *AlwaysOnInboundHandler) GetInbound() proxy.Inbound {
   186  	return h.proxy
   187  }