github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/memberlist/net_test.go (about)

     1  package memberlist
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"log"
     9  	"net"
    10  	"os"
    11  	"reflect"
    12  	"strconv"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/hashicorp/go-msgpack/codec"
    18  	"github.com/stretchr/testify/require"
    19  )
    20  
    21  // As a regression we left this test very low-level and network-ey, even after
    22  // we abstracted the transport. We added some basic network-free transport tests
    23  // in transport_test.go to prove that we didn't hard code some network stuff
    24  // outside of NetTransport.
    25  
    26  func TestHandleCompoundPing(t *testing.T) {
    27  	//m := GetMemberlist(t, func(c *Config) {
    28  	//	c.EnableCompression = false
    29  	//})
    30  	//defer m.Shutdown()
    31  	//
    32  	//udp := listenUDP(t)
    33  	//defer udp.Close()
    34  	//
    35  	//udpAddr := udp.LocalAddr().(*net.UDPAddr)
    36  
    37  	// Encode a ping
    38  	p1 := ping{
    39  		SeqNo:      42,
    40  		SourceAddr: "127.0.0.1",
    41  		SourcePort: 56199,
    42  		SourceNode: "test",
    43  	}
    44  	buf, err := encode(pingMsg, p1)
    45  	fmt.Println(buf.Bytes())
    46  	if err != nil {
    47  		t.Fatalf("unexpected err %s", err)
    48  	}
    49  
    50  	msg := []byte{133, 165, 83, 101, 113, 78, 111, 42, 164, 78, 111, 100, 101, 160, 170, 83, 111, 117, 114, 99, 101, 65, 100, 100, 114, 169, 49, 50, 55, 46, 48, 46, 48, 46, 49, 170, 83, 111, 117, 114, 99, 101, 80, 111, 114, 116, 205, 219, 135, 170, 83, 111, 117, 114, 99, 101, 78, 111, 100, 101, 164, 116, 101, 115, 116}
    51  	var p ping
    52  	decode(msg, &p)
    53  	fmt.Printf("%+v\n", p)
    54  
    55  	//// Make a compound message
    56  	//compound := makeCompoundMessage([][]byte{buf.Bytes(), buf.Bytes(), buf.Bytes()})
    57  	//
    58  	//// Send compound version
    59  	//addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort}
    60  	//_, err = udp.WriteTo(compound.Bytes(), addr)
    61  	//if err != nil {
    62  	//	t.Fatalf("unexpected err %s", err)
    63  	//}
    64  	//
    65  	//// Wait for responses
    66  	//doneCh := make(chan struct{}, 1)
    67  	//go func() {
    68  	//	select {
    69  	//	case <-doneCh:
    70  	//	case <-time.After(2 * time.Second):
    71  	//		panic("timeout")
    72  	//	}
    73  	//}()
    74  	//
    75  	//for i := 0; i < 3; i++ {
    76  	//	in := make([]byte, 1500)
    77  	//	n, _, err := udp.ReadFrom(in)
    78  	//	if err != nil {
    79  	//		t.Fatalf("unexpected err %s", err)
    80  	//	}
    81  	//	in = in[0:n]
    82  	//
    83  	//	msgType := messageType(in[0])
    84  	//	if msgType != ackRespMsg {
    85  	//		t.Fatalf("bad response %v", in)
    86  	//	}
    87  	//
    88  	//	var ack ackResp
    89  	//	if err := decode(in[1:], &ack); err != nil {
    90  	//		t.Fatalf("unexpected err %s", err)
    91  	//	}
    92  	//
    93  	//	if ack.SeqNo != 42 {
    94  	//		t.Fatalf("bad sequence no")
    95  	//	}
    96  	//}
    97  	//
    98  	//doneCh <- struct{}{}
    99  }
   100  
   101  func TestHandlePing(t *testing.T) {
   102  	m := GetMemberlist(t, func(c *Config) {
   103  		c.EnableCompression = false
   104  	})
   105  	defer m.Shutdown()
   106  
   107  	udp := listenUDP(t)
   108  	defer udp.Close()
   109  
   110  	udpAddr := udp.LocalAddr().(*net.UDPAddr)
   111  
   112  	// Encode a ping
   113  	ping := ping{
   114  		SeqNo:      42,
   115  		SourceAddr: udpAddr.IP.String(),
   116  		SourcePort: uint16(udpAddr.Port),
   117  		SourceNode: "test",
   118  	}
   119  	buf, err := encode(pingMsg, ping)
   120  	if err != nil {
   121  		t.Fatalf("unexpected err %s", err)
   122  	}
   123  
   124  	// Send
   125  	addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort}
   126  	_, err = udp.WriteTo(buf.Bytes(), addr)
   127  	if err != nil {
   128  		t.Fatalf("unexpected err %s", err)
   129  	}
   130  
   131  	// Wait for response
   132  	doneCh := make(chan struct{}, 1)
   133  	go func() {
   134  		select {
   135  		case <-doneCh:
   136  		case <-time.After(120 * time.Second):
   137  			panic("timeout")
   138  		}
   139  	}()
   140  
   141  	in := make([]byte, 1500)
   142  	n, _, err := udp.ReadFrom(in)
   143  	if err != nil {
   144  		t.Fatalf("unexpected err %s", err)
   145  	}
   146  	in = in[0:n]
   147  
   148  	msgType := messageType(in[0])
   149  	if msgType != ackRespMsg {
   150  		t.Fatalf("bad response %v", in)
   151  	}
   152  
   153  	var ack ackResp
   154  	if err := decode(in[1:], &ack); err != nil {
   155  		t.Fatalf("unexpected err %s", err)
   156  	}
   157  
   158  	if ack.SeqNo != 42 {
   159  		t.Fatalf("bad sequence no")
   160  	}
   161  
   162  	doneCh <- struct{}{}
   163  }
   164  
   165  func TestHandlePing_WrongNode(t *testing.T) {
   166  	m := GetMemberlist(t, func(c *Config) {
   167  		c.EnableCompression = false
   168  	})
   169  	defer m.Shutdown()
   170  
   171  	udp := listenUDP(t)
   172  	defer udp.Close()
   173  
   174  	udpAddr := udp.LocalAddr().(*net.UDPAddr)
   175  
   176  	// Encode a ping, wrong node!
   177  	ping := ping{
   178  		SeqNo:      42,
   179  		Node:       m.config.Name + "-bad",
   180  		SourceAddr: udpAddr.IP.String(),
   181  		SourcePort: uint16(udpAddr.Port),
   182  		SourceNode: "test",
   183  	}
   184  	buf, err := encode(pingMsg, ping)
   185  	if err != nil {
   186  		t.Fatalf("unexpected err %s", err)
   187  	}
   188  
   189  	// Send
   190  	addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort}
   191  	_, err = udp.WriteTo(buf.Bytes(), addr)
   192  	if err != nil {
   193  		t.Fatalf("unexpected err %s", err)
   194  	}
   195  
   196  	// Wait for response
   197  	udp.SetDeadline(time.Now().Add(50 * time.Millisecond))
   198  	in := make([]byte, 1500)
   199  	_, _, err = udp.ReadFrom(in)
   200  
   201  	// Should get an i/o timeout
   202  	if err == nil {
   203  		t.Fatalf("expected err %s", err)
   204  	}
   205  }
   206  
   207  func TestHandleIndirectPing(t *testing.T) {
   208  	m := GetMemberlist(t, func(c *Config) {
   209  		c.EnableCompression = false
   210  	})
   211  	defer m.Shutdown()
   212  
   213  	udp := listenUDP(t)
   214  	defer udp.Close()
   215  
   216  	udpAddr := udp.LocalAddr().(*net.UDPAddr)
   217  
   218  	// Encode an indirect ping
   219  	ind := indirectPingReq{
   220  		SeqNo:      100,
   221  		Target:     m.config.BindAddr,
   222  		Port:       uint16(m.config.BindPort),
   223  		Node:       m.config.Name,
   224  		SourceAddr: udpAddr.IP.String(),
   225  		SourcePort: uint16(udpAddr.Port),
   226  		SourceNode: "test",
   227  	}
   228  	buf, err := encode(indirectPingMsg, &ind)
   229  	if err != nil {
   230  		t.Fatalf("unexpected err %s", err)
   231  	}
   232  
   233  	// Send
   234  	addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort}
   235  	_, err = udp.WriteTo(buf.Bytes(), addr)
   236  	if err != nil {
   237  		t.Fatalf("unexpected err %s", err)
   238  	}
   239  
   240  	// Wait for response
   241  	doneCh := make(chan struct{}, 1)
   242  	go func() {
   243  		select {
   244  		case <-doneCh:
   245  		case <-time.After(2 * time.Second):
   246  			panic("timeout")
   247  		}
   248  	}()
   249  
   250  	in := make([]byte, 1500)
   251  	n, _, err := udp.ReadFrom(in)
   252  	if err != nil {
   253  		t.Fatalf("unexpected err %s", err)
   254  	}
   255  	in = in[0:n]
   256  
   257  	msgType := messageType(in[0])
   258  	if msgType != ackRespMsg {
   259  		t.Fatalf("bad response %v", in)
   260  	}
   261  
   262  	var ack ackResp
   263  	if err := decode(in[1:], &ack); err != nil {
   264  		t.Fatalf("unexpected err %s", err)
   265  	}
   266  
   267  	if ack.SeqNo != 100 {
   268  		t.Fatalf("bad sequence no")
   269  	}
   270  
   271  	doneCh <- struct{}{}
   272  }
   273  
   274  func TestTCPPing(t *testing.T) {
   275  	var tcp *net.TCPListener
   276  	var tcpAddr *net.TCPAddr
   277  	for port := 60000; port < 61000; port++ {
   278  		tcpAddr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: port}
   279  		tcpLn, err := net.ListenTCP("tcp", tcpAddr)
   280  		if err == nil {
   281  			tcp = tcpLn
   282  			break
   283  		}
   284  	}
   285  	if tcp == nil {
   286  		t.Fatalf("no tcp listener")
   287  	}
   288  
   289  	tcpAddr2 := Address{Addr: tcpAddr.String(), Name: "test"}
   290  
   291  	// Note that tcp gets closed in the last test, so we avoid a deferred
   292  	// Close() call here.
   293  
   294  	m := GetMemberlist(t, nil)
   295  	defer m.Shutdown()
   296  
   297  	pingTimeout := m.config.ProbeInterval
   298  	pingTimeMax := m.config.ProbeInterval + 10*time.Millisecond
   299  
   300  	// Do a normal round trip.
   301  	pingOut := ping{SeqNo: 23, Node: "mongo"}
   302  	go func() {
   303  		tcp.SetDeadline(time.Now().Add(pingTimeMax))
   304  		conn, err := tcp.AcceptTCP()
   305  		if err != nil {
   306  			t.Fatalf("failed to connect: %s", err)
   307  		}
   308  		defer conn.Close()
   309  
   310  		msgType, _, dec, err := m.readStream(conn)
   311  		if err != nil {
   312  			t.Fatalf("failed to read ping: %s", err)
   313  		}
   314  
   315  		if msgType != pingMsg {
   316  			t.Fatalf("expecting ping, got message type (%d)", msgType)
   317  		}
   318  
   319  		var pingIn ping
   320  		if err := dec.Decode(&pingIn); err != nil {
   321  			t.Fatalf("failed to decode ping: %s", err)
   322  		}
   323  
   324  		if pingIn.SeqNo != pingOut.SeqNo {
   325  			t.Fatalf("sequence number isn't correct (%d) vs (%d)", pingIn.SeqNo, pingOut.SeqNo)
   326  		}
   327  
   328  		if pingIn.Node != pingOut.Node {
   329  			t.Fatalf("node name isn't correct (%s) vs (%s)", pingIn.Node, pingOut.Node)
   330  		}
   331  
   332  		ack := ackResp{pingIn.SeqNo, nil}
   333  		out, err := encode(ackRespMsg, &ack)
   334  		if err != nil {
   335  			t.Fatalf("failed to encode ack: %s", err)
   336  		}
   337  
   338  		err = m.rawSendMsgStream(conn, out.Bytes())
   339  		if err != nil {
   340  			t.Fatalf("failed to send ack: %s", err)
   341  		}
   342  	}()
   343  	deadline := time.Now().Add(pingTimeout)
   344  	didContact, err := m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline)
   345  	if err != nil {
   346  		t.Fatalf("error trying to ping: %s", err)
   347  	}
   348  	if !didContact {
   349  		t.Fatalf("expected successful ping")
   350  	}
   351  
   352  	// Make sure a mis-matched sequence number is caught.
   353  	go func() {
   354  		tcp.SetDeadline(time.Now().Add(pingTimeMax))
   355  		conn, err := tcp.AcceptTCP()
   356  		if err != nil {
   357  			t.Fatalf("failed to connect: %s", err)
   358  		}
   359  		defer conn.Close()
   360  
   361  		_, _, dec, err := m.readStream(conn)
   362  		if err != nil {
   363  			t.Fatalf("failed to read ping: %s", err)
   364  		}
   365  
   366  		var pingIn ping
   367  		if err := dec.Decode(&pingIn); err != nil {
   368  			t.Fatalf("failed to decode ping: %s", err)
   369  		}
   370  
   371  		ack := ackResp{pingIn.SeqNo + 1, nil}
   372  		out, err := encode(ackRespMsg, &ack)
   373  		if err != nil {
   374  			t.Fatalf("failed to encode ack: %s", err)
   375  		}
   376  
   377  		err = m.rawSendMsgStream(conn, out.Bytes())
   378  		if err != nil {
   379  			t.Fatalf("failed to send ack: %s", err)
   380  		}
   381  	}()
   382  	deadline = time.Now().Add(pingTimeout)
   383  	didContact, err = m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline)
   384  	if err == nil || !strings.Contains(err.Error(), "Sequence number") {
   385  		t.Fatalf("expected an error from mis-matched sequence number")
   386  	}
   387  	if didContact {
   388  		t.Fatalf("expected failed ping")
   389  	}
   390  
   391  	// Make sure an unexpected message type is handled gracefully.
   392  	go func() {
   393  		tcp.SetDeadline(time.Now().Add(pingTimeMax))
   394  		conn, err := tcp.AcceptTCP()
   395  		if err != nil {
   396  			t.Fatalf("failed to connect: %s", err)
   397  		}
   398  		defer conn.Close()
   399  
   400  		_, _, _, err = m.readStream(conn)
   401  		if err != nil {
   402  			t.Fatalf("failed to read ping: %s", err)
   403  		}
   404  
   405  		bogus := indirectPingReq{}
   406  		out, err := encode(indirectPingMsg, &bogus)
   407  		if err != nil {
   408  			t.Fatalf("failed to encode bogus msg: %s", err)
   409  		}
   410  
   411  		err = m.rawSendMsgStream(conn, out.Bytes())
   412  		if err != nil {
   413  			t.Fatalf("failed to send bogus msg: %s", err)
   414  		}
   415  	}()
   416  	deadline = time.Now().Add(pingTimeout)
   417  	didContact, err = m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline)
   418  	if err == nil || !strings.Contains(err.Error(), "Unexpected msgType") {
   419  		t.Fatalf("expected an error from bogus message")
   420  	}
   421  	if didContact {
   422  		t.Fatalf("expected failed ping")
   423  	}
   424  
   425  	// Make sure failed I/O respects the deadline. In this case we try the
   426  	// common case of the receiving node being totally down.
   427  	tcp.Close()
   428  	deadline = time.Now().Add(pingTimeout)
   429  	startPing := time.Now()
   430  	didContact, err = m.sendPingAndWaitForAck(tcpAddr2, pingOut, deadline)
   431  	pingTime := time.Now().Sub(startPing)
   432  	if err != nil {
   433  		t.Fatalf("expected no error during ping on closed socket, got: %s", err)
   434  	}
   435  	if didContact {
   436  		t.Fatalf("expected failed ping")
   437  	}
   438  	if pingTime > pingTimeMax {
   439  		t.Fatalf("took too long to fail ping, %9.6f", pingTime.Seconds())
   440  	}
   441  }
   442  
   443  func TestTCPPushPull(t *testing.T) {
   444  	m := GetMemberlist(t, nil)
   445  	defer m.Shutdown()
   446  
   447  	m.nodes = append(m.nodes, &nodeState{
   448  		Node: Node{
   449  			Name: "Test 0",
   450  			Addr: m.config.BindAddr,
   451  			Port: uint16(m.config.BindPort),
   452  		},
   453  		Incarnation: 0,
   454  		State:       StateSuspect,
   455  		StateChange: time.Now().Add(-1 * time.Second),
   456  	})
   457  
   458  	addr := net.JoinHostPort(m.config.BindAddr, strconv.Itoa(m.config.BindPort))
   459  	conn, err := net.Dial("tcp", addr)
   460  	if err != nil {
   461  		t.Fatalf("unexpected err %s", err)
   462  	}
   463  	defer conn.Close()
   464  
   465  	localNodes := make([]pushNodeState, 3)
   466  	localNodes[0].Name = "Test 0"
   467  	localNodes[0].Addr = m.config.BindAddr
   468  	localNodes[0].Port = uint16(m.config.BindPort)
   469  	localNodes[0].Incarnation = 1
   470  	localNodes[0].State = StateAlive
   471  	localNodes[1].Name = "Test 1"
   472  	localNodes[1].Addr = m.config.BindAddr
   473  	localNodes[1].Port = uint16(m.config.BindPort)
   474  	localNodes[1].Incarnation = 1
   475  	localNodes[1].State = StateAlive
   476  	localNodes[2].Name = "Test 2"
   477  	localNodes[2].Addr = m.config.BindAddr
   478  	localNodes[2].Port = uint16(m.config.BindPort)
   479  	localNodes[2].Incarnation = 1
   480  	localNodes[2].State = StateAlive
   481  
   482  	// Send our node state
   483  	header := pushPullHeader{Nodes: 3}
   484  	hd := codec.MsgpackHandle{}
   485  	enc := codec.NewEncoder(conn, &hd)
   486  
   487  	// Send the push/pull indicator
   488  	conn.Write([]byte{byte(pushPullMsg)})
   489  
   490  	if err := enc.Encode(&header); err != nil {
   491  		t.Fatalf("unexpected err %s", err)
   492  	}
   493  	for i := 0; i < header.Nodes; i++ {
   494  		if err := enc.Encode(&localNodes[i]); err != nil {
   495  			t.Fatalf("unexpected err %s", err)
   496  		}
   497  	}
   498  
   499  	// Read the message type
   500  	var msgType messageType
   501  	if err := binary.Read(conn, binary.BigEndian, &msgType); err != nil {
   502  		t.Fatalf("unexpected err %s", err)
   503  	}
   504  
   505  	var bufConn io.Reader = conn
   506  	msghd := codec.MsgpackHandle{}
   507  	dec := codec.NewDecoder(bufConn, &msghd)
   508  
   509  	// Check if we have a compressed message
   510  	if msgType == compressMsg {
   511  		var c compress
   512  		if err := dec.Decode(&c); err != nil {
   513  			t.Fatalf("unexpected err %s", err)
   514  		}
   515  		decomp, err := decompressBuffer(&c)
   516  		if err != nil {
   517  			t.Fatalf("unexpected err %s", err)
   518  		}
   519  
   520  		// Reset the message type
   521  		msgType = messageType(decomp[0])
   522  
   523  		// Create a new bufConn
   524  		bufConn = bytes.NewReader(decomp[1:])
   525  
   526  		// Create a new decoder
   527  		dec = codec.NewDecoder(bufConn, &hd)
   528  	}
   529  
   530  	// Quit if not push/pull
   531  	if msgType != pushPullMsg {
   532  		t.Fatalf("bad message type")
   533  	}
   534  
   535  	if err := dec.Decode(&header); err != nil {
   536  		t.Fatalf("unexpected err %s", err)
   537  	}
   538  
   539  	// Allocate space for the transfer
   540  	remoteNodes := make([]pushNodeState, header.Nodes)
   541  
   542  	// Try to decode all the states
   543  	for i := 0; i < header.Nodes; i++ {
   544  		if err := dec.Decode(&remoteNodes[i]); err != nil {
   545  			t.Fatalf("unexpected err %s", err)
   546  		}
   547  	}
   548  
   549  	if len(remoteNodes) != 1 {
   550  		t.Fatalf("bad response")
   551  	}
   552  
   553  	n := &remoteNodes[0]
   554  	if n.Name != "Test 0" {
   555  		t.Fatalf("bad name")
   556  	}
   557  	if n.Incarnation != 0 {
   558  		t.Fatal("bad incarnation")
   559  	}
   560  	if n.State != StateSuspect {
   561  		t.Fatal("bad state")
   562  	}
   563  }
   564  
   565  func TestSendMsg_Piggyback(t *testing.T) {
   566  	m := GetMemberlist(t, nil)
   567  	defer m.Shutdown()
   568  
   569  	// Add a message to be broadcast
   570  	a := alive{
   571  		Incarnation: 10,
   572  		Node:        "rand",
   573  		Addr:        "127.0.0.255",
   574  		Meta:        nil,
   575  		Vsn: []uint8{
   576  			ProtocolVersionMin, ProtocolVersionMax, ProtocolVersionMin,
   577  			1, 1, 1,
   578  		},
   579  	}
   580  	m.encodeAndBroadcast("rand", aliveMsg, &a)
   581  
   582  	udp := listenUDP(t)
   583  	defer udp.Close()
   584  
   585  	udpAddr := udp.LocalAddr().(*net.UDPAddr)
   586  
   587  	// Encode a ping
   588  	ping := ping{
   589  		SeqNo:      42,
   590  		SourceAddr: udpAddr.IP.String(),
   591  		SourcePort: uint16(udpAddr.Port),
   592  		SourceNode: "test",
   593  	}
   594  	buf, err := encode(pingMsg, ping)
   595  	if err != nil {
   596  		t.Fatalf("unexpected err %s", err)
   597  	}
   598  
   599  	// Send
   600  	addr := &net.UDPAddr{IP: net.ParseIP(m.config.BindAddr), Port: m.config.BindPort}
   601  	_, err = udp.WriteTo(buf.Bytes(), addr)
   602  	if err != nil {
   603  		t.Fatalf("unexpected err %s", err)
   604  	}
   605  
   606  	// Wait for response
   607  	doneCh := make(chan struct{}, 1)
   608  	go func() {
   609  		select {
   610  		case <-doneCh:
   611  		case <-time.After(2 * time.Second):
   612  			panic("timeout")
   613  		}
   614  	}()
   615  
   616  	in := make([]byte, 1500)
   617  	n, _, err := udp.ReadFrom(in)
   618  	if err != nil {
   619  		t.Fatalf("unexpected err %s", err)
   620  	}
   621  	in = in[0:n]
   622  
   623  	msgType := messageType(in[0])
   624  	if msgType != compoundMsg {
   625  		t.Fatalf("bad response %v", in)
   626  	}
   627  
   628  	// get the parts
   629  	trunc, parts, err := decodeCompoundMessage(in[1:])
   630  	if trunc != 0 {
   631  		t.Fatalf("unexpected truncation")
   632  	}
   633  	if len(parts) != 2 {
   634  		t.Fatalf("unexpected parts %v", parts)
   635  	}
   636  	if err != nil {
   637  		t.Fatalf("unexpected err %s", err)
   638  	}
   639  
   640  	var ack ackResp
   641  	if err := decode(parts[0][1:], &ack); err != nil {
   642  		t.Fatalf("unexpected err %s", err)
   643  	}
   644  
   645  	if ack.SeqNo != 42 {
   646  		t.Fatalf("bad sequence no")
   647  	}
   648  
   649  	var aliveout alive
   650  	if err := decode(parts[1][1:], &aliveout); err != nil {
   651  		t.Fatalf("unexpected err %s", err)
   652  	}
   653  
   654  	if aliveout.Node != "rand" || aliveout.Incarnation != 10 {
   655  		t.Fatalf("bad mesg")
   656  	}
   657  
   658  	doneCh <- struct{}{}
   659  }
   660  
   661  func TestEncryptDecryptState(t *testing.T) {
   662  	state := []byte("this is our internal state...")
   663  	config := &Config{
   664  		SecretKey:       []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   665  		ProtocolVersion: ProtocolVersionMax,
   666  	}
   667  
   668  	m, err := Create(config)
   669  	if err != nil {
   670  		t.Fatalf("err: %s", err)
   671  	}
   672  	defer m.Shutdown()
   673  
   674  	crypt, err := m.encryptLocalState(state)
   675  	if err != nil {
   676  		t.Fatalf("err: %v", err)
   677  	}
   678  
   679  	// Create reader, seek past the type byte
   680  	buf := bytes.NewReader(crypt)
   681  	buf.Seek(1, 0)
   682  
   683  	plain, err := m.decryptRemoteState(buf)
   684  	if err != nil {
   685  		t.Fatalf("err: %v", err)
   686  	}
   687  
   688  	if !reflect.DeepEqual(state, plain) {
   689  		t.Fatalf("Decrypt failed: %v", plain)
   690  	}
   691  }
   692  
   693  func testConfigNet(tb testing.TB, network byte) *Config {
   694  	tb.Helper()
   695  
   696  	config := DefaultLANConfig()
   697  	config.BindAddr = getBindAddrNet(network).String()
   698  	config.Name = config.BindAddr
   699  	config.BindPort = 0 // choose free port
   700  	config.RequireNodeNames = true
   701  	config.Logger = log.New(os.Stderr, config.Name, log.LstdFlags)
   702  	return config
   703  }
   704  
   705  func testConfig(tb testing.TB) *Config {
   706  	return testConfigNet(tb, 0)
   707  }
   708  
   709  func GetMemberlist(tb testing.TB, f func(c *Config)) *Memberlist {
   710  	c := testConfig(tb)
   711  	c.BindPort = 0
   712  	if f != nil {
   713  		f(c)
   714  	}
   715  
   716  	m, err := NewMemberlist(c)
   717  	require.NoError(tb, err)
   718  	return m
   719  }
   720  
   721  func TestRawSendUdp_CRC(t *testing.T) {
   722  	m := GetMemberlist(t, func(c *Config) {
   723  		c.EnableCompression = false
   724  	})
   725  	defer m.Shutdown()
   726  
   727  	udp := listenUDP(t)
   728  	defer udp.Close()
   729  
   730  	a := Address{
   731  		Addr: udp.LocalAddr().String(),
   732  		Name: "test",
   733  	}
   734  
   735  	// Pass a nil node with no nodes registered, should result in no checksum
   736  	payload := []byte{3, 3, 3, 3}
   737  	m.rawSendMsgPacket(a, nil, payload)
   738  
   739  	in := make([]byte, 1500)
   740  	n, _, err := udp.ReadFrom(in)
   741  	if err != nil {
   742  		t.Fatalf("unexpected err %s", err)
   743  	}
   744  	in = in[0:n]
   745  
   746  	if len(in) != 4 {
   747  		t.Fatalf("bad: %v", in)
   748  	}
   749  
   750  	// Pass a non-nil node with PMax >= 5, should result in a checksum
   751  	m.rawSendMsgPacket(a, &Node{PMax: 5}, payload)
   752  
   753  	in = make([]byte, 1500)
   754  	n, _, err = udp.ReadFrom(in)
   755  	if err != nil {
   756  		t.Fatalf("unexpected err %s", err)
   757  	}
   758  	in = in[0:n]
   759  
   760  	if len(in) != 9 {
   761  		t.Fatalf("bad: %v", in)
   762  	}
   763  
   764  	// Register a node with PMax >= 5 to be looked up, should result in a checksum
   765  	m.nodeMap["127.0.0.1"] = &nodeState{
   766  		Node: Node{PMax: 5},
   767  	}
   768  	m.rawSendMsgPacket(a, nil, payload)
   769  
   770  	in = make([]byte, 1500)
   771  	n, _, err = udp.ReadFrom(in)
   772  	if err != nil {
   773  		t.Fatalf("unexpected err %s", err)
   774  	}
   775  	in = in[0:n]
   776  
   777  	if len(in) != 9 {
   778  		t.Fatalf("bad: %v", in)
   779  	}
   780  }
   781  
   782  func TestIngestPacket_CRC(t *testing.T) {
   783  	m := GetMemberlist(t, func(c *Config) {
   784  		c.EnableCompression = false
   785  	})
   786  	defer m.Shutdown()
   787  
   788  	udp := listenUDP(t)
   789  	defer udp.Close()
   790  
   791  	a := Address{
   792  		Addr: udp.LocalAddr().String(),
   793  		Name: "test",
   794  	}
   795  
   796  	// Get a message with a checksum
   797  	payload := []byte{3, 3, 3, 3}
   798  	m.rawSendMsgPacket(a, &Node{PMax: 5}, payload)
   799  
   800  	in := make([]byte, 1500)
   801  	n, _, err := udp.ReadFrom(in)
   802  	if err != nil {
   803  		t.Fatalf("unexpected err %s", err)
   804  	}
   805  	in = in[0:n]
   806  
   807  	if len(in) != 9 {
   808  		t.Fatalf("bad: %v", in)
   809  	}
   810  
   811  	// Corrupt the checksum
   812  	in[1] <<= 1
   813  
   814  	logs := &bytes.Buffer{}
   815  	logger := log.New(logs, "", 0)
   816  	m.logger = logger
   817  	m.ingestPacket(in, udp.LocalAddr(), time.Now())
   818  
   819  	if !strings.Contains(logs.String(), "invalid checksum") {
   820  		t.Fatalf("bad: %s", logs.String())
   821  	}
   822  }
   823  
   824  func TestIngestPacket_ExportedFunc_EmptyMessage(t *testing.T) {
   825  	m := GetMemberlist(t, func(c *Config) {
   826  		c.EnableCompression = false
   827  	})
   828  	defer m.Shutdown()
   829  
   830  	udp := listenUDP(t)
   831  	defer udp.Close()
   832  
   833  	emptyConn := &emptyReadNetConn{}
   834  
   835  	logs := &bytes.Buffer{}
   836  	logger := log.New(logs, "", 0)
   837  	m.logger = logger
   838  
   839  	type ingestionAwareTransport interface {
   840  		IngestPacket(conn net.Conn, addr net.Addr, now time.Time, shouldClose bool) error
   841  	}
   842  
   843  	err := m.transport.(ingestionAwareTransport).IngestPacket(emptyConn, udp.LocalAddr(), time.Now(), true)
   844  	require.Error(t, err)
   845  	require.Contains(t, err.Error(), "packet too short")
   846  }
   847  
   848  type emptyReadNetConn struct {
   849  	net.Conn
   850  }
   851  
   852  func (c *emptyReadNetConn) Read(b []byte) (n int, err error) {
   853  	return 0, io.EOF
   854  }
   855  
   856  func (c *emptyReadNetConn) Close() error {
   857  	return nil
   858  }
   859  
   860  func TestGossip_MismatchedKeys(t *testing.T) {
   861  	// Create two agents with different gossip keys
   862  	c1 := testConfig(t)
   863  	c1.SecretKey = []byte("4W6DGn2VQVqDEceOdmuRTQ==")
   864  	c1.BindPort = 56188
   865  	c1.AdvertisePort = 56188
   866  	m1, err := Create(c1)
   867  	require.NoError(t, err)
   868  	defer m1.Shutdown()
   869  
   870  	c2 := testConfig(t)
   871  	c2.BindPort = 56189
   872  	c2.AdvertisePort = 56189
   873  	c2.SecretKey = []byte("XhX/w702/JKKK7/7OtM9Ww==")
   874  
   875  	m2, err := Create(c2)
   876  	require.NoError(t, err)
   877  	defer m2.Shutdown()
   878  
   879  	// Make sure we get this error on the joining side
   880  	m1JoinUrl := fmt.Sprintf("%s/%s:%d", m1.config.Name, m1.advertiseAddr, m1.advertisePort)
   881  	_, err = m2.Join([]string{m1JoinUrl})
   882  	if err == nil || !strings.Contains(err.Error(), "No installed keys could decrypt the message") {
   883  		t.Fatalf("bad: %s", err)
   884  	}
   885  }
   886  
   887  func listenUDP(t *testing.T) *net.UDPConn {
   888  	var udp *net.UDPConn
   889  	for port := 56199; port < 60000; port++ {
   890  		udpAddr := fmt.Sprintf("127.0.0.1:%d", port)
   891  		udpLn, err := net.ListenPacket("udp", udpAddr)
   892  		if err == nil {
   893  			udp = udpLn.(*net.UDPConn)
   894  			break
   895  		}
   896  	}
   897  	if udp == nil {
   898  		t.Fatalf("no udp listener")
   899  	}
   900  	return udp
   901  }
   902  
   903  func TestHandleCommand(t *testing.T) {
   904  	var buf bytes.Buffer
   905  	m := Memberlist{
   906  		logger: log.New(&buf, "", 0),
   907  	}
   908  	m.handleCommand(nil, &net.TCPAddr{Port: 12345}, time.Now())
   909  	require.Contains(t, buf.String(), "missing message type byte")
   910  }