github.com/kaituanwang/hyperledger@v2.0.1+incompatible/gossip/comm/comm_test.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/hmac"
    13  	"crypto/sha256"
    14  	"crypto/tls"
    15  	"fmt"
    16  	"math/rand"
    17  	"net"
    18  	"strconv"
    19  	"sync"
    20  	"sync/atomic"
    21  	"testing"
    22  	"time"
    23  
    24  	cb "github.com/hyperledger/fabric-protos-go/common"
    25  	proto "github.com/hyperledger/fabric-protos-go/gossip"
    26  	"github.com/hyperledger/fabric/bccsp/factory"
    27  	"github.com/hyperledger/fabric/common/metrics/disabled"
    28  	"github.com/hyperledger/fabric/core/comm"
    29  	"github.com/hyperledger/fabric/gossip/api"
    30  	"github.com/hyperledger/fabric/gossip/api/mocks"
    31  	"github.com/hyperledger/fabric/gossip/common"
    32  	"github.com/hyperledger/fabric/gossip/identity"
    33  	"github.com/hyperledger/fabric/gossip/metrics"
    34  	"github.com/hyperledger/fabric/gossip/protoext"
    35  	"github.com/hyperledger/fabric/gossip/util"
    36  	"github.com/stretchr/testify/assert"
    37  	"github.com/stretchr/testify/mock"
    38  	"google.golang.org/grpc"
    39  	"google.golang.org/grpc/credentials"
    40  )
    41  
    42  func init() {
    43  	util.SetupTestLogging()
    44  	rand.Seed(time.Now().UnixNano())
    45  	factory.InitFactories(nil)
    46  	naiveSec.On("OrgByPeerIdentity", mock.Anything).Return(api.OrgIdentityType{})
    47  }
    48  
    49  var testCommConfig = CommConfig{
    50  	DialTimeout:  300 * time.Millisecond,
    51  	ConnTimeout:  DefConnTimeout,
    52  	RecvBuffSize: DefRecvBuffSize,
    53  	SendBuffSize: DefSendBuffSize,
    54  }
    55  
    56  func acceptAll(msg interface{}) bool {
    57  	return true
    58  }
    59  
    60  var noopPurgeIdentity = func(_ common.PKIidType, _ api.PeerIdentityType) {
    61  
    62  }
    63  
    64  var (
    65  	naiveSec        = &naiveSecProvider{}
    66  	hmacKey         = []byte{0, 0, 0}
    67  	disabledMetrics = metrics.NewGossipMetrics(&disabled.Provider{}).CommMetrics
    68  )
    69  
    70  type naiveSecProvider struct {
    71  	mocks.SecurityAdvisor
    72  }
    73  
    74  func (nsp *naiveSecProvider) OrgByPeerIdentity(identity api.PeerIdentityType) api.OrgIdentityType {
    75  	return nsp.SecurityAdvisor.Called(identity).Get(0).(api.OrgIdentityType)
    76  }
    77  
    78  func (*naiveSecProvider) Expiration(peerIdentity api.PeerIdentityType) (time.Time, error) {
    79  	return time.Now().Add(time.Hour), nil
    80  }
    81  
    82  func (*naiveSecProvider) ValidateIdentity(peerIdentity api.PeerIdentityType) error {
    83  	return nil
    84  }
    85  
    86  // GetPKIidOfCert returns the PKI-ID of a peer's identity
    87  func (*naiveSecProvider) GetPKIidOfCert(peerIdentity api.PeerIdentityType) common.PKIidType {
    88  	return common.PKIidType(peerIdentity)
    89  }
    90  
    91  // VerifyBlock returns nil if the block is properly signed,
    92  // else returns error
    93  func (*naiveSecProvider) VerifyBlock(channelID common.ChannelID, seqNum uint64, signedBlock *cb.Block) error {
    94  	return nil
    95  }
    96  
    97  // Sign signs msg with this peer's signing key and outputs
    98  // the signature if no error occurred.
    99  func (*naiveSecProvider) Sign(msg []byte) ([]byte, error) {
   100  	mac := hmac.New(sha256.New, hmacKey)
   101  	mac.Write(msg)
   102  	return mac.Sum(nil), nil
   103  }
   104  
   105  // Verify checks that signature is a valid signature of message under a peer's verification key.
   106  // If the verification succeeded, Verify returns nil meaning no error occurred.
   107  // If peerCert is nil, then the signature is verified against this peer's verification key.
   108  func (*naiveSecProvider) Verify(peerIdentity api.PeerIdentityType, signature, message []byte) error {
   109  	mac := hmac.New(sha256.New, hmacKey)
   110  	mac.Write(message)
   111  	expected := mac.Sum(nil)
   112  	if !bytes.Equal(signature, expected) {
   113  		return fmt.Errorf("Wrong certificate:%v, %v", signature, message)
   114  	}
   115  	return nil
   116  }
   117  
   118  // VerifyByChannel verifies a peer's signature on a message in the context
   119  // of a specific channel
   120  func (*naiveSecProvider) VerifyByChannel(_ common.ChannelID, _ api.PeerIdentityType, _, _ []byte) error {
   121  	return nil
   122  }
   123  
   124  func newCommInstanceOnlyWithMetrics(t *testing.T, commMetrics *metrics.CommMetrics, sec *naiveSecProvider,
   125  	gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
   126  	secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
   127  
   128  	_, portString, err := net.SplitHostPort(gRPCServer.Address())
   129  	assert.NoError(t, err)
   130  
   131  	endpoint := fmt.Sprintf("127.0.0.1:%s", portString)
   132  	id := []byte(endpoint)
   133  	identityMapper := identity.NewIdentityMapper(sec, id, noopPurgeIdentity, sec)
   134  
   135  	commInst, err := NewCommInstance(gRPCServer.Server(), certs, identityMapper, id, secureDialOpts,
   136  		sec, commMetrics, testCommConfig, dialOpts...)
   137  	assert.NoError(t, err)
   138  
   139  	go func() {
   140  		err := gRPCServer.Start()
   141  		assert.NoError(t, err)
   142  	}()
   143  
   144  	return &commGRPC{commInst.(*commImpl), gRPCServer}
   145  }
   146  
   147  type commGRPC struct {
   148  	*commImpl
   149  	gRPCServer *comm.GRPCServer
   150  }
   151  
   152  func (c *commGRPC) Stop() {
   153  	c.commImpl.Stop()
   154  	c.commImpl.idMapper.Stop()
   155  	c.gRPCServer.Stop()
   156  }
   157  
   158  func newCommInstanceOnly(t *testing.T, sec *naiveSecProvider,
   159  	gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
   160  	secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
   161  	return newCommInstanceOnlyWithMetrics(t, disabledMetrics, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
   162  }
   163  
   164  func newCommInstance(t *testing.T, sec *naiveSecProvider) (c Comm, port int) {
   165  	port, gRPCServer, certs, secureDialOpts, dialOpts := util.CreateGRPCLayer()
   166  	comm := newCommInstanceOnly(t, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
   167  	return comm, port
   168  }
   169  
   170  type msgMutator func(*protoext.SignedGossipMessage) *protoext.SignedGossipMessage
   171  
   172  type tlsType int
   173  
   174  const (
   175  	none tlsType = iota
   176  	oneWayTLS
   177  	mutualTLS
   178  )
   179  
   180  func handshaker(port int, endpoint string, comm Comm, t *testing.T, connMutator msgMutator, connType tlsType) <-chan protoext.ReceivedMessage {
   181  	c := &commImpl{}
   182  	cert := GenerateCertificatesOrPanic()
   183  	tlsCfg := &tls.Config{
   184  		InsecureSkipVerify: true,
   185  	}
   186  	if connType == mutualTLS {
   187  		tlsCfg.Certificates = []tls.Certificate{cert}
   188  	}
   189  	ta := credentials.NewTLS(tlsCfg)
   190  	secureOpts := grpc.WithTransportCredentials(ta)
   191  	if connType == none {
   192  		secureOpts = grpc.WithInsecure()
   193  	}
   194  	acceptChan := comm.Accept(acceptAll)
   195  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   196  	defer cancel()
   197  	target := fmt.Sprintf("127.0.0.1:%d", port)
   198  	conn, err := grpc.DialContext(ctx, target, secureOpts, grpc.WithBlock())
   199  	assert.NoError(t, err, "%v", err)
   200  	if err != nil {
   201  		return nil
   202  	}
   203  	cl := proto.NewGossipClient(conn)
   204  	stream, err := cl.GossipStream(context.Background())
   205  	assert.NoError(t, err, "%v", err)
   206  	if err != nil {
   207  		return nil
   208  	}
   209  
   210  	var clientCertHash []byte
   211  	if len(tlsCfg.Certificates) > 0 {
   212  		clientCertHash = certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
   213  	}
   214  
   215  	pkiID := common.PKIidType(endpoint)
   216  	assert.NoError(t, err, "%v", err)
   217  	msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte(endpoint), func(msg []byte) ([]byte, error) {
   218  		mac := hmac.New(sha256.New, hmacKey)
   219  		mac.Write(msg)
   220  		return mac.Sum(nil), nil
   221  	})
   222  	// Mutate connection message to test negative paths
   223  	msg = connMutator(msg)
   224  	// Send your own connection message
   225  	stream.Send(msg.Envelope)
   226  	// Wait for connection message from the other side
   227  	envelope, err := stream.Recv()
   228  	if err != nil {
   229  		return acceptChan
   230  	}
   231  	assert.NoError(t, err, "%v", err)
   232  	msg, err = protoext.EnvelopeToGossipMessage(envelope)
   233  	assert.NoError(t, err, "%v", err)
   234  	assert.Equal(t, []byte(target), msg.GetConn().PkiId)
   235  	assert.Equal(t, extractCertificateHashFromContext(stream.Context()), msg.GetConn().TlsCertHash)
   236  	msg2Send := createGossipMsg()
   237  	nonce := uint64(rand.Int())
   238  	msg2Send.Nonce = nonce
   239  	go stream.Send(msg2Send.Envelope)
   240  	return acceptChan
   241  }
   242  
   243  func TestMutualParallelSendWithAck(t *testing.T) {
   244  	t.Parallel()
   245  
   246  	// This test tests concurrent and parallel sending of many (1000) messages
   247  	// from 2 instances to one another at the same time.
   248  
   249  	msgNum := 1000
   250  
   251  	comm1, port1 := newCommInstance(t, naiveSec)
   252  	comm2, port2 := newCommInstance(t, naiveSec)
   253  	defer comm1.Stop()
   254  	defer comm2.Stop()
   255  
   256  	acceptData := func(o interface{}) bool {
   257  		m := o.(protoext.ReceivedMessage).GetGossipMessage()
   258  		return protoext.IsDataMsg(m.GossipMessage)
   259  	}
   260  
   261  	inc1 := comm1.Accept(acceptData)
   262  	inc2 := comm2.Accept(acceptData)
   263  
   264  	// Send a message from comm1 to comm2, to make the instances establish a preliminary connection
   265  	comm1.Send(createGossipMsg(), remotePeer(port2))
   266  	// Wait for the message to be received in comm2
   267  	<-inc2
   268  
   269  	for i := 0; i < msgNum; i++ {
   270  		go comm1.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port2))
   271  	}
   272  
   273  	for i := 0; i < msgNum; i++ {
   274  		go comm2.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port1))
   275  	}
   276  
   277  	go func() {
   278  		for i := 0; i < msgNum; i++ {
   279  			<-inc1
   280  		}
   281  	}()
   282  
   283  	for i := 0; i < msgNum; i++ {
   284  		<-inc2
   285  	}
   286  }
   287  
   288  func getAvailablePort(t *testing.T) (port int, endpoint string, ll net.Listener) {
   289  	ll, err := net.Listen("tcp", "127.0.0.1:0")
   290  	assert.NoError(t, err)
   291  	endpoint = ll.Addr().String()
   292  	_, portS, err := net.SplitHostPort(endpoint)
   293  	assert.NoError(t, err)
   294  	portInt, err := strconv.Atoi(portS)
   295  	assert.NoError(t, err)
   296  	return portInt, endpoint, ll
   297  }
   298  
   299  func TestHandshake(t *testing.T) {
   300  	t.Parallel()
   301  	signer := func(msg []byte) ([]byte, error) {
   302  		mac := hmac.New(sha256.New, hmacKey)
   303  		mac.Write(msg)
   304  		return mac.Sum(nil), nil
   305  	}
   306  	mutator := func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   307  		return msg
   308  	}
   309  	assertPositivePath := func(msg protoext.ReceivedMessage, endpoint string) {
   310  		expectedPKIID := common.PKIidType(endpoint)
   311  		assert.Equal(t, expectedPKIID, msg.GetConnectionInfo().ID)
   312  		assert.Equal(t, api.PeerIdentityType(endpoint), msg.GetConnectionInfo().Identity)
   313  		assert.NotNil(t, msg.GetConnectionInfo().Auth)
   314  		sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
   315  		assert.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
   316  	}
   317  
   318  	// Positive path 1 - check authentication without TLS
   319  	port, endpoint, ll := getAvailablePort(t)
   320  	s := grpc.NewServer()
   321  	id := []byte(endpoint)
   322  	idMapper := identity.NewIdentityMapper(naiveSec, id, noopPurgeIdentity, naiveSec)
   323  	inst, err := NewCommInstance(s, nil, idMapper, api.PeerIdentityType(endpoint), func() []grpc.DialOption {
   324  		return []grpc.DialOption{grpc.WithInsecure()}
   325  	}, naiveSec, disabledMetrics, testCommConfig)
   326  	go s.Serve(ll)
   327  	assert.NoError(t, err)
   328  	var msg protoext.ReceivedMessage
   329  
   330  	_, tempEndpoint, tempL := getAvailablePort(t)
   331  	acceptChan := handshaker(port, tempEndpoint, inst, t, mutator, none)
   332  	select {
   333  	case <-time.After(time.Duration(time.Second * 4)):
   334  		assert.FailNow(t, "Didn't receive a message, seems like handshake failed")
   335  	case msg = <-acceptChan:
   336  	}
   337  	assert.Equal(t, common.PKIidType(tempEndpoint), msg.GetConnectionInfo().ID)
   338  	assert.Equal(t, api.PeerIdentityType(tempEndpoint), msg.GetConnectionInfo().Identity)
   339  	sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
   340  	assert.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
   341  
   342  	inst.Stop()
   343  	s.Stop()
   344  	ll.Close()
   345  	tempL.Close()
   346  	time.Sleep(time.Second)
   347  
   348  	comm, port := newCommInstance(t, naiveSec)
   349  	defer comm.Stop()
   350  	// Positive path 2: initiating peer sends its own certificate
   351  	_, tempEndpoint, tempL = getAvailablePort(t)
   352  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   353  
   354  	select {
   355  	case <-time.After(time.Second * 2):
   356  		assert.FailNow(t, "Didn't receive a message, seems like handshake failed")
   357  	case msg = <-acceptChan:
   358  	}
   359  	assertPositivePath(msg, tempEndpoint)
   360  	tempL.Close()
   361  
   362  	// Negative path: initiating peer doesn't send its own certificate
   363  	_, tempEndpoint, tempL = getAvailablePort(t)
   364  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, oneWayTLS)
   365  	time.Sleep(time.Second)
   366  	assert.Equal(t, 0, len(acceptChan))
   367  	tempL.Close()
   368  
   369  	// Negative path, signature is wrong
   370  	_, tempEndpoint, tempL = getAvailablePort(t)
   371  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   372  		msg.Signature = append(msg.Signature, 0)
   373  		return msg
   374  	}
   375  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   376  	time.Sleep(time.Second)
   377  	assert.Equal(t, 0, len(acceptChan))
   378  	tempL.Close()
   379  
   380  	// Negative path, the PKIid doesn't match the identity
   381  	_, tempEndpoint, tempL = getAvailablePort(t)
   382  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   383  		msg.GetConn().PkiId = []byte(tempEndpoint)
   384  		// Sign the message again
   385  		msg.Sign(signer)
   386  		return msg
   387  	}
   388  	_, tempEndpoint2, tempL2 := getAvailablePort(t)
   389  	acceptChan = handshaker(port, tempEndpoint2, comm, t, mutator, mutualTLS)
   390  	time.Sleep(time.Second)
   391  	assert.Equal(t, 0, len(acceptChan))
   392  	tempL.Close()
   393  	tempL2.Close()
   394  
   395  	// Negative path, the cert hash isn't what is expected
   396  	_, tempEndpoint, tempL = getAvailablePort(t)
   397  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   398  		msg.GetConn().TlsCertHash = append(msg.GetConn().TlsCertHash, 0)
   399  		msg.Sign(signer)
   400  		return msg
   401  	}
   402  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   403  	time.Sleep(time.Second)
   404  	assert.Equal(t, 0, len(acceptChan))
   405  	tempL.Close()
   406  
   407  	// Negative path, no PKI-ID was sent
   408  	_, tempEndpoint, tempL = getAvailablePort(t)
   409  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   410  		msg.GetConn().PkiId = nil
   411  		msg.Sign(signer)
   412  		return msg
   413  	}
   414  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   415  	time.Sleep(time.Second)
   416  	assert.Equal(t, 0, len(acceptChan))
   417  	tempL.Close()
   418  
   419  	// Negative path, connection message is of a different type
   420  	_, tempEndpoint, tempL = getAvailablePort(t)
   421  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   422  		msg.Content = &proto.GossipMessage_Empty{
   423  			Empty: &proto.Empty{},
   424  		}
   425  		msg.Sign(signer)
   426  		return msg
   427  	}
   428  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   429  	time.Sleep(time.Second)
   430  	assert.Equal(t, 0, len(acceptChan))
   431  	tempL.Close()
   432  
   433  	// Negative path, the peer didn't respond to the handshake in due time
   434  	_, tempEndpoint, tempL = getAvailablePort(t)
   435  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   436  		time.Sleep(time.Second * 5)
   437  		return msg
   438  	}
   439  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   440  	time.Sleep(time.Second)
   441  	assert.Equal(t, 0, len(acceptChan))
   442  	tempL.Close()
   443  }
   444  
   445  func TestConnectUnexpectedPeer(t *testing.T) {
   446  	t.Parallel()
   447  	// Scenarios: In both scenarios, comm1 connects to comm2 or comm3.
   448  	// and expects to see a PKI-ID which is equal to comm4's PKI-ID.
   449  	// The connection attempt would succeed or fail based on whether comm2 or comm3
   450  	// are in the same org as comm4
   451  
   452  	identityByPort := func(port int) api.PeerIdentityType {
   453  		return api.PeerIdentityType(fmt.Sprintf("127.0.0.1:%d", port))
   454  	}
   455  
   456  	customNaiveSec := &naiveSecProvider{}
   457  
   458  	comm1Port, gRPCServer1, certs1, secureDialOpts1, dialOpts1 := util.CreateGRPCLayer()
   459  	comm2Port, gRPCServer2, certs2, secureDialOpts2, dialOpts2 := util.CreateGRPCLayer()
   460  	comm3Port, gRPCServer3, certs3, secureDialOpts3, dialOpts3 := util.CreateGRPCLayer()
   461  	comm4Port, gRPCServer4, certs4, secureDialOpts4, dialOpts4 := util.CreateGRPCLayer()
   462  
   463  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm1Port)).Return(api.OrgIdentityType("O"))
   464  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm2Port)).Return(api.OrgIdentityType("A"))
   465  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm3Port)).Return(api.OrgIdentityType("B"))
   466  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm4Port)).Return(api.OrgIdentityType("A"))
   467  
   468  	comm1 := newCommInstanceOnly(t, customNaiveSec, gRPCServer1, certs1, secureDialOpts1, dialOpts1...)
   469  	comm2 := newCommInstanceOnly(t, naiveSec, gRPCServer2, certs2, secureDialOpts2, dialOpts2...)
   470  	comm3 := newCommInstanceOnly(t, naiveSec, gRPCServer3, certs3, secureDialOpts3, dialOpts3...)
   471  	comm4 := newCommInstanceOnly(t, naiveSec, gRPCServer4, certs4, secureDialOpts4, dialOpts4...)
   472  
   473  	defer comm1.Stop()
   474  	defer comm2.Stop()
   475  	defer comm3.Stop()
   476  	defer comm4.Stop()
   477  
   478  	messagesForComm1 := comm1.Accept(acceptAll)
   479  	messagesForComm2 := comm2.Accept(acceptAll)
   480  	messagesForComm3 := comm3.Accept(acceptAll)
   481  
   482  	// Have comm4 send a message to comm1
   483  	// in order for comm1 to know comm4
   484  	comm4.Send(createGossipMsg(), remotePeer(comm1Port))
   485  	<-messagesForComm1
   486  	// Close the connection with comm4
   487  	comm1.CloseConn(remotePeer(comm4Port))
   488  	// At this point, comm1 knows comm4's identity and organization
   489  
   490  	t.Run("Same organization", func(t *testing.T) {
   491  		unexpectedRemotePeer := remotePeer(comm2Port)
   492  		unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
   493  		comm1.Send(createGossipMsg(), unexpectedRemotePeer)
   494  		select {
   495  		case <-messagesForComm2:
   496  		case <-time.After(time.Second * 5):
   497  			assert.Fail(t, "Didn't receive a message within a timely manner")
   498  			util.PrintStackTrace()
   499  		}
   500  	})
   501  
   502  	t.Run("Unexpected organization", func(t *testing.T) {
   503  		unexpectedRemotePeer := remotePeer(comm3Port)
   504  		unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
   505  		comm1.Send(createGossipMsg(), unexpectedRemotePeer)
   506  		select {
   507  		case <-messagesForComm3:
   508  			assert.Fail(t, "Message shouldn't have been received")
   509  		case <-time.After(time.Second * 5):
   510  		}
   511  	})
   512  }
   513  
   514  func TestGetConnectionInfo(t *testing.T) {
   515  	t.Parallel()
   516  	comm1, port1 := newCommInstance(t, naiveSec)
   517  	comm2, _ := newCommInstance(t, naiveSec)
   518  	defer comm1.Stop()
   519  	defer comm2.Stop()
   520  	m1 := comm1.Accept(acceptAll)
   521  	comm2.Send(createGossipMsg(), remotePeer(port1))
   522  	select {
   523  	case <-time.After(time.Second * 10):
   524  		t.Fatal("Didn't receive a message in time")
   525  	case msg := <-m1:
   526  		assert.Equal(t, comm2.GetPKIid(), msg.GetConnectionInfo().ID)
   527  		assert.NotNil(t, msg.GetSourceEnvelope())
   528  	}
   529  }
   530  
   531  func TestCloseConn(t *testing.T) {
   532  	t.Parallel()
   533  	comm1, port1 := newCommInstance(t, naiveSec)
   534  	defer comm1.Stop()
   535  	acceptChan := comm1.Accept(acceptAll)
   536  
   537  	cert := GenerateCertificatesOrPanic()
   538  	tlsCfg := &tls.Config{
   539  		InsecureSkipVerify: true,
   540  		Certificates:       []tls.Certificate{cert},
   541  	}
   542  	ta := credentials.NewTLS(tlsCfg)
   543  
   544  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   545  	defer cancel()
   546  	target := fmt.Sprintf("127.0.0.1:%d", port1)
   547  	conn, err := grpc.DialContext(ctx, target, grpc.WithTransportCredentials(ta), grpc.WithBlock())
   548  	assert.NoError(t, err, "%v", err)
   549  	cl := proto.NewGossipClient(conn)
   550  	stream, err := cl.GossipStream(context.Background())
   551  	assert.NoError(t, err, "%v", err)
   552  	c := &commImpl{}
   553  	tlsCertHash := certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
   554  	connMsg, _ := c.createConnectionMsg(common.PKIidType("pkiID"), tlsCertHash, api.PeerIdentityType("pkiID"), func(msg []byte) ([]byte, error) {
   555  		mac := hmac.New(sha256.New, hmacKey)
   556  		mac.Write(msg)
   557  		return mac.Sum(nil), nil
   558  	})
   559  	assert.NoError(t, stream.Send(connMsg.Envelope))
   560  	stream.Send(createGossipMsg().Envelope)
   561  	select {
   562  	case <-acceptChan:
   563  	case <-time.After(time.Second):
   564  		assert.Fail(t, "Didn't receive a message within a timely period")
   565  	}
   566  	comm1.CloseConn(&RemotePeer{PKIID: common.PKIidType("pkiID")})
   567  	time.Sleep(time.Second * 10)
   568  	gotErr := false
   569  	msg2Send := createGossipMsg()
   570  	msg2Send.GetDataMsg().Payload = &proto.Payload{
   571  		Data: make([]byte, 1024*1024),
   572  	}
   573  	protoext.NoopSign(msg2Send.GossipMessage)
   574  	for i := 0; i < DefRecvBuffSize; i++ {
   575  		err := stream.Send(msg2Send.Envelope)
   576  		if err != nil {
   577  			gotErr = true
   578  			break
   579  		}
   580  	}
   581  	assert.True(t, gotErr, "Should have failed because connection is closed")
   582  }
   583  
   584  // TestCommSend makes sure that enough messages get through
   585  // eventually. Comm.Send() is both asynchronous and best-effort, so this test
   586  // case assumes some will fail, but that eventually enough messages will get
   587  // through that the test will end.
   588  func TestCommSend(t *testing.T) {
   589  	t.Parallel()
   590  	comm1, port1 := newCommInstance(t, naiveSec)
   591  	comm2, port2 := newCommInstance(t, naiveSec)
   592  	defer comm1.Stop()
   593  	defer comm2.Stop()
   594  
   595  	// Create the receive channel before sending the messages
   596  	ch1 := comm1.Accept(acceptAll)
   597  	ch2 := comm2.Accept(acceptAll)
   598  
   599  	// control channels for background senders
   600  	stopch1 := make(chan struct{})
   601  	stopch2 := make(chan struct{})
   602  
   603  	go func() {
   604  		for {
   605  			emptyMsg := createGossipMsg()
   606  			select {
   607  			case <-stopch1:
   608  				return
   609  			default:
   610  				comm1.Send(emptyMsg, remotePeer(port2))
   611  			}
   612  		}
   613  	}()
   614  
   615  	go func() {
   616  		for {
   617  			emptyMsg := createGossipMsg()
   618  			select {
   619  			case <-stopch2:
   620  				return
   621  			default:
   622  				comm2.Send(emptyMsg, remotePeer(port1))
   623  			}
   624  		}
   625  	}()
   626  
   627  	c1recieved := 0
   628  	c2recieved := 0
   629  	// hopefully in some runs we'll fill both send and receive buffers and
   630  	// drop overflowing messages, but still finish, because the endless
   631  	// stream of messages inexorably gets through unless something is very
   632  	// broken.
   633  	totalMessagesRecieved := (DefSendBuffSize + DefRecvBuffSize) * 2
   634  RECV:
   635  	for {
   636  		select {
   637  		case <-ch1:
   638  			c1recieved++
   639  			if c1recieved == totalMessagesRecieved {
   640  				close(stopch2)
   641  			}
   642  		case <-ch2:
   643  			c2recieved++
   644  			if c2recieved == totalMessagesRecieved {
   645  				close(stopch1)
   646  			}
   647  		case <-time.After(30 * time.Second):
   648  			t.Fatalf("timed out waiting for messages to be recieved.\nc1 got %d messages\nc2 got %d messages", c1recieved, c2recieved)
   649  		default:
   650  			if c1recieved > totalMessagesRecieved && c2recieved > totalMessagesRecieved {
   651  				break RECV
   652  			}
   653  		}
   654  	}
   655  	t.Logf("c1 got %d messages\nc2 got %d messages", c1recieved, c2recieved)
   656  }
   657  
   658  type nonResponsivePeer struct {
   659  	*grpc.Server
   660  	port int
   661  }
   662  
   663  func newNonResponsivePeer(t *testing.T) *nonResponsivePeer {
   664  	port, gRPCServer, _, _, _ := util.CreateGRPCLayer()
   665  	nrp := &nonResponsivePeer{
   666  		Server: gRPCServer.Server(),
   667  		port:   port,
   668  	}
   669  	proto.RegisterGossipServer(gRPCServer.Server(), nrp)
   670  	return nrp
   671  }
   672  
   673  func (bp *nonResponsivePeer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) {
   674  	time.Sleep(time.Second * 15)
   675  	return &proto.Empty{}, nil
   676  }
   677  
   678  func (bp *nonResponsivePeer) GossipStream(stream proto.Gossip_GossipStreamServer) error {
   679  	return nil
   680  }
   681  
   682  func (bp *nonResponsivePeer) stop() {
   683  	bp.Server.Stop()
   684  }
   685  
   686  func TestNonResponsivePing(t *testing.T) {
   687  	t.Parallel()
   688  	c, _ := newCommInstance(t, naiveSec)
   689  	defer c.Stop()
   690  	nonRespPeer := newNonResponsivePeer(t)
   691  	defer nonRespPeer.stop()
   692  	s := make(chan struct{})
   693  	go func() {
   694  		c.Probe(remotePeer(nonRespPeer.port))
   695  		s <- struct{}{}
   696  	}()
   697  	select {
   698  	case <-time.After(time.Second * 10):
   699  		assert.Fail(t, "Request wasn't cancelled on time")
   700  	case <-s:
   701  	}
   702  
   703  }
   704  
   705  func TestResponses(t *testing.T) {
   706  	t.Parallel()
   707  	comm1, port1 := newCommInstance(t, naiveSec)
   708  	comm2, _ := newCommInstance(t, naiveSec)
   709  
   710  	defer comm1.Stop()
   711  	defer comm2.Stop()
   712  
   713  	wg := sync.WaitGroup{}
   714  
   715  	msg := createGossipMsg()
   716  	wg.Add(1)
   717  	go func() {
   718  		inChan := comm1.Accept(acceptAll)
   719  		wg.Done()
   720  		for m := range inChan {
   721  			reply := createGossipMsg()
   722  			reply.Nonce = m.GetGossipMessage().Nonce + 1
   723  			m.Respond(reply.GossipMessage)
   724  		}
   725  	}()
   726  	expectedNOnce := uint64(msg.Nonce + 1)
   727  	responsesFromComm1 := comm2.Accept(acceptAll)
   728  
   729  	ticker := time.NewTicker(10 * time.Second)
   730  	wg.Wait()
   731  	comm2.Send(msg, remotePeer(port1))
   732  
   733  	select {
   734  	case <-ticker.C:
   735  		assert.Fail(t, "Haven't got response from comm1 within a timely manner")
   736  		break
   737  	case resp := <-responsesFromComm1:
   738  		ticker.Stop()
   739  		assert.Equal(t, expectedNOnce, resp.GetGossipMessage().Nonce)
   740  		break
   741  	}
   742  }
   743  
   744  // TestAccept makes sure that accept filters work. The probability of the parity
   745  // of all nonces being 0 or 1 is very low.
   746  func TestAccept(t *testing.T) {
   747  	t.Parallel()
   748  	comm1, port1 := newCommInstance(t, naiveSec)
   749  	comm2, _ := newCommInstance(t, naiveSec)
   750  
   751  	evenNONCESelector := func(m interface{}) bool {
   752  		return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 == 0
   753  	}
   754  
   755  	oddNONCESelector := func(m interface{}) bool {
   756  		return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 != 0
   757  	}
   758  
   759  	evenNONCES := comm1.Accept(evenNONCESelector)
   760  	oddNONCES := comm1.Accept(oddNONCESelector)
   761  
   762  	var evenResults []uint64
   763  	var oddResults []uint64
   764  
   765  	out := make(chan uint64)
   766  	sem := make(chan struct{})
   767  
   768  	readIntoSlice := func(a *[]uint64, ch <-chan protoext.ReceivedMessage) {
   769  		for m := range ch {
   770  			*a = append(*a, m.GetGossipMessage().Nonce)
   771  			select {
   772  			case out <- m.GetGossipMessage().Nonce:
   773  			default: // avoid blocking when we stop reading from out
   774  			}
   775  		}
   776  		sem <- struct{}{}
   777  	}
   778  
   779  	go readIntoSlice(&evenResults, evenNONCES)
   780  	go readIntoSlice(&oddResults, oddNONCES)
   781  
   782  	stopSend := make(chan struct{})
   783  	go func() {
   784  		for {
   785  			select {
   786  			case <-stopSend:
   787  				return
   788  			default:
   789  				comm2.Send(createGossipMsg(), remotePeer(port1))
   790  			}
   791  		}
   792  	}()
   793  
   794  	waitForMessages(t, out, (DefSendBuffSize+DefRecvBuffSize)*2, "Didn't receive all messages sent")
   795  	close(stopSend)
   796  
   797  	comm1.Stop()
   798  	comm2.Stop()
   799  
   800  	<-sem
   801  	<-sem
   802  
   803  	t.Logf("%d even nonces received", len(evenResults))
   804  	t.Logf("%d  odd nonces received", len(oddResults))
   805  
   806  	assert.NotEmpty(t, evenResults)
   807  	assert.NotEmpty(t, oddResults)
   808  
   809  	remainderPredicate := func(a []uint64, rem uint64) {
   810  		for _, n := range a {
   811  			assert.Equal(t, n%2, rem)
   812  		}
   813  	}
   814  
   815  	remainderPredicate(evenResults, 0)
   816  	remainderPredicate(oddResults, 1)
   817  }
   818  
   819  func TestReConnections(t *testing.T) {
   820  	t.Parallel()
   821  	comm1, port1 := newCommInstance(t, naiveSec)
   822  	comm2, port2 := newCommInstance(t, naiveSec)
   823  
   824  	reader := func(out chan uint64, in <-chan protoext.ReceivedMessage) {
   825  		for {
   826  			msg := <-in
   827  			if msg == nil {
   828  				return
   829  			}
   830  			out <- msg.GetGossipMessage().Nonce
   831  		}
   832  	}
   833  
   834  	out1 := make(chan uint64, 10)
   835  	out2 := make(chan uint64, 10)
   836  
   837  	go reader(out1, comm1.Accept(acceptAll))
   838  	go reader(out2, comm2.Accept(acceptAll))
   839  
   840  	// comm1 connects to comm2
   841  	comm1.Send(createGossipMsg(), remotePeer(port2))
   842  	waitForMessages(t, out2, 1, "Comm2 didn't receive a message from comm1 in a timely manner")
   843  	// comm2 sends to comm1
   844  	comm2.Send(createGossipMsg(), remotePeer(port1))
   845  	waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
   846  	comm1.Stop()
   847  
   848  	comm1, port1 = newCommInstance(t, naiveSec)
   849  	out1 = make(chan uint64, 1)
   850  	go reader(out1, comm1.Accept(acceptAll))
   851  	comm2.Send(createGossipMsg(), remotePeer(port1))
   852  	waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
   853  	comm1.Stop()
   854  	comm2.Stop()
   855  }
   856  
   857  func TestProbe(t *testing.T) {
   858  	t.Parallel()
   859  	comm1, port1 := newCommInstance(t, naiveSec)
   860  	defer comm1.Stop()
   861  	comm2, port2 := newCommInstance(t, naiveSec)
   862  	time.Sleep(time.Duration(1) * time.Second)
   863  	assert.NoError(t, comm1.Probe(remotePeer(port2)))
   864  	_, err := comm1.Handshake(remotePeer(port2))
   865  	assert.NoError(t, err)
   866  	tempPort, _, ll := getAvailablePort(t)
   867  	defer ll.Close()
   868  	assert.Error(t, comm1.Probe(remotePeer(tempPort)))
   869  	_, err = comm1.Handshake(remotePeer(tempPort))
   870  	assert.Error(t, err)
   871  	comm2.Stop()
   872  	time.Sleep(time.Duration(1) * time.Second)
   873  	assert.Error(t, comm1.Probe(remotePeer(port2)))
   874  	_, err = comm1.Handshake(remotePeer(port2))
   875  	assert.Error(t, err)
   876  	comm2, port2 = newCommInstance(t, naiveSec)
   877  	defer comm2.Stop()
   878  	time.Sleep(time.Duration(1) * time.Second)
   879  	assert.NoError(t, comm2.Probe(remotePeer(port1)))
   880  	_, err = comm2.Handshake(remotePeer(port1))
   881  	assert.NoError(t, err)
   882  	assert.NoError(t, comm1.Probe(remotePeer(port2)))
   883  	_, err = comm1.Handshake(remotePeer(port2))
   884  	assert.NoError(t, err)
   885  	// Now try a deep probe with an expected PKI-ID that doesn't match
   886  	wrongRemotePeer := remotePeer(port2)
   887  	if wrongRemotePeer.PKIID[0] == 0 {
   888  		wrongRemotePeer.PKIID[0] = 1
   889  	} else {
   890  		wrongRemotePeer.PKIID[0] = 0
   891  	}
   892  	_, err = comm1.Handshake(wrongRemotePeer)
   893  	assert.Error(t, err)
   894  	// Try a deep probe with a nil PKI-ID
   895  	endpoint := fmt.Sprintf("127.0.0.1:%d", port2)
   896  	id, err := comm1.Handshake(&RemotePeer{Endpoint: endpoint})
   897  	assert.NoError(t, err)
   898  	assert.Equal(t, api.PeerIdentityType(endpoint), id)
   899  }
   900  
   901  func TestPresumedDead(t *testing.T) {
   902  	t.Parallel()
   903  	comm1, _ := newCommInstance(t, naiveSec)
   904  	comm2, port2 := newCommInstance(t, naiveSec)
   905  
   906  	wg := sync.WaitGroup{}
   907  	wg.Add(1)
   908  	go func() {
   909  		wg.Wait()
   910  		comm1.Send(createGossipMsg(), remotePeer(port2))
   911  	}()
   912  
   913  	ticker := time.NewTicker(time.Duration(10) * time.Second)
   914  	acceptCh := comm2.Accept(acceptAll)
   915  	wg.Done()
   916  	select {
   917  	case <-acceptCh:
   918  		ticker.Stop()
   919  	case <-ticker.C:
   920  		assert.Fail(t, "Didn't get first message")
   921  	}
   922  
   923  	comm2.Stop()
   924  	go func() {
   925  		for i := 0; i < 5; i++ {
   926  			comm1.Send(createGossipMsg(), remotePeer(port2))
   927  			time.Sleep(time.Millisecond * 200)
   928  		}
   929  	}()
   930  
   931  	ticker = time.NewTicker(time.Second * time.Duration(3))
   932  	select {
   933  	case <-ticker.C:
   934  		assert.Fail(t, "Didn't get a presumed dead message within a timely manner")
   935  		break
   936  	case <-comm1.PresumedDead():
   937  		ticker.Stop()
   938  		break
   939  	}
   940  }
   941  
   942  func createGossipMsg() *protoext.SignedGossipMessage {
   943  	msg, _ := protoext.NoopSign(&proto.GossipMessage{
   944  		Tag:   proto.GossipMessage_EMPTY,
   945  		Nonce: uint64(rand.Int()),
   946  		Content: &proto.GossipMessage_DataMsg{
   947  			DataMsg: &proto.DataMessage{},
   948  		},
   949  	})
   950  	return msg
   951  }
   952  
   953  func remotePeer(port int) *RemotePeer {
   954  	endpoint := fmt.Sprintf("127.0.0.1:%d", port)
   955  	return &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)}
   956  }
   957  
   958  func waitForMessages(t *testing.T, msgChan chan uint64, count int, errMsg string) {
   959  	c := 0
   960  	waiting := true
   961  	ticker := time.NewTicker(time.Duration(10) * time.Second)
   962  	for waiting {
   963  		select {
   964  		case <-msgChan:
   965  			c++
   966  			if c == count {
   967  				waiting = false
   968  			}
   969  		case <-ticker.C:
   970  			waiting = false
   971  		}
   972  	}
   973  	assert.Equal(t, count, c, errMsg)
   974  }
   975  
   976  func TestConcurrentCloseSend(t *testing.T) {
   977  	t.Parallel()
   978  	var stopping int32
   979  
   980  	comm1, _ := newCommInstance(t, naiveSec)
   981  	comm2, port2 := newCommInstance(t, naiveSec)
   982  	m := comm2.Accept(acceptAll)
   983  	comm1.Send(createGossipMsg(), remotePeer(port2))
   984  	<-m
   985  	ready := make(chan struct{})
   986  	done := make(chan struct{})
   987  	go func() {
   988  		defer close(done)
   989  
   990  		comm1.Send(createGossipMsg(), remotePeer(port2))
   991  		close(ready)
   992  
   993  		for atomic.LoadInt32(&stopping) == int32(0) {
   994  			comm1.Send(createGossipMsg(), remotePeer(port2))
   995  		}
   996  	}()
   997  	<-ready
   998  	comm2.Stop()
   999  	atomic.StoreInt32(&stopping, int32(1))
  1000  	<-done
  1001  }