github.com/aergoio/aergo@v1.3.1/p2p/v030/v030handshake_test.go (about) 1 /* 2 * @file 3 * @copyright defined in aergo/LICENSE.txt 4 */ 5 6 package v030 7 8 import ( 9 "context" 10 "encoding/hex" 11 "fmt" 12 "reflect" 13 "testing" 14 15 "github.com/aergoio/aergo-lib/log" 16 "github.com/aergoio/aergo/config" 17 "github.com/aergoio/aergo/p2p/p2pcommon" 18 "github.com/aergoio/aergo/p2p/p2pkey" 19 "github.com/aergoio/aergo/p2p/p2pmock" 20 "github.com/aergoio/aergo/p2p/p2putil" 21 "github.com/aergoio/aergo/types" 22 "github.com/golang/mock/gomock" 23 ) 24 25 var ( 26 myChainID, theirChainID *types.ChainID 27 myChainBytes, theirChainBytes []byte 28 samplePeerID, _ = types.IDB58Decode("16Uiu2HAmFqptXPfcdaCdwipB2fhHATgKGVFVPehDAPZsDKSU7jRm") 29 dummyBlockHash, _ = hex.DecodeString("4f461d85e869ade8a0544f8313987c33a9c06534e50c4ad941498299579bd7ac") 30 dummyBlockHeight uint64 = 100215 31 ) 32 33 func init() { 34 myChainID = types.NewChainID() 35 myChainID.Magic = "itSmain1" 36 myChainBytes, _ = myChainID.Bytes() 37 38 theirChainID = types.NewChainID() 39 theirChainID.Read(myChainBytes) 40 theirChainID.Magic = "itsdiff2" 41 theirChainBytes, _ = theirChainID.Bytes() 42 43 sampleKeyFile := "../../test/sample.key" 44 baseCfg := &config.BaseConfig{AuthDir: "test"} 45 p2pCfg := &config.P2PConfig{NPKey: sampleKeyFile} 46 p2pkey.InitNodeInfo(baseCfg, p2pCfg, "0.0.1-test", log.NewLogger("v030.test")) 47 } 48 49 func TestDeepEqual(t *testing.T) { 50 b1, _ := myChainID.Bytes() 51 b2 := make([]byte, len(b1), len(b1)<<1) 52 copy(b2, b1) 53 54 s1 := &types.Status{ChainID: b1} 55 s2 := &types.Status{ChainID: b2} 56 57 if !reflect.DeepEqual(s1, s2) { 58 t.Errorf("byte slice cant do DeepEqual! %v, %v", b1, b2) 59 } 60 61 } 62 63 func TestV030StatusHS_doForOutbound(t *testing.T) { 64 ctrl := gomock.NewController(t) 65 defer ctrl.Finish() 66 67 logger := log.NewLogger("test") 68 mockActor := p2pmock.NewMockActorService(ctrl) 69 mockCA := p2pmock.NewMockChainAccessor(ctrl) 70 mockPM := p2pmock.NewMockPeerManager(ctrl) 71 72 dummyMeta := p2pcommon.PeerMeta{ID: samplePeerID, IPAddress: "dummy.aergo.io"} 73 dummyAddr := dummyMeta.ToPeerAddress() 74 mockPM.EXPECT().SelfMeta().Return(dummyMeta).AnyTimes() 75 dummyBlock := &types.Block{Hash: dummyBlockHash, Header: &types.BlockHeader{BlockNo: dummyBlockHeight}} 76 mockActor.EXPECT().GetChainAccessor().Return(mockCA).AnyTimes() 77 mockCA.EXPECT().GetBestBlock().Return(dummyBlock, nil).AnyTimes() 78 79 dummyStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr} 80 nilSenderStatusMsg := &types.Status{ChainID: myChainBytes, Sender: nil} 81 diffStatusMsg := &types.Status{ChainID: theirChainBytes, Sender: &dummyAddr} 82 tests := []struct { 83 name string 84 readReturn *types.Status 85 readError error 86 writeError error 87 want *types.Status 88 wantErr bool 89 wantGoAway bool 90 }{ 91 {"TSuccess", dummyStatusMsg, nil, nil, dummyStatusMsg, false, false}, 92 {"TUnexpMsg", nil, nil, nil, nil, true, true}, 93 {"TRFail", dummyStatusMsg, fmt.Errorf("failed"), nil, nil, true, true}, 94 {"TRNoSender", nilSenderStatusMsg, nil, nil, nil, true, true}, 95 {"TWFail", dummyStatusMsg, nil, fmt.Errorf("failed"), nil, true, false}, 96 {"TDiffChain", diffStatusMsg, nil, nil, nil, true, true}, 97 } 98 for _, tt := range tests { 99 t.Run(tt.name, func(t *testing.T) { 100 dummyReader := p2pmock.NewMockReadWriteCloser(ctrl) 101 mockRW := p2pmock.NewMockMsgReadWriter(ctrl) 102 103 var containerMsg *p2pcommon.MessageValue 104 if tt.readReturn != nil { 105 containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.StatusRequest, p2pcommon.NewMsgID()) 106 statusBytes, _ := p2putil.MarshalMessageBody(tt.readReturn) 107 containerMsg.SetPayload(statusBytes) 108 } else { 109 containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.AddressesRequest, p2pcommon.NewMsgID()) 110 } 111 mockRW.EXPECT().ReadMsg().Return(containerMsg, tt.readError).AnyTimes() 112 if tt.wantGoAway { 113 mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.GoAway}).Return(tt.writeError) 114 } 115 mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.StatusRequest}).Return(tt.writeError).MaxTimes(1) 116 117 h := NewV030VersionedHS(mockPM, mockActor, logger, myChainID, samplePeerID, dummyReader) 118 h.msgRW = mockRW 119 got, err := h.DoForOutbound(context.Background()) 120 if (err != nil) != tt.wantErr { 121 t.Errorf("PeerHandshaker.handshakeOutboundPeer() error = %v, wantErr %v", err, tt.wantErr) 122 return 123 } 124 if got != nil && tt.want != nil { 125 if !reflect.DeepEqual(got.ChainID, tt.want.ChainID) { 126 fmt.Printf("got:(%d) %s \n", len(got.ChainID), hex.EncodeToString(got.ChainID)) 127 fmt.Printf("got:(%d) %s \n", len(tt.want.ChainID), hex.EncodeToString(tt.want.ChainID)) 128 t.Errorf("PeerHandshaker.handshakeOutboundPeer() = %v, want %v", got.ChainID, tt.want.ChainID) 129 } 130 } else if !reflect.DeepEqual(got, tt.want) { 131 t.Errorf("PeerHandshaker.handshakeOutboundPeer() = %v, want %v", got, tt.want) 132 } 133 }) 134 } 135 } 136 137 func TestV030StatusHS_handshakeInboundPeer(t *testing.T) { 138 ctrl := gomock.NewController(t) 139 defer ctrl.Finish() 140 141 // t.SkipNow() 142 logger := log.NewLogger("test") 143 mockActor := p2pmock.NewMockActorService(ctrl) 144 mockCA := p2pmock.NewMockChainAccessor(ctrl) 145 mockPM := p2pmock.NewMockPeerManager(ctrl) 146 147 dummyMeta := p2pcommon.PeerMeta{ID: samplePeerID, IPAddress: "dummy.aergo.io"} 148 dummyAddr := dummyMeta.ToPeerAddress() 149 mockPM.EXPECT().SelfMeta().Return(dummyMeta).AnyTimes() 150 dummyBlock := &types.Block{Hash: dummyBlockHash, Header: &types.BlockHeader{BlockNo: dummyBlockHeight}} 151 //dummyBlkRsp := message.GetBestBlockRsp{Block: dummyBlock} 152 mockActor.EXPECT().GetChainAccessor().Return(mockCA).AnyTimes() 153 mockCA.EXPECT().GetBestBlock().Return(dummyBlock, nil).AnyTimes() 154 155 dummyStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr} 156 nilSenderStatusMsg := &types.Status{ChainID: myChainBytes, Sender: nil} 157 diffStatusMsg := &types.Status{ChainID: theirChainBytes, Sender: &dummyAddr} 158 tests := []struct { 159 name string 160 readReturn *types.Status 161 readError error 162 writeError error 163 want *types.Status 164 wantErr bool 165 wantGoAway bool 166 }{ 167 {"TSuccess", dummyStatusMsg, nil, nil, dummyStatusMsg, false, false}, 168 {"TUnexpMsg", nil, nil, nil, nil, true, true}, 169 {"TRFail", dummyStatusMsg, fmt.Errorf("failed"), nil, nil, true, true}, 170 {"TRNoSender", nilSenderStatusMsg, nil, nil, nil, true, true}, 171 {"TWFail", dummyStatusMsg, nil, fmt.Errorf("failed"), nil, true, false}, 172 {"TDiffChain", diffStatusMsg, nil, nil, nil, true, true}, 173 } 174 for _, tt := range tests { 175 t.Run(tt.name, func(t *testing.T) { 176 dummyReader := p2pmock.NewMockReadWriteCloser(ctrl) 177 mockRW := p2pmock.NewMockMsgReadWriter(ctrl) 178 179 containerMsg := &p2pcommon.MessageValue{} 180 if tt.readReturn != nil { 181 containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.StatusRequest, p2pcommon.NewMsgID()) 182 statusBytes, _ := p2putil.MarshalMessageBody(tt.readReturn) 183 containerMsg.SetPayload(statusBytes) 184 } else { 185 containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.AddressesRequest, p2pcommon.NewMsgID()) 186 } 187 188 mockRW.EXPECT().ReadMsg().Return(containerMsg, tt.readError).AnyTimes() 189 if tt.wantGoAway { 190 mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.GoAway}).Return(tt.writeError) 191 } 192 mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.StatusRequest}).Return(tt.writeError).MaxTimes(1) 193 194 h := NewV030VersionedHS(mockPM, mockActor, logger, myChainID, samplePeerID, dummyReader) 195 h.msgRW = mockRW 196 got, err := h.DoForInbound(context.Background()) 197 if (err != nil) != tt.wantErr { 198 t.Errorf("PeerHandshaker.handshakeInboundPeer() error = %v, wantErr %v", err, tt.wantErr) 199 return 200 } 201 if got != nil && tt.want != nil { 202 if !reflect.DeepEqual(got.ChainID, tt.want.ChainID) { 203 fmt.Printf("got:(%d) %s \n", len(got.ChainID), hex.EncodeToString(got.ChainID)) 204 fmt.Printf("got:(%d) %s \n", len(tt.want.ChainID), hex.EncodeToString(tt.want.ChainID)) 205 t.Errorf("PeerHandshaker.handshakeOutboundPeer() = %v, want %v", got.ChainID, tt.want.ChainID) 206 } 207 } else if !reflect.DeepEqual(got, tt.want) { 208 t.Errorf("PeerHandshaker.handshakeInboundPeer() = %v, want %v", got, tt.want) 209 } 210 }) 211 } 212 } 213 214 type MsgMatcher struct { 215 sub p2pcommon.SubProtocol 216 } 217 218 func (m MsgMatcher) Matches(x interface{}) bool { 219 return x.(p2pcommon.Message).Subprotocol() == m.sub 220 } 221 222 func (m MsgMatcher) String() string { 223 return "matcher "+m.sub.String() 224 } 225 226 func Test_createMessage(t *testing.T) { 227 type args struct { 228 protocolID p2pcommon.SubProtocol 229 msgBody p2pcommon.MessageBody 230 } 231 tests := []struct { 232 name string 233 args args 234 wantNil bool 235 }{ 236 {"TStatus", args{protocolID:p2pcommon.StatusRequest,msgBody:&types.Status{Version:"11"}}, false}, 237 {"TGOAway", args{protocolID:p2pcommon.GoAway,msgBody:&types.GoAwayNotice{Message:"test"}}, false}, 238 {"TNil", args{protocolID:p2pcommon.StatusRequest,msgBody:nil}, true}, 239 // TODO: Add test cases. 240 } 241 for _, tt := range tests { 242 t.Run(tt.name, func(t *testing.T) { 243 got := createMessage(tt.args.protocolID, p2pcommon.NewMsgID(), tt.args.msgBody) 244 if (got == nil) != tt.wantNil { 245 t.Errorf("createMessage() = %v, want nil %v", got, tt.wantNil) 246 } 247 if got != nil && got.Subprotocol() != tt.args.protocolID { 248 t.Errorf("status.ProtocolID = %v, want %v", got.Subprotocol() , tt.args.protocolID) 249 } 250 }) 251 } 252 }