github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/router/router.go (about)

     1  package router
     2  
     3  //go:generate go run github.com/xmplusdev/xmcore/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	sync "sync"
     8  	"runtime"
     9  	"sort"
    10  
    11  	"github.com/xmplusdev/xmcore/common"
    12  	"github.com/xmplusdev/xmcore/common/serial"
    13  	"github.com/xmplusdev/xmcore/core"
    14  	"github.com/xmplusdev/xmcore/features/dns"
    15  	"github.com/xmplusdev/xmcore/features/outbound"
    16  	"github.com/xmplusdev/xmcore/features/routing"
    17  	routing_dns "github.com/xmplusdev/xmcore/features/routing/dns"
    18  )
    19  
    20  // Router is an implementation of routing.Router.
    21  type Router struct {
    22  	domainStrategy Config_DomainStrategy
    23  	rules          []*Rule
    24  	balancers      map[string]*Balancer
    25  	dns            dns.Client
    26  
    27  	ctx        context.Context
    28  	ohm        outbound.Manager
    29  	dispatcher routing.Dispatcher
    30  	mu         sync.Mutex
    31  	tag2indexmap map[string]int
    32  	index2tag    map[int]string
    33  }
    34  
    35  // Route is an implementation of routing.Route.
    36  type Route struct {
    37  	routing.Context
    38  	outboundGroupTags []string
    39  	outboundTag       string
    40  }
    41  
    42  func NewRouter() *Router {
    43  	con := NewConditionChan()
    44  	con.Add(NewInboundTagMatcher([]string{"asdf"}))
    45  	con.Add(NewProtocolMatcher([]string{"tls"}))
    46  	con.Add(NewUserMatcher([]string{"bge"}))
    47  	return &Router{
    48  		domainStrategy:     Config_AsIs,
    49  		rules:              []*Rule{&Rule{Condition: con}},
    50  		balancers:          map[string]*Balancer{},
    51  		tag2indexmap: map[string]int{},
    52  		index2tag:    map[int]string{},
    53  	}
    54  }
    55  
    56  func Romvededuplicate(users []string) []string {
    57  	sort.Strings(users)
    58  	j := 0
    59  	for i := 1; i < len(users); i++ {
    60  		if users[j] == users[i] {
    61  			continue
    62  		}
    63  		j++
    64  		// preserve the original data
    65  		// in[i], in[j] = in[j], in[i]
    66  		// only set what is required
    67  		users[j] = users[i]
    68  	}
    69  	return users[:j+1]
    70  }
    71  
    72  func (r *Router) AddUsers(tag string, emails []string) {
    73  	r.mu.Lock()
    74  	defer r.mu.Unlock()
    75  	if index, ok := r.tag2indexmap[tag]; ok {
    76  		if conditioncan, ok := r.rules[index].Condition.(*ConditionChan); ok {
    77  			for _, condition := range *conditioncan {
    78  				if usermatcher, ok := condition.(*UserMatcher); ok {
    79  					usermatcher.user = Romvededuplicate(append(usermatcher.user, emails...))
    80  					break
    81  				}
    82  			}
    83  		} else if usermatcher, ok := r.rules[index].Condition.(*UserMatcher); ok {
    84  			usermatcher.user = Romvededuplicate(append(usermatcher.user, emails...))
    85  
    86  		}
    87  	} else {
    88  		tagStartIndex := len(r.rules)
    89  		r.tag2indexmap[tag] = tagStartIndex
    90  		r.index2tag[tagStartIndex] = tag
    91  		r.rules = append(r.rules, &Rule{Condition: NewUserMatcher(emails), Tag: tag})
    92  	}
    93  	runtime.GC()
    94  }
    95  
    96  func (r *Router) RemoveUsers(Users []string) {
    97  	r.mu.Lock()
    98  	defer r.mu.Unlock()
    99  	removed_index := make([]int, 0, len(r.rules))
   100  	for _, email := range Users {
   101  		for _, rl := range r.rules {
   102  			conditions, ok := rl.Condition.(*ConditionChan)
   103  			if ok {
   104  				for _, v := range *conditions {
   105  					usermatcher, ok := v.(*UserMatcher)
   106  					if ok {
   107  						index := -1
   108  						for i, e := range usermatcher.user {
   109  							if e == email {
   110  								index = i
   111  								break
   112  							}
   113  						}
   114  						if index != -1 {
   115  							usermatcher.user = append(usermatcher.user[:index], usermatcher.user[index+1:]...)
   116  						}
   117  						break
   118  					}
   119  				}
   120  			} else {
   121  				if usermatcher, ok := rl.Condition.(*UserMatcher); ok {
   122  					index := -1
   123  					for i, e := range usermatcher.user {
   124  						if e == email {
   125  							index = i
   126  							break
   127  						}
   128  					}
   129  					if index != -1 {
   130  						usermatcher.user = append(usermatcher.user[:index], usermatcher.user[index+1:]...)
   131  					}
   132  				}
   133  			}
   134  
   135  		}
   136  	}
   137  	
   138  	for index, rl := range r.rules {
   139  		conditions, ok := rl.Condition.(*ConditionChan)
   140  		if ok {
   141  			for _, v := range *conditions {
   142  				usermatcher, ok := v.(*UserMatcher)
   143  				if ok {
   144  					if len(usermatcher.user) == 0 {
   145  						removed_index = append(removed_index, index)
   146  						break
   147  					}
   148  
   149  				}
   150  			}
   151  		} else {
   152  			usermatcher, ok := rl.Condition.(*UserMatcher)
   153  			if ok {
   154  				if len(usermatcher.user) == 0 {
   155  					removed_index = append(removed_index, index)
   156  				}
   157  			}
   158  		}
   159  
   160  	} 
   161  	
   162  	newRules := make([]*Rule, len(r.rules) - len(removed_index))
   163  	m := make(map[int]bool, len(r.rules))
   164  	for _, reomve := range removed_index {
   165  		m[reomve] = true
   166  	}
   167  	
   168  	start := 0
   169  	for index, rl := range r.rules {
   170  		if !m[index] {
   171  			newRules[start] = rl
   172  			start += 1
   173  		}
   174  	}
   175  	
   176  	newtag2indexmap := make(map[string]int, len(newRules))
   177  	newindex2tag := make(map[int]string, len(newRules))
   178  	for index, rule := range newRules {
   179  		newtag2indexmap[rule.Tag] = index
   180  		newindex2tag[index] = rule.Tag
   181  	}
   182  	
   183  	r.rules = newRules
   184  	r.tag2indexmap = newtag2indexmap
   185  	r.index2tag = newindex2tag
   186  	runtime.GC()
   187  	return
   188  }
   189  
   190  // Init initializes the Router.
   191  func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error {
   192  	r.domainStrategy = config.DomainStrategy
   193  	r.dns = d
   194  	r.ctx = ctx
   195  	r.ohm = ohm
   196  	r.dispatcher = dispatcher
   197  	
   198  	r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
   199  	r.tag2indexmap = map[string]int{}
   200  	r.index2tag = map[int]string{}
   201  	
   202  	for _, rule := range config.BalancingRule {
   203  		balancer, err := rule.Build(ohm, dispatcher)
   204  		if err != nil {
   205  			return err
   206  		}
   207  		balancer.InjectContext(ctx)
   208  		r.balancers[rule.Tag] = balancer
   209  	}
   210  
   211  	r.rules = make([]*Rule, 0, len(config.Rule))
   212  	for _, rule := range config.Rule {
   213  		cond, err := rule.BuildCondition()
   214  		if err != nil {
   215  			return err
   216  		}
   217  		rr := &Rule{
   218  			Condition: cond,
   219  			Tag:       rule.GetTag(),
   220  			RuleTag:   rule.GetRuleTag(),
   221  		}
   222  		btag := rule.GetBalancingTag()
   223  		if len(btag) > 0 {
   224  			brule, found := r.balancers[btag]
   225  			if !found {
   226  				return newError("balancer ", btag, " not found")
   227  			}
   228  			rr.Balancer = brule
   229  		}
   230  		r.rules = append(r.rules, rr)
   231  	}
   232  
   233  	return nil
   234  }
   235  
   236  // PickRoute implements routing.Router.
   237  func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) {
   238  	rule, ctx, err := r.pickRouteInternal(ctx)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  	tag, err := rule.GetTag()
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  	return &Route{Context: ctx, outboundTag: tag}, nil
   247  }
   248  
   249  // AddRule implements routing.Router.
   250  func (r *Router) AddRule(config *serial.TypedMessage, shouldAppend bool) error {
   251  
   252  	inst, err := config.GetInstance()
   253  	if err != nil {
   254  		return err
   255  	}
   256  	if c, ok := inst.(*Config); ok {
   257  		return r.ReloadRules(c, shouldAppend)
   258  	}
   259  	return newError("AddRule: config type error")
   260  }
   261  
   262  func (r *Router) ReloadRules(config *Config, shouldAppend bool) error {
   263  	r.mu.Lock()
   264  	defer r.mu.Unlock()
   265  
   266  	if !shouldAppend {
   267  		r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
   268  		r.rules = make([]*Rule, 0, len(config.Rule))
   269  	}
   270  	for _, rule := range config.BalancingRule {
   271  		_, found := r.balancers[rule.Tag]
   272  		if found {
   273  			return newError("duplicate balancer tag")
   274  		}
   275  		balancer, err := rule.Build(r.ohm, r.dispatcher)
   276  		if err != nil {
   277  			return err
   278  		}
   279  		balancer.InjectContext(r.ctx)
   280  		r.balancers[rule.Tag] = balancer
   281  	}
   282  
   283  	for _, rule := range config.Rule {
   284  		if r.RuleExists(rule.GetRuleTag()) {
   285  			return newError("duplicate ruleTag ", rule.GetRuleTag())
   286  		}
   287  		cond, err := rule.BuildCondition()
   288  		if err != nil {
   289  			return err
   290  		}
   291  		rr := &Rule{
   292  			Condition: cond,
   293  			Tag:       rule.GetTag(),
   294  			RuleTag:   rule.GetRuleTag(),
   295  		}
   296  		btag := rule.GetBalancingTag()
   297  		if len(btag) > 0 {
   298  			brule, found := r.balancers[btag]
   299  			if !found {
   300  				return newError("balancer ", btag, " not found")
   301  			}
   302  			rr.Balancer = brule
   303  		}
   304  		r.rules = append(r.rules, rr)
   305  	}
   306  
   307  	return nil
   308  }
   309  
   310  func (r *Router) RuleExists(tag string) bool {
   311  	if tag != "" {
   312  		for _, rule := range r.rules {
   313  			if rule.RuleTag == tag {
   314  				return true
   315  			}
   316  		}
   317  	}
   318  	return false
   319  }
   320  
   321  // RemoveRule implements routing.Router.
   322  func (r *Router) RemoveRule(tag string) error {
   323  	r.mu.Lock()
   324  	defer r.mu.Unlock()
   325  
   326  	newRules := []*Rule{}
   327  	if tag != "" {
   328  		for _, rule := range r.rules {
   329  			if rule.RuleTag != tag {
   330  				newRules = append(newRules, rule)
   331  			}
   332  		}
   333  		r.rules = newRules
   334  		return nil
   335  	}
   336  	return newError("empty tag name!")
   337  
   338  }
   339  func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) {
   340  	// SkipDNSResolve is set from DNS module.
   341  	// the DOH remote server maybe a domain name,
   342  	// this prevents cycle resolving dead loop
   343  	skipDNSResolve := ctx.GetSkipDNSResolve()
   344  
   345  	if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve {
   346  		ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
   347  	}
   348  
   349  	for _, rule := range r.rules {
   350  		if rule.Apply(ctx) {
   351  			return rule, ctx, nil
   352  		}
   353  	}
   354  
   355  	if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve {
   356  		return nil, ctx, common.ErrNoClue
   357  	}
   358  
   359  	ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
   360  
   361  	// Try applying rules again if we have IPs.
   362  	for _, rule := range r.rules {
   363  		if rule.Apply(ctx) {
   364  			return rule, ctx, nil
   365  		}
   366  	}
   367  
   368  	return nil, ctx, common.ErrNoClue
   369  }
   370  
   371  // Start implements common.Runnable.
   372  func (r *Router) Start() error {
   373  	return nil
   374  }
   375  
   376  // Close implements common.Closable.
   377  func (r *Router) Close() error {
   378  	return nil
   379  }
   380  
   381  // Type implements common.HasType.
   382  func (*Router) Type() interface{} {
   383  	return routing.RouterType()
   384  }
   385  
   386  // GetOutboundGroupTags implements routing.Route.
   387  func (r *Route) GetOutboundGroupTags() []string {
   388  	return r.outboundGroupTags
   389  }
   390  
   391  // GetOutboundTag implements routing.Route.
   392  func (r *Route) GetOutboundTag() string {
   393  	return r.outboundTag
   394  }
   395  
   396  func init() {
   397  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   398  		r := new(Router)
   399  		if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager, dispatcher routing.Dispatcher) error {
   400  			return r.Init(ctx, config.(*Config), d, ohm, dispatcher)
   401  		}); err != nil {
   402  			return nil, err
   403  		}
   404  		return r, nil
   405  	}))
   406  }