github.com/xmplusdev/xray-core@v1.8.10/app/router/strategy_leastload.go (about) 1 package router 2 3 import ( 4 "context" 5 "math" 6 "sort" 7 "time" 8 9 "github.com/xmplusdev/xray-core/app/observatory" 10 "github.com/xmplusdev/xray-core/common" 11 "github.com/xmplusdev/xray-core/common/dice" 12 "github.com/xmplusdev/xray-core/core" 13 "github.com/xmplusdev/xray-core/features/extension" 14 ) 15 16 // LeastLoadStrategy represents a least load balancing strategy 17 type LeastLoadStrategy struct { 18 settings *StrategyLeastLoadConfig 19 costs *WeightManager 20 21 observer extension.Observatory 22 23 ctx context.Context 24 } 25 26 func (l *LeastLoadStrategy) GetPrincipleTarget(strings []string) []string { 27 var ret []string 28 nodes := l.pickOutbounds(strings) 29 for _, v := range nodes { 30 ret = append(ret, v.Tag) 31 } 32 return ret 33 } 34 35 // NewLeastLoadStrategy creates a new LeastLoadStrategy with settings 36 func NewLeastLoadStrategy(settings *StrategyLeastLoadConfig) *LeastLoadStrategy { 37 return &LeastLoadStrategy{ 38 settings: settings, 39 costs: NewWeightManager( 40 settings.Costs, 1, 41 func(value, cost float64) float64 { 42 return value * math.Pow(cost, 0.5) 43 }, 44 ), 45 } 46 } 47 48 // node is a minimal copy of HealthCheckResult 49 // we don't use HealthCheckResult directly because 50 // it may change by health checker during routing 51 type node struct { 52 Tag string 53 CountAll int 54 CountFail int 55 RTTAverage time.Duration 56 RTTDeviation time.Duration 57 RTTDeviationCost time.Duration 58 } 59 60 func (l *LeastLoadStrategy) InjectContext(ctx context.Context) { 61 l.ctx = ctx 62 } 63 64 func (s *LeastLoadStrategy) PickOutbound(candidates []string) string { 65 selects := s.pickOutbounds(candidates) 66 count := len(selects) 67 if count == 0 { 68 // goes to fallbackTag 69 return "" 70 } 71 return selects[dice.Roll(count)].Tag 72 } 73 74 func (s *LeastLoadStrategy) pickOutbounds(candidates []string) []*node { 75 qualified := s.getNodes(candidates, time.Duration(s.settings.MaxRTT)) 76 selects := s.selectLeastLoad(qualified) 77 return selects 78 } 79 80 // selectLeastLoad selects nodes according to Baselines and Expected Count. 81 // 82 // The strategy always improves network response speed, not matter which mode below is configured. 83 // But they can still have different priorities. 84 // 85 // 1. Bandwidth priority: no Baseline + Expected Count > 0.: selects `Expected Count` of nodes. 86 // (one if Expected Count <= 0) 87 // 88 // 2. Bandwidth priority advanced: Baselines + Expected Count > 0. 89 // Select `Expected Count` amount of nodes, and also those near them according to baselines. 90 // In other words, it selects according to different Baselines, until one of them matches 91 // the Expected Count, if no Baseline matches, Expected Count applied. 92 // 93 // 3. Speed priority: Baselines + `Expected Count <= 0`. 94 // go through all baselines until find selects, if not, select none. Used in combination 95 // with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback. 96 func (s *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node { 97 if len(nodes) == 0 { 98 newError("least load: no qualified outbound").AtInfo().WriteToLog() 99 return nil 100 } 101 expected := int(s.settings.Expected) 102 availableCount := len(nodes) 103 if expected > availableCount { 104 return nodes 105 } 106 107 if expected <= 0 { 108 expected = 1 109 } 110 if len(s.settings.Baselines) == 0 { 111 return nodes[:expected] 112 } 113 114 count := 0 115 // go through all base line until find expected selects 116 for _, b := range s.settings.Baselines { 117 baseline := time.Duration(b) 118 for i := count; i < availableCount; i++ { 119 if nodes[i].RTTDeviationCost >= baseline { 120 break 121 } 122 count = i + 1 123 } 124 // don't continue if find expected selects 125 if count >= expected { 126 newError("applied baseline: ", baseline).AtDebug().WriteToLog() 127 break 128 } 129 } 130 if s.settings.Expected > 0 && count < expected { 131 count = expected 132 } 133 return nodes[:count] 134 } 135 136 func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node { 137 if s.observer == nil { 138 common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { 139 s.observer = observatory 140 return nil 141 })) 142 } 143 observeResult, err := s.observer.GetObservation(s.ctx) 144 if err != nil { 145 newError("cannot get observation").Base(err).WriteToLog() 146 return make([]*node, 0) 147 } 148 149 results := observeResult.(*observatory.ObservationResult) 150 151 outboundlist := outboundList(candidates) 152 153 var ret []*node 154 155 for _, v := range results.Status { 156 if v.Alive && (v.Delay < maxRTT.Milliseconds() || maxRTT == 0) && outboundlist.contains(v.OutboundTag) { 157 record := &node{ 158 Tag: v.OutboundTag, 159 CountAll: 1, 160 CountFail: 1, 161 RTTAverage: time.Duration(v.Delay) * time.Millisecond, 162 RTTDeviation: time.Duration(v.Delay) * time.Millisecond, 163 RTTDeviationCost: time.Duration(s.costs.Apply(v.OutboundTag, float64(time.Duration(v.Delay)*time.Millisecond))), 164 } 165 166 if v.HealthPing != nil { 167 record.RTTAverage = time.Duration(v.HealthPing.Average) 168 record.RTTDeviation = time.Duration(v.HealthPing.Deviation) 169 record.RTTDeviationCost = time.Duration(s.costs.Apply(v.OutboundTag, float64(v.HealthPing.Deviation))) 170 record.CountAll = int(v.HealthPing.All) 171 record.CountFail = int(v.HealthPing.Fail) 172 173 } 174 ret = append(ret, record) 175 } 176 } 177 178 leastloadSort(ret) 179 return ret 180 } 181 182 func leastloadSort(nodes []*node) { 183 sort.Slice(nodes, func(i, j int) bool { 184 left := nodes[i] 185 right := nodes[j] 186 if left.RTTDeviationCost != right.RTTDeviationCost { 187 return left.RTTDeviationCost < right.RTTDeviationCost 188 } 189 if left.RTTAverage != right.RTTAverage { 190 return left.RTTAverage < right.RTTAverage 191 } 192 if left.CountFail != right.CountFail { 193 return left.CountFail < right.CountFail 194 } 195 if left.CountAll != right.CountAll { 196 return left.CountAll > right.CountAll 197 } 198 return left.Tag < right.Tag 199 }) 200 }