github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/router/balancing.go (about)

     1  package router
     2  
     3  import (
     4  	"context"
     5  	sync "sync"
     6  
     7  	"github.com/xmplusdev/xmcore/features/extension"
     8  	"github.com/xmplusdev/xmcore/features/outbound"
     9  )
    10  
    11  type BalancingStrategy interface {
    12  	PickOutbound([]string) string
    13  }
    14  
    15  type BalancingPrincipleTarget interface {
    16  	GetPrincipleTarget([]string) []string
    17  }
    18  
    19  type RoundRobinStrategy struct {
    20  	mu    sync.Mutex
    21  	index int
    22  }
    23  
    24  func (s *RoundRobinStrategy) PickOutbound(tags []string) string {
    25  	n := len(tags)
    26  	if n == 0 {
    27  		panic("0 tags")
    28  	}
    29  
    30  	s.mu.Lock()
    31  	defer s.mu.Unlock()
    32  	tag := tags[s.index%n]
    33  	s.index = (s.index + 1) % n
    34  	return tag
    35  }
    36  
    37  type Balancer struct {
    38  	selectors   []string
    39  	strategy    BalancingStrategy
    40  	ohm         outbound.Manager
    41  	fallbackTag string
    42  
    43  	override override
    44  }
    45  
    46  // PickOutbound picks the tag of a outbound
    47  func (b *Balancer) PickOutbound() (string, error) {
    48  	candidates, err := b.SelectOutbounds()
    49  	if err != nil {
    50  		if b.fallbackTag != "" {
    51  			newError("fallback to [", b.fallbackTag, "], due to error: ", err).AtInfo().WriteToLog()
    52  			return b.fallbackTag, nil
    53  		}
    54  		return "", err
    55  	}
    56  	var tag string
    57  	if o := b.override.Get(); o != "" {
    58  		tag = o
    59  	} else {
    60  		tag = b.strategy.PickOutbound(candidates)
    61  	}
    62  	if tag == "" {
    63  		if b.fallbackTag != "" {
    64  			newError("fallback to [", b.fallbackTag, "], due to empty tag returned").AtInfo().WriteToLog()
    65  			return b.fallbackTag, nil
    66  		}
    67  		// will use default handler
    68  		return "", newError("balancing strategy returns empty tag")
    69  	}
    70  	return tag, nil
    71  }
    72  
    73  func (b *Balancer) InjectContext(ctx context.Context) {
    74  	if contextReceiver, ok := b.strategy.(extension.ContextReceiver); ok {
    75  		contextReceiver.InjectContext(ctx)
    76  	}
    77  }
    78  
    79  // SelectOutbounds select outbounds with selectors of the Balancer
    80  func (b *Balancer) SelectOutbounds() ([]string, error) {
    81  	hs, ok := b.ohm.(outbound.HandlerSelector)
    82  	if !ok {
    83  		return nil, newError("outbound.Manager is not a HandlerSelector")
    84  	}
    85  	tags := hs.Select(b.selectors)
    86  	return tags, nil
    87  }
    88  
    89  // GetPrincipleTarget implements routing.BalancerPrincipleTarget
    90  func (r *Router) GetPrincipleTarget(tag string) ([]string, error) {
    91  	if b, ok := r.balancers[tag]; ok {
    92  		if s, ok := b.strategy.(BalancingPrincipleTarget); ok {
    93  			candidates, err := b.SelectOutbounds()
    94  			if err != nil {
    95  				return nil, newError("unable to select outbounds").Base(err)
    96  			}
    97  			return s.GetPrincipleTarget(candidates), nil
    98  		}
    99  		return nil, newError("unsupported GetPrincipleTarget")
   100  	}
   101  	return nil, newError("cannot find tag")
   102  }
   103  
   104  // SetOverrideTarget implements routing.BalancerOverrider
   105  func (r *Router) SetOverrideTarget(tag, target string) error {
   106  	if b, ok := r.balancers[tag]; ok {
   107  		b.override.Put(target)
   108  		return nil
   109  	}
   110  	return newError("cannot find tag")
   111  }
   112  
   113  // GetOverrideTarget implements routing.BalancerOverrider
   114  func (r *Router) GetOverrideTarget(tag string) (string, error) {
   115  	if b, ok := r.balancers[tag]; ok {
   116  		return b.override.Get(), nil
   117  	}
   118  	return "", newError("cannot find tag")
   119  }