
     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    14  // Package consistenthash provides consistent hash utilities.
    15  package consistenthash
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"sort"
    21  	"strconv"
    22  	"sync"
    23  	"time"
    25  	""
    26  	""
    27  	""
    28  )
    30  // defaultReplicas is the default virtual node coefficient.
    31  const (
    32  	defaultReplicas int = 100
    33  	prime               = 16777619
    34  )
    36  // Hash is the hash function type.
    37  type Hash func(data []byte) uint64
    39  // defaultHashFunc uses CRC32 as the default.
    40  var defaultHashFunc Hash = xxhash.Sum64
    42  func init() {
    43  	loadbalance.Register("consistent_hash", NewConsistentHash())
    44  }
    46  // NewConsistentHash creates a new ConsistentHash.
    47  func NewConsistentHash() *ConsistentHash {
    48  	return &ConsistentHash{
    49  		pickers:  new(sync.Map),
    50  		hashFunc: defaultHashFunc,
    51  	}
    52  }
    54  // NewCustomConsistentHash creates a new ConsistentHash with custom hash function.
    55  func NewCustomConsistentHash(hashFunc Hash) *ConsistentHash {
    56  	return &ConsistentHash{
    57  		pickers:  new(sync.Map),
    58  		hashFunc: hashFunc,
    59  	}
    60  }
    62  // ConsistentHash defines the consistent hash.
    63  type ConsistentHash struct {
    64  	pickers  *sync.Map
    65  	interval time.Duration
    66  	hashFunc Hash
    67  }
    69  // Select implements loadbalance.LoadBalancer.
    70  func (ch *ConsistentHash) Select(serviceName string, list []*registry.Node,
    71  	opt ...loadbalance.Option) (*registry.Node, error) {
    72  	opts := &loadbalance.Options{}
    73  	for _, o := range opt {
    74  		o(opts)
    75  	}
    76  	p, ok := ch.pickers.Load(serviceName)
    77  	if ok {
    78  		return p.(*chPicker).Pick(list, opts)
    79  	}
    81  	newPicker := &chPicker{
    82  		interval: ch.interval,
    83  		hashFunc: ch.hashFunc,
    84  	}
    85  	v, ok := ch.pickers.LoadOrStore(serviceName, newPicker)
    86  	if !ok {
    87  		return newPicker.Pick(list, opts)
    88  	}
    89  	return v.(*chPicker).Pick(list, opts)
    90  }
    92  // chPicker is the picker of the consistent hash.
    93  type chPicker struct {
    94  	list     []*registry.Node
    95  	hashFunc Hash
    96  	keys     Uint64Slice                 // a hash slice of sorted node list, it's length is #(node)*replica
    97  	hashMap  map[uint64][]*registry.Node // a map which keeps hash-nodes maps
    98  	mu       sync.Mutex
    99  	interval time.Duration
   100  }
   102  // Pick picks a node.
   103  func (p *chPicker) Pick(list []*registry.Node, opts *loadbalance.Options) (*registry.Node, error) {
   104  	if len(list) == 0 {
   105  		return nil, loadbalance.ErrNoServerAvailable
   106  	}
   107  	// Returns error if opts.Key is not provided.
   108  	if opts.Key == "" {
   109  		return nil, errors.New("missing key")
   110  	}
   111  	tmpKeys, tmpMap, err := p.updateState(list, opts.Replicas)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	hash := p.hashFunc([]byte(opts.Key))
   116  	// Find the best matched node by binary search. Node A is better than B if A's hash value is
   117  	// greater than B's.
   118  	idx := sort.Search(len(tmpKeys), func(i int) bool { return tmpKeys[i] >= hash })
   119  	if idx == len(tmpKeys) {
   120  		idx = 0
   121  	}
   122  	nodes, ok := tmpMap[tmpKeys[idx]]
   123  	if !ok {
   124  		return nil, loadbalance.ErrNoServerAvailable
   125  	}
   126  	switch len(nodes) {
   127  	case 1:
   128  		return nodes[0], nil
   129  	default:
   130  		innerIndex := p.hashFunc(innerRepr(opts.Key))
   131  		pos := int(innerIndex % uint64(len(nodes)))
   132  		return nodes[pos], nil
   133  	}
   134  }
   136  // updateState recalculates list every so often if nodes changed.
   137  func (p *chPicker) updateState(list []*registry.Node, replicas int) (Uint64Slice, map[uint64][]*registry.Node, error) {
   139  	defer
   140  	// if node list is the same as last update, there is no need to update hash ring.
   141  	if isNodeSliceEqualBCE(p.list, list) {
   142  		return p.keys, p.hashMap, nil
   143  	}
   144  	actualReplicas := replicas
   145  	if actualReplicas <= 0 {
   146  		actualReplicas = defaultReplicas
   147  	}
   148  	// update node list.
   149  	p.list = list
   150  	p.hashMap = make(map[uint64][]*registry.Node)
   151  	p.keys = make(Uint64Slice, len(list)*actualReplicas)
   152  	for i, node := range list {
   153  		if node == nil {
   154  			// node must not be nil.
   155  			return nil, nil, errors.New("list contains nil node")
   156  		}
   157  		for j := 0; j < actualReplicas; j++ {
   158  			hash := p.hashFunc([]byte(strconv.Itoa(j) + node.Address))
   159  			p.keys[i*(actualReplicas)+j] = hash
   160  			p.hashMap[hash] = append(p.hashMap[hash], node)
   161  		}
   162  	}
   163  	sort.Sort(p.keys)
   164  	return p.keys, p.hashMap, nil
   165  }
   167  // Uint64Slice defines uint64 slice.
   168  type Uint64Slice []uint64
   170  // Len returns the length of the slice.
   171  func (s Uint64Slice) Len() int {
   172  	return len(s)
   173  }
   175  // Less returns whether the value at i is less than j.
   176  func (s Uint64Slice) Less(i, j int) bool {
   177  	return s[i] < s[j]
   178  }
   180  // Swap swaps values between i and j.
   181  func (s Uint64Slice) Swap(i, j int) {
   182  	s[i], s[j] = s[j], s[i]
   183  }
   185  // isNodeSliceEqualBCE check whether two node list is equal by BCE.
   186  func isNodeSliceEqualBCE(a, b []*registry.Node) bool {
   187  	if len(a) != len(b) {
   188  		return false
   189  	}
   190  	if (a == nil) != (b == nil) {
   191  		return false
   192  	}
   193  	b = b[:len(a)]
   194  	for i, v := range a {
   195  		if (v == nil) != (b[i] == nil) {
   196  			return false
   197  		}
   198  		if v.Address != b[i].Address {
   199  			return false
   200  		}
   201  	}
   202  	return true
   203  }
   205  func innerRepr(key interface{}) []byte {
   206  	return []byte(fmt.Sprintf("%d:%v", prime, key))
   207  }