github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/gossip/comm/comm_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/hmac"
    13  	"crypto/sha256"
    14  	"crypto/tls"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"math/rand"
    19  	"net"
    20  	"strconv"
    21  	"sync"
    22  	"sync/atomic"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/hechain20/hechain/bccsp/factory"
    27  	"github.com/hechain20/hechain/common/flogging"
    28  	"github.com/hechain20/hechain/common/metrics/disabled"
    29  	"github.com/hechain20/hechain/gossip/api"
    30  	"github.com/hechain20/hechain/gossip/api/mocks"
    31  	gmocks "github.com/hechain20/hechain/gossip/comm/mocks"
    32  	"github.com/hechain20/hechain/gossip/common"
    33  	"github.com/hechain20/hechain/gossip/identity"
    34  	"github.com/hechain20/hechain/gossip/metrics"
    35  	"github.com/hechain20/hechain/gossip/protoext"
    36  	"github.com/hechain20/hechain/gossip/util"
    37  	"github.com/hechain20/hechain/internal/pkg/comm"
    38  	cb "github.com/hyperledger/fabric-protos-go/common"
    39  	proto "github.com/hyperledger/fabric-protos-go/gossip"
    40  	"github.com/stretchr/testify/mock"
    41  	"github.com/stretchr/testify/require"
    42  	"google.golang.org/grpc"
    43  	"google.golang.org/grpc/credentials"
    44  )
    45  
    46  func init() {
    47  	util.SetupTestLogging()
    48  	rand.Seed(time.Now().UnixNano())
    49  	factory.InitFactories(nil)
    50  	naiveSec.On("OrgByPeerIdentity", mock.Anything).Return(api.OrgIdentityType{})
    51  }
    52  
    53  var testCommConfig = CommConfig{
    54  	DialTimeout:  300 * time.Millisecond,
    55  	ConnTimeout:  DefConnTimeout,
    56  	RecvBuffSize: DefRecvBuffSize,
    57  	SendBuffSize: DefSendBuffSize,
    58  }
    59  
    60  func acceptAll(msg interface{}) bool {
    61  	return true
    62  }
    63  
    64  var noopPurgeIdentity = func(_ common.PKIidType, _ api.PeerIdentityType) {
    65  }
    66  
    67  var (
    68  	naiveSec        = &naiveSecProvider{}
    69  	hmacKey         = []byte{0, 0, 0}
    70  	disabledMetrics = metrics.NewGossipMetrics(&disabled.Provider{}).CommMetrics
    71  )
    72  
    73  type naiveSecProvider struct {
    74  	mocks.SecurityAdvisor
    75  }
    76  
    77  func (nsp *naiveSecProvider) OrgByPeerIdentity(identity api.PeerIdentityType) api.OrgIdentityType {
    78  	return nsp.SecurityAdvisor.Called(identity).Get(0).(api.OrgIdentityType)
    79  }
    80  
    81  func (*naiveSecProvider) Expiration(peerIdentity api.PeerIdentityType) (time.Time, error) {
    82  	return time.Now().Add(time.Hour), nil
    83  }
    84  
    85  func (*naiveSecProvider) ValidateIdentity(peerIdentity api.PeerIdentityType) error {
    86  	return nil
    87  }
    88  
    89  // GetPKIidOfCert returns the PKI-ID of a peer's identity
    90  func (*naiveSecProvider) GetPKIidOfCert(peerIdentity api.PeerIdentityType) common.PKIidType {
    91  	return common.PKIidType(peerIdentity)
    92  }
    93  
    94  // VerifyBlock returns nil if the block is properly signed,
    95  // else returns error
    96  func (*naiveSecProvider) VerifyBlock(channelID common.ChannelID, seqNum uint64, signedBlock *cb.Block) error {
    97  	return nil
    98  }
    99  
   100  // Sign signs msg with this peer's signing key and outputs
   101  // the signature if no error occurred.
   102  func (*naiveSecProvider) Sign(msg []byte) ([]byte, error) {
   103  	mac := hmac.New(sha256.New, hmacKey)
   104  	mac.Write(msg)
   105  	return mac.Sum(nil), nil
   106  }
   107  
   108  // Verify checks that signature is a valid signature of message under a peer's verification key.
   109  // If the verification succeeded, Verify returns nil meaning no error occurred.
   110  // If peerCert is nil, then the signature is verified against this peer's verification key.
   111  func (*naiveSecProvider) Verify(peerIdentity api.PeerIdentityType, signature, message []byte) error {
   112  	mac := hmac.New(sha256.New, hmacKey)
   113  	mac.Write(message)
   114  	expected := mac.Sum(nil)
   115  	if !bytes.Equal(signature, expected) {
   116  		return fmt.Errorf("Wrong certificate:%v, %v", signature, message)
   117  	}
   118  	return nil
   119  }
   120  
   121  // VerifyByChannel verifies a peer's signature on a message in the context
   122  // of a specific channel
   123  func (*naiveSecProvider) VerifyByChannel(_ common.ChannelID, _ api.PeerIdentityType, _, _ []byte) error {
   124  	return nil
   125  }
   126  
   127  func newCommInstanceOnlyWithMetrics(t *testing.T, commMetrics *metrics.CommMetrics, sec *naiveSecProvider,
   128  	gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
   129  	secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
   130  	_, portString, err := net.SplitHostPort(gRPCServer.Address())
   131  	require.NoError(t, err)
   132  
   133  	endpoint := fmt.Sprintf("127.0.0.1:%s", portString)
   134  	id := []byte(endpoint)
   135  	identityMapper := identity.NewIdentityMapper(sec, id, noopPurgeIdentity, sec)
   136  
   137  	commInst, err := NewCommInstance(gRPCServer.Server(), certs, identityMapper, id, secureDialOpts,
   138  		sec, commMetrics, testCommConfig, dialOpts...)
   139  	require.NoError(t, err)
   140  
   141  	go func() {
   142  		err := gRPCServer.Start()
   143  		require.NoError(t, err)
   144  	}()
   145  
   146  	return &commGRPC{commInst.(*commImpl), gRPCServer}
   147  }
   148  
   149  type commGRPC struct {
   150  	*commImpl
   151  	gRPCServer *comm.GRPCServer
   152  }
   153  
   154  func (c *commGRPC) Stop() {
   155  	c.commImpl.Stop()
   156  	c.commImpl.idMapper.Stop()
   157  	c.gRPCServer.Stop()
   158  }
   159  
   160  func newCommInstanceOnly(t *testing.T, sec *naiveSecProvider,
   161  	gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
   162  	secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
   163  	return newCommInstanceOnlyWithMetrics(t, disabledMetrics, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
   164  }
   165  
   166  func newCommInstance(t *testing.T, sec *naiveSecProvider) (c Comm, port int) {
   167  	port, gRPCServer, certs, secureDialOpts, dialOpts := util.CreateGRPCLayer()
   168  	comm := newCommInstanceOnly(t, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
   169  	return comm, port
   170  }
   171  
   172  type msgMutator func(*protoext.SignedGossipMessage) *protoext.SignedGossipMessage
   173  
   174  type tlsType int
   175  
   176  const (
   177  	none tlsType = iota
   178  	oneWayTLS
   179  	mutualTLS
   180  )
   181  
   182  func handshaker(port int, endpoint string, comm Comm, t *testing.T, connMutator msgMutator, connType tlsType) <-chan protoext.ReceivedMessage {
   183  	c := &commImpl{}
   184  	cert := GenerateCertificatesOrPanic()
   185  	tlsCfg := &tls.Config{
   186  		InsecureSkipVerify: true,
   187  	}
   188  	if connType == mutualTLS {
   189  		tlsCfg.Certificates = []tls.Certificate{cert}
   190  	}
   191  	ta := credentials.NewTLS(tlsCfg)
   192  	secureOpts := grpc.WithTransportCredentials(ta)
   193  	if connType == none {
   194  		secureOpts = grpc.WithInsecure()
   195  	}
   196  	acceptChan := comm.Accept(acceptAll)
   197  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   198  	defer cancel()
   199  	target := fmt.Sprintf("127.0.0.1:%d", port)
   200  	conn, err := grpc.DialContext(ctx, target, secureOpts, grpc.WithBlock())
   201  	require.NoError(t, err, "%v", err)
   202  	if err != nil {
   203  		return nil
   204  	}
   205  	cl := proto.NewGossipClient(conn)
   206  	stream, err := cl.GossipStream(context.Background())
   207  	require.NoError(t, err, "%v", err)
   208  	if err != nil {
   209  		return nil
   210  	}
   211  
   212  	var clientCertHash []byte
   213  	if len(tlsCfg.Certificates) > 0 {
   214  		clientCertHash = certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
   215  	}
   216  
   217  	pkiID := common.PKIidType(endpoint)
   218  	require.NoError(t, err, "%v", err)
   219  	msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte(endpoint), func(msg []byte) ([]byte, error) {
   220  		mac := hmac.New(sha256.New, hmacKey)
   221  		mac.Write(msg)
   222  		return mac.Sum(nil), nil
   223  	}, false)
   224  	// Mutate connection message to test negative paths
   225  	msg = connMutator(msg)
   226  	// Send your own connection message
   227  	stream.Send(msg.Envelope)
   228  	// Wait for connection message from the other side
   229  	envelope, err := stream.Recv()
   230  	if err != nil {
   231  		return acceptChan
   232  	}
   233  	require.NoError(t, err, "%v", err)
   234  	msg, err = protoext.EnvelopeToGossipMessage(envelope)
   235  	require.NoError(t, err, "%v", err)
   236  	require.Equal(t, []byte(target), msg.GetConn().PkiId)
   237  	require.Equal(t, extractCertificateHashFromContext(stream.Context()), msg.GetConn().TlsCertHash)
   238  	msg2Send := createGossipMsg()
   239  	nonce := uint64(rand.Int())
   240  	msg2Send.Nonce = nonce
   241  	go stream.Send(msg2Send.Envelope)
   242  	return acceptChan
   243  }
   244  
   245  func TestMutualParallelSendWithAck(t *testing.T) {
   246  	// This test tests concurrent and parallel sending of many (1000) messages
   247  	// from 2 instances to one another at the same time.
   248  
   249  	msgNum := 1000
   250  
   251  	comm1, port1 := newCommInstance(t, naiveSec)
   252  	comm2, port2 := newCommInstance(t, naiveSec)
   253  	defer comm1.Stop()
   254  	defer comm2.Stop()
   255  
   256  	acceptData := func(o interface{}) bool {
   257  		m := o.(protoext.ReceivedMessage).GetGossipMessage()
   258  		return protoext.IsDataMsg(m.GossipMessage)
   259  	}
   260  
   261  	inc1 := comm1.Accept(acceptData)
   262  	inc2 := comm2.Accept(acceptData)
   263  
   264  	// Send a message from comm1 to comm2, to make the instances establish a preliminary connection
   265  	comm1.Send(createGossipMsg(), remotePeer(port2))
   266  	// Wait for the message to be received in comm2
   267  	<-inc2
   268  
   269  	for i := 0; i < msgNum; i++ {
   270  		go comm1.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port2))
   271  	}
   272  
   273  	for i := 0; i < msgNum; i++ {
   274  		go comm2.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port1))
   275  	}
   276  
   277  	go func() {
   278  		for i := 0; i < msgNum; i++ {
   279  			<-inc1
   280  		}
   281  	}()
   282  
   283  	for i := 0; i < msgNum; i++ {
   284  		<-inc2
   285  	}
   286  }
   287  
   288  func getAvailablePort(t *testing.T) (port int, endpoint string, ll net.Listener) {
   289  	ll, err := net.Listen("tcp", "127.0.0.1:0")
   290  	require.NoError(t, err)
   291  	endpoint = ll.Addr().String()
   292  	_, portS, err := net.SplitHostPort(endpoint)
   293  	require.NoError(t, err)
   294  	portInt, err := strconv.Atoi(portS)
   295  	require.NoError(t, err)
   296  	return portInt, endpoint, ll
   297  }
   298  
   299  func TestHandshake(t *testing.T) {
   300  	signer := func(msg []byte) ([]byte, error) {
   301  		mac := hmac.New(sha256.New, hmacKey)
   302  		mac.Write(msg)
   303  		return mac.Sum(nil), nil
   304  	}
   305  	mutator := func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   306  		return msg
   307  	}
   308  	assertPositivePath := func(msg protoext.ReceivedMessage, endpoint string) {
   309  		expectedPKIID := common.PKIidType(endpoint)
   310  		require.Equal(t, expectedPKIID, msg.GetConnectionInfo().ID)
   311  		require.Equal(t, api.PeerIdentityType(endpoint), msg.GetConnectionInfo().Identity)
   312  		require.NotNil(t, msg.GetConnectionInfo().Auth)
   313  		sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
   314  		require.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
   315  	}
   316  
   317  	// Positive path 1 - check authentication without TLS
   318  	port, endpoint, ll := getAvailablePort(t)
   319  	s := grpc.NewServer()
   320  	id := []byte(endpoint)
   321  	idMapper := identity.NewIdentityMapper(naiveSec, id, noopPurgeIdentity, naiveSec)
   322  	inst, err := NewCommInstance(s, nil, idMapper, api.PeerIdentityType(endpoint), func() []grpc.DialOption {
   323  		return []grpc.DialOption{grpc.WithInsecure()}
   324  	}, naiveSec, disabledMetrics, testCommConfig)
   325  	go s.Serve(ll)
   326  	require.NoError(t, err)
   327  	var msg protoext.ReceivedMessage
   328  
   329  	_, tempEndpoint, tempL := getAvailablePort(t)
   330  	acceptChan := handshaker(port, tempEndpoint, inst, t, mutator, none)
   331  	select {
   332  	case <-time.After(time.Duration(time.Second * 4)):
   333  		require.FailNow(t, "Didn't receive a message, seems like handshake failed")
   334  	case msg = <-acceptChan:
   335  	}
   336  	require.Equal(t, common.PKIidType(tempEndpoint), msg.GetConnectionInfo().ID)
   337  	require.Equal(t, api.PeerIdentityType(tempEndpoint), msg.GetConnectionInfo().Identity)
   338  	sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
   339  	require.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
   340  
   341  	inst.Stop()
   342  	s.Stop()
   343  	ll.Close()
   344  	tempL.Close()
   345  	time.Sleep(time.Second)
   346  
   347  	comm, port := newCommInstance(t, naiveSec)
   348  	defer comm.Stop()
   349  	// Positive path 2: initiating peer sends its own certificate
   350  	_, tempEndpoint, tempL = getAvailablePort(t)
   351  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   352  
   353  	select {
   354  	case <-time.After(time.Second * 2):
   355  		require.FailNow(t, "Didn't receive a message, seems like handshake failed")
   356  	case msg = <-acceptChan:
   357  	}
   358  	assertPositivePath(msg, tempEndpoint)
   359  	tempL.Close()
   360  
   361  	// Negative path: initiating peer doesn't send its own certificate
   362  	_, tempEndpoint, tempL = getAvailablePort(t)
   363  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, oneWayTLS)
   364  	time.Sleep(time.Second)
   365  	require.Equal(t, 0, len(acceptChan))
   366  	tempL.Close()
   367  
   368  	// Negative path, signature is wrong
   369  	_, tempEndpoint, tempL = getAvailablePort(t)
   370  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   371  		msg.Signature = append(msg.Signature, 0)
   372  		return msg
   373  	}
   374  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   375  	time.Sleep(time.Second)
   376  	require.Equal(t, 0, len(acceptChan))
   377  	tempL.Close()
   378  
   379  	// Negative path, the PKIid doesn't match the identity
   380  	_, tempEndpoint, tempL = getAvailablePort(t)
   381  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   382  		msg.GetConn().PkiId = []byte(tempEndpoint)
   383  		// Sign the message again
   384  		msg.Sign(signer)
   385  		return msg
   386  	}
   387  	_, tempEndpoint2, tempL2 := getAvailablePort(t)
   388  	acceptChan = handshaker(port, tempEndpoint2, comm, t, mutator, mutualTLS)
   389  	time.Sleep(time.Second)
   390  	require.Equal(t, 0, len(acceptChan))
   391  	tempL.Close()
   392  	tempL2.Close()
   393  
   394  	// Negative path, the cert hash isn't what is expected
   395  	_, tempEndpoint, tempL = getAvailablePort(t)
   396  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   397  		msg.GetConn().TlsCertHash = append(msg.GetConn().TlsCertHash, 0)
   398  		msg.Sign(signer)
   399  		return msg
   400  	}
   401  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   402  	time.Sleep(time.Second)
   403  	require.Equal(t, 0, len(acceptChan))
   404  	tempL.Close()
   405  
   406  	// Negative path, no PKI-ID was sent
   407  	_, tempEndpoint, tempL = getAvailablePort(t)
   408  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   409  		msg.GetConn().PkiId = nil
   410  		msg.Sign(signer)
   411  		return msg
   412  	}
   413  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   414  	time.Sleep(time.Second)
   415  	require.Equal(t, 0, len(acceptChan))
   416  	tempL.Close()
   417  
   418  	// Negative path, connection message is of a different type
   419  	_, tempEndpoint, tempL = getAvailablePort(t)
   420  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   421  		msg.Content = &proto.GossipMessage_Empty{
   422  			Empty: &proto.Empty{},
   423  		}
   424  		msg.Sign(signer)
   425  		return msg
   426  	}
   427  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   428  	time.Sleep(time.Second)
   429  	require.Equal(t, 0, len(acceptChan))
   430  	tempL.Close()
   431  
   432  	// Negative path, the peer didn't respond to the handshake in due time
   433  	_, tempEndpoint, tempL = getAvailablePort(t)
   434  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   435  		time.Sleep(time.Second * 5)
   436  		return msg
   437  	}
   438  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   439  	time.Sleep(time.Second)
   440  	require.Equal(t, 0, len(acceptChan))
   441  	tempL.Close()
   442  }
   443  
   444  func TestConnectUnexpectedPeer(t *testing.T) {
   445  	// Scenarios: In both scenarios, comm1 connects to comm2 or comm3.
   446  	// and expects to see a PKI-ID which is equal to comm4's PKI-ID.
   447  	// The connection attempt would succeed or fail based on whether comm2 or comm3
   448  	// are in the same org as comm4
   449  
   450  	identityByPort := func(port int) api.PeerIdentityType {
   451  		return api.PeerIdentityType(fmt.Sprintf("127.0.0.1:%d", port))
   452  	}
   453  
   454  	customNaiveSec := &naiveSecProvider{}
   455  
   456  	comm1Port, gRPCServer1, certs1, secureDialOpts1, dialOpts1 := util.CreateGRPCLayer()
   457  	comm2Port, gRPCServer2, certs2, secureDialOpts2, dialOpts2 := util.CreateGRPCLayer()
   458  	comm3Port, gRPCServer3, certs3, secureDialOpts3, dialOpts3 := util.CreateGRPCLayer()
   459  	comm4Port, gRPCServer4, certs4, secureDialOpts4, dialOpts4 := util.CreateGRPCLayer()
   460  
   461  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm1Port)).Return(api.OrgIdentityType("O"))
   462  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm2Port)).Return(api.OrgIdentityType("A"))
   463  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm3Port)).Return(api.OrgIdentityType("B"))
   464  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm4Port)).Return(api.OrgIdentityType("A"))
   465  
   466  	comm1 := newCommInstanceOnly(t, customNaiveSec, gRPCServer1, certs1, secureDialOpts1, dialOpts1...)
   467  	comm2 := newCommInstanceOnly(t, naiveSec, gRPCServer2, certs2, secureDialOpts2, dialOpts2...)
   468  	comm3 := newCommInstanceOnly(t, naiveSec, gRPCServer3, certs3, secureDialOpts3, dialOpts3...)
   469  	comm4 := newCommInstanceOnly(t, naiveSec, gRPCServer4, certs4, secureDialOpts4, dialOpts4...)
   470  
   471  	defer comm1.Stop()
   472  	defer comm2.Stop()
   473  	defer comm3.Stop()
   474  	defer comm4.Stop()
   475  
   476  	messagesForComm1 := comm1.Accept(acceptAll)
   477  	messagesForComm2 := comm2.Accept(acceptAll)
   478  	messagesForComm3 := comm3.Accept(acceptAll)
   479  
   480  	// Have comm4 send a message to comm1
   481  	// in order for comm1 to know comm4
   482  	comm4.Send(createGossipMsg(), remotePeer(comm1Port))
   483  	<-messagesForComm1
   484  	// Close the connection with comm4
   485  	comm1.CloseConn(remotePeer(comm4Port))
   486  	// At this point, comm1 knows comm4's identity and organization
   487  
   488  	t.Run("Same organization", func(t *testing.T) {
   489  		unexpectedRemotePeer := remotePeer(comm2Port)
   490  		unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
   491  		comm1.Send(createGossipMsg(), unexpectedRemotePeer)
   492  		select {
   493  		case <-messagesForComm2:
   494  		case <-time.After(time.Second * 5):
   495  			require.Fail(t, "Didn't receive a message within a timely manner")
   496  			util.PrintStackTrace()
   497  		}
   498  	})
   499  
   500  	t.Run("Unexpected organization", func(t *testing.T) {
   501  		unexpectedRemotePeer := remotePeer(comm3Port)
   502  		unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
   503  		comm1.Send(createGossipMsg(), unexpectedRemotePeer)
   504  		select {
   505  		case <-messagesForComm3:
   506  			require.Fail(t, "Message shouldn't have been received")
   507  		case <-time.After(time.Second * 5):
   508  		}
   509  	})
   510  }
   511  
   512  func TestGetConnectionInfo(t *testing.T) {
   513  	comm1, port1 := newCommInstance(t, naiveSec)
   514  	comm2, _ := newCommInstance(t, naiveSec)
   515  	defer comm1.Stop()
   516  	defer comm2.Stop()
   517  	m1 := comm1.Accept(acceptAll)
   518  	comm2.Send(createGossipMsg(), remotePeer(port1))
   519  	select {
   520  	case <-time.After(time.Second * 10):
   521  		t.Fatal("Didn't receive a message in time")
   522  	case msg := <-m1:
   523  		require.Equal(t, comm2.GetPKIid(), msg.GetConnectionInfo().ID)
   524  		require.NotNil(t, msg.GetSourceEnvelope())
   525  	}
   526  }
   527  
   528  func TestCloseConn(t *testing.T) {
   529  	comm1, port1 := newCommInstance(t, naiveSec)
   530  	defer comm1.Stop()
   531  	acceptChan := comm1.Accept(acceptAll)
   532  
   533  	cert := GenerateCertificatesOrPanic()
   534  	tlsCfg := &tls.Config{
   535  		InsecureSkipVerify: true,
   536  		Certificates:       []tls.Certificate{cert},
   537  	}
   538  	ta := credentials.NewTLS(tlsCfg)
   539  
   540  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   541  	defer cancel()
   542  	target := fmt.Sprintf("127.0.0.1:%d", port1)
   543  	conn, err := grpc.DialContext(ctx, target, grpc.WithTransportCredentials(ta), grpc.WithBlock())
   544  	require.NoError(t, err, "%v", err)
   545  	cl := proto.NewGossipClient(conn)
   546  	stream, err := cl.GossipStream(context.Background())
   547  	require.NoError(t, err, "%v", err)
   548  	c := &commImpl{}
   549  	tlsCertHash := certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
   550  	connMsg, _ := c.createConnectionMsg(common.PKIidType("pkiID"), tlsCertHash, api.PeerIdentityType("pkiID"), func(msg []byte) ([]byte, error) {
   551  		mac := hmac.New(sha256.New, hmacKey)
   552  		mac.Write(msg)
   553  		return mac.Sum(nil), nil
   554  	}, false)
   555  	require.NoError(t, stream.Send(connMsg.Envelope))
   556  	stream.Send(createGossipMsg().Envelope)
   557  	select {
   558  	case <-acceptChan:
   559  	case <-time.After(time.Second):
   560  		require.Fail(t, "Didn't receive a message within a timely period")
   561  	}
   562  	comm1.CloseConn(&RemotePeer{PKIID: common.PKIidType("pkiID")})
   563  	time.Sleep(time.Second * 10)
   564  	gotErr := false
   565  	msg2Send := createGossipMsg()
   566  	msg2Send.GetDataMsg().Payload = &proto.Payload{
   567  		Data: make([]byte, 1024*1024),
   568  	}
   569  	protoext.NoopSign(msg2Send.GossipMessage)
   570  	for i := 0; i < DefRecvBuffSize; i++ {
   571  		err := stream.Send(msg2Send.Envelope)
   572  		if err != nil {
   573  			gotErr = true
   574  			break
   575  		}
   576  	}
   577  	require.True(t, gotErr, "Should have failed because connection is closed")
   578  }
   579  
   580  // TestCommSend makes sure that enough messages get through
   581  // eventually. Comm.Send() is both asynchronous and best-effort, so this test
   582  // case assumes some will fail, but that eventually enough messages will get
   583  // through that the test will end.
   584  func TestCommSend(t *testing.T) {
   585  	sendMessages := func(c Comm, peer *RemotePeer, stopChan <-chan struct{}) {
   586  		ticker := time.NewTicker(time.Millisecond)
   587  		defer ticker.Stop()
   588  		for {
   589  			emptyMsg := createGossipMsg()
   590  			select {
   591  			case <-stopChan:
   592  				return
   593  			case <-ticker.C:
   594  				c.Send(emptyMsg, peer)
   595  			}
   596  		}
   597  	}
   598  
   599  	comm1, port1 := newCommInstance(t, naiveSec)
   600  	comm2, port2 := newCommInstance(t, naiveSec)
   601  	defer comm1.Stop()
   602  	defer comm2.Stop()
   603  
   604  	// Create the receive channel before sending the messages
   605  	ch1 := comm1.Accept(acceptAll)
   606  	ch2 := comm2.Accept(acceptAll)
   607  
   608  	// control channels for background senders
   609  	stopch1 := make(chan struct{})
   610  	stopch2 := make(chan struct{})
   611  
   612  	go sendMessages(comm1, remotePeer(port2), stopch1)
   613  	go sendMessages(comm2, remotePeer(port1), stopch2)
   614  
   615  	c1received := 0
   616  	c2received := 0
   617  	// hopefully in some runs we'll fill both send and receive buffers and
   618  	// drop overflowing messages, but still finish, because the endless
   619  	// stream of messages inexorably gets through unless something is very
   620  	// broken.
   621  	totalMessagesReceived := (DefSendBuffSize + DefRecvBuffSize) * 2
   622  	timer := time.NewTimer(30 * time.Second)
   623  	defer timer.Stop()
   624  RECV:
   625  	for {
   626  		select {
   627  		case <-ch1:
   628  			c1received++
   629  			if c1received == totalMessagesReceived {
   630  				close(stopch2)
   631  			}
   632  		case <-ch2:
   633  			c2received++
   634  			if c2received == totalMessagesReceived {
   635  				close(stopch1)
   636  			}
   637  		case <-timer.C:
   638  			t.Fatalf("timed out waiting for messages to be received.\nc1 got %d messages\nc2 got %d messages", c1received, c2received)
   639  		default:
   640  			if c1received >= totalMessagesReceived && c2received >= totalMessagesReceived {
   641  				break RECV
   642  			}
   643  		}
   644  	}
   645  	t.Logf("c1 got %d messages\nc2 got %d messages", c1received, c2received)
   646  }
   647  
   648  type nonResponsivePeer struct {
   649  	*grpc.Server
   650  	port int
   651  }
   652  
   653  func newNonResponsivePeer(t *testing.T) *nonResponsivePeer {
   654  	port, gRPCServer, _, _, _ := util.CreateGRPCLayer()
   655  	nrp := &nonResponsivePeer{
   656  		Server: gRPCServer.Server(),
   657  		port:   port,
   658  	}
   659  	proto.RegisterGossipServer(gRPCServer.Server(), nrp)
   660  	return nrp
   661  }
   662  
   663  func (bp *nonResponsivePeer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) {
   664  	time.Sleep(time.Second * 15)
   665  	return &proto.Empty{}, nil
   666  }
   667  
   668  func (bp *nonResponsivePeer) GossipStream(stream proto.Gossip_GossipStreamServer) error {
   669  	return nil
   670  }
   671  
   672  func (bp *nonResponsivePeer) stop() {
   673  	bp.Server.Stop()
   674  }
   675  
   676  func TestNonResponsivePing(t *testing.T) {
   677  	c, _ := newCommInstance(t, naiveSec)
   678  	defer c.Stop()
   679  	nonRespPeer := newNonResponsivePeer(t)
   680  	defer nonRespPeer.stop()
   681  	s := make(chan struct{})
   682  	go func() {
   683  		c.Probe(remotePeer(nonRespPeer.port))
   684  		s <- struct{}{}
   685  	}()
   686  	select {
   687  	case <-time.After(time.Second * 10):
   688  		require.Fail(t, "Request wasn't cancelled on time")
   689  	case <-s:
   690  	}
   691  }
   692  
   693  func TestResponses(t *testing.T) {
   694  	comm1, port1 := newCommInstance(t, naiveSec)
   695  	comm2, _ := newCommInstance(t, naiveSec)
   696  
   697  	defer comm1.Stop()
   698  	defer comm2.Stop()
   699  
   700  	wg := sync.WaitGroup{}
   701  
   702  	msg := createGossipMsg()
   703  	wg.Add(1)
   704  	go func() {
   705  		inChan := comm1.Accept(acceptAll)
   706  		wg.Done()
   707  		for m := range inChan {
   708  			reply := createGossipMsg()
   709  			reply.Nonce = m.GetGossipMessage().Nonce + 1
   710  			m.Respond(reply.GossipMessage)
   711  		}
   712  	}()
   713  	expectedNOnce := uint64(msg.Nonce + 1)
   714  	responsesFromComm1 := comm2.Accept(acceptAll)
   715  
   716  	ticker := time.NewTicker(10 * time.Second)
   717  	wg.Wait()
   718  	comm2.Send(msg, remotePeer(port1))
   719  
   720  	select {
   721  	case <-ticker.C:
   722  		require.Fail(t, "Haven't got response from comm1 within a timely manner")
   723  		break
   724  	case resp := <-responsesFromComm1:
   725  		ticker.Stop()
   726  		require.Equal(t, expectedNOnce, resp.GetGossipMessage().Nonce)
   727  		break
   728  	}
   729  }
   730  
   731  // TestAccept makes sure that accept filters work. The probability of the parity
   732  // of all nonces being 0 or 1 is very low.
   733  func TestAccept(t *testing.T) {
   734  	comm1, port1 := newCommInstance(t, naiveSec)
   735  	comm2, _ := newCommInstance(t, naiveSec)
   736  
   737  	evenNONCESelector := func(m interface{}) bool {
   738  		return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 == 0
   739  	}
   740  
   741  	oddNONCESelector := func(m interface{}) bool {
   742  		return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 != 0
   743  	}
   744  
   745  	evenNONCES := comm1.Accept(evenNONCESelector)
   746  	oddNONCES := comm1.Accept(oddNONCESelector)
   747  
   748  	var evenResults []uint64
   749  	var oddResults []uint64
   750  
   751  	out := make(chan uint64)
   752  	sem := make(chan struct{})
   753  
   754  	readIntoSlice := func(a *[]uint64, ch <-chan protoext.ReceivedMessage) {
   755  		for m := range ch {
   756  			*a = append(*a, m.GetGossipMessage().Nonce)
   757  			select {
   758  			case out <- m.GetGossipMessage().Nonce:
   759  			default: // avoid blocking when we stop reading from out
   760  			}
   761  		}
   762  		sem <- struct{}{}
   763  	}
   764  
   765  	go readIntoSlice(&evenResults, evenNONCES)
   766  	go readIntoSlice(&oddResults, oddNONCES)
   767  
   768  	stopSend := make(chan struct{})
   769  	go func() {
   770  		for {
   771  			select {
   772  			case <-stopSend:
   773  				return
   774  			default:
   775  				comm2.Send(createGossipMsg(), remotePeer(port1))
   776  			}
   777  		}
   778  	}()
   779  
   780  	waitForMessages(t, out, (DefSendBuffSize+DefRecvBuffSize)*2, "Didn't receive all messages sent")
   781  	close(stopSend)
   782  
   783  	comm1.Stop()
   784  	comm2.Stop()
   785  
   786  	<-sem
   787  	<-sem
   788  
   789  	t.Logf("%d even nonces received", len(evenResults))
   790  	t.Logf("%d  odd nonces received", len(oddResults))
   791  
   792  	require.NotEmpty(t, evenResults)
   793  	require.NotEmpty(t, oddResults)
   794  
   795  	remainderPredicate := func(a []uint64, rem uint64) {
   796  		for _, n := range a {
   797  			require.Equal(t, n%2, rem)
   798  		}
   799  	}
   800  
   801  	remainderPredicate(evenResults, 0)
   802  	remainderPredicate(oddResults, 1)
   803  }
   804  
   805  func TestReConnections(t *testing.T) {
   806  	comm1, port1 := newCommInstance(t, naiveSec)
   807  	comm2, port2 := newCommInstance(t, naiveSec)
   808  
   809  	reader := func(out chan uint64, in <-chan protoext.ReceivedMessage) {
   810  		for {
   811  			msg := <-in
   812  			if msg == nil {
   813  				return
   814  			}
   815  			out <- msg.GetGossipMessage().Nonce
   816  		}
   817  	}
   818  
   819  	out1 := make(chan uint64, 10)
   820  	out2 := make(chan uint64, 10)
   821  
   822  	go reader(out1, comm1.Accept(acceptAll))
   823  	go reader(out2, comm2.Accept(acceptAll))
   824  
   825  	// comm1 connects to comm2
   826  	comm1.Send(createGossipMsg(), remotePeer(port2))
   827  	waitForMessages(t, out2, 1, "Comm2 didn't receive a message from comm1 in a timely manner")
   828  	// comm2 sends to comm1
   829  	comm2.Send(createGossipMsg(), remotePeer(port1))
   830  	waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
   831  	comm1.Stop()
   832  
   833  	comm1, port1 = newCommInstance(t, naiveSec)
   834  	out1 = make(chan uint64, 1)
   835  	go reader(out1, comm1.Accept(acceptAll))
   836  	comm2.Send(createGossipMsg(), remotePeer(port1))
   837  	waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
   838  	comm1.Stop()
   839  	comm2.Stop()
   840  }
   841  
   842  func TestProbe(t *testing.T) {
   843  	comm1, port1 := newCommInstance(t, naiveSec)
   844  	defer comm1.Stop()
   845  	comm2, port2 := newCommInstance(t, naiveSec)
   846  	time.Sleep(time.Duration(1) * time.Second)
   847  	require.NoError(t, comm1.Probe(remotePeer(port2)))
   848  	_, err := comm1.Handshake(remotePeer(port2))
   849  	require.NoError(t, err)
   850  	tempPort, _, ll := getAvailablePort(t)
   851  	defer ll.Close()
   852  	require.Error(t, comm1.Probe(remotePeer(tempPort)))
   853  	_, err = comm1.Handshake(remotePeer(tempPort))
   854  	require.Error(t, err)
   855  	comm2.Stop()
   856  	time.Sleep(time.Duration(1) * time.Second)
   857  	require.Error(t, comm1.Probe(remotePeer(port2)))
   858  	_, err = comm1.Handshake(remotePeer(port2))
   859  	require.Error(t, err)
   860  	comm2, port2 = newCommInstance(t, naiveSec)
   861  	defer comm2.Stop()
   862  	time.Sleep(time.Duration(1) * time.Second)
   863  	require.NoError(t, comm2.Probe(remotePeer(port1)))
   864  	_, err = comm2.Handshake(remotePeer(port1))
   865  	require.NoError(t, err)
   866  	require.NoError(t, comm1.Probe(remotePeer(port2)))
   867  	_, err = comm1.Handshake(remotePeer(port2))
   868  	require.NoError(t, err)
   869  	// Now try a deep probe with an expected PKI-ID that doesn't match
   870  	wrongRemotePeer := remotePeer(port2)
   871  	if wrongRemotePeer.PKIID[0] == 0 {
   872  		wrongRemotePeer.PKIID[0] = 1
   873  	} else {
   874  		wrongRemotePeer.PKIID[0] = 0
   875  	}
   876  	_, err = comm1.Handshake(wrongRemotePeer)
   877  	require.Error(t, err)
   878  	// Try a deep probe with a nil PKI-ID
   879  	endpoint := fmt.Sprintf("127.0.0.1:%d", port2)
   880  	id, err := comm1.Handshake(&RemotePeer{Endpoint: endpoint})
   881  	require.NoError(t, err)
   882  	require.Equal(t, api.PeerIdentityType(endpoint), id)
   883  }
   884  
   885  func TestPresumedDead(t *testing.T) {
   886  	comm1, _ := newCommInstance(t, naiveSec)
   887  	comm2, port2 := newCommInstance(t, naiveSec)
   888  
   889  	wg := sync.WaitGroup{}
   890  	wg.Add(1)
   891  	go func() {
   892  		wg.Wait()
   893  		comm1.Send(createGossipMsg(), remotePeer(port2))
   894  	}()
   895  
   896  	ticker := time.NewTicker(time.Duration(10) * time.Second)
   897  	acceptCh := comm2.Accept(acceptAll)
   898  	wg.Done()
   899  	select {
   900  	case <-acceptCh:
   901  		ticker.Stop()
   902  	case <-ticker.C:
   903  		require.Fail(t, "Didn't get first message")
   904  	}
   905  
   906  	comm2.Stop()
   907  	go func() {
   908  		for i := 0; i < 5; i++ {
   909  			comm1.Send(createGossipMsg(), remotePeer(port2))
   910  			time.Sleep(time.Millisecond * 200)
   911  		}
   912  	}()
   913  
   914  	ticker = time.NewTicker(time.Second * time.Duration(3))
   915  	select {
   916  	case <-ticker.C:
   917  		require.Fail(t, "Didn't get a presumed dead message within a timely manner")
   918  		break
   919  	case <-comm1.PresumedDead():
   920  		ticker.Stop()
   921  		break
   922  	}
   923  }
   924  
   925  func TestReadFromStream(t *testing.T) {
   926  	stream := &gmocks.MockStream{}
   927  	stream.On("CloseSend").Return(nil)
   928  	stream.On("Recv").Return(&proto.Envelope{Payload: []byte{1}}, nil).Once()
   929  	stream.On("Recv").Return(nil, errors.New("stream closed")).Once()
   930  
   931  	conn := newConnection(nil, nil, stream, disabledMetrics, ConnConfig{1, 1})
   932  	conn.logger = flogging.MustGetLogger("test")
   933  
   934  	errChan := make(chan error, 2)
   935  	msgChan := make(chan *protoext.SignedGossipMessage, 1)
   936  	var wg sync.WaitGroup
   937  	wg.Add(1)
   938  	go func() {
   939  		defer wg.Done()
   940  		conn.readFromStream(errChan, msgChan)
   941  	}()
   942  
   943  	select {
   944  	case <-msgChan:
   945  		require.Fail(t, "malformed message shouldn't have been received")
   946  	case <-time.After(time.Millisecond * 100):
   947  		require.Len(t, errChan, 1)
   948  	}
   949  
   950  	conn.close()
   951  	wg.Wait()
   952  }
   953  
   954  func TestSendBadEnvelope(t *testing.T) {
   955  	comm1, port := newCommInstance(t, naiveSec)
   956  	defer comm1.Stop()
   957  
   958  	stream, err := establishSession(t, port)
   959  	require.NoError(t, err)
   960  
   961  	inc := comm1.Accept(acceptAll)
   962  
   963  	goodMsg := createGossipMsg()
   964  	err = stream.Send(goodMsg.Envelope)
   965  	require.NoError(t, err)
   966  
   967  	select {
   968  	case goodMsgReceived := <-inc:
   969  		require.Equal(t, goodMsg.Envelope.Payload, goodMsgReceived.GetSourceEnvelope().Payload)
   970  	case <-time.After(time.Minute):
   971  		require.Fail(t, "Didn't receive message within a timely manner")
   972  		return
   973  	}
   974  
   975  	// Next, we corrupt a message and send it until the stream is closed forcefully from the remote peer
   976  	start := time.Now()
   977  	for {
   978  		badMsg := createGossipMsg()
   979  		badMsg.Envelope.Payload = []byte{1}
   980  		err = stream.Send(badMsg.Envelope)
   981  		if err != nil {
   982  			require.Equal(t, io.EOF, err)
   983  			break
   984  		}
   985  		if time.Now().After(start.Add(time.Second * 30)) {
   986  			require.Fail(t, "Didn't close stream within a timely manner")
   987  			return
   988  		}
   989  	}
   990  }
   991  
   992  func establishSession(t *testing.T, port int) (proto.Gossip_GossipStreamClient, error) {
   993  	cert := GenerateCertificatesOrPanic()
   994  	secureOpts := grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
   995  		InsecureSkipVerify: true,
   996  		Certificates:       []tls.Certificate{cert},
   997  	}))
   998  
   999  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  1000  	defer cancel()
  1001  
  1002  	endpoint := fmt.Sprintf("127.0.0.1:%d", port)
  1003  	conn, err := grpc.DialContext(ctx, endpoint, secureOpts, grpc.WithBlock())
  1004  	require.NoError(t, err, "%v", err)
  1005  	if err != nil {
  1006  		return nil, err
  1007  	}
  1008  	cl := proto.NewGossipClient(conn)
  1009  	stream, err := cl.GossipStream(context.Background())
  1010  	require.NoError(t, err, "%v", err)
  1011  	if err != nil {
  1012  		return nil, err
  1013  	}
  1014  
  1015  	clientCertHash := certHashFromRawCert(cert.Certificate[0])
  1016  	pkiID := common.PKIidType([]byte{1, 2, 3})
  1017  	c := &commImpl{}
  1018  	require.NoError(t, err, "%v", err)
  1019  	msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte{1, 2, 3}, func(msg []byte) ([]byte, error) {
  1020  		mac := hmac.New(sha256.New, hmacKey)
  1021  		mac.Write(msg)
  1022  		return mac.Sum(nil), nil
  1023  	}, false)
  1024  	// Send your own connection message
  1025  	stream.Send(msg.Envelope)
  1026  	// Wait for connection message from the other side
  1027  	envelope, err := stream.Recv()
  1028  	if err != nil {
  1029  		return nil, err
  1030  	}
  1031  	require.NotNil(t, envelope)
  1032  	return stream, nil
  1033  }
  1034  
  1035  func createGossipMsg() *protoext.SignedGossipMessage {
  1036  	msg, _ := protoext.NoopSign(&proto.GossipMessage{
  1037  		Tag:   proto.GossipMessage_EMPTY,
  1038  		Nonce: uint64(rand.Int()),
  1039  		Content: &proto.GossipMessage_DataMsg{
  1040  			DataMsg: &proto.DataMessage{},
  1041  		},
  1042  	})
  1043  	return msg
  1044  }
  1045  
  1046  func remotePeer(port int) *RemotePeer {
  1047  	endpoint := fmt.Sprintf("127.0.0.1:%d", port)
  1048  	return &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)}
  1049  }
  1050  
  1051  func waitForMessages(t *testing.T, msgChan chan uint64, count int, errMsg string) {
  1052  	c := 0
  1053  	waiting := true
  1054  	ticker := time.NewTicker(time.Duration(10) * time.Second)
  1055  	for waiting {
  1056  		select {
  1057  		case <-msgChan:
  1058  			c++
  1059  			if c == count {
  1060  				waiting = false
  1061  			}
  1062  		case <-ticker.C:
  1063  			waiting = false
  1064  		}
  1065  	}
  1066  	require.Equal(t, count, c, errMsg)
  1067  }
  1068  
  1069  func TestConcurrentCloseSend(t *testing.T) {
  1070  	var stopping int32
  1071  
  1072  	comm1, _ := newCommInstance(t, naiveSec)
  1073  	comm2, port2 := newCommInstance(t, naiveSec)
  1074  	m := comm2.Accept(acceptAll)
  1075  	comm1.Send(createGossipMsg(), remotePeer(port2))
  1076  	<-m
  1077  	ready := make(chan struct{})
  1078  	done := make(chan struct{})
  1079  	go func() {
  1080  		defer close(done)
  1081  
  1082  		comm1.Send(createGossipMsg(), remotePeer(port2))
  1083  		close(ready)
  1084  
  1085  		for atomic.LoadInt32(&stopping) == int32(0) {
  1086  			comm1.Send(createGossipMsg(), remotePeer(port2))
  1087  		}
  1088  	}()
  1089  	<-ready
  1090  	comm2.Stop()
  1091  	atomic.StoreInt32(&stopping, int32(1))
  1092  	<-done
  1093  }