github.com/aergoio/aergo@v1.3.1/p2p/handshake_test.go (about)

     1  /**
     2   *  @file
     3   *  @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package p2p
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"github.com/aergoio/aergo/config"
    12  	"github.com/aergoio/aergo/p2p/p2pkey"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/aergoio/aergo/p2p/p2pcommon"
    19  	"github.com/aergoio/aergo/p2p/p2pmock"
    20  	"github.com/golang/mock/gomock"
    21  	"github.com/stretchr/testify/assert"
    22  
    23  	"github.com/aergoio/aergo-lib/log"
    24  	"github.com/aergoio/aergo/types"
    25  )
    26  
    27  const (
    28  	sampleKeyFile = "../test/sample.key"
    29  )
    30  
    31  var (
    32  	// sampleID matches the key defined in test config file
    33  	sampleID types.PeerID
    34  )
    35  
    36  func init() {
    37  	sampleID = "16Uiu2HAmP2iRDpPumUbKhNnEngoxAUQWBmCyn7FaYUrkaDAMXJPJ"
    38  	baseCfg := &config.BaseConfig{AuthDir: "test"}
    39  	p2pCfg := &config.P2PConfig{NPKey: sampleKeyFile}
    40  	p2pkey.InitNodeInfo(baseCfg, p2pCfg, "0.0.1-test", logger)
    41  }
    42  
    43  func TestPeerHandshaker_handshakeOutboundPeerTimeout(t *testing.T) {
    44  	var myChainID = &types.ChainID{Magic: "itSmain1"}
    45  
    46  	ctrl := gomock.NewController(t)
    47  	defer ctrl.Finish()
    48  
    49  	logger = log.NewLogger("test")
    50  	// dummyStatusMsg := &types.Status{}
    51  	tests := []struct {
    52  		name    string
    53  		delay   time.Duration
    54  		want    *types.Status
    55  		wantErr bool
    56  	}{
    57  		// {"TNormal", time.Millisecond, dummyStatusMsg, false},
    58  		{"TWriteTimeout", time.Millisecond * 100, nil, true},
    59  		// TODO: Add test cases.
    60  	}
    61  	for _, tt := range tests {
    62  		t.Run(tt.name, func(t *testing.T) {
    63  			mockActor := p2pmock.NewMockActorService(ctrl)
    64  			mockPM := p2pmock.NewMockPeerManager(ctrl)
    65  			mockCA := p2pmock.NewMockChainAccessor(ctrl)
    66  			if !tt.wantErr {
    67  				// these will be called if timeout is not happen, so version handshake is called.
    68  				mockPM.EXPECT().SelfMeta().Return(dummyMeta).Times(2)
    69  				mockActor.EXPECT().GetChainAccessor().Return(mockCA)
    70  				mockCA.EXPECT().GetBestBlock().Return(dummyBestBlock, nil)
    71  			}
    72  
    73  			h := newHandshaker(mockPM, mockActor, logger, myChainID, samplePeerID)
    74  			mockReader := p2pmock.NewMockReadWriteCloser(ctrl)
    75  			mockReader.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
    76  				time.Sleep(tt.delay)
    77  				return 0, fmt.Errorf("must not reach")
    78  			}).AnyTimes()
    79  			mockReader.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
    80  				time.Sleep(tt.delay)
    81  				return len(p), nil
    82  			})
    83  			ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
    84  			defer cancel()
    85  			_, got, err := h.handshakeOutboundPeer(ctx, mockReader)
    86  			//_, got, err := h.handshakeOutboundPeerTimeout(mockReader, mockWriter, time.Millisecond*50)
    87  			if !strings.Contains(err.Error(), "context deadline exceeded") {
    88  				t.Errorf("LegacyWireHandshaker.handshakeOutboundPeer() error = %v, wantErr %v", err, "context deadline exceeded")
    89  				return
    90  			}
    91  			if !reflect.DeepEqual(got, tt.want) {
    92  				t.Errorf("LegacyWireHandshaker.handshakeOutboundPeer() = %v, want %v", got, tt.want)
    93  			}
    94  		})
    95  	}
    96  }
    97  
    98  func TestPeerHandshaker_Select(t *testing.T) {
    99  	ctrl := gomock.NewController(t)
   100  	defer ctrl.Finish()
   101  
   102  	logger = log.NewLogger("test")
   103  	mockActor := p2pmock.NewMockActorService(ctrl)
   104  	mockPM := p2pmock.NewMockPeerManager(ctrl)
   105  
   106  	tests := []struct {
   107  		name     string
   108  		hsHeader p2pcommon.HSHeader
   109  		wantErr  bool
   110  	}{
   111  		{"TVer030", p2pcommon.HSHeader{p2pcommon.MAGICMain, p2pcommon.P2PVersion030}, false},
   112  		{"Tver020", p2pcommon.HSHeader{p2pcommon.MAGICMain, 0x00000200}, true},
   113  		{"TInvalid", p2pcommon.HSHeader{p2pcommon.MAGICMain, 0x000001}, true},
   114  	}
   115  	for _, test := range tests {
   116  		t.Run(test.name, func(t *testing.T) {
   117  			mockReader := p2pmock.NewMockReadWriteCloser(ctrl)
   118  
   119  			h := newHandshaker(mockPM, mockActor, logger, nil, samplePeerID)
   120  
   121  			actual, err := h.selectProtocolVersion(test.hsHeader.Version, mockReader)
   122  			assert.Equal(t, test.wantErr, err != nil)
   123  			if !test.wantErr {
   124  				assert.NotNil(t, actual)
   125  			}
   126  		})
   127  	}
   128  }