github.com/xmplusdev/xray-core@v1.8.10/app/router/router.go (about) 1 package router 2 3 //go:generate go run github.com/xmplusdev/xray-core/common/errors/errorgen 4 5 import ( 6 "context" 7 sync "sync" 8 9 "github.com/xmplusdev/xray-core/common" 10 "github.com/xmplusdev/xray-core/common/serial" 11 "github.com/xmplusdev/xray-core/core" 12 "github.com/xmplusdev/xray-core/features/dns" 13 "github.com/xmplusdev/xray-core/features/outbound" 14 "github.com/xmplusdev/xray-core/features/routing" 15 routing_dns "github.com/xmplusdev/xray-core/features/routing/dns" 16 ) 17 18 // Router is an implementation of routing.Router. 19 type Router struct { 20 domainStrategy Config_DomainStrategy 21 rules []*Rule 22 balancers map[string]*Balancer 23 dns dns.Client 24 25 ctx context.Context 26 ohm outbound.Manager 27 dispatcher routing.Dispatcher 28 mu sync.Mutex 29 } 30 31 // Route is an implementation of routing.Route. 32 type Route struct { 33 routing.Context 34 outboundGroupTags []string 35 outboundTag string 36 } 37 38 39 // Init initializes the Router. 40 func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error { 41 r.domainStrategy = config.DomainStrategy 42 r.dns = d 43 r.ctx = ctx 44 r.ohm = ohm 45 r.dispatcher = dispatcher 46 47 r.balancers = make(map[string]*Balancer, len(config.BalancingRule)) 48 49 for _, rule := range config.BalancingRule { 50 balancer, err := rule.Build(ohm, dispatcher) 51 if err != nil { 52 return err 53 } 54 balancer.InjectContext(ctx) 55 r.balancers[rule.Tag] = balancer 56 } 57 58 r.rules = make([]*Rule, 0, len(config.Rule)) 59 for _, rule := range config.Rule { 60 cond, err := rule.BuildCondition() 61 if err != nil { 62 return err 63 } 64 rr := &Rule{ 65 Condition: cond, 66 Tag: rule.GetTag(), 67 RuleTag: rule.GetRuleTag(), 68 } 69 btag := rule.GetBalancingTag() 70 if len(btag) > 0 { 71 brule, found := r.balancers[btag] 72 if !found { 73 return newError("balancer ", btag, " not found") 74 } 75 rr.Balancer = brule 76 } 77 r.rules = append(r.rules, rr) 78 } 79 80 return nil 81 } 82 83 // PickRoute implements routing.Router. 84 func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) { 85 rule, ctx, err := r.pickRouteInternal(ctx) 86 if err != nil { 87 return nil, err 88 } 89 tag, err := rule.GetTag() 90 if err != nil { 91 return nil, err 92 } 93 return &Route{Context: ctx, outboundTag: tag}, nil 94 } 95 96 // AddRule implements routing.Router. 97 func (r *Router) AddRule(config *serial.TypedMessage, shouldAppend bool) error { 98 99 inst, err := config.GetInstance() 100 if err != nil { 101 return err 102 } 103 if c, ok := inst.(*Config); ok { 104 return r.ReloadRules(c, shouldAppend) 105 } 106 return newError("AddRule: config type error") 107 } 108 109 func (r *Router) ReloadRules(config *Config, shouldAppend bool) error { 110 r.mu.Lock() 111 defer r.mu.Unlock() 112 113 if !shouldAppend { 114 r.balancers = make(map[string]*Balancer, len(config.BalancingRule)) 115 r.rules = make([]*Rule, 0, len(config.Rule)) 116 } 117 for _, rule := range config.BalancingRule { 118 _, found := r.balancers[rule.Tag] 119 if found { 120 return newError("duplicate balancer tag") 121 } 122 balancer, err := rule.Build(r.ohm, r.dispatcher) 123 if err != nil { 124 return err 125 } 126 balancer.InjectContext(r.ctx) 127 r.balancers[rule.Tag] = balancer 128 } 129 130 for _, rule := range config.Rule { 131 if r.RuleExists(rule.GetRuleTag()) { 132 return newError("duplicate ruleTag ", rule.GetRuleTag()) 133 } 134 cond, err := rule.BuildCondition() 135 if err != nil { 136 return err 137 } 138 rr := &Rule{ 139 Condition: cond, 140 Tag: rule.GetTag(), 141 RuleTag: rule.GetRuleTag(), 142 } 143 btag := rule.GetBalancingTag() 144 if len(btag) > 0 { 145 brule, found := r.balancers[btag] 146 if !found { 147 return newError("balancer ", btag, " not found") 148 } 149 rr.Balancer = brule 150 } 151 r.rules = append(r.rules, rr) 152 } 153 154 return nil 155 } 156 157 func (r *Router) RuleExists(tag string) bool { 158 if tag != "" { 159 for _, rule := range r.rules { 160 if rule.RuleTag == tag { 161 return true 162 } 163 } 164 } 165 return false 166 } 167 168 // RemoveRule implements routing.Router. 169 func (r *Router) RemoveRule(tag string) error { 170 r.mu.Lock() 171 defer r.mu.Unlock() 172 173 newRules := []*Rule{} 174 if tag != "" { 175 for _, rule := range r.rules { 176 if rule.RuleTag != tag { 177 newRules = append(newRules, rule) 178 } 179 } 180 r.rules = newRules 181 return nil 182 } 183 return newError("empty tag name!") 184 185 } 186 func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) { 187 // SkipDNSResolve is set from DNS module. 188 // the DOH remote server maybe a domain name, 189 // this prevents cycle resolving dead loop 190 skipDNSResolve := ctx.GetSkipDNSResolve() 191 192 if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve { 193 ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) 194 } 195 196 for _, rule := range r.rules { 197 if rule.Apply(ctx) { 198 return rule, ctx, nil 199 } 200 } 201 202 if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve { 203 return nil, ctx, common.ErrNoClue 204 } 205 206 ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) 207 208 // Try applying rules again if we have IPs. 209 for _, rule := range r.rules { 210 if rule.Apply(ctx) { 211 return rule, ctx, nil 212 } 213 } 214 215 return nil, ctx, common.ErrNoClue 216 } 217 218 // Start implements common.Runnable. 219 func (r *Router) Start() error { 220 return nil 221 } 222 223 // Close implements common.Closable. 224 func (r *Router) Close() error { 225 return nil 226 } 227 228 // Type implements common.HasType. 229 func (*Router) Type() interface{} { 230 return routing.RouterType() 231 } 232 233 // GetOutboundGroupTags implements routing.Route. 234 func (r *Route) GetOutboundGroupTags() []string { 235 return r.outboundGroupTags 236 } 237 238 // GetOutboundTag implements routing.Route. 239 func (r *Route) GetOutboundTag() string { 240 return r.outboundTag 241 } 242 243 func init() { 244 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 245 r := new(Router) 246 if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error { 247 return r.Init(ctx, config.(*Config), d, ohm, dispatcher) 248 }); err != nil { 249 return nil, err 250 } 251 return r, nil 252 })) 253 }