github.com/eagleql/xray-core@v1.4.4/app/router/router.go (about)

     1  package router
     2  
     3  //go:generate go run github.com/eagleql/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  
     8  	"github.com/eagleql/xray-core/common"
     9  	"github.com/eagleql/xray-core/core"
    10  	"github.com/eagleql/xray-core/features/dns"
    11  	"github.com/eagleql/xray-core/features/outbound"
    12  	"github.com/eagleql/xray-core/features/routing"
    13  	routing_dns "github.com/eagleql/xray-core/features/routing/dns"
    14  )
    15  
    16  // Router is an implementation of routing.Router.
    17  type Router struct {
    18  	domainStrategy Config_DomainStrategy
    19  	rules          []*Rule
    20  	balancers      map[string]*Balancer
    21  	dns            dns.Client
    22  }
    23  
    24  // Route is an implementation of routing.Route.
    25  type Route struct {
    26  	routing.Context
    27  	outboundGroupTags []string
    28  	outboundTag       string
    29  }
    30  
    31  // Init initializes the Router.
    32  func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error {
    33  	r.domainStrategy = config.DomainStrategy
    34  	r.dns = d
    35  
    36  	r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
    37  	for _, rule := range config.BalancingRule {
    38  		balancer, err := rule.Build(ohm)
    39  		if err != nil {
    40  			return err
    41  		}
    42  		r.balancers[rule.Tag] = balancer
    43  	}
    44  
    45  	r.rules = make([]*Rule, 0, len(config.Rule))
    46  	for _, rule := range config.Rule {
    47  		cond, err := rule.BuildCondition()
    48  		if err != nil {
    49  			return err
    50  		}
    51  		rr := &Rule{
    52  			Condition: cond,
    53  			Tag:       rule.GetTag(),
    54  		}
    55  		btag := rule.GetBalancingTag()
    56  		if len(btag) > 0 {
    57  			brule, found := r.balancers[btag]
    58  			if !found {
    59  				return newError("balancer ", btag, " not found")
    60  			}
    61  			rr.Balancer = brule
    62  		}
    63  		r.rules = append(r.rules, rr)
    64  	}
    65  
    66  	return nil
    67  }
    68  
    69  // PickRoute implements routing.Router.
    70  func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) {
    71  	rule, ctx, err := r.pickRouteInternal(ctx)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	tag, err := rule.GetTag()
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return &Route{Context: ctx, outboundTag: tag}, nil
    80  }
    81  
    82  func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) {
    83  	if r.domainStrategy == Config_IpOnDemand {
    84  		ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
    85  	}
    86  
    87  	for _, rule := range r.rules {
    88  		if rule.Apply(ctx) {
    89  			return rule, ctx, nil
    90  		}
    91  	}
    92  
    93  	if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 {
    94  		return nil, ctx, common.ErrNoClue
    95  	}
    96  
    97  	ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
    98  
    99  	// Try applying rules again if we have IPs.
   100  	for _, rule := range r.rules {
   101  		if rule.Apply(ctx) {
   102  			return rule, ctx, nil
   103  		}
   104  	}
   105  
   106  	return nil, ctx, common.ErrNoClue
   107  }
   108  
   109  // Start implements common.Runnable.
   110  func (*Router) Start() error {
   111  	return nil
   112  }
   113  
   114  // Close implements common.Closable.
   115  func (*Router) Close() error {
   116  	return nil
   117  }
   118  
   119  // Type implement common.HasType.
   120  func (*Router) Type() interface{} {
   121  	return routing.RouterType()
   122  }
   123  
   124  // GetOutboundGroupTags implements routing.Route.
   125  func (r *Route) GetOutboundGroupTags() []string {
   126  	return r.outboundGroupTags
   127  }
   128  
   129  // GetOutboundTag implements routing.Route.
   130  func (r *Route) GetOutboundTag() string {
   131  	return r.outboundTag
   132  }
   133  
   134  func init() {
   135  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   136  		r := new(Router)
   137  		if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager) error {
   138  			return r.Init(config.(*Config), d, ohm)
   139  		}); err != nil {
   140  			return nil, err
   141  		}
   142  		return r, nil
   143  	}))
   144  }