github.com/MetalBlockchain/subnet-evm@v0.4.9/peer/network_test.go (about) 1 // (c) 2019-2022, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package peer 5 6 import ( 7 "context" 8 "errors" 9 "fmt" 10 "sync" 11 "sync/atomic" 12 "testing" 13 "time" 14 15 "github.com/MetalBlockchain/metalgo/snow/engine/common" 16 "github.com/MetalBlockchain/metalgo/utils/set" 17 18 "github.com/MetalBlockchain/subnet-evm/plugin/evm/message" 19 20 "github.com/MetalBlockchain/metalgo/codec" 21 "github.com/MetalBlockchain/metalgo/codec/linearcodec" 22 "github.com/MetalBlockchain/metalgo/ids" 23 "github.com/MetalBlockchain/metalgo/version" 24 "github.com/stretchr/testify/assert" 25 26 ethcommon "github.com/ethereum/go-ethereum/common" 27 ) 28 29 var ( 30 defaultPeerVersion = &version.Application{ 31 Major: 1, 32 Minor: 0, 33 Patch: 0, 34 } 35 36 _ message.Request = &HelloRequest{} 37 _ = &HelloResponse{} 38 _ = &GreetingRequest{} 39 _ = &GreetingResponse{} 40 _ = &TestMessage{} 41 42 _ message.RequestHandler = &HelloGreetingRequestHandler{} 43 _ message.RequestHandler = &testRequestHandler{} 44 45 _ common.AppSender = testAppSender{} 46 _ message.GossipMessage = HelloGossip{} 47 _ message.GossipHandler = &testGossipHandler{} 48 49 _ message.CrossChainRequest = &ExampleCrossChainRequest{} 50 _ message.CrossChainRequestHandler = &testCrossChainHandler{} 51 ) 52 53 func TestNetworkDoesNotConnectToItself(t *testing.T) { 54 selfNodeID := ids.GenerateTestNodeID() 55 n := NewNetwork(nil, nil, nil, selfNodeID, 1, 1) 56 assert.NoError(t, n.Connected(context.Background(), selfNodeID, defaultPeerVersion)) 57 assert.EqualValues(t, 0, n.Size()) 58 } 59 60 func TestRequestAnyRequestsRoutingAndResponse(t *testing.T) { 61 callNum := uint32(0) 62 senderWg := &sync.WaitGroup{} 63 var net Network 64 sender := testAppSender{ 65 sendAppRequestFn: func(nodes set.Set[ids.NodeID], requestID uint32, requestBytes []byte) error { 66 nodeID, _ := nodes.Pop() 67 senderWg.Add(1) 68 go func() { 69 defer senderWg.Done() 70 if err := net.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(5*time.Second), requestBytes); err != nil { 71 panic(err) 72 } 73 }() 74 return nil 75 }, 76 sendAppResponseFn: func(nodeID ids.NodeID, requestID uint32, responseBytes []byte) error { 77 senderWg.Add(1) 78 go func() { 79 defer senderWg.Done() 80 if err := net.AppResponse(context.Background(), nodeID, requestID, responseBytes); err != nil { 81 panic(err) 82 } 83 atomic.AddUint32(&callNum, 1) 84 }() 85 return nil 86 }, 87 } 88 89 codecManager := buildCodec(t, HelloRequest{}, HelloResponse{}) 90 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 91 net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 16, 16) 92 net.SetRequestHandler(&HelloGreetingRequestHandler{codec: codecManager}) 93 client := NewNetworkClient(net) 94 nodeID := ids.GenerateTestNodeID() 95 assert.NoError(t, net.Connected(context.Background(), nodeID, defaultPeerVersion)) 96 97 requestMessage := HelloRequest{Message: "this is a request"} 98 99 defer net.Shutdown() 100 assert.NoError(t, net.Connected(context.Background(), nodeID, defaultPeerVersion)) 101 102 totalRequests := 5000 103 numCallsPerRequest := 1 // on sending response 104 totalCalls := totalRequests * numCallsPerRequest 105 106 requestWg := &sync.WaitGroup{} 107 requestWg.Add(totalCalls) 108 for i := 0; i < totalCalls; i++ { 109 go func(wg *sync.WaitGroup) { 110 defer wg.Done() 111 requestBytes, err := message.RequestToBytes(codecManager, requestMessage) 112 assert.NoError(t, err) 113 responseBytes, _, err := client.SendAppRequestAny(defaultPeerVersion, requestBytes) 114 assert.NoError(t, err) 115 assert.NotNil(t, responseBytes) 116 117 var response TestMessage 118 if _, err = codecManager.Unmarshal(responseBytes, &response); err != nil { 119 panic(fmt.Errorf("unexpected error during unmarshal: %w", err)) 120 } 121 assert.Equal(t, "Hi", response.Message) 122 }(requestWg) 123 } 124 125 requestWg.Wait() 126 senderWg.Wait() 127 assert.Equal(t, totalCalls, int(atomic.LoadUint32(&callNum))) 128 } 129 130 func TestRequestRequestsRoutingAndResponse(t *testing.T) { 131 callNum := uint32(0) 132 senderWg := &sync.WaitGroup{} 133 var net Network 134 var lock sync.Mutex 135 contactedNodes := make(map[ids.NodeID]struct{}) 136 sender := testAppSender{ 137 sendAppRequestFn: func(nodes set.Set[ids.NodeID], requestID uint32, requestBytes []byte) error { 138 nodeID, _ := nodes.Pop() 139 lock.Lock() 140 contactedNodes[nodeID] = struct{}{} 141 lock.Unlock() 142 senderWg.Add(1) 143 go func() { 144 defer senderWg.Done() 145 if err := net.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(5*time.Second), requestBytes); err != nil { 146 panic(err) 147 } 148 }() 149 return nil 150 }, 151 sendAppResponseFn: func(nodeID ids.NodeID, requestID uint32, responseBytes []byte) error { 152 senderWg.Add(1) 153 go func() { 154 defer senderWg.Done() 155 if err := net.AppResponse(context.Background(), nodeID, requestID, responseBytes); err != nil { 156 panic(err) 157 } 158 atomic.AddUint32(&callNum, 1) 159 }() 160 return nil 161 }, 162 } 163 164 codecManager := buildCodec(t, HelloRequest{}, HelloResponse{}) 165 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 166 net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 16, 16) 167 net.SetRequestHandler(&HelloGreetingRequestHandler{codec: codecManager}) 168 client := NewNetworkClient(net) 169 170 nodes := []ids.NodeID{ 171 ids.GenerateTestNodeID(), 172 ids.GenerateTestNodeID(), 173 ids.GenerateTestNodeID(), 174 ids.GenerateTestNodeID(), 175 ids.GenerateTestNodeID(), 176 } 177 for _, nodeID := range nodes { 178 assert.NoError(t, net.Connected(context.Background(), nodeID, defaultPeerVersion)) 179 } 180 181 requestMessage := HelloRequest{Message: "this is a request"} 182 defer net.Shutdown() 183 184 totalRequests := 5000 185 numCallsPerRequest := 1 // on sending response 186 totalCalls := totalRequests * numCallsPerRequest 187 188 requestWg := &sync.WaitGroup{} 189 requestWg.Add(totalCalls) 190 nodeIdx := 0 191 for i := 0; i < totalCalls; i++ { 192 nodeIdx = (nodeIdx + 1) % (len(nodes)) 193 nodeID := nodes[nodeIdx] 194 go func(wg *sync.WaitGroup, nodeID ids.NodeID) { 195 defer wg.Done() 196 requestBytes, err := message.RequestToBytes(codecManager, requestMessage) 197 assert.NoError(t, err) 198 responseBytes, err := client.SendAppRequest(nodeID, requestBytes) 199 assert.NoError(t, err) 200 assert.NotNil(t, responseBytes) 201 202 var response TestMessage 203 if _, err = codecManager.Unmarshal(responseBytes, &response); err != nil { 204 panic(fmt.Errorf("unexpected error during unmarshal: %w", err)) 205 } 206 assert.Equal(t, "Hi", response.Message) 207 }(requestWg, nodeID) 208 } 209 210 requestWg.Wait() 211 senderWg.Wait() 212 assert.Equal(t, totalCalls, int(atomic.LoadUint32(&callNum))) 213 for _, nodeID := range nodes { 214 if _, exists := contactedNodes[nodeID]; !exists { 215 t.Fatalf("expected nodeID %s to be contacted but was not", nodeID) 216 } 217 } 218 219 // ensure empty nodeID is not allowed 220 _, err := client.SendAppRequest(ids.EmptyNodeID, []byte("hello there")) 221 assert.Error(t, err) 222 assert.Contains(t, err.Error(), "cannot send request to empty nodeID") 223 } 224 225 func TestRequestMinVersion(t *testing.T) { 226 callNum := uint32(0) 227 nodeID := ids.GenerateTestNodeID() 228 codecManager := buildCodec(t, TestMessage{}) 229 230 var net Network 231 sender := testAppSender{ 232 sendAppRequestFn: func(nodes set.Set[ids.NodeID], reqID uint32, messageBytes []byte) error { 233 atomic.AddUint32(&callNum, 1) 234 assert.True(t, nodes.Contains(nodeID), "request nodes should contain expected nodeID") 235 assert.Len(t, nodes, 1, "request nodes should contain exactly one node") 236 237 go func() { 238 time.Sleep(200 * time.Millisecond) 239 atomic.AddUint32(&callNum, 1) 240 responseBytes, err := codecManager.Marshal(message.Version, TestMessage{Message: "this is a response"}) 241 if err != nil { 242 panic(err) 243 } 244 err = net.AppResponse(context.Background(), nodeID, reqID, responseBytes) 245 assert.NoError(t, err) 246 }() 247 return nil 248 }, 249 } 250 251 // passing nil as codec works because the net.AppRequest is never called 252 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 253 net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 16) 254 client := NewNetworkClient(net) 255 requestMessage := TestMessage{Message: "this is a request"} 256 requestBytes, err := message.RequestToBytes(codecManager, requestMessage) 257 assert.NoError(t, err) 258 assert.NoError(t, 259 net.Connected( 260 context.Background(), 261 nodeID, 262 &version.Application{ 263 Major: 1, 264 Minor: 7, 265 Patch: 1, 266 }, 267 ), 268 ) 269 270 // ensure version does not match 271 responseBytes, _, err := client.SendAppRequestAny( 272 &version.Application{ 273 Major: 2, 274 Minor: 0, 275 Patch: 0, 276 }, 277 requestBytes, 278 ) 279 assert.Equal(t, err.Error(), "no peers found matching version metal/2.0.0 out of 1 peers") 280 assert.Nil(t, responseBytes) 281 282 // ensure version matches and the request goes through 283 responseBytes, _, err = client.SendAppRequestAny(defaultPeerVersion, requestBytes) 284 assert.NoError(t, err) 285 286 var response TestMessage 287 if _, err = codecManager.Unmarshal(responseBytes, &response); err != nil { 288 t.Fatal("unexpected error during unmarshal", err) 289 } 290 assert.Equal(t, "this is a response", response.Message) 291 } 292 293 func TestOnRequestHonoursDeadline(t *testing.T) { 294 var net Network 295 responded := false 296 sender := testAppSender{ 297 sendAppRequestFn: func(nodes set.Set[ids.NodeID], reqID uint32, message []byte) error { 298 return nil 299 }, 300 sendAppResponseFn: func(nodeID ids.NodeID, reqID uint32, message []byte) error { 301 responded = true 302 return nil 303 }, 304 } 305 306 codecManager := buildCodec(t, TestMessage{}) 307 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 308 309 requestBytes, err := marshalStruct(codecManager, TestMessage{Message: "hello there"}) 310 assert.NoError(t, err) 311 312 requestHandler := &testRequestHandler{ 313 processingDuration: 500 * time.Millisecond, 314 } 315 316 net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) 317 net.SetRequestHandler(requestHandler) 318 nodeID := ids.GenerateTestNodeID() 319 320 requestHandler.response, err = marshalStruct(codecManager, TestMessage{Message: "hi there"}) 321 assert.NoError(t, err) 322 err = net.AppRequest(context.Background(), nodeID, 1, time.Now().Add(1*time.Millisecond), requestBytes) 323 assert.NoError(t, err) 324 // ensure the handler didn't get called (as peer.Network would've dropped the request) 325 assert.EqualValues(t, requestHandler.calls, 0) 326 327 requestHandler.processingDuration = 0 328 err = net.AppRequest(context.Background(), nodeID, 2, time.Now().Add(250*time.Millisecond), requestBytes) 329 assert.NoError(t, err) 330 assert.True(t, responded) 331 assert.EqualValues(t, requestHandler.calls, 1) 332 } 333 334 func TestGossip(t *testing.T) { 335 codecManager := buildCodec(t, HelloGossip{}) 336 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 337 338 nodeID := ids.GenerateTestNodeID() 339 var clientNetwork Network 340 wg := &sync.WaitGroup{} 341 sentGossip := false 342 wg.Add(1) 343 sender := testAppSender{ 344 sendAppGossipFn: func(msg []byte) error { 345 go func() { 346 defer wg.Done() 347 err := clientNetwork.AppGossip(context.Background(), nodeID, msg) 348 assert.NoError(t, err) 349 }() 350 sentGossip = true 351 return nil 352 }, 353 } 354 355 gossipHandler := &testGossipHandler{} 356 clientNetwork = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) 357 clientNetwork.SetGossipHandler(gossipHandler) 358 359 assert.NoError(t, clientNetwork.Connected(context.Background(), nodeID, defaultPeerVersion)) 360 361 client := NewNetworkClient(clientNetwork) 362 defer clientNetwork.Shutdown() 363 364 b, err := buildGossip(codecManager, HelloGossip{Msg: "hello there!"}) 365 assert.NoError(t, err) 366 367 err = client.Gossip(b) 368 assert.NoError(t, err) 369 370 wg.Wait() 371 assert.True(t, sentGossip) 372 assert.True(t, gossipHandler.received) 373 } 374 375 func TestHandleInvalidMessages(t *testing.T) { 376 codecManager := buildCodec(t, HelloGossip{}, TestMessage{}) 377 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 378 379 nodeID := ids.GenerateTestNodeID() 380 requestID := uint32(1) 381 sender := testAppSender{} 382 383 clientNetwork := NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) 384 clientNetwork.SetGossipHandler(message.NoopMempoolGossipHandler{}) 385 clientNetwork.SetRequestHandler(&testRequestHandler{}) 386 387 assert.NoError(t, clientNetwork.Connected(context.Background(), nodeID, defaultPeerVersion)) 388 389 defer clientNetwork.Shutdown() 390 391 // Ensure a valid gossip message sent as any App specific message type does not trigger a fatal error 392 gossipMsg, err := buildGossip(codecManager, HelloGossip{Msg: "hello there!"}) 393 assert.NoError(t, err) 394 395 // Ensure a valid request message sent as any App specific message type does not trigger a fatal error 396 requestMessage, err := marshalStruct(codecManager, TestMessage{Message: "Hello"}) 397 assert.NoError(t, err) 398 399 // Ensure a random message sent as any App specific message type does not trigger a fatal error 400 garbageResponse := make([]byte, 10) 401 // Ensure a zero-length message sent as any App specific message type does not trigger a fatal error 402 emptyResponse := make([]byte, 0) 403 // Ensure a nil byte slice sent as any App specific message type does not trigger a fatal error 404 var nilResponse []byte 405 406 // Check for edge cases 407 assert.NoError(t, clientNetwork.AppGossip(context.Background(), nodeID, gossipMsg)) 408 assert.NoError(t, clientNetwork.AppGossip(context.Background(), nodeID, requestMessage)) 409 assert.NoError(t, clientNetwork.AppGossip(context.Background(), nodeID, garbageResponse)) 410 assert.NoError(t, clientNetwork.AppGossip(context.Background(), nodeID, emptyResponse)) 411 assert.NoError(t, clientNetwork.AppGossip(context.Background(), nodeID, nilResponse)) 412 assert.NoError(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), gossipMsg)) 413 assert.NoError(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), requestMessage)) 414 assert.NoError(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), garbageResponse)) 415 assert.NoError(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), emptyResponse)) 416 assert.NoError(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), nilResponse)) 417 assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, gossipMsg)) 418 assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, requestMessage)) 419 assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, garbageResponse)) 420 assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, emptyResponse)) 421 assert.NoError(t, clientNetwork.AppResponse(context.Background(), nodeID, requestID, nilResponse)) 422 assert.NoError(t, clientNetwork.AppRequestFailed(context.Background(), nodeID, requestID)) 423 } 424 425 func TestNetworkPropagatesRequestHandlerError(t *testing.T) { 426 codecManager := buildCodec(t, TestMessage{}) 427 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 428 429 nodeID := ids.GenerateTestNodeID() 430 requestID := uint32(1) 431 sender := testAppSender{} 432 433 clientNetwork := NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) 434 clientNetwork.SetGossipHandler(message.NoopMempoolGossipHandler{}) 435 clientNetwork.SetRequestHandler(&testRequestHandler{err: errors.New("fail")}) // Return an error from the request handler 436 437 assert.NoError(t, clientNetwork.Connected(context.Background(), nodeID, defaultPeerVersion)) 438 439 defer clientNetwork.Shutdown() 440 441 // Ensure a valid request message sent as any App specific message type does not trigger a fatal error 442 requestMessage, err := marshalStruct(codecManager, TestMessage{Message: "Hello"}) 443 assert.NoError(t, err) 444 445 // Check that if the request handler returns an error, it is propagated as a fatal error. 446 assert.Error(t, clientNetwork.AppRequest(context.Background(), nodeID, requestID, time.Now().Add(time.Second), requestMessage)) 447 } 448 449 func TestCrossChainAppRequest(t *testing.T) { 450 var net Network 451 codecManager := buildCodec(t, TestMessage{}) 452 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 453 454 sender := testAppSender{ 455 sendCrossChainAppRequestFn: func(requestingChainID ids.ID, requestID uint32, requestBytes []byte) error { 456 go func() { 457 if err := net.CrossChainAppRequest(context.Background(), requestingChainID, requestID, time.Now().Add(5*time.Second), requestBytes); err != nil { 458 panic(err) 459 } 460 }() 461 return nil 462 }, 463 sendCrossChainAppResponseFn: func(respondingChainID ids.ID, requestID uint32, responseBytes []byte) error { 464 go func() { 465 if err := net.CrossChainAppResponse(context.Background(), respondingChainID, requestID, responseBytes); err != nil { 466 panic(err) 467 } 468 }() 469 return nil 470 }, 471 } 472 473 net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) 474 net.SetCrossChainRequestHandler(&testCrossChainHandler{codec: crossChainCodecManager}) 475 client := NewNetworkClient(net) 476 477 exampleCrossChainRequest := ExampleCrossChainRequest{ 478 Message: "hello this is an example request", 479 } 480 481 crossChainRequest, err := buildCrossChainRequest(crossChainCodecManager, exampleCrossChainRequest) 482 assert.NoError(t, err) 483 484 chainID := ids.ID(ethcommon.BytesToHash([]byte{1, 2, 3, 4, 5})) 485 responseBytes, err := client.SendCrossChainRequest(chainID, crossChainRequest) 486 assert.NoError(t, err) 487 488 var response ExampleCrossChainResponse 489 if _, err = crossChainCodecManager.Unmarshal(responseBytes, &response); err != nil { 490 t.Fatal("unexpected error during unmarshal", err) 491 } 492 assert.Equal(t, "this is an example response", response.Response) 493 } 494 495 func TestCrossChainRequestRequestsRoutingAndResponse(t *testing.T) { 496 var ( 497 callNum uint32 498 senderWg sync.WaitGroup 499 net Network 500 ) 501 502 sender := testAppSender{ 503 sendCrossChainAppRequestFn: func(requestingChainID ids.ID, requestID uint32, requestBytes []byte) error { 504 senderWg.Add(1) 505 go func() { 506 defer senderWg.Done() 507 if err := net.CrossChainAppRequest(context.Background(), requestingChainID, requestID, time.Now().Add(5*time.Second), requestBytes); err != nil { 508 panic(err) 509 } 510 }() 511 return nil 512 }, 513 sendCrossChainAppResponseFn: func(respondingChainID ids.ID, requestID uint32, responseBytes []byte) error { 514 senderWg.Add(1) 515 go func() { 516 defer senderWg.Done() 517 if err := net.CrossChainAppResponse(context.Background(), respondingChainID, requestID, responseBytes); err != nil { 518 panic(err) 519 } 520 atomic.AddUint32(&callNum, 1) 521 }() 522 return nil 523 }, 524 } 525 526 codecManager := buildCodec(t, TestMessage{}) 527 crossChainCodecManager := buildCodec(t, ExampleCrossChainRequest{}, ExampleCrossChainResponse{}) 528 net = NewNetwork(sender, codecManager, crossChainCodecManager, ids.EmptyNodeID, 1, 1) 529 net.SetCrossChainRequestHandler(&testCrossChainHandler{codec: crossChainCodecManager}) 530 client := NewNetworkClient(net) 531 532 exampleCrossChainRequest := ExampleCrossChainRequest{ 533 Message: "hello this is an example request", 534 } 535 536 chainID := ids.ID(ethcommon.BytesToHash([]byte{1, 2, 3, 4, 5})) 537 defer net.Shutdown() 538 539 totalRequests := 500 540 numCallsPerRequest := 1 // on sending response 541 totalCalls := totalRequests * numCallsPerRequest 542 543 var requestWg sync.WaitGroup 544 requestWg.Add(totalCalls) 545 546 for i := 0; i < totalCalls; i++ { 547 go func() { 548 defer requestWg.Done() 549 crossChainRequest, err := buildCrossChainRequest(crossChainCodecManager, exampleCrossChainRequest) 550 assert.NoError(t, err) 551 responseBytes, err := client.SendCrossChainRequest(chainID, crossChainRequest) 552 assert.NoError(t, err) 553 assert.NotNil(t, responseBytes) 554 555 var response ExampleCrossChainResponse 556 if _, err = crossChainCodecManager.Unmarshal(responseBytes, &response); err != nil { 557 panic(fmt.Errorf("unexpected error during unmarshal: %w", err)) 558 } 559 assert.Equal(t, "this is an example response", response.Response) 560 }() 561 } 562 563 requestWg.Wait() 564 senderWg.Wait() 565 assert.Equal(t, totalCalls, int(atomic.LoadUint32(&callNum))) 566 } 567 568 func buildCodec(t *testing.T, types ...interface{}) codec.Manager { 569 codecManager := codec.NewDefaultManager() 570 c := linearcodec.NewDefault() 571 for _, typ := range types { 572 assert.NoError(t, c.RegisterType(typ)) 573 } 574 assert.NoError(t, codecManager.RegisterCodec(message.Version, c)) 575 return codecManager 576 } 577 578 // marshalStruct is a helper method used to marshal an object as `interface{}` 579 // so that the codec is able to include the TypeID in the resulting bytes 580 func marshalStruct(codec codec.Manager, obj interface{}) ([]byte, error) { 581 return codec.Marshal(message.Version, &obj) 582 } 583 584 func buildGossip(codec codec.Manager, msg message.GossipMessage) ([]byte, error) { 585 return codec.Marshal(message.Version, &msg) 586 } 587 588 func buildCrossChainRequest(codec codec.Manager, msg message.CrossChainRequest) ([]byte, error) { 589 return codec.Marshal(message.Version, &msg) 590 } 591 592 type testAppSender struct { 593 sendCrossChainAppRequestFn func(ids.ID, uint32, []byte) error 594 sendCrossChainAppResponseFn func(ids.ID, uint32, []byte) error 595 sendAppRequestFn func(set.Set[ids.NodeID], uint32, []byte) error 596 sendAppResponseFn func(ids.NodeID, uint32, []byte) error 597 sendAppGossipFn func([]byte) error 598 } 599 600 func (t testAppSender) SendCrossChainAppRequest(_ context.Context, chainID ids.ID, requestID uint32, appRequestBytes []byte) error { 601 return t.sendCrossChainAppRequestFn(chainID, requestID, appRequestBytes) 602 } 603 604 func (t testAppSender) SendCrossChainAppResponse(_ context.Context, chainID ids.ID, requestID uint32, appResponseBytes []byte) error { 605 return t.sendCrossChainAppResponseFn(chainID, requestID, appResponseBytes) 606 } 607 608 func (t testAppSender) SendAppGossipSpecific(context.Context, set.Set[ids.NodeID], []byte) error { 609 panic("not implemented") 610 } 611 612 func (t testAppSender) SendAppRequest(_ context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, message []byte) error { 613 return t.sendAppRequestFn(nodeIDs, requestID, message) 614 } 615 616 func (t testAppSender) SendAppResponse(_ context.Context, nodeID ids.NodeID, requestID uint32, message []byte) error { 617 return t.sendAppResponseFn(nodeID, requestID, message) 618 } 619 620 func (t testAppSender) SendAppGossip(_ context.Context, message []byte) error { 621 return t.sendAppGossipFn(message) 622 } 623 624 type HelloRequest struct { 625 Message string `serialize:"true"` 626 } 627 628 func (h HelloRequest) Handle(ctx context.Context, nodeID ids.NodeID, requestID uint32, handler message.RequestHandler) ([]byte, error) { 629 // casting is only necessary for test since RequestHandler does not implement anything at the moment 630 return handler.(TestRequestHandler).HandleHelloRequest(ctx, nodeID, requestID, &h) 631 } 632 633 func (h HelloRequest) String() string { 634 return fmt.Sprintf("HelloRequest(%s)", h.Message) 635 } 636 637 type GreetingRequest struct { 638 Greeting string `serialize:"true"` 639 } 640 641 func (g GreetingRequest) Handle(ctx context.Context, nodeID ids.NodeID, requestID uint32, handler message.RequestHandler) ([]byte, error) { 642 // casting is only necessary for test since RequestHandler does not implement anything at the moment 643 return handler.(TestRequestHandler).HandleGreetingRequest(ctx, nodeID, requestID, &g) 644 } 645 646 func (g GreetingRequest) String() string { 647 return fmt.Sprintf("GreetingRequest(%s)", g.Greeting) 648 } 649 650 type HelloResponse struct { 651 Response string `serialize:"true"` 652 } 653 654 type GreetingResponse struct { 655 Greet string `serialize:"true"` 656 } 657 658 type TestRequestHandler interface { 659 HandleHelloRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, request *HelloRequest) ([]byte, error) 660 HandleGreetingRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, request *GreetingRequest) ([]byte, error) 661 } 662 663 type HelloGreetingRequestHandler struct { 664 message.RequestHandler 665 codec codec.Manager 666 } 667 668 func (h *HelloGreetingRequestHandler) HandleHelloRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, request *HelloRequest) ([]byte, error) { 669 return h.codec.Marshal(message.Version, HelloResponse{Response: "Hi"}) 670 } 671 672 func (h *HelloGreetingRequestHandler) HandleGreetingRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, request *GreetingRequest) ([]byte, error) { 673 return h.codec.Marshal(message.Version, GreetingResponse{Greet: "Hey there"}) 674 } 675 676 type TestMessage struct { 677 Message string `serialize:"true"` 678 } 679 680 func (t TestMessage) Handle(ctx context.Context, nodeID ids.NodeID, requestID uint32, handler message.RequestHandler) ([]byte, error) { 681 return handler.(*testRequestHandler).handleTestRequest(ctx, nodeID, requestID, &t) 682 } 683 684 func (t TestMessage) String() string { 685 return fmt.Sprintf("TestMessage(%s)", t.Message) 686 } 687 688 type HelloGossip struct { 689 Msg string `serialize:"true"` 690 } 691 692 func (h HelloGossip) Handle(handler message.GossipHandler, nodeID ids.NodeID) error { 693 return handler.HandleTxs(nodeID, message.TxsGossip{}) 694 } 695 696 func (h HelloGossip) String() string { 697 return fmt.Sprintf("HelloGossip(%s)", h.Msg) 698 } 699 700 func (h HelloGossip) initialize(_ []byte) { 701 // no op 702 } 703 704 func (h HelloGossip) Bytes() []byte { 705 // no op 706 return nil 707 } 708 709 type testGossipHandler struct { 710 received bool 711 nodeID ids.NodeID 712 msg []byte 713 } 714 715 func (t *testGossipHandler) HandleTxs(nodeID ids.NodeID, msg message.TxsGossip) error { 716 t.received = true 717 t.nodeID = nodeID 718 return nil 719 } 720 721 type testRequestHandler struct { 722 message.RequestHandler 723 calls uint32 724 processingDuration time.Duration 725 response []byte 726 err error 727 } 728 729 func (r *testRequestHandler) handleTestRequest(ctx context.Context, _ ids.NodeID, _ uint32, _ *TestMessage) ([]byte, error) { 730 r.calls++ 731 select { 732 case <-time.After(r.processingDuration): 733 break 734 case <-ctx.Done(): 735 return nil, ctx.Err() 736 } 737 return r.response, r.err 738 } 739 740 type ExampleCrossChainRequest struct { 741 Message string `serialize:"true"` 742 } 743 744 func (e ExampleCrossChainRequest) Handle(ctx context.Context, requestingChainID ids.ID, requestID uint32, handler message.CrossChainRequestHandler) ([]byte, error) { 745 return handler.(*testCrossChainHandler).HandleCrossChainRequest(ctx, requestingChainID, requestID, e) 746 } 747 748 func (e ExampleCrossChainRequest) String() string { 749 return fmt.Sprintf("TestMessage(%s)", e.Message) 750 } 751 752 type ExampleCrossChainResponse struct { 753 Response string `serialize:"true"` 754 } 755 756 type TestCrossChainRequestHandler interface { 757 HandleCrossChainRequest(ctx context.Context, requestingchainID ids.ID, requestID uint32, exampleRequest message.CrossChainRequest) ([]byte, error) 758 } 759 760 type testCrossChainHandler struct { 761 message.CrossChainRequestHandler 762 codec codec.Manager 763 } 764 765 func (t *testCrossChainHandler) HandleCrossChainRequest(ctx context.Context, requestingChainID ids.ID, requestID uint32, exampleRequest message.CrossChainRequest) ([]byte, error) { 766 return t.codec.Marshal(message.Version, ExampleCrossChainResponse{Response: "this is an example response"}) 767 }