github.com/xraypb/xray-core@v1.6.6/app/proxyman/inbound/always.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/xraypb/xray-core/app/proxyman"
     7  	"github.com/xraypb/xray-core/common"
     8  	"github.com/xraypb/xray-core/common/dice"
     9  	"github.com/xraypb/xray-core/common/errors"
    10  	"github.com/xraypb/xray-core/common/mux"
    11  	"github.com/xraypb/xray-core/common/net"
    12  	"github.com/xraypb/xray-core/core"
    13  	"github.com/xraypb/xray-core/features/policy"
    14  	"github.com/xraypb/xray-core/features/stats"
    15  	"github.com/xraypb/xray-core/proxy"
    16  	"github.com/xraypb/xray-core/transport/internet"
    17  )
    18  
    19  func getStatCounter(v *core.Instance, tag string) (stats.Counter, stats.Counter) {
    20  	var uplinkCounter stats.Counter
    21  	var downlinkCounter stats.Counter
    22  
    23  	policy := v.GetFeature(policy.ManagerType()).(policy.Manager)
    24  	if len(tag) > 0 && policy.ForSystem().Stats.InboundUplink {
    25  		statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
    26  		name := "inbound>>>" + tag + ">>>traffic>>>uplink"
    27  		c, _ := stats.GetOrRegisterCounter(statsManager, name)
    28  		if c != nil {
    29  			uplinkCounter = c
    30  		}
    31  	}
    32  	if len(tag) > 0 && policy.ForSystem().Stats.InboundDownlink {
    33  		statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
    34  		name := "inbound>>>" + tag + ">>>traffic>>>downlink"
    35  		c, _ := stats.GetOrRegisterCounter(statsManager, name)
    36  		if c != nil {
    37  			downlinkCounter = c
    38  		}
    39  	}
    40  
    41  	return uplinkCounter, downlinkCounter
    42  }
    43  
    44  type AlwaysOnInboundHandler struct {
    45  	proxy   proxy.Inbound
    46  	workers []worker
    47  	mux     *mux.Server
    48  	tag     string
    49  }
    50  
    51  func NewAlwaysOnInboundHandler(ctx context.Context, tag string, receiverConfig *proxyman.ReceiverConfig, proxyConfig interface{}) (*AlwaysOnInboundHandler, error) {
    52  	rawProxy, err := common.CreateObject(ctx, proxyConfig)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	p, ok := rawProxy.(proxy.Inbound)
    57  	if !ok {
    58  		return nil, newError("not an inbound proxy.")
    59  	}
    60  
    61  	h := &AlwaysOnInboundHandler{
    62  		proxy: p,
    63  		mux:   mux.NewServer(ctx),
    64  		tag:   tag,
    65  	}
    66  
    67  	uplinkCounter, downlinkCounter := getStatCounter(core.MustFromContext(ctx), tag)
    68  
    69  	nl := p.Network()
    70  	pl := receiverConfig.PortList
    71  	address := receiverConfig.Listen.AsAddress()
    72  	if address == nil {
    73  		address = net.AnyIP
    74  	}
    75  
    76  	mss, err := internet.ToMemoryStreamConfig(receiverConfig.StreamSettings)
    77  	if err != nil {
    78  		return nil, newError("failed to parse stream config").Base(err).AtWarning()
    79  	}
    80  
    81  	if receiverConfig.ReceiveOriginalDestination {
    82  		if mss.SocketSettings == nil {
    83  			mss.SocketSettings = &internet.SocketConfig{}
    84  		}
    85  		if mss.SocketSettings.Tproxy == internet.SocketConfig_Off {
    86  			mss.SocketSettings.Tproxy = internet.SocketConfig_Redirect
    87  		}
    88  		mss.SocketSettings.ReceiveOriginalDestAddress = true
    89  	}
    90  	if pl == nil {
    91  		if net.HasNetwork(nl, net.Network_UNIX) {
    92  			newError("creating unix domain socket worker on ", address).AtDebug().WriteToLog()
    93  
    94  			worker := &dsWorker{
    95  				address:         address,
    96  				proxy:           p,
    97  				stream:          mss,
    98  				tag:             tag,
    99  				dispatcher:      h.mux,
   100  				sniffingConfig:  receiverConfig.GetEffectiveSniffingSettings(),
   101  				uplinkCounter:   uplinkCounter,
   102  				downlinkCounter: downlinkCounter,
   103  				ctx:             ctx,
   104  			}
   105  			h.workers = append(h.workers, worker)
   106  		}
   107  	}
   108  	if pl != nil {
   109  		for _, pr := range pl.Range {
   110  			for port := pr.From; port <= pr.To; port++ {
   111  				if net.HasNetwork(nl, net.Network_TCP) {
   112  					newError("creating stream worker on ", address, ":", port).AtDebug().WriteToLog()
   113  
   114  					worker := &tcpWorker{
   115  						address:         address,
   116  						port:            net.Port(port),
   117  						proxy:           p,
   118  						stream:          mss,
   119  						recvOrigDest:    receiverConfig.ReceiveOriginalDestination,
   120  						tag:             tag,
   121  						dispatcher:      h.mux,
   122  						sniffingConfig:  receiverConfig.GetEffectiveSniffingSettings(),
   123  						uplinkCounter:   uplinkCounter,
   124  						downlinkCounter: downlinkCounter,
   125  						ctx:             ctx,
   126  					}
   127  					h.workers = append(h.workers, worker)
   128  				}
   129  
   130  				if net.HasNetwork(nl, net.Network_UDP) {
   131  					worker := &udpWorker{
   132  						tag:             tag,
   133  						proxy:           p,
   134  						address:         address,
   135  						port:            net.Port(port),
   136  						dispatcher:      h.mux,
   137  						sniffingConfig:  receiverConfig.GetEffectiveSniffingSettings(),
   138  						uplinkCounter:   uplinkCounter,
   139  						downlinkCounter: downlinkCounter,
   140  						stream:          mss,
   141  						ctx:             ctx,
   142  					}
   143  					h.workers = append(h.workers, worker)
   144  				}
   145  			}
   146  		}
   147  	}
   148  
   149  	return h, nil
   150  }
   151  
   152  // Start implements common.Runnable.
   153  func (h *AlwaysOnInboundHandler) Start() error {
   154  	for _, worker := range h.workers {
   155  		if err := worker.Start(); err != nil {
   156  			return err
   157  		}
   158  	}
   159  	return nil
   160  }
   161  
   162  // Close implements common.Closable.
   163  func (h *AlwaysOnInboundHandler) Close() error {
   164  	var errs []error
   165  	for _, worker := range h.workers {
   166  		errs = append(errs, worker.Close())
   167  	}
   168  	errs = append(errs, h.mux.Close())
   169  	if err := errors.Combine(errs...); err != nil {
   170  		return newError("failed to close all resources").Base(err)
   171  	}
   172  	return nil
   173  }
   174  
   175  func (h *AlwaysOnInboundHandler) GetRandomInboundProxy() (interface{}, net.Port, int) {
   176  	if len(h.workers) == 0 {
   177  		return nil, 0, 0
   178  	}
   179  	w := h.workers[dice.Roll(len(h.workers))]
   180  	return w.Proxy(), w.Port(), 9999
   181  }
   182  
   183  func (h *AlwaysOnInboundHandler) Tag() string {
   184  	return h.tag
   185  }
   186  
   187  func (h *AlwaysOnInboundHandler) GetInbound() proxy.Inbound {
   188  	return h.proxy
   189  }