github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/manager/state/raft/transport/transport_test.go (about)

     1  package transport
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/coreos/etcd/raft"
     9  	"github.com/coreos/etcd/raft/raftpb"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  // Build a snapshot message where each byte in the data is of the value (index % sizeof(byte))
    15  func newSnapshotMessage(from uint64, to uint64) raftpb.Message {
    16  	data := make([]byte, GRPCMaxMsgSize)
    17  	for i := 0; i < GRPCMaxMsgSize; i++ {
    18  		data[i] = byte(i % (1 << 8))
    19  	}
    20  
    21  	return raftpb.Message{
    22  		Type: raftpb.MsgSnap,
    23  		From: from,
    24  		To:   to,
    25  		Snapshot: raftpb.Snapshot{
    26  			Data: data,
    27  			// Include the snapshot size in the Index field for testing.
    28  			Metadata: raftpb.SnapshotMetadata{
    29  				Index: uint64(len(data)),
    30  			},
    31  		},
    32  	}
    33  }
    34  
    35  // Verify that the snapshot data where each byte is of the value (index % sizeof(byte)).
    36  func verifySnapshot(raftMsg *raftpb.Message) bool {
    37  	for i, b := range raftMsg.Snapshot.Data {
    38  		if int(b) != i%(1<<8) {
    39  			return false
    40  		}
    41  	}
    42  
    43  	return len(raftMsg.Snapshot.Data) == int(raftMsg.Snapshot.Metadata.Index)
    44  }
    45  
    46  func sendMessages(ctx context.Context, c *mockCluster, from uint64, to []uint64, msgType raftpb.MessageType) error {
    47  	var firstErr error
    48  	for _, id := range to {
    49  		var err error
    50  		if msgType == raftpb.MsgSnap {
    51  			err = c.Get(from).tr.Send(newSnapshotMessage(from, id))
    52  		} else {
    53  			err = c.Get(from).tr.Send(raftpb.Message{
    54  				Type: msgType,
    55  				From: from,
    56  				To:   id,
    57  			})
    58  		}
    59  		if firstErr == nil {
    60  			firstErr = err
    61  		}
    62  	}
    63  	return firstErr
    64  }
    65  
    66  func testSend(ctx context.Context, c *mockCluster, from uint64, to []uint64, msgType raftpb.MessageType) func(*testing.T) {
    67  	return func(t *testing.T) {
    68  		ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
    69  		defer cancel()
    70  		require.NoError(t, sendMessages(ctx, c, from, to, msgType))
    71  
    72  		for _, id := range to {
    73  			select {
    74  			case msg := <-c.Get(id).processedMessages:
    75  				assert.Equal(t, msg.To, id)
    76  				assert.Equal(t, msg.From, from)
    77  			case <-ctx.Done():
    78  				t.Fatal(ctx.Err())
    79  			}
    80  		}
    81  
    82  		if msgType == raftpb.MsgSnap {
    83  			var snaps []snapshotReport
    84  			for i := 0; i < len(to); i++ {
    85  				select {
    86  				case snap := <-c.Get(from).processedSnapshots:
    87  					snaps = append(snaps, snap)
    88  				case <-ctx.Done():
    89  					t.Fatal(ctx.Err())
    90  				}
    91  			}
    92  		loop:
    93  			for _, id := range to {
    94  				for _, s := range snaps {
    95  					if s.id == id {
    96  						assert.Equal(t, s.status, raft.SnapshotFinish)
    97  						continue loop
    98  					}
    99  				}
   100  				t.Fatalf("snapshot id %d is not reported", id)
   101  			}
   102  		}
   103  	}
   104  }
   105  
   106  func TestSend(t *testing.T) {
   107  	ctx, cancel := context.WithCancel(context.Background())
   108  	c := newCluster()
   109  	defer func() {
   110  		cancel()
   111  		c.Stop()
   112  	}()
   113  	require.NoError(t, c.Add(1))
   114  	require.NoError(t, c.Add(2))
   115  	require.NoError(t, c.Add(3))
   116  
   117  	t.Run("Send Message", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup))
   118  	t.Run("Send_Snapshot_Message", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgSnap))
   119  
   120  	// Return error on streaming.
   121  	for _, raft := range c.rafts {
   122  		raft.forceErrorStream = true
   123  	}
   124  
   125  	// Messages should still be delivered.
   126  	t.Run("Send Message", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup))
   127  }
   128  
   129  func TestSendRemoved(t *testing.T) {
   130  	ctx, cancel := context.WithCancel(context.Background())
   131  	c := newCluster()
   132  	defer func() {
   133  		cancel()
   134  		c.Stop()
   135  	}()
   136  	require.NoError(t, c.Add(1))
   137  	require.NoError(t, c.Add(2))
   138  	require.NoError(t, c.Add(3))
   139  	require.NoError(t, c.Get(1).RemovePeer(2))
   140  
   141  	err := sendMessages(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)
   142  	require.Error(t, err)
   143  	require.Contains(t, err.Error(), "to removed member")
   144  }
   145  
   146  func TestSendSnapshotFailure(t *testing.T) {
   147  	ctx, cancel := context.WithCancel(context.Background())
   148  	c := newCluster()
   149  	defer func() {
   150  		cancel()
   151  		c.Stop()
   152  	}()
   153  	require.NoError(t, c.Add(1))
   154  	require.NoError(t, c.Add(2))
   155  
   156  	// stop peer server to emulate error
   157  	c.Get(2).s.Stop()
   158  
   159  	msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second)
   160  	defer msgCancel()
   161  
   162  	require.NoError(t, sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgSnap))
   163  
   164  	select {
   165  	case snap := <-c.Get(1).processedSnapshots:
   166  		assert.Equal(t, snap.id, uint64(2))
   167  		assert.Equal(t, snap.status, raft.SnapshotFailure)
   168  	case <-msgCtx.Done():
   169  		t.Fatal(ctx.Err())
   170  	}
   171  
   172  	select {
   173  	case id := <-c.Get(1).reportedUnreachables:
   174  		assert.Equal(t, id, uint64(2))
   175  	case <-msgCtx.Done():
   176  		t.Fatal(ctx.Err())
   177  	}
   178  }
   179  
   180  func TestSendUnknown(t *testing.T) {
   181  	ctx, cancel := context.WithCancel(context.Background())
   182  	c := newCluster()
   183  	defer func() {
   184  		cancel()
   185  		c.Stop()
   186  	}()
   187  	require.NoError(t, c.Add(1))
   188  	require.NoError(t, c.Add(2))
   189  	require.NoError(t, c.Add(3))
   190  
   191  	// remove peer from 1 transport to make it "unknown" to it
   192  	oldPeer := c.Get(1).tr.peers[2]
   193  	delete(c.Get(1).tr.peers, 2)
   194  	oldPeer.cancel()
   195  	<-oldPeer.done
   196  
   197  	// give peers time to mark each other as active
   198  	time.Sleep(1 * time.Second)
   199  
   200  	msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second)
   201  	defer msgCancel()
   202  
   203  	require.NoError(t, sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgHup))
   204  
   205  	select {
   206  	case msg := <-c.Get(2).processedMessages:
   207  		assert.Equal(t, msg.To, uint64(2))
   208  		assert.Equal(t, msg.From, uint64(1))
   209  	case <-msgCtx.Done():
   210  		t.Fatal(msgCtx.Err())
   211  	}
   212  }
   213  
   214  func TestUpdatePeerAddr(t *testing.T) {
   215  	ctx, cancel := context.WithCancel(context.Background())
   216  	c := newCluster()
   217  	defer func() {
   218  		cancel()
   219  		c.Stop()
   220  	}()
   221  	require.NoError(t, c.Add(1))
   222  	require.NoError(t, c.Add(2))
   223  	require.NoError(t, c.Add(3))
   224  
   225  	t.Run("Send Message Before Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup))
   226  
   227  	nr, err := newMockRaft()
   228  	require.NoError(t, err)
   229  
   230  	c.Get(3).Stop()
   231  	c.rafts[3] = nr
   232  
   233  	require.NoError(t, c.Get(1).tr.UpdatePeer(3, nr.Addr()))
   234  	require.NoError(t, c.Get(1).tr.UpdatePeer(3, nr.Addr()))
   235  
   236  	t.Run("Send Message After Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup))
   237  }
   238  
   239  func TestUpdatePeerAddrDelayed(t *testing.T) {
   240  	ctx, cancel := context.WithCancel(context.Background())
   241  	c := newCluster()
   242  	defer func() {
   243  		cancel()
   244  		c.Stop()
   245  	}()
   246  	require.NoError(t, c.Add(1))
   247  	require.NoError(t, c.Add(2))
   248  	require.NoError(t, c.Add(3))
   249  
   250  	t.Run("Send Message Before Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup))
   251  
   252  	nr, err := newMockRaft()
   253  	require.NoError(t, err)
   254  
   255  	c.Get(3).Stop()
   256  	c.rafts[3] = nr
   257  
   258  	require.NoError(t, c.Get(1).tr.UpdatePeerAddr(3, nr.Addr()))
   259  
   260  	// initiate failure to replace connection, and wait for it
   261  	sendMessages(ctx, c, 1, []uint64{3}, raftpb.MsgHup)
   262  	updateCtx, updateCancel := context.WithTimeout(ctx, 4*time.Second)
   263  	defer updateCancel()
   264  	select {
   265  	case update := <-c.Get(1).updatedNodes:
   266  		require.Equal(t, update.id, uint64(3))
   267  		require.Equal(t, update.addr, nr.Addr())
   268  	case <-updateCtx.Done():
   269  		t.Fatal(updateCtx.Err())
   270  	}
   271  
   272  	t.Run("Send Message After Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup))
   273  }
   274  
   275  func TestSendUnreachable(t *testing.T) {
   276  	ctx, cancel := context.WithCancel(context.Background())
   277  	c := newCluster()
   278  	defer func() {
   279  		cancel()
   280  		c.Stop()
   281  	}()
   282  	require.NoError(t, c.Add(1))
   283  	require.NoError(t, c.Add(2))
   284  
   285  	// set channel to nil to emulate full queue
   286  	// we need to reset some fields after cancel
   287  	p2 := c.Get(1).tr.peers[2]
   288  	p2.cancel()
   289  	<-p2.done
   290  	p2.msgc = nil
   291  	p2.done = make(chan struct{})
   292  	p2.ctx = ctx
   293  	go p2.run(ctx)
   294  
   295  	msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second)
   296  	defer msgCancel()
   297  
   298  	err := sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgSnap)
   299  	require.Error(t, err)
   300  	require.Contains(t, err.Error(), "peer is unreachable")
   301  	select {
   302  	case id := <-c.Get(1).reportedUnreachables:
   303  		assert.Equal(t, id, uint64(2))
   304  	case <-msgCtx.Done():
   305  		t.Fatal(ctx.Err())
   306  	}
   307  }
   308  
   309  func TestSendNodeRemoved(t *testing.T) {
   310  	ctx, cancel := context.WithCancel(context.Background())
   311  	c := newCluster()
   312  	defer func() {
   313  		cancel()
   314  		c.Stop()
   315  	}()
   316  	require.NoError(t, c.Add(1))
   317  	require.NoError(t, c.Add(2))
   318  
   319  	require.NoError(t, c.Get(1).RemovePeer(2))
   320  
   321  	msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second)
   322  	defer msgCancel()
   323  
   324  	require.NoError(t, sendMessages(msgCtx, c, 2, []uint64{1}, raftpb.MsgSnap))
   325  	select {
   326  	case <-c.Get(2).nodeRemovedSignal:
   327  	case <-msgCtx.Done():
   328  		t.Fatal(msgCtx.Err())
   329  	}
   330  }