github.com/aergoio/aergo@v1.3.1/p2p/handshakev2_test.go (about) 1 /* 2 * @file 3 * @copyright defined in aergo/LICENSE.txt 4 */ 5 6 package p2p 7 8 import ( 9 "bytes" 10 "context" 11 "io" 12 "reflect" 13 "sync/atomic" 14 "testing" 15 "time" 16 17 "github.com/aergoio/aergo-lib/log" 18 "github.com/aergoio/aergo/p2p/p2pcommon" 19 "github.com/aergoio/aergo/p2p/p2pmock" 20 "github.com/aergoio/aergo/types" 21 "github.com/golang/mock/gomock" 22 "github.com/pkg/errors" 23 ) 24 25 func Test_baseWireHandshaker_writeWireHSRequest(t *testing.T) { 26 tests := []struct { 27 name string 28 args p2pcommon.HSHeadReq 29 wantErr bool 30 wantSize int 31 wantErr2 bool 32 }{ 33 {"TEmpty", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, nil}, false, 8, true}, 34 {"TSingle", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031}}, false, 12, false}, 35 {"TMulti", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{0x033333, 0x092fa10, p2pcommon.P2PVersion031, p2pcommon.P2PVersion030}}, false, 24, false}, 36 } 37 for _, tt := range tests { 38 t.Run(tt.name, func(t *testing.T) { 39 h := &baseWireHandshaker{} 40 buffer := bytes.NewBuffer(nil) 41 //wr := bufio.NewWriter(buffer) 42 err := h.writeWireHSRequest(tt.args, buffer) 43 if (err != nil) != tt.wantErr { 44 t.Errorf("baseWireHandshaker.writeWireHSRequest() error = %v, wantErr %v", err, tt.wantErr) 45 } 46 if buffer.Len() != tt.wantSize { 47 t.Errorf("baseWireHandshaker.writeWireHSRequest() error = %v, wantErr %v", buffer.Len(), tt.wantSize) 48 } 49 50 got, err2 := h.readWireHSRequest(buffer) 51 if (err2 != nil) != tt.wantErr2 { 52 t.Errorf("baseWireHandshaker.readWireHSRequest() error = %v, wantErr %v", err2, tt.wantErr2) 53 } 54 if !reflect.DeepEqual(tt.args, got) { 55 t.Errorf("baseWireHandshaker.readWireHSRequest() = %v, want %v", got, tt.args) 56 } 57 if buffer.Len() != 0 { 58 t.Errorf("baseWireHandshaker.readWireHSRequest() error = %v, wantErr %v", buffer.Len(), 0) 59 } 60 61 }) 62 } 63 } 64 65 func Test_baseWireHandshaker_writeWireHSResponse(t *testing.T) { 66 tests := []struct { 67 name string 68 args p2pcommon.HSHeadResp 69 wantErr bool 70 wantSize int 71 wantErr2 bool 72 }{ 73 {"TSingle", p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion030.Uint32()}, false, 8, false}, 74 } 75 for _, tt := range tests { 76 t.Run(tt.name, func(t *testing.T) { 77 h := &baseWireHandshaker{} 78 buffer := bytes.NewBuffer(nil) 79 err := h.writeWireHSResponse(tt.args, buffer) 80 if (err != nil) != tt.wantErr { 81 t.Errorf("baseWireHandshaker.writeWireHSRequest() error = %v, wantErr %v", err, tt.wantErr) 82 } 83 if buffer.Len() != tt.wantSize { 84 t.Errorf("baseWireHandshaker.writeWireHSRequest() error = %v, wantErr %v", buffer.Len(), tt.wantSize) 85 } 86 87 got, err2 := h.readWireHSResp(buffer) 88 if (err2 != nil) != tt.wantErr2 { 89 t.Errorf("baseWireHandshaker.readWireHSRequest() error = %v, wantErr %v", err2, tt.wantErr2) 90 } 91 if !reflect.DeepEqual(tt.args, got) { 92 t.Errorf("baseWireHandshaker.readWireHSRequest() = %v, want %v", got, tt.args) 93 } 94 if buffer.Len() != 0 { 95 t.Errorf("baseWireHandshaker.readWireHSRequest() error = %v, wantErr %v", buffer.Len(), 0) 96 } 97 98 }) 99 } 100 } 101 102 func TestInboundWireHandshker_handleInboundPeer(t *testing.T) { 103 ctrl := gomock.NewController(t) 104 defer ctrl.Finish() 105 106 sampleChainID := &types.ChainID{} 107 sampleStatus := &types.Status{} 108 logger := log.NewLogger("p2p.test") 109 sampleEmptyHSReq := p2pcommon.HSHeadReq{p2pcommon.MAGICMain, nil} 110 sampleEmptyHSResp := p2pcommon.HSHeadResp{p2pcommon.HSError, p2pcommon.HSCodeWrongHSReq} 111 112 type args struct { 113 r []byte 114 } 115 tests := []struct { 116 name string 117 in []byte 118 119 bestVer p2pcommon.P2PVersion 120 ctxCancel int // 0 is not , 1 is during read, 2 is during write 121 vhErr bool // version handshaker failed 122 123 wantW []byte // sent header 124 wantErr bool 125 }{ 126 // All valid 127 {"TCurrentVersion", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion031, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion031.Uint32()}.Marshal(), false}, 128 {"TOldVersion", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{0x000010, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion030, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion030.Uint32()}.Marshal(), false}, 129 // wrong io read 130 {"TWrongRead", sampleEmptyHSReq.Marshal()[:7], p2pcommon.P2PVersion031, 0, false, sampleEmptyHSResp.Marshal(), true}, 131 // empty version 132 {"TEmptyVersion", sampleEmptyHSReq.Marshal(), p2pcommon.P2PVersion031, 0, false, sampleEmptyHSResp.Marshal(), true}, 133 // wrong io write 134 // {"TWrongWrite", sampleEmptyHSReq.Marshal()[:7], sampleEmptyHSResp.Marshal(), true }, 135 // wrong magic 136 {"TWrongMagic", p2pcommon.HSHeadReq{0x0001, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031}}.Marshal(), p2pcommon.P2PVersion031, 0, false, sampleEmptyHSResp.Marshal(), true}, 137 // not supported version (or wrong version) 138 {"TNoVersion", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{0x000010, 0x030405, 0x000101}}.Marshal(), p2pcommon.P2PVersionUnknown, 0, false, p2pcommon.HSHeadResp{p2pcommon.HSError, p2pcommon.HSCodeNoMatchedVersion}.Marshal(), true}, 139 // protocol handshake failed 140 {"TVersionHSFailed", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion031, 0, true, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion031.Uint32()}.Marshal(), true}, 141 142 // timeout while read, no reply to remote 143 {"TTimeoutRead", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion031, 1, false, []byte{}, true}, 144 // timeout while writing, sent but remote not receiving fast 145 {"TTimeoutWrite", p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion031, p2pcommon.P2PVersion030, 0x000101}}.Marshal(), p2pcommon.P2PVersion031, 2, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion031.Uint32()}.Marshal(), true}, 146 } 147 for _, tt := range tests { 148 t.Run(tt.name, func(t *testing.T) { 149 mockPM := p2pmock.NewMockPeerManager(ctrl) 150 mockActor := p2pmock.NewMockActorService(ctrl) 151 mockVM := p2pmock.NewMockVersionedManager(ctrl) 152 mockVH := p2pmock.NewMockVersionedHandshaker(ctrl) 153 154 mockCtx := NewContextTestDouble(tt.ctxCancel) // TODO make mock 155 wbuf := bytes.NewBuffer(nil) 156 dummyReader := &RWCWrapper{bytes.NewBuffer(tt.in), wbuf, nil} 157 dummyMsgRW := p2pmock.NewMockMsgReadWriter(ctrl) 158 159 mockVM.EXPECT().FindBestP2PVersion(gomock.Any()).Return(tt.bestVer).MaxTimes(1) 160 mockVM.EXPECT().GetVersionedHandshaker(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockVH, nil).MaxTimes(1) 161 if !tt.vhErr { 162 mockVH.EXPECT().DoForInbound(mockCtx).Return(sampleStatus, nil).MaxTimes(1) 163 mockVH.EXPECT().GetMsgRW().Return(dummyMsgRW).MaxTimes(1) 164 } else { 165 mockVH.EXPECT().DoForInbound(mockCtx).Return(nil, errors.New("version hs failed")).MaxTimes(1) 166 mockVH.EXPECT().GetMsgRW().Return(nil).MaxTimes(1) 167 } 168 169 h := NewInboundHSHandler(mockPM, mockActor, mockVM, logger, sampleChainID, samplePeerID).(*InboundWireHandshaker) 170 got, got1, err := h.handleInboundPeer(mockCtx, dummyReader) 171 if (err != nil) != tt.wantErr { 172 t.Errorf("InboundWireHandshaker.handleInboundPeer() error = %v, wantErr %v", err, tt.wantErr) 173 } 174 if !bytes.Equal(wbuf.Bytes(), tt.wantW) { 175 t.Errorf("InboundWireHandshaker.handleInboundPeer() send resp %v, want %v", wbuf.Bytes(), tt.wantW) 176 } 177 if !tt.wantErr { 178 if got == nil { 179 t.Errorf("InboundWireHandshaker.handleInboundPeer() got msgrw nil, want not") 180 } 181 if got1 == nil { 182 t.Errorf("InboundWireHandshaker.handleInboundPeer() got status nil, want not") 183 } 184 } 185 }) 186 } 187 } 188 189 func TestOutboundWireHandshaker_handleOutboundPeer(t *testing.T) { 190 ctrl := gomock.NewController(t) 191 defer ctrl.Finish() 192 193 sampleChainID := &types.ChainID{} 194 sampleStatus := &types.Status{} 195 logger := log.NewLogger("p2p.test") 196 // This bytes is actually hard-coded in source handshake_v2.go. 197 outBytes := p2pcommon.HSHeadReq{p2pcommon.MAGICMain, []p2pcommon.P2PVersion{p2pcommon.P2PVersion032, p2pcommon.P2PVersion031}}.Marshal() 198 199 tests := []struct { 200 name string 201 202 remoteRespVer p2pcommon.P2PVersion 203 ctxCancel int // 0 is not , 1 is during write, 2 is during read 204 versionHSerror bool // whether version handshaker return failed or not 205 remoteResponse []byte // emulate response from remote peer 206 207 wantErr bool 208 }{ 209 // remote listening peer accept my best p2p version 210 {"TCurrentVersion", p2pcommon.P2PVersion032, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion032.Uint32()}.Marshal(), false}, 211 // remote listening peer can connect, but old p2p version 212 {"TOldVersion", p2pcommon.P2PVersion031, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion031.Uint32()}.Marshal(), false}, 213 {"TOlderVersion", p2pcommon.P2PVersion030, 0, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion030.Uint32()}.Marshal(), false}, 214 // wrong io read 215 {"TWrongResp", p2pcommon.P2PVersion032, 0, false, outBytes[:6], true}, 216 // {"TWrongWrite", sampleEmptyHSReq.Marshal()[:7], sampleEmptyHSResp.Marshal(), true }, 217 // wrong magic 218 {"TWrongMagic", p2pcommon.P2PVersion032, 0, false, p2pcommon.HSHeadResp{p2pcommon.HSError, p2pcommon.HSCodeWrongHSReq}.Marshal(), true}, 219 // not supported version (or wrong version) 220 {"TNoVersion", p2pcommon.P2PVersionUnknown, 0, false, p2pcommon.HSHeadResp{p2pcommon.HSError, p2pcommon.HSCodeNoMatchedVersion}.Marshal(), true}, 221 // protocol handshake failed 222 {"TVersionHSFailed", p2pcommon.P2PVersion032, 0, true, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion032.Uint32()}.Marshal(), true}, 223 224 // timeout while read, no reply to remote 225 {"TTimeoutRead", p2pcommon.P2PVersion031, 1, false, []byte{}, true}, 226 // timeout while writing, sent but remote not receiving fast 227 {"TTimeoutWrite", p2pcommon.P2PVersion032, 2, false, p2pcommon.HSHeadResp{p2pcommon.MAGICMain, p2pcommon.P2PVersion032.Uint32()}.Marshal(), true}, 228 } 229 for _, tt := range tests { 230 t.Run(tt.name, func(t *testing.T) { 231 mockPM := p2pmock.NewMockPeerManager(ctrl) 232 mockActor := p2pmock.NewMockActorService(ctrl) 233 mockVM := p2pmock.NewMockVersionedManager(ctrl) 234 mockVH := p2pmock.NewMockVersionedHandshaker(ctrl) 235 236 mockCtx := NewContextTestDouble(tt.ctxCancel) // TODO make mock 237 wbuf := bytes.NewBuffer(nil) 238 dummyRWC := &RWCWrapper{bytes.NewBuffer(tt.remoteResponse), wbuf, nil} 239 dummyMsgRW := p2pmock.NewMockMsgReadWriter(ctrl) 240 241 mockVM.EXPECT().GetVersionedHandshaker(tt.remoteRespVer, gomock.Any(), gomock.Any()).Return(mockVH, nil).MaxTimes(1) 242 if tt.versionHSerror { 243 mockVH.EXPECT().DoForOutbound(mockCtx).Return(nil, errors.New("version hs failed")).MaxTimes(1) 244 mockVH.EXPECT().GetMsgRW().Return(nil).MaxTimes(1) 245 } else { 246 mockVH.EXPECT().DoForOutbound(mockCtx).Return(sampleStatus, nil).MaxTimes(1) 247 mockVH.EXPECT().GetMsgRW().Return(dummyMsgRW).MaxTimes(1) 248 } 249 250 h := NewOutboundHSHandler(mockPM, mockActor, mockVM, logger, sampleChainID, samplePeerID).(*OutboundWireHandshaker) 251 got, got1, err := h.handleOutboundPeer(mockCtx, dummyRWC) 252 if (err != nil) != tt.wantErr { 253 t.Errorf("OutboundWireHandshaker.handleOutboundPeer() error = %v, wantErr %v", err, tt.wantErr) 254 } 255 if !bytes.Equal(wbuf.Bytes(), outBytes) { 256 t.Errorf("OutboundWireHandshaker.handleOutboundPeer() send resp %v, want %v", wbuf.Bytes(), outBytes) 257 } 258 if !tt.wantErr { 259 if got == nil { 260 t.Errorf("OutboundWireHandshaker.handleOutboundPeer() got msgrw nil, want not") 261 } 262 if got1 == nil { 263 t.Errorf("OutboundWireHandshaker.handleOutboundPeer() got status nil, want not") 264 } 265 } 266 }) 267 } 268 } 269 270 type RWCWrapper struct { 271 r io.Reader 272 w io.Writer 273 c io.Closer 274 } 275 276 func (rwc RWCWrapper) Read(p []byte) (n int, err error) { 277 return rwc.r.Read(p) 278 } 279 280 func (rwc RWCWrapper) Write(p []byte) (n int, err error) { 281 return rwc.w.Write(p) 282 } 283 284 func (rwc RWCWrapper) Close() error { 285 return rwc.c.Close() 286 } 287 288 type ContextTestDouble struct { 289 doneChannel chan struct{} 290 expire uint32 291 callCnt uint32 292 } 293 294 var _ context.Context = (*ContextTestDouble)(nil) 295 296 func NewContextTestDouble(expire int) *ContextTestDouble { 297 if expire <= 0 { 298 expire = 9999999 299 } 300 return &ContextTestDouble{expire: uint32(expire), doneChannel: make(chan struct{}, 1)} 301 } 302 303 func (*ContextTestDouble) Deadline() (deadline time.Time, ok bool) { 304 panic("implement me") 305 } 306 307 func (c *ContextTestDouble) Done() <-chan struct{} { 308 current := atomic.AddUint32(&c.callCnt, 1) 309 if current >= c.expire { 310 c.doneChannel <- struct{}{} 311 } 312 return c.doneChannel 313 } 314 315 func (c *ContextTestDouble) Err() error { 316 if atomic.LoadUint32(&c.callCnt) >= c.expire { 317 return errors.New("timeout") 318 } else { 319 return nil 320 } 321 } 322 323 func (*ContextTestDouble) Value(key interface{}) interface{} { 324 panic("implement me") 325 }