github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/router/strategy_leastload.go (about)

     1  package router
     2  
     3  import (
     4  	"context"
     5  	"math"
     6  	"sort"
     7  	"time"
     8  
     9  	"github.com/golang/protobuf/proto"
    10  
    11  	core "github.com/v2fly/v2ray-core/v5"
    12  	"github.com/v2fly/v2ray-core/v5/app/observatory"
    13  	"github.com/v2fly/v2ray-core/v5/common"
    14  	"github.com/v2fly/v2ray-core/v5/common/dice"
    15  	"github.com/v2fly/v2ray-core/v5/features"
    16  	"github.com/v2fly/v2ray-core/v5/features/extension"
    17  )
    18  
    19  // LeastLoadStrategy represents a least load balancing strategy
    20  type LeastLoadStrategy struct {
    21  	settings *StrategyLeastLoadConfig
    22  	costs    *WeightManager
    23  
    24  	observer extension.Observatory
    25  
    26  	ctx context.Context
    27  }
    28  
    29  func (l *LeastLoadStrategy) GetPrincipleTarget(strings []string) []string {
    30  	var ret []string
    31  	nodes := l.pickOutbounds(strings)
    32  	for _, v := range nodes {
    33  		ret = append(ret, v.Tag)
    34  	}
    35  	return ret
    36  }
    37  
    38  // NewLeastLoadStrategy creates a new LeastLoadStrategy with settings
    39  func NewLeastLoadStrategy(settings *StrategyLeastLoadConfig) *LeastLoadStrategy {
    40  	return &LeastLoadStrategy{
    41  		settings: settings,
    42  		costs: NewWeightManager(
    43  			settings.Costs, 1,
    44  			func(value, cost float64) float64 {
    45  				return value * math.Pow(cost, 0.5)
    46  			},
    47  		),
    48  	}
    49  }
    50  
    51  // node is a minimal copy of HealthCheckResult
    52  // we don't use HealthCheckResult directly because
    53  // it may change by health checker during routing
    54  type node struct {
    55  	Tag              string
    56  	CountAll         int
    57  	CountFail        int
    58  	RTTAverage       time.Duration
    59  	RTTDeviation     time.Duration
    60  	RTTDeviationCost time.Duration
    61  }
    62  
    63  func (l *LeastLoadStrategy) InjectContext(ctx context.Context) {
    64  	l.ctx = ctx
    65  }
    66  
    67  func (l *LeastLoadStrategy) PickOutbound(candidates []string) string {
    68  	selects := l.pickOutbounds(candidates)
    69  	count := len(selects)
    70  	if count == 0 {
    71  		// goes to fallbackTag
    72  		return ""
    73  	}
    74  	return selects[dice.Roll(count)].Tag
    75  }
    76  
    77  func (l *LeastLoadStrategy) pickOutbounds(candidates []string) []*node {
    78  	qualified := l.getNodes(candidates, time.Duration(l.settings.MaxRTT))
    79  	selects := l.selectLeastLoad(qualified)
    80  	return selects
    81  }
    82  
    83  // selectLeastLoad selects nodes according to Baselines and Expected Count.
    84  //
    85  // The strategy always improves network response speed, not matter which mode below is configured.
    86  // But they can still have different priorities.
    87  //
    88  // 1. Bandwidth priority: no Baseline + Expected Count > 0.: selects `Expected Count` of nodes.
    89  // (one if Expected Count <= 0)
    90  //
    91  // 2. Bandwidth priority advanced: Baselines + Expected Count > 0.
    92  // Select `Expected Count` amount of nodes, and also those near them according to baselines.
    93  // In other words, it selects according to different Baselines, until one of them matches
    94  // the Expected Count, if no Baseline matches, Expected Count applied.
    95  //
    96  // 3. Speed priority: Baselines + `Expected Count <= 0`.
    97  // go through all baselines until find selects, if not, select none. Used in combination
    98  // with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback.
    99  func (l *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node {
   100  	if len(nodes) == 0 {
   101  		newError("least load: no qualified outbound").AtInfo().WriteToLog()
   102  		return nil
   103  	}
   104  	expected := int(l.settings.Expected)
   105  	availableCount := len(nodes)
   106  	if expected > availableCount {
   107  		return nodes
   108  	}
   109  
   110  	if expected <= 0 {
   111  		expected = 1
   112  	}
   113  	if len(l.settings.Baselines) == 0 {
   114  		return nodes[:expected]
   115  	}
   116  
   117  	count := 0
   118  	// go through all base line until find expected selects
   119  	for _, b := range l.settings.Baselines {
   120  		baseline := time.Duration(b)
   121  		for i := count; i < availableCount; i++ {
   122  			if nodes[i].RTTDeviationCost >= baseline {
   123  				break
   124  			}
   125  			count = i + 1
   126  		}
   127  		// don't continue if find expected selects
   128  		if count >= expected {
   129  			newError("applied baseline: ", baseline).AtDebug().WriteToLog()
   130  			break
   131  		}
   132  	}
   133  	if l.settings.Expected > 0 && count < expected {
   134  		count = expected
   135  	}
   136  	return nodes[:count]
   137  }
   138  
   139  func (l *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node {
   140  	if l.observer == nil {
   141  		common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error {
   142  			l.observer = observatory
   143  			return nil
   144  		}))
   145  	}
   146  
   147  	var result proto.Message
   148  	if l.settings.ObserverTag == "" {
   149  		observeResult, err := l.observer.GetObservation(l.ctx)
   150  		if err != nil {
   151  			newError("cannot get observation").Base(err).WriteToLog()
   152  			return make([]*node, 0)
   153  		}
   154  		result = observeResult
   155  	} else {
   156  		observeResult, err := common.Must2(l.observer.(features.TaggedFeatures).GetFeaturesByTag(l.settings.ObserverTag)).(extension.Observatory).GetObservation(l.ctx)
   157  		if err != nil {
   158  			newError("cannot get observation").Base(err).WriteToLog()
   159  			return make([]*node, 0)
   160  		}
   161  		result = observeResult
   162  	}
   163  
   164  	results := result.(*observatory.ObservationResult)
   165  
   166  	outboundlist := outboundList(candidates)
   167  
   168  	var ret []*node
   169  
   170  	for _, v := range results.Status {
   171  		if v.Alive && (v.Delay < maxRTT.Milliseconds() || maxRTT == 0) && outboundlist.contains(v.OutboundTag) {
   172  			record := &node{
   173  				Tag:              v.OutboundTag,
   174  				CountAll:         1,
   175  				CountFail:        1,
   176  				RTTAverage:       time.Duration(v.Delay) * time.Millisecond,
   177  				RTTDeviation:     time.Duration(v.Delay) * time.Millisecond,
   178  				RTTDeviationCost: time.Duration(l.costs.Apply(v.OutboundTag, float64(time.Duration(v.Delay)*time.Millisecond))),
   179  			}
   180  
   181  			if v.HealthPing != nil {
   182  				record.RTTAverage = time.Duration(v.HealthPing.Average)
   183  				record.RTTDeviation = time.Duration(v.HealthPing.Deviation)
   184  				record.RTTDeviationCost = time.Duration(l.costs.Apply(v.OutboundTag, float64(v.HealthPing.Deviation)))
   185  				record.CountAll = int(v.HealthPing.All)
   186  				record.CountFail = int(v.HealthPing.Fail)
   187  			}
   188  			ret = append(ret, record)
   189  		}
   190  	}
   191  
   192  	leastloadSort(ret)
   193  	return ret
   194  }
   195  
   196  func leastloadSort(nodes []*node) {
   197  	sort.Slice(nodes, func(i, j int) bool {
   198  		left := nodes[i]
   199  		right := nodes[j]
   200  		if left.RTTDeviationCost != right.RTTDeviationCost {
   201  			return left.RTTDeviationCost < right.RTTDeviationCost
   202  		}
   203  		if left.RTTAverage != right.RTTAverage {
   204  			return left.RTTAverage < right.RTTAverage
   205  		}
   206  		if left.CountFail != right.CountFail {
   207  			return left.CountFail < right.CountFail
   208  		}
   209  		if left.CountAll != right.CountAll {
   210  			return left.CountAll > right.CountAll
   211  		}
   212  		return left.Tag < right.Tag
   213  	})
   214  }
   215  
   216  func init() {
   217  	common.Must(common.RegisterConfig((*StrategyLeastLoadConfig)(nil), nil))
   218  }