github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/memberlist/net.go (about) 1 package memberlist 2 3 import ( 4 "bufio" 5 "bytes" 6 "encoding/binary" 7 "fmt" 8 "github.com/unionj-cloud/go-doudou/v2/toolkit/stringutils" 9 "hash/crc32" 10 "io" 11 "net" 12 "sync/atomic" 13 "time" 14 15 metrics "github.com/armon/go-metrics" 16 "github.com/hashicorp/go-msgpack/codec" 17 ) 18 19 // This is the minimum and maximum protocol version that we can 20 // _understand_. We're allowed to speak at any version within this 21 // range. This range is inclusive. 22 const ( 23 ProtocolVersionMin uint8 = 1 24 25 // Version 3 added support for TCP pings but we kept the default 26 // protocol version at 2 to ease transition to this new feature. 27 // A memberlist speaking version 2 of the protocol will attempt 28 // to TCP ping another memberlist who understands version 3 or 29 // greater. 30 // 31 // Version 4 added support for nacks as part of indirect probes. 32 // A memberlist speaking version 2 of the protocol will expect 33 // nacks from another memberlist who understands version 4 or 34 // greater, and likewise nacks will be sent to memberlists who 35 // understand version 4 or greater. 36 ProtocolVersion2Compatible = 2 37 38 ProtocolVersionMax = 5 39 ) 40 41 // messageType is an integer ID of a type of message that can be received 42 // on network channels from other members. 43 type messageType uint8 44 45 // The list of available message types. 46 const ( 47 pingMsg messageType = iota 48 indirectPingMsg 49 ackRespMsg 50 suspectMsg 51 aliveMsg 52 deadMsg 53 pushPullMsg 54 compoundMsg 55 userMsg // User message, not handled by us 56 compressMsg 57 encryptMsg 58 nackRespMsg 59 hasCrcMsg 60 errMsg 61 weightMsg 62 ) 63 64 // compressionType is used to specify the compression algorithm 65 type compressionType uint8 66 67 const ( 68 lzwAlgo compressionType = iota 69 ) 70 71 const ( 72 MetaMaxSize = 512 // Maximum size for node meta data 73 compoundHeaderOverhead = 2 // Assumed header overhead 74 compoundOverhead = 2 // Assumed overhead per entry in compoundHeader 75 userMsgOverhead = 1 76 weightMsgOverhead = 1 77 blockingWarning = 10 * time.Millisecond // Warn if a UDP packet takes this long to process 78 maxPushStateBytes = 20 * 1024 * 1024 79 maxPushPullRequests = 128 // Maximum number of concurrent push/pull requests 80 ) 81 82 // ping request sent directly to node 83 type ping struct { 84 SeqNo uint32 85 86 // Node is sent so the target can verify they are 87 // the intended recipient. This is to protect again an agent 88 // restart with a new name. 89 Node string 90 91 SourceAddr string `codec:",omitempty"` // Source address, used for a direct reply 92 SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply 93 SourceNode string `codec:",omitempty"` // Source name, used for a direct reply 94 } 95 96 // indirect ping sent to an indirect node 97 type indirectPingReq struct { 98 SeqNo uint32 99 Target string 100 Port uint16 101 102 // Node is sent so the target can verify they are 103 // the intended recipient. This is to protect against an agent 104 // restart with a new name. 105 Node string 106 107 Nack bool // true if we'd like a nack back 108 109 SourceAddr string `codec:",omitempty"` // Source address, used for a direct reply 110 SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply 111 SourceNode string `codec:",omitempty"` // Source name, used for a direct reply 112 } 113 114 // ack response is sent for a ping 115 type ackResp struct { 116 SeqNo uint32 117 Payload []byte 118 } 119 120 // nack response is sent for an indirect ping when the pinger doesn't hear from 121 // the ping-ee within the configured timeout. This lets the original node know 122 // that the indirect ping attempt happened but didn't succeed. 123 type nackResp struct { 124 SeqNo uint32 125 } 126 127 // err response is sent to relay the error from the remote end 128 type errResp struct { 129 Error string 130 } 131 132 // suspect is broadcast when we suspect a node is dead 133 type suspect struct { 134 Incarnation uint32 135 Node string 136 From string // Include who is suspecting 137 } 138 139 // alive is broadcast when we know a node is alive. 140 // Overloaded for nodes joining 141 type alive struct { 142 Incarnation uint32 143 Node string 144 Addr string 145 Port uint16 146 Meta []byte 147 148 // The versions of the protocol/delegate that are being spoken, order: 149 // pmin, pmax, pcur, dmin, dmax, dcur 150 Vsn []uint8 151 } 152 153 func NewAlive(incarnation uint32, node string, addr string, port uint16, meta []byte, vsn []uint8) alive { 154 return alive{Incarnation: incarnation, Node: node, Addr: addr, Port: port, Meta: meta, Vsn: vsn} 155 } 156 157 // dead is broadcast when we confirm a node is dead 158 // Overloaded for nodes leaving 159 type dead struct { 160 Incarnation uint32 161 Node string 162 From string // Include who is suspecting 163 } 164 165 // weight is broadcast when we send local node weight 166 type weight struct { 167 Incarnation uint32 168 // Node whose weight 169 Node string 170 // From message from which node 171 From string 172 // Weight the weight for Node 173 Weight int 174 // WeightAt is UTC timestamp which the weight calculated at, used for ignoring old weight messages in milliseconds 175 WeightAt int64 176 } 177 178 func NewWeight(incarnation uint32, node string, from string, wei int, weightAt int64) *weight { 179 return &weight{Incarnation: incarnation, Node: node, From: from, Weight: wei, WeightAt: weightAt} 180 } 181 182 // pushPullHeader is used to inform the 183 // otherside how many states we are transferring 184 type pushPullHeader struct { 185 Nodes int 186 UserStateLen int // Encodes the byte lengh of user state 187 Join bool // Is this a join request or a anti-entropy run 188 } 189 190 // userMsgHeader is used to encapsulate a userMsg 191 type userMsgHeader struct { 192 UserMsgLen int // Encodes the byte lengh of user state 193 } 194 195 // pushNodeState is used for pushPullReq when we are 196 // transferring out node states 197 type pushNodeState struct { 198 Name string 199 Addr string 200 Port uint16 201 Meta []byte 202 Incarnation uint32 203 State NodeStateType 204 Vsn []uint8 // Protocol versions 205 } 206 207 // compress is used to wrap an underlying payload 208 // using a specified compression algorithm 209 type compress struct { 210 Algo compressionType 211 Buf []byte 212 } 213 214 // msgHandoff is used to transfer a message between goroutines 215 type msgHandoff struct { 216 msgType messageType 217 buf []byte 218 from net.Addr 219 } 220 221 // encryptionVersion returns the encryption version to use 222 func (m *Memberlist) encryptionVersion() encryptionVersion { 223 switch m.ProtocolVersion() { 224 case 1: 225 return 0 226 default: 227 return 1 228 } 229 } 230 231 // streamListen is a long running goroutine that pulls incoming streams from the 232 // transport and hands them off for processing. 233 func (m *Memberlist) streamListen() { 234 for { 235 select { 236 case conn := <-m.transport.StreamCh(): 237 go m.handleConn(conn) 238 239 case <-m.shutdownCh: 240 return 241 } 242 } 243 } 244 245 // handleConn handles a single incoming stream connection from the transport. 246 func (m *Memberlist) handleConn(conn net.Conn) { 247 defer conn.Close() 248 m.logger.Printf("[DEBUG] memberlist: Stream connection %s", LogConn(conn)) 249 250 metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1) 251 252 from := conn.RemoteAddr() 253 if err := m.ensureCanConnect(from); err != nil { 254 m.logger.Printf("[DEBUG] memberlist: Blocked message: %s from %s", err, LogAddress(from)) 255 return 256 } 257 258 conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) 259 260 msgType, bufConn, dec, err := m.readStream(conn) 261 if err != nil { 262 if err != io.EOF { 263 m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn)) 264 265 resp := errResp{err.Error()} 266 out, err := encode(errMsg, &resp) 267 if err != nil { 268 m.logger.Printf("[ERR] memberlist: Failed to encode error response: %s", err) 269 return 270 } 271 272 err = m.rawSendMsgStream(conn, out.Bytes()) 273 if err != nil { 274 m.logger.Printf("[ERR] memberlist: Failed to send error: %s %s", err, LogConn(conn)) 275 return 276 } 277 } 278 return 279 } 280 281 switch msgType { 282 case userMsg: 283 if err := m.readUserMsg(bufConn, dec); err != nil { 284 m.logger.Printf("[ERR] memberlist: Failed to receive user message: %s %s", err, LogConn(conn)) 285 } 286 case pushPullMsg: 287 // Increment counter of pending push/pulls 288 numConcurrent := atomic.AddUint32(&m.pushPullReq, 1) 289 defer atomic.AddUint32(&m.pushPullReq, ^uint32(0)) 290 291 // Check if we have too many open push/pull requests 292 if numConcurrent >= maxPushPullRequests { 293 m.logger.Printf("[ERR] memberlist: Too many pending push/pull requests") 294 return 295 } 296 297 join, remoteNodes, userState, err := m.readRemoteState(bufConn, dec) 298 if err != nil { 299 m.logger.Printf("[ERR] memberlist: Failed to read remote state: %s %s", err, LogConn(conn)) 300 return 301 } 302 303 if join { 304 remote := remoteNodes[0] 305 if m.config.IPMustBeChecked() { 306 if stringutils.IsNotEmpty(remote.Addr) { 307 if err := m.config.AddrAllowed(remote.Addr); err != nil { 308 m.logger.Printf("[DEBUG] memberlist: Blocked join.Addr=%s message from: %s %s", remote.Addr, err, LogAddress(from)) 309 return 310 } 311 } 312 } 313 } 314 315 if err := m.sendLocalState(conn, join); err != nil { 316 m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn)) 317 return 318 } 319 320 if err := m.mergeRemoteState(join, remoteNodes, userState); err != nil { 321 m.logger.Printf("[ERR] memberlist: Failed push/pull merge: %s %s", err, LogConn(conn)) 322 return 323 } 324 case pingMsg: 325 var p ping 326 if err := dec.Decode(&p); err != nil { 327 m.logger.Printf("[ERR] memberlist: Failed to decode ping: %s %s", err, LogConn(conn)) 328 return 329 } 330 331 if p.Node != "" && p.Node != m.config.Name { 332 m.logger.Printf("[WARN] memberlist: Got ping for unexpected node %s %s", p.Node, LogConn(conn)) 333 return 334 } 335 336 if m.config.IPMustBeChecked() { 337 if stringutils.IsNotEmpty(p.SourceAddr) { 338 if err := m.config.AddrAllowed(p.SourceAddr); err != nil { 339 m.logger.Printf("[DEBUG] memberlist: Blocked ping.Addr=%s message from: %s %s", p.SourceAddr, err, LogAddress(from)) 340 return 341 } 342 } 343 } 344 345 ack := ackResp{p.SeqNo, nil} 346 out, err := encode(ackRespMsg, &ack) 347 if err != nil { 348 m.logger.Printf("[ERR] memberlist: Failed to encode ack: %s", err) 349 return 350 } 351 352 err = m.rawSendMsgStream(conn, out.Bytes()) 353 if err != nil { 354 m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogConn(conn)) 355 return 356 } 357 default: 358 m.logger.Printf("[ERR] memberlist: Received invalid msgType (%d) %s", msgType, LogConn(conn)) 359 } 360 } 361 362 // packetListen is a long-running goroutine that pulls packets out of the 363 // transport and hands them off for processing. 364 func (m *Memberlist) packetListen() { 365 for { 366 select { 367 case packet := <-m.transport.PacketCh(): 368 m.ingestPacket(packet.Buf, packet.From, packet.Timestamp) 369 370 case <-m.shutdownCh: 371 return 372 } 373 } 374 } 375 376 func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) { 377 if err := m.ensureCanConnect(from); err != nil { 378 m.logger.Printf("[DEBUG] memberlist: Blocked message: %s from %s", err, LogAddress(from)) 379 return 380 } 381 // Check if encryption is enabled 382 if m.config.EncryptionEnabled() { 383 // Decrypt the payload 384 plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil) 385 if err != nil { 386 if !m.config.GossipVerifyIncoming { 387 // Treat the message as plaintext 388 plain = buf 389 } else { 390 m.logger.Printf("[ERR] memberlist: Decrypt packet failed: %v %s", err, LogAddress(from)) 391 return 392 } 393 } 394 395 // Continue processing the plaintext buffer 396 buf = plain 397 } 398 399 // See if there's a checksum included to verify the contents of the message 400 if len(buf) >= 5 && messageType(buf[0]) == hasCrcMsg { 401 crc := crc32.ChecksumIEEE(buf[5:]) 402 expected := binary.BigEndian.Uint32(buf[1:5]) 403 if crc != expected { 404 m.logger.Printf("[WARN] memberlist: Got invalid checksum for UDP packet: %x, %x", crc, expected) 405 return 406 } 407 m.handleCommand(buf[5:], from, timestamp) 408 } else { 409 m.handleCommand(buf, from, timestamp) 410 } 411 } 412 413 func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) { 414 if len(buf) < 1 { 415 m.logger.Printf("[ERR] memberlist: missing message type byte %s", LogAddress(from)) 416 return 417 } 418 // Decode the message type 419 msgType := messageType(buf[0]) 420 buf = buf[1:] 421 422 // Switch on the msgType 423 switch msgType { 424 case compoundMsg: 425 m.handleCompound(buf, from, timestamp) 426 case compressMsg: 427 m.handleCompressed(buf, from, timestamp) 428 429 case pingMsg: 430 m.handlePing(buf, from) 431 case indirectPingMsg: 432 m.handleIndirectPing(buf, from) 433 case ackRespMsg: 434 m.handleAck(buf, from, timestamp) 435 case nackRespMsg: 436 m.handleNack(buf, from) 437 438 case suspectMsg: 439 fallthrough 440 case aliveMsg: 441 fallthrough 442 case deadMsg: 443 fallthrough 444 case weightMsg: 445 fallthrough 446 case userMsg: 447 // Determine the message queue, prioritize alive 448 queue := m.lowPriorityMsgQueue 449 if msgType == aliveMsg { 450 queue = m.highPriorityMsgQueue 451 } 452 453 // Check for overflow and append if not full 454 m.msgQueueLock.Lock() 455 if queue.Len() >= m.config.HandoffQueueDepth { 456 m.logger.Printf("[WARN] memberlist: handler queue full, dropping message (%d) %s", msgType, LogAddress(from)) 457 } else { 458 queue.PushBack(msgHandoff{msgType, buf, from}) 459 } 460 m.msgQueueLock.Unlock() 461 462 // Notify of pending message 463 select { 464 case m.handoffCh <- struct{}{}: 465 default: 466 } 467 468 default: 469 m.logger.Printf("[ERR] memberlist: msg type (%d) not supported %s", msgType, LogAddress(from)) 470 } 471 } 472 473 // getNextMessage returns the next message to process in priority order, using LIFO 474 func (m *Memberlist) getNextMessage() (msgHandoff, bool) { 475 m.msgQueueLock.Lock() 476 defer m.msgQueueLock.Unlock() 477 478 if el := m.highPriorityMsgQueue.Back(); el != nil { 479 m.highPriorityMsgQueue.Remove(el) 480 msg := el.Value.(msgHandoff) 481 return msg, true 482 } else if el := m.lowPriorityMsgQueue.Back(); el != nil { 483 m.lowPriorityMsgQueue.Remove(el) 484 msg := el.Value.(msgHandoff) 485 return msg, true 486 } 487 return msgHandoff{}, false 488 } 489 490 // packetHandler is a long-running goroutine that processes messages received 491 // over the packet interface, but is decoupled from the listener to avoid 492 // blocking the listener which may cause ping/ack messages to be delayed. 493 func (m *Memberlist) packetHandler() { 494 for { 495 select { 496 case <-m.handoffCh: 497 for { 498 msg, ok := m.getNextMessage() 499 if !ok { 500 break 501 } 502 msgType := msg.msgType 503 buf := msg.buf 504 from := msg.from 505 506 switch msgType { 507 case suspectMsg: 508 m.handleSuspect(buf, from) 509 case aliveMsg: 510 m.handleAlive(buf, from) 511 case deadMsg: 512 m.handleDead(buf, from) 513 case weightMsg: 514 m.handleWeight(buf, from) 515 case userMsg: 516 m.handleUser(buf, from) 517 default: 518 m.logger.Printf("[ERR] memberlist: Message type (%d) not supported %s (packet handler)", msgType, LogAddress(from)) 519 } 520 } 521 522 case <-m.shutdownCh: 523 return 524 } 525 } 526 } 527 528 func (m *Memberlist) handleCompound(buf []byte, from net.Addr, timestamp time.Time) { 529 // Decode the parts 530 trunc, parts, err := decodeCompoundMessage(buf) 531 if err != nil { 532 m.logger.Printf("[ERR] memberlist: Failed to decode compound request: %s %s", err, LogAddress(from)) 533 return 534 } 535 536 // Log any truncation 537 if trunc > 0 { 538 m.logger.Printf("[WARN] memberlist: Compound request had %d truncated messages %s", trunc, LogAddress(from)) 539 } 540 541 // Handle each message 542 for _, part := range parts { 543 m.handleCommand(part, from, timestamp) 544 } 545 } 546 547 func (m *Memberlist) handlePing(buf []byte, from net.Addr) { 548 var p ping 549 if err := decode(buf, &p); err != nil { 550 m.logger.Printf("[ERR] memberlist: Failed to decode ping request: %s %s", err, LogAddress(from)) 551 return 552 } 553 // If node is provided, verify that it is for us 554 if p.Node != "" && p.Node != m.config.Name { 555 m.logger.Printf("[WARN] memberlist: Got ping for unexpected node '%s' %s", p.Node, LogAddress(from)) 556 return 557 } 558 559 if m.config.IPMustBeChecked() { 560 if stringutils.IsNotEmpty(p.SourceAddr) { 561 if err := m.config.AddrAllowed(p.SourceAddr); err != nil { 562 m.logger.Printf("[DEBUG] memberlist: Blocked ping.Addr=%s message from: %s %s", p.SourceAddr, err, LogAddress(from)) 563 return 564 } 565 } 566 } 567 568 var ack ackResp 569 ack.SeqNo = p.SeqNo 570 if m.config.Ping != nil { 571 ack.Payload = m.config.Ping.AckPayload() 572 } 573 574 addr := "" 575 if len(p.SourceAddr) > 0 && p.SourcePort > 0 { 576 addr = joinHostPort(p.SourceAddr, p.SourcePort) 577 } else { 578 addr = from.String() 579 } 580 581 a := Address{ 582 Addr: addr, 583 Name: p.SourceNode, 584 } 585 if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil { 586 m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from)) 587 } 588 } 589 590 func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { 591 var ind indirectPingReq 592 if err := decode(buf, &ind); err != nil { 593 m.logger.Printf("[ERR] memberlist: Failed to decode indirect ping request: %s %s", err, LogAddress(from)) 594 return 595 } 596 597 if m.config.IPMustBeChecked() { 598 if stringutils.IsNotEmpty(ind.SourceAddr) { 599 if err := m.config.AddrAllowed(ind.SourceAddr); err != nil { 600 m.logger.Printf("[DEBUG] memberlist: Blocked indirectPing.Addr=%s message from: %s %s", ind.SourceAddr, err, LogAddress(from)) 601 return 602 } 603 } 604 } 605 606 // For proto versions < 2, there is no port provided. Mask old 607 // behavior by using the configured port. 608 if m.ProtocolVersion() < 2 || ind.Port == 0 { 609 ind.Port = uint16(m.config.BindPort) 610 } 611 612 // Send a ping to the correct host. 613 localSeqNo := m.nextSeqNo() 614 selfAddr, selfPort := m.getAdvertise() 615 ping := ping{ 616 SeqNo: localSeqNo, 617 Node: ind.Node, 618 // The outbound message is addressed FROM us. 619 SourceAddr: selfAddr, 620 SourcePort: selfPort, 621 SourceNode: m.config.Name, 622 } 623 624 // Forward the ack back to the requestor. If the request encodes an origin 625 // use that otherwise assume that the other end of the UDP socket is 626 // usable. 627 indAddr := "" 628 if len(ind.SourceAddr) > 0 && ind.SourcePort > 0 { 629 indAddr = joinHostPort(ind.SourceAddr, ind.SourcePort) 630 } else { 631 indAddr = from.String() 632 } 633 634 // Setup a response handler to relay the ack 635 cancelCh := make(chan struct{}) 636 respHandler := func(payload []byte, timestamp time.Time) { 637 // Try to prevent the nack if we've caught it in time. 638 close(cancelCh) 639 640 ack := ackResp{ind.SeqNo, nil} 641 a := Address{ 642 Addr: indAddr, 643 Name: ind.SourceNode, 644 } 645 if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil { 646 m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogStringAddress(indAddr)) 647 } 648 } 649 m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout) 650 651 // Send the ping. 652 addr := joinHostPort(ind.Target, ind.Port) 653 a := Address{ 654 Addr: addr, 655 Name: ind.Node, 656 } 657 if err := m.encodeAndSendMsg(a, pingMsg, &ping); err != nil { 658 m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s %s", err, LogStringAddress(indAddr)) 659 } 660 661 // Setup a timer to fire off a nack if no ack is seen in time. 662 if ind.Nack { 663 go func() { 664 select { 665 case <-cancelCh: 666 return 667 case <-time.After(m.config.ProbeTimeout): 668 nack := nackResp{ind.SeqNo} 669 a := Address{ 670 Addr: indAddr, 671 Name: ind.SourceNode, 672 } 673 if err := m.encodeAndSendMsg(a, nackRespMsg, &nack); err != nil { 674 m.logger.Printf("[ERR] memberlist: Failed to send nack: %s %s", err, LogStringAddress(indAddr)) 675 } 676 } 677 }() 678 } 679 } 680 681 func (m *Memberlist) handleAck(buf []byte, from net.Addr, timestamp time.Time) { 682 var ack ackResp 683 if err := decode(buf, &ack); err != nil { 684 m.logger.Printf("[ERR] memberlist: Failed to decode ack response: %s %s", err, LogAddress(from)) 685 return 686 } 687 m.invokeAckHandler(ack, timestamp) 688 } 689 690 func (m *Memberlist) handleNack(buf []byte, from net.Addr) { 691 var nack nackResp 692 if err := decode(buf, &nack); err != nil { 693 m.logger.Printf("[ERR] memberlist: Failed to decode nack response: %s %s", err, LogAddress(from)) 694 return 695 } 696 m.invokeNackHandler(nack) 697 } 698 699 func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) { 700 var sus suspect 701 if err := decode(buf, &sus); err != nil { 702 m.logger.Printf("[ERR] memberlist: Failed to decode suspect message: %s %s", err, LogAddress(from)) 703 return 704 } 705 m.suspectNode(&sus) 706 } 707 708 // ensureCanConnect return the IP from a RemoteAddress 709 // return error if this client must not connect 710 func (m *Memberlist) ensureCanConnect(from net.Addr) error { 711 if !m.config.IPMustBeChecked() { 712 return nil 713 } 714 source := from.String() 715 if source == "pipe" { 716 return nil 717 } 718 host, _, err := net.SplitHostPort(source) 719 if err != nil { 720 return err 721 } 722 return m.config.AddrAllowed(host) 723 } 724 725 func (m *Memberlist) handleAlive(buf []byte, from net.Addr) { 726 if err := m.ensureCanConnect(from); err != nil { 727 m.logger.Printf("[DEBUG] memberlist: Blocked alive message: %s %s", err, LogAddress(from)) 728 return 729 } 730 var live alive 731 if err := decode(buf, &live); err != nil { 732 m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s %s", err, LogAddress(from)) 733 return 734 } 735 if m.config.IPMustBeChecked() { 736 if stringutils.IsNotEmpty(live.Addr) { 737 if err := m.config.AddrAllowed(live.Addr); err != nil { 738 m.logger.Printf("[DEBUG] memberlist: Blocked alive.Addr=%s message from: %s %s", live.Addr, err, LogAddress(from)) 739 return 740 } 741 } 742 } 743 744 // For proto versions < 2, there is no port provided. Mask old 745 // behavior by using the configured port 746 if m.ProtocolVersion() < 2 || live.Port == 0 { 747 live.Port = uint16(m.config.BindPort) 748 } 749 750 m.aliveNode(&live, nil, false) 751 } 752 753 func (m *Memberlist) handleDead(buf []byte, from net.Addr) { 754 var d dead 755 if err := decode(buf, &d); err != nil { 756 m.logger.Printf("[ERR] memberlist: Failed to decode dead message: %s %s", err, LogAddress(from)) 757 return 758 } 759 m.deadNode(&d) 760 } 761 762 func (m *Memberlist) handleWeight(buf []byte, from net.Addr) { 763 var wei weight 764 if err := decode(buf, &wei); err != nil { 765 m.logger.Printf("[ERR] memberlist: Failed to decode weight message: %s %s", err, LogAddress(from)) 766 return 767 } 768 m.weightNode(&wei) 769 } 770 771 // handleUser is used to notify channels of incoming user data 772 func (m *Memberlist) handleUser(buf []byte, from net.Addr) { 773 d := m.config.Delegate 774 if d != nil { 775 d.NotifyMsg(buf) 776 } 777 } 778 779 // handleCompressed is used to unpack a compressed message 780 func (m *Memberlist) handleCompressed(buf []byte, from net.Addr, timestamp time.Time) { 781 // Try to decode the payload 782 payload, err := decompressPayload(buf) 783 if err != nil { 784 m.logger.Printf("[ERR] memberlist: Failed to decompress payload: %v %s", err, LogAddress(from)) 785 return 786 } 787 788 // Recursively handle the payload 789 m.handleCommand(payload, from, timestamp) 790 } 791 792 // encodeAndSendMsg is used to combine the encoding and sending steps 793 func (m *Memberlist) encodeAndSendMsg(a Address, msgType messageType, msg interface{}) error { 794 out, err := encode(msgType, msg) 795 if err != nil { 796 return err 797 } 798 if err := m.sendMsg(a, out.Bytes()); err != nil { 799 return err 800 } 801 return nil 802 } 803 804 // sendMsg is used to send a message via packet to another host. It will 805 // opportunistically create a compoundMsg and piggy back other broadcasts. 806 func (m *Memberlist) sendMsg(a Address, msg []byte) error { 807 // Check if we can piggy back any messages 808 bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead 809 if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { 810 bytesAvail -= encryptOverhead(m.encryptionVersion()) 811 } 812 extra := m.getBroadcasts(compoundOverhead, bytesAvail) 813 814 // Fast path if nothing to piggypack 815 if len(extra) == 0 { 816 return m.rawSendMsgPacket(a, nil, msg) 817 } 818 819 // Join all the messages 820 msgs := make([][]byte, 0, 1+len(extra)) 821 msgs = append(msgs, msg) 822 msgs = append(msgs, extra...) 823 824 // Create a compound message 825 compound := makeCompoundMessage(msgs) 826 827 // Send the message 828 return m.rawSendMsgPacket(a, nil, compound.Bytes()) 829 } 830 831 // rawSendMsgPacket is used to send message via packet to another host without 832 // modification, other than compression or encryption if enabled. 833 func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error { 834 if a.Name == "" && m.config.RequireNodeNames { 835 return errNodeNamesAreRequired 836 } 837 838 // Check if we have compression enabled 839 if m.config.EnableCompression { 840 buf, err := compressPayload(msg) 841 if err != nil { 842 m.logger.Printf("[WARN] memberlist: Failed to compress payload: %v", err) 843 } else { 844 // Only use compression if it reduced the size 845 if buf.Len() < len(msg) { 846 msg = buf.Bytes() 847 } 848 } 849 } 850 851 // Try to look up the destination node. Note this will only work if the 852 // bare ip address is used as the node name, which is not guaranteed. 853 if node == nil { 854 toAddr, _, err := net.SplitHostPort(a.Addr) 855 if err != nil { 856 m.logger.Printf("[ERR] memberlist: Failed to parse address %q: %v", a.Addr, err) 857 return err 858 } 859 m.nodeLock.RLock() 860 nodeState, ok := m.nodeMap[toAddr] 861 m.nodeLock.RUnlock() 862 if ok { 863 node = &nodeState.Node 864 } 865 } 866 867 // Add a CRC to the end of the payload if the recipient understands 868 // ProtocolVersion >= 5 869 if node != nil && node.PMax >= 5 { 870 crc := crc32.ChecksumIEEE(msg) 871 header := make([]byte, 5, 5+len(msg)) 872 header[0] = byte(hasCrcMsg) 873 binary.BigEndian.PutUint32(header[1:], crc) 874 msg = append(header, msg...) 875 } 876 877 // Check if we have encryption enabled 878 if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { 879 // Encrypt the payload 880 var buf bytes.Buffer 881 primaryKey := m.config.Keyring.GetPrimaryKey() 882 err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf) 883 if err != nil { 884 m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err) 885 return err 886 } 887 msg = buf.Bytes() 888 } 889 890 metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg))) 891 _, err := m.transport.WriteToAddress(msg, a) 892 return err 893 } 894 895 // rawSendMsgStream is used to stream a message to another host without 896 // modification, other than applying compression and encryption if enabled. 897 func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { 898 // Check if compression is enabled 899 if m.config.EnableCompression { 900 compBuf, err := compressPayload(sendBuf) 901 if err != nil { 902 m.logger.Printf("[ERROR] memberlist: Failed to compress payload: %v", err) 903 } else { 904 sendBuf = compBuf.Bytes() 905 } 906 } 907 908 // Check if encryption is enabled 909 if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { 910 crypt, err := m.encryptLocalState(sendBuf) 911 if err != nil { 912 m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err) 913 return err 914 } 915 sendBuf = crypt 916 } 917 918 // Write out the entire send buffer 919 metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf))) 920 921 if n, err := conn.Write(sendBuf); err != nil { 922 return err 923 } else if n != len(sendBuf) { 924 return fmt.Errorf("only %d of %d bytes written", n, len(sendBuf)) 925 } 926 927 return nil 928 } 929 930 // sendUserMsg is used to stream a user message to another host. 931 func (m *Memberlist) sendUserMsg(a Address, sendBuf []byte) error { 932 if a.Name == "" && m.config.RequireNodeNames { 933 return errNodeNamesAreRequired 934 } 935 936 conn, err := m.transport.DialAddressTimeout(a, m.config.TCPTimeout) 937 if err != nil { 938 return err 939 } 940 defer conn.Close() 941 942 bufConn := bytes.NewBuffer(nil) 943 if err := bufConn.WriteByte(byte(userMsg)); err != nil { 944 return err 945 } 946 947 header := userMsgHeader{UserMsgLen: len(sendBuf)} 948 hd := codec.MsgpackHandle{} 949 enc := codec.NewEncoder(bufConn, &hd) 950 if err := enc.Encode(&header); err != nil { 951 return err 952 } 953 if _, err := bufConn.Write(sendBuf); err != nil { 954 return err 955 } 956 return m.rawSendMsgStream(conn, bufConn.Bytes()) 957 } 958 959 // sendAndReceiveState is used to initiate a push/pull over a stream with a 960 // remote host. 961 func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, []byte, error) { 962 if a.Name == "" && m.config.RequireNodeNames { 963 return nil, nil, errNodeNamesAreRequired 964 } 965 966 // Attempt to connect 967 conn, err := m.transport.DialAddressTimeout(a, m.config.TCPTimeout) 968 if err != nil { 969 return nil, nil, err 970 } 971 defer conn.Close() 972 m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s %s", a.Name, conn.RemoteAddr()) 973 metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1) 974 975 // Send our state 976 if err := m.sendLocalState(conn, join); err != nil { 977 return nil, nil, err 978 } 979 980 conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) 981 msgType, bufConn, dec, err := m.readStream(conn) 982 if err != nil { 983 return nil, nil, err 984 } 985 986 if msgType == errMsg { 987 var resp errResp 988 if err := dec.Decode(&resp); err != nil { 989 return nil, nil, err 990 } 991 return nil, nil, fmt.Errorf("remote error: %v", resp.Error) 992 } 993 994 // Quit if not push/pull 995 if msgType != pushPullMsg { 996 err := fmt.Errorf("received invalid msgType (%d), expected pushPullMsg (%d) %s", msgType, pushPullMsg, LogConn(conn)) 997 return nil, nil, err 998 } 999 1000 // Read remote state 1001 _, remoteNodes, userState, err := m.readRemoteState(bufConn, dec) 1002 return remoteNodes, userState, err 1003 } 1004 1005 // sendLocalState is invoked to send our local state over a stream connection. 1006 func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { 1007 // Setup a deadline 1008 conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) 1009 1010 // Prepare the local node state 1011 m.nodeLock.RLock() 1012 localNodes := make([]pushNodeState, len(m.nodes)) 1013 for idx, n := range m.nodes { 1014 localNodes[idx].Name = n.Name 1015 localNodes[idx].Addr = n.Addr 1016 localNodes[idx].Port = n.Port 1017 localNodes[idx].Incarnation = n.Incarnation 1018 localNodes[idx].State = n.State 1019 localNodes[idx].Meta = n.Meta 1020 localNodes[idx].Vsn = []uint8{ 1021 n.PMin, n.PMax, n.PCur, 1022 n.DMin, n.DMax, n.DCur, 1023 } 1024 } 1025 m.nodeLock.RUnlock() 1026 1027 // Get the delegate state 1028 var userData []byte 1029 if m.config.Delegate != nil { 1030 userData = m.config.Delegate.LocalState(join) 1031 } 1032 1033 // Create a bytes buffer writer 1034 bufConn := bytes.NewBuffer(nil) 1035 1036 // Send our node state 1037 header := pushPullHeader{Nodes: len(localNodes), UserStateLen: len(userData), Join: join} 1038 hd := codec.MsgpackHandle{} 1039 enc := codec.NewEncoder(bufConn, &hd) 1040 1041 // Begin state push 1042 if _, err := bufConn.Write([]byte{byte(pushPullMsg)}); err != nil { 1043 return err 1044 } 1045 1046 if err := enc.Encode(&header); err != nil { 1047 return err 1048 } 1049 for i := 0; i < header.Nodes; i++ { 1050 if err := enc.Encode(&localNodes[i]); err != nil { 1051 return err 1052 } 1053 } 1054 1055 // Write the user state as well 1056 if userData != nil { 1057 if _, err := bufConn.Write(userData); err != nil { 1058 return err 1059 } 1060 } 1061 1062 // Get the send buffer 1063 return m.rawSendMsgStream(conn, bufConn.Bytes()) 1064 } 1065 1066 // encryptLocalState is used to help encrypt local state before sending 1067 func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { 1068 var buf bytes.Buffer 1069 1070 // Write the encryptMsg byte 1071 buf.WriteByte(byte(encryptMsg)) 1072 1073 // Write the size of the message 1074 sizeBuf := make([]byte, 4) 1075 encVsn := m.encryptionVersion() 1076 encLen := encryptedLength(encVsn, len(sendBuf)) 1077 binary.BigEndian.PutUint32(sizeBuf, uint32(encLen)) 1078 buf.Write(sizeBuf) 1079 1080 // Write the encrypted cipher text to the buffer 1081 key := m.config.Keyring.GetPrimaryKey() 1082 err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf) 1083 if err != nil { 1084 return nil, err 1085 } 1086 return buf.Bytes(), nil 1087 } 1088 1089 // decryptRemoteState is used to help decrypt the remote state 1090 func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { 1091 // Read in enough to determine message length 1092 cipherText := bytes.NewBuffer(nil) 1093 cipherText.WriteByte(byte(encryptMsg)) 1094 _, err := io.CopyN(cipherText, bufConn, 4) 1095 if err != nil { 1096 return nil, err 1097 } 1098 1099 // Ensure we aren't asked to download too much. This is to guard against 1100 // an attack vector where a huge amount of state is sent 1101 moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5]) 1102 if moreBytes > maxPushStateBytes { 1103 return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes) 1104 } 1105 1106 // Read in the rest of the payload 1107 _, err = io.CopyN(cipherText, bufConn, int64(moreBytes)) 1108 if err != nil { 1109 return nil, err 1110 } 1111 1112 // Decrypt the cipherText 1113 dataBytes := cipherText.Bytes()[:5] 1114 cipherBytes := cipherText.Bytes()[5:] 1115 1116 // Decrypt the payload 1117 keys := m.config.Keyring.GetKeys() 1118 return decryptPayload(keys, cipherBytes, dataBytes) 1119 } 1120 1121 // readStream is used to read from a stream connection, decrypting and 1122 // decompressing the stream if necessary. 1123 func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) { 1124 // Created a buffered reader 1125 var bufConn io.Reader = bufio.NewReader(conn) 1126 1127 // Read the message type 1128 buf := [1]byte{0} 1129 if _, err := bufConn.Read(buf[:]); err != nil { 1130 return 0, nil, nil, err 1131 } 1132 msgType := messageType(buf[0]) 1133 1134 // Check if the message is encrypted 1135 if msgType == encryptMsg { 1136 if !m.config.EncryptionEnabled() { 1137 return 0, nil, nil, 1138 fmt.Errorf("Remote state is encrypted and encryption is not configured") 1139 } 1140 1141 plain, err := m.decryptRemoteState(bufConn) 1142 if err != nil { 1143 return 0, nil, nil, err 1144 } 1145 1146 // Reset message type and bufConn 1147 msgType = messageType(plain[0]) 1148 bufConn = bytes.NewReader(plain[1:]) 1149 } else if m.config.EncryptionEnabled() && m.config.GossipVerifyIncoming { 1150 return 0, nil, nil, 1151 fmt.Errorf("Encryption is configured but remote state is not encrypted") 1152 } 1153 1154 // Get the msgPack decoders 1155 hd := codec.MsgpackHandle{} 1156 dec := codec.NewDecoder(bufConn, &hd) 1157 1158 // Check if we have a compressed message 1159 if msgType == compressMsg { 1160 var c compress 1161 if err := dec.Decode(&c); err != nil { 1162 return 0, nil, nil, err 1163 } 1164 decomp, err := decompressBuffer(&c) 1165 if err != nil { 1166 return 0, nil, nil, err 1167 } 1168 1169 // Reset the message type 1170 msgType = messageType(decomp[0]) 1171 1172 // Create a new bufConn 1173 bufConn = bytes.NewReader(decomp[1:]) 1174 1175 // Create a new decoder 1176 dec = codec.NewDecoder(bufConn, &hd) 1177 } 1178 1179 return msgType, bufConn, dec, nil 1180 } 1181 1182 // readRemoteState is used to read the remote state from a connection 1183 func (m *Memberlist) readRemoteState(bufConn io.Reader, dec *codec.Decoder) (bool, []pushNodeState, []byte, error) { 1184 // Read the push/pull header 1185 var header pushPullHeader 1186 if err := dec.Decode(&header); err != nil { 1187 return false, nil, nil, err 1188 } 1189 1190 // Allocate space for the transfer 1191 remoteNodes := make([]pushNodeState, header.Nodes) 1192 1193 // Try to decode all the states 1194 for i := 0; i < header.Nodes; i++ { 1195 if err := dec.Decode(&remoteNodes[i]); err != nil { 1196 return false, nil, nil, err 1197 } 1198 } 1199 1200 // Read the remote user state into a buffer 1201 var userBuf []byte 1202 if header.UserStateLen > 0 { 1203 userBuf = make([]byte, header.UserStateLen) 1204 bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserStateLen) 1205 if err == nil && bytes != header.UserStateLen { 1206 err = fmt.Errorf( 1207 "Failed to read full user state (%d / %d)", 1208 bytes, header.UserStateLen) 1209 } 1210 if err != nil { 1211 return false, nil, nil, err 1212 } 1213 } 1214 1215 // For proto versions < 2, there is no port provided. Mask old 1216 // behavior by using the configured port 1217 for idx := range remoteNodes { 1218 if m.ProtocolVersion() < 2 || remoteNodes[idx].Port == 0 { 1219 remoteNodes[idx].Port = uint16(m.config.BindPort) 1220 } 1221 } 1222 1223 return header.Join, remoteNodes, userBuf, nil 1224 } 1225 1226 // mergeRemoteState is used to merge the remote state with our local state 1227 func (m *Memberlist) mergeRemoteState(join bool, remoteNodes []pushNodeState, userBuf []byte) error { 1228 if err := m.verifyProtocol(remoteNodes); err != nil { 1229 return err 1230 } 1231 1232 // Invoke the merge delegate if any 1233 if join && m.config.Merge != nil { 1234 nodes := make([]*Node, len(remoteNodes)) 1235 for idx, n := range remoteNodes { 1236 nodes[idx] = &Node{ 1237 Name: n.Name, 1238 Addr: n.Addr, 1239 Port: n.Port, 1240 Meta: n.Meta, 1241 State: n.State, 1242 PMin: n.Vsn[0], 1243 PMax: n.Vsn[1], 1244 PCur: n.Vsn[2], 1245 DMin: n.Vsn[3], 1246 DMax: n.Vsn[4], 1247 DCur: n.Vsn[5], 1248 } 1249 } 1250 if err := m.config.Merge.NotifyMerge(nodes); err != nil { 1251 return err 1252 } 1253 } 1254 1255 // Merge the membership state 1256 m.mergeState(remoteNodes) 1257 1258 // Invoke the delegate for user state 1259 if userBuf != nil && m.config.Delegate != nil { 1260 m.config.Delegate.MergeRemoteState(userBuf, join) 1261 } 1262 return nil 1263 } 1264 1265 // readUserMsg is used to decode a userMsg from a stream. 1266 func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error { 1267 // Read the user message header 1268 var header userMsgHeader 1269 if err := dec.Decode(&header); err != nil { 1270 return err 1271 } 1272 1273 // Read the user message into a buffer 1274 var userBuf []byte 1275 if header.UserMsgLen > 0 { 1276 userBuf = make([]byte, header.UserMsgLen) 1277 bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserMsgLen) 1278 if err == nil && bytes != header.UserMsgLen { 1279 err = fmt.Errorf( 1280 "Failed to read full user message (%d / %d)", 1281 bytes, header.UserMsgLen) 1282 } 1283 if err != nil { 1284 return err 1285 } 1286 1287 d := m.config.Delegate 1288 if d != nil { 1289 d.NotifyMsg(userBuf) 1290 } 1291 } 1292 1293 return nil 1294 } 1295 1296 // sendPingAndWaitForAck makes a stream connection to the given address, sends 1297 // a ping, and waits for an ack. All of this is done as a series of blocking 1298 // operations, given the deadline. The bool return parameter is true if we 1299 // we able to round trip a ping to the other node. 1300 func (m *Memberlist) sendPingAndWaitForAck(a Address, ping ping, deadline time.Time) (bool, error) { 1301 if a.Name == "" && m.config.RequireNodeNames { 1302 return false, errNodeNamesAreRequired 1303 } 1304 1305 conn, err := m.transport.DialAddressTimeout(a, deadline.Sub(time.Now())) 1306 if err != nil { 1307 // If the node is actually dead we expect this to fail, so we 1308 // shouldn't spam the logs with it. After this point, errors 1309 // with the connection are real, unexpected errors and should 1310 // get propagated up. 1311 return false, nil 1312 } 1313 defer conn.Close() 1314 conn.SetDeadline(deadline) 1315 1316 out, err := encode(pingMsg, &ping) 1317 if err != nil { 1318 return false, err 1319 } 1320 1321 if err = m.rawSendMsgStream(conn, out.Bytes()); err != nil { 1322 return false, err 1323 } 1324 1325 msgType, _, dec, err := m.readStream(conn) 1326 if err != nil { 1327 return false, err 1328 } 1329 1330 if msgType != ackRespMsg { 1331 return false, fmt.Errorf("Unexpected msgType (%d) from ping %s", msgType, LogConn(conn)) 1332 } 1333 1334 var ack ackResp 1335 if err = dec.Decode(&ack); err != nil { 1336 return false, err 1337 } 1338 1339 if ack.SeqNo != ping.SeqNo { 1340 return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d)", ack.SeqNo, ping.SeqNo) 1341 } 1342 1343 return true, nil 1344 }