github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/memberlist/util.go (about)

     1  package memberlist
     2  
     3  import (
     4  	"bytes"
     5  	"compress/lzw"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"io"
     9  	"math"
    10  	"math/rand"
    11  	"net"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/hashicorp/go-msgpack/codec"
    17  	"github.com/sean-/seed"
    18  )
    19  
    20  // pushPullScale is the minimum number of nodes
    21  // before we start scaling the push/pull timing. The scale
    22  // effect is the log2(Nodes) - log2(pushPullScale). This means
    23  // that the 33rd node will cause us to double the interval,
    24  // while the 65th will triple it.
    25  const pushPullScaleThreshold = 32
    26  
    27  const (
    28  	// Constant litWidth 2-8
    29  	lzwLitWidth = 8
    30  )
    31  
    32  func init() {
    33  	seed.Init()
    34  }
    35  
    36  // Decode reverses the encode operation on a byte slice input
    37  func decode(buf []byte, out interface{}) error {
    38  	r := bytes.NewReader(buf)
    39  	hd := codec.MsgpackHandle{}
    40  	dec := codec.NewDecoder(r, &hd)
    41  	return dec.Decode(out)
    42  }
    43  
    44  // Encode writes an encoded object to a new bytes buffer
    45  func encode(msgType messageType, in interface{}) (*bytes.Buffer, error) {
    46  	buf := bytes.NewBuffer(nil)
    47  	buf.WriteByte(uint8(msgType))
    48  	hd := codec.MsgpackHandle{}
    49  	enc := codec.NewEncoder(buf, &hd)
    50  	err := enc.Encode(in)
    51  	return buf, err
    52  }
    53  
    54  // Returns a random offset between 0 and n
    55  func randomOffset(n int) int {
    56  	if n == 0 {
    57  		return 0
    58  	}
    59  	return int(rand.Uint32() % uint32(n))
    60  }
    61  
    62  // suspicionTimeout computes the timeout that should be used when
    63  // a node is suspected
    64  func suspicionTimeout(suspicionMult, n int, interval time.Duration) time.Duration {
    65  	nodeScale := math.Max(1.0, math.Log10(math.Max(1.0, float64(n))))
    66  	// multiply by 1000 to keep some precision because time.Duration is an int64 type
    67  	timeout := time.Duration(suspicionMult) * time.Duration(nodeScale*1000) * interval / 1000
    68  	return timeout
    69  }
    70  
    71  // retransmitLimit computes the limit of retransmissions
    72  func retransmitLimit(retransmitMult, n int) int {
    73  	nodeScale := math.Ceil(math.Log10(float64(n + 1)))
    74  	limit := retransmitMult * int(nodeScale)
    75  	return limit
    76  }
    77  
    78  // shuffleNodes randomly shuffles the input nodes using the Fisher-Yates shuffle
    79  func shuffleNodes(nodes []*nodeState) {
    80  	n := len(nodes)
    81  	rand.Shuffle(n, func(i, j int) {
    82  		nodes[i], nodes[j] = nodes[j], nodes[i]
    83  	})
    84  }
    85  
    86  // pushPushScale is used to scale the time interval at which push/pull
    87  // syncs take place. It is used to prevent network saturation as the
    88  // cluster size grows
    89  func pushPullScale(interval time.Duration, n int) time.Duration {
    90  	// Don't scale until we cross the threshold
    91  	if n <= pushPullScaleThreshold {
    92  		return interval
    93  	}
    94  
    95  	multiplier := math.Ceil(math.Log2(float64(n))-math.Log2(pushPullScaleThreshold)) + 1.0
    96  	return time.Duration(multiplier) * interval
    97  }
    98  
    99  // moveDeadNodes moves dead and left nodes that that have not changed during the gossipToTheDeadTime interval
   100  // to the end of the slice and returns the index of the first moved node.
   101  func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int {
   102  	numDead := 0
   103  	n := len(nodes)
   104  	for i := 0; i < n-numDead; i++ {
   105  		if !nodes[i].DeadOrLeft() {
   106  			continue
   107  		}
   108  
   109  		// Respect the gossip to the dead interval
   110  		if time.Since(nodes[i].StateChange) <= gossipToTheDeadTime {
   111  			continue
   112  		}
   113  
   114  		// Move this node to the end
   115  		nodes[i], nodes[n-numDead-1] = nodes[n-numDead-1], nodes[i]
   116  		numDead++
   117  		i--
   118  	}
   119  	return n - numDead
   120  }
   121  
   122  // kRandomNodes is used to select up to k random Nodes, excluding any nodes where
   123  // the exclude function returns true. It is possible that less than k nodes are
   124  // returned.
   125  func kRandomNodes(k int, nodes []*nodeState, exclude func(*nodeState) bool) []Node {
   126  	n := len(nodes)
   127  	kNodes := make([]Node, 0, k)
   128  OUTER:
   129  	// Probe up to 3*n times, with large n this is not necessary
   130  	// since k << n, but with small n we want search to be
   131  	// exhaustive
   132  	for i := 0; i < 3*n && len(kNodes) < k; i++ {
   133  		// Get random nodeState
   134  		idx := randomOffset(n)
   135  		state := nodes[idx]
   136  
   137  		// Give the filter a shot at it.
   138  		if exclude != nil && exclude(state) {
   139  			continue OUTER
   140  		}
   141  
   142  		// Check if we have this node already
   143  		for j := 0; j < len(kNodes); j++ {
   144  			if state.Node.Name == kNodes[j].Name {
   145  				continue OUTER
   146  			}
   147  		}
   148  
   149  		// Append the node
   150  		kNodes = append(kNodes, state.Node)
   151  	}
   152  	return kNodes
   153  }
   154  
   155  // makeCompoundMessage takes a list of messages and generates
   156  // a single compound message containing all of them
   157  func makeCompoundMessage(msgs [][]byte) *bytes.Buffer {
   158  	// Create a local buffer
   159  	buf := bytes.NewBuffer(nil)
   160  
   161  	// Write out the type
   162  	buf.WriteByte(uint8(compoundMsg))
   163  
   164  	// Write out the number of message
   165  	buf.WriteByte(uint8(len(msgs)))
   166  
   167  	// Add the message lengths
   168  	for _, m := range msgs {
   169  		binary.Write(buf, binary.BigEndian, uint16(len(m)))
   170  	}
   171  
   172  	// Append the messages
   173  	for _, m := range msgs {
   174  		buf.Write(m)
   175  	}
   176  
   177  	return buf
   178  }
   179  
   180  // decodeCompoundMessage splits a compound message and returns
   181  // the slices of individual messages. Also returns the number
   182  // of truncated messages and any potential error
   183  func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
   184  	if len(buf) < 1 {
   185  		err = fmt.Errorf("missing compound length byte")
   186  		return
   187  	}
   188  	numParts := int(buf[0])
   189  	buf = buf[1:]
   190  
   191  	// Check we have enough bytes
   192  	if len(buf) < numParts*2 {
   193  		err = fmt.Errorf("truncated len slice")
   194  		return
   195  	}
   196  
   197  	// Decode the lengths
   198  	lengths := make([]uint16, numParts)
   199  	for i := 0; i < numParts; i++ {
   200  		lengths[i] = binary.BigEndian.Uint16(buf[i*2 : i*2+2])
   201  	}
   202  	buf = buf[numParts*2:]
   203  
   204  	// Split each message
   205  	for idx, msgLen := range lengths {
   206  		if len(buf) < int(msgLen) {
   207  			trunc = numParts - idx
   208  			return
   209  		}
   210  
   211  		// Extract the slice, seek past on the buffer
   212  		slice := buf[:msgLen]
   213  		buf = buf[msgLen:]
   214  		parts = append(parts, slice)
   215  	}
   216  	return
   217  }
   218  
   219  // compressPayload takes an opaque input buffer, compresses it
   220  // and wraps it in a compress{} message that is encoded.
   221  func compressPayload(inp []byte) (*bytes.Buffer, error) {
   222  	var buf bytes.Buffer
   223  	compressor := lzw.NewWriter(&buf, lzw.LSB, lzwLitWidth)
   224  
   225  	_, err := compressor.Write(inp)
   226  	if err != nil {
   227  		return nil, err
   228  	}
   229  
   230  	// Ensure we flush everything out
   231  	if err := compressor.Close(); err != nil {
   232  		return nil, err
   233  	}
   234  
   235  	// Create a compressed message
   236  	c := compress{
   237  		Algo: lzwAlgo,
   238  		Buf:  buf.Bytes(),
   239  	}
   240  	return encode(compressMsg, &c)
   241  }
   242  
   243  // decompressPayload is used to unpack an encoded compress{}
   244  // message and return its payload uncompressed
   245  func decompressPayload(msg []byte) ([]byte, error) {
   246  	// Decode the message
   247  	var c compress
   248  	if err := decode(msg, &c); err != nil {
   249  		return nil, err
   250  	}
   251  	return decompressBuffer(&c)
   252  }
   253  
   254  // decompressBuffer is used to decompress the buffer of
   255  // a single compress message, handling multiple algorithms
   256  func decompressBuffer(c *compress) ([]byte, error) {
   257  	// Verify the algorithm
   258  	if c.Algo != lzwAlgo {
   259  		return nil, fmt.Errorf("Cannot decompress unknown algorithm %d", c.Algo)
   260  	}
   261  
   262  	// Create a uncompressor
   263  	uncomp := lzw.NewReader(bytes.NewReader(c.Buf), lzw.LSB, lzwLitWidth)
   264  	defer uncomp.Close()
   265  
   266  	// Read all the data
   267  	var b bytes.Buffer
   268  	_, err := io.Copy(&b, uncomp)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  
   273  	// Return the uncompressed bytes
   274  	return b.Bytes(), nil
   275  }
   276  
   277  // joinHostPort returns the host:port form of an address, for use with a
   278  // transport.
   279  func joinHostPort(host string, port uint16) string {
   280  	return net.JoinHostPort(host, strconv.Itoa(int(port)))
   281  }
   282  
   283  // hasPort is given a string of the form "host", "host:port", "ipv6::address",
   284  // or "[ipv6::address]:port", and returns true if the string includes a port.
   285  func hasPort(s string) bool {
   286  	// IPv6 address in brackets.
   287  	if strings.LastIndex(s, "[") == 0 {
   288  		return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
   289  	}
   290  
   291  	// Otherwise the presence of a single colon determines if there's a port
   292  	// since IPv6 addresses outside of brackets (count > 1) can't have a
   293  	// port.
   294  	return strings.Count(s, ":") == 1
   295  }
   296  
   297  // ensurePort makes sure the given string has a port number on it, otherwise it
   298  // appends the given port as a default.
   299  func ensurePort(s string, port int) string {
   300  	if hasPort(s) {
   301  		return s
   302  	}
   303  
   304  	// If this is an IPv6 address, the join call will add another set of
   305  	// brackets, so we have to trim before we add the default port.
   306  	s = strings.Trim(s, "[]")
   307  	s = net.JoinHostPort(s, strconv.Itoa(port))
   308  	return s
   309  }