github.com/yaling888/clash@v1.53.0/adapter/outboundgroup/loadbalance.go (about)

     1  package outboundgroup
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  
    10  	"golang.org/x/net/publicsuffix"
    11  
    12  	"github.com/yaling888/clash/adapter/outbound"
    13  	"github.com/yaling888/clash/common/murmur3"
    14  	"github.com/yaling888/clash/common/singledo"
    15  	"github.com/yaling888/clash/component/dialer"
    16  	C "github.com/yaling888/clash/constant"
    17  	"github.com/yaling888/clash/constant/provider"
    18  )
    19  
    20  type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy
    21  
    22  var _ C.ProxyAdapter = (*LoadBalance)(nil)
    23  
    24  type LoadBalance struct {
    25  	*outbound.Base
    26  	disableUDP bool
    27  	disableDNS bool
    28  	single     *singledo.Single[[]C.Proxy]
    29  	providers  []provider.ProxyProvider
    30  	strategyFn strategyFn
    31  }
    32  
    33  var errStrategy = errors.New("unsupported strategy")
    34  
    35  func parseStrategy(config map[string]any) string {
    36  	if strategy, ok := config["strategy"].(string); ok {
    37  		return strategy
    38  	}
    39  	return "consistent-hashing"
    40  }
    41  
    42  func getKey(metadata *C.Metadata) string {
    43  	if metadata.Host != "" {
    44  		// ip host
    45  		if ip := net.ParseIP(metadata.Host); ip != nil {
    46  			return metadata.Host
    47  		}
    48  
    49  		if eTLD, err := publicsuffix.EffectiveTLDPlusOne(metadata.Host); err == nil {
    50  			return eTLD
    51  		}
    52  	}
    53  
    54  	if !metadata.DstIP.IsValid() {
    55  		return ""
    56  	}
    57  
    58  	return metadata.DstIP.String()
    59  }
    60  
    61  func jumpHash(key uint64, buckets int32) int32 {
    62  	var b, j int64
    63  
    64  	for j < int64(buckets) {
    65  		b = j
    66  		key = key*2862933555777941757 + 1
    67  		j = int64(float64(b+1) * (float64(int64(1)<<31) / float64((key>>33)+1)))
    68  	}
    69  
    70  	return int32(b)
    71  }
    72  
    73  // DialContext implements C.ProxyAdapter
    74  func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
    75  	defer func() {
    76  		if err == nil {
    77  			c.AppendToChains(lb)
    78  		}
    79  	}()
    80  
    81  	proxy := lb.Unwrap(metadata)
    82  
    83  	c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
    84  	return
    85  }
    86  
    87  // ListenPacketContext implements C.ProxyAdapter
    88  func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (pc C.PacketConn, err error) {
    89  	defer func() {
    90  		if err == nil {
    91  			pc.AppendToChains(lb)
    92  		}
    93  	}()
    94  
    95  	proxy := lb.Unwrap(metadata)
    96  	return proxy.ListenPacketContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
    97  }
    98  
    99  // SupportUDP implements C.ProxyAdapter
   100  func (lb *LoadBalance) SupportUDP() bool {
   101  	return !lb.disableUDP
   102  }
   103  
   104  // DisableDnsResolve implements C.DisableDnsResolve
   105  func (lb *LoadBalance) DisableDnsResolve() bool {
   106  	return lb.disableDNS
   107  }
   108  
   109  func strategyRoundRobin() strategyFn {
   110  	idx := 0
   111  	return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy {
   112  		length := len(proxies)
   113  		for i := 0; i < length; i++ {
   114  			idx = (idx + 1) % length
   115  			proxy := proxies[idx]
   116  			if proxy.Alive() {
   117  				return proxy
   118  			}
   119  		}
   120  
   121  		return proxies[0]
   122  	}
   123  }
   124  
   125  func strategyConsistentHashing() strategyFn {
   126  	maxRetry := 5
   127  	return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy {
   128  		key := uint64(murmur3.Sum32([]byte(getKey(metadata))))
   129  		buckets := int32(len(proxies))
   130  		for i := 0; i < maxRetry; i, key = i+1, key+1 {
   131  			idx := jumpHash(key, buckets)
   132  			proxy := proxies[idx]
   133  			if proxy.Alive() {
   134  				return proxy
   135  			}
   136  		}
   137  
   138  		// when availability is poor, traverse the entire list to get the available nodes
   139  		for _, proxy := range proxies {
   140  			if proxy.Alive() {
   141  				return proxy
   142  			}
   143  		}
   144  
   145  		return proxies[0]
   146  	}
   147  }
   148  
   149  // Unwrap implements C.ProxyAdapter
   150  func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy {
   151  	proxies := lb.proxies(true)
   152  	return lb.strategyFn(proxies, metadata)
   153  }
   154  
   155  func (lb *LoadBalance) proxies(touch bool) []C.Proxy {
   156  	elm, _, _ := lb.single.Do(func() ([]C.Proxy, error) {
   157  		return getProvidersProxies(lb.providers, touch), nil
   158  	})
   159  
   160  	return elm
   161  }
   162  
   163  // MarshalJSON implements C.ProxyAdapter
   164  func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
   165  	var all []string
   166  	for _, proxy := range lb.proxies(false) {
   167  		all = append(all, proxy.Name())
   168  	}
   169  	return json.Marshal(map[string]any{
   170  		"type": lb.Type().String(),
   171  		"all":  all,
   172  	})
   173  }
   174  
   175  // Cleanup implements C.ProxyAdapter
   176  func (lb *LoadBalance) Cleanup() {
   177  	lb.single.Reset()
   178  }
   179  
   180  func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
   181  	var strategyFn strategyFn
   182  	switch strategy {
   183  	case "consistent-hashing":
   184  		strategyFn = strategyConsistentHashing()
   185  	case "round-robin":
   186  		strategyFn = strategyRoundRobin()
   187  	default:
   188  		return nil, fmt.Errorf("%w: %s", errStrategy, strategy)
   189  	}
   190  	return &LoadBalance{
   191  		Base: outbound.NewBase(outbound.BaseOption{
   192  			Name:        option.Name,
   193  			Type:        C.LoadBalance,
   194  			Interface:   option.Interface,
   195  			RoutingMark: option.RoutingMark,
   196  		}),
   197  		single:     singledo.NewSingle[[]C.Proxy](defaultGetProxiesDuration),
   198  		providers:  providers,
   199  		strategyFn: strategyFn,
   200  		disableUDP: option.DisableUDP,
   201  		disableDNS: option.DisableDNS,
   202  	}, nil
   203  }