github.com/chwjbn/xclash@v0.2.0/adapter/outboundgroup/loadbalance.go (about)

     1  package outboundgroup
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  
    10  	"github.com/chwjbn/xclash/adapter/outbound"
    11  	"github.com/chwjbn/xclash/common/murmur3"
    12  	"github.com/chwjbn/xclash/common/singledo"
    13  	"github.com/chwjbn/xclash/component/dialer"
    14  	C "github.com/chwjbn/xclash/constant"
    15  	"github.com/chwjbn/xclash/constant/provider"
    16  
    17  	"golang.org/x/net/publicsuffix"
    18  )
    19  
    20  type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy
    21  
    22  type LoadBalance struct {
    23  	*outbound.Base
    24  	disableUDP bool
    25  	single     *singledo.Single
    26  	providers  []provider.ProxyProvider
    27  	strategyFn strategyFn
    28  }
    29  
    30  var errStrategy = errors.New("unsupported strategy")
    31  
    32  func parseStrategy(config map[string]interface{}) string {
    33  	if elm, ok := config["strategy"]; ok {
    34  		if strategy, ok := elm.(string); ok {
    35  			return strategy
    36  		}
    37  	}
    38  	return "consistent-hashing"
    39  }
    40  
    41  func getKey(metadata *C.Metadata) string {
    42  	if metadata.Host != "" {
    43  		// ip host
    44  		if ip := net.ParseIP(metadata.Host); ip != nil {
    45  			return metadata.Host
    46  		}
    47  
    48  		if etld, err := publicsuffix.EffectiveTLDPlusOne(metadata.Host); err == nil {
    49  			return etld
    50  		}
    51  	}
    52  
    53  	if metadata.DstIP == nil {
    54  		return ""
    55  	}
    56  
    57  	return metadata.DstIP.String()
    58  }
    59  
    60  func jumpHash(key uint64, buckets int32) int32 {
    61  	var b, j int64
    62  
    63  	for j < int64(buckets) {
    64  		b = j
    65  		key = key*2862933555777941757 + 1
    66  		j = int64(float64(b+1) * (float64(int64(1)<<31) / float64((key>>33)+1)))
    67  	}
    68  
    69  	return int32(b)
    70  }
    71  
    72  // DialContext implements C.ProxyAdapter
    73  func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
    74  	defer func() {
    75  		if err == nil {
    76  			c.AppendToChains(lb)
    77  		}
    78  	}()
    79  
    80  	proxy := lb.Unwrap(metadata)
    81  
    82  	c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
    83  	return
    84  }
    85  
    86  // ListenPacketContext implements C.ProxyAdapter
    87  func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (pc C.PacketConn, err error) {
    88  	defer func() {
    89  		if err == nil {
    90  			pc.AppendToChains(lb)
    91  		}
    92  	}()
    93  
    94  	proxy := lb.Unwrap(metadata)
    95  	return proxy.ListenPacketContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
    96  }
    97  
    98  // SupportUDP implements C.ProxyAdapter
    99  func (lb *LoadBalance) SupportUDP() bool {
   100  	return !lb.disableUDP
   101  }
   102  
   103  func strategyRoundRobin() strategyFn {
   104  	idx := 0
   105  	return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy {
   106  		length := len(proxies)
   107  		for i := 0; i < length; i++ {
   108  			idx = (idx + 1) % length
   109  			proxy := proxies[idx]
   110  			if proxy.Alive() {
   111  				return proxy
   112  			}
   113  		}
   114  
   115  		return proxies[0]
   116  	}
   117  }
   118  
   119  func strategyConsistentHashing() strategyFn {
   120  	maxRetry := 5
   121  	return func(proxies []C.Proxy, metadata *C.Metadata) C.Proxy {
   122  		key := uint64(murmur3.Sum32([]byte(getKey(metadata))))
   123  		buckets := int32(len(proxies))
   124  		for i := 0; i < maxRetry; i, key = i+1, key+1 {
   125  			idx := jumpHash(key, buckets)
   126  			proxy := proxies[idx]
   127  			if proxy.Alive() {
   128  				return proxy
   129  			}
   130  		}
   131  
   132  		return proxies[0]
   133  	}
   134  }
   135  
   136  // Unwrap implements C.ProxyAdapter
   137  func (lb *LoadBalance) Unwrap(metadata *C.Metadata) C.Proxy {
   138  	proxies := lb.proxies(true)
   139  	return lb.strategyFn(proxies, metadata)
   140  }
   141  
   142  func (lb *LoadBalance) proxies(touch bool) []C.Proxy {
   143  	elm, _, _ := lb.single.Do(func() (interface{}, error) {
   144  		return getProvidersProxies(lb.providers, touch), nil
   145  	})
   146  
   147  	return elm.([]C.Proxy)
   148  }
   149  
   150  // MarshalJSON implements C.ProxyAdapter
   151  func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
   152  	var all []string
   153  	for _, proxy := range lb.proxies(false) {
   154  		all = append(all, proxy.Name())
   155  	}
   156  	return json.Marshal(map[string]interface{}{
   157  		"type": lb.Type().String(),
   158  		"all":  all,
   159  	})
   160  }
   161  
   162  func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
   163  	var strategyFn strategyFn
   164  	switch strategy {
   165  	case "consistent-hashing":
   166  		strategyFn = strategyConsistentHashing()
   167  	case "round-robin":
   168  		strategyFn = strategyRoundRobin()
   169  	default:
   170  		return nil, fmt.Errorf("%w: %s", errStrategy, strategy)
   171  	}
   172  	return &LoadBalance{
   173  		Base: outbound.NewBase(outbound.BaseOption{
   174  			Name:        option.Name,
   175  			Type:        C.LoadBalance,
   176  			Interface:   option.Interface,
   177  			RoutingMark: option.RoutingMark,
   178  		}),
   179  		single:     singledo.NewSingle(defaultGetProxiesDuration),
   180  		providers:  providers,
   181  		strategyFn: strategyFn,
   182  		disableUDP: option.DisableUDP,
   183  	}, nil
   184  }