github.com/aergoio/aergo@v1.3.1/p2p/raftsupport/snapshotsender_test.go (about)

     1  /*
     2   * @file
     3   * @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package raftsupport
     7  
     8  import (
     9  	"bytes"
    10  	"github.com/aergoio/aergo/p2p/p2pcommon"
    11  	"github.com/aergoio/aergo/p2p/p2pmock"
    12  	"github.com/aergoio/etcd/raft/raftpb"
    13  	"github.com/golang/mock/gomock"
    14  	"github.com/libp2p/go-libp2p-core/network"
    15  	"github.com/libp2p/go-libp2p-core/protocol"
    16  	"github.com/pkg/errors"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/aergoio/aergo-lib/log"
    21  	"github.com/aergoio/aergo/types"
    22  	"github.com/aergoio/etcd/snap"
    23  )
    24  
    25  func Test_snapshotSender_Send(t *testing.T) {
    26  	ctrl := gomock.NewController(t)
    27  	defer ctrl.Finish()
    28  
    29  	sampleSnaps := make([]byte, 10000)
    30  	logger := log.NewLogger("raft.support.test")
    31  	pid, _ := types.IDB58Decode("16Uiu2HAmFqptXPfcdaCdwipB2fhHATgKGVFVPehDAPZsDKSU7jRm")
    32  	sampleMeta := p2pcommon.PeerMeta{ID: pid}
    33  	tests := []struct {
    34  		name string
    35  
    36  		ntErr    error
    37  		wantSucc bool
    38  	}{
    39  		{"TRemoteDown", errors.New("conn fail"), false},
    40  		{"TLaterFail", nil, false},
    41  		// TODO : add success cases
    42  	}
    43  	for _, tt := range tests {
    44  		t.Run(tt.name, func(t *testing.T) {
    45  			mockNT := p2pmock.NewMockNetworkTransport(ctrl)
    46  			mockRaft := p2pmock.NewMockAergoRaftAccessor(ctrl)
    47  			mockPeer := p2pmock.NewMockRemotePeer(ctrl)
    48  
    49  			rc := &testStream{in: sampleSnaps, out: nil}
    50  			dummyStream := &testStream{out: bytes.NewBuffer(nil)}
    51  			mockPeer.EXPECT().ID().Return(pid).AnyTimes()
    52  			mockPeer.EXPECT().Meta().Return(sampleMeta).AnyTimes()
    53  			mockPeer.EXPECT().Name().Return("tester").AnyTimes()
    54  			mockNT.EXPECT().GetOrCreateStream(sampleMeta, p2pcommon.RaftSnapSubAddr).Return(dummyStream, tt.ntErr)
    55  			if !tt.wantSucc {
    56  				mockRaft.EXPECT().ReportUnreachable(gomock.Any())
    57  			}
    58  			mockRaft.EXPECT().ReportSnapshot(gomock.Any(), gomock.Any())
    59  
    60  			rs := raftpb.Message{}
    61  			msg := snap.NewMessage(rs, rc, 1000)
    62  
    63  			s := snapshotSender{nt: mockNT, logger: logger, rAcc: mockRaft, stopChan: make(chan interface{}), peer:mockPeer}
    64  
    65  			s.Send(msg)
    66  
    67  			if tt.ntErr != nil {
    68  				return
    69  			}
    70  
    71  			// Wait for send function finished
    72  			tick := time.NewTicker(time.Millisecond * 100)
    73  			select {
    74  			case r := <-msg.CloseNotify():
    75  				if r != tt.wantSucc {
    76  					t.Errorf("send result %v , want %v", r, tt.wantSucc)
    77  				}
    78  			case <-tick.C:
    79  				t.Errorf("unexpected timeout in send")
    80  			}
    81  
    82  		})
    83  	}
    84  }
    85  
    86  func Test_readWireHSResp(t *testing.T) {
    87  	sampleBuf := bytes.NewBuffer(nil)
    88  	sampleResp := types.SnapshotResponse{Status: types.ResultStatus_INVALID_ARGUMENT, Message: "wrong type"}
    89  	(&snapshotReceiver{}).sendResp(sampleBuf, &sampleResp)
    90  	sample := sampleBuf.Bytes()
    91  	currupted := CopyOf(sample)
    92  	lastidx := len(currupted) - 1
    93  	currupted[lastidx] = currupted[lastidx] ^ 0xff
    94  
    95  	tests := []struct {
    96  		name string
    97  		in   []byte
    98  
    99  		wantErr bool
   100  	}{
   101  		{"TNormal", CopyOf(sample), false},
   102  		{"TLongBody", append(CopyOf(sample), []byte("dummies")...), false},
   103  		{"TShortBody", CopyOf(sample)[:len(sample)-1], true},
   104  		{"TWrongHead", CopyOf(sample)[:3], true},
   105  		{"TInvalidByte", CopyOf(currupted), true},
   106  	}
   107  	for _, tt := range tests {
   108  		t.Run(tt.name, func(t *testing.T) {
   109  			buf := bytes.NewBuffer(tt.in)
   110  
   111  			gotResp, err := readWireHSResp(buf)
   112  			if (err != nil) != tt.wantErr {
   113  				t.Errorf("readWireHSResp() error = %v, wantErr %v", err, tt.wantErr)
   114  				return
   115  			}
   116  			if !tt.wantErr {
   117  				if gotResp.Status != sampleResp.Status || gotResp.Message != sampleResp.Message {
   118  					t.Errorf("readWireHSResp() = %v, want %v", gotResp, sampleResp)
   119  				}
   120  			}
   121  		})
   122  	}
   123  }
   124  
   125  func CopyOf(org []byte) []byte {
   126  	dst := make([]byte, len(org))
   127  	copy(dst, org)
   128  	return dst
   129  }
   130  
   131  type testStream struct {
   132  	in     []byte
   133  	out    *bytes.Buffer
   134  	closed bool
   135  }
   136  
   137  func (s *testStream) Read(p []byte) (n int, err error) {
   138  	size := copy(p, s.in)
   139  	return size, nil
   140  }
   141  
   142  func (s *testStream) Write(p []byte) (n int, err error) {
   143  	return s.out.Write(p)
   144  }
   145  
   146  func (s *testStream) Close() error {
   147  	s.closed = true
   148  	return nil
   149  }
   150  
   151  func (*testStream) Reset() error {
   152  	panic("implement me")
   153  }
   154  
   155  func (*testStream) SetDeadline(time.Time) error {
   156  	panic("implement me")
   157  }
   158  
   159  func (*testStream) SetReadDeadline(time.Time) error {
   160  	panic("implement me")
   161  }
   162  
   163  func (*testStream) SetWriteDeadline(time.Time) error {
   164  	panic("implement me")
   165  }
   166  
   167  func (*testStream) Protocol() protocol.ID {
   168  	panic("implement me")
   169  }
   170  
   171  func (*testStream) SetProtocol(id protocol.ID) {
   172  	panic("implement me")
   173  }
   174  
   175  func (*testStream) Stat() network.Stat {
   176  	panic("implement me")
   177  }
   178  
   179  func (*testStream) Conn() network.Conn {
   180  	panic("implement me")
   181  }