github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/app/router/router.go (about)

     1  package router
     2  
     3  //go:generate go run github.com/xtls/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	sync "sync"
     8  
     9  	"github.com/xtls/xray-core/common"
    10  	"github.com/xtls/xray-core/common/serial"
    11  	"github.com/xtls/xray-core/core"
    12  	"github.com/xtls/xray-core/features/dns"
    13  	"github.com/xtls/xray-core/features/outbound"
    14  	"github.com/xtls/xray-core/features/routing"
    15  	routing_dns "github.com/xtls/xray-core/features/routing/dns"
    16  )
    17  
    18  // Router is an implementation of routing.Router.
    19  type Router struct {
    20  	domainStrategy Config_DomainStrategy
    21  	rules          []*Rule
    22  	balancers      map[string]*Balancer
    23  	dns            dns.Client
    24  
    25  	ctx        context.Context
    26  	ohm        outbound.Manager
    27  	dispatcher routing.Dispatcher
    28  	mu         sync.Mutex
    29  }
    30  
    31  // Route is an implementation of routing.Route.
    32  type Route struct {
    33  	routing.Context
    34  	outboundGroupTags []string
    35  	outboundTag       string
    36  }
    37  
    38  // Init initializes the Router.
    39  func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error {
    40  	r.domainStrategy = config.DomainStrategy
    41  	r.dns = d
    42  	r.ctx = ctx
    43  	r.ohm = ohm
    44  	r.dispatcher = dispatcher
    45  
    46  	r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
    47  	for _, rule := range config.BalancingRule {
    48  		balancer, err := rule.Build(ohm, dispatcher)
    49  		if err != nil {
    50  			return err
    51  		}
    52  		balancer.InjectContext(ctx)
    53  		r.balancers[rule.Tag] = balancer
    54  	}
    55  
    56  	r.rules = make([]*Rule, 0, len(config.Rule))
    57  	for _, rule := range config.Rule {
    58  		cond, err := rule.BuildCondition()
    59  		if err != nil {
    60  			return err
    61  		}
    62  		rr := &Rule{
    63  			Condition: cond,
    64  			Tag:       rule.GetTag(),
    65  			RuleTag:   rule.GetRuleTag(),
    66  		}
    67  		btag := rule.GetBalancingTag()
    68  		if len(btag) > 0 {
    69  			brule, found := r.balancers[btag]
    70  			if !found {
    71  				return newError("balancer ", btag, " not found")
    72  			}
    73  			rr.Balancer = brule
    74  		}
    75  		r.rules = append(r.rules, rr)
    76  	}
    77  
    78  	return nil
    79  }
    80  
    81  // PickRoute implements routing.Router.
    82  func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) {
    83  	rule, ctx, err := r.pickRouteInternal(ctx)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	tag, err := rule.GetTag()
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	return &Route{Context: ctx, outboundTag: tag}, nil
    92  }
    93  
    94  // AddRule implements routing.Router.
    95  func (r *Router) AddRule(config *serial.TypedMessage, shouldAppend bool) error {
    96  
    97  	inst, err := config.GetInstance()
    98  	if err != nil {
    99  		return err
   100  	}
   101  	if c, ok := inst.(*Config); ok {
   102  		return r.ReloadRules(c, shouldAppend)
   103  	}
   104  	return newError("AddRule: config type error")
   105  }
   106  
   107  func (r *Router) ReloadRules(config *Config, shouldAppend bool) error {
   108  	r.mu.Lock()
   109  	defer r.mu.Unlock()
   110  
   111  	if !shouldAppend {
   112  		r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
   113  		r.rules = make([]*Rule, 0, len(config.Rule))
   114  	}
   115  	for _, rule := range config.BalancingRule {
   116  		_, found := r.balancers[rule.Tag]
   117  		if found {
   118  			return newError("duplicate balancer tag")
   119  		}
   120  		balancer, err := rule.Build(r.ohm, r.dispatcher)
   121  		if err != nil {
   122  			return err
   123  		}
   124  		balancer.InjectContext(r.ctx)
   125  		r.balancers[rule.Tag] = balancer
   126  	}
   127  
   128  	for _, rule := range config.Rule {
   129  		if r.RuleExists(rule.GetRuleTag()) {
   130  			return newError("duplicate ruleTag ", rule.GetRuleTag())
   131  		}
   132  		cond, err := rule.BuildCondition()
   133  		if err != nil {
   134  			return err
   135  		}
   136  		rr := &Rule{
   137  			Condition: cond,
   138  			Tag:       rule.GetTag(),
   139  			RuleTag:   rule.GetRuleTag(),
   140  		}
   141  		btag := rule.GetBalancingTag()
   142  		if len(btag) > 0 {
   143  			brule, found := r.balancers[btag]
   144  			if !found {
   145  				return newError("balancer ", btag, " not found")
   146  			}
   147  			rr.Balancer = brule
   148  		}
   149  		r.rules = append(r.rules, rr)
   150  	}
   151  
   152  	return nil
   153  }
   154  
   155  func (r *Router) RuleExists(tag string) bool {
   156  	if tag != "" {
   157  		for _, rule := range r.rules {
   158  			if rule.RuleTag == tag {
   159  				return true
   160  			}
   161  		}
   162  	}
   163  	return false
   164  }
   165  
   166  // RemoveRule implements routing.Router.
   167  func (r *Router) RemoveRule(tag string) error {
   168  	r.mu.Lock()
   169  	defer r.mu.Unlock()
   170  
   171  	newRules := []*Rule{}
   172  	if tag != "" {
   173  		for _, rule := range r.rules {
   174  			if rule.RuleTag != tag {
   175  				newRules = append(newRules, rule)
   176  			}
   177  		}
   178  		r.rules = newRules
   179  		return nil
   180  	}
   181  	return newError("empty tag name!")
   182  
   183  }
   184  func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) {
   185  	// SkipDNSResolve is set from DNS module.
   186  	// the DOH remote server maybe a domain name,
   187  	// this prevents cycle resolving dead loop
   188  	skipDNSResolve := ctx.GetSkipDNSResolve()
   189  
   190  	if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve {
   191  		ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
   192  	}
   193  
   194  	for _, rule := range r.rules {
   195  		if rule.Apply(ctx) {
   196  			return rule, ctx, nil
   197  		}
   198  	}
   199  
   200  	if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve {
   201  		return nil, ctx, common.ErrNoClue
   202  	}
   203  
   204  	ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
   205  
   206  	// Try applying rules again if we have IPs.
   207  	for _, rule := range r.rules {
   208  		if rule.Apply(ctx) {
   209  			return rule, ctx, nil
   210  		}
   211  	}
   212  
   213  	return nil, ctx, common.ErrNoClue
   214  }
   215  
   216  // Start implements common.Runnable.
   217  func (r *Router) Start() error {
   218  	return nil
   219  }
   220  
   221  // Close implements common.Closable.
   222  func (r *Router) Close() error {
   223  	return nil
   224  }
   225  
   226  // Type implements common.HasType.
   227  func (*Router) Type() interface{} {
   228  	return routing.RouterType()
   229  }
   230  
   231  // GetOutboundGroupTags implements routing.Route.
   232  func (r *Route) GetOutboundGroupTags() []string {
   233  	return r.outboundGroupTags
   234  }
   235  
   236  // GetOutboundTag implements routing.Route.
   237  func (r *Route) GetOutboundTag() string {
   238  	return r.outboundTag
   239  }
   240  
   241  func init() {
   242  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   243  		r := new(Router)
   244  		if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error {
   245  			return r.Init(ctx, config.(*Config), d, ohm, dispatcher)
   246  		}); err != nil {
   247  			return nil, err
   248  		}
   249  		return r, nil
   250  	}))
   251  }