github.com/dubbogo/gost@v1.14.0/hash/consistent/consistent.go (about)

     1  /*
     2   * Licensed to the Apache Software Foundation (ASF) under one or more
     3   * contributor license agreements.  See the NOTICE file distributed with
     4   * this work for additional information regarding copyright ownership.
     5   * The ASF licenses this file to You under the Apache License, Version 2.0
     6   * (the "License"); you may not use this file except in compliance with
     7   * the License.  You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   */
    17  
    18  package consistent
    19  
    20  import (
    21  	"encoding/binary"
    22  	"math"
    23  	"sort"
    24  	"strconv"
    25  	"sync"
    26  	"sync/atomic"
    27  )
    28  
    29  import (
    30  	"github.com/pkg/errors"
    31  
    32  	"golang.org/x/crypto/blake2b"
    33  )
    34  
    35  import (
    36  	"github.com/dubbogo/gost/strings"
    37  )
    38  
    39  const (
    40  	replicationFactor = 10
    41  	maxBucketNum      = math.MaxUint32
    42  )
    43  
    44  var ErrNoHosts = errors.New("no hosts added")
    45  
    46  type Options struct {
    47  	HashFunc    HashFunc
    48  	ReplicaNum  int
    49  	MaxVnodeNum int
    50  }
    51  
    52  type Option func(option *Options)
    53  
    54  func WithHashFunc(hash HashFunc) Option {
    55  	return func(opts *Options) {
    56  		opts.HashFunc = hash
    57  	}
    58  }
    59  
    60  func WithReplicaNum(replicaNum int) Option {
    61  	return func(opts *Options) {
    62  		opts.ReplicaNum = replicaNum
    63  	}
    64  }
    65  
    66  func WithMaxVnodeNum(maxVnodeNum int) Option {
    67  	return func(opts *Options) {
    68  		opts.MaxVnodeNum = maxVnodeNum
    69  	}
    70  }
    71  
    72  type hashArray []uint32
    73  
    74  // Len returns the length of the hashArray
    75  func (h hashArray) Len() int { return len(h) }
    76  
    77  // Less returns true if element i is less than element j
    78  func (h hashArray) Less(i, j int) bool { return h[i] < h[j] }
    79  
    80  // Swap exchanges elements i and j
    81  func (h hashArray) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
    82  
    83  type Host struct {
    84  	Name string
    85  	Load int64
    86  }
    87  
    88  type HashFunc func([]byte) uint64
    89  
    90  func hash(key []byte) uint64 {
    91  	out := blake2b.Sum512(key)
    92  	return binary.LittleEndian.Uint64(out[:])
    93  }
    94  
    95  type Consistent struct {
    96  	circle        map[uint32]string // hash -> node name
    97  	sortedHashes  hashArray         // hash valid in ascending
    98  	loadMap       map[string]*Host  // node name -> struct Host
    99  	totalLoad     int64             // total load
   100  	replicaFactor uint32
   101  	bucketNum     uint32
   102  	hashFunc      HashFunc
   103  
   104  	sync.RWMutex
   105  }
   106  
   107  func NewConsistentHash(opts ...Option) *Consistent {
   108  	options := Options{
   109  		HashFunc:    hash,
   110  		ReplicaNum:  replicationFactor,
   111  		MaxVnodeNum: maxBucketNum,
   112  	}
   113  
   114  	for index := range opts {
   115  		opts[index](&options)
   116  	}
   117  
   118  	return &Consistent{
   119  		circle:        map[uint32]string{},
   120  		loadMap:       map[string]*Host{},
   121  		replicaFactor: uint32(options.ReplicaNum),
   122  		bucketNum:     uint32(options.MaxVnodeNum),
   123  		hashFunc:      options.HashFunc,
   124  	}
   125  }
   126  
   127  func (c *Consistent) SetHashFunc(f HashFunc) {
   128  	c.hashFunc = f
   129  }
   130  
   131  // eltKey generates a string key for an element with an index
   132  func (c *Consistent) eltKey(elt string, idx int) string {
   133  	return strconv.Itoa(idx) + elt
   134  }
   135  
   136  func (c *Consistent) Hash(key string) uint32 {
   137  	return uint32(c.hashFunc(gxstrings.Slice(key))) % c.bucketNum
   138  }
   139  
   140  // updateSortedHashes sort hashes in ascending
   141  func (c *Consistent) updateSortedHashes() {
   142  	hashes := c.sortedHashes[:0]
   143  	// reallocate if we're holding on to too much (1/4th)
   144  	if c.sortedHashes.Len()/int(c.replicaFactor*4) > len(c.circle) {
   145  		hashes = nil
   146  	}
   147  	for k := range c.circle {
   148  		hashes = append(hashes, k)
   149  	}
   150  	sort.Sort(hashes)
   151  	c.sortedHashes = hashes
   152  }
   153  
   154  func (c *Consistent) Add(host string) {
   155  	c.Lock()
   156  	defer c.Unlock()
   157  
   158  	c.add(host)
   159  }
   160  
   161  func (c *Consistent) add(host string) {
   162  	if _, ok := c.loadMap[host]; ok {
   163  		return
   164  	}
   165  
   166  	c.loadMap[host] = &Host{Name: host}
   167  	for i := uint32(0); i < c.replicaFactor; i++ {
   168  		h := c.Hash(c.eltKey(host, int(i)))
   169  		c.circle[h] = host
   170  		c.sortedHashes = append(c.sortedHashes, h)
   171  	}
   172  
   173  	c.updateSortedHashes()
   174  }
   175  
   176  // Set sets all the elements in the hash. If there are existing elements not
   177  // present in elts, they will be removed.
   178  func (c *Consistent) Set(elts []string) {
   179  	c.Lock()
   180  	defer c.Unlock()
   181  
   182  	for k := range c.loadMap {
   183  		found := true
   184  		for _, elt := range elts {
   185  			if k == elt {
   186  				found = false
   187  				break
   188  			}
   189  		}
   190  
   191  		if found {
   192  			c.remove(k)
   193  		}
   194  
   195  		for _, elt := range elts {
   196  			if _, ok := c.loadMap[elt]; !ok {
   197  				c.add(elt)
   198  			}
   199  		}
   200  	}
   201  }
   202  
   203  func (c *Consistent) Members() []string {
   204  	c.RLock()
   205  	defer c.RUnlock()
   206  
   207  	m := make([]string, 0, len(c.loadMap))
   208  	for k := range c.loadMap {
   209  		m = append(m, k)
   210  	}
   211  	return m
   212  }
   213  
   214  // Get It returns ErrNoHosts if the ring has no hosts in it
   215  func (c *Consistent) Get(key string) (string, error) {
   216  	c.RLock()
   217  	defer c.RUnlock()
   218  
   219  	if len(c.circle) == 0 {
   220  		return "", ErrNoHosts
   221  	}
   222  	return c.circle[c.sortedHashes[c.search(c.Hash(key))]], nil
   223  }
   224  
   225  // GetHash It returns ErrNoHosts if the ring has no hosts in it
   226  func (c *Consistent) GetHash(hashKey uint32) (string, error) {
   227  	c.RLock()
   228  	defer c.RUnlock()
   229  
   230  	if len(c.circle) == 0 {
   231  		return "", ErrNoHosts
   232  	}
   233  	return c.circle[c.sortedHashes[c.search(hashKey)]], nil
   234  }
   235  
   236  // GetTwo returns the two closest distinct elements to the name input in the circle
   237  func (c *Consistent) GetTwo(name string) (string, string, error) {
   238  	c.RLock()
   239  	defer c.RUnlock()
   240  
   241  	if len(c.circle) == 0 {
   242  		return "", "", ErrNoHosts
   243  	}
   244  
   245  	i := c.search(c.Hash(name))
   246  	a := c.circle[c.sortedHashes[i]]
   247  
   248  	if len(c.loadMap) == 1 {
   249  		return a, "", nil
   250  	}
   251  
   252  	start := i
   253  	var b string
   254  
   255  	for i = start + 1; i != start; i++ {
   256  		if i >= len(c.sortedHashes) {
   257  			i = 0
   258  		}
   259  		b = c.circle[c.sortedHashes[i]]
   260  		if b != a {
   261  			break
   262  		}
   263  	}
   264  	return a, b, nil
   265  }
   266  
   267  func sliceContainsMember(set []string, member string) bool {
   268  	for _, m := range set {
   269  		if m == member {
   270  			return true
   271  		}
   272  	}
   273  	return false
   274  }
   275  
   276  // GetN returns the N closest distinct elements to the name input in the circle
   277  func (c *Consistent) GetN(name string, n int) ([]string, error) {
   278  	c.RLock()
   279  	defer c.RUnlock()
   280  
   281  	if len(c.circle) == 0 {
   282  		return nil, ErrNoHosts
   283  	}
   284  
   285  	if len(c.loadMap) < n {
   286  		n = len(c.loadMap)
   287  	}
   288  
   289  	var (
   290  		i     = c.search(c.Hash(name))
   291  		start = i
   292  		res   = make([]string, 0, n)
   293  		elem  = c.circle[c.sortedHashes[i]]
   294  	)
   295  
   296  	res = append(res, elem)
   297  
   298  	if len(res) == n {
   299  		return res, nil
   300  	}
   301  
   302  	for i = start + 1; i != start; i++ {
   303  		if i >= len(c.sortedHashes) {
   304  			i = 0
   305  		}
   306  		elem = c.circle[c.sortedHashes[i]]
   307  		if !sliceContainsMember(res, elem) {
   308  			res = append(res, elem)
   309  		}
   310  		if len(res) == n {
   311  			break
   312  		}
   313  	}
   314  
   315  	return res, nil
   316  }
   317  
   318  // GetLeast It uses Consistent Hashing With Bounded loads
   319  // https://research.googleblog.com/2017/04/consistent-hashing-with-bounded-loads.html
   320  // to pick the least loaded host that can serve the key
   321  // It returns ErrNoHosts if the ring has no hosts in it.
   322  func (c *Consistent) GetLeast(key string) (string, error) {
   323  	c.RLock()
   324  	defer c.RUnlock()
   325  
   326  	if len(c.circle) == 0 {
   327  		return "", ErrNoHosts
   328  	}
   329  
   330  	idx := c.search(c.Hash(key))
   331  
   332  	i := idx
   333  	for {
   334  		host := c.circle[c.sortedHashes[i]]
   335  		if c.loadOK(host) {
   336  			return host, nil
   337  		}
   338  		i++
   339  		if i >= len(c.circle) {
   340  			i = 0
   341  		}
   342  	}
   343  }
   344  
   345  func (c *Consistent) search(key uint32) int {
   346  	idx := sort.Search(len(c.sortedHashes), func(i int) bool { return c.sortedHashes[i] >= key })
   347  	if idx >= len(c.sortedHashes) {
   348  		return 0
   349  	}
   350  	return idx
   351  }
   352  
   353  // UpdateLoad Sets the load of `host` to the given `load`
   354  func (c *Consistent) UpdateLoad(host string, load int64) {
   355  	c.Lock()
   356  	defer c.Unlock()
   357  
   358  	if _, ok := c.loadMap[host]; !ok {
   359  		return
   360  	}
   361  
   362  	c.totalLoad -= c.loadMap[host].Load
   363  	c.loadMap[host].Load = load
   364  	c.totalLoad += load
   365  }
   366  
   367  // Inc Increments the load of host by 1
   368  // should only be used with if you obtained a host with GetLeast
   369  func (c *Consistent) Inc(host string) {
   370  	c.Lock()
   371  	defer c.Unlock()
   372  
   373  	atomic.AddInt64(&c.loadMap[host].Load, 1)
   374  	atomic.AddInt64(&c.totalLoad, 1)
   375  }
   376  
   377  // Done Decrements the load of host by 1
   378  // should only be used with if you obtained a host with GetLeast
   379  func (c *Consistent) Done(host string) {
   380  	c.Lock()
   381  	defer c.Unlock()
   382  
   383  	if _, ok := c.loadMap[host]; !ok {
   384  		return
   385  	}
   386  
   387  	atomic.AddInt64(&c.loadMap[host].Load, -1)
   388  	atomic.AddInt64(&c.totalLoad, 1)
   389  }
   390  
   391  // Remove Deletes host from the ring
   392  func (c *Consistent) Remove(host string) bool {
   393  	c.Lock()
   394  	defer c.Unlock()
   395  	return c.remove(host)
   396  }
   397  
   398  func (c *Consistent) remove(host string) bool {
   399  	for i := uint32(0); i < c.replicaFactor; i++ {
   400  		h := c.Hash(c.eltKey(host, int(i)))
   401  		delete(c.circle, h)
   402  		c.delSlice(h)
   403  	}
   404  
   405  	if _, ok := c.loadMap[host]; ok {
   406  		atomic.AddInt64(&c.totalLoad, -c.loadMap[host].Load)
   407  		delete(c.loadMap, host)
   408  	}
   409  	return true
   410  }
   411  
   412  // Hosts Return the list of hosts in the ring
   413  func (c *Consistent) Hosts() []string {
   414  	c.RLock()
   415  	defer c.RUnlock()
   416  
   417  	hosts := make([]string, 0, len(c.loadMap))
   418  	for k := range c.loadMap {
   419  		hosts = append(hosts, k)
   420  	}
   421  	return hosts
   422  }
   423  
   424  // GetLoads Returns the loads of all the hosts
   425  func (c *Consistent) GetLoads() map[string]int64 {
   426  	loads := make(map[string]int64, len(c.loadMap))
   427  
   428  	for k, v := range c.loadMap {
   429  		loads[k] = v.Load
   430  	}
   431  	return loads
   432  }
   433  
   434  // MaxLoad Returns the maximum load of the single host
   435  // which is:
   436  // (total_load/number_of_hosts)*1.25
   437  // total_load = is the total number of active requests served by hosts
   438  // for more info:
   439  // https://research.googleblog.com/2017/04/consistent-hashing-with-bounded-loads.html
   440  func (c *Consistent) MaxLoad() int64 {
   441  	if c.totalLoad == 0 {
   442  		c.totalLoad = 1
   443  	}
   444  
   445  	avgLoadPerNode := float64(c.totalLoad / int64(len(c.loadMap)))
   446  	if avgLoadPerNode == 0 {
   447  		avgLoadPerNode = 1
   448  	}
   449  	avgLoadPerNode = math.Ceil(avgLoadPerNode * 1.25)
   450  	return int64(avgLoadPerNode)
   451  }
   452  
   453  func (c *Consistent) loadOK(host string) bool {
   454  	// a safety check if someone performed c.Done more than needed
   455  	if c.totalLoad < 0 {
   456  		c.totalLoad = 0
   457  	}
   458  
   459  	var avgLoadPerNode float64
   460  	avgLoadPerNode = float64((c.totalLoad + 1) / int64(len(c.loadMap)))
   461  	if avgLoadPerNode == 0 {
   462  		avgLoadPerNode = 1
   463  	}
   464  	avgLoadPerNode = math.Ceil(avgLoadPerNode * 1.25)
   465  
   466  	bhost, ok := c.loadMap[host]
   467  	if !ok {
   468  		panic("given host(" + bhost.Name + ") not in loadsMap")
   469  	}
   470  
   471  	if float64(bhost.Load)+1 <= avgLoadPerNode {
   472  		return true
   473  	}
   474  
   475  	return false
   476  }
   477  
   478  func (c *Consistent) delSlice(val uint32) {
   479  	idx := -1
   480  	l := 0
   481  	r := len(c.sortedHashes) - 1
   482  	for l <= r {
   483  		m := (l + r) / 2
   484  		if c.sortedHashes[m] == val {
   485  			idx = m
   486  			break
   487  		} else if c.sortedHashes[m] < val {
   488  			l = m + 1
   489  		} else if c.sortedHashes[m] > val {
   490  			r = m - 1
   491  		}
   492  	}
   493  	if idx != -1 {
   494  		c.sortedHashes = append(c.sortedHashes[:idx], c.sortedHashes[idx+1:]...)
   495  	}
   496  }