github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/orderer/common/cluster/connections.go (about) 1 /* 2 Copyright hechain. 2017 All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package cluster 8 9 import ( 10 "crypto/x509" 11 "sync" 12 13 "github.com/hechain20/hechain/common/crypto" 14 "github.com/hechain20/hechain/common/metrics" 15 "github.com/pkg/errors" 16 "google.golang.org/grpc" 17 ) 18 19 // RemoteVerifier verifies the connection to the remote host 20 type RemoteVerifier func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error 21 22 //go:generate mockery -dir . -name SecureDialer -case underscore -output ./mocks/ 23 24 // SecureDialer connects to a remote address 25 type SecureDialer interface { 26 Dial(address string, verifyFunc RemoteVerifier) (*grpc.ClientConn, error) 27 } 28 29 // ConnectionMapper maps certificates to connections 30 type ConnectionMapper interface { 31 Lookup(cert []byte) (*grpc.ClientConn, bool) 32 Put(cert []byte, conn *grpc.ClientConn) 33 Remove(cert []byte) 34 Size() int 35 } 36 37 // ConnectionStore stores connections to remote nodes 38 type ConnectionStore struct { 39 lock sync.RWMutex 40 Connections ConnectionMapper 41 dialer SecureDialer 42 } 43 44 // NewConnectionStore creates a new ConnectionStore with the given SecureDialer 45 func NewConnectionStore(dialer SecureDialer, tlsConnectionCount metrics.Gauge) *ConnectionStore { 46 connMapping := &ConnectionStore{ 47 Connections: &connMapperReporter{ 48 ConnectionMapper: make(ConnByCertMap), 49 tlsConnectionCountMetrics: tlsConnectionCount, 50 }, 51 dialer: dialer, 52 } 53 return connMapping 54 } 55 56 // verifyHandshake returns a predicate that verifies that the remote node authenticates 57 // itself with the given TLS certificate 58 func (c *ConnectionStore) verifyHandshake(endpoint string, certificate []byte) RemoteVerifier { 59 return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 60 err := crypto.CertificatesWithSamePublicKey(certificate, rawCerts[0]) 61 if err == nil { 62 return nil 63 } 64 return errors.Errorf("public key of server certificate presented by %s doesn't match the expected public key", 65 endpoint) 66 } 67 } 68 69 // Disconnect closes the gRPC connection that is mapped to the given certificate 70 func (c *ConnectionStore) Disconnect(expectedServerCert []byte) { 71 c.lock.Lock() 72 defer c.lock.Unlock() 73 74 conn, connected := c.Connections.Lookup(expectedServerCert) 75 if !connected { 76 return 77 } 78 conn.Close() 79 c.Connections.Remove(expectedServerCert) 80 } 81 82 // Connection obtains a connection to the given endpoint and expects the given server certificate 83 // to be presented by the remote node 84 func (c *ConnectionStore) Connection(endpoint string, expectedServerCert []byte) (*grpc.ClientConn, error) { 85 c.lock.RLock() 86 conn, alreadyConnected := c.Connections.Lookup(expectedServerCert) 87 c.lock.RUnlock() 88 89 if alreadyConnected { 90 return conn, nil 91 } 92 93 // Else, we need to connect to the remote endpoint 94 return c.connect(endpoint, expectedServerCert) 95 } 96 97 // connect connects to the given endpoint and expects the given TLS server certificate 98 // to be presented at the time of authentication 99 func (c *ConnectionStore) connect(endpoint string, expectedServerCert []byte) (*grpc.ClientConn, error) { 100 c.lock.Lock() 101 defer c.lock.Unlock() 102 // Check again to see if some other goroutine has already connected while 103 // we were waiting on the lock 104 conn, alreadyConnected := c.Connections.Lookup(expectedServerCert) 105 if alreadyConnected { 106 return conn, nil 107 } 108 109 v := c.verifyHandshake(endpoint, expectedServerCert) 110 conn, err := c.dialer.Dial(endpoint, v) 111 if err != nil { 112 return nil, err 113 } 114 115 c.Connections.Put(expectedServerCert, conn) 116 return conn, nil 117 } 118 119 type connMapperReporter struct { 120 tlsConnectionCountMetrics metrics.Gauge 121 ConnectionMapper 122 } 123 124 func (cmg *connMapperReporter) Put(cert []byte, conn *grpc.ClientConn) { 125 cmg.ConnectionMapper.Put(cert, conn) 126 cmg.reportSize() 127 } 128 129 func (cmg *connMapperReporter) Remove(cert []byte) { 130 cmg.ConnectionMapper.Remove(cert) 131 cmg.reportSize() 132 } 133 134 func (cmg *connMapperReporter) reportSize() { 135 cmg.tlsConnectionCountMetrics.Set(float64(cmg.ConnectionMapper.Size())) 136 }