github.com/aergoio/aergo@v1.3.1/p2p/v030/v030handshake_test.go (about)

     1  /*
     2   * @file
     3   * @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package v030
     7  
     8  import (
     9  	"context"
    10  	"encoding/hex"
    11  	"fmt"
    12  	"reflect"
    13  	"testing"
    14  
    15  	"github.com/aergoio/aergo-lib/log"
    16  	"github.com/aergoio/aergo/config"
    17  	"github.com/aergoio/aergo/p2p/p2pcommon"
    18  	"github.com/aergoio/aergo/p2p/p2pkey"
    19  	"github.com/aergoio/aergo/p2p/p2pmock"
    20  	"github.com/aergoio/aergo/p2p/p2putil"
    21  	"github.com/aergoio/aergo/types"
    22  	"github.com/golang/mock/gomock"
    23  )
    24  
    25  var (
    26  	myChainID, theirChainID       *types.ChainID
    27  	myChainBytes, theirChainBytes []byte
    28  	samplePeerID, _                      = types.IDB58Decode("16Uiu2HAmFqptXPfcdaCdwipB2fhHATgKGVFVPehDAPZsDKSU7jRm")
    29  	dummyBlockHash, _                    = hex.DecodeString("4f461d85e869ade8a0544f8313987c33a9c06534e50c4ad941498299579bd7ac")
    30  	dummyBlockHeight              uint64 = 100215
    31  )
    32  
    33  func init() {
    34  	myChainID = types.NewChainID()
    35  	myChainID.Magic = "itSmain1"
    36  	myChainBytes, _ = myChainID.Bytes()
    37  
    38  	theirChainID = types.NewChainID()
    39  	theirChainID.Read(myChainBytes)
    40  	theirChainID.Magic = "itsdiff2"
    41  	theirChainBytes, _ = theirChainID.Bytes()
    42  
    43  	sampleKeyFile := "../../test/sample.key"
    44  	baseCfg := &config.BaseConfig{AuthDir: "test"}
    45  	p2pCfg := &config.P2PConfig{NPKey: sampleKeyFile}
    46  	p2pkey.InitNodeInfo(baseCfg, p2pCfg, "0.0.1-test", log.NewLogger("v030.test"))
    47  }
    48  
    49  func TestDeepEqual(t *testing.T) {
    50  	b1, _ := myChainID.Bytes()
    51  	b2 := make([]byte, len(b1), len(b1)<<1)
    52  	copy(b2, b1)
    53  
    54  	s1 := &types.Status{ChainID: b1}
    55  	s2 := &types.Status{ChainID: b2}
    56  
    57  	if !reflect.DeepEqual(s1, s2) {
    58  		t.Errorf("byte slice cant do DeepEqual! %v, %v", b1, b2)
    59  	}
    60  
    61  }
    62  
    63  func TestV030StatusHS_doForOutbound(t *testing.T) {
    64  	ctrl := gomock.NewController(t)
    65  	defer ctrl.Finish()
    66  
    67  	logger := log.NewLogger("test")
    68  	mockActor := p2pmock.NewMockActorService(ctrl)
    69  	mockCA := p2pmock.NewMockChainAccessor(ctrl)
    70  	mockPM := p2pmock.NewMockPeerManager(ctrl)
    71  
    72  	dummyMeta := p2pcommon.PeerMeta{ID: samplePeerID, IPAddress: "dummy.aergo.io"}
    73  	dummyAddr := dummyMeta.ToPeerAddress()
    74  	mockPM.EXPECT().SelfMeta().Return(dummyMeta).AnyTimes()
    75  	dummyBlock := &types.Block{Hash: dummyBlockHash, Header: &types.BlockHeader{BlockNo: dummyBlockHeight}}
    76  	mockActor.EXPECT().GetChainAccessor().Return(mockCA).AnyTimes()
    77  	mockCA.EXPECT().GetBestBlock().Return(dummyBlock, nil).AnyTimes()
    78  
    79  	dummyStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr}
    80  	nilSenderStatusMsg := &types.Status{ChainID: myChainBytes, Sender: nil}
    81  	diffStatusMsg := &types.Status{ChainID: theirChainBytes, Sender: &dummyAddr}
    82  	tests := []struct {
    83  		name       string
    84  		readReturn *types.Status
    85  		readError  error
    86  		writeError error
    87  		want       *types.Status
    88  		wantErr    bool
    89  		wantGoAway bool
    90  	}{
    91  		{"TSuccess", dummyStatusMsg, nil, nil, dummyStatusMsg, false, false},
    92  		{"TUnexpMsg", nil, nil, nil, nil, true, true},
    93  		{"TRFail", dummyStatusMsg, fmt.Errorf("failed"), nil, nil, true, true},
    94  		{"TRNoSender", nilSenderStatusMsg, nil, nil, nil, true, true},
    95  		{"TWFail", dummyStatusMsg, nil, fmt.Errorf("failed"), nil, true, false},
    96  		{"TDiffChain", diffStatusMsg, nil, nil, nil, true, true},
    97  	}
    98  	for _, tt := range tests {
    99  		t.Run(tt.name, func(t *testing.T) {
   100  			dummyReader := p2pmock.NewMockReadWriteCloser(ctrl)
   101  			mockRW := p2pmock.NewMockMsgReadWriter(ctrl)
   102  
   103  			var containerMsg *p2pcommon.MessageValue
   104  			if tt.readReturn != nil {
   105  				containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.StatusRequest, p2pcommon.NewMsgID())
   106  				statusBytes, _ := p2putil.MarshalMessageBody(tt.readReturn)
   107  				containerMsg.SetPayload(statusBytes)
   108  			} else {
   109  				containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.AddressesRequest, p2pcommon.NewMsgID())
   110  			}
   111  			mockRW.EXPECT().ReadMsg().Return(containerMsg, tt.readError).AnyTimes()
   112  			if tt.wantGoAway {
   113  				mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.GoAway}).Return(tt.writeError)
   114  			}
   115  			mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.StatusRequest}).Return(tt.writeError).MaxTimes(1)
   116  
   117  			h := NewV030VersionedHS(mockPM, mockActor, logger, myChainID, samplePeerID, dummyReader)
   118  			h.msgRW = mockRW
   119  			got, err := h.DoForOutbound(context.Background())
   120  			if (err != nil) != tt.wantErr {
   121  				t.Errorf("PeerHandshaker.handshakeOutboundPeer() error = %v, wantErr %v", err, tt.wantErr)
   122  				return
   123  			}
   124  			if got != nil && tt.want != nil {
   125  				if !reflect.DeepEqual(got.ChainID, tt.want.ChainID) {
   126  					fmt.Printf("got:(%d) %s \n", len(got.ChainID), hex.EncodeToString(got.ChainID))
   127  					fmt.Printf("got:(%d) %s \n", len(tt.want.ChainID), hex.EncodeToString(tt.want.ChainID))
   128  					t.Errorf("PeerHandshaker.handshakeOutboundPeer() = %v, want %v", got.ChainID, tt.want.ChainID)
   129  				}
   130  			} else if !reflect.DeepEqual(got, tt.want) {
   131  				t.Errorf("PeerHandshaker.handshakeOutboundPeer() = %v, want %v", got, tt.want)
   132  			}
   133  		})
   134  	}
   135  }
   136  
   137  func TestV030StatusHS_handshakeInboundPeer(t *testing.T) {
   138  	ctrl := gomock.NewController(t)
   139  	defer ctrl.Finish()
   140  
   141  	// t.SkipNow()
   142  	logger := log.NewLogger("test")
   143  	mockActor := p2pmock.NewMockActorService(ctrl)
   144  	mockCA := p2pmock.NewMockChainAccessor(ctrl)
   145  	mockPM := p2pmock.NewMockPeerManager(ctrl)
   146  
   147  	dummyMeta := p2pcommon.PeerMeta{ID: samplePeerID, IPAddress: "dummy.aergo.io"}
   148  	dummyAddr := dummyMeta.ToPeerAddress()
   149  	mockPM.EXPECT().SelfMeta().Return(dummyMeta).AnyTimes()
   150  	dummyBlock := &types.Block{Hash: dummyBlockHash, Header: &types.BlockHeader{BlockNo: dummyBlockHeight}}
   151  	//dummyBlkRsp := message.GetBestBlockRsp{Block: dummyBlock}
   152  	mockActor.EXPECT().GetChainAccessor().Return(mockCA).AnyTimes()
   153  	mockCA.EXPECT().GetBestBlock().Return(dummyBlock, nil).AnyTimes()
   154  
   155  	dummyStatusMsg := &types.Status{ChainID: myChainBytes, Sender: &dummyAddr}
   156  	nilSenderStatusMsg := &types.Status{ChainID: myChainBytes, Sender: nil}
   157  	diffStatusMsg := &types.Status{ChainID: theirChainBytes, Sender: &dummyAddr}
   158  	tests := []struct {
   159  		name       string
   160  		readReturn *types.Status
   161  		readError  error
   162  		writeError error
   163  		want       *types.Status
   164  		wantErr    bool
   165  		wantGoAway bool
   166  	}{
   167  		{"TSuccess", dummyStatusMsg, nil, nil, dummyStatusMsg, false, false},
   168  		{"TUnexpMsg", nil, nil, nil, nil, true, true},
   169  		{"TRFail", dummyStatusMsg, fmt.Errorf("failed"), nil, nil, true, true},
   170  		{"TRNoSender", nilSenderStatusMsg, nil, nil, nil, true, true},
   171  		{"TWFail", dummyStatusMsg, nil, fmt.Errorf("failed"), nil, true, false},
   172  		{"TDiffChain", diffStatusMsg, nil, nil, nil, true, true},
   173  	}
   174  	for _, tt := range tests {
   175  		t.Run(tt.name, func(t *testing.T) {
   176  			dummyReader := p2pmock.NewMockReadWriteCloser(ctrl)
   177  			mockRW := p2pmock.NewMockMsgReadWriter(ctrl)
   178  
   179  			containerMsg := &p2pcommon.MessageValue{}
   180  			if tt.readReturn != nil {
   181  				containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.StatusRequest, p2pcommon.NewMsgID())
   182  				statusBytes, _ := p2putil.MarshalMessageBody(tt.readReturn)
   183  				containerMsg.SetPayload(statusBytes)
   184  			} else {
   185  				containerMsg = p2pcommon.NewSimpleMsgVal(p2pcommon.AddressesRequest, p2pcommon.NewMsgID())
   186  			}
   187  
   188  			mockRW.EXPECT().ReadMsg().Return(containerMsg, tt.readError).AnyTimes()
   189  			if tt.wantGoAway {
   190  				mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.GoAway}).Return(tt.writeError)
   191  			}
   192  			mockRW.EXPECT().WriteMsg(&MsgMatcher{p2pcommon.StatusRequest}).Return(tt.writeError).MaxTimes(1)
   193  
   194  			h := NewV030VersionedHS(mockPM, mockActor, logger, myChainID, samplePeerID, dummyReader)
   195  			h.msgRW = mockRW
   196  			got, err := h.DoForInbound(context.Background())
   197  			if (err != nil) != tt.wantErr {
   198  				t.Errorf("PeerHandshaker.handshakeInboundPeer() error = %v, wantErr %v", err, tt.wantErr)
   199  				return
   200  			}
   201  			if got != nil && tt.want != nil {
   202  				if !reflect.DeepEqual(got.ChainID, tt.want.ChainID) {
   203  					fmt.Printf("got:(%d) %s \n", len(got.ChainID), hex.EncodeToString(got.ChainID))
   204  					fmt.Printf("got:(%d) %s \n", len(tt.want.ChainID), hex.EncodeToString(tt.want.ChainID))
   205  					t.Errorf("PeerHandshaker.handshakeOutboundPeer() = %v, want %v", got.ChainID, tt.want.ChainID)
   206  				}
   207  			} else if !reflect.DeepEqual(got, tt.want) {
   208  				t.Errorf("PeerHandshaker.handshakeInboundPeer() = %v, want %v", got, tt.want)
   209  			}
   210  		})
   211  	}
   212  }
   213  
   214  type MsgMatcher struct {
   215  	sub p2pcommon.SubProtocol
   216  }
   217  
   218  func (m MsgMatcher) Matches(x interface{}) bool {
   219  	return x.(p2pcommon.Message).Subprotocol() == m.sub
   220  }
   221  
   222  func (m MsgMatcher) String() string {
   223  	return "matcher "+m.sub.String()
   224  }
   225  
   226  func Test_createMessage(t *testing.T) {
   227  	type args struct {
   228  		protocolID p2pcommon.SubProtocol
   229  		msgBody    p2pcommon.MessageBody
   230  	}
   231  	tests := []struct {
   232  		name string
   233  		args args
   234  		wantNil bool
   235  	}{
   236  		{"TStatus", args{protocolID:p2pcommon.StatusRequest,msgBody:&types.Status{Version:"11"}}, false},
   237  		{"TGOAway", args{protocolID:p2pcommon.GoAway,msgBody:&types.GoAwayNotice{Message:"test"}}, false},
   238  		{"TNil", args{protocolID:p2pcommon.StatusRequest,msgBody:nil}, true},
   239  		// TODO: Add test cases.
   240  	}
   241  	for _, tt := range tests {
   242  		t.Run(tt.name, func(t *testing.T) {
   243  			got := createMessage(tt.args.protocolID, p2pcommon.NewMsgID(), tt.args.msgBody)
   244  			if (got == nil) != tt.wantNil {
   245  				t.Errorf("createMessage() = %v, want nil %v", got, tt.wantNil)
   246  			}
   247  			if got != nil &&  got.Subprotocol() != tt.args.protocolID {
   248  				t.Errorf("status.ProtocolID = %v, want %v", got.Subprotocol() , tt.args.protocolID)
   249  			}
   250  		})
   251  	}
   252  }