github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/memberlist/util.go (about) 1 package memberlist 2 3 import ( 4 "bytes" 5 "compress/lzw" 6 "encoding/binary" 7 "fmt" 8 "io" 9 "math" 10 "math/rand" 11 "net" 12 "strconv" 13 "strings" 14 "time" 15 16 "github.com/hashicorp/go-msgpack/codec" 17 "github.com/sean-/seed" 18 ) 19 20 // pushPullScale is the minimum number of nodes 21 // before we start scaling the push/pull timing. The scale 22 // effect is the log2(Nodes) - log2(pushPullScale). This means 23 // that the 33rd node will cause us to double the interval, 24 // while the 65th will triple it. 25 const pushPullScaleThreshold = 32 26 27 const ( 28 // Constant litWidth 2-8 29 lzwLitWidth = 8 30 ) 31 32 func init() { 33 seed.Init() 34 } 35 36 // Decode reverses the encode operation on a byte slice input 37 func decode(buf []byte, out interface{}) error { 38 r := bytes.NewReader(buf) 39 hd := codec.MsgpackHandle{} 40 dec := codec.NewDecoder(r, &hd) 41 return dec.Decode(out) 42 } 43 44 // Encode writes an encoded object to a new bytes buffer 45 func encode(msgType messageType, in interface{}) (*bytes.Buffer, error) { 46 buf := bytes.NewBuffer(nil) 47 buf.WriteByte(uint8(msgType)) 48 hd := codec.MsgpackHandle{} 49 enc := codec.NewEncoder(buf, &hd) 50 err := enc.Encode(in) 51 return buf, err 52 } 53 54 // Returns a random offset between 0 and n 55 func randomOffset(n int) int { 56 if n == 0 { 57 return 0 58 } 59 return int(rand.Uint32() % uint32(n)) 60 } 61 62 // suspicionTimeout computes the timeout that should be used when 63 // a node is suspected 64 func suspicionTimeout(suspicionMult, n int, interval time.Duration) time.Duration { 65 nodeScale := math.Max(1.0, math.Log10(math.Max(1.0, float64(n)))) 66 // multiply by 1000 to keep some precision because time.Duration is an int64 type 67 timeout := time.Duration(suspicionMult) * time.Duration(nodeScale*1000) * interval / 1000 68 return timeout 69 } 70 71 // retransmitLimit computes the limit of retransmissions 72 func retransmitLimit(retransmitMult, n int) int { 73 nodeScale := math.Ceil(math.Log10(float64(n + 1))) 74 limit := retransmitMult * int(nodeScale) 75 return limit 76 } 77 78 // shuffleNodes randomly shuffles the input nodes using the Fisher-Yates shuffle 79 func shuffleNodes(nodes []*nodeState) { 80 n := len(nodes) 81 rand.Shuffle(n, func(i, j int) { 82 nodes[i], nodes[j] = nodes[j], nodes[i] 83 }) 84 } 85 86 // pushPushScale is used to scale the time interval at which push/pull 87 // syncs take place. It is used to prevent network saturation as the 88 // cluster size grows 89 func pushPullScale(interval time.Duration, n int) time.Duration { 90 // Don't scale until we cross the threshold 91 if n <= pushPullScaleThreshold { 92 return interval 93 } 94 95 multiplier := math.Ceil(math.Log2(float64(n))-math.Log2(pushPullScaleThreshold)) + 1.0 96 return time.Duration(multiplier) * interval 97 } 98 99 // moveDeadNodes moves dead and left nodes that that have not changed during the gossipToTheDeadTime interval 100 // to the end of the slice and returns the index of the first moved node. 101 func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int { 102 numDead := 0 103 n := len(nodes) 104 for i := 0; i < n-numDead; i++ { 105 if !nodes[i].DeadOrLeft() { 106 continue 107 } 108 109 // Respect the gossip to the dead interval 110 if time.Since(nodes[i].StateChange) <= gossipToTheDeadTime { 111 continue 112 } 113 114 // Move this node to the end 115 nodes[i], nodes[n-numDead-1] = nodes[n-numDead-1], nodes[i] 116 numDead++ 117 i-- 118 } 119 return n - numDead 120 } 121 122 // kRandomNodes is used to select up to k random Nodes, excluding any nodes where 123 // the exclude function returns true. It is possible that less than k nodes are 124 // returned. 125 func kRandomNodes(k int, nodes []*nodeState, exclude func(*nodeState) bool) []Node { 126 n := len(nodes) 127 kNodes := make([]Node, 0, k) 128 OUTER: 129 // Probe up to 3*n times, with large n this is not necessary 130 // since k << n, but with small n we want search to be 131 // exhaustive 132 for i := 0; i < 3*n && len(kNodes) < k; i++ { 133 // Get random nodeState 134 idx := randomOffset(n) 135 state := nodes[idx] 136 137 // Give the filter a shot at it. 138 if exclude != nil && exclude(state) { 139 continue OUTER 140 } 141 142 // Check if we have this node already 143 for j := 0; j < len(kNodes); j++ { 144 if state.Node.Name == kNodes[j].Name { 145 continue OUTER 146 } 147 } 148 149 // Append the node 150 kNodes = append(kNodes, state.Node) 151 } 152 return kNodes 153 } 154 155 // makeCompoundMessage takes a list of messages and generates 156 // a single compound message containing all of them 157 func makeCompoundMessage(msgs [][]byte) *bytes.Buffer { 158 // Create a local buffer 159 buf := bytes.NewBuffer(nil) 160 161 // Write out the type 162 buf.WriteByte(uint8(compoundMsg)) 163 164 // Write out the number of message 165 buf.WriteByte(uint8(len(msgs))) 166 167 // Add the message lengths 168 for _, m := range msgs { 169 binary.Write(buf, binary.BigEndian, uint16(len(m))) 170 } 171 172 // Append the messages 173 for _, m := range msgs { 174 buf.Write(m) 175 } 176 177 return buf 178 } 179 180 // decodeCompoundMessage splits a compound message and returns 181 // the slices of individual messages. Also returns the number 182 // of truncated messages and any potential error 183 func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) { 184 if len(buf) < 1 { 185 err = fmt.Errorf("missing compound length byte") 186 return 187 } 188 numParts := int(buf[0]) 189 buf = buf[1:] 190 191 // Check we have enough bytes 192 if len(buf) < numParts*2 { 193 err = fmt.Errorf("truncated len slice") 194 return 195 } 196 197 // Decode the lengths 198 lengths := make([]uint16, numParts) 199 for i := 0; i < numParts; i++ { 200 lengths[i] = binary.BigEndian.Uint16(buf[i*2 : i*2+2]) 201 } 202 buf = buf[numParts*2:] 203 204 // Split each message 205 for idx, msgLen := range lengths { 206 if len(buf) < int(msgLen) { 207 trunc = numParts - idx 208 return 209 } 210 211 // Extract the slice, seek past on the buffer 212 slice := buf[:msgLen] 213 buf = buf[msgLen:] 214 parts = append(parts, slice) 215 } 216 return 217 } 218 219 // compressPayload takes an opaque input buffer, compresses it 220 // and wraps it in a compress{} message that is encoded. 221 func compressPayload(inp []byte) (*bytes.Buffer, error) { 222 var buf bytes.Buffer 223 compressor := lzw.NewWriter(&buf, lzw.LSB, lzwLitWidth) 224 225 _, err := compressor.Write(inp) 226 if err != nil { 227 return nil, err 228 } 229 230 // Ensure we flush everything out 231 if err := compressor.Close(); err != nil { 232 return nil, err 233 } 234 235 // Create a compressed message 236 c := compress{ 237 Algo: lzwAlgo, 238 Buf: buf.Bytes(), 239 } 240 return encode(compressMsg, &c) 241 } 242 243 // decompressPayload is used to unpack an encoded compress{} 244 // message and return its payload uncompressed 245 func decompressPayload(msg []byte) ([]byte, error) { 246 // Decode the message 247 var c compress 248 if err := decode(msg, &c); err != nil { 249 return nil, err 250 } 251 return decompressBuffer(&c) 252 } 253 254 // decompressBuffer is used to decompress the buffer of 255 // a single compress message, handling multiple algorithms 256 func decompressBuffer(c *compress) ([]byte, error) { 257 // Verify the algorithm 258 if c.Algo != lzwAlgo { 259 return nil, fmt.Errorf("Cannot decompress unknown algorithm %d", c.Algo) 260 } 261 262 // Create a uncompressor 263 uncomp := lzw.NewReader(bytes.NewReader(c.Buf), lzw.LSB, lzwLitWidth) 264 defer uncomp.Close() 265 266 // Read all the data 267 var b bytes.Buffer 268 _, err := io.Copy(&b, uncomp) 269 if err != nil { 270 return nil, err 271 } 272 273 // Return the uncompressed bytes 274 return b.Bytes(), nil 275 } 276 277 // joinHostPort returns the host:port form of an address, for use with a 278 // transport. 279 func joinHostPort(host string, port uint16) string { 280 return net.JoinHostPort(host, strconv.Itoa(int(port))) 281 } 282 283 // hasPort is given a string of the form "host", "host:port", "ipv6::address", 284 // or "[ipv6::address]:port", and returns true if the string includes a port. 285 func hasPort(s string) bool { 286 // IPv6 address in brackets. 287 if strings.LastIndex(s, "[") == 0 { 288 return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") 289 } 290 291 // Otherwise the presence of a single colon determines if there's a port 292 // since IPv6 addresses outside of brackets (count > 1) can't have a 293 // port. 294 return strings.Count(s, ":") == 1 295 } 296 297 // ensurePort makes sure the given string has a port number on it, otherwise it 298 // appends the given port as a default. 299 func ensurePort(s string, port int) string { 300 if hasPort(s) { 301 return s 302 } 303 304 // If this is an IPv6 address, the join call will add another set of 305 // brackets, so we have to trim before we add the default port. 306 s = strings.Trim(s, "[]") 307 s = net.JoinHostPort(s, strconv.Itoa(port)) 308 return s 309 }