github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/router/router.go (about) 1 package router 2 3 //go:generate go run github.com/xmplusdev/xmcore/common/errors/errorgen 4 5 import ( 6 "context" 7 sync "sync" 8 "runtime" 9 "sort" 10 11 "github.com/xmplusdev/xmcore/common" 12 "github.com/xmplusdev/xmcore/common/serial" 13 "github.com/xmplusdev/xmcore/core" 14 "github.com/xmplusdev/xmcore/features/dns" 15 "github.com/xmplusdev/xmcore/features/outbound" 16 "github.com/xmplusdev/xmcore/features/routing" 17 routing_dns "github.com/xmplusdev/xmcore/features/routing/dns" 18 ) 19 20 // Router is an implementation of routing.Router. 21 type Router struct { 22 domainStrategy Config_DomainStrategy 23 rules []*Rule 24 balancers map[string]*Balancer 25 dns dns.Client 26 27 ctx context.Context 28 ohm outbound.Manager 29 dispatcher routing.Dispatcher 30 mu sync.Mutex 31 tag2indexmap map[string]int 32 index2tag map[int]string 33 } 34 35 // Route is an implementation of routing.Route. 36 type Route struct { 37 routing.Context 38 outboundGroupTags []string 39 outboundTag string 40 } 41 42 func NewRouter() *Router { 43 con := NewConditionChan() 44 con.Add(NewInboundTagMatcher([]string{"asdf"})) 45 con.Add(NewProtocolMatcher([]string{"tls"})) 46 con.Add(NewUserMatcher([]string{"bge"})) 47 return &Router{ 48 domainStrategy: Config_AsIs, 49 rules: []*Rule{&Rule{Condition: con}}, 50 balancers: map[string]*Balancer{}, 51 tag2indexmap: map[string]int{}, 52 index2tag: map[int]string{}, 53 } 54 } 55 56 func Romvededuplicate(users []string) []string { 57 sort.Strings(users) 58 j := 0 59 for i := 1; i < len(users); i++ { 60 if users[j] == users[i] { 61 continue 62 } 63 j++ 64 // preserve the original data 65 // in[i], in[j] = in[j], in[i] 66 // only set what is required 67 users[j] = users[i] 68 } 69 return users[:j+1] 70 } 71 72 func (r *Router) AddUsers(tag string, emails []string) { 73 r.mu.Lock() 74 defer r.mu.Unlock() 75 if index, ok := r.tag2indexmap[tag]; ok { 76 if conditioncan, ok := r.rules[index].Condition.(*ConditionChan); ok { 77 for _, condition := range *conditioncan { 78 if usermatcher, ok := condition.(*UserMatcher); ok { 79 usermatcher.user = Romvededuplicate(append(usermatcher.user, emails...)) 80 break 81 } 82 } 83 } else if usermatcher, ok := r.rules[index].Condition.(*UserMatcher); ok { 84 usermatcher.user = Romvededuplicate(append(usermatcher.user, emails...)) 85 86 } 87 } else { 88 tagStartIndex := len(r.rules) 89 r.tag2indexmap[tag] = tagStartIndex 90 r.index2tag[tagStartIndex] = tag 91 r.rules = append(r.rules, &Rule{Condition: NewUserMatcher(emails), Tag: tag}) 92 } 93 runtime.GC() 94 } 95 96 func (r *Router) RemoveUsers(Users []string) { 97 r.mu.Lock() 98 defer r.mu.Unlock() 99 removed_index := make([]int, 0, len(r.rules)) 100 for _, email := range Users { 101 for _, rl := range r.rules { 102 conditions, ok := rl.Condition.(*ConditionChan) 103 if ok { 104 for _, v := range *conditions { 105 usermatcher, ok := v.(*UserMatcher) 106 if ok { 107 index := -1 108 for i, e := range usermatcher.user { 109 if e == email { 110 index = i 111 break 112 } 113 } 114 if index != -1 { 115 usermatcher.user = append(usermatcher.user[:index], usermatcher.user[index+1:]...) 116 } 117 break 118 } 119 } 120 } else { 121 if usermatcher, ok := rl.Condition.(*UserMatcher); ok { 122 index := -1 123 for i, e := range usermatcher.user { 124 if e == email { 125 index = i 126 break 127 } 128 } 129 if index != -1 { 130 usermatcher.user = append(usermatcher.user[:index], usermatcher.user[index+1:]...) 131 } 132 } 133 } 134 135 } 136 } 137 138 for index, rl := range r.rules { 139 conditions, ok := rl.Condition.(*ConditionChan) 140 if ok { 141 for _, v := range *conditions { 142 usermatcher, ok := v.(*UserMatcher) 143 if ok { 144 if len(usermatcher.user) == 0 { 145 removed_index = append(removed_index, index) 146 break 147 } 148 149 } 150 } 151 } else { 152 usermatcher, ok := rl.Condition.(*UserMatcher) 153 if ok { 154 if len(usermatcher.user) == 0 { 155 removed_index = append(removed_index, index) 156 } 157 } 158 } 159 160 } 161 162 newRules := make([]*Rule, len(r.rules) - len(removed_index)) 163 m := make(map[int]bool, len(r.rules)) 164 for _, reomve := range removed_index { 165 m[reomve] = true 166 } 167 168 start := 0 169 for index, rl := range r.rules { 170 if !m[index] { 171 newRules[start] = rl 172 start += 1 173 } 174 } 175 176 newtag2indexmap := make(map[string]int, len(newRules)) 177 newindex2tag := make(map[int]string, len(newRules)) 178 for index, rule := range newRules { 179 newtag2indexmap[rule.Tag] = index 180 newindex2tag[index] = rule.Tag 181 } 182 183 r.rules = newRules 184 r.tag2indexmap = newtag2indexmap 185 r.index2tag = newindex2tag 186 runtime.GC() 187 return 188 } 189 190 // Init initializes the Router. 191 func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error { 192 r.domainStrategy = config.DomainStrategy 193 r.dns = d 194 r.ctx = ctx 195 r.ohm = ohm 196 r.dispatcher = dispatcher 197 198 r.balancers = make(map[string]*Balancer, len(config.BalancingRule)) 199 r.tag2indexmap = map[string]int{} 200 r.index2tag = map[int]string{} 201 202 for _, rule := range config.BalancingRule { 203 balancer, err := rule.Build(ohm, dispatcher) 204 if err != nil { 205 return err 206 } 207 balancer.InjectContext(ctx) 208 r.balancers[rule.Tag] = balancer 209 } 210 211 r.rules = make([]*Rule, 0, len(config.Rule)) 212 for _, rule := range config.Rule { 213 cond, err := rule.BuildCondition() 214 if err != nil { 215 return err 216 } 217 rr := &Rule{ 218 Condition: cond, 219 Tag: rule.GetTag(), 220 RuleTag: rule.GetRuleTag(), 221 } 222 btag := rule.GetBalancingTag() 223 if len(btag) > 0 { 224 brule, found := r.balancers[btag] 225 if !found { 226 return newError("balancer ", btag, " not found") 227 } 228 rr.Balancer = brule 229 } 230 r.rules = append(r.rules, rr) 231 } 232 233 return nil 234 } 235 236 // PickRoute implements routing.Router. 237 func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) { 238 rule, ctx, err := r.pickRouteInternal(ctx) 239 if err != nil { 240 return nil, err 241 } 242 tag, err := rule.GetTag() 243 if err != nil { 244 return nil, err 245 } 246 return &Route{Context: ctx, outboundTag: tag}, nil 247 } 248 249 // AddRule implements routing.Router. 250 func (r *Router) AddRule(config *serial.TypedMessage, shouldAppend bool) error { 251 252 inst, err := config.GetInstance() 253 if err != nil { 254 return err 255 } 256 if c, ok := inst.(*Config); ok { 257 return r.ReloadRules(c, shouldAppend) 258 } 259 return newError("AddRule: config type error") 260 } 261 262 func (r *Router) ReloadRules(config *Config, shouldAppend bool) error { 263 r.mu.Lock() 264 defer r.mu.Unlock() 265 266 if !shouldAppend { 267 r.balancers = make(map[string]*Balancer, len(config.BalancingRule)) 268 r.rules = make([]*Rule, 0, len(config.Rule)) 269 } 270 for _, rule := range config.BalancingRule { 271 _, found := r.balancers[rule.Tag] 272 if found { 273 return newError("duplicate balancer tag") 274 } 275 balancer, err := rule.Build(r.ohm, r.dispatcher) 276 if err != nil { 277 return err 278 } 279 balancer.InjectContext(r.ctx) 280 r.balancers[rule.Tag] = balancer 281 } 282 283 for _, rule := range config.Rule { 284 if r.RuleExists(rule.GetRuleTag()) { 285 return newError("duplicate ruleTag ", rule.GetRuleTag()) 286 } 287 cond, err := rule.BuildCondition() 288 if err != nil { 289 return err 290 } 291 rr := &Rule{ 292 Condition: cond, 293 Tag: rule.GetTag(), 294 RuleTag: rule.GetRuleTag(), 295 } 296 btag := rule.GetBalancingTag() 297 if len(btag) > 0 { 298 brule, found := r.balancers[btag] 299 if !found { 300 return newError("balancer ", btag, " not found") 301 } 302 rr.Balancer = brule 303 } 304 r.rules = append(r.rules, rr) 305 } 306 307 return nil 308 } 309 310 func (r *Router) RuleExists(tag string) bool { 311 if tag != "" { 312 for _, rule := range r.rules { 313 if rule.RuleTag == tag { 314 return true 315 } 316 } 317 } 318 return false 319 } 320 321 // RemoveRule implements routing.Router. 322 func (r *Router) RemoveRule(tag string) error { 323 r.mu.Lock() 324 defer r.mu.Unlock() 325 326 newRules := []*Rule{} 327 if tag != "" { 328 for _, rule := range r.rules { 329 if rule.RuleTag != tag { 330 newRules = append(newRules, rule) 331 } 332 } 333 r.rules = newRules 334 return nil 335 } 336 return newError("empty tag name!") 337 338 } 339 func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) { 340 // SkipDNSResolve is set from DNS module. 341 // the DOH remote server maybe a domain name, 342 // this prevents cycle resolving dead loop 343 skipDNSResolve := ctx.GetSkipDNSResolve() 344 345 if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve { 346 ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) 347 } 348 349 for _, rule := range r.rules { 350 if rule.Apply(ctx) { 351 return rule, ctx, nil 352 } 353 } 354 355 if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve { 356 return nil, ctx, common.ErrNoClue 357 } 358 359 ctx = routing_dns.ContextWithDNSClient(ctx, r.dns) 360 361 // Try applying rules again if we have IPs. 362 for _, rule := range r.rules { 363 if rule.Apply(ctx) { 364 return rule, ctx, nil 365 } 366 } 367 368 return nil, ctx, common.ErrNoClue 369 } 370 371 // Start implements common.Runnable. 372 func (r *Router) Start() error { 373 return nil 374 } 375 376 // Close implements common.Closable. 377 func (r *Router) Close() error { 378 return nil 379 } 380 381 // Type implements common.HasType. 382 func (*Router) Type() interface{} { 383 return routing.RouterType() 384 } 385 386 // GetOutboundGroupTags implements routing.Route. 387 func (r *Route) GetOutboundGroupTags() []string { 388 return r.outboundGroupTags 389 } 390 391 // GetOutboundTag implements routing.Route. 392 func (r *Route) GetOutboundTag() string { 393 return r.outboundTag 394 } 395 396 func init() { 397 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 398 r := new(Router) 399 if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error { 400 return r.Init(ctx, config.(*Config), d, ohm, dispatcher) 401 }); err != nil { 402 return nil, err 403 } 404 return r, nil 405 })) 406 }