github.com/moqsien/xraycore@v1.8.5/app/router/router.go (about)

     1  package router
     2  
     3  //go:generate go run github.com/moqsien/xraycore/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  
     8  	"github.com/moqsien/xraycore/common"
     9  	"github.com/moqsien/xraycore/core"
    10  	"github.com/moqsien/xraycore/features/dns"
    11  	"github.com/moqsien/xraycore/features/outbound"
    12  	"github.com/moqsien/xraycore/features/routing"
    13  	routing_dns "github.com/moqsien/xraycore/features/routing/dns"
    14  )
    15  
    16  // Router is an implementation of routing.Router.
    17  type Router struct {
    18  	domainStrategy Config_DomainStrategy
    19  	rules          []*Rule
    20  	balancers      map[string]*Balancer
    21  	dns            dns.Client
    22  }
    23  
    24  // Route is an implementation of routing.Route.
    25  type Route struct {
    26  	routing.Context
    27  	outboundGroupTags []string
    28  	outboundTag       string
    29  }
    30  
    31  // Init initializes the Router.
    32  func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager) error {
    33  	r.domainStrategy = config.DomainStrategy
    34  	r.dns = d
    35  
    36  	r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
    37  	for _, rule := range config.BalancingRule {
    38  		balancer, err := rule.Build(ohm)
    39  		if err != nil {
    40  			return err
    41  		}
    42  		balancer.InjectContext(ctx)
    43  		r.balancers[rule.Tag] = balancer
    44  	}
    45  
    46  	r.rules = make([]*Rule, 0, len(config.Rule))
    47  	for _, rule := range config.Rule {
    48  		cond, err := rule.BuildCondition()
    49  		if err != nil {
    50  			return err
    51  		}
    52  		rr := &Rule{
    53  			Condition: cond,
    54  			Tag:       rule.GetTag(),
    55  		}
    56  		btag := rule.GetBalancingTag()
    57  		if len(btag) > 0 {
    58  			brule, found := r.balancers[btag]
    59  			if !found {
    60  				return newError("balancer ", btag, " not found")
    61  			}
    62  			rr.Balancer = brule
    63  		}
    64  		r.rules = append(r.rules, rr)
    65  	}
    66  
    67  	return nil
    68  }
    69  
    70  // PickRoute implements routing.Router.
    71  func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) {
    72  	rule, ctx, err := r.pickRouteInternal(ctx)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	tag, err := rule.GetTag()
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	return &Route{Context: ctx, outboundTag: tag}, nil
    81  }
    82  
    83  func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) {
    84  	// SkipDNSResolve is set from DNS module.
    85  	// the DOH remote server maybe a domain name,
    86  	// this prevents cycle resolving dead loop
    87  	skipDNSResolve := ctx.GetSkipDNSResolve()
    88  
    89  	if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve {
    90  		ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
    91  	}
    92  
    93  	for _, rule := range r.rules {
    94  		if rule.Apply(ctx) {
    95  			return rule, ctx, nil
    96  		}
    97  	}
    98  
    99  	if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve {
   100  		return nil, ctx, common.ErrNoClue
   101  	}
   102  
   103  	ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
   104  
   105  	// Try applying rules again if we have IPs.
   106  	for _, rule := range r.rules {
   107  		if rule.Apply(ctx) {
   108  			return rule, ctx, nil
   109  		}
   110  	}
   111  
   112  	return nil, ctx, common.ErrNoClue
   113  }
   114  
   115  // Start implements common.Runnable.
   116  func (*Router) Start() error {
   117  	return nil
   118  }
   119  
   120  // Close implements common.Closable.
   121  func (*Router) Close() error {
   122  	return nil
   123  }
   124  
   125  // Type implements common.HasType.
   126  func (*Router) Type() interface{} {
   127  	return routing.RouterType()
   128  }
   129  
   130  // GetOutboundGroupTags implements routing.Route.
   131  func (r *Route) GetOutboundGroupTags() []string {
   132  	return r.outboundGroupTags
   133  }
   134  
   135  // GetOutboundTag implements routing.Route.
   136  func (r *Route) GetOutboundTag() string {
   137  	return r.outboundTag
   138  }
   139  
   140  func init() {
   141  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   142  		r := new(Router)
   143  		if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager) error {
   144  			return r.Init(ctx, config.(*Config), d, ohm)
   145  		}); err != nil {
   146  			return nil, err
   147  		}
   148  		return r, nil
   149  	}))
   150  }