github.com/xmplusdev/xray-core@v1.8.10/app/router/strategy_leastload.go (about)

     1  package router
     2  
     3  import (
     4  	"context"
     5  	"math"
     6  	"sort"
     7  	"time"
     8  
     9  	"github.com/xmplusdev/xray-core/app/observatory"
    10  	"github.com/xmplusdev/xray-core/common"
    11  	"github.com/xmplusdev/xray-core/common/dice"
    12  	"github.com/xmplusdev/xray-core/core"
    13  	"github.com/xmplusdev/xray-core/features/extension"
    14  )
    15  
    16  // LeastLoadStrategy represents a least load balancing strategy
    17  type LeastLoadStrategy struct {
    18  	settings *StrategyLeastLoadConfig
    19  	costs    *WeightManager
    20  
    21  	observer extension.Observatory
    22  
    23  	ctx context.Context
    24  }
    25  
    26  func (l *LeastLoadStrategy) GetPrincipleTarget(strings []string) []string {
    27  	var ret []string
    28  	nodes := l.pickOutbounds(strings)
    29  	for _, v := range nodes {
    30  		ret = append(ret, v.Tag)
    31  	}
    32  	return ret
    33  }
    34  
    35  // NewLeastLoadStrategy creates a new LeastLoadStrategy with settings
    36  func NewLeastLoadStrategy(settings *StrategyLeastLoadConfig) *LeastLoadStrategy {
    37  	return &LeastLoadStrategy{
    38  		settings: settings,
    39  		costs: NewWeightManager(
    40  			settings.Costs, 1,
    41  			func(value, cost float64) float64 {
    42  				return value * math.Pow(cost, 0.5)
    43  			},
    44  		),
    45  	}
    46  }
    47  
    48  // node is a minimal copy of HealthCheckResult
    49  // we don't use HealthCheckResult directly because
    50  // it may change by health checker during routing
    51  type node struct {
    52  	Tag              string
    53  	CountAll         int
    54  	CountFail        int
    55  	RTTAverage       time.Duration
    56  	RTTDeviation     time.Duration
    57  	RTTDeviationCost time.Duration
    58  }
    59  
    60  func (l *LeastLoadStrategy) InjectContext(ctx context.Context) {
    61  	l.ctx = ctx
    62  }
    63  
    64  func (s *LeastLoadStrategy) PickOutbound(candidates []string) string {
    65  	selects := s.pickOutbounds(candidates)
    66  	count := len(selects)
    67  	if count == 0 {
    68  		// goes to fallbackTag
    69  		return ""
    70  	}
    71  	return selects[dice.Roll(count)].Tag
    72  }
    73  
    74  func (s *LeastLoadStrategy) pickOutbounds(candidates []string) []*node {
    75  	qualified := s.getNodes(candidates, time.Duration(s.settings.MaxRTT))
    76  	selects := s.selectLeastLoad(qualified)
    77  	return selects
    78  }
    79  
    80  // selectLeastLoad selects nodes according to Baselines and Expected Count.
    81  //
    82  // The strategy always improves network response speed, not matter which mode below is configured.
    83  // But they can still have different priorities.
    84  //
    85  // 1. Bandwidth priority: no Baseline + Expected Count > 0.: selects `Expected Count` of nodes.
    86  // (one if Expected Count <= 0)
    87  //
    88  // 2. Bandwidth priority advanced: Baselines + Expected Count > 0.
    89  // Select `Expected Count` amount of nodes, and also those near them according to baselines.
    90  // In other words, it selects according to different Baselines, until one of them matches
    91  // the Expected Count, if no Baseline matches, Expected Count applied.
    92  //
    93  // 3. Speed priority: Baselines + `Expected Count <= 0`.
    94  // go through all baselines until find selects, if not, select none. Used in combination
    95  // with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback.
    96  func (s *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node {
    97  	if len(nodes) == 0 {
    98  		newError("least load: no qualified outbound").AtInfo().WriteToLog()
    99  		return nil
   100  	}
   101  	expected := int(s.settings.Expected)
   102  	availableCount := len(nodes)
   103  	if expected > availableCount {
   104  		return nodes
   105  	}
   106  
   107  	if expected <= 0 {
   108  		expected = 1
   109  	}
   110  	if len(s.settings.Baselines) == 0 {
   111  		return nodes[:expected]
   112  	}
   113  
   114  	count := 0
   115  	// go through all base line until find expected selects
   116  	for _, b := range s.settings.Baselines {
   117  		baseline := time.Duration(b)
   118  		for i := count; i < availableCount; i++ {
   119  			if nodes[i].RTTDeviationCost >= baseline {
   120  				break
   121  			}
   122  			count = i + 1
   123  		}
   124  		// don't continue if find expected selects
   125  		if count >= expected {
   126  			newError("applied baseline: ", baseline).AtDebug().WriteToLog()
   127  			break
   128  		}
   129  	}
   130  	if s.settings.Expected > 0 && count < expected {
   131  		count = expected
   132  	}
   133  	return nodes[:count]
   134  }
   135  
   136  func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node {
   137  	if s.observer == nil {
   138  		common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
   139  			s.observer = observatory
   140  			return nil
   141  		}))
   142  	}
   143  	observeResult, err := s.observer.GetObservation(s.ctx)
   144  	if err != nil {
   145  		newError("cannot get observation").Base(err).WriteToLog()
   146  		return make([]*node, 0)
   147  	}
   148  
   149  	results := observeResult.(*observatory.ObservationResult)
   150  
   151  	outboundlist := outboundList(candidates)
   152  
   153  	var ret []*node
   154  
   155  	for _, v := range results.Status {
   156  		if v.Alive && (v.Delay < maxRTT.Milliseconds() || maxRTT == 0) && outboundlist.contains(v.OutboundTag) {
   157  			record := &node{
   158  				Tag:              v.OutboundTag,
   159  				CountAll:         1,
   160  				CountFail:        1,
   161  				RTTAverage:       time.Duration(v.Delay) * time.Millisecond,
   162  				RTTDeviation:     time.Duration(v.Delay) * time.Millisecond,
   163  				RTTDeviationCost: time.Duration(s.costs.Apply(v.OutboundTag, float64(time.Duration(v.Delay)*time.Millisecond))),
   164  			}
   165  
   166  			if v.HealthPing != nil {
   167  				record.RTTAverage = time.Duration(v.HealthPing.Average)
   168  				record.RTTDeviation = time.Duration(v.HealthPing.Deviation)
   169  				record.RTTDeviationCost = time.Duration(s.costs.Apply(v.OutboundTag, float64(v.HealthPing.Deviation)))
   170  				record.CountAll = int(v.HealthPing.All)
   171  				record.CountFail = int(v.HealthPing.Fail)
   172  
   173  			}
   174  			ret = append(ret, record)
   175  		}
   176  	}
   177  
   178  	leastloadSort(ret)
   179  	return ret
   180  }
   181  
   182  func leastloadSort(nodes []*node) {
   183  	sort.Slice(nodes, func(i, j int) bool {
   184  		left := nodes[i]
   185  		right := nodes[j]
   186  		if left.RTTDeviationCost != right.RTTDeviationCost {
   187  			return left.RTTDeviationCost < right.RTTDeviationCost
   188  		}
   189  		if left.RTTAverage != right.RTTAverage {
   190  			return left.RTTAverage < right.RTTAverage
   191  		}
   192  		if left.CountFail != right.CountFail {
   193  			return left.CountFail < right.CountFail
   194  		}
   195  		if left.CountAll != right.CountAll {
   196  			return left.CountAll > right.CountAll
   197  		}
   198  		return left.Tag < right.Tag
   199  	})
   200  }