github.com/igoogolx/clash@v1.19.8/adapter/outboundgroup/loadbalance.go (about)

     1  package outboundgroup
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  
    10  	"github.com/igoogolx/clash/adapter/outbound"
    11  	"github.com/igoogolx/clash/common/murmur3"
    12  	"github.com/igoogolx/clash/common/singledo"
    13  	"github.com/igoogolx/clash/component/dialer"
    14  	C "github.com/igoogolx/clash/constant"
    15  	"github.com/igoogolx/clash/constant/provider"
    16  
    17  	"golang.org/x/net/publicsuffix"
    18  )
    19  
    20  type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy
    21  
    22  type LoadBalance struct {
    23  	*outbound.Base
    24  	disableUDP bool
    25  	single     *singledo.Single
    26  	providers  []provider.ProxyProvider
    27  	strategyFn strategyFn
    28  }
    29  
    30  var errStrategy = errors.New("unsupported strategy")
    31  
    32  func parseStrategy(config map[string]any) string {
    33  	if strategy, ok := config["strategy"].(string); ok {
    34  		return strategy
    35  	}
    36  	return "consistent-hashing"
    37  }
    38  
    39  func getKey(metadata *C.Metadata) string {
    40  	if metadata.Host != "" {
    41  		// ip host
    42  		if ip := net.ParseIP(metadata.Host); ip != nil {
    43  			return metadata.Host
    44  		}
    45  
    46  		if etld, err := publicsuffix.EffectiveTLDPlusOne(metadata.Host); err == nil {
    47  			return etld
    48  		}
    49  	}
    50  
    51  	if metadata.DstIP == nil {
    52  		return ""
    53  	}
    54  
    55  	return metadata.DstIP.String()
    56  }
    57  
    58  func jumpHash(key uint64, buckets int32) int32 {
    59  	var b, j int64
    60  
    61  	for j < int64(buckets) {
    62  		b = j
    63  		key = key*2862933555777941757 + 1
    64  		j = int64(float64(b+1) * (float64(int64(1)<<31) / float64((key>>33)+1)))
    65  	}
    66  
    67  	return int32(b)
    68  }
    69  
    70  // DialContext implements C.ProxyAdapter
    71  func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
    72  	defer func() {
    73  		if err == nil {
    74  			c.AppendToChains(lb)
    75  		}
    76  	}()
    77  
    78  	proxy := lb.Unwrap(metadata)
    79  
    80  	c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
    81  	return
    82  }
    83  
    84  // ListenPacketContext implements C.ProxyAdapter
    85  func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (pc C.PacketConn, err error) {
    86  	defer func() {
    87  		if err == nil {
    88  			pc.AppendToChains(lb)
    89  		}
    90  	}()
    91  
    92  	proxy := lb.Unwrap(metadata)
    93  	return proxy.ListenPacketContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
    94  }
    95  
    96  // SupportUDP implements C.ProxyAdapter
    97  func (lb *LoadBalance) SupportUDP() bool {
    98  	return !lb.disableUDP
    99  }
   100  
   101  func strategyRoundRobin() strategyFn {
   102  	idx := 0
   103  	return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy {
   104  		length := len(proxies)
   105  		for i := 0; i < length; i++ {
   106  			idx = (idx + 1) % length
   107  			proxy := proxies[idx]
   108  			if proxy.Alive() {
   109  				return proxy
   110  			}
   111  		}
   112  
   113  		return proxies[0]
   114  	}
   115  }
   116  
   117  func strategyConsistentHashing() strategyFn {
   118  	maxRetry := 5
   119  	return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy {
   120  		key := uint64(murmur3.Sum32([]byte(getKey(metadata))))
   121  		buckets := int32(len(proxies))
   122  		for i := 0; i < maxRetry; i, key = i+1, key+1 {
   123  			idx := jumpHash(key, buckets)
   124  			proxy := proxies[idx]
   125  			if proxy.Alive() {
   126  				return proxy
   127  			}
   128  		}
   129  
   130  		// when availability is poor, traverse the entire list to get the available nodes
   131  		for _, proxy := range proxies {
   132  			if proxy.Alive() {
   133  				return proxy
   134  			}
   135  		}
   136  
   137  		return proxies[0]
   138  	}
   139  }
   140  
   141  // Unwrap implements C.ProxyAdapter
   142  func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy {
   143  	proxies := lb.proxies(true)
   144  	return lb.strategyFn(proxies, metadata)
   145  }
   146  
   147  func (lb *LoadBalance) proxies(touch bool) []C.Proxy {
   148  	elm, _, _ := lb.single.Do(func() (any, error) {
   149  		return getProvidersProxies(lb.providers, touch), nil
   150  	})
   151  
   152  	return elm.([]C.Proxy)
   153  }
   154  
   155  // MarshalJSON implements C.ProxyAdapter
   156  func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
   157  	var all []string
   158  	for _, proxy := range lb.proxies(false) {
   159  		all = append(all, proxy.Name())
   160  	}
   161  	return json.Marshal(map[string]any{
   162  		"type": lb.Type().String(),
   163  		"all":  all,
   164  	})
   165  }
   166  
   167  func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
   168  	var strategyFn strategyFn
   169  	switch strategy {
   170  	case "consistent-hashing":
   171  		strategyFn = strategyConsistentHashing()
   172  	case "round-robin":
   173  		strategyFn = strategyRoundRobin()
   174  	default:
   175  		return nil, fmt.Errorf("%w: %s", errStrategy, strategy)
   176  	}
   177  	return &LoadBalance{
   178  		Base: outbound.NewBase(outbound.BaseOption{
   179  			Name:        option.Name,
   180  			Type:        C.LoadBalance,
   181  			Interface:   option.Interface,
   182  			RoutingMark: option.RoutingMark,
   183  		}),
   184  		single:     singledo.NewSingle(defaultGetProxiesDuration),
   185  		providers:  providers,
   186  		strategyFn: strategyFn,
   187  		disableUDP: option.DisableUDP,
   188  	}, nil
   189  }