github.com/aergoio/aergo@v1.3.1/p2p/v030/v032handshake_test.go (about)

     1  /*
     2   * @file
     3   * @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package v030
     7  
     8  import (
     9  	"context"
    10  	"encoding/hex"
    11  	"fmt"
    12  	"reflect"
    13  	"testing"
    14  
    15  	"github.com/aergoio/aergo-lib/log"
    16  	"github.com/aergoio/aergo/p2p/p2pcommon"
    17  	"github.com/aergoio/aergo/p2p/p2pmock"
    18  	"github.com/aergoio/aergo/p2p/p2putil"
    19  	"github.com/aergoio/aergo/types"
    20  	"github.com/golang/mock/gomock"
    21  )
    22  
    23  func TestV032VersionedHS_DoForOutbound(t *testing.T) {
    24  	ctrl := gomock.NewController(t)
    25  	defer ctrl.Finish()
    26  
    27  	logger := log.NewLogger("test")
    28  	mockActor := p2pmock.NewMockActorService(ctrl)
    29  	mockCA := p2pmock.NewMockChainAccessor(ctrl)
    30  	mockPM := p2pmock.NewMockPeerManager(ctrl)
    31  
    32  	dummyMeta := p2pcommon.PeerMeta{ID: samplePeerID, IPAddress: "dummy.aergo.io"}
    33  	dummyAddr := dummyMeta.ToPeerAddress()
    34  	mockPM.EXPECT().SelfMeta().Return(dummyMeta).AnyTimes()
    35  	dummyBlock := &types.Block{Hash: dummyBlockHash, Header: &types.BlockHeader{BlockNo: dummyBlockHeight}}
    36  	mockActor.EXPECT().GetChainAccessor().Return(mockCA).AnyTimes()
    37  	mockCA.EXPECT().GetBestBlock().Return(dummyBlock, nil).AnyTimes()
    38  	dummyGenHash := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
    39  	diffGenesis := []byte{0xff, 0xfe, 0xfd, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
    40  	dummyStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: dummyGenHash}
    41  	diffGenesisStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: diffGenesis}
    42  	nilGenesisStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: nil}
    43  	nilSenderStatusMsg := &types.Status{ChainID: myChainBytes, Sender: nil, Genesis: dummyGenHash}
    44  	diffStatusMsg := &types.Status{ChainID: theirChainBytes, Sender: &dummyAddr, Genesis: dummyGenHash}
    45  	tests := []struct {
    46  		name       string
    47  		readReturn *types.Status
    48  		readError  error
    49  		writeError error
    50  		want       *types.Status
    51  		wantErr    bool
    52  		wantGoAway bool
    53  	}{
    54  		{"TSuccess", dummyStatusMsg, nil, nil, dummyStatusMsg, false, false},
    55  		{"TUnexpMsg", nil, nil, nil, nil, true, true},
    56  		{"TRFail", dummyStatusMsg, fmt.Errorf("failed"), nil, nil, true, true},
    57  		{"TRNoSender", nilSenderStatusMsg, nil, nil, nil, true, true},
    58  		{"TWFail", dummyStatusMsg, nil, fmt.Errorf("failed"), nil, true, false},
    59  		{"TDiffChain", diffStatusMsg, nil, nil, nil, true, true},
    60  		{"TNilGenesis", nilGenesisStatusMsg, nil, nil, nil, true, true},
    61  		{"TDiffGenesis", diffGenesisStatusMsg, nil, nil, nil, true, true},
    62  	}
    63  	for _, tt := range tests {
    64  		t.Run(tt.name, func(t *testing.T) {
    65  			dummyReader := p2pmock.NewMockReadWriteCloser(ctrl)
    66  			mockRW := p2pmock.NewMockMsgReadWriter(ctrl)
    67  
    68  			var containerMsg *p2pcommon.MessageValue
    69  			if tt.readReturn != nil {
    70  				containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.StatusRequest, p2pcommon.NewMsgID())
    71  				statusBytes, _ := p2putil.MarshalMessageBody(tt.readReturn)
    72  				containerMsg.SetPayload(statusBytes)
    73  			} else {
    74  				containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.AddressesRequest, p2pcommon.NewMsgID())
    75  			}
    76  			mockRW.EXPECT().ReadMsg().Return(containerMsg, tt.readError).AnyTimes()
    77  			if tt.wantGoAway {
    78  				mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.GoAway}).Return(tt.writeError)
    79  			}
    80  			mockRW.EXPECT().WriteMsg(gomock.Any()).Return(tt.writeError).AnyTimes()
    81  
    82  			h := NewV032VersionedHS(mockPM, mockActor, logger, myChainID, samplePeerID, dummyReader, dummyGenHash)
    83  			h.msgRW = mockRW
    84  			got, err := h.DoForOutbound(context.Background())
    85  			if (err != nil) != tt.wantErr {
    86  				t.Errorf("PeerHandshaker.DoForOutbound() error = %v, wantErr %v", err, tt.wantErr)
    87  				return
    88  			}
    89  			if got != nil && tt.want != nil {
    90  				if !reflect.DeepEqual(got.ChainID, tt.want.ChainID) {
    91  					fmt.Printf("got:(%d) %s \n", len(got.ChainID), hex.EncodeToString(got.ChainID))
    92  					fmt.Printf("got:(%d) %s \n", len(tt.want.ChainID), hex.EncodeToString(tt.want.ChainID))
    93  					t.Errorf("PeerHandshaker.DoForOutbound() = %v, want %v", got.ChainID, tt.want.ChainID)
    94  				}
    95  			} else if !reflect.DeepEqual(got, tt.want) {
    96  				t.Errorf("PeerHandshaker.DoForOutbound() = %v, want %v", got, tt.want)
    97  			}
    98  		})
    99  	}
   100  }
   101  
   102  func TestV032VersionedHS_DoForInbound(t *testing.T) {
   103  	ctrl := gomock.NewController(t)
   104  	defer ctrl.Finish()
   105  
   106  	// t.SkipNow()
   107  	logger := log.NewLogger("test")
   108  	mockActor := p2pmock.NewMockActorService(ctrl)
   109  	mockCA := p2pmock.NewMockChainAccessor(ctrl)
   110  	mockPM := p2pmock.NewMockPeerManager(ctrl)
   111  
   112  	dummyMeta := p2pcommon.PeerMeta{ID: samplePeerID, IPAddress: "dummy.aergo.io"}
   113  	dummyAddr := dummyMeta.ToPeerAddress()
   114  	mockPM.EXPECT().SelfMeta().Return(dummyMeta).AnyTimes()
   115  	dummyBlock := &types.Block{Hash: dummyBlockHash, Header: &types.BlockHeader{BlockNo: dummyBlockHeight}}
   116  	//dummyBlkRsp := message.GetBestBlockRsp{Block: dummyBlock}
   117  	mockActor.EXPECT().GetChainAccessor().Return(mockCA).AnyTimes()
   118  	mockCA.EXPECT().GetBestBlock().Return(dummyBlock, nil).AnyTimes()
   119  
   120  	dummyGenHash := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
   121  	diffGenHash := []byte{0xff, 0xfe, 0xfd, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
   122  	dummyStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: dummyGenHash}
   123  	diffGenesisStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: diffGenHash}
   124  	nilGenesisStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr, Genesis: nil}
   125  	nilSenderStatusMsg := &types.Status{ChainID: myChainBytes, Sender: nil, Genesis: dummyGenHash}
   126  	diffStatusMsg := &types.Status{ChainID: theirChainBytes, Sender: &dummyAddr, Genesis: dummyGenHash}
   127  	tests := []struct {
   128  		name       string
   129  		readReturn *types.Status
   130  		readError  error
   131  		writeError error
   132  		want       *types.Status
   133  		wantErr    bool
   134  		wantGoAway bool
   135  	}{
   136  		{"TSuccess", dummyStatusMsg, nil, nil, dummyStatusMsg, false, false},
   137  		{"TUnexpMsg", nil, nil, nil, nil, true, true},
   138  		{"TRFail", dummyStatusMsg, fmt.Errorf("failed"), nil, nil, true, true},
   139  		{"TRNoSender", nilSenderStatusMsg, nil, nil, nil, true, true},
   140  		{"TWFail", dummyStatusMsg, nil, fmt.Errorf("failed"), nil, true, false},
   141  		{"TDiffChain", diffStatusMsg, nil, nil, nil, true, true},
   142  		{"TNilGenesis", nilGenesisStatusMsg, nil, nil, nil, true, true},
   143  		{"TDiffGenesis", diffGenesisStatusMsg, nil, nil, nil, true, true},
   144  	}
   145  	for _, tt := range tests {
   146  		t.Run(tt.name, func(t *testing.T) {
   147  			dummyReader := p2pmock.NewMockReadWriteCloser(ctrl)
   148  			mockRW := p2pmock.NewMockMsgReadWriter(ctrl)
   149  
   150  			containerMsg := &p2pcommon.MessageValue{}
   151  			if tt.readReturn != nil {
   152  				containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.StatusRequest, p2pcommon.NewMsgID())
   153  				statusBytes, _ := p2putil.MarshalMessageBody(tt.readReturn)
   154  				containerMsg.SetPayload(statusBytes)
   155  			} else {
   156  				containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.AddressesRequest, p2pcommon.NewMsgID())
   157  			}
   158  
   159  			mockRW.EXPECT().ReadMsg().Return(containerMsg, tt.readError).AnyTimes()
   160  			if tt.wantGoAway {
   161  				mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.GoAway}).Return(tt.writeError)
   162  			}
   163  			mockRW.EXPECT().WriteMsg(gomock.Any()).Return(tt.writeError).AnyTimes()
   164  
   165  			h := NewV032VersionedHS(mockPM, mockActor, logger, myChainID, samplePeerID, dummyReader, dummyGenHash)
   166  			h.msgRW = mockRW
   167  			got, err := h.DoForInbound(context.Background())
   168  			if (err != nil) != tt.wantErr {
   169  				t.Errorf("PeerHandshaker.DoForInbound() error = %v, wantErr %v", err, tt.wantErr)
   170  				return
   171  			}
   172  			if got != nil && tt.want != nil {
   173  				if !reflect.DeepEqual(got.ChainID, tt.want.ChainID) {
   174  					fmt.Printf("got:(%d) %s \n", len(got.ChainID), hex.EncodeToString(got.ChainID))
   175  					fmt.Printf("got:(%d) %s \n", len(tt.want.ChainID), hex.EncodeToString(tt.want.ChainID))
   176  					t.Errorf("PeerHandshaker.DoForInbound() = %v, want %v", got.ChainID, tt.want.ChainID)
   177  				}
   178  			} else if !reflect.DeepEqual(got, tt.want) {
   179  				t.Errorf("PeerHandshaker.DoForInbound() = %v, want %v", got, tt.want)
   180  			}
   181  		})
   182  	}
   183  }