github.com/osdi23p228/fabric@v0.0.0-20221218062954-77808885f5db/gossip/comm/comm_test.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/hmac"
    13  	"crypto/sha256"
    14  	"crypto/tls"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"math/rand"
    19  	"net"
    20  	"strconv"
    21  	"sync"
    22  	"sync/atomic"
    23  	"testing"
    24  	"time"
    25  
    26  	cb "github.com/hyperledger/fabric-protos-go/common"
    27  	proto "github.com/hyperledger/fabric-protos-go/gossip"
    28  	"github.com/osdi23p228/fabric/bccsp/factory"
    29  	"github.com/osdi23p228/fabric/common/flogging"
    30  	"github.com/osdi23p228/fabric/common/metrics/disabled"
    31  	"github.com/osdi23p228/fabric/gossip/api"
    32  	"github.com/osdi23p228/fabric/gossip/api/mocks"
    33  	gmocks "github.com/osdi23p228/fabric/gossip/comm/mocks"
    34  	"github.com/osdi23p228/fabric/gossip/common"
    35  	"github.com/osdi23p228/fabric/gossip/identity"
    36  	"github.com/osdi23p228/fabric/gossip/metrics"
    37  	"github.com/osdi23p228/fabric/gossip/protoext"
    38  	"github.com/osdi23p228/fabric/gossip/util"
    39  	"github.com/osdi23p228/fabric/internal/pkg/comm"
    40  	"github.com/stretchr/testify/assert"
    41  	"github.com/stretchr/testify/mock"
    42  	"google.golang.org/grpc"
    43  	"google.golang.org/grpc/credentials"
    44  )
    45  
    46  func init() {
    47  	util.SetupTestLogging()
    48  	rand.Seed(time.Now().UnixNano())
    49  	factory.InitFactories(nil)
    50  	naiveSec.On("OrgByPeerIdentity", mock.Anything).Return(api.OrgIdentityType{})
    51  }
    52  
    53  var testCommConfig = CommConfig{
    54  	DialTimeout:  300 * time.Millisecond,
    55  	ConnTimeout:  DefConnTimeout,
    56  	RecvBuffSize: DefRecvBuffSize,
    57  	SendBuffSize: DefSendBuffSize,
    58  }
    59  
    60  func acceptAll(msg interface{}) bool {
    61  	return true
    62  }
    63  
    64  var noopPurgeIdentity = func(_ common.PKIidType, _ api.PeerIdentityType) {
    65  
    66  }
    67  
    68  var (
    69  	naiveSec        = &naiveSecProvider{}
    70  	hmacKey         = []byte{0, 0, 0}
    71  	disabledMetrics = metrics.NewGossipMetrics(&disabled.Provider{}).CommMetrics
    72  )
    73  
    74  type naiveSecProvider struct {
    75  	mocks.SecurityAdvisor
    76  }
    77  
    78  func (nsp *naiveSecProvider) OrgByPeerIdentity(identity api.PeerIdentityType) api.OrgIdentityType {
    79  	return nsp.SecurityAdvisor.Called(identity).Get(0).(api.OrgIdentityType)
    80  }
    81  
    82  func (*naiveSecProvider) Expiration(peerIdentity api.PeerIdentityType) (time.Time, error) {
    83  	return time.Now().Add(time.Hour), nil
    84  }
    85  
    86  func (*naiveSecProvider) ValidateIdentity(peerIdentity api.PeerIdentityType) error {
    87  	return nil
    88  }
    89  
    90  // GetPKIidOfCert returns the PKI-ID of a peer's identity
    91  func (*naiveSecProvider) GetPKIidOfCert(peerIdentity api.PeerIdentityType) common.PKIidType {
    92  	return common.PKIidType(peerIdentity)
    93  }
    94  
    95  // VerifyBlock returns nil if the block is properly signed,
    96  // else returns error
    97  func (*naiveSecProvider) VerifyBlock(channelID common.ChannelID, seqNum uint64, signedBlock *cb.Block) error {
    98  	return nil
    99  }
   100  
   101  // Sign signs msg with this peer's signing key and outputs
   102  // the signature if no error occurred.
   103  func (*naiveSecProvider) Sign(msg []byte) ([]byte, error) {
   104  	mac := hmac.New(sha256.New, hmacKey)
   105  	mac.Write(msg)
   106  	return mac.Sum(nil), nil
   107  }
   108  
   109  // Verify checks that signature is a valid signature of message under a peer's verification key.
   110  // If the verification succeeded, Verify returns nil meaning no error occurred.
   111  // If peerCert is nil, then the signature is verified against this peer's verification key.
   112  func (*naiveSecProvider) Verify(peerIdentity api.PeerIdentityType, signature, message []byte) error {
   113  	mac := hmac.New(sha256.New, hmacKey)
   114  	mac.Write(message)
   115  	expected := mac.Sum(nil)
   116  	if !bytes.Equal(signature, expected) {
   117  		return fmt.Errorf("Wrong certificate:%v, %v", signature, message)
   118  	}
   119  	return nil
   120  }
   121  
   122  // VerifyByChannel verifies a peer's signature on a message in the context
   123  // of a specific channel
   124  func (*naiveSecProvider) VerifyByChannel(_ common.ChannelID, _ api.PeerIdentityType, _, _ []byte) error {
   125  	return nil
   126  }
   127  
   128  func newCommInstanceOnlyWithMetrics(t *testing.T, commMetrics *metrics.CommMetrics, sec *naiveSecProvider,
   129  	gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
   130  	secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
   131  
   132  	_, portString, err := net.SplitHostPort(gRPCServer.Address())
   133  	assert.NoError(t, err)
   134  
   135  	endpoint := fmt.Sprintf("127.0.0.1:%s", portString)
   136  	id := []byte(endpoint)
   137  	identityMapper := identity.NewIdentityMapper(sec, id, noopPurgeIdentity, sec)
   138  
   139  	commInst, err := NewCommInstance(gRPCServer.Server(), certs, identityMapper, id, secureDialOpts,
   140  		sec, commMetrics, testCommConfig, dialOpts...)
   141  	assert.NoError(t, err)
   142  
   143  	go func() {
   144  		err := gRPCServer.Start()
   145  		assert.NoError(t, err)
   146  	}()
   147  
   148  	return &commGRPC{commInst.(*commImpl), gRPCServer}
   149  }
   150  
   151  type commGRPC struct {
   152  	*commImpl
   153  	gRPCServer *comm.GRPCServer
   154  }
   155  
   156  func (c *commGRPC) Stop() {
   157  	c.commImpl.Stop()
   158  	c.commImpl.idMapper.Stop()
   159  	c.gRPCServer.Stop()
   160  }
   161  
   162  func newCommInstanceOnly(t *testing.T, sec *naiveSecProvider,
   163  	gRPCServer *comm.GRPCServer, certs *common.TLSCertificates,
   164  	secureDialOpts api.PeerSecureDialOpts, dialOpts ...grpc.DialOption) Comm {
   165  	return newCommInstanceOnlyWithMetrics(t, disabledMetrics, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
   166  }
   167  
   168  func newCommInstance(t *testing.T, sec *naiveSecProvider) (c Comm, port int) {
   169  	port, gRPCServer, certs, secureDialOpts, dialOpts := util.CreateGRPCLayer()
   170  	comm := newCommInstanceOnly(t, sec, gRPCServer, certs, secureDialOpts, dialOpts...)
   171  	return comm, port
   172  }
   173  
   174  type msgMutator func(*protoext.SignedGossipMessage) *protoext.SignedGossipMessage
   175  
   176  type tlsType int
   177  
   178  const (
   179  	none tlsType = iota
   180  	oneWayTLS
   181  	mutualTLS
   182  )
   183  
   184  func handshaker(port int, endpoint string, comm Comm, t *testing.T, connMutator msgMutator, connType tlsType) <-chan protoext.ReceivedMessage {
   185  	c := &commImpl{}
   186  	cert := GenerateCertificatesOrPanic()
   187  	tlsCfg := &tls.Config{
   188  		InsecureSkipVerify: true,
   189  	}
   190  	if connType == mutualTLS {
   191  		tlsCfg.Certificates = []tls.Certificate{cert}
   192  	}
   193  	ta := credentials.NewTLS(tlsCfg)
   194  	secureOpts := grpc.WithTransportCredentials(ta)
   195  	if connType == none {
   196  		secureOpts = grpc.WithInsecure()
   197  	}
   198  	acceptChan := comm.Accept(acceptAll)
   199  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   200  	defer cancel()
   201  	target := fmt.Sprintf("127.0.0.1:%d", port)
   202  	conn, err := grpc.DialContext(ctx, target, secureOpts, grpc.WithBlock())
   203  	assert.NoError(t, err, "%v", err)
   204  	if err != nil {
   205  		return nil
   206  	}
   207  	cl := proto.NewGossipClient(conn)
   208  	stream, err := cl.GossipStream(context.Background())
   209  	assert.NoError(t, err, "%v", err)
   210  	if err != nil {
   211  		return nil
   212  	}
   213  
   214  	var clientCertHash []byte
   215  	if len(tlsCfg.Certificates) > 0 {
   216  		clientCertHash = certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
   217  	}
   218  
   219  	pkiID := common.PKIidType(endpoint)
   220  	assert.NoError(t, err, "%v", err)
   221  	msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte(endpoint), func(msg []byte) ([]byte, error) {
   222  		mac := hmac.New(sha256.New, hmacKey)
   223  		mac.Write(msg)
   224  		return mac.Sum(nil), nil
   225  	}, false)
   226  	// Mutate connection message to test negative paths
   227  	msg = connMutator(msg)
   228  	// Send your own connection message
   229  	stream.Send(msg.Envelope)
   230  	// Wait for connection message from the other side
   231  	envelope, err := stream.Recv()
   232  	if err != nil {
   233  		return acceptChan
   234  	}
   235  	assert.NoError(t, err, "%v", err)
   236  	msg, err = protoext.EnvelopeToGossipMessage(envelope)
   237  	assert.NoError(t, err, "%v", err)
   238  	assert.Equal(t, []byte(target), msg.GetConn().PkiId)
   239  	assert.Equal(t, extractCertificateHashFromContext(stream.Context()), msg.GetConn().TlsCertHash)
   240  	msg2Send := createGossipMsg()
   241  	nonce := uint64(rand.Int())
   242  	msg2Send.Nonce = nonce
   243  	go stream.Send(msg2Send.Envelope)
   244  	return acceptChan
   245  }
   246  
   247  func TestMutualParallelSendWithAck(t *testing.T) {
   248  	// This test tests concurrent and parallel sending of many (1000) messages
   249  	// from 2 instances to one another at the same time.
   250  
   251  	msgNum := 1000
   252  
   253  	comm1, port1 := newCommInstance(t, naiveSec)
   254  	comm2, port2 := newCommInstance(t, naiveSec)
   255  	defer comm1.Stop()
   256  	defer comm2.Stop()
   257  
   258  	acceptData := func(o interface{}) bool {
   259  		m := o.(protoext.ReceivedMessage).GetGossipMessage()
   260  		return protoext.IsDataMsg(m.GossipMessage)
   261  	}
   262  
   263  	inc1 := comm1.Accept(acceptData)
   264  	inc2 := comm2.Accept(acceptData)
   265  
   266  	// Send a message from comm1 to comm2, to make the instances establish a preliminary connection
   267  	comm1.Send(createGossipMsg(), remotePeer(port2))
   268  	// Wait for the message to be received in comm2
   269  	<-inc2
   270  
   271  	for i := 0; i < msgNum; i++ {
   272  		go comm1.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port2))
   273  	}
   274  
   275  	for i := 0; i < msgNum; i++ {
   276  		go comm2.SendWithAck(createGossipMsg(), time.Second*5, 1, remotePeer(port1))
   277  	}
   278  
   279  	go func() {
   280  		for i := 0; i < msgNum; i++ {
   281  			<-inc1
   282  		}
   283  	}()
   284  
   285  	for i := 0; i < msgNum; i++ {
   286  		<-inc2
   287  	}
   288  }
   289  
   290  func getAvailablePort(t *testing.T) (port int, endpoint string, ll net.Listener) {
   291  	ll, err := net.Listen("tcp", "127.0.0.1:0")
   292  	assert.NoError(t, err)
   293  	endpoint = ll.Addr().String()
   294  	_, portS, err := net.SplitHostPort(endpoint)
   295  	assert.NoError(t, err)
   296  	portInt, err := strconv.Atoi(portS)
   297  	assert.NoError(t, err)
   298  	return portInt, endpoint, ll
   299  }
   300  
   301  func TestHandshake(t *testing.T) {
   302  	signer := func(msg []byte) ([]byte, error) {
   303  		mac := hmac.New(sha256.New, hmacKey)
   304  		mac.Write(msg)
   305  		return mac.Sum(nil), nil
   306  	}
   307  	mutator := func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   308  		return msg
   309  	}
   310  	assertPositivePath := func(msg protoext.ReceivedMessage, endpoint string) {
   311  		expectedPKIID := common.PKIidType(endpoint)
   312  		assert.Equal(t, expectedPKIID, msg.GetConnectionInfo().ID)
   313  		assert.Equal(t, api.PeerIdentityType(endpoint), msg.GetConnectionInfo().Identity)
   314  		assert.NotNil(t, msg.GetConnectionInfo().Auth)
   315  		sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
   316  		assert.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
   317  	}
   318  
   319  	// Positive path 1 - check authentication without TLS
   320  	port, endpoint, ll := getAvailablePort(t)
   321  	s := grpc.NewServer()
   322  	id := []byte(endpoint)
   323  	idMapper := identity.NewIdentityMapper(naiveSec, id, noopPurgeIdentity, naiveSec)
   324  	inst, err := NewCommInstance(s, nil, idMapper, api.PeerIdentityType(endpoint), func() []grpc.DialOption {
   325  		return []grpc.DialOption{grpc.WithInsecure()}
   326  	}, naiveSec, disabledMetrics, testCommConfig)
   327  	go s.Serve(ll)
   328  	assert.NoError(t, err)
   329  	var msg protoext.ReceivedMessage
   330  
   331  	_, tempEndpoint, tempL := getAvailablePort(t)
   332  	acceptChan := handshaker(port, tempEndpoint, inst, t, mutator, none)
   333  	select {
   334  	case <-time.After(time.Duration(time.Second * 4)):
   335  		assert.FailNow(t, "Didn't receive a message, seems like handshake failed")
   336  	case msg = <-acceptChan:
   337  	}
   338  	assert.Equal(t, common.PKIidType(tempEndpoint), msg.GetConnectionInfo().ID)
   339  	assert.Equal(t, api.PeerIdentityType(tempEndpoint), msg.GetConnectionInfo().Identity)
   340  	sig, _ := (&naiveSecProvider{}).Sign(msg.GetConnectionInfo().Auth.SignedData)
   341  	assert.Equal(t, sig, msg.GetConnectionInfo().Auth.Signature)
   342  
   343  	inst.Stop()
   344  	s.Stop()
   345  	ll.Close()
   346  	tempL.Close()
   347  	time.Sleep(time.Second)
   348  
   349  	comm, port := newCommInstance(t, naiveSec)
   350  	defer comm.Stop()
   351  	// Positive path 2: initiating peer sends its own certificate
   352  	_, tempEndpoint, tempL = getAvailablePort(t)
   353  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   354  
   355  	select {
   356  	case <-time.After(time.Second * 2):
   357  		assert.FailNow(t, "Didn't receive a message, seems like handshake failed")
   358  	case msg = <-acceptChan:
   359  	}
   360  	assertPositivePath(msg, tempEndpoint)
   361  	tempL.Close()
   362  
   363  	// Negative path: initiating peer doesn't send its own certificate
   364  	_, tempEndpoint, tempL = getAvailablePort(t)
   365  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, oneWayTLS)
   366  	time.Sleep(time.Second)
   367  	assert.Equal(t, 0, len(acceptChan))
   368  	tempL.Close()
   369  
   370  	// Negative path, signature is wrong
   371  	_, tempEndpoint, tempL = getAvailablePort(t)
   372  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   373  		msg.Signature = append(msg.Signature, 0)
   374  		return msg
   375  	}
   376  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   377  	time.Sleep(time.Second)
   378  	assert.Equal(t, 0, len(acceptChan))
   379  	tempL.Close()
   380  
   381  	// Negative path, the PKIid doesn't match the identity
   382  	_, tempEndpoint, tempL = getAvailablePort(t)
   383  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   384  		msg.GetConn().PkiId = []byte(tempEndpoint)
   385  		// Sign the message again
   386  		msg.Sign(signer)
   387  		return msg
   388  	}
   389  	_, tempEndpoint2, tempL2 := getAvailablePort(t)
   390  	acceptChan = handshaker(port, tempEndpoint2, comm, t, mutator, mutualTLS)
   391  	time.Sleep(time.Second)
   392  	assert.Equal(t, 0, len(acceptChan))
   393  	tempL.Close()
   394  	tempL2.Close()
   395  
   396  	// Negative path, the cert hash isn't what is expected
   397  	_, tempEndpoint, tempL = getAvailablePort(t)
   398  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   399  		msg.GetConn().TlsCertHash = append(msg.GetConn().TlsCertHash, 0)
   400  		msg.Sign(signer)
   401  		return msg
   402  	}
   403  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   404  	time.Sleep(time.Second)
   405  	assert.Equal(t, 0, len(acceptChan))
   406  	tempL.Close()
   407  
   408  	// Negative path, no PKI-ID was sent
   409  	_, tempEndpoint, tempL = getAvailablePort(t)
   410  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   411  		msg.GetConn().PkiId = nil
   412  		msg.Sign(signer)
   413  		return msg
   414  	}
   415  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   416  	time.Sleep(time.Second)
   417  	assert.Equal(t, 0, len(acceptChan))
   418  	tempL.Close()
   419  
   420  	// Negative path, connection message is of a different type
   421  	_, tempEndpoint, tempL = getAvailablePort(t)
   422  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   423  		msg.Content = &proto.GossipMessage_Empty{
   424  			Empty: &proto.Empty{},
   425  		}
   426  		msg.Sign(signer)
   427  		return msg
   428  	}
   429  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   430  	time.Sleep(time.Second)
   431  	assert.Equal(t, 0, len(acceptChan))
   432  	tempL.Close()
   433  
   434  	// Negative path, the peer didn't respond to the handshake in due time
   435  	_, tempEndpoint, tempL = getAvailablePort(t)
   436  	mutator = func(msg *protoext.SignedGossipMessage) *protoext.SignedGossipMessage {
   437  		time.Sleep(time.Second * 5)
   438  		return msg
   439  	}
   440  	acceptChan = handshaker(port, tempEndpoint, comm, t, mutator, mutualTLS)
   441  	time.Sleep(time.Second)
   442  	assert.Equal(t, 0, len(acceptChan))
   443  	tempL.Close()
   444  }
   445  
   446  func TestConnectUnexpectedPeer(t *testing.T) {
   447  	// Scenarios: In both scenarios, comm1 connects to comm2 or comm3.
   448  	// and expects to see a PKI-ID which is equal to comm4's PKI-ID.
   449  	// The connection attempt would succeed or fail based on whether comm2 or comm3
   450  	// are in the same org as comm4
   451  
   452  	identityByPort := func(port int) api.PeerIdentityType {
   453  		return api.PeerIdentityType(fmt.Sprintf("127.0.0.1:%d", port))
   454  	}
   455  
   456  	customNaiveSec := &naiveSecProvider{}
   457  
   458  	comm1Port, gRPCServer1, certs1, secureDialOpts1, dialOpts1 := util.CreateGRPCLayer()
   459  	comm2Port, gRPCServer2, certs2, secureDialOpts2, dialOpts2 := util.CreateGRPCLayer()
   460  	comm3Port, gRPCServer3, certs3, secureDialOpts3, dialOpts3 := util.CreateGRPCLayer()
   461  	comm4Port, gRPCServer4, certs4, secureDialOpts4, dialOpts4 := util.CreateGRPCLayer()
   462  
   463  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm1Port)).Return(api.OrgIdentityType("O"))
   464  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm2Port)).Return(api.OrgIdentityType("A"))
   465  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm3Port)).Return(api.OrgIdentityType("B"))
   466  	customNaiveSec.On("OrgByPeerIdentity", identityByPort(comm4Port)).Return(api.OrgIdentityType("A"))
   467  
   468  	comm1 := newCommInstanceOnly(t, customNaiveSec, gRPCServer1, certs1, secureDialOpts1, dialOpts1...)
   469  	comm2 := newCommInstanceOnly(t, naiveSec, gRPCServer2, certs2, secureDialOpts2, dialOpts2...)
   470  	comm3 := newCommInstanceOnly(t, naiveSec, gRPCServer3, certs3, secureDialOpts3, dialOpts3...)
   471  	comm4 := newCommInstanceOnly(t, naiveSec, gRPCServer4, certs4, secureDialOpts4, dialOpts4...)
   472  
   473  	defer comm1.Stop()
   474  	defer comm2.Stop()
   475  	defer comm3.Stop()
   476  	defer comm4.Stop()
   477  
   478  	messagesForComm1 := comm1.Accept(acceptAll)
   479  	messagesForComm2 := comm2.Accept(acceptAll)
   480  	messagesForComm3 := comm3.Accept(acceptAll)
   481  
   482  	// Have comm4 send a message to comm1
   483  	// in order for comm1 to know comm4
   484  	comm4.Send(createGossipMsg(), remotePeer(comm1Port))
   485  	<-messagesForComm1
   486  	// Close the connection with comm4
   487  	comm1.CloseConn(remotePeer(comm4Port))
   488  	// At this point, comm1 knows comm4's identity and organization
   489  
   490  	t.Run("Same organization", func(t *testing.T) {
   491  		unexpectedRemotePeer := remotePeer(comm2Port)
   492  		unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
   493  		comm1.Send(createGossipMsg(), unexpectedRemotePeer)
   494  		select {
   495  		case <-messagesForComm2:
   496  		case <-time.After(time.Second * 5):
   497  			assert.Fail(t, "Didn't receive a message within a timely manner")
   498  			util.PrintStackTrace()
   499  		}
   500  	})
   501  
   502  	t.Run("Unexpected organization", func(t *testing.T) {
   503  		unexpectedRemotePeer := remotePeer(comm3Port)
   504  		unexpectedRemotePeer.PKIID = remotePeer(comm4Port).PKIID
   505  		comm1.Send(createGossipMsg(), unexpectedRemotePeer)
   506  		select {
   507  		case <-messagesForComm3:
   508  			assert.Fail(t, "Message shouldn't have been received")
   509  		case <-time.After(time.Second * 5):
   510  		}
   511  	})
   512  }
   513  
   514  func TestGetConnectionInfo(t *testing.T) {
   515  	comm1, port1 := newCommInstance(t, naiveSec)
   516  	comm2, _ := newCommInstance(t, naiveSec)
   517  	defer comm1.Stop()
   518  	defer comm2.Stop()
   519  	m1 := comm1.Accept(acceptAll)
   520  	comm2.Send(createGossipMsg(), remotePeer(port1))
   521  	select {
   522  	case <-time.After(time.Second * 10):
   523  		t.Fatal("Didn't receive a message in time")
   524  	case msg := <-m1:
   525  		assert.Equal(t, comm2.GetPKIid(), msg.GetConnectionInfo().ID)
   526  		assert.NotNil(t, msg.GetSourceEnvelope())
   527  	}
   528  }
   529  
   530  func TestCloseConn(t *testing.T) {
   531  	comm1, port1 := newCommInstance(t, naiveSec)
   532  	defer comm1.Stop()
   533  	acceptChan := comm1.Accept(acceptAll)
   534  
   535  	cert := GenerateCertificatesOrPanic()
   536  	tlsCfg := &tls.Config{
   537  		InsecureSkipVerify: true,
   538  		Certificates:       []tls.Certificate{cert},
   539  	}
   540  	ta := credentials.NewTLS(tlsCfg)
   541  
   542  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   543  	defer cancel()
   544  	target := fmt.Sprintf("127.0.0.1:%d", port1)
   545  	conn, err := grpc.DialContext(ctx, target, grpc.WithTransportCredentials(ta), grpc.WithBlock())
   546  	assert.NoError(t, err, "%v", err)
   547  	cl := proto.NewGossipClient(conn)
   548  	stream, err := cl.GossipStream(context.Background())
   549  	assert.NoError(t, err, "%v", err)
   550  	c := &commImpl{}
   551  	tlsCertHash := certHashFromRawCert(tlsCfg.Certificates[0].Certificate[0])
   552  	connMsg, _ := c.createConnectionMsg(common.PKIidType("pkiID"), tlsCertHash, api.PeerIdentityType("pkiID"), func(msg []byte) ([]byte, error) {
   553  		mac := hmac.New(sha256.New, hmacKey)
   554  		mac.Write(msg)
   555  		return mac.Sum(nil), nil
   556  	}, false)
   557  	assert.NoError(t, stream.Send(connMsg.Envelope))
   558  	stream.Send(createGossipMsg().Envelope)
   559  	select {
   560  	case <-acceptChan:
   561  	case <-time.After(time.Second):
   562  		assert.Fail(t, "Didn't receive a message within a timely period")
   563  	}
   564  	comm1.CloseConn(&RemotePeer{PKIID: common.PKIidType("pkiID")})
   565  	time.Sleep(time.Second * 10)
   566  	gotErr := false
   567  	msg2Send := createGossipMsg()
   568  	msg2Send.GetDataMsg().Payload = &proto.Payload{
   569  		Data: make([]byte, 1024*1024),
   570  	}
   571  	protoext.NoopSign(msg2Send.GossipMessage)
   572  	for i := 0; i < DefRecvBuffSize; i++ {
   573  		err := stream.Send(msg2Send.Envelope)
   574  		if err != nil {
   575  			gotErr = true
   576  			break
   577  		}
   578  	}
   579  	assert.True(t, gotErr, "Should have failed because connection is closed")
   580  }
   581  
   582  // TestCommSend makes sure that enough messages get through
   583  // eventually. Comm.Send() is both asynchronous and best-effort, so this test
   584  // case assumes some will fail, but that eventually enough messages will get
   585  // through that the test will end.
   586  func TestCommSend(t *testing.T) {
   587  	sendMessages := func(c Comm, peer *RemotePeer, stopChan <-chan struct{}) {
   588  		ticker := time.NewTicker(time.Millisecond)
   589  		defer ticker.Stop()
   590  		for {
   591  			emptyMsg := createGossipMsg()
   592  			select {
   593  			case <-stopChan:
   594  				return
   595  			case <-ticker.C:
   596  				c.Send(emptyMsg, peer)
   597  			}
   598  		}
   599  	}
   600  
   601  	comm1, port1 := newCommInstance(t, naiveSec)
   602  	comm2, port2 := newCommInstance(t, naiveSec)
   603  	defer comm1.Stop()
   604  	defer comm2.Stop()
   605  
   606  	// Create the receive channel before sending the messages
   607  	ch1 := comm1.Accept(acceptAll)
   608  	ch2 := comm2.Accept(acceptAll)
   609  
   610  	// control channels for background senders
   611  	stopch1 := make(chan struct{})
   612  	stopch2 := make(chan struct{})
   613  
   614  	go sendMessages(comm1, remotePeer(port2), stopch1)
   615  	go sendMessages(comm2, remotePeer(port1), stopch2)
   616  
   617  	c1received := 0
   618  	c2received := 0
   619  	// hopefully in some runs we'll fill both send and receive buffers and
   620  	// drop overflowing messages, but still finish, because the endless
   621  	// stream of messages inexorably gets through unless something is very
   622  	// broken.
   623  	totalMessagesReceived := (DefSendBuffSize + DefRecvBuffSize) * 2
   624  	timer := time.NewTimer(30 * time.Second)
   625  	defer timer.Stop()
   626  RECV:
   627  	for {
   628  		select {
   629  		case <-ch1:
   630  			c1received++
   631  			if c1received == totalMessagesReceived {
   632  				close(stopch2)
   633  			}
   634  		case <-ch2:
   635  			c2received++
   636  			if c2received == totalMessagesReceived {
   637  				close(stopch1)
   638  			}
   639  		case <-timer.C:
   640  			t.Fatalf("timed out waiting for messages to be received.\nc1 got %d messages\nc2 got %d messages", c1received, c2received)
   641  		default:
   642  			if c1received >= totalMessagesReceived && c2received >= totalMessagesReceived {
   643  				break RECV
   644  			}
   645  		}
   646  	}
   647  	t.Logf("c1 got %d messages\nc2 got %d messages", c1received, c2received)
   648  }
   649  
   650  type nonResponsivePeer struct {
   651  	*grpc.Server
   652  	port int
   653  }
   654  
   655  func newNonResponsivePeer(t *testing.T) *nonResponsivePeer {
   656  	port, gRPCServer, _, _, _ := util.CreateGRPCLayer()
   657  	nrp := &nonResponsivePeer{
   658  		Server: gRPCServer.Server(),
   659  		port:   port,
   660  	}
   661  	proto.RegisterGossipServer(gRPCServer.Server(), nrp)
   662  	return nrp
   663  }
   664  
   665  func (bp *nonResponsivePeer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) {
   666  	time.Sleep(time.Second * 15)
   667  	return &proto.Empty{}, nil
   668  }
   669  
   670  func (bp *nonResponsivePeer) GossipStream(stream proto.Gossip_GossipStreamServer) error {
   671  	return nil
   672  }
   673  
   674  func (bp *nonResponsivePeer) stop() {
   675  	bp.Server.Stop()
   676  }
   677  
   678  func TestNonResponsivePing(t *testing.T) {
   679  	c, _ := newCommInstance(t, naiveSec)
   680  	defer c.Stop()
   681  	nonRespPeer := newNonResponsivePeer(t)
   682  	defer nonRespPeer.stop()
   683  	s := make(chan struct{})
   684  	go func() {
   685  		c.Probe(remotePeer(nonRespPeer.port))
   686  		s <- struct{}{}
   687  	}()
   688  	select {
   689  	case <-time.After(time.Second * 10):
   690  		assert.Fail(t, "Request wasn't cancelled on time")
   691  	case <-s:
   692  	}
   693  
   694  }
   695  
   696  func TestResponses(t *testing.T) {
   697  	comm1, port1 := newCommInstance(t, naiveSec)
   698  	comm2, _ := newCommInstance(t, naiveSec)
   699  
   700  	defer comm1.Stop()
   701  	defer comm2.Stop()
   702  
   703  	wg := sync.WaitGroup{}
   704  
   705  	msg := createGossipMsg()
   706  	wg.Add(1)
   707  	go func() {
   708  		inChan := comm1.Accept(acceptAll)
   709  		wg.Done()
   710  		for m := range inChan {
   711  			reply := createGossipMsg()
   712  			reply.Nonce = m.GetGossipMessage().Nonce + 1
   713  			m.Respond(reply.GossipMessage)
   714  		}
   715  	}()
   716  	expectedNOnce := uint64(msg.Nonce + 1)
   717  	responsesFromComm1 := comm2.Accept(acceptAll)
   718  
   719  	ticker := time.NewTicker(10 * time.Second)
   720  	wg.Wait()
   721  	comm2.Send(msg, remotePeer(port1))
   722  
   723  	select {
   724  	case <-ticker.C:
   725  		assert.Fail(t, "Haven't got response from comm1 within a timely manner")
   726  		break
   727  	case resp := <-responsesFromComm1:
   728  		ticker.Stop()
   729  		assert.Equal(t, expectedNOnce, resp.GetGossipMessage().Nonce)
   730  		break
   731  	}
   732  }
   733  
   734  // TestAccept makes sure that accept filters work. The probability of the parity
   735  // of all nonces being 0 or 1 is very low.
   736  func TestAccept(t *testing.T) {
   737  	comm1, port1 := newCommInstance(t, naiveSec)
   738  	comm2, _ := newCommInstance(t, naiveSec)
   739  
   740  	evenNONCESelector := func(m interface{}) bool {
   741  		return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 == 0
   742  	}
   743  
   744  	oddNONCESelector := func(m interface{}) bool {
   745  		return m.(protoext.ReceivedMessage).GetGossipMessage().Nonce%2 != 0
   746  	}
   747  
   748  	evenNONCES := comm1.Accept(evenNONCESelector)
   749  	oddNONCES := comm1.Accept(oddNONCESelector)
   750  
   751  	var evenResults []uint64
   752  	var oddResults []uint64
   753  
   754  	out := make(chan uint64)
   755  	sem := make(chan struct{})
   756  
   757  	readIntoSlice := func(a *[]uint64, ch <-chan protoext.ReceivedMessage) {
   758  		for m := range ch {
   759  			*a = append(*a, m.GetGossipMessage().Nonce)
   760  			select {
   761  			case out <- m.GetGossipMessage().Nonce:
   762  			default: // avoid blocking when we stop reading from out
   763  			}
   764  		}
   765  		sem <- struct{}{}
   766  	}
   767  
   768  	go readIntoSlice(&evenResults, evenNONCES)
   769  	go readIntoSlice(&oddResults, oddNONCES)
   770  
   771  	stopSend := make(chan struct{})
   772  	go func() {
   773  		for {
   774  			select {
   775  			case <-stopSend:
   776  				return
   777  			default:
   778  				comm2.Send(createGossipMsg(), remotePeer(port1))
   779  			}
   780  		}
   781  	}()
   782  
   783  	waitForMessages(t, out, (DefSendBuffSize+DefRecvBuffSize)*2, "Didn't receive all messages sent")
   784  	close(stopSend)
   785  
   786  	comm1.Stop()
   787  	comm2.Stop()
   788  
   789  	<-sem
   790  	<-sem
   791  
   792  	t.Logf("%d even nonces received", len(evenResults))
   793  	t.Logf("%d  odd nonces received", len(oddResults))
   794  
   795  	assert.NotEmpty(t, evenResults)
   796  	assert.NotEmpty(t, oddResults)
   797  
   798  	remainderPredicate := func(a []uint64, rem uint64) {
   799  		for _, n := range a {
   800  			assert.Equal(t, n%2, rem)
   801  		}
   802  	}
   803  
   804  	remainderPredicate(evenResults, 0)
   805  	remainderPredicate(oddResults, 1)
   806  }
   807  
   808  func TestReConnections(t *testing.T) {
   809  	comm1, port1 := newCommInstance(t, naiveSec)
   810  	comm2, port2 := newCommInstance(t, naiveSec)
   811  
   812  	reader := func(out chan uint64, in <-chan protoext.ReceivedMessage) {
   813  		for {
   814  			msg := <-in
   815  			if msg == nil {
   816  				return
   817  			}
   818  			out <- msg.GetGossipMessage().Nonce
   819  		}
   820  	}
   821  
   822  	out1 := make(chan uint64, 10)
   823  	out2 := make(chan uint64, 10)
   824  
   825  	go reader(out1, comm1.Accept(acceptAll))
   826  	go reader(out2, comm2.Accept(acceptAll))
   827  
   828  	// comm1 connects to comm2
   829  	comm1.Send(createGossipMsg(), remotePeer(port2))
   830  	waitForMessages(t, out2, 1, "Comm2 didn't receive a message from comm1 in a timely manner")
   831  	// comm2 sends to comm1
   832  	comm2.Send(createGossipMsg(), remotePeer(port1))
   833  	waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
   834  	comm1.Stop()
   835  
   836  	comm1, port1 = newCommInstance(t, naiveSec)
   837  	out1 = make(chan uint64, 1)
   838  	go reader(out1, comm1.Accept(acceptAll))
   839  	comm2.Send(createGossipMsg(), remotePeer(port1))
   840  	waitForMessages(t, out1, 1, "Comm1 didn't receive a message from comm2 in a timely manner")
   841  	comm1.Stop()
   842  	comm2.Stop()
   843  }
   844  
   845  func TestProbe(t *testing.T) {
   846  	comm1, port1 := newCommInstance(t, naiveSec)
   847  	defer comm1.Stop()
   848  	comm2, port2 := newCommInstance(t, naiveSec)
   849  	time.Sleep(time.Duration(1) * time.Second)
   850  	assert.NoError(t, comm1.Probe(remotePeer(port2)))
   851  	_, err := comm1.Handshake(remotePeer(port2))
   852  	assert.NoError(t, err)
   853  	tempPort, _, ll := getAvailablePort(t)
   854  	defer ll.Close()
   855  	assert.Error(t, comm1.Probe(remotePeer(tempPort)))
   856  	_, err = comm1.Handshake(remotePeer(tempPort))
   857  	assert.Error(t, err)
   858  	comm2.Stop()
   859  	time.Sleep(time.Duration(1) * time.Second)
   860  	assert.Error(t, comm1.Probe(remotePeer(port2)))
   861  	_, err = comm1.Handshake(remotePeer(port2))
   862  	assert.Error(t, err)
   863  	comm2, port2 = newCommInstance(t, naiveSec)
   864  	defer comm2.Stop()
   865  	time.Sleep(time.Duration(1) * time.Second)
   866  	assert.NoError(t, comm2.Probe(remotePeer(port1)))
   867  	_, err = comm2.Handshake(remotePeer(port1))
   868  	assert.NoError(t, err)
   869  	assert.NoError(t, comm1.Probe(remotePeer(port2)))
   870  	_, err = comm1.Handshake(remotePeer(port2))
   871  	assert.NoError(t, err)
   872  	// Now try a deep probe with an expected PKI-ID that doesn't match
   873  	wrongRemotePeer := remotePeer(port2)
   874  	if wrongRemotePeer.PKIID[0] == 0 {
   875  		wrongRemotePeer.PKIID[0] = 1
   876  	} else {
   877  		wrongRemotePeer.PKIID[0] = 0
   878  	}
   879  	_, err = comm1.Handshake(wrongRemotePeer)
   880  	assert.Error(t, err)
   881  	// Try a deep probe with a nil PKI-ID
   882  	endpoint := fmt.Sprintf("127.0.0.1:%d", port2)
   883  	id, err := comm1.Handshake(&RemotePeer{Endpoint: endpoint})
   884  	assert.NoError(t, err)
   885  	assert.Equal(t, api.PeerIdentityType(endpoint), id)
   886  }
   887  
   888  func TestPresumedDead(t *testing.T) {
   889  	comm1, _ := newCommInstance(t, naiveSec)
   890  	comm2, port2 := newCommInstance(t, naiveSec)
   891  
   892  	wg := sync.WaitGroup{}
   893  	wg.Add(1)
   894  	go func() {
   895  		wg.Wait()
   896  		comm1.Send(createGossipMsg(), remotePeer(port2))
   897  	}()
   898  
   899  	ticker := time.NewTicker(time.Duration(10) * time.Second)
   900  	acceptCh := comm2.Accept(acceptAll)
   901  	wg.Done()
   902  	select {
   903  	case <-acceptCh:
   904  		ticker.Stop()
   905  	case <-ticker.C:
   906  		assert.Fail(t, "Didn't get first message")
   907  	}
   908  
   909  	comm2.Stop()
   910  	go func() {
   911  		for i := 0; i < 5; i++ {
   912  			comm1.Send(createGossipMsg(), remotePeer(port2))
   913  			time.Sleep(time.Millisecond * 200)
   914  		}
   915  	}()
   916  
   917  	ticker = time.NewTicker(time.Second * time.Duration(3))
   918  	select {
   919  	case <-ticker.C:
   920  		assert.Fail(t, "Didn't get a presumed dead message within a timely manner")
   921  		break
   922  	case <-comm1.PresumedDead():
   923  		ticker.Stop()
   924  		break
   925  	}
   926  }
   927  
   928  func TestReadFromStream(t *testing.T) {
   929  	stream := &gmocks.MockStream{}
   930  	stream.On("CloseSend").Return(nil)
   931  	stream.On("Recv").Return(&proto.Envelope{Payload: []byte{1}}, nil).Once()
   932  	stream.On("Recv").Return(nil, errors.New("stream closed")).Once()
   933  
   934  	conn := newConnection(nil, nil, stream, disabledMetrics, ConnConfig{1, 1})
   935  	conn.logger = flogging.MustGetLogger("test")
   936  
   937  	errChan := make(chan error, 2)
   938  	msgChan := make(chan *protoext.SignedGossipMessage, 1)
   939  	var wg sync.WaitGroup
   940  	wg.Add(1)
   941  	go func() {
   942  		defer wg.Done()
   943  		conn.readFromStream(errChan, msgChan)
   944  	}()
   945  
   946  	select {
   947  	case <-msgChan:
   948  		assert.Fail(t, "malformed message shouldn't have been received")
   949  	case <-time.After(time.Millisecond * 100):
   950  		assert.Len(t, errChan, 1)
   951  	}
   952  
   953  	conn.close()
   954  	wg.Wait()
   955  }
   956  
   957  func TestSendBadEnvelope(t *testing.T) {
   958  	comm1, port := newCommInstance(t, naiveSec)
   959  	defer comm1.Stop()
   960  
   961  	stream, err := establishSession(t, port)
   962  	assert.NoError(t, err)
   963  
   964  	inc := comm1.Accept(acceptAll)
   965  
   966  	goodMsg := createGossipMsg()
   967  	err = stream.Send(goodMsg.Envelope)
   968  	assert.NoError(t, err)
   969  
   970  	select {
   971  	case goodMsgReceived := <-inc:
   972  		assert.Equal(t, goodMsg.Envelope.Payload, goodMsgReceived.GetSourceEnvelope().Payload)
   973  	case <-time.After(time.Minute):
   974  		assert.Fail(t, "Didn't receive message within a timely manner")
   975  		return
   976  	}
   977  
   978  	// Next, we corrupt a message and send it until the stream is closed forcefully from the remote peer
   979  	start := time.Now()
   980  	for {
   981  		badMsg := createGossipMsg()
   982  		badMsg.Envelope.Payload = []byte{1}
   983  		err = stream.Send(badMsg.Envelope)
   984  		if err != nil {
   985  			assert.Equal(t, io.EOF, err)
   986  			break
   987  		}
   988  		if time.Now().After(start.Add(time.Second * 30)) {
   989  			assert.Fail(t, "Didn't close stream within a timely manner")
   990  			return
   991  		}
   992  	}
   993  }
   994  
   995  func establishSession(t *testing.T, port int) (proto.Gossip_GossipStreamClient, error) {
   996  	cert := GenerateCertificatesOrPanic()
   997  	secureOpts := grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
   998  		InsecureSkipVerify: true,
   999  		Certificates:       []tls.Certificate{cert},
  1000  	}))
  1001  
  1002  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  1003  	defer cancel()
  1004  
  1005  	endpoint := fmt.Sprintf("127.0.0.1:%d", port)
  1006  	conn, err := grpc.DialContext(ctx, endpoint, secureOpts, grpc.WithBlock())
  1007  	assert.NoError(t, err, "%v", err)
  1008  	if err != nil {
  1009  		return nil, err
  1010  	}
  1011  	cl := proto.NewGossipClient(conn)
  1012  	stream, err := cl.GossipStream(context.Background())
  1013  	assert.NoError(t, err, "%v", err)
  1014  	if err != nil {
  1015  		return nil, err
  1016  	}
  1017  
  1018  	clientCertHash := certHashFromRawCert(cert.Certificate[0])
  1019  	pkiID := common.PKIidType([]byte{1, 2, 3})
  1020  	c := &commImpl{}
  1021  	assert.NoError(t, err, "%v", err)
  1022  	msg, _ := c.createConnectionMsg(pkiID, clientCertHash, []byte{1, 2, 3}, func(msg []byte) ([]byte, error) {
  1023  		mac := hmac.New(sha256.New, hmacKey)
  1024  		mac.Write(msg)
  1025  		return mac.Sum(nil), nil
  1026  	}, false)
  1027  	// Send your own connection message
  1028  	stream.Send(msg.Envelope)
  1029  	// Wait for connection message from the other side
  1030  	envelope, err := stream.Recv()
  1031  	if err != nil {
  1032  		return nil, err
  1033  	}
  1034  	assert.NotNil(t, envelope)
  1035  	return stream, nil
  1036  }
  1037  
  1038  func createGossipMsg() *protoext.SignedGossipMessage {
  1039  	msg, _ := protoext.NoopSign(&proto.GossipMessage{
  1040  		Tag:   proto.GossipMessage_EMPTY,
  1041  		Nonce: uint64(rand.Int()),
  1042  		Content: &proto.GossipMessage_DataMsg{
  1043  			DataMsg: &proto.DataMessage{},
  1044  		},
  1045  	})
  1046  	return msg
  1047  }
  1048  
  1049  func remotePeer(port int) *RemotePeer {
  1050  	endpoint := fmt.Sprintf("127.0.0.1:%d", port)
  1051  	return &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)}
  1052  }
  1053  
  1054  func waitForMessages(t *testing.T, msgChan chan uint64, count int, errMsg string) {
  1055  	c := 0
  1056  	waiting := true
  1057  	ticker := time.NewTicker(time.Duration(10) * time.Second)
  1058  	for waiting {
  1059  		select {
  1060  		case <-msgChan:
  1061  			c++
  1062  			if c == count {
  1063  				waiting = false
  1064  			}
  1065  		case <-ticker.C:
  1066  			waiting = false
  1067  		}
  1068  	}
  1069  	assert.Equal(t, count, c, errMsg)
  1070  }
  1071  
  1072  func TestConcurrentCloseSend(t *testing.T) {
  1073  	var stopping int32
  1074  
  1075  	comm1, _ := newCommInstance(t, naiveSec)
  1076  	comm2, port2 := newCommInstance(t, naiveSec)
  1077  	m := comm2.Accept(acceptAll)
  1078  	comm1.Send(createGossipMsg(), remotePeer(port2))
  1079  	<-m
  1080  	ready := make(chan struct{})
  1081  	done := make(chan struct{})
  1082  	go func() {
  1083  		defer close(done)
  1084  
  1085  		comm1.Send(createGossipMsg(), remotePeer(port2))
  1086  		close(ready)
  1087  
  1088  		for atomic.LoadInt32(&stopping) == int32(0) {
  1089  			comm1.Send(createGossipMsg(), remotePeer(port2))
  1090  		}
  1091  	}()
  1092  	<-ready
  1093  	comm2.Stop()
  1094  	atomic.StoreInt32(&stopping, int32(1))
  1095  	<-done
  1096  }