github.com/metacubex/mihomo@v1.18.5/adapter/outboundgroup/loadbalance.go (about)

     1  package outboundgroup
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/metacubex/mihomo/adapter/outbound"
    13  	"github.com/metacubex/mihomo/common/callback"
    14  	"github.com/metacubex/mihomo/common/lru"
    15  	N "github.com/metacubex/mihomo/common/net"
    16  	"github.com/metacubex/mihomo/common/utils"
    17  	"github.com/metacubex/mihomo/component/dialer"
    18  	C "github.com/metacubex/mihomo/constant"
    19  	"github.com/metacubex/mihomo/constant/provider"
    20  
    21  	"golang.org/x/net/publicsuffix"
    22  )
    23  
    24  type strategyFn = func(proxies []C.Proxy, metadata *C.Metadata, touch bool) C.Proxy
    25  
    26  type LoadBalance struct {
    27  	*GroupBase
    28  	disableUDP     bool
    29  	strategyFn     strategyFn
    30  	testUrl        string
    31  	expectedStatus string
    32  	Hidden         bool
    33  	Icon           string
    34  }
    35  
    36  var errStrategy = errors.New("unsupported strategy")
    37  
    38  func parseStrategy(config map[string]any) string {
    39  	if strategy, ok := config["strategy"].(string); ok {
    40  		return strategy
    41  	}
    42  	return "consistent-hashing"
    43  }
    44  
    45  func getKey(metadata *C.Metadata) string {
    46  	if metadata == nil {
    47  		return ""
    48  	}
    49  
    50  	if metadata.Host != "" {
    51  		// ip host
    52  		if ip := net.ParseIP(metadata.Host); ip != nil {
    53  			return metadata.Host
    54  		}
    55  
    56  		if etld, err := publicsuffix.EffectiveTLDPlusOne(metadata.Host); err == nil {
    57  			return etld
    58  		}
    59  	}
    60  
    61  	if !metadata.DstIP.IsValid() {
    62  		return ""
    63  	}
    64  
    65  	return metadata.DstIP.String()
    66  }
    67  
    68  func getKeyWithSrcAndDst(metadata *C.Metadata) string {
    69  	dst := getKey(metadata)
    70  	src := ""
    71  	if metadata != nil {
    72  		src = metadata.SrcIP.String()
    73  	}
    74  
    75  	return fmt.Sprintf("%s%s", src, dst)
    76  }
    77  
    78  func jumpHash(key uint64, buckets int32) int32 {
    79  	var b, j int64
    80  
    81  	for j < int64(buckets) {
    82  		b = j
    83  		key = key*2862933555777941757 + 1
    84  		j = int64(float64(b+1) * (float64(int64(1)<<31) / float64((key>>33)+1)))
    85  	}
    86  
    87  	return int32(b)
    88  }
    89  
    90  // DialContext implements C.ProxyAdapter
    91  func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
    92  	proxy := lb.Unwrap(metadata, true)
    93  	c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
    94  
    95  	if err == nil {
    96  		c.AppendToChains(lb)
    97  	} else {
    98  		lb.onDialFailed(proxy.Type(), err)
    99  	}
   100  
   101  	if N.NeedHandshake(c) {
   102  		c = callback.NewFirstWriteCallBackConn(c, func(err error) {
   103  			if err == nil {
   104  				lb.onDialSuccess()
   105  			} else {
   106  				lb.onDialFailed(proxy.Type(), err)
   107  			}
   108  		})
   109  	}
   110  
   111  	return
   112  }
   113  
   114  // ListenPacketContext implements C.ProxyAdapter
   115  func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (pc C.PacketConn, err error) {
   116  	defer func() {
   117  		if err == nil {
   118  			pc.AppendToChains(lb)
   119  		}
   120  	}()
   121  
   122  	proxy := lb.Unwrap(metadata, true)
   123  	return proxy.ListenPacketContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
   124  }
   125  
   126  // SupportUDP implements C.ProxyAdapter
   127  func (lb *LoadBalance) SupportUDP() bool {
   128  	return !lb.disableUDP
   129  }
   130  
   131  // IsL3Protocol implements C.ProxyAdapter
   132  func (lb *LoadBalance) IsL3Protocol(metadata *C.Metadata) bool {
   133  	return lb.Unwrap(metadata, false).IsL3Protocol(metadata)
   134  }
   135  
   136  func strategyRoundRobin(url string) strategyFn {
   137  	idx := 0
   138  	idxMutex := sync.Mutex{}
   139  	return func(proxies []C.Proxy, metadata *C.Metadata, touch bool) C.Proxy {
   140  		idxMutex.Lock()
   141  		defer idxMutex.Unlock()
   142  
   143  		i := 0
   144  		length := len(proxies)
   145  
   146  		if touch {
   147  			defer func() {
   148  				idx = (idx + i) % length
   149  			}()
   150  		}
   151  
   152  		for ; i < length; i++ {
   153  			id := (idx + i) % length
   154  			proxy := proxies[id]
   155  			if proxy.AliveForTestUrl(url) {
   156  				i++
   157  				return proxy
   158  			}
   159  		}
   160  
   161  		return proxies[0]
   162  	}
   163  }
   164  
   165  func strategyConsistentHashing(url string) strategyFn {
   166  	maxRetry := 5
   167  	return func(proxies []C.Proxy, metadata *C.Metadata, touch bool) C.Proxy {
   168  		key := utils.MapHash(getKey(metadata))
   169  		buckets := int32(len(proxies))
   170  		for i := 0; i < maxRetry; i, key = i+1, key+1 {
   171  			idx := jumpHash(key, buckets)
   172  			proxy := proxies[idx]
   173  			if proxy.AliveForTestUrl(url) {
   174  				return proxy
   175  			}
   176  		}
   177  
   178  		// when availability is poor, traverse the entire list to get the available nodes
   179  		for _, proxy := range proxies {
   180  			if proxy.AliveForTestUrl(url) {
   181  				return proxy
   182  			}
   183  		}
   184  
   185  		return proxies[0]
   186  	}
   187  }
   188  
   189  func strategyStickySessions(url string) strategyFn {
   190  	ttl := time.Minute * 10
   191  	maxRetry := 5
   192  	lruCache := lru.New[uint64, int](
   193  		lru.WithAge[uint64, int](int64(ttl.Seconds())),
   194  		lru.WithSize[uint64, int](1000))
   195  	return func(proxies []C.Proxy, metadata *C.Metadata, touch bool) C.Proxy {
   196  		key := utils.MapHash(getKeyWithSrcAndDst(metadata))
   197  		length := len(proxies)
   198  		idx, has := lruCache.Get(key)
   199  		if !has {
   200  			idx = int(jumpHash(key+uint64(time.Now().UnixNano()), int32(length)))
   201  		}
   202  
   203  		nowIdx := idx
   204  		for i := 1; i < maxRetry; i++ {
   205  			proxy := proxies[nowIdx]
   206  			if proxy.AliveForTestUrl(url) {
   207  				if nowIdx != idx {
   208  					lruCache.Delete(key)
   209  					lruCache.Set(key, nowIdx)
   210  				}
   211  
   212  				return proxy
   213  			} else {
   214  				nowIdx = int(jumpHash(key+uint64(time.Now().UnixNano()), int32(length)))
   215  			}
   216  		}
   217  
   218  		lruCache.Delete(key)
   219  		lruCache.Set(key, 0)
   220  		return proxies[0]
   221  	}
   222  }
   223  
   224  // Unwrap implements C.ProxyAdapter
   225  func (lb *LoadBalance) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
   226  	proxies := lb.GetProxies(touch)
   227  	return lb.strategyFn(proxies, metadata, touch)
   228  }
   229  
   230  // MarshalJSON implements C.ProxyAdapter
   231  func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
   232  	var all []string
   233  	for _, proxy := range lb.GetProxies(false) {
   234  		all = append(all, proxy.Name())
   235  	}
   236  	return json.Marshal(map[string]any{
   237  		"type":           lb.Type().String(),
   238  		"all":            all,
   239  		"testUrl":        lb.testUrl,
   240  		"expectedStatus": lb.expectedStatus,
   241  		"hidden":         lb.Hidden,
   242  		"icon":           lb.Icon,
   243  	})
   244  }
   245  
   246  func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
   247  	var strategyFn strategyFn
   248  	switch strategy {
   249  	case "consistent-hashing":
   250  		strategyFn = strategyConsistentHashing(option.URL)
   251  	case "round-robin":
   252  		strategyFn = strategyRoundRobin(option.URL)
   253  	case "sticky-sessions":
   254  		strategyFn = strategyStickySessions(option.URL)
   255  	default:
   256  		return nil, fmt.Errorf("%w: %s", errStrategy, strategy)
   257  	}
   258  	return &LoadBalance{
   259  		GroupBase: NewGroupBase(GroupBaseOption{
   260  			outbound.BaseOption{
   261  				Name:        option.Name,
   262  				Type:        C.LoadBalance,
   263  				Interface:   option.Interface,
   264  				RoutingMark: option.RoutingMark,
   265  			},
   266  			option.Filter,
   267  			option.ExcludeFilter,
   268  			option.ExcludeType,
   269  			option.TestTimeout,
   270  			option.MaxFailedTimes,
   271  			providers,
   272  		}),
   273  		strategyFn:     strategyFn,
   274  		disableUDP:     option.DisableUDP,
   275  		testUrl:        option.URL,
   276  		expectedStatus: option.ExpectedStatus,
   277  		Hidden:         option.Hidden,
   278  		Icon:           option.Icon,
   279  	}, nil
   280  }