github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/router/strategy_leastload.go (about) 1 package router 2 3 import ( 4 "context" 5 "math" 6 "sort" 7 "time" 8 9 "github.com/golang/protobuf/proto" 10 11 core "github.com/v2fly/v2ray-core/v5" 12 "github.com/v2fly/v2ray-core/v5/app/observatory" 13 "github.com/v2fly/v2ray-core/v5/common" 14 "github.com/v2fly/v2ray-core/v5/common/dice" 15 "github.com/v2fly/v2ray-core/v5/features" 16 "github.com/v2fly/v2ray-core/v5/features/extension" 17 ) 18 19 // LeastLoadStrategy represents a least load balancing strategy 20 type LeastLoadStrategy struct { 21 settings *StrategyLeastLoadConfig 22 costs *WeightManager 23 24 observer extension.Observatory 25 26 ctx context.Context 27 } 28 29 func (l *LeastLoadStrategy) GetPrincipleTarget(strings []string) []string { 30 var ret []string 31 nodes := l.pickOutbounds(strings) 32 for _, v := range nodes { 33 ret = append(ret, v.Tag) 34 } 35 return ret 36 } 37 38 // NewLeastLoadStrategy creates a new LeastLoadStrategy with settings 39 func NewLeastLoadStrategy(settings *StrategyLeastLoadConfig) *LeastLoadStrategy { 40 return &LeastLoadStrategy{ 41 settings: settings, 42 costs: NewWeightManager( 43 settings.Costs, 1, 44 func(value, cost float64) float64 { 45 return value * math.Pow(cost, 0.5) 46 }, 47 ), 48 } 49 } 50 51 // node is a minimal copy of HealthCheckResult 52 // we don't use HealthCheckResult directly because 53 // it may change by health checker during routing 54 type node struct { 55 Tag string 56 CountAll int 57 CountFail int 58 RTTAverage time.Duration 59 RTTDeviation time.Duration 60 RTTDeviationCost time.Duration 61 } 62 63 func (l *LeastLoadStrategy) InjectContext(ctx context.Context) { 64 l.ctx = ctx 65 } 66 67 func (l *LeastLoadStrategy) PickOutbound(candidates []string) string { 68 selects := l.pickOutbounds(candidates) 69 count := len(selects) 70 if count == 0 { 71 // goes to fallbackTag 72 return "" 73 } 74 return selects[dice.Roll(count)].Tag 75 } 76 77 func (l *LeastLoadStrategy) pickOutbounds(candidates []string) []*node { 78 qualified := l.getNodes(candidates, time.Duration(l.settings.MaxRTT)) 79 selects := l.selectLeastLoad(qualified) 80 return selects 81 } 82 83 // selectLeastLoad selects nodes according to Baselines and Expected Count. 84 // 85 // The strategy always improves network response speed, not matter which mode below is configured. 86 // But they can still have different priorities. 87 // 88 // 1. Bandwidth priority: no Baseline + Expected Count > 0.: selects `Expected Count` of nodes. 89 // (one if Expected Count <= 0) 90 // 91 // 2. Bandwidth priority advanced: Baselines + Expected Count > 0. 92 // Select `Expected Count` amount of nodes, and also those near them according to baselines. 93 // In other words, it selects according to different Baselines, until one of them matches 94 // the Expected Count, if no Baseline matches, Expected Count applied. 95 // 96 // 3. Speed priority: Baselines + `Expected Count <= 0`. 97 // go through all baselines until find selects, if not, select none. Used in combination 98 // with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback. 99 func (l *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node { 100 if len(nodes) == 0 { 101 newError("least load: no qualified outbound").AtInfo().WriteToLog() 102 return nil 103 } 104 expected := int(l.settings.Expected) 105 availableCount := len(nodes) 106 if expected > availableCount { 107 return nodes 108 } 109 110 if expected <= 0 { 111 expected = 1 112 } 113 if len(l.settings.Baselines) == 0 { 114 return nodes[:expected] 115 } 116 117 count := 0 118 // go through all base line until find expected selects 119 for _, b := range l.settings.Baselines { 120 baseline := time.Duration(b) 121 for i := count; i < availableCount; i++ { 122 if nodes[i].RTTDeviationCost >= baseline { 123 break 124 } 125 count = i + 1 126 } 127 // don't continue if find expected selects 128 if count >= expected { 129 newError("applied baseline: ", baseline).AtDebug().WriteToLog() 130 break 131 } 132 } 133 if l.settings.Expected > 0 && count < expected { 134 count = expected 135 } 136 return nodes[:count] 137 } 138 139 func (l *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node { 140 if l.observer == nil { 141 common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error { 142 l.observer = observatory 143 return nil 144 })) 145 } 146 147 var result proto.Message 148 if l.settings.ObserverTag == "" { 149 observeResult, err := l.observer.GetObservation(l.ctx) 150 if err != nil { 151 newError("cannot get observation").Base(err).WriteToLog() 152 return make([]*node, 0) 153 } 154 result = observeResult 155 } else { 156 observeResult, err := common.Must2(l.observer.(features.TaggedFeatures).GetFeaturesByTag(l.settings.ObserverTag)).(extension.Observatory).GetObservation(l.ctx) 157 if err != nil { 158 newError("cannot get observation").Base(err).WriteToLog() 159 return make([]*node, 0) 160 } 161 result = observeResult 162 } 163 164 results := result.(*observatory.ObservationResult) 165 166 outboundlist := outboundList(candidates) 167 168 var ret []*node 169 170 for _, v := range results.Status { 171 if v.Alive && (v.Delay < maxRTT.Milliseconds() || maxRTT == 0) && outboundlist.contains(v.OutboundTag) { 172 record := &node{ 173 Tag: v.OutboundTag, 174 CountAll: 1, 175 CountFail: 1, 176 RTTAverage: time.Duration(v.Delay) * time.Millisecond, 177 RTTDeviation: time.Duration(v.Delay) * time.Millisecond, 178 RTTDeviationCost: time.Duration(l.costs.Apply(v.OutboundTag, float64(time.Duration(v.Delay)*time.Millisecond))), 179 } 180 181 if v.HealthPing != nil { 182 record.RTTAverage = time.Duration(v.HealthPing.Average) 183 record.RTTDeviation = time.Duration(v.HealthPing.Deviation) 184 record.RTTDeviationCost = time.Duration(l.costs.Apply(v.OutboundTag, float64(v.HealthPing.Deviation))) 185 record.CountAll = int(v.HealthPing.All) 186 record.CountFail = int(v.HealthPing.Fail) 187 } 188 ret = append(ret, record) 189 } 190 } 191 192 leastloadSort(ret) 193 return ret 194 } 195 196 func leastloadSort(nodes []*node) { 197 sort.Slice(nodes, func(i, j int) bool { 198 left := nodes[i] 199 right := nodes[j] 200 if left.RTTDeviationCost != right.RTTDeviationCost { 201 return left.RTTDeviationCost < right.RTTDeviationCost 202 } 203 if left.RTTAverage != right.RTTAverage { 204 return left.RTTAverage < right.RTTAverage 205 } 206 if left.CountFail != right.CountFail { 207 return left.CountFail < right.CountFail 208 } 209 if left.CountAll != right.CountAll { 210 return left.CountAll > right.CountAll 211 } 212 return left.Tag < right.Tag 213 }) 214 } 215 216 func init() { 217 common.Must(common.RegisterConfig((*StrategyLeastLoadConfig)(nil), nil)) 218 }