github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/gossip/comm/comm_test.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm 8 9 import ( 10 "bytes" 11 "context" 12 "crypto/hmac" 13 "crypto/sha256" 14 "crypto/tls" 15 "errors" 16 "fmt" 17 "io" 18 "math/rand" 19 "net" 20 "strconv" 21 "sync" 22 "sync/atomic" 23 "testing" 24 "time" 25 26 "github.com/hechain20/hechain/bccsp/factory" 27 "github.com/hechain20/hechain/common/flogging" 28 "github.com/hechain20/hechain/common/metrics/disabled" 29 "github.com/hechain20/hechain/gossip/api" 30 "github.com/hechain20/hechain/gossip/api/mocks" 31 gmocks "github.com/hechain20/hechain/gossip/comm/mocks" 32 "github.com/hechain20/hechain/gossip/common" 33 "github.com/hechain20/hechain/gossip/identity" 34 "github.com/hechain20/hechain/gossip/metrics" 35 "github.com/hechain20/hechain/gossip/protoext" 36 "github.com/hechain20/hechain/gossip/util" 37 "github.com/hechain20/hechain/internal/pkg/comm" 38 cb "github.com/hyperledger/fabric-protos-go/common" 39 proto "github.com/hyperledger/fabric-protos-go/gossip" 40 "github.com/stretchr/testify/mock" 41 "github.com/stretchr/testify/require" 42 "google.golang.org/grpc" 43 "google.golang.org/grpc/credentials" 44 ) 45 46 func init() { 47 util.SetupTestLogging() 48 rand.Seed(time.Now().UnixNano()) 49 factory.InitFactories(nil) 50 naiveSec.On("OrgByPeerIdentity", mock.Anything).Return(api.OrgIdentityType{}) 51 } 52 53 var testCommConfig = CommConfig{ 54 DialTimeout: 300 * time.Millisecond, 55 ConnTimeout: DefConnTimeout, 56 RecvBuffSize: DefRecvBuffSize, 57 SendBuffSize: DefSendBuffSize, 58 } 59 60 func acceptAll(msg interface{}) bool { 61 return true 62 } 63 64 var noopPurgeIdentity = func(_ common.PKIidType, _ api.PeerIdentityType) { 65 } 66 67 var ( 68 naiveSec = &naiveSecProvider{} 69 hmacKey = []byte{0, 0, 0} 70 disabledMetrics = metrics.NewGossipMetrics(&disabled.Provider{}).CommMetrics 71 ) 72 73 type naiveSecProvider struct { 74 mocks.SecurityAdvisor 75 } 76 77 func (nsp *naiveSecProvider) OrgByPeerIdentity(identity api.PeerIdentityType) api.OrgIdentityType { 78 return nsp.SecurityAdvisor.Called(identity).Get(0).(api.OrgIdentityType) 79 } 80 81 func (*naiveSecProvider) Expiration(peerIdentity api.PeerIdentityType) (time.Time, error) { 82 return time.Now().Add(time.Hour), nil 83 } 84 85 func (*naiveSecProvider) ValidateIdentity(peerIdentity api.PeerIdentityType) error { 86 return nil 87 } 88 89 // GetPKIidOfCert returns the PKI-ID of a peer's identity 90 func (*naiveSecProvider) GetPKIidOfCert(peerIdentity api.PeerIdentityType) common.PKIidType { 91 return common.PKIidType(peerIdentity) 92 } 93 94 // VerifyBlock returns nil if the block is properly signed, 95 // else returns error 96 func (*naiveSecProvider) VerifyBlock(channelID common.ChannelID, seqNum uint64, signedBlock *cb.Block) error { 97 return nil 98 } 99 100 // Sign signs msg with this peer's signing key and outputs 101 // the signature if no error occurred. 102 func (*naiveSecProvider) Sign(msg []byte) ([]byte, error) { 103 mac := hmac.New(sha256.New, hmacKey) 104 mac.Write(msg) 105 return mac.Sum(nil), nil 106 } 107 108 // Verify checks that signature is a valid signature of message under a peer's verification key. 109 // If the verification succeeded, Verify returns nil meaning no error occurred. 110 // If peerCert is nil, then the signature is verified against this peer's verification key. 111 func (*naiveSecProvider) Verify(peerIdentity api.PeerIdentityType, signature, message []byte) error { 112 mac := hmac.New(sha256.New, hmacKey) 113 mac.Write(message) 114 expected := mac.Sum(nil) 115 if !bytes.Equal(signature, expected) { 116 return fmt.Errorf("Wrong certificate:%v, %v", signature, message) 117 } 118 return nil 119 } 120 121 // VerifyByChannel verifies a peer's signature on a message in the context 122 // of a specific channel 123 func (*naiveSecProvider) VerifyByChannel(_ common.ChannelID, _ api.PeerIdentityType, _, _ []byte) error { 124 return nil 125 } 126 127 func newCommInstanceOnlyWithMetrics(t *testing.T, commMetrics *metrics.CommMetrics, sec *naiveSecProvider, 128 gRPCServer *comm.GRPCServer, certs *common.TLSCertificates, 129 secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm { 130 _, portString, err := net.SplitHostPort(gRPCServer.Address()) 131 require.NoError(t, err) 132 133 endpoint := fmt.Sprintf("127.0.0.1:%s", portString) 134 id := []byte(endpoint) 135 identityMapper := identity.NewIdentityMapper(sec, id, noopPurgeIdentity, sec) 136 137 commInst, err := NewCommInstance(gRPCServer.Server(), certs, identityMapper, id, secureDialOpts, 138 sec, commMetrics, testCommConfig, dialOpts...) 139 require.NoError(t, err) 140 141 go func() { 142 err := gRPCServer.Start() 143 require.NoError(t, err) 144 }() 145 146 return &commGRPC{commInst.(*commImpl), gRPCServer} 147 } 148 149 type commGRPC struct { 150 *commImpl 151 gRPCServer *comm.GRPCServer 152 } 153 154 func (c *commGRPC) Stop() { 155 c.commImpl.Stop() 156 c.commImpl.idMapper.Stop() 157 c.gRPCServer.Stop() 158 } 159 160 func newCommInstanceOnly(t *testing.T, sec *naiveSecProvider, 161 gRPCServer *comm.GRPCServer, certs *common.TLSCertificates, 162 secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm { 163 return newCommInstanceOnlyWithMetrics(t, disabledMetrics, sec, gRPCServer, certs, secureDialOpts, dialOpts...) 164 } 165 166 func newCommInstance(t *testing.T, sec *naiveSecProvider) (c Comm, port int) { 167 port, gRPCServer, certs, secureDialOpts, dialOpts := util.CreateGRPCLayer() 168 comm := newCommInstanceOnly(t, sec, gRPCServer, certs, secureDialOpts, dialOpts...) 169 return comm, port 170 } 171 172 type msgMutator func(*protoext.SignedGossipMessage) *protoext.SignedGossipMessage 173 174 type tlsType int 175 176 const ( 177 none tlsType = iota 178 oneWayTLS 179 mutualTLS 180 ) 181 182 func handshaker(port int, endpoint string, comm Comm, t *testing.T, connMutator msgMutator, connType tlsType) <-chan protoext.ReceivedMessage { 183 c := &commImpl{} 184 cert := GenerateCertificatesOrPanic() 185 tlsCfg := &tls.Config{ 186 InsecureSkipVerify: true, 187 } 188 if connType == mutualTLS { 189 tlsCfg.Certificates = []tls.Certificate{cert} 190 } 191 ta := credentials.NewTLS(tlsCfg) 192 secureOpts := grpc.WithTransportCredentials(ta) 193 if connType == none { 194 secureOpts = grpc.WithInsecure() 195 } 196 acceptChan := comm.Accept(acceptAll) 197 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 198 defer cancel() 199 target := fmt.Sprintf("127.0.0.1:%d", port) 200 conn, err := grpc.DialContext(ctx, target, secureOpts, grpc.WithBlock()) 201 require.NoError(t, err, "%v", err) 202 if err != nil { 203 return nil 204 } 205 cl := proto.NewGossipClient(conn) 206 stream, err := cl.GossipStream(context.Background()) 207 require.NoError(t, err, "%v", err) 208 if err != nil { 209 return nil 210 } 211 212 var clientCertHash []byte 213 if len(tlsCfg.Certificates) > 0 { 214 clientCertHash = certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0]) 215 } 216 217 pkiID := common.PKIidType(endpoint) 218 require.NoError(t, err, "%v", err) 219 msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte(endpoint), func(msg []byte) ([]byte, error) { 220 mac := hmac.New(sha256.New, hmacKey) 221 mac.Write(msg) 222 return mac.Sum(nil), nil 223 }, false) 224 // Mutate connection message to test negative paths 225 msg = connMutator(msg) 226 // Send your own connection message 227 stream.Send(msg.Envelope) 228 // Wait for connection message from the other side 229 envelope, err := stream.Recv() 230 if err != nil { 231 return acceptChan 232 } 233 require.NoError(t, err, "%v", err) 234 msg, err = protoext.EnvelopeToGossipMessage(envelope) 235 require.NoError(t, err, "%v", err) 236 require.Equal(t, []byte(target), msg.GetConn().PkiId) 237 require.Equal(t, extractCertificateHashFromContext(stream.Context()), msg.GetConn().TlsCertHash) 238 msg2Send := createGossipMsg() 239 nonce := uint64(rand.Int()) 240 msg2Send.Nonce = nonce 241 go stream.Send(msg2Send.Envelope) 242 return acceptChan 243 } 244 245 func TestMutualParallelSendWithAck(t *testing.T) { 246 // This test tests concurrent and parallel sending of many (1000) messages 247 // from 2 instances to one another at the same time. 248 249 msgNum := 1000 250 251 comm1, port1 := newCommInstance(t, naiveSec) 252 comm2, port2 := newCommInstance(t, naiveSec) 253 defer comm1.Stop() 254 defer comm2.Stop() 255 256 acceptData := func(o interface{}) bool { 257 m := o.(protoext.ReceivedMessage).GetGossipMessage() 258 return protoext.IsDataMsg(m.GossipMessage) 259 } 260 261 inc1 := comm1.Accept(acceptData) 262 inc2 := comm2.Accept(acceptData) 263 264 // Send a message from comm1 to comm2, to make the instances establish a preliminary connection 265 comm1.Send(createGossipMsg(), remotePeer(port2)) 266 // Wait for the message to be received in comm2 267 <-inc2 268 269 for i := 0; i < msgNum; i++ { 270 go comm1.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port2)) 271 } 272 273 for i := 0; i < msgNum; i++ { 274 go comm2.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port1)) 275 } 276 277 go func() { 278 for i := 0; i < msgNum; i++ { 279 <-inc1 280 } 281 }() 282 283 for i := 0; i < msgNum; i++ { 284 <-inc2 285 } 286 } 287 288 func getAvailablePort(t *testing.T) (port int, endpoint string, ll net.Listener) { 289 ll, err := net.Listen("tcp", "127.0.0.1:0") 290 require.NoError(t, err) 291 endpoint = ll.Addr().String() 292 _, portS, err := net.SplitHostPort(endpoint) 293 require.NoError(t, err) 294 portInt, err := strconv.Atoi(portS) 295 require.NoError(t, err) 296 return portInt, endpoint, ll 297 } 298 299 func TestHandshake(t *testing.T) { 300 signer := func(msg []byte) ([]byte, error) { 301 mac := hmac.New(sha256.New, hmacKey) 302 mac.Write(msg) 303 return mac.Sum(nil), nil 304 } 305 mutator := func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage { 306 return msg 307 } 308 assertPositivePath := func(msg protoext.ReceivedMessage, endpoint string) { 309 expectedPKIID := common.PKIidType(endpoint) 310 require.Equal(t, expectedPKIID, msg.GetConnectionInfo().ID) 311 require.Equal(t, api.PeerIdentityType(endpoint), msg.GetConnectionInfo().Identity) 312 require.NotNil(t, msg.GetConnectionInfo().Auth) 313 sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData) 314 require.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature) 315 } 316 317 // Positive path 1 - check authentication without TLS 318 port, endpoint, ll := getAvailablePort(t) 319 s := grpc.NewServer() 320 id := []byte(endpoint) 321 idMapper := identity.NewIdentityMapper(naiveSec, id, noopPurgeIdentity, naiveSec) 322 inst, err := NewCommInstance(s, nil, idMapper, api.PeerIdentityType(endpoint), func() []grpc.DialOption { 323 return []grpc.DialOption{grpc.WithInsecure()} 324 }, naiveSec, disabledMetrics, testCommConfig) 325 go s.Serve(ll) 326 require.NoError(t, err) 327 var msg protoext.ReceivedMessage 328 329 _, tempEndpoint, tempL := getAvailablePort(t) 330 acceptChan := handshaker(port, tempEndpoint, inst, t, mutator, none) 331 select { 332 case <-time.After(time.Duration(time.Second * 4)): 333 require.FailNow(t, "Didn't receive a message, seems like handshake failed") 334 case msg = <-acceptChan: 335 } 336 require.Equal(t, common.PKIidType(tempEndpoint), msg.GetConnectionInfo().ID) 337 require.Equal(t, api.PeerIdentityType(tempEndpoint), msg.GetConnectionInfo().Identity) 338 sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData) 339 require.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature) 340 341 inst.Stop() 342 s.Stop() 343 ll.Close() 344 tempL.Close() 345 time.Sleep(time.Second) 346 347 comm, port := newCommInstance(t, naiveSec) 348 defer comm.Stop() 349 // Positive path 2: initiating peer sends its own certificate 350 _, tempEndpoint, tempL = getAvailablePort(t) 351 acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS) 352 353 select { 354 case <-time.After(time.Second * 2): 355 require.FailNow(t, "Didn't receive a message, seems like handshake failed") 356 case msg = <-acceptChan: 357 } 358 assertPositivePath(msg, tempEndpoint) 359 tempL.Close() 360 361 // Negative path: initiating peer doesn't send its own certificate 362 _, tempEndpoint, tempL = getAvailablePort(t) 363 acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, oneWayTLS) 364 time.Sleep(time.Second) 365 require.Equal(t, 0, len(acceptChan)) 366 tempL.Close() 367 368 // Negative path, signature is wrong 369 _, tempEndpoint, tempL = getAvailablePort(t) 370 mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage { 371 msg.Signature = append(msg.Signature, 0) 372 return msg 373 } 374 acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS) 375 time.Sleep(time.Second) 376 require.Equal(t, 0, len(acceptChan)) 377 tempL.Close() 378 379 // Negative path, the PKIid doesn't match the identity 380 _, tempEndpoint, tempL = getAvailablePort(t) 381 mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage { 382 msg.GetConn().PkiId = []byte(tempEndpoint) 383 // Sign the message again 384 msg.Sign(signer) 385 return msg 386 } 387 _, tempEndpoint2, tempL2 := getAvailablePort(t) 388 acceptChan = handshaker(port, tempEndpoint2, comm, t, mutator, mutualTLS) 389 time.Sleep(time.Second) 390 require.Equal(t, 0, len(acceptChan)) 391 tempL.Close() 392 tempL2.Close() 393 394 // Negative path, the cert hash isn't what is expected 395 _, tempEndpoint, tempL = getAvailablePort(t) 396 mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage { 397 msg.GetConn().TlsCertHash = append(msg.GetConn().TlsCertHash, 0) 398 msg.Sign(signer) 399 return msg 400 } 401 acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS) 402 time.Sleep(time.Second) 403 require.Equal(t, 0, len(acceptChan)) 404 tempL.Close() 405 406 // Negative path, no PKI-ID was sent 407 _, tempEndpoint, tempL = getAvailablePort(t) 408 mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage { 409 msg.GetConn().PkiId = nil 410 msg.Sign(signer) 411 return msg 412 } 413 acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS) 414 time.Sleep(time.Second) 415 require.Equal(t, 0, len(acceptChan)) 416 tempL.Close() 417 418 // Negative path, connection message is of a different type 419 _, tempEndpoint, tempL = getAvailablePort(t) 420 mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage { 421 msg.Content = &proto.GossipMessage_Empty{ 422 Empty: &proto.Empty{}, 423 } 424 msg.Sign(signer) 425 return msg 426 } 427 acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS) 428 time.Sleep(time.Second) 429 require.Equal(t, 0, len(acceptChan)) 430 tempL.Close() 431 432 // Negative path, the peer didn't respond to the handshake in due time 433 _, tempEndpoint, tempL = getAvailablePort(t) 434 mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage { 435 time.Sleep(time.Second * 5) 436 return msg 437 } 438 acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS) 439 time.Sleep(time.Second) 440 require.Equal(t, 0, len(acceptChan)) 441 tempL.Close() 442 } 443 444 func TestConnectUnexpectedPeer(t *testing.T) { 445 // Scenarios: In both scenarios, comm1 connects to comm2 or comm3. 446 // and expects to see a PKI-ID which is equal to comm4's PKI-ID. 447 // The connection attempt would succeed or fail based on whether comm2 or comm3 448 // are in the same org as comm4 449 450 identityByPort := func(port int) api.PeerIdentityType { 451 return api.PeerIdentityType(fmt.Sprintf("127.0.0.1:%d", port)) 452 } 453 454 customNaiveSec := &naiveSecProvider{} 455 456 comm1Port, gRPCServer1, certs1, secureDialOpts1, dialOpts1 := util.CreateGRPCLayer() 457 comm2Port, gRPCServer2, certs2, secureDialOpts2, dialOpts2 := util.CreateGRPCLayer() 458 comm3Port, gRPCServer3, certs3, secureDialOpts3, dialOpts3 := util.CreateGRPCLayer() 459 comm4Port, gRPCServer4, certs4, secureDialOpts4, dialOpts4 := util.CreateGRPCLayer() 460 461 customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm1Port)).Return(api.OrgIdentityType("O")) 462 customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm2Port)).Return(api.OrgIdentityType("A")) 463 customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm3Port)).Return(api.OrgIdentityType("B")) 464 customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm4Port)).Return(api.OrgIdentityType("A")) 465 466 comm1 := newCommInstanceOnly(t, customNaiveSec, gRPCServer1, certs1, secureDialOpts1, dialOpts1...) 467 comm2 := newCommInstanceOnly(t, naiveSec, gRPCServer2, certs2, secureDialOpts2, dialOpts2...) 468 comm3 := newCommInstanceOnly(t, naiveSec, gRPCServer3, certs3, secureDialOpts3, dialOpts3...) 469 comm4 := newCommInstanceOnly(t, naiveSec, gRPCServer4, certs4, secureDialOpts4, dialOpts4...) 470 471 defer comm1.Stop() 472 defer comm2.Stop() 473 defer comm3.Stop() 474 defer comm4.Stop() 475 476 messagesForComm1 := comm1.Accept(acceptAll) 477 messagesForComm2 := comm2.Accept(acceptAll) 478 messagesForComm3 := comm3.Accept(acceptAll) 479 480 // Have comm4 send a message to comm1 481 // in order for comm1 to know comm4 482 comm4.Send(createGossipMsg(), remotePeer(comm1Port)) 483 <-messagesForComm1 484 // Close the connection with comm4 485 comm1.CloseConn(remotePeer(comm4Port)) 486 // At this point, comm1 knows comm4's identity and organization 487 488 t.Run("Same organization", func(t *testing.T) { 489 unexpectedRemotePeer := remotePeer(comm2Port) 490 unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID 491 comm1.Send(createGossipMsg(), unexpectedRemotePeer) 492 select { 493 case <-messagesForComm2: 494 case <-time.After(time.Second * 5): 495 require.Fail(t, "Didn't receive a message within a timely manner") 496 util.PrintStackTrace() 497 } 498 }) 499 500 t.Run("Unexpected organization", func(t *testing.T) { 501 unexpectedRemotePeer := remotePeer(comm3Port) 502 unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID 503 comm1.Send(createGossipMsg(), unexpectedRemotePeer) 504 select { 505 case <-messagesForComm3: 506 require.Fail(t, "Message shouldn't have been received") 507 case <-time.After(time.Second * 5): 508 } 509 }) 510 } 511 512 func TestGetConnectionInfo(t *testing.T) { 513 comm1, port1 := newCommInstance(t, naiveSec) 514 comm2, _ := newCommInstance(t, naiveSec) 515 defer comm1.Stop() 516 defer comm2.Stop() 517 m1 := comm1.Accept(acceptAll) 518 comm2.Send(createGossipMsg(), remotePeer(port1)) 519 select { 520 case <-time.After(time.Second * 10): 521 t.Fatal("Didn't receive a message in time") 522 case msg := <-m1: 523 require.Equal(t, comm2.GetPKIid(), msg.GetConnectionInfo().ID) 524 require.NotNil(t, msg.GetSourceEnvelope()) 525 } 526 } 527 528 func TestCloseConn(t *testing.T) { 529 comm1, port1 := newCommInstance(t, naiveSec) 530 defer comm1.Stop() 531 acceptChan := comm1.Accept(acceptAll) 532 533 cert := GenerateCertificatesOrPanic() 534 tlsCfg := &tls.Config{ 535 InsecureSkipVerify: true, 536 Certificates: []tls.Certificate{cert}, 537 } 538 ta := credentials.NewTLS(tlsCfg) 539 540 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 541 defer cancel() 542 target := fmt.Sprintf("127.0.0.1:%d", port1) 543 conn, err := grpc.DialContext(ctx, target, grpc.WithTransportCredentials(ta), grpc.WithBlock()) 544 require.NoError(t, err, "%v", err) 545 cl := proto.NewGossipClient(conn) 546 stream, err := cl.GossipStream(context.Background()) 547 require.NoError(t, err, "%v", err) 548 c := &commImpl{} 549 tlsCertHash := certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0]) 550 connMsg, _ := c.createConnectionMsg(common.PKIidType("pkiID"), tlsCertHash, api.PeerIdentityType("pkiID"), func(msg []byte) ([]byte, error) { 551 mac := hmac.New(sha256.New, hmacKey) 552 mac.Write(msg) 553 return mac.Sum(nil), nil 554 }, false) 555 require.NoError(t, stream.Send(connMsg.Envelope)) 556 stream.Send(createGossipMsg().Envelope) 557 select { 558 case <-acceptChan: 559 case <-time.After(time.Second): 560 require.Fail(t, "Didn't receive a message within a timely period") 561 } 562 comm1.CloseConn(&RemotePeer{PKIID: common.PKIidType("pkiID")}) 563 time.Sleep(time.Second * 10) 564 gotErr := false 565 msg2Send := createGossipMsg() 566 msg2Send.GetDataMsg().Payload = &proto.Payload{ 567 Data: make([]byte, 1024*1024), 568 } 569 protoext.NoopSign(msg2Send.GossipMessage) 570 for i := 0; i < DefRecvBuffSize; i++ { 571 err := stream.Send(msg2Send.Envelope) 572 if err != nil { 573 gotErr = true 574 break 575 } 576 } 577 require.True(t, gotErr, "Should have failed because connection is closed") 578 } 579 580 // TestCommSend makes sure that enough messages get through 581 // eventually. Comm.Send() is both asynchronous and best-effort, so this test 582 // case assumes some will fail, but that eventually enough messages will get 583 // through that the test will end. 584 func TestCommSend(t *testing.T) { 585 sendMessages := func(c Comm, peer *RemotePeer, stopChan <-chan struct{}) { 586 ticker := time.NewTicker(time.Millisecond) 587 defer ticker.Stop() 588 for { 589 emptyMsg := createGossipMsg() 590 select { 591 case <-stopChan: 592 return 593 case <-ticker.C: 594 c.Send(emptyMsg, peer) 595 } 596 } 597 } 598 599 comm1, port1 := newCommInstance(t, naiveSec) 600 comm2, port2 := newCommInstance(t, naiveSec) 601 defer comm1.Stop() 602 defer comm2.Stop() 603 604 // Create the receive channel before sending the messages 605 ch1 := comm1.Accept(acceptAll) 606 ch2 := comm2.Accept(acceptAll) 607 608 // control channels for background senders 609 stopch1 := make(chan struct{}) 610 stopch2 := make(chan struct{}) 611 612 go sendMessages(comm1, remotePeer(port2), stopch1) 613 go sendMessages(comm2, remotePeer(port1), stopch2) 614 615 c1received := 0 616 c2received := 0 617 // hopefully in some runs we'll fill both send and receive buffers and 618 // drop overflowing messages, but still finish, because the endless 619 // stream of messages inexorably gets through unless something is very 620 // broken. 621 totalMessagesReceived := (DefSendBuffSize + DefRecvBuffSize) * 2 622 timer := time.NewTimer(30 * time.Second) 623 defer timer.Stop() 624 RECV: 625 for { 626 select { 627 case <-ch1: 628 c1received++ 629 if c1received == totalMessagesReceived { 630 close(stopch2) 631 } 632 case <-ch2: 633 c2received++ 634 if c2received == totalMessagesReceived { 635 close(stopch1) 636 } 637 case <-timer.C: 638 t.Fatalf("timed out waiting for messages to be received.\nc1 got %d messages\nc2 got %d messages", c1received, c2received) 639 default: 640 if c1received >= totalMessagesReceived && c2received >= totalMessagesReceived { 641 break RECV 642 } 643 } 644 } 645 t.Logf("c1 got %d messages\nc2 got %d messages", c1received, c2received) 646 } 647 648 type nonResponsivePeer struct { 649 *grpc.Server 650 port int 651 } 652 653 func newNonResponsivePeer(t *testing.T) *nonResponsivePeer { 654 port, gRPCServer, _, _, _ := util.CreateGRPCLayer() 655 nrp := &nonResponsivePeer{ 656 Server: gRPCServer.Server(), 657 port: port, 658 } 659 proto.RegisterGossipServer(gRPCServer.Server(), nrp) 660 return nrp 661 } 662 663 func (bp *nonResponsivePeer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) { 664 time.Sleep(time.Second * 15) 665 return &proto.Empty{}, nil 666 } 667 668 func (bp *nonResponsivePeer) GossipStream(stream proto.Gossip_GossipStreamServer) error { 669 return nil 670 } 671 672 func (bp *nonResponsivePeer) stop() { 673 bp.Server.Stop() 674 } 675 676 func TestNonResponsivePing(t *testing.T) { 677 c, _ := newCommInstance(t, naiveSec) 678 defer c.Stop() 679 nonRespPeer := newNonResponsivePeer(t) 680 defer nonRespPeer.stop() 681 s := make(chan struct{}) 682 go func() { 683 c.Probe(remotePeer(nonRespPeer.port)) 684 s <- struct{}{} 685 }() 686 select { 687 case <-time.After(time.Second * 10): 688 require.Fail(t, "Request wasn't cancelled on time") 689 case <-s: 690 } 691 } 692 693 func TestResponses(t *testing.T) { 694 comm1, port1 := newCommInstance(t, naiveSec) 695 comm2, _ := newCommInstance(t, naiveSec) 696 697 defer comm1.Stop() 698 defer comm2.Stop() 699 700 wg := sync.WaitGroup{} 701 702 msg := createGossipMsg() 703 wg.Add(1) 704 go func() { 705 inChan := comm1.Accept(acceptAll) 706 wg.Done() 707 for m := range inChan { 708 reply := createGossipMsg() 709 reply.Nonce = m.GetGossipMessage().Nonce + 1 710 m.Respond(reply.GossipMessage) 711 } 712 }() 713 expectedNOnce := uint64(msg.Nonce + 1) 714 responsesFromComm1 := comm2.Accept(acceptAll) 715 716 ticker := time.NewTicker(10 * time.Second) 717 wg.Wait() 718 comm2.Send(msg, remotePeer(port1)) 719 720 select { 721 case <-ticker.C: 722 require.Fail(t, "Haven't got response from comm1 within a timely manner") 723 break 724 case resp := <-responsesFromComm1: 725 ticker.Stop() 726 require.Equal(t, expectedNOnce, resp.GetGossipMessage().Nonce) 727 break 728 } 729 } 730 731 // TestAccept makes sure that accept filters work. The probability of the parity 732 // of all nonces being 0 or 1 is very low. 733 func TestAccept(t *testing.T) { 734 comm1, port1 := newCommInstance(t, naiveSec) 735 comm2, _ := newCommInstance(t, naiveSec) 736 737 evenNONCESelector := func(m interface{}) bool { 738 return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 == 0 739 } 740 741 oddNONCESelector := func(m interface{}) bool { 742 return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 != 0 743 } 744 745 evenNONCES := comm1.Accept(evenNONCESelector) 746 oddNONCES := comm1.Accept(oddNONCESelector) 747 748 var evenResults []uint64 749 var oddResults []uint64 750 751 out := make(chan uint64) 752 sem := make(chan struct{}) 753 754 readIntoSlice := func(a *[]uint64, ch <-chan protoext.ReceivedMessage) { 755 for m := range ch { 756 *a = append(*a, m.GetGossipMessage().Nonce) 757 select { 758 case out <- m.GetGossipMessage().Nonce: 759 default: // avoid blocking when we stop reading from out 760 } 761 } 762 sem <- struct{}{} 763 } 764 765 go readIntoSlice(&evenResults, evenNONCES) 766 go readIntoSlice(&oddResults, oddNONCES) 767 768 stopSend := make(chan struct{}) 769 go func() { 770 for { 771 select { 772 case <-stopSend: 773 return 774 default: 775 comm2.Send(createGossipMsg(), remotePeer(port1)) 776 } 777 } 778 }() 779 780 waitForMessages(t, out, (DefSendBuffSize+DefRecvBuffSize)*2, "Didn't receive all messages sent") 781 close(stopSend) 782 783 comm1.Stop() 784 comm2.Stop() 785 786 <-sem 787 <-sem 788 789 t.Logf("%d even nonces received", len(evenResults)) 790 t.Logf("%d odd nonces received", len(oddResults)) 791 792 require.NotEmpty(t, evenResults) 793 require.NotEmpty(t, oddResults) 794 795 remainderPredicate := func(a []uint64, rem uint64) { 796 for _, n := range a { 797 require.Equal(t, n%2, rem) 798 } 799 } 800 801 remainderPredicate(evenResults, 0) 802 remainderPredicate(oddResults, 1) 803 } 804 805 func TestReConnections(t *testing.T) { 806 comm1, port1 := newCommInstance(t, naiveSec) 807 comm2, port2 := newCommInstance(t, naiveSec) 808 809 reader := func(out chan uint64, in <-chan protoext.ReceivedMessage) { 810 for { 811 msg := <-in 812 if msg == nil { 813 return 814 } 815 out <- msg.GetGossipMessage().Nonce 816 } 817 } 818 819 out1 := make(chan uint64, 10) 820 out2 := make(chan uint64, 10) 821 822 go reader(out1, comm1.Accept(acceptAll)) 823 go reader(out2, comm2.Accept(acceptAll)) 824 825 // comm1 connects to comm2 826 comm1.Send(createGossipMsg(), remotePeer(port2)) 827 waitForMessages(t, out2, 1, "Comm2 didn't receive a message from comm1 in a timely manner") 828 // comm2 sends to comm1 829 comm2.Send(createGossipMsg(), remotePeer(port1)) 830 waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner") 831 comm1.Stop() 832 833 comm1, port1 = newCommInstance(t, naiveSec) 834 out1 = make(chan uint64, 1) 835 go reader(out1, comm1.Accept(acceptAll)) 836 comm2.Send(createGossipMsg(), remotePeer(port1)) 837 waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner") 838 comm1.Stop() 839 comm2.Stop() 840 } 841 842 func TestProbe(t *testing.T) { 843 comm1, port1 := newCommInstance(t, naiveSec) 844 defer comm1.Stop() 845 comm2, port2 := newCommInstance(t, naiveSec) 846 time.Sleep(time.Duration(1) * time.Second) 847 require.NoError(t, comm1.Probe(remotePeer(port2))) 848 _, err := comm1.Handshake(remotePeer(port2)) 849 require.NoError(t, err) 850 tempPort, _, ll := getAvailablePort(t) 851 defer ll.Close() 852 require.Error(t, comm1.Probe(remotePeer(tempPort))) 853 _, err = comm1.Handshake(remotePeer(tempPort)) 854 require.Error(t, err) 855 comm2.Stop() 856 time.Sleep(time.Duration(1) * time.Second) 857 require.Error(t, comm1.Probe(remotePeer(port2))) 858 _, err = comm1.Handshake(remotePeer(port2)) 859 require.Error(t, err) 860 comm2, port2 = newCommInstance(t, naiveSec) 861 defer comm2.Stop() 862 time.Sleep(time.Duration(1) * time.Second) 863 require.NoError(t, comm2.Probe(remotePeer(port1))) 864 _, err = comm2.Handshake(remotePeer(port1)) 865 require.NoError(t, err) 866 require.NoError(t, comm1.Probe(remotePeer(port2))) 867 _, err = comm1.Handshake(remotePeer(port2)) 868 require.NoError(t, err) 869 // Now try a deep probe with an expected PKI-ID that doesn't match 870 wrongRemotePeer := remotePeer(port2) 871 if wrongRemotePeer.PKIID[0] == 0 { 872 wrongRemotePeer.PKIID[0] = 1 873 } else { 874 wrongRemotePeer.PKIID[0] = 0 875 } 876 _, err = comm1.Handshake(wrongRemotePeer) 877 require.Error(t, err) 878 // Try a deep probe with a nil PKI-ID 879 endpoint := fmt.Sprintf("127.0.0.1:%d", port2) 880 id, err := comm1.Handshake(&RemotePeer{Endpoint: endpoint}) 881 require.NoError(t, err) 882 require.Equal(t, api.PeerIdentityType(endpoint), id) 883 } 884 885 func TestPresumedDead(t *testing.T) { 886 comm1, _ := newCommInstance(t, naiveSec) 887 comm2, port2 := newCommInstance(t, naiveSec) 888 889 wg := sync.WaitGroup{} 890 wg.Add(1) 891 go func() { 892 wg.Wait() 893 comm1.Send(createGossipMsg(), remotePeer(port2)) 894 }() 895 896 ticker := time.NewTicker(time.Duration(10) * time.Second) 897 acceptCh := comm2.Accept(acceptAll) 898 wg.Done() 899 select { 900 case <-acceptCh: 901 ticker.Stop() 902 case <-ticker.C: 903 require.Fail(t, "Didn't get first message") 904 } 905 906 comm2.Stop() 907 go func() { 908 for i := 0; i < 5; i++ { 909 comm1.Send(createGossipMsg(), remotePeer(port2)) 910 time.Sleep(time.Millisecond * 200) 911 } 912 }() 913 914 ticker = time.NewTicker(time.Second * time.Duration(3)) 915 select { 916 case <-ticker.C: 917 require.Fail(t, "Didn't get a presumed dead message within a timely manner") 918 break 919 case <-comm1.PresumedDead(): 920 ticker.Stop() 921 break 922 } 923 } 924 925 func TestReadFromStream(t *testing.T) { 926 stream := &gmocks.MockStream{} 927 stream.On("CloseSend").Return(nil) 928 stream.On("Recv").Return(&proto.Envelope{Payload: []byte{1}}, nil).Once() 929 stream.On("Recv").Return(nil, errors.New("stream closed")).Once() 930 931 conn := newConnection(nil, nil, stream, disabledMetrics, ConnConfig{1, 1}) 932 conn.logger = flogging.MustGetLogger("test") 933 934 errChan := make(chan error, 2) 935 msgChan := make(chan *protoext.SignedGossipMessage, 1) 936 var wg sync.WaitGroup 937 wg.Add(1) 938 go func() { 939 defer wg.Done() 940 conn.readFromStream(errChan, msgChan) 941 }() 942 943 select { 944 case <-msgChan: 945 require.Fail(t, "malformed message shouldn't have been received") 946 case <-time.After(time.Millisecond * 100): 947 require.Len(t, errChan, 1) 948 } 949 950 conn.close() 951 wg.Wait() 952 } 953 954 func TestSendBadEnvelope(t *testing.T) { 955 comm1, port := newCommInstance(t, naiveSec) 956 defer comm1.Stop() 957 958 stream, err := establishSession(t, port) 959 require.NoError(t, err) 960 961 inc := comm1.Accept(acceptAll) 962 963 goodMsg := createGossipMsg() 964 err = stream.Send(goodMsg.Envelope) 965 require.NoError(t, err) 966 967 select { 968 case goodMsgReceived := <-inc: 969 require.Equal(t, goodMsg.Envelope.Payload, goodMsgReceived.GetSourceEnvelope().Payload) 970 case <-time.After(time.Minute): 971 require.Fail(t, "Didn't receive message within a timely manner") 972 return 973 } 974 975 // Next, we corrupt a message and send it until the stream is closed forcefully from the remote peer 976 start := time.Now() 977 for { 978 badMsg := createGossipMsg() 979 badMsg.Envelope.Payload = []byte{1} 980 err = stream.Send(badMsg.Envelope) 981 if err != nil { 982 require.Equal(t, io.EOF, err) 983 break 984 } 985 if time.Now().After(start.Add(time.Second * 30)) { 986 require.Fail(t, "Didn't close stream within a timely manner") 987 return 988 } 989 } 990 } 991 992 func establishSession(t *testing.T, port int) (proto.Gossip_GossipStreamClient, error) { 993 cert := GenerateCertificatesOrPanic() 994 secureOpts := grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ 995 InsecureSkipVerify: true, 996 Certificates: []tls.Certificate{cert}, 997 })) 998 999 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 1000 defer cancel() 1001 1002 endpoint := fmt.Sprintf("127.0.0.1:%d", port) 1003 conn, err := grpc.DialContext(ctx, endpoint, secureOpts, grpc.WithBlock()) 1004 require.NoError(t, err, "%v", err) 1005 if err != nil { 1006 return nil, err 1007 } 1008 cl := proto.NewGossipClient(conn) 1009 stream, err := cl.GossipStream(context.Background()) 1010 require.NoError(t, err, "%v", err) 1011 if err != nil { 1012 return nil, err 1013 } 1014 1015 clientCertHash := certHashFromRawCert(cert.Certificate[0]) 1016 pkiID := common.PKIidType([]byte{1, 2, 3}) 1017 c := &commImpl{} 1018 require.NoError(t, err, "%v", err) 1019 msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte{1, 2, 3}, func(msg []byte) ([]byte, error) { 1020 mac := hmac.New(sha256.New, hmacKey) 1021 mac.Write(msg) 1022 return mac.Sum(nil), nil 1023 }, false) 1024 // Send your own connection message 1025 stream.Send(msg.Envelope) 1026 // Wait for connection message from the other side 1027 envelope, err := stream.Recv() 1028 if err != nil { 1029 return nil, err 1030 } 1031 require.NotNil(t, envelope) 1032 return stream, nil 1033 } 1034 1035 func createGossipMsg() *protoext.SignedGossipMessage { 1036 msg, _ := protoext.NoopSign(&proto.GossipMessage{ 1037 Tag: proto.GossipMessage_EMPTY, 1038 Nonce: uint64(rand.Int()), 1039 Content: &proto.GossipMessage_DataMsg{ 1040 DataMsg: &proto.DataMessage{}, 1041 }, 1042 }) 1043 return msg 1044 } 1045 1046 func remotePeer(port int) *RemotePeer { 1047 endpoint := fmt.Sprintf("127.0.0.1:%d", port) 1048 return &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)} 1049 } 1050 1051 func waitForMessages(t *testing.T, msgChan chan uint64, count int, errMsg string) { 1052 c := 0 1053 waiting := true 1054 ticker := time.NewTicker(time.Duration(10) * time.Second) 1055 for waiting { 1056 select { 1057 case <-msgChan: 1058 c++ 1059 if c == count { 1060 waiting = false 1061 } 1062 case <-ticker.C: 1063 waiting = false 1064 } 1065 } 1066 require.Equal(t, count, c, errMsg) 1067 } 1068 1069 func TestConcurrentCloseSend(t *testing.T) { 1070 var stopping int32 1071 1072 comm1, _ := newCommInstance(t, naiveSec) 1073 comm2, port2 := newCommInstance(t, naiveSec) 1074 m := comm2.Accept(acceptAll) 1075 comm1.Send(createGossipMsg(), remotePeer(port2)) 1076 <-m 1077 ready := make(chan struct{}) 1078 done := make(chan struct{}) 1079 go func() { 1080 defer close(done) 1081 1082 comm1.Send(createGossipMsg(), remotePeer(port2)) 1083 close(ready) 1084 1085 for atomic.LoadInt32(&stopping) == int32(0) { 1086 comm1.Send(createGossipMsg(), remotePeer(port2)) 1087 } 1088 }() 1089 <-ready 1090 comm2.Stop() 1091 atomic.StoreInt32(&stopping, int32(1)) 1092 <-done 1093 }