github.com/aergoio/aergo@v1.3.1/p2p/v030/v032handshake_test.go (about) 1 /* 2 * @file 3 * @copyright defined in aergo/LICENSE.txt 4 */ 5 6 package v030 7 8 import ( 9 "context" 10 "encoding/hex" 11 "fmt" 12 "reflect" 13 "testing" 14 15 "github.com/aergoio/aergo-lib/log" 16 "github.com/aergoio/aergo/p2p/p2pcommon" 17 "github.com/aergoio/aergo/p2p/p2pmock" 18 "github.com/aergoio/aergo/p2p/p2putil" 19 "github.com/aergoio/aergo/types" 20 "github.com/golang/mock/gomock" 21 ) 22 23 func TestV032VersionedHS_DoForOutbound(t *testing.T) { 24 ctrl := gomock.NewController(t) 25 defer ctrl.Finish() 26 27 logger := log.NewLogger("test") 28 mockActor := p2pmock.NewMockActorService(ctrl) 29 mockCA := p2pmock.NewMockChainAccessor(ctrl) 30 mockPM := p2pmock.NewMockPeerManager(ctrl) 31 32 dummyMeta := p2pcommon.PeerMeta{ID: samplePeerID, IPAddress: "dummy.aergo.io"} 33 dummyAddr := dummyMeta.ToPeerAddress() 34 mockPM.EXPECT().SelfMeta().Return(dummyMeta).AnyTimes() 35 dummyBlock := &types.Block{Hash: dummyBlockHash, Header: &types.BlockHeader{BlockNo: dummyBlockHeight}} 36 mockActor.EXPECT().GetChainAccessor().Return(mockCA).AnyTimes() 37 mockCA.EXPECT().GetBestBlock().Return(dummyBlock, nil).AnyTimes() 38 dummyGenHash := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 39 diffGenesis := []byte{0xff, 0xfe, 0xfd, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 40 dummyStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: dummyGenHash} 41 diffGenesisStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: diffGenesis} 42 nilGenesisStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: nil} 43 nilSenderStatusMsg := &types.Status{ChainID: myChainBytes, Sender: nil, Genesis: dummyGenHash} 44 diffStatusMsg := &types.Status{ChainID: theirChainBytes, Sender: &dummyAddr, Genesis: dummyGenHash} 45 tests := []struct { 46 name string 47 readReturn *types.Status 48 readError error 49 writeError error 50 want *types.Status 51 wantErr bool 52 wantGoAway bool 53 }{ 54 {"TSuccess", dummyStatusMsg, nil, nil, dummyStatusMsg, false, false}, 55 {"TUnexpMsg", nil, nil, nil, nil, true, true}, 56 {"TRFail", dummyStatusMsg, fmt.Errorf("failed"), nil, nil, true, true}, 57 {"TRNoSender", nilSenderStatusMsg, nil, nil, nil, true, true}, 58 {"TWFail", dummyStatusMsg, nil, fmt.Errorf("failed"), nil, true, false}, 59 {"TDiffChain", diffStatusMsg, nil, nil, nil, true, true}, 60 {"TNilGenesis", nilGenesisStatusMsg, nil, nil, nil, true, true}, 61 {"TDiffGenesis", diffGenesisStatusMsg, nil, nil, nil, true, true}, 62 } 63 for _, tt := range tests { 64 t.Run(tt.name, func(t *testing.T) { 65 dummyReader := p2pmock.NewMockReadWriteCloser(ctrl) 66 mockRW := p2pmock.NewMockMsgReadWriter(ctrl) 67 68 var containerMsg *p2pcommon.MessageValue 69 if tt.readReturn != nil { 70 containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.StatusRequest, p2pcommon.NewMsgID()) 71 statusBytes, _ := p2putil.MarshalMessageBody(tt.readReturn) 72 containerMsg.SetPayload(statusBytes) 73 } else { 74 containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.AddressesRequest, p2pcommon.NewMsgID()) 75 } 76 mockRW.EXPECT().ReadMsg().Return(containerMsg, tt.readError).AnyTimes() 77 if tt.wantGoAway { 78 mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.GoAway}).Return(tt.writeError) 79 } 80 mockRW.EXPECT().WriteMsg(gomock.Any()).Return(tt.writeError).AnyTimes() 81 82 h := NewV032VersionedHS(mockPM, mockActor, logger, myChainID, samplePeerID, dummyReader, dummyGenHash) 83 h.msgRW = mockRW 84 got, err := h.DoForOutbound(context.Background()) 85 if (err != nil) != tt.wantErr { 86 t.Errorf("PeerHandshaker.DoForOutbound() error = %v, wantErr %v", err, tt.wantErr) 87 return 88 } 89 if got != nil && tt.want != nil { 90 if !reflect.DeepEqual(got.ChainID, tt.want.ChainID) { 91 fmt.Printf("got:(%d) %s \n", len(got.ChainID), hex.EncodeToString(got.ChainID)) 92 fmt.Printf("got:(%d) %s \n", len(tt.want.ChainID), hex.EncodeToString(tt.want.ChainID)) 93 t.Errorf("PeerHandshaker.DoForOutbound() = %v, want %v", got.ChainID, tt.want.ChainID) 94 } 95 } else if !reflect.DeepEqual(got, tt.want) { 96 t.Errorf("PeerHandshaker.DoForOutbound() = %v, want %v", got, tt.want) 97 } 98 }) 99 } 100 } 101 102 func TestV032VersionedHS_DoForInbound(t *testing.T) { 103 ctrl := gomock.NewController(t) 104 defer ctrl.Finish() 105 106 // t.SkipNow() 107 logger := log.NewLogger("test") 108 mockActor := p2pmock.NewMockActorService(ctrl) 109 mockCA := p2pmock.NewMockChainAccessor(ctrl) 110 mockPM := p2pmock.NewMockPeerManager(ctrl) 111 112 dummyMeta := p2pcommon.PeerMeta{ID: samplePeerID, IPAddress: "dummy.aergo.io"} 113 dummyAddr := dummyMeta.ToPeerAddress() 114 mockPM.EXPECT().SelfMeta().Return(dummyMeta).AnyTimes() 115 dummyBlock := &types.Block{Hash: dummyBlockHash, Header: &types.BlockHeader{BlockNo: dummyBlockHeight}} 116 //dummyBlkRsp := message.GetBestBlockRsp{Block: dummyBlock} 117 mockActor.EXPECT().GetChainAccessor().Return(mockCA).AnyTimes() 118 mockCA.EXPECT().GetBestBlock().Return(dummyBlock, nil).AnyTimes() 119 120 dummyGenHash := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 121 diffGenHash := []byte{0xff, 0xfe, 0xfd, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 122 dummyStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: dummyGenHash} 123 diffGenesisStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: diffGenHash} 124 nilGenesisStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: nil} 125 nilSenderStatusMsg := &types.Status{ChainID: myChainBytes, Sender: nil, Genesis: dummyGenHash} 126 diffStatusMsg := &types.Status{ChainID: theirChainBytes, Sender: &dummyAddr, Genesis: dummyGenHash} 127 tests := []struct { 128 name string 129 readReturn *types.Status 130 readError error 131 writeError error 132 want *types.Status 133 wantErr bool 134 wantGoAway bool 135 }{ 136 {"TSuccess", dummyStatusMsg, nil, nil, dummyStatusMsg, false, false}, 137 {"TUnexpMsg", nil, nil, nil, nil, true, true}, 138 {"TRFail", dummyStatusMsg, fmt.Errorf("failed"), nil, nil, true, true}, 139 {"TRNoSender", nilSenderStatusMsg, nil, nil, nil, true, true}, 140 {"TWFail", dummyStatusMsg, nil, fmt.Errorf("failed"), nil, true, false}, 141 {"TDiffChain", diffStatusMsg, nil, nil, nil, true, true}, 142 {"TNilGenesis", nilGenesisStatusMsg, nil, nil, nil, true, true}, 143 {"TDiffGenesis", diffGenesisStatusMsg, nil, nil, nil, true, true}, 144 } 145 for _, tt := range tests { 146 t.Run(tt.name, func(t *testing.T) { 147 dummyReader := p2pmock.NewMockReadWriteCloser(ctrl) 148 mockRW := p2pmock.NewMockMsgReadWriter(ctrl) 149 150 containerMsg := &p2pcommon.MessageValue{} 151 if tt.readReturn != nil { 152 containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.StatusRequest, p2pcommon.NewMsgID()) 153 statusBytes, _ := p2putil.MarshalMessageBody(tt.readReturn) 154 containerMsg.SetPayload(statusBytes) 155 } else { 156 containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.AddressesRequest, p2pcommon.NewMsgID()) 157 } 158 159 mockRW.EXPECT().ReadMsg().Return(containerMsg, tt.readError).AnyTimes() 160 if tt.wantGoAway { 161 mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.GoAway}).Return(tt.writeError) 162 } 163 mockRW.EXPECT().WriteMsg(gomock.Any()).Return(tt.writeError).AnyTimes() 164 165 h := NewV032VersionedHS(mockPM, mockActor, logger, myChainID, samplePeerID, dummyReader, dummyGenHash) 166 h.msgRW = mockRW 167 got, err := h.DoForInbound(context.Background()) 168 if (err != nil) != tt.wantErr { 169 t.Errorf("PeerHandshaker.DoForInbound() error = %v, wantErr %v", err, tt.wantErr) 170 return 171 } 172 if got != nil && tt.want != nil { 173 if !reflect.DeepEqual(got.ChainID, tt.want.ChainID) { 174 fmt.Printf("got:(%d) %s \n", len(got.ChainID), hex.EncodeToString(got.ChainID)) 175 fmt.Printf("got:(%d) %s \n", len(tt.want.ChainID), hex.EncodeToString(tt.want.ChainID)) 176 t.Errorf("PeerHandshaker.DoForInbound() = %v, want %v", got.ChainID, tt.want.ChainID) 177 } 178 } else if !reflect.DeepEqual(got, tt.want) { 179 t.Errorf("PeerHandshaker.DoForInbound() = %v, want %v", got, tt.want) 180 } 181 }) 182 } 183 }