github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/app/router/router.go (about) 1 package router 2 3 //go:generate go run github.com/xtls/xray-core/common/errors/errorgen 4 5 import ( 6 "context" 7 sync "sync" 8 9 "github.com/xtls/xray-core/common" 10 "github.com/xtls/xray-core/common/serial" 11 "github.com/xtls/xray-core/core" 12 "github.com/xtls/xray-core/features/dns" 13 "github.com/xtls/xray-core/features/outbound" 14 "github.com/xtls/xray-core/features/routing" 15 routing_dns "github.com/xtls/xray-core/features/routing/dns" 16 ) 17 18 // Router is an implementation of routing.Router. 19 type Router struct { 20 domainStrategy Config_DomainStrategy 21 rules []*Rule 22 balancers map[string]*Balancer 23 dns dns.Client 24 25 ctx context.Context 26 ohm outbound.Manager 27 dispatcher routing.Dispatcher 28 mu sync.Mutex 29 } 30 31 // Route is an implementation of routing.Route. 32 type Route struct { 33 routing.Context 34 outboundGroupTags []string 35 outboundTag string 36 } 37 38 // Init initializes the Router. 39 func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error { 40 r.domainStrategy = config.DomainStrategy 41 r.dns = d 42 r.ctx = ctx 43 r.ohm = ohm 44 r.dispatcher = dispatcher 45 46 r.balancers = make(map[string]*Balancer, len(config.BalancingRule)) 47 for _, rule := range config.BalancingRule { 48 balancer, err := rule.Build(ohm, dispatcher) 49 if err != nil { 50 return err 51 } 52 balancer.InjectContext(ctx) 53 r.balancers[rule.Tag] = balancer 54 } 55 56 r.rules = make([]*Rule, 0, len(config.Rule)) 57 for _, rule := range config.Rule { 58 cond, err := rule.BuildCondition() 59 if err != nil { 60 return err 61 } 62 rr := &Rule{ 63 Condition: cond, 64 Tag: rule.GetTag(), 65 RuleTag: rule.GetRuleTag(), 66 } 67 btag := rule.GetBalancingTag() 68 if len(btag) > 0 { 69 brule, found := r.balancers[btag] 70 if !found { 71 return newError("balancer ", btag, " not found") 72 } 73 rr.Balancer = brule 74 } 75 r.rules = append(r.rules, rr) 76 } 77 78 return nil 79 } 80 81 // PickRoute implements routing.Router. 82 func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) { 83 rule, ctx, err := r.pickRouteInternal(ctx) 84 if err != nil { 85 return nil, err 86 } 87 tag, err := rule.GetTag() 88 if err != nil { 89 return nil, err 90 } 91 return &Route{Context: ctx, outboundTag: tag}, nil 92 } 93 94 // AddRule implements routing.Router. 95 func (r *Router) AddRule(config *serial.TypedMessage, shouldAppend bool) error { 96 97 inst, err := config.GetInstance() 98 if err != nil { 99 return err 100 } 101 if c, ok := inst.(*Config); ok { 102 return r.ReloadRules(c, shouldAppend) 103 } 104 return newError("AddRule: config type error") 105 } 106 107 func (r *Router) ReloadRules(config *Config, shouldAppend bool) error { 108 r.mu.Lock() 109 defer r.mu.Unlock() 110 111 if !shouldAppend { 112 r.balancers = make(map[string]*Balancer, len(config.BalancingRule)) 113 r.rules = make([]*Rule, 0, len(config.Rule)) 114 } 115 for _, rule := range config.BalancingRule { 116 _, found := r.balancers[rule.Tag] 117 if found { 118 return newError("duplicate balancer tag") 119 } 120 balancer, err := rule.Build(r.ohm, r.dispatcher) 121 if err != nil { 122 return err 123 } 124 balancer.InjectContext(r.ctx) 125 r.balancers[rule.Tag] = balancer 126 } 127 128 for _, rule := range config.Rule { 129 if r.RuleExists(rule.GetRuleTag()) { 130 return newError("duplicate ruleTag ", rule.GetRuleTag()) 131 } 132 cond, err := rule.BuildCondition() 133 if err != nil { 134 return err 135 } 136 rr := &Rule{ 137 Condition: cond, 138 Tag: rule.GetTag(), 139 RuleTag: rule.GetRuleTag(), 140 } 141 btag := rule.GetBalancingTag() 142 if len(btag) > 0 { 143 brule, found := r.balancers[btag] 144 if !found { 145 return newError("balancer ", btag, " not found") 146 } 147 rr.Balancer = brule 148 } 149 r.rules = append(r.rules, rr) 150 } 151 152 return nil 153 } 154 155 func (r *Router) RuleExists(tag string) bool { 156 if tag != "" { 157 for _, rule := range r.rules { 158 if rule.RuleTag == tag { 159 return true 160 } 161 } 162 } 163 return false 164 } 165 166 // RemoveRule implements routing.Router. 167 func (r *Router) RemoveRule(tag string) error { 168 r.mu.Lock() 169 defer r.mu.Unlock() 170 171 newRules := []*Rule{} 172 if tag != "" { 173 for _, rule := range r.rules { 174 if rule.RuleTag != tag { 175 newRules = append(newRules, rule) 176 } 177 } 178 r.rules = newRules 179 return nil 180 } 181 return newError("empty tag name!") 182 183 } 184 func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) { 185 // SkipDNSResolve is set from DNS module. 186 // the DOH remote server maybe a domain name, 187 // this prevents cycle resolving dead loop 188 skipDNSResolve := ctx.GetSkipDNSResolve() 189 190 if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve { 191 ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) 192 } 193 194 for _, rule := range r.rules { 195 if rule.Apply(ctx) { 196 return rule, ctx, nil 197 } 198 } 199 200 if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve { 201 return nil, ctx, common.ErrNoClue 202 } 203 204 ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) 205 206 // Try applying rules again if we have IPs. 207 for _, rule := range r.rules { 208 if rule.Apply(ctx) { 209 return rule, ctx, nil 210 } 211 } 212 213 return nil, ctx, common.ErrNoClue 214 } 215 216 // Start implements common.Runnable. 217 func (r *Router) Start() error { 218 return nil 219 } 220 221 // Close implements common.Closable. 222 func (r *Router) Close() error { 223 return nil 224 } 225 226 // Type implements common.HasType. 227 func (*Router) Type() interface{} { 228 return routing.RouterType() 229 } 230 231 // GetOutboundGroupTags implements routing.Route. 232 func (r *Route) GetOutboundGroupTags() []string { 233 return r.outboundGroupTags 234 } 235 236 // GetOutboundTag implements routing.Route. 237 func (r *Route) GetOutboundTag() string { 238 return r.outboundTag 239 } 240 241 func init() { 242 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 243 r := new(Router) 244 if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error { 245 return r.Init(ctx, config.(*Config), d, ohm, dispatcher) 246 }); err != nil { 247 return nil, err 248 } 249 return r, nil 250 })) 251 }