github.com/xmplusdev/xray-core@v1.8.10/app/router/router.go (about)

     1  package router
     2  
     3  //go:generate go run github.com/xmplusdev/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	sync "sync"
     8  
     9  	"github.com/xmplusdev/xray-core/common"
    10  	"github.com/xmplusdev/xray-core/common/serial"
    11  	"github.com/xmplusdev/xray-core/core"
    12  	"github.com/xmplusdev/xray-core/features/dns"
    13  	"github.com/xmplusdev/xray-core/features/outbound"
    14  	"github.com/xmplusdev/xray-core/features/routing"
    15  	routing_dns "github.com/xmplusdev/xray-core/features/routing/dns"
    16  )
    17  
    18  // Router is an implementation of routing.Router.
    19  type Router struct {
    20  	domainStrategy Config_DomainStrategy
    21  	rules          []*Rule
    22  	balancers      map[string]*Balancer
    23  	dns            dns.Client
    24  
    25  	ctx        context.Context
    26  	ohm        outbound.Manager
    27  	dispatcher routing.Dispatcher
    28  	mu         sync.Mutex
    29  }
    30  
    31  // Route is an implementation of routing.Route.
    32  type Route struct {
    33  	routing.Context
    34  	outboundGroupTags []string
    35  	outboundTag       string
    36  }
    37  
    38  
    39  // Init initializes the Router.
    40  func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error {
    41  	r.domainStrategy = config.DomainStrategy
    42  	r.dns = d
    43  	r.ctx = ctx
    44  	r.ohm = ohm
    45  	r.dispatcher = dispatcher
    46  	
    47  	r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
    48  	
    49  	for _, rule := range config.BalancingRule {
    50  		balancer, err := rule.Build(ohm, dispatcher)
    51  		if err != nil {
    52  			return err
    53  		}
    54  		balancer.InjectContext(ctx)
    55  		r.balancers[rule.Tag] = balancer
    56  	}
    57  
    58  	r.rules = make([]*Rule, 0, len(config.Rule))
    59  	for _, rule := range config.Rule {
    60  		cond, err := rule.BuildCondition()
    61  		if err != nil {
    62  			return err
    63  		}
    64  		rr := &Rule{
    65  			Condition: cond,
    66  			Tag:       rule.GetTag(),
    67  			RuleTag:   rule.GetRuleTag(),
    68  		}
    69  		btag := rule.GetBalancingTag()
    70  		if len(btag) > 0 {
    71  			brule, found := r.balancers[btag]
    72  			if !found {
    73  				return newError("balancer ", btag, " not found")
    74  			}
    75  			rr.Balancer = brule
    76  		}
    77  		r.rules = append(r.rules, rr)
    78  	}
    79  
    80  	return nil
    81  }
    82  
    83  // PickRoute implements routing.Router.
    84  func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) {
    85  	rule, ctx, err := r.pickRouteInternal(ctx)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	tag, err := rule.GetTag()
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  	return &Route{Context: ctx, outboundTag: tag}, nil
    94  }
    95  
    96  // AddRule implements routing.Router.
    97  func (r *Router) AddRule(config *serial.TypedMessage, shouldAppend bool) error {
    98  
    99  	inst, err := config.GetInstance()
   100  	if err != nil {
   101  		return err
   102  	}
   103  	if c, ok := inst.(*Config); ok {
   104  		return r.ReloadRules(c, shouldAppend)
   105  	}
   106  	return newError("AddRule: config type error")
   107  }
   108  
   109  func (r *Router) ReloadRules(config *Config, shouldAppend bool) error {
   110  	r.mu.Lock()
   111  	defer r.mu.Unlock()
   112  
   113  	if !shouldAppend {
   114  		r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
   115  		r.rules = make([]*Rule, 0, len(config.Rule))
   116  	}
   117  	for _, rule := range config.BalancingRule {
   118  		_, found := r.balancers[rule.Tag]
   119  		if found {
   120  			return newError("duplicate balancer tag")
   121  		}
   122  		balancer, err := rule.Build(r.ohm, r.dispatcher)
   123  		if err != nil {
   124  			return err
   125  		}
   126  		balancer.InjectContext(r.ctx)
   127  		r.balancers[rule.Tag] = balancer
   128  	}
   129  
   130  	for _, rule := range config.Rule {
   131  		if r.RuleExists(rule.GetRuleTag()) {
   132  			return newError("duplicate ruleTag ", rule.GetRuleTag())
   133  		}
   134  		cond, err := rule.BuildCondition()
   135  		if err != nil {
   136  			return err
   137  		}
   138  		rr := &Rule{
   139  			Condition: cond,
   140  			Tag:       rule.GetTag(),
   141  			RuleTag:   rule.GetRuleTag(),
   142  		}
   143  		btag := rule.GetBalancingTag()
   144  		if len(btag) > 0 {
   145  			brule, found := r.balancers[btag]
   146  			if !found {
   147  				return newError("balancer ", btag, " not found")
   148  			}
   149  			rr.Balancer = brule
   150  		}
   151  		r.rules = append(r.rules, rr)
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (r *Router) RuleExists(tag string) bool {
   158  	if tag != "" {
   159  		for _, rule := range r.rules {
   160  			if rule.RuleTag == tag {
   161  				return true
   162  			}
   163  		}
   164  	}
   165  	return false
   166  }
   167  
   168  // RemoveRule implements routing.Router.
   169  func (r *Router) RemoveRule(tag string) error {
   170  	r.mu.Lock()
   171  	defer r.mu.Unlock()
   172  
   173  	newRules := []*Rule{}
   174  	if tag != "" {
   175  		for _, rule := range r.rules {
   176  			if rule.RuleTag != tag {
   177  				newRules = append(newRules, rule)
   178  			}
   179  		}
   180  		r.rules = newRules
   181  		return nil
   182  	}
   183  	return newError("empty tag name!")
   184  
   185  }
   186  func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) {
   187  	// SkipDNSResolve is set from DNS module.
   188  	// the DOH remote server maybe a domain name,
   189  	// this prevents cycle resolving dead loop
   190  	skipDNSResolve := ctx.GetSkipDNSResolve()
   191  
   192  	if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve {
   193  		ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
   194  	}
   195  
   196  	for _, rule := range r.rules {
   197  		if rule.Apply(ctx) {
   198  			return rule, ctx, nil
   199  		}
   200  	}
   201  
   202  	if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve {
   203  		return nil, ctx, common.ErrNoClue
   204  	}
   205  
   206  	ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
   207  
   208  	// Try applying rules again if we have IPs.
   209  	for _, rule := range r.rules {
   210  		if rule.Apply(ctx) {
   211  			return rule, ctx, nil
   212  		}
   213  	}
   214  
   215  	return nil, ctx, common.ErrNoClue
   216  }
   217  
   218  // Start implements common.Runnable.
   219  func (r *Router) Start() error {
   220  	return nil
   221  }
   222  
   223  // Close implements common.Closable.
   224  func (r *Router) Close() error {
   225  	return nil
   226  }
   227  
   228  // Type implements common.HasType.
   229  func (*Router) Type() interface{} {
   230  	return routing.RouterType()
   231  }
   232  
   233  // GetOutboundGroupTags implements routing.Route.
   234  func (r *Route) GetOutboundGroupTags() []string {
   235  	return r.outboundGroupTags
   236  }
   237  
   238  // GetOutboundTag implements routing.Route.
   239  func (r *Route) GetOutboundTag() string {
   240  	return r.outboundTag
   241  }
   242  
   243  func init() {
   244  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   245  		r := new(Router)
   246  		if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error {
   247  			return r.Init(ctx, config.(*Config), d, ohm, dispatcher)
   248  		}); err != nil {
   249  			return nil, err
   250  		}
   251  		return r, nil
   252  	}))
   253  }