github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/gossip/comm/comm_impl.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm 8 9 import ( 10 "bytes" 11 "context" 12 "crypto/tls" 13 "encoding/hex" 14 "fmt" 15 "reflect" 16 "sync" 17 "sync/atomic" 18 "time" 19 20 "github.com/hechain20/hechain/gossip/api" 21 "github.com/hechain20/hechain/gossip/common" 22 "github.com/hechain20/hechain/gossip/identity" 23 "github.com/hechain20/hechain/gossip/metrics" 24 "github.com/hechain20/hechain/gossip/protoext" 25 "github.com/hechain20/hechain/gossip/util" 26 proto "github.com/hyperledger/fabric-protos-go/gossip" 27 "github.com/pkg/errors" 28 "google.golang.org/grpc" 29 "google.golang.org/grpc/peer" 30 ) 31 32 const ( 33 handshakeTimeout = time.Second * 10 34 DefDialTimeout = time.Second * 3 35 DefConnTimeout = time.Second * 2 36 DefRecvBuffSize = 20 37 DefSendBuffSize = 20 38 ) 39 40 var errProbe = errors.New("probe") 41 42 // SecurityAdvisor defines an external auxiliary object 43 // that provides security and identity related capabilities 44 type SecurityAdvisor interface { 45 // OrgByPeerIdentity returns the organization identity of the given PeerIdentityType 46 OrgByPeerIdentity(api.PeerIdentityType) api.OrgIdentityType 47 } 48 49 func (c *commImpl) SetDialOpts(opts ...grpc.DialOption) { 50 if len(opts) == 0 { 51 c.logger.Warning("Given an empty set of grpc.DialOption, aborting") 52 return 53 } 54 c.opts = opts 55 } 56 57 // NewCommInstance creates a new comm instance that binds itself to the given gRPC server 58 func NewCommInstance(s *grpc.Server, certs *common.TLSCertificates, idStore identity.Mapper, 59 peerIdentity api.PeerIdentityType, secureDialOpts api.PeerSecureDialOpts, sa api.SecurityAdvisor, 60 commMetrics *metrics.CommMetrics, config CommConfig, dialOpts ...grpc.DialOption) (Comm, error) { 61 commInst := &commImpl{ 62 sa: sa, 63 pubSub: util.NewPubSub(), 64 PKIID: idStore.GetPKIidOfCert(peerIdentity), 65 idMapper: idStore, 66 logger: util.GetLogger(util.CommLogger, ""), 67 peerIdentity: peerIdentity, 68 opts: dialOpts, 69 secureDialOpts: secureDialOpts, 70 msgPublisher: NewChannelDemultiplexer(), 71 lock: &sync.Mutex{}, 72 deadEndpoints: make(chan common.PKIidType, 100), 73 identityChanges: make(chan common.PKIidType, 1), 74 stopping: int32(0), 75 exitChan: make(chan struct{}), 76 subscriptions: make([]chan protoext.ReceivedMessage, 0), 77 tlsCerts: certs, 78 metrics: commMetrics, 79 dialTimeout: config.DialTimeout, 80 connTimeout: config.ConnTimeout, 81 recvBuffSize: config.RecvBuffSize, 82 sendBuffSize: config.SendBuffSize, 83 } 84 85 connConfig := ConnConfig{ 86 RecvBuffSize: config.RecvBuffSize, 87 SendBuffSize: config.SendBuffSize, 88 } 89 90 commInst.connStore = newConnStore(commInst, commInst.logger, connConfig) 91 92 proto.RegisterGossipServer(s, commInst) 93 94 return commInst, nil 95 } 96 97 // CommConfig is the configuration required to initialize a new comm 98 type CommConfig struct { 99 DialTimeout time.Duration // Dial timeout 100 ConnTimeout time.Duration // Connection timeout 101 RecvBuffSize int // Buffer size of received messages 102 SendBuffSize int // Buffer size of sending messages 103 } 104 105 type commImpl struct { 106 sa api.SecurityAdvisor 107 tlsCerts *common.TLSCertificates 108 pubSub *util.PubSub 109 peerIdentity api.PeerIdentityType 110 idMapper identity.Mapper 111 logger util.Logger 112 opts []grpc.DialOption 113 secureDialOpts func() []grpc.DialOption 114 connStore *connectionStore 115 PKIID []byte 116 deadEndpoints chan common.PKIidType 117 identityChanges chan common.PKIidType 118 msgPublisher *ChannelDeMultiplexer 119 lock *sync.Mutex 120 exitChan chan struct{} 121 stopWG sync.WaitGroup 122 subscriptions []chan protoext.ReceivedMessage 123 stopping int32 124 metrics *metrics.CommMetrics 125 dialTimeout time.Duration 126 connTimeout time.Duration 127 recvBuffSize int 128 sendBuffSize int 129 } 130 131 func (c *commImpl) createConnection(endpoint string, expectedPKIID common.PKIidType) (*connection, error) { 132 var err error 133 var cc *grpc.ClientConn 134 var stream proto.Gossip_GossipStreamClient 135 var pkiID common.PKIidType 136 var connInfo *protoext.ConnectionInfo 137 var dialOpts []grpc.DialOption 138 139 c.logger.Debug("Entering", endpoint, expectedPKIID) 140 defer c.logger.Debug("Exiting") 141 142 if c.isStopping() { 143 return nil, errors.New("Stopping") 144 } 145 dialOpts = append(dialOpts, c.secureDialOpts()...) 146 dialOpts = append(dialOpts, grpc.WithBlock()) 147 dialOpts = append(dialOpts, c.opts...) 148 ctx := context.Background() 149 ctx, cancel := context.WithTimeout(ctx, c.dialTimeout) 150 defer cancel() 151 cc, err = grpc.DialContext(ctx, endpoint, dialOpts...) 152 if err != nil { 153 return nil, errors.WithStack(err) 154 } 155 156 cl := proto.NewGossipClient(cc) 157 158 ctx, cancel = context.WithTimeout(context.Background(), c.connTimeout) 159 defer cancel() 160 if _, err = cl.Ping(ctx, &proto.Empty{}); err != nil { 161 cc.Close() 162 return nil, errors.WithStack(err) 163 } 164 165 ctx, cancel = context.WithCancel(context.Background()) 166 if stream, err = cl.GossipStream(ctx); err == nil { 167 connInfo, err = c.authenticateRemotePeer(stream, true, false) 168 if err == nil { 169 pkiID = connInfo.ID 170 // PKIID is nil when we don't know the remote PKI id's 171 if expectedPKIID != nil && !bytes.Equal(pkiID, expectedPKIID) { 172 actualOrg := c.sa.OrgByPeerIdentity(connInfo.Identity) 173 // If the identity isn't present, it's nil - therefore OrgByPeerIdentity would 174 // return nil too and thus would be different than the actual organization 175 identity, _ := c.idMapper.Get(expectedPKIID) 176 oldOrg := c.sa.OrgByPeerIdentity(identity) 177 if !bytes.Equal(actualOrg, oldOrg) { 178 c.logger.Warning("Remote endpoint claims to be a different peer, expected", expectedPKIID, "but got", pkiID) 179 cc.Close() 180 cancel() 181 return nil, errors.New("authentication failure") 182 } else { 183 c.logger.Infof("Peer %s changed its PKI-ID from %s to %s", endpoint, expectedPKIID, pkiID) 184 c.identityChanges <- expectedPKIID 185 } 186 } 187 connConfig := ConnConfig{ 188 RecvBuffSize: c.recvBuffSize, 189 SendBuffSize: c.sendBuffSize, 190 } 191 conn := newConnection(cl, cc, stream, c.metrics, connConfig) 192 conn.pkiID = pkiID 193 conn.info = connInfo 194 conn.logger = c.logger 195 conn.cancel = cancel 196 197 h := func(m *protoext.SignedGossipMessage) { 198 c.logger.Debug("Got message:", m) 199 c.msgPublisher.DeMultiplex(&ReceivedMessageImpl{ 200 conn: conn, 201 SignedGossipMessage: m, 202 connInfo: connInfo, 203 }) 204 } 205 conn.handler = interceptAcks(h, connInfo.ID, c.pubSub) 206 return conn, nil 207 } 208 c.logger.Warningf("Authentication failed: %+v", err) 209 } 210 cc.Close() 211 cancel() 212 return nil, errors.WithStack(err) 213 } 214 215 func (c *commImpl) Send(msg *protoext.SignedGossipMessage, peers ...*RemotePeer) { 216 if c.isStopping() || len(peers) == 0 { 217 return 218 } 219 c.logger.Debug("Entering, sending", msg, "to ", len(peers), "peers") 220 221 for _, peer := range peers { 222 go func(peer *RemotePeer, msg *protoext.SignedGossipMessage) { 223 c.sendToEndpoint(peer, msg, nonBlockingSend) 224 }(peer, msg) 225 } 226 } 227 228 func (c *commImpl) sendToEndpoint(peer *RemotePeer, msg *protoext.SignedGossipMessage, shouldBlock blockingBehavior) { 229 if c.isStopping() { 230 return 231 } 232 c.logger.Debug("Entering, Sending to", peer.Endpoint, ", msg:", msg) 233 defer c.logger.Debug("Exiting") 234 var err error 235 236 conn, err := c.connStore.getConnection(peer) 237 if err == nil { 238 disConnectOnErr := func(err error) { 239 c.logger.Warningf("%v isn't responsive: %v", peer, err) 240 c.disconnect(peer.PKIID) 241 conn.close() 242 } 243 conn.send(msg, disConnectOnErr, shouldBlock) 244 return 245 } 246 c.logger.Warningf("Failed obtaining connection for %v reason: %v", peer, err) 247 c.disconnect(peer.PKIID) 248 } 249 250 func (c *commImpl) isStopping() bool { 251 return atomic.LoadInt32(&c.stopping) == int32(1) 252 } 253 254 func (c *commImpl) Probe(remotePeer *RemotePeer) error { 255 var dialOpts []grpc.DialOption 256 endpoint := remotePeer.Endpoint 257 pkiID := remotePeer.PKIID 258 if c.isStopping() { 259 return errors.New("stopping") 260 } 261 c.logger.Debug("Entering, endpoint:", endpoint, "PKIID:", pkiID) 262 dialOpts = append(dialOpts, c.secureDialOpts()...) 263 dialOpts = append(dialOpts, grpc.WithBlock()) 264 dialOpts = append(dialOpts, c.opts...) 265 ctx := context.Background() 266 ctx, cancel := context.WithTimeout(ctx, c.dialTimeout) 267 defer cancel() 268 cc, err := grpc.DialContext(ctx, remotePeer.Endpoint, dialOpts...) 269 if err != nil { 270 c.logger.Debugf("Returning %v", err) 271 return err 272 } 273 defer cc.Close() 274 cl := proto.NewGossipClient(cc) 275 ctx, cancel = context.WithTimeout(context.Background(), c.connTimeout) 276 defer cancel() 277 _, err = cl.Ping(ctx, &proto.Empty{}) 278 c.logger.Debugf("Returning %v", err) 279 return err 280 } 281 282 func (c *commImpl) Handshake(remotePeer *RemotePeer) (api.PeerIdentityType, error) { 283 var dialOpts []grpc.DialOption 284 dialOpts = append(dialOpts, c.secureDialOpts()...) 285 dialOpts = append(dialOpts, grpc.WithBlock()) 286 dialOpts = append(dialOpts, c.opts...) 287 ctx := context.Background() 288 ctx, cancel := context.WithTimeout(ctx, c.dialTimeout) 289 defer cancel() 290 cc, err := grpc.DialContext(ctx, remotePeer.Endpoint, dialOpts...) 291 if err != nil { 292 return nil, err 293 } 294 defer cc.Close() 295 296 cl := proto.NewGossipClient(cc) 297 ctx, cancel = context.WithTimeout(context.Background(), c.connTimeout) 298 defer cancel() 299 if _, err = cl.Ping(ctx, &proto.Empty{}); err != nil { 300 return nil, err 301 } 302 303 ctx, cancel = context.WithTimeout(context.Background(), handshakeTimeout) 304 defer cancel() 305 stream, err := cl.GossipStream(ctx) 306 if err != nil { 307 return nil, err 308 } 309 connInfo, err := c.authenticateRemotePeer(stream, true, true) 310 if err != nil { 311 c.logger.Warningf("Authentication failed: %v", err) 312 return nil, err 313 } 314 if len(remotePeer.PKIID) > 0 && !bytes.Equal(connInfo.ID, remotePeer.PKIID) { 315 return nil, errors.New("PKI-ID of remote peer doesn't match expected PKI-ID") 316 } 317 return connInfo.Identity, nil 318 } 319 320 func (c *commImpl) Accept(acceptor common.MessageAcceptor) <-chan protoext.ReceivedMessage { 321 genericChan := c.msgPublisher.AddChannel(acceptor) 322 specificChan := make(chan protoext.ReceivedMessage, 10) 323 324 if c.isStopping() { 325 c.logger.Warning("Accept() called but comm module is stopping, returning empty channel") 326 return specificChan 327 } 328 329 c.lock.Lock() 330 c.subscriptions = append(c.subscriptions, specificChan) 331 c.lock.Unlock() 332 333 c.stopWG.Add(1) 334 go func() { 335 defer c.logger.Debug("Exiting Accept() loop") 336 337 defer c.stopWG.Done() 338 339 for { 340 select { 341 case msg, channelOpen := <-genericChan: 342 if !channelOpen { 343 return 344 } 345 select { 346 case specificChan <- msg.(*ReceivedMessageImpl): 347 case <-c.exitChan: 348 return 349 } 350 case <-c.exitChan: 351 return 352 } 353 } 354 }() 355 return specificChan 356 } 357 358 func (c *commImpl) PresumedDead() <-chan common.PKIidType { 359 return c.deadEndpoints 360 } 361 362 func (c *commImpl) IdentitySwitch() <-chan common.PKIidType { 363 return c.identityChanges 364 } 365 366 func (c *commImpl) CloseConn(peer *RemotePeer) { 367 c.logger.Debug("Closing connection for", peer) 368 c.connStore.closeConnByPKIid(peer.PKIID) 369 } 370 371 func (c *commImpl) closeSubscriptions() { 372 c.lock.Lock() 373 defer c.lock.Unlock() 374 for _, ch := range c.subscriptions { 375 close(ch) 376 } 377 } 378 379 func (c *commImpl) Stop() { 380 if !atomic.CompareAndSwapInt32(&c.stopping, 0, int32(1)) { 381 return 382 } 383 c.logger.Info("Stopping") 384 defer c.logger.Info("Stopped") 385 c.connStore.shutdown() 386 c.logger.Debug("Shut down connection store, connection count:", c.connStore.connNum()) 387 c.msgPublisher.Close() 388 close(c.exitChan) 389 c.stopWG.Wait() 390 c.closeSubscriptions() 391 } 392 393 func (c *commImpl) GetPKIid() common.PKIidType { 394 return c.PKIID 395 } 396 397 func extractRemoteAddress(stream stream) string { 398 var remoteAddress string 399 p, ok := peer.FromContext(stream.Context()) 400 if ok { 401 if address := p.Addr; address != nil { 402 remoteAddress = address.String() 403 } 404 } 405 return remoteAddress 406 } 407 408 func (c *commImpl) authenticateRemotePeer(stream stream, initiator, isProbe bool) (*protoext.ConnectionInfo, error) { 409 ctx := stream.Context() 410 remoteAddress := extractRemoteAddress(stream) 411 remoteCertHash := extractCertificateHashFromContext(ctx) 412 var err error 413 var cMsg *protoext.SignedGossipMessage 414 useTLS := c.tlsCerts != nil 415 var selfCertHash []byte 416 417 if useTLS { 418 certReference := c.tlsCerts.TLSServerCert 419 if initiator { 420 certReference = c.tlsCerts.TLSClientCert 421 } 422 selfCertHash = certHashFromRawCert(certReference.Load().(*tls.Certificate).Certificate[0]) 423 } 424 425 signer := func(msg []byte) ([]byte, error) { 426 return c.idMapper.Sign(msg) 427 } 428 429 // TLS enabled but not detected on other side 430 if useTLS && len(remoteCertHash) == 0 { 431 c.logger.Warningf("%s didn't send TLS certificate", remoteAddress) 432 return nil, errors.New("no TLS certificate") 433 } 434 435 cMsg, err = c.createConnectionMsg(c.PKIID, selfCertHash, c.peerIdentity, signer, isProbe) 436 if err != nil { 437 return nil, err 438 } 439 440 c.logger.Debug("Sending", cMsg, "to", remoteAddress) 441 stream.Send(cMsg.Envelope) 442 m, err := readWithTimeout(stream, c.connTimeout, remoteAddress) 443 if err != nil { 444 c.logger.Warningf("Failed reading message from %s, reason: %v", remoteAddress, err) 445 return nil, err 446 } 447 receivedMsg := m.GetConn() 448 if receivedMsg == nil { 449 c.logger.Warning("Expected connection message from", remoteAddress, "but got", receivedMsg) 450 return nil, errors.New("wrong type") 451 } 452 453 if receivedMsg.PkiId == nil { 454 c.logger.Warningf("%s didn't send a pkiID", remoteAddress) 455 return nil, errors.New("no PKI-ID") 456 } 457 458 c.logger.Debug("Received", receivedMsg, "from", remoteAddress) 459 err = c.idMapper.Put(receivedMsg.PkiId, receivedMsg.Identity) 460 if err != nil { 461 c.logger.Warningf("Identity store rejected %s : %v", remoteAddress, err) 462 return nil, err 463 } 464 465 connInfo := &protoext.ConnectionInfo{ 466 ID: receivedMsg.PkiId, 467 Identity: receivedMsg.Identity, 468 Endpoint: remoteAddress, 469 Auth: &protoext.AuthInfo{ 470 Signature: m.Signature, 471 SignedData: m.Payload, 472 }, 473 } 474 475 // if TLS is enabled and detected, verify remote peer 476 if useTLS { 477 // If the remote peer sent its TLS certificate, make sure it actually matches the TLS cert 478 // that the peer used. 479 if !bytes.Equal(remoteCertHash, receivedMsg.TlsCertHash) { 480 return nil, errors.Errorf("Expected %v in remote hash of TLS cert, but got %v", remoteCertHash, receivedMsg.TlsCertHash) 481 } 482 } 483 // Final step - verify the signature on the connection message itself 484 verifier := func(peerIdentity []byte, signature, message []byte) error { 485 pkiID := c.idMapper.GetPKIidOfCert(peerIdentity) 486 return c.idMapper.Verify(pkiID, signature, message) 487 } 488 err = m.Verify(receivedMsg.Identity, verifier) 489 if err != nil { 490 c.logger.Errorf("Failed verifying signature from %s : %v", remoteAddress, err) 491 return nil, err 492 } 493 494 c.logger.Debug("Authenticated", remoteAddress) 495 496 if receivedMsg.Probe { 497 return connInfo, errProbe 498 } 499 500 return connInfo, nil 501 } 502 503 // SendWithAck sends a message to remote peers, waiting for acknowledgement from minAck of them, or until a certain timeout expires 504 func (c *commImpl) SendWithAck(msg *protoext.SignedGossipMessage, timeout time.Duration, minAck int, peers ...*RemotePeer) AggregatedSendResult { 505 if len(peers) == 0 { 506 return nil 507 } 508 var err error 509 510 // Roll a random NONCE to be used as a send ID to differentiate 511 // between different invocations 512 msg.Nonce = util.RandomUInt64() 513 // Replace the envelope in the message to update the NONCE 514 msg, err = protoext.NoopSign(msg.GossipMessage) 515 516 if c.isStopping() || err != nil { 517 if err == nil { 518 err = errors.New("comm is stopping") 519 } 520 results := []SendResult{} 521 for _, p := range peers { 522 results = append(results, SendResult{ 523 error: err, 524 RemotePeer: *p, 525 }) 526 } 527 return results 528 } 529 c.logger.Debug("Entering, sending", msg, "to ", len(peers), "peers") 530 sndFunc := func(peer *RemotePeer, msg *protoext.SignedGossipMessage) { 531 c.sendToEndpoint(peer, msg, blockingSend) 532 } 533 // Subscribe to acks 534 subscriptions := make(map[string]func() error) 535 for _, p := range peers { 536 topic := topicForAck(msg.Nonce, p.PKIID) 537 sub := c.pubSub.Subscribe(topic, timeout) 538 subscriptions[string(p.PKIID)] = func() error { 539 msg, err := sub.Listen() 540 if err != nil { 541 return err 542 } 543 if msg, isAck := msg.(*proto.Acknowledgement); !isAck { 544 return errors.Errorf("received a message of type %s, expected *proto.Acknowledgement", reflect.TypeOf(msg)) 545 } else { 546 if msg.Error != "" { 547 return errors.New(msg.Error) 548 } 549 } 550 return nil 551 } 552 } 553 waitForAck := func(p *RemotePeer) error { 554 return subscriptions[string(p.PKIID)]() 555 } 556 ackOperation := newAckSendOperation(sndFunc, waitForAck) 557 return ackOperation.send(msg, minAck, peers...) 558 } 559 560 func (c *commImpl) GossipStream(stream proto.Gossip_GossipStreamServer) error { 561 if c.isStopping() { 562 return errors.New("shutting down") 563 } 564 connInfo, err := c.authenticateRemotePeer(stream, false, false) 565 566 if err == errProbe { 567 c.logger.Infof("Peer %s (%s) probed us", connInfo.ID, connInfo.Endpoint) 568 return nil 569 } 570 571 if err != nil { 572 c.logger.Errorf("Authentication failed: %v", err) 573 return err 574 } 575 c.logger.Debug("Servicing", extractRemoteAddress(stream)) 576 577 conn := c.connStore.onConnected(stream, connInfo, c.metrics) 578 579 h := func(m *protoext.SignedGossipMessage) { 580 c.msgPublisher.DeMultiplex(&ReceivedMessageImpl{ 581 conn: conn, 582 SignedGossipMessage: m, 583 connInfo: connInfo, 584 }) 585 } 586 587 conn.handler = interceptAcks(h, connInfo.ID, c.pubSub) 588 589 defer func() { 590 c.logger.Debug("Client", extractRemoteAddress(stream), " disconnected") 591 c.connStore.closeConnByPKIid(connInfo.ID) 592 }() 593 594 return conn.serviceConnection() 595 } 596 597 func (c *commImpl) Ping(context.Context, *proto.Empty) (*proto.Empty, error) { 598 return &proto.Empty{}, nil 599 } 600 601 func (c *commImpl) disconnect(pkiID common.PKIidType) { 602 select { 603 case c.deadEndpoints <- pkiID: 604 case <-c.exitChan: 605 return 606 } 607 608 c.connStore.closeConnByPKIid(pkiID) 609 } 610 611 func readWithTimeout(stream stream, timeout time.Duration, address string) (*protoext.SignedGossipMessage, error) { 612 incChan := make(chan *protoext.SignedGossipMessage, 1) 613 errChan := make(chan error, 1) 614 go func() { 615 if m, err := stream.Recv(); err == nil { 616 msg, err := protoext.EnvelopeToGossipMessage(m) 617 if err != nil { 618 errChan <- err 619 return 620 } 621 incChan <- msg 622 } 623 }() 624 select { 625 case <-time.After(timeout): 626 return nil, errors.Errorf("timed out waiting for connection message from %s", address) 627 case m := <-incChan: 628 return m, nil 629 case err := <-errChan: 630 return nil, errors.WithStack(err) 631 } 632 } 633 634 func (c *commImpl) createConnectionMsg(pkiID common.PKIidType, certHash []byte, cert api.PeerIdentityType, signer protoext.Signer, isProbe bool) (*protoext.SignedGossipMessage, error) { 635 m := &proto.GossipMessage{ 636 Tag: proto.GossipMessage_EMPTY, 637 Nonce: 0, 638 Content: &proto.GossipMessage_Conn{ 639 Conn: &proto.ConnEstablish{ 640 TlsCertHash: certHash, 641 Identity: cert, 642 PkiId: pkiID, 643 Probe: isProbe, 644 }, 645 }, 646 } 647 sMsg := &protoext.SignedGossipMessage{ 648 GossipMessage: m, 649 } 650 _, err := sMsg.Sign(signer) 651 return sMsg, errors.WithStack(err) 652 } 653 654 type stream interface { 655 Send(envelope *proto.Envelope) error 656 Recv() (*proto.Envelope, error) 657 Context() context.Context 658 } 659 660 func topicForAck(nonce uint64, pkiID common.PKIidType) string { 661 return fmt.Sprintf("%d %s", nonce, hex.EncodeToString(pkiID)) 662 }