github.com/aergoio/aergo@v1.3.1/p2p/handshakev2_test.go (about)

     1  /*
     2   * @file
     3   * @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package p2p
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"io"
    12  	"reflect"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/aergoio/aergo-lib/log"
    18  	"github.com/aergoio/aergo/p2p/p2pcommon"
    19  	"github.com/aergoio/aergo/p2p/p2pmock"
    20  	"github.com/aergoio/aergo/types"
    21  	"github.com/golang/mock/gomock"
    22  	"github.com/pkg/errors"
    23  )
    24  
    25  func Test_baseWireHandshaker_writeWireHSRequest(t *testing.T) {
    26  	tests := []struct {
    27  		name     string
    28  		args     p2pcommon.HSHeadReq
    29  		wantErr  bool
    30  		wantSize int
    31  		wantErr2 bool
    32  	}{
    33  		{"TEmpty", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, nil}, false, 8, true},
    34  		{"TSingle", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031}}, false, 12, false},
    35  		{"TMulti", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{0x033333, 0x092fa10, p2pcommon.P2PVersion031, p2pcommon.P2PVersion030}}, false, 24, false},
    36  	}
    37  	for _, tt := range tests {
    38  		t.Run(tt.name, func(t *testing.T) {
    39  			h := &baseWireHandshaker{}
    40  			buffer := bytes.NewBuffer(nil)
    41  			//wr := bufio.NewWriter(buffer)
    42  			err := h.writeWireHSRequest(tt.args, buffer)
    43  			if (err != nil) != tt.wantErr {
    44  				t.Errorf("baseWireHandshaker.writeWireHSRequest() error = %v, wantErr %v", err, tt.wantErr)
    45  			}
    46  			if buffer.Len() != tt.wantSize {
    47  				t.Errorf("baseWireHandshaker.writeWireHSRequest() error = %v, wantErr %v", buffer.Len(), tt.wantSize)
    48  			}
    49  
    50  			got, err2 := h.readWireHSRequest(buffer)
    51  			if (err2 != nil) != tt.wantErr2 {
    52  				t.Errorf("baseWireHandshaker.readWireHSRequest() error = %v, wantErr %v", err2, tt.wantErr2)
    53  			}
    54  			if !reflect.DeepEqual(tt.args, got) {
    55  				t.Errorf("baseWireHandshaker.readWireHSRequest() = %v, want %v", got, tt.args)
    56  			}
    57  			if buffer.Len() != 0 {
    58  				t.Errorf("baseWireHandshaker.readWireHSRequest() error = %v, wantErr %v", buffer.Len(), 0)
    59  			}
    60  
    61  		})
    62  	}
    63  }
    64  
    65  func Test_baseWireHandshaker_writeWireHSResponse(t *testing.T) {
    66  	tests := []struct {
    67  		name     string
    68  		args     p2pcommon.HSHeadResp
    69  		wantErr  bool
    70  		wantSize int
    71  		wantErr2 bool
    72  	}{
    73  		{"TSingle", p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion030.Uint32()}, false, 8, false},
    74  	}
    75  	for _, tt := range tests {
    76  		t.Run(tt.name, func(t *testing.T) {
    77  			h := &baseWireHandshaker{}
    78  			buffer := bytes.NewBuffer(nil)
    79  			err := h.writeWireHSResponse(tt.args, buffer)
    80  			if (err != nil) != tt.wantErr {
    81  				t.Errorf("baseWireHandshaker.writeWireHSRequest() error = %v, wantErr %v", err, tt.wantErr)
    82  			}
    83  			if buffer.Len() != tt.wantSize {
    84  				t.Errorf("baseWireHandshaker.writeWireHSRequest() error = %v, wantErr %v", buffer.Len(), tt.wantSize)
    85  			}
    86  
    87  			got, err2 := h.readWireHSResp(buffer)
    88  			if (err2 != nil) != tt.wantErr2 {
    89  				t.Errorf("baseWireHandshaker.readWireHSRequest() error = %v, wantErr %v", err2, tt.wantErr2)
    90  			}
    91  			if !reflect.DeepEqual(tt.args, got) {
    92  				t.Errorf("baseWireHandshaker.readWireHSRequest() = %v, want %v", got, tt.args)
    93  			}
    94  			if buffer.Len() != 0 {
    95  				t.Errorf("baseWireHandshaker.readWireHSRequest() error = %v, wantErr %v", buffer.Len(), 0)
    96  			}
    97  
    98  		})
    99  	}
   100  }
   101  
   102  func TestInboundWireHandshker_handleInboundPeer(t *testing.T) {
   103  	ctrl := gomock.NewController(t)
   104  	defer ctrl.Finish()
   105  
   106  	sampleChainID := &types.ChainID{}
   107  	sampleStatus := &types.Status{}
   108  	logger := log.NewLogger("p2p.test")
   109  	sampleEmptyHSReq := p2pcommon.HSHeadReq{p2pcommon.MAGICMain, nil}
   110  	sampleEmptyHSResp := p2pcommon.HSHeadResp{p2pcommon.HSError, p2pcommon.HSCodeWrongHSReq}
   111  
   112  	type args struct {
   113  		r []byte
   114  	}
   115  	tests := []struct {
   116  		name string
   117  		in   []byte
   118  
   119  		bestVer   p2pcommon.P2PVersion
   120  		ctxCancel int  // 0 is not , 1 is during read, 2 is during write
   121  		vhErr     bool // version handshaker failed
   122  
   123  		wantW   []byte // sent header
   124  		wantErr bool
   125  	}{
   126  		// All valid
   127  		{"TCurrentVersion", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion031, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion031.Uint32()}.Marshal(), false},
   128  		{"TOldVersion", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{0x000010, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion030, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion030.Uint32()}.Marshal(), false},
   129  		// wrong io read
   130  		{"TWrongRead", sampleEmptyHSReq.Marshal()[:7], p2pcommon.P2PVersion031, 0, false, sampleEmptyHSResp.Marshal(), true},
   131  		// empty version
   132  		{"TEmptyVersion", sampleEmptyHSReq.Marshal(), p2pcommon.P2PVersion031, 0, false, sampleEmptyHSResp.Marshal(), true},
   133  		// wrong io write
   134  		// {"TWrongWrite", sampleEmptyHSReq.Marshal()[:7], sampleEmptyHSResp.Marshal(), true },
   135  		// wrong magic
   136  		{"TWrongMagic", p2pcommon.HSHeadReq{0x0001, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031}}.Marshal(), p2pcommon.P2PVersion031, 0, false, sampleEmptyHSResp.Marshal(), true},
   137  		// not supported version (or wrong version)
   138  		{"TNoVersion", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{0x000010, 0x030405, 0x000101}}.Marshal(), p2pcommon.P2PVersionUnknown, 0, false, p2pcommon.HSHeadResp{p2pcommon.HSError, p2pcommon.HSCodeNoMatchedVersion}.Marshal(), true},
   139  		// protocol handshake failed
   140  		{"TVersionHSFailed", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion031, 0, true, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion031.Uint32()}.Marshal(), true},
   141  
   142  		// timeout while read, no reply to remote
   143  		{"TTimeoutRead", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion031, 1, false, []byte{}, true},
   144  		// timeout while writing, sent but remote not receiving fast
   145  		{"TTimeoutWrite", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion031, 2, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion031.Uint32()}.Marshal(), true},
   146  	}
   147  	for _, tt := range tests {
   148  		t.Run(tt.name, func(t *testing.T) {
   149  			mockPM := p2pmock.NewMockPeerManager(ctrl)
   150  			mockActor := p2pmock.NewMockActorService(ctrl)
   151  			mockVM := p2pmock.NewMockVersionedManager(ctrl)
   152  			mockVH := p2pmock.NewMockVersionedHandshaker(ctrl)
   153  
   154  			mockCtx := NewContextTestDouble(tt.ctxCancel) // TODO make mock
   155  			wbuf := bytes.NewBuffer(nil)
   156  			dummyReader := &RWCWrapper{bytes.NewBuffer(tt.in), wbuf, nil}
   157  			dummyMsgRW := p2pmock.NewMockMsgReadWriter(ctrl)
   158  
   159  			mockVM.EXPECT().FindBestP2PVersion(gomock.Any()).Return(tt.bestVer).MaxTimes(1)
   160  			mockVM.EXPECT().GetVersionedHandshaker(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockVH, nil).MaxTimes(1)
   161  			if !tt.vhErr {
   162  				mockVH.EXPECT().DoForInbound(mockCtx).Return(sampleStatus, nil).MaxTimes(1)
   163  				mockVH.EXPECT().GetMsgRW().Return(dummyMsgRW).MaxTimes(1)
   164  			} else {
   165  				mockVH.EXPECT().DoForInbound(mockCtx).Return(nil, errors.New("version hs failed")).MaxTimes(1)
   166  				mockVH.EXPECT().GetMsgRW().Return(nil).MaxTimes(1)
   167  			}
   168  
   169  			h := NewInboundHSHandler(mockPM, mockActor, mockVM, logger, sampleChainID, samplePeerID).(*InboundWireHandshaker)
   170  			got, got1, err := h.handleInboundPeer(mockCtx, dummyReader)
   171  			if (err != nil) != tt.wantErr {
   172  				t.Errorf("InboundWireHandshaker.handleInboundPeer() error = %v, wantErr %v", err, tt.wantErr)
   173  			}
   174  			if !bytes.Equal(wbuf.Bytes(), tt.wantW) {
   175  				t.Errorf("InboundWireHandshaker.handleInboundPeer() send resp %v, want %v", wbuf.Bytes(), tt.wantW)
   176  			}
   177  			if !tt.wantErr {
   178  				if got == nil {
   179  					t.Errorf("InboundWireHandshaker.handleInboundPeer() got msgrw nil, want not")
   180  				}
   181  				if got1 == nil {
   182  					t.Errorf("InboundWireHandshaker.handleInboundPeer() got status nil, want not")
   183  				}
   184  			}
   185  		})
   186  	}
   187  }
   188  
   189  func TestOutboundWireHandshaker_handleOutboundPeer(t *testing.T) {
   190  	ctrl := gomock.NewController(t)
   191  	defer ctrl.Finish()
   192  
   193  	sampleChainID := &types.ChainID{}
   194  	sampleStatus := &types.Status{}
   195  	logger := log.NewLogger("p2p.test")
   196  	// This bytes is actually hard-coded in source handshake_v2.go.
   197  	outBytes := p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion032, p2pcommon.P2PVersion031}}.Marshal()
   198  
   199  	tests := []struct {
   200  		name string
   201  
   202  		remoteRespVer  p2pcommon.P2PVersion
   203  		ctxCancel      int    // 0 is not , 1 is during write, 2 is during read
   204  		versionHSerror bool   // whether version handshaker return failed or not
   205  		remoteResponse []byte // emulate response from remote peer
   206  
   207  		wantErr bool
   208  	}{
   209  		// remote listening peer accept my best p2p version
   210  		{"TCurrentVersion", p2pcommon.P2PVersion032, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion032.Uint32()}.Marshal(), false},
   211  		// remote listening peer can connect, but old p2p version
   212  		{"TOldVersion", p2pcommon.P2PVersion031, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion031.Uint32()}.Marshal(), false},
   213  		{"TOlderVersion", p2pcommon.P2PVersion030, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion030.Uint32()}.Marshal(), false},
   214  		// wrong io read
   215  		{"TWrongResp", p2pcommon.P2PVersion032, 0, false, outBytes[:6], true},
   216  		// {"TWrongWrite", sampleEmptyHSReq.Marshal()[:7], sampleEmptyHSResp.Marshal(), true },
   217  		// wrong magic
   218  		{"TWrongMagic", p2pcommon.P2PVersion032, 0, false, p2pcommon.HSHeadResp{p2pcommon.HSError, p2pcommon.HSCodeWrongHSReq}.Marshal(), true},
   219  		// not supported version (or wrong version)
   220  		{"TNoVersion", p2pcommon.P2PVersionUnknown, 0, false, p2pcommon.HSHeadResp{p2pcommon.HSError, p2pcommon.HSCodeNoMatchedVersion}.Marshal(), true},
   221  		// protocol handshake failed
   222  		{"TVersionHSFailed", p2pcommon.P2PVersion032, 0, true, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion032.Uint32()}.Marshal(), true},
   223  
   224  		// timeout while read, no reply to remote
   225  		{"TTimeoutRead", p2pcommon.P2PVersion031, 1, false, []byte{}, true},
   226  		// timeout while writing, sent but remote not receiving fast
   227  		{"TTimeoutWrite", p2pcommon.P2PVersion032, 2, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion032.Uint32()}.Marshal(), true},
   228  	}
   229  	for _, tt := range tests {
   230  		t.Run(tt.name, func(t *testing.T) {
   231  			mockPM := p2pmock.NewMockPeerManager(ctrl)
   232  			mockActor := p2pmock.NewMockActorService(ctrl)
   233  			mockVM := p2pmock.NewMockVersionedManager(ctrl)
   234  			mockVH := p2pmock.NewMockVersionedHandshaker(ctrl)
   235  
   236  			mockCtx := NewContextTestDouble(tt.ctxCancel) // TODO make mock
   237  			wbuf := bytes.NewBuffer(nil)
   238  			dummyRWC := &RWCWrapper{bytes.NewBuffer(tt.remoteResponse), wbuf, nil}
   239  			dummyMsgRW := p2pmock.NewMockMsgReadWriter(ctrl)
   240  
   241  			mockVM.EXPECT().GetVersionedHandshaker(tt.remoteRespVer, gomock.Any(), gomock.Any()).Return(mockVH, nil).MaxTimes(1)
   242  			if tt.versionHSerror {
   243  				mockVH.EXPECT().DoForOutbound(mockCtx).Return(nil, errors.New("version hs failed")).MaxTimes(1)
   244  				mockVH.EXPECT().GetMsgRW().Return(nil).MaxTimes(1)
   245  			} else {
   246  				mockVH.EXPECT().DoForOutbound(mockCtx).Return(sampleStatus, nil).MaxTimes(1)
   247  				mockVH.EXPECT().GetMsgRW().Return(dummyMsgRW).MaxTimes(1)
   248  			}
   249  
   250  			h := NewOutboundHSHandler(mockPM, mockActor, mockVM, logger, sampleChainID, samplePeerID).(*OutboundWireHandshaker)
   251  			got, got1, err := h.handleOutboundPeer(mockCtx, dummyRWC)
   252  			if (err != nil) != tt.wantErr {
   253  				t.Errorf("OutboundWireHandshaker.handleOutboundPeer() error = %v, wantErr %v", err, tt.wantErr)
   254  			}
   255  			if !bytes.Equal(wbuf.Bytes(), outBytes) {
   256  				t.Errorf("OutboundWireHandshaker.handleOutboundPeer() send resp %v, want %v", wbuf.Bytes(), outBytes)
   257  			}
   258  			if !tt.wantErr {
   259  				if got == nil {
   260  					t.Errorf("OutboundWireHandshaker.handleOutboundPeer() got msgrw nil, want not")
   261  				}
   262  				if got1 == nil {
   263  					t.Errorf("OutboundWireHandshaker.handleOutboundPeer() got status nil, want not")
   264  				}
   265  			}
   266  		})
   267  	}
   268  }
   269  
   270  type RWCWrapper struct {
   271  	r io.Reader
   272  	w io.Writer
   273  	c io.Closer
   274  }
   275  
   276  func (rwc RWCWrapper) Read(p []byte) (n int, err error) {
   277  	return rwc.r.Read(p)
   278  }
   279  
   280  func (rwc RWCWrapper) Write(p []byte) (n int, err error) {
   281  	return rwc.w.Write(p)
   282  }
   283  
   284  func (rwc RWCWrapper) Close() error {
   285  	return rwc.c.Close()
   286  }
   287  
   288  type ContextTestDouble struct {
   289  	doneChannel chan struct{}
   290  	expire      uint32
   291  	callCnt     uint32
   292  }
   293  
   294  var _ context.Context = (*ContextTestDouble)(nil)
   295  
   296  func NewContextTestDouble(expire int) *ContextTestDouble {
   297  	if expire <= 0 {
   298  		expire = 9999999
   299  	}
   300  	return &ContextTestDouble{expire: uint32(expire), doneChannel: make(chan struct{}, 1)}
   301  }
   302  
   303  func (*ContextTestDouble) Deadline() (deadline time.Time, ok bool) {
   304  	panic("implement me")
   305  }
   306  
   307  func (c *ContextTestDouble) Done() <-chan struct{} {
   308  	current := atomic.AddUint32(&c.callCnt, 1)
   309  	if current >= c.expire {
   310  		c.doneChannel <- struct{}{}
   311  	}
   312  	return c.doneChannel
   313  }
   314  
   315  func (c *ContextTestDouble) Err() error {
   316  	if atomic.LoadUint32(&c.callCnt) >= c.expire {
   317  		return errors.New("timeout")
   318  	} else {
   319  		return nil
   320  	}
   321  }
   322  
   323  func (*ContextTestDouble) Value(key interface{}) interface{} {
   324  	panic("implement me")
   325  }