github.com/defanghe/fabric@v2.1.1+incompatible/gossip/comm/comm_test.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/hmac"
    13  	"crypto/sha256"
    14  	"crypto/tls"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"math/rand"
    19  	"net"
    20  	"strconv"
    21  	"sync"
    22  	"sync/atomic"
    23  	"testing"
    24  	"time"
    25  
    26  	cb "github.com/hyperledger/fabric-protos-go/common"
    27  	proto "github.com/hyperledger/fabric-protos-go/gossip"
    28  	"github.com/hyperledger/fabric/bccsp/factory"
    29  	"github.com/hyperledger/fabric/common/flogging"
    30  	"github.com/hyperledger/fabric/common/metrics/disabled"
    31  	"github.com/hyperledger/fabric/gossip/api"
    32  	"github.com/hyperledger/fabric/gossip/api/mocks"
    33  	gmocks "github.com/hyperledger/fabric/gossip/comm/mocks"
    34  	"github.com/hyperledger/fabric/gossip/common"
    35  	"github.com/hyperledger/fabric/gossip/identity"
    36  	"github.com/hyperledger/fabric/gossip/metrics"
    37  	"github.com/hyperledger/fabric/gossip/protoext"
    38  	"github.com/hyperledger/fabric/gossip/util"
    39  	"github.com/hyperledger/fabric/internal/pkg/comm"
    40  	"github.com/stretchr/testify/assert"
    41  	"github.com/stretchr/testify/mock"
    42  	"google.golang.org/grpc"
    43  	"google.golang.org/grpc/credentials"
    44  )
    45  
    46  func init() {
    47  	util.SetupTestLogging()
    48  	rand.Seed(time.Now().UnixNano())
    49  	factory.InitFactories(nil)
    50  	naiveSec.On("OrgByPeerIdentity", mock.Anything).Return(api.OrgIdentityType{})
    51  }
    52  
    53  var testCommConfig = CommConfig{
    54  	DialTimeout:  300 * time.Millisecond,
    55  	ConnTimeout:  DefConnTimeout,
    56  	RecvBuffSize: DefRecvBuffSize,
    57  	SendBuffSize: DefSendBuffSize,
    58  }
    59  
    60  func acceptAll(msg interface{}) bool {
    61  	return true
    62  }
    63  
    64  var noopPurgeIdentity = func(_ common.PKIidType, _ api.PeerIdentityType) {
    65  
    66  }
    67  
    68  var (
    69  	naiveSec        = &naiveSecProvider{}
    70  	hmacKey         = []byte{0, 0, 0}
    71  	disabledMetrics = metrics.NewGossipMetrics(&disabled.Provider{}).CommMetrics
    72  )
    73  
    74  type naiveSecProvider struct {
    75  	mocks.SecurityAdvisor
    76  }
    77  
    78  func (nsp *naiveSecProvider) OrgByPeerIdentity(identity api.PeerIdentityType) api.OrgIdentityType {
    79  	return nsp.SecurityAdvisor.Called(identity).Get(0).(api.OrgIdentityType)
    80  }
    81  
    82  func (*naiveSecProvider) Expiration(peerIdentity api.PeerIdentityType) (time.Time, error) {
    83  	return time.Now().Add(time.Hour), nil
    84  }
    85  
    86  func (*naiveSecProvider) ValidateIdentity(peerIdentity api.PeerIdentityType) error {
    87  	return nil
    88  }
    89  
    90  // GetPKIidOfCert returns the PKI-ID of a peer's identity
    91  func (*naiveSecProvider) GetPKIidOfCert(peerIdentity api.PeerIdentityType) common.PKIidType {
    92  	return common.PKIidType(peerIdentity)
    93  }
    94  
    95  // VerifyBlock returns nil if the block is properly signed,
    96  // else returns error
    97  func (*naiveSecProvider) VerifyBlock(channelID common.ChannelID, seqNum uint64, signedBlock *cb.Block) error {
    98  	return nil
    99  }
   100  
   101  // Sign signs msg with this peer's signing key and outputs
   102  // the signature if no error occurred.
   103  func (*naiveSecProvider) Sign(msg []byte) ([]byte, error) {
   104  	mac := hmac.New(sha256.New, hmacKey)
   105  	mac.Write(msg)
   106  	return mac.Sum(nil), nil
   107  }
   108  
   109  // Verify checks that signature is a valid signature of message under a peer's verification key.
   110  // If the verification succeeded, Verify returns nil meaning no error occurred.
   111  // If peerCert is nil, then the signature is verified against this peer's verification key.
   112  func (*naiveSecProvider) Verify(peerIdentity api.PeerIdentityType, signature, message []byte) error {
   113  	mac := hmac.New(sha256.New, hmacKey)
   114  	mac.Write(message)
   115  	expected := mac.Sum(nil)
   116  	if !bytes.Equal(signature, expected) {
   117  		return fmt.Errorf("Wrong certificate:%v, %v", signature, message)
   118  	}
   119  	return nil
   120  }
   121  
   122  // VerifyByChannel verifies a peer's signature on a message in the context
   123  // of a specific channel
   124  func (*naiveSecProvider) VerifyByChannel(_ common.ChannelID, _ api.PeerIdentityType, _, _ []byte) error {
   125  	return nil
   126  }
   127  
   128  func newCommInstanceOnlyWithMetrics(t *testing.T, commMetrics *metrics.CommMetrics, sec *naiveSecProvider,
   129  	gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
   130  	secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
   131  
   132  	_, portString, err := net.SplitHostPort(gRPCServer.Address())
   133  	assert.NoError(t, err)
   134  
   135  	endpoint := fmt.Sprintf("127.0.0.1:%s", portString)
   136  	id := []byte(endpoint)
   137  	identityMapper := identity.NewIdentityMapper(sec, id, noopPurgeIdentity, sec)
   138  
   139  	commInst, err := NewCommInstance(gRPCServer.Server(), certs, identityMapper, id, secureDialOpts,
   140  		sec, commMetrics, testCommConfig, dialOpts...)
   141  	assert.NoError(t, err)
   142  
   143  	go func() {
   144  		err := gRPCServer.Start()
   145  		assert.NoError(t, err)
   146  	}()
   147  
   148  	return &commGRPC{commInst.(*commImpl), gRPCServer}
   149  }
   150  
   151  type commGRPC struct {
   152  	*commImpl
   153  	gRPCServer *comm.GRPCServer
   154  }
   155  
   156  func (c *commGRPC) Stop() {
   157  	c.commImpl.Stop()
   158  	c.commImpl.idMapper.Stop()
   159  	c.gRPCServer.Stop()
   160  }
   161  
   162  func newCommInstanceOnly(t *testing.T, sec *naiveSecProvider,
   163  	gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
   164  	secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
   165  	return newCommInstanceOnlyWithMetrics(t, disabledMetrics, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
   166  }
   167  
   168  func newCommInstance(t *testing.T, sec *naiveSecProvider) (c Comm, port int) {
   169  	port, gRPCServer, certs, secureDialOpts, dialOpts := util.CreateGRPCLayer()
   170  	comm := newCommInstanceOnly(t, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
   171  	return comm, port
   172  }
   173  
   174  type msgMutator func(*protoext.SignedGossipMessage) *protoext.SignedGossipMessage
   175  
   176  type tlsType int
   177  
   178  const (
   179  	none tlsType = iota
   180  	oneWayTLS
   181  	mutualTLS
   182  )
   183  
   184  func handshaker(port int, endpoint string, comm Comm, t *testing.T, connMutator msgMutator, connType tlsType) <-chan protoext.ReceivedMessage {
   185  	c := &commImpl{}
   186  	cert := GenerateCertificatesOrPanic()
   187  	tlsCfg := &tls.Config{
   188  		InsecureSkipVerify: true,
   189  	}
   190  	if connType == mutualTLS {
   191  		tlsCfg.Certificates = []tls.Certificate{cert}
   192  	}
   193  	ta := credentials.NewTLS(tlsCfg)
   194  	secureOpts := grpc.WithTransportCredentials(ta)
   195  	if connType == none {
   196  		secureOpts = grpc.WithInsecure()
   197  	}
   198  	acceptChan := comm.Accept(acceptAll)
   199  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   200  	defer cancel()
   201  	target := fmt.Sprintf("127.0.0.1:%d", port)
   202  	conn, err := grpc.DialContext(ctx, target, secureOpts, grpc.WithBlock())
   203  	assert.NoError(t, err, "%v", err)
   204  	if err != nil {
   205  		return nil
   206  	}
   207  	cl := proto.NewGossipClient(conn)
   208  	stream, err := cl.GossipStream(context.Background())
   209  	assert.NoError(t, err, "%v", err)
   210  	if err != nil {
   211  		return nil
   212  	}
   213  
   214  	var clientCertHash []byte
   215  	if len(tlsCfg.Certificates) > 0 {
   216  		clientCertHash = certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
   217  	}
   218  
   219  	pkiID := common.PKIidType(endpoint)
   220  	assert.NoError(t, err, "%v", err)
   221  	msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte(endpoint), func(msg []byte) ([]byte, error) {
   222  		mac := hmac.New(sha256.New, hmacKey)
   223  		mac.Write(msg)
   224  		return mac.Sum(nil), nil
   225  	}, false)
   226  	// Mutate connection message to test negative paths
   227  	msg = connMutator(msg)
   228  	// Send your own connection message
   229  	stream.Send(msg.Envelope)
   230  	// Wait for connection message from the other side
   231  	envelope, err := stream.Recv()
   232  	if err != nil {
   233  		return acceptChan
   234  	}
   235  	assert.NoError(t, err, "%v", err)
   236  	msg, err = protoext.EnvelopeToGossipMessage(envelope)
   237  	assert.NoError(t, err, "%v", err)
   238  	assert.Equal(t, []byte(target), msg.GetConn().PkiId)
   239  	assert.Equal(t, extractCertificateHashFromContext(stream.Context()), msg.GetConn().TlsCertHash)
   240  	msg2Send := createGossipMsg()
   241  	nonce := uint64(rand.Int())
   242  	msg2Send.Nonce = nonce
   243  	go stream.Send(msg2Send.Envelope)
   244  	return acceptChan
   245  }
   246  
   247  func TestMutualParallelSendWithAck(t *testing.T) {
   248  	t.Parallel()
   249  
   250  	// This test tests concurrent and parallel sending of many (1000) messages
   251  	// from 2 instances to one another at the same time.
   252  
   253  	msgNum := 1000
   254  
   255  	comm1, port1 := newCommInstance(t, naiveSec)
   256  	comm2, port2 := newCommInstance(t, naiveSec)
   257  	defer comm1.Stop()
   258  	defer comm2.Stop()
   259  
   260  	acceptData := func(o interface{}) bool {
   261  		m := o.(protoext.ReceivedMessage).GetGossipMessage()
   262  		return protoext.IsDataMsg(m.GossipMessage)
   263  	}
   264  
   265  	inc1 := comm1.Accept(acceptData)
   266  	inc2 := comm2.Accept(acceptData)
   267  
   268  	// Send a message from comm1 to comm2, to make the instances establish a preliminary connection
   269  	comm1.Send(createGossipMsg(), remotePeer(port2))
   270  	// Wait for the message to be received in comm2
   271  	<-inc2
   272  
   273  	for i := 0; i < msgNum; i++ {
   274  		go comm1.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port2))
   275  	}
   276  
   277  	for i := 0; i < msgNum; i++ {
   278  		go comm2.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port1))
   279  	}
   280  
   281  	go func() {
   282  		for i := 0; i < msgNum; i++ {
   283  			<-inc1
   284  		}
   285  	}()
   286  
   287  	for i := 0; i < msgNum; i++ {
   288  		<-inc2
   289  	}
   290  }
   291  
   292  func getAvailablePort(t *testing.T) (port int, endpoint string, ll net.Listener) {
   293  	ll, err := net.Listen("tcp", "127.0.0.1:0")
   294  	assert.NoError(t, err)
   295  	endpoint = ll.Addr().String()
   296  	_, portS, err := net.SplitHostPort(endpoint)
   297  	assert.NoError(t, err)
   298  	portInt, err := strconv.Atoi(portS)
   299  	assert.NoError(t, err)
   300  	return portInt, endpoint, ll
   301  }
   302  
   303  func TestHandshake(t *testing.T) {
   304  	t.Parallel()
   305  	signer := func(msg []byte) ([]byte, error) {
   306  		mac := hmac.New(sha256.New, hmacKey)
   307  		mac.Write(msg)
   308  		return mac.Sum(nil), nil
   309  	}
   310  	mutator := func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   311  		return msg
   312  	}
   313  	assertPositivePath := func(msg protoext.ReceivedMessage, endpoint string) {
   314  		expectedPKIID := common.PKIidType(endpoint)
   315  		assert.Equal(t, expectedPKIID, msg.GetConnectionInfo().ID)
   316  		assert.Equal(t, api.PeerIdentityType(endpoint), msg.GetConnectionInfo().Identity)
   317  		assert.NotNil(t, msg.GetConnectionInfo().Auth)
   318  		sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
   319  		assert.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
   320  	}
   321  
   322  	// Positive path 1 - check authentication without TLS
   323  	port, endpoint, ll := getAvailablePort(t)
   324  	s := grpc.NewServer()
   325  	id := []byte(endpoint)
   326  	idMapper := identity.NewIdentityMapper(naiveSec, id, noopPurgeIdentity, naiveSec)
   327  	inst, err := NewCommInstance(s, nil, idMapper, api.PeerIdentityType(endpoint), func() []grpc.DialOption {
   328  		return []grpc.DialOption{grpc.WithInsecure()}
   329  	}, naiveSec, disabledMetrics, testCommConfig)
   330  	go s.Serve(ll)
   331  	assert.NoError(t, err)
   332  	var msg protoext.ReceivedMessage
   333  
   334  	_, tempEndpoint, tempL := getAvailablePort(t)
   335  	acceptChan := handshaker(port, tempEndpoint, inst, t, mutator, none)
   336  	select {
   337  	case <-time.After(time.Duration(time.Second * 4)):
   338  		assert.FailNow(t, "Didn't receive a message, seems like handshake failed")
   339  	case msg = <-acceptChan:
   340  	}
   341  	assert.Equal(t, common.PKIidType(tempEndpoint), msg.GetConnectionInfo().ID)
   342  	assert.Equal(t, api.PeerIdentityType(tempEndpoint), msg.GetConnectionInfo().Identity)
   343  	sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
   344  	assert.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
   345  
   346  	inst.Stop()
   347  	s.Stop()
   348  	ll.Close()
   349  	tempL.Close()
   350  	time.Sleep(time.Second)
   351  
   352  	comm, port := newCommInstance(t, naiveSec)
   353  	defer comm.Stop()
   354  	// Positive path 2: initiating peer sends its own certificate
   355  	_, tempEndpoint, tempL = getAvailablePort(t)
   356  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   357  
   358  	select {
   359  	case <-time.After(time.Second * 2):
   360  		assert.FailNow(t, "Didn't receive a message, seems like handshake failed")
   361  	case msg = <-acceptChan:
   362  	}
   363  	assertPositivePath(msg, tempEndpoint)
   364  	tempL.Close()
   365  
   366  	// Negative path: initiating peer doesn't send its own certificate
   367  	_, tempEndpoint, tempL = getAvailablePort(t)
   368  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, oneWayTLS)
   369  	time.Sleep(time.Second)
   370  	assert.Equal(t, 0, len(acceptChan))
   371  	tempL.Close()
   372  
   373  	// Negative path, signature is wrong
   374  	_, tempEndpoint, tempL = getAvailablePort(t)
   375  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   376  		msg.Signature = append(msg.Signature, 0)
   377  		return msg
   378  	}
   379  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   380  	time.Sleep(time.Second)
   381  	assert.Equal(t, 0, len(acceptChan))
   382  	tempL.Close()
   383  
   384  	// Negative path, the PKIid doesn't match the identity
   385  	_, tempEndpoint, tempL = getAvailablePort(t)
   386  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   387  		msg.GetConn().PkiId = []byte(tempEndpoint)
   388  		// Sign the message again
   389  		msg.Sign(signer)
   390  		return msg
   391  	}
   392  	_, tempEndpoint2, tempL2 := getAvailablePort(t)
   393  	acceptChan = handshaker(port, tempEndpoint2, comm, t, mutator, mutualTLS)
   394  	time.Sleep(time.Second)
   395  	assert.Equal(t, 0, len(acceptChan))
   396  	tempL.Close()
   397  	tempL2.Close()
   398  
   399  	// Negative path, the cert hash isn't what is expected
   400  	_, tempEndpoint, tempL = getAvailablePort(t)
   401  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   402  		msg.GetConn().TlsCertHash = append(msg.GetConn().TlsCertHash, 0)
   403  		msg.Sign(signer)
   404  		return msg
   405  	}
   406  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   407  	time.Sleep(time.Second)
   408  	assert.Equal(t, 0, len(acceptChan))
   409  	tempL.Close()
   410  
   411  	// Negative path, no PKI-ID was sent
   412  	_, tempEndpoint, tempL = getAvailablePort(t)
   413  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   414  		msg.GetConn().PkiId = nil
   415  		msg.Sign(signer)
   416  		return msg
   417  	}
   418  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   419  	time.Sleep(time.Second)
   420  	assert.Equal(t, 0, len(acceptChan))
   421  	tempL.Close()
   422  
   423  	// Negative path, connection message is of a different type
   424  	_, tempEndpoint, tempL = getAvailablePort(t)
   425  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   426  		msg.Content = &proto.GossipMessage_Empty{
   427  			Empty: &proto.Empty{},
   428  		}
   429  		msg.Sign(signer)
   430  		return msg
   431  	}
   432  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   433  	time.Sleep(time.Second)
   434  	assert.Equal(t, 0, len(acceptChan))
   435  	tempL.Close()
   436  
   437  	// Negative path, the peer didn't respond to the handshake in due time
   438  	_, tempEndpoint, tempL = getAvailablePort(t)
   439  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   440  		time.Sleep(time.Second * 5)
   441  		return msg
   442  	}
   443  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   444  	time.Sleep(time.Second)
   445  	assert.Equal(t, 0, len(acceptChan))
   446  	tempL.Close()
   447  }
   448  
   449  func TestConnectUnexpectedPeer(t *testing.T) {
   450  	t.Parallel()
   451  	// Scenarios: In both scenarios, comm1 connects to comm2 or comm3.
   452  	// and expects to see a PKI-ID which is equal to comm4's PKI-ID.
   453  	// The connection attempt would succeed or fail based on whether comm2 or comm3
   454  	// are in the same org as comm4
   455  
   456  	identityByPort := func(port int) api.PeerIdentityType {
   457  		return api.PeerIdentityType(fmt.Sprintf("127.0.0.1:%d", port))
   458  	}
   459  
   460  	customNaiveSec := &naiveSecProvider{}
   461  
   462  	comm1Port, gRPCServer1, certs1, secureDialOpts1, dialOpts1 := util.CreateGRPCLayer()
   463  	comm2Port, gRPCServer2, certs2, secureDialOpts2, dialOpts2 := util.CreateGRPCLayer()
   464  	comm3Port, gRPCServer3, certs3, secureDialOpts3, dialOpts3 := util.CreateGRPCLayer()
   465  	comm4Port, gRPCServer4, certs4, secureDialOpts4, dialOpts4 := util.CreateGRPCLayer()
   466  
   467  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm1Port)).Return(api.OrgIdentityType("O"))
   468  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm2Port)).Return(api.OrgIdentityType("A"))
   469  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm3Port)).Return(api.OrgIdentityType("B"))
   470  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm4Port)).Return(api.OrgIdentityType("A"))
   471  
   472  	comm1 := newCommInstanceOnly(t, customNaiveSec, gRPCServer1, certs1, secureDialOpts1, dialOpts1...)
   473  	comm2 := newCommInstanceOnly(t, naiveSec, gRPCServer2, certs2, secureDialOpts2, dialOpts2...)
   474  	comm3 := newCommInstanceOnly(t, naiveSec, gRPCServer3, certs3, secureDialOpts3, dialOpts3...)
   475  	comm4 := newCommInstanceOnly(t, naiveSec, gRPCServer4, certs4, secureDialOpts4, dialOpts4...)
   476  
   477  	defer comm1.Stop()
   478  	defer comm2.Stop()
   479  	defer comm3.Stop()
   480  	defer comm4.Stop()
   481  
   482  	messagesForComm1 := comm1.Accept(acceptAll)
   483  	messagesForComm2 := comm2.Accept(acceptAll)
   484  	messagesForComm3 := comm3.Accept(acceptAll)
   485  
   486  	// Have comm4 send a message to comm1
   487  	// in order for comm1 to know comm4
   488  	comm4.Send(createGossipMsg(), remotePeer(comm1Port))
   489  	<-messagesForComm1
   490  	// Close the connection with comm4
   491  	comm1.CloseConn(remotePeer(comm4Port))
   492  	// At this point, comm1 knows comm4's identity and organization
   493  
   494  	t.Run("Same organization", func(t *testing.T) {
   495  		unexpectedRemotePeer := remotePeer(comm2Port)
   496  		unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
   497  		comm1.Send(createGossipMsg(), unexpectedRemotePeer)
   498  		select {
   499  		case <-messagesForComm2:
   500  		case <-time.After(time.Second * 5):
   501  			assert.Fail(t, "Didn't receive a message within a timely manner")
   502  			util.PrintStackTrace()
   503  		}
   504  	})
   505  
   506  	t.Run("Unexpected organization", func(t *testing.T) {
   507  		unexpectedRemotePeer := remotePeer(comm3Port)
   508  		unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
   509  		comm1.Send(createGossipMsg(), unexpectedRemotePeer)
   510  		select {
   511  		case <-messagesForComm3:
   512  			assert.Fail(t, "Message shouldn't have been received")
   513  		case <-time.After(time.Second * 5):
   514  		}
   515  	})
   516  }
   517  
   518  func TestGetConnectionInfo(t *testing.T) {
   519  	t.Parallel()
   520  	comm1, port1 := newCommInstance(t, naiveSec)
   521  	comm2, _ := newCommInstance(t, naiveSec)
   522  	defer comm1.Stop()
   523  	defer comm2.Stop()
   524  	m1 := comm1.Accept(acceptAll)
   525  	comm2.Send(createGossipMsg(), remotePeer(port1))
   526  	select {
   527  	case <-time.After(time.Second * 10):
   528  		t.Fatal("Didn't receive a message in time")
   529  	case msg := <-m1:
   530  		assert.Equal(t, comm2.GetPKIid(), msg.GetConnectionInfo().ID)
   531  		assert.NotNil(t, msg.GetSourceEnvelope())
   532  	}
   533  }
   534  
   535  func TestCloseConn(t *testing.T) {
   536  	t.Parallel()
   537  	comm1, port1 := newCommInstance(t, naiveSec)
   538  	defer comm1.Stop()
   539  	acceptChan := comm1.Accept(acceptAll)
   540  
   541  	cert := GenerateCertificatesOrPanic()
   542  	tlsCfg := &tls.Config{
   543  		InsecureSkipVerify: true,
   544  		Certificates:       []tls.Certificate{cert},
   545  	}
   546  	ta := credentials.NewTLS(tlsCfg)
   547  
   548  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   549  	defer cancel()
   550  	target := fmt.Sprintf("127.0.0.1:%d", port1)
   551  	conn, err := grpc.DialContext(ctx, target, grpc.WithTransportCredentials(ta), grpc.WithBlock())
   552  	assert.NoError(t, err, "%v", err)
   553  	cl := proto.NewGossipClient(conn)
   554  	stream, err := cl.GossipStream(context.Background())
   555  	assert.NoError(t, err, "%v", err)
   556  	c := &commImpl{}
   557  	tlsCertHash := certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
   558  	connMsg, _ := c.createConnectionMsg(common.PKIidType("pkiID"), tlsCertHash, api.PeerIdentityType("pkiID"), func(msg []byte) ([]byte, error) {
   559  		mac := hmac.New(sha256.New, hmacKey)
   560  		mac.Write(msg)
   561  		return mac.Sum(nil), nil
   562  	}, false)
   563  	assert.NoError(t, stream.Send(connMsg.Envelope))
   564  	stream.Send(createGossipMsg().Envelope)
   565  	select {
   566  	case <-acceptChan:
   567  	case <-time.After(time.Second):
   568  		assert.Fail(t, "Didn't receive a message within a timely period")
   569  	}
   570  	comm1.CloseConn(&RemotePeer{PKIID: common.PKIidType("pkiID")})
   571  	time.Sleep(time.Second * 10)
   572  	gotErr := false
   573  	msg2Send := createGossipMsg()
   574  	msg2Send.GetDataMsg().Payload = &proto.Payload{
   575  		Data: make([]byte, 1024*1024),
   576  	}
   577  	protoext.NoopSign(msg2Send.GossipMessage)
   578  	for i := 0; i < DefRecvBuffSize; i++ {
   579  		err := stream.Send(msg2Send.Envelope)
   580  		if err != nil {
   581  			gotErr = true
   582  			break
   583  		}
   584  	}
   585  	assert.True(t, gotErr, "Should have failed because connection is closed")
   586  }
   587  
   588  // TestCommSend makes sure that enough messages get through
   589  // eventually. Comm.Send() is both asynchronous and best-effort, so this test
   590  // case assumes some will fail, but that eventually enough messages will get
   591  // through that the test will end.
   592  func TestCommSend(t *testing.T) {
   593  	t.Parallel()
   594  	comm1, port1 := newCommInstance(t, naiveSec)
   595  	comm2, port2 := newCommInstance(t, naiveSec)
   596  	defer comm1.Stop()
   597  	defer comm2.Stop()
   598  
   599  	// Create the receive channel before sending the messages
   600  	ch1 := comm1.Accept(acceptAll)
   601  	ch2 := comm2.Accept(acceptAll)
   602  
   603  	// control channels for background senders
   604  	stopch1 := make(chan struct{})
   605  	stopch2 := make(chan struct{})
   606  
   607  	go func() {
   608  		for {
   609  			emptyMsg := createGossipMsg()
   610  			select {
   611  			case <-stopch1:
   612  				return
   613  			default:
   614  				comm1.Send(emptyMsg, remotePeer(port2))
   615  			}
   616  		}
   617  	}()
   618  
   619  	go func() {
   620  		for {
   621  			emptyMsg := createGossipMsg()
   622  			select {
   623  			case <-stopch2:
   624  				return
   625  			default:
   626  				comm2.Send(emptyMsg, remotePeer(port1))
   627  			}
   628  		}
   629  	}()
   630  
   631  	c1recieved := 0
   632  	c2recieved := 0
   633  	// hopefully in some runs we'll fill both send and receive buffers and
   634  	// drop overflowing messages, but still finish, because the endless
   635  	// stream of messages inexorably gets through unless something is very
   636  	// broken.
   637  	totalMessagesRecieved := (DefSendBuffSize + DefRecvBuffSize) * 2
   638  RECV:
   639  	for {
   640  		select {
   641  		case <-ch1:
   642  			c1recieved++
   643  			if c1recieved == totalMessagesRecieved {
   644  				close(stopch2)
   645  			}
   646  		case <-ch2:
   647  			c2recieved++
   648  			if c2recieved == totalMessagesRecieved {
   649  				close(stopch1)
   650  			}
   651  		case <-time.After(30 * time.Second):
   652  			t.Fatalf("timed out waiting for messages to be recieved.\nc1 got %d messages\nc2 got %d messages", c1recieved, c2recieved)
   653  		default:
   654  			if c1recieved > totalMessagesRecieved && c2recieved > totalMessagesRecieved {
   655  				break RECV
   656  			}
   657  		}
   658  	}
   659  	t.Logf("c1 got %d messages\nc2 got %d messages", c1recieved, c2recieved)
   660  }
   661  
   662  type nonResponsivePeer struct {
   663  	*grpc.Server
   664  	port int
   665  }
   666  
   667  func newNonResponsivePeer(t *testing.T) *nonResponsivePeer {
   668  	port, gRPCServer, _, _, _ := util.CreateGRPCLayer()
   669  	nrp := &nonResponsivePeer{
   670  		Server: gRPCServer.Server(),
   671  		port:   port,
   672  	}
   673  	proto.RegisterGossipServer(gRPCServer.Server(), nrp)
   674  	return nrp
   675  }
   676  
   677  func (bp *nonResponsivePeer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) {
   678  	time.Sleep(time.Second * 15)
   679  	return &proto.Empty{}, nil
   680  }
   681  
   682  func (bp *nonResponsivePeer) GossipStream(stream proto.Gossip_GossipStreamServer) error {
   683  	return nil
   684  }
   685  
   686  func (bp *nonResponsivePeer) stop() {
   687  	bp.Server.Stop()
   688  }
   689  
   690  func TestNonResponsivePing(t *testing.T) {
   691  	t.Parallel()
   692  	c, _ := newCommInstance(t, naiveSec)
   693  	defer c.Stop()
   694  	nonRespPeer := newNonResponsivePeer(t)
   695  	defer nonRespPeer.stop()
   696  	s := make(chan struct{})
   697  	go func() {
   698  		c.Probe(remotePeer(nonRespPeer.port))
   699  		s <- struct{}{}
   700  	}()
   701  	select {
   702  	case <-time.After(time.Second * 10):
   703  		assert.Fail(t, "Request wasn't cancelled on time")
   704  	case <-s:
   705  	}
   706  
   707  }
   708  
   709  func TestResponses(t *testing.T) {
   710  	t.Parallel()
   711  	comm1, port1 := newCommInstance(t, naiveSec)
   712  	comm2, _ := newCommInstance(t, naiveSec)
   713  
   714  	defer comm1.Stop()
   715  	defer comm2.Stop()
   716  
   717  	wg := sync.WaitGroup{}
   718  
   719  	msg := createGossipMsg()
   720  	wg.Add(1)
   721  	go func() {
   722  		inChan := comm1.Accept(acceptAll)
   723  		wg.Done()
   724  		for m := range inChan {
   725  			reply := createGossipMsg()
   726  			reply.Nonce = m.GetGossipMessage().Nonce + 1
   727  			m.Respond(reply.GossipMessage)
   728  		}
   729  	}()
   730  	expectedNOnce := uint64(msg.Nonce + 1)
   731  	responsesFromComm1 := comm2.Accept(acceptAll)
   732  
   733  	ticker := time.NewTicker(10 * time.Second)
   734  	wg.Wait()
   735  	comm2.Send(msg, remotePeer(port1))
   736  
   737  	select {
   738  	case <-ticker.C:
   739  		assert.Fail(t, "Haven't got response from comm1 within a timely manner")
   740  		break
   741  	case resp := <-responsesFromComm1:
   742  		ticker.Stop()
   743  		assert.Equal(t, expectedNOnce, resp.GetGossipMessage().Nonce)
   744  		break
   745  	}
   746  }
   747  
   748  // TestAccept makes sure that accept filters work. The probability of the parity
   749  // of all nonces being 0 or 1 is very low.
   750  func TestAccept(t *testing.T) {
   751  	t.Parallel()
   752  	comm1, port1 := newCommInstance(t, naiveSec)
   753  	comm2, _ := newCommInstance(t, naiveSec)
   754  
   755  	evenNONCESelector := func(m interface{}) bool {
   756  		return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 == 0
   757  	}
   758  
   759  	oddNONCESelector := func(m interface{}) bool {
   760  		return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 != 0
   761  	}
   762  
   763  	evenNONCES := comm1.Accept(evenNONCESelector)
   764  	oddNONCES := comm1.Accept(oddNONCESelector)
   765  
   766  	var evenResults []uint64
   767  	var oddResults []uint64
   768  
   769  	out := make(chan uint64)
   770  	sem := make(chan struct{})
   771  
   772  	readIntoSlice := func(a *[]uint64, ch <-chan protoext.ReceivedMessage) {
   773  		for m := range ch {
   774  			*a = append(*a, m.GetGossipMessage().Nonce)
   775  			select {
   776  			case out <- m.GetGossipMessage().Nonce:
   777  			default: // avoid blocking when we stop reading from out
   778  			}
   779  		}
   780  		sem <- struct{}{}
   781  	}
   782  
   783  	go readIntoSlice(&evenResults, evenNONCES)
   784  	go readIntoSlice(&oddResults, oddNONCES)
   785  
   786  	stopSend := make(chan struct{})
   787  	go func() {
   788  		for {
   789  			select {
   790  			case <-stopSend:
   791  				return
   792  			default:
   793  				comm2.Send(createGossipMsg(), remotePeer(port1))
   794  			}
   795  		}
   796  	}()
   797  
   798  	waitForMessages(t, out, (DefSendBuffSize+DefRecvBuffSize)*2, "Didn't receive all messages sent")
   799  	close(stopSend)
   800  
   801  	comm1.Stop()
   802  	comm2.Stop()
   803  
   804  	<-sem
   805  	<-sem
   806  
   807  	t.Logf("%d even nonces received", len(evenResults))
   808  	t.Logf("%d  odd nonces received", len(oddResults))
   809  
   810  	assert.NotEmpty(t, evenResults)
   811  	assert.NotEmpty(t, oddResults)
   812  
   813  	remainderPredicate := func(a []uint64, rem uint64) {
   814  		for _, n := range a {
   815  			assert.Equal(t, n%2, rem)
   816  		}
   817  	}
   818  
   819  	remainderPredicate(evenResults, 0)
   820  	remainderPredicate(oddResults, 1)
   821  }
   822  
   823  func TestReConnections(t *testing.T) {
   824  	t.Parallel()
   825  	comm1, port1 := newCommInstance(t, naiveSec)
   826  	comm2, port2 := newCommInstance(t, naiveSec)
   827  
   828  	reader := func(out chan uint64, in <-chan protoext.ReceivedMessage) {
   829  		for {
   830  			msg := <-in
   831  			if msg == nil {
   832  				return
   833  			}
   834  			out <- msg.GetGossipMessage().Nonce
   835  		}
   836  	}
   837  
   838  	out1 := make(chan uint64, 10)
   839  	out2 := make(chan uint64, 10)
   840  
   841  	go reader(out1, comm1.Accept(acceptAll))
   842  	go reader(out2, comm2.Accept(acceptAll))
   843  
   844  	// comm1 connects to comm2
   845  	comm1.Send(createGossipMsg(), remotePeer(port2))
   846  	waitForMessages(t, out2, 1, "Comm2 didn't receive a message from comm1 in a timely manner")
   847  	// comm2 sends to comm1
   848  	comm2.Send(createGossipMsg(), remotePeer(port1))
   849  	waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
   850  	comm1.Stop()
   851  
   852  	comm1, port1 = newCommInstance(t, naiveSec)
   853  	out1 = make(chan uint64, 1)
   854  	go reader(out1, comm1.Accept(acceptAll))
   855  	comm2.Send(createGossipMsg(), remotePeer(port1))
   856  	waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
   857  	comm1.Stop()
   858  	comm2.Stop()
   859  }
   860  
   861  func TestProbe(t *testing.T) {
   862  	t.Parallel()
   863  	comm1, port1 := newCommInstance(t, naiveSec)
   864  	defer comm1.Stop()
   865  	comm2, port2 := newCommInstance(t, naiveSec)
   866  	time.Sleep(time.Duration(1) * time.Second)
   867  	assert.NoError(t, comm1.Probe(remotePeer(port2)))
   868  	_, err := comm1.Handshake(remotePeer(port2))
   869  	assert.NoError(t, err)
   870  	tempPort, _, ll := getAvailablePort(t)
   871  	defer ll.Close()
   872  	assert.Error(t, comm1.Probe(remotePeer(tempPort)))
   873  	_, err = comm1.Handshake(remotePeer(tempPort))
   874  	assert.Error(t, err)
   875  	comm2.Stop()
   876  	time.Sleep(time.Duration(1) * time.Second)
   877  	assert.Error(t, comm1.Probe(remotePeer(port2)))
   878  	_, err = comm1.Handshake(remotePeer(port2))
   879  	assert.Error(t, err)
   880  	comm2, port2 = newCommInstance(t, naiveSec)
   881  	defer comm2.Stop()
   882  	time.Sleep(time.Duration(1) * time.Second)
   883  	assert.NoError(t, comm2.Probe(remotePeer(port1)))
   884  	_, err = comm2.Handshake(remotePeer(port1))
   885  	assert.NoError(t, err)
   886  	assert.NoError(t, comm1.Probe(remotePeer(port2)))
   887  	_, err = comm1.Handshake(remotePeer(port2))
   888  	assert.NoError(t, err)
   889  	// Now try a deep probe with an expected PKI-ID that doesn't match
   890  	wrongRemotePeer := remotePeer(port2)
   891  	if wrongRemotePeer.PKIID[0] == 0 {
   892  		wrongRemotePeer.PKIID[0] = 1
   893  	} else {
   894  		wrongRemotePeer.PKIID[0] = 0
   895  	}
   896  	_, err = comm1.Handshake(wrongRemotePeer)
   897  	assert.Error(t, err)
   898  	// Try a deep probe with a nil PKI-ID
   899  	endpoint := fmt.Sprintf("127.0.0.1:%d", port2)
   900  	id, err := comm1.Handshake(&RemotePeer{Endpoint: endpoint})
   901  	assert.NoError(t, err)
   902  	assert.Equal(t, api.PeerIdentityType(endpoint), id)
   903  }
   904  
   905  func TestPresumedDead(t *testing.T) {
   906  	t.Parallel()
   907  	comm1, _ := newCommInstance(t, naiveSec)
   908  	comm2, port2 := newCommInstance(t, naiveSec)
   909  
   910  	wg := sync.WaitGroup{}
   911  	wg.Add(1)
   912  	go func() {
   913  		wg.Wait()
   914  		comm1.Send(createGossipMsg(), remotePeer(port2))
   915  	}()
   916  
   917  	ticker := time.NewTicker(time.Duration(10) * time.Second)
   918  	acceptCh := comm2.Accept(acceptAll)
   919  	wg.Done()
   920  	select {
   921  	case <-acceptCh:
   922  		ticker.Stop()
   923  	case <-ticker.C:
   924  		assert.Fail(t, "Didn't get first message")
   925  	}
   926  
   927  	comm2.Stop()
   928  	go func() {
   929  		for i := 0; i < 5; i++ {
   930  			comm1.Send(createGossipMsg(), remotePeer(port2))
   931  			time.Sleep(time.Millisecond * 200)
   932  		}
   933  	}()
   934  
   935  	ticker = time.NewTicker(time.Second * time.Duration(3))
   936  	select {
   937  	case <-ticker.C:
   938  		assert.Fail(t, "Didn't get a presumed dead message within a timely manner")
   939  		break
   940  	case <-comm1.PresumedDead():
   941  		ticker.Stop()
   942  		break
   943  	}
   944  }
   945  
   946  func TestReadFromStream(t *testing.T) {
   947  	stream := &gmocks.MockStream{}
   948  	stream.On("CloseSend").Return(nil)
   949  	stream.On("Recv").Return(&proto.Envelope{Payload: []byte{1}}, nil).Once()
   950  	stream.On("Recv").Return(nil, errors.New("stream closed")).Once()
   951  
   952  	conn := newConnection(nil, nil, stream, disabledMetrics, ConnConfig{1, 1})
   953  	conn.logger = flogging.MustGetLogger("test")
   954  
   955  	errChan := make(chan error, 2)
   956  	msgChan := make(chan *protoext.SignedGossipMessage, 1)
   957  	var wg sync.WaitGroup
   958  	wg.Add(1)
   959  	go func() {
   960  		defer wg.Done()
   961  		conn.readFromStream(errChan, msgChan)
   962  	}()
   963  
   964  	select {
   965  	case <-msgChan:
   966  		assert.Fail(t, "malformed message shouldn't have been received")
   967  	case <-time.After(time.Millisecond * 100):
   968  		assert.Len(t, errChan, 1)
   969  	}
   970  
   971  	conn.close()
   972  	wg.Wait()
   973  }
   974  
   975  func TestSendBadEnvelope(t *testing.T) {
   976  	comm1, port := newCommInstance(t, naiveSec)
   977  	defer comm1.Stop()
   978  
   979  	stream, err := establishSession(t, port)
   980  	assert.NoError(t, err)
   981  
   982  	inc := comm1.Accept(acceptAll)
   983  
   984  	goodMsg := createGossipMsg()
   985  	err = stream.Send(goodMsg.Envelope)
   986  	assert.NoError(t, err)
   987  
   988  	select {
   989  	case goodMsgReceived := <-inc:
   990  		assert.Equal(t, goodMsg.Envelope.Payload, goodMsgReceived.GetSourceEnvelope().Payload)
   991  	case <-time.After(time.Minute):
   992  		assert.Fail(t, "Didn't receive message within a timely manner")
   993  		return
   994  	}
   995  
   996  	// Next, we corrupt a message and send it until the stream is closed forcefully from the remote peer
   997  	start := time.Now()
   998  	for {
   999  		badMsg := createGossipMsg()
  1000  		badMsg.Envelope.Payload = []byte{1}
  1001  		err = stream.Send(badMsg.Envelope)
  1002  		if err != nil {
  1003  			assert.Equal(t, io.EOF, err)
  1004  			break
  1005  		}
  1006  		if time.Now().After(start.Add(time.Second * 30)) {
  1007  			assert.Fail(t, "Didn't close stream within a timely manner")
  1008  			return
  1009  		}
  1010  	}
  1011  }
  1012  
  1013  func establishSession(t *testing.T, port int) (proto.Gossip_GossipStreamClient, error) {
  1014  	cert := GenerateCertificatesOrPanic()
  1015  	secureOpts := grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
  1016  		InsecureSkipVerify: true,
  1017  		Certificates:       []tls.Certificate{cert},
  1018  	}))
  1019  
  1020  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  1021  	defer cancel()
  1022  
  1023  	endpoint := fmt.Sprintf("127.0.0.1:%d", port)
  1024  	conn, err := grpc.DialContext(ctx, endpoint, secureOpts, grpc.WithBlock())
  1025  	assert.NoError(t, err, "%v", err)
  1026  	if err != nil {
  1027  		return nil, err
  1028  	}
  1029  	cl := proto.NewGossipClient(conn)
  1030  	stream, err := cl.GossipStream(context.Background())
  1031  	assert.NoError(t, err, "%v", err)
  1032  	if err != nil {
  1033  		return nil, err
  1034  	}
  1035  
  1036  	clientCertHash := certHashFromRawCert(cert.Certificate[0])
  1037  	pkiID := common.PKIidType([]byte{1, 2, 3})
  1038  	c := &commImpl{}
  1039  	assert.NoError(t, err, "%v", err)
  1040  	msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte{1, 2, 3}, func(msg []byte) ([]byte, error) {
  1041  		mac := hmac.New(sha256.New, hmacKey)
  1042  		mac.Write(msg)
  1043  		return mac.Sum(nil), nil
  1044  	}, false)
  1045  	// Send your own connection message
  1046  	stream.Send(msg.Envelope)
  1047  	// Wait for connection message from the other side
  1048  	envelope, err := stream.Recv()
  1049  	if err != nil {
  1050  		return nil, err
  1051  	}
  1052  	assert.NotNil(t, envelope)
  1053  	return stream, nil
  1054  }
  1055  
  1056  func createGossipMsg() *protoext.SignedGossipMessage {
  1057  	msg, _ := protoext.NoopSign(&proto.GossipMessage{
  1058  		Tag:   proto.GossipMessage_EMPTY,
  1059  		Nonce: uint64(rand.Int()),
  1060  		Content: &proto.GossipMessage_DataMsg{
  1061  			DataMsg: &proto.DataMessage{},
  1062  		},
  1063  	})
  1064  	return msg
  1065  }
  1066  
  1067  func remotePeer(port int) *RemotePeer {
  1068  	endpoint := fmt.Sprintf("127.0.0.1:%d", port)
  1069  	return &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)}
  1070  }
  1071  
  1072  func waitForMessages(t *testing.T, msgChan chan uint64, count int, errMsg string) {
  1073  	c := 0
  1074  	waiting := true
  1075  	ticker := time.NewTicker(time.Duration(10) * time.Second)
  1076  	for waiting {
  1077  		select {
  1078  		case <-msgChan:
  1079  			c++
  1080  			if c == count {
  1081  				waiting = false
  1082  			}
  1083  		case <-ticker.C:
  1084  			waiting = false
  1085  		}
  1086  	}
  1087  	assert.Equal(t, count, c, errMsg)
  1088  }
  1089  
  1090  func TestConcurrentCloseSend(t *testing.T) {
  1091  	t.Parallel()
  1092  	var stopping int32
  1093  
  1094  	comm1, _ := newCommInstance(t, naiveSec)
  1095  	comm2, port2 := newCommInstance(t, naiveSec)
  1096  	m := comm2.Accept(acceptAll)
  1097  	comm1.Send(createGossipMsg(), remotePeer(port2))
  1098  	<-m
  1099  	ready := make(chan struct{})
  1100  	done := make(chan struct{})
  1101  	go func() {
  1102  		defer close(done)
  1103  
  1104  		comm1.Send(createGossipMsg(), remotePeer(port2))
  1105  		close(ready)
  1106  
  1107  		for atomic.LoadInt32(&stopping) == int32(0) {
  1108  			comm1.Send(createGossipMsg(), remotePeer(port2))
  1109  		}
  1110  	}()
  1111  	<-ready
  1112  	comm2.Stop()
  1113  	atomic.StoreInt32(&stopping, int32(1))
  1114  	<-done
  1115  }