github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/creds.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm 8 9 import ( 10 "context" 11 "crypto/tls" 12 "crypto/x509" 13 "errors" 14 "net" 15 "sync" 16 "time" 17 18 "github.com/hechain20/hechain/common/flogging" 19 "google.golang.org/grpc/credentials" 20 ) 21 22 var ( 23 ErrClientHandshakeNotImplemented = errors.New("core/comm: client handshakes are not implemented with serverCreds") 24 ErrServerHandshakeNotImplemented = errors.New("core/comm: server handshakes are not implemented with clientCreds") 25 ErrOverrideHostnameNotSupported = errors.New("core/comm: OverrideServerName is not supported") 26 27 // alpnProtoStr are the specified application level protocols for gRPC. 28 alpnProtoStr = []string{"h2"} 29 30 // Logger for TLS client connections 31 tlsClientLogger = flogging.MustGetLogger("comm.tls") 32 ) 33 34 // NewServerTransportCredentials returns a new initialized 35 // grpc/credentials.TransportCredentials 36 func NewServerTransportCredentials( 37 serverConfig *TLSConfig, 38 logger *flogging.FabricLogger) credentials.TransportCredentials { 39 // NOTE: unlike the default grpc/credentials implementation, we do not 40 // clone the tls.Config which allows us to update it dynamically 41 serverConfig.config.NextProtos = alpnProtoStr 42 serverConfig.config.MinVersion = tls.VersionTLS12 43 44 if logger == nil { 45 logger = tlsClientLogger 46 } 47 48 return &serverCreds{ 49 serverConfig: serverConfig, 50 logger: logger, 51 } 52 } 53 54 // serverCreds is an implementation of grpc/credentials.TransportCredentials. 55 type serverCreds struct { 56 serverConfig *TLSConfig 57 logger *flogging.FabricLogger 58 } 59 60 type TLSConfig struct { 61 config *tls.Config 62 lock sync.RWMutex 63 } 64 65 func NewTLSConfig(config *tls.Config) *TLSConfig { 66 return &TLSConfig{ 67 config: config, 68 } 69 } 70 71 func (t *TLSConfig) Config() tls.Config { 72 t.lock.RLock() 73 defer t.lock.RUnlock() 74 75 if t.config != nil { 76 return *t.config.Clone() 77 } 78 79 return tls.Config{} 80 } 81 82 func (t *TLSConfig) AddClientRootCA(cert *x509.Certificate) { 83 t.lock.Lock() 84 defer t.lock.Unlock() 85 86 t.config.ClientCAs.AddCert(cert) 87 } 88 89 func (t *TLSConfig) SetClientCAs(certPool *x509.CertPool) { 90 t.lock.Lock() 91 defer t.lock.Unlock() 92 93 t.config.ClientCAs = certPool 94 } 95 96 // ClientHandShake is not implemented for `serverCreds`. 97 func (sc *serverCreds) ClientHandshake(context.Context, string, net.Conn) (net.Conn, credentials.AuthInfo, error) { 98 return nil, nil, ErrClientHandshakeNotImplemented 99 } 100 101 // ServerHandshake does the authentication handshake for servers. 102 func (sc *serverCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 103 serverConfig := sc.serverConfig.Config() 104 105 conn := tls.Server(rawConn, &serverConfig) 106 l := sc.logger.With("remote address", conn.RemoteAddr().String()) 107 start := time.Now() 108 if err := conn.Handshake(); err != nil { 109 l.Errorf("Server TLS handshake failed in %s with error %s", time.Since(start), err) 110 return nil, nil, err 111 } 112 l.Debugf("Server TLS handshake completed in %s", time.Since(start)) 113 return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil 114 } 115 116 // Info provides the ProtocolInfo of this TransportCredentials. 117 func (sc *serverCreds) Info() credentials.ProtocolInfo { 118 return credentials.ProtocolInfo{ 119 SecurityProtocol: "tls", 120 SecurityVersion: "1.2", 121 } 122 } 123 124 // Clone makes a copy of this TransportCredentials. 125 func (sc *serverCreds) Clone() credentials.TransportCredentials { 126 config := sc.serverConfig.Config() 127 serverConfig := NewTLSConfig(&config) 128 return NewServerTransportCredentials(serverConfig, sc.logger) 129 } 130 131 // OverrideServerName overrides the server name used to verify the hostname 132 // on the returned certificates from the server. 133 func (sc *serverCreds) OverrideServerName(string) error { 134 return ErrOverrideHostnameNotSupported 135 } 136 137 type DynamicClientCredentials struct { 138 TLSConfig *tls.Config 139 } 140 141 func (dtc *DynamicClientCredentials) latestConfig() *tls.Config { 142 return dtc.TLSConfig.Clone() 143 } 144 145 func (dtc *DynamicClientCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 146 l := tlsClientLogger.With("remote address", rawConn.RemoteAddr().String()) 147 creds := credentials.NewTLS(dtc.latestConfig()) 148 start := time.Now() 149 conn, auth, err := creds.ClientHandshake(ctx, authority, rawConn) 150 if err != nil { 151 l.Errorf("Client TLS handshake failed after %s with error: %s", time.Since(start), err) 152 } else { 153 l.Debugf("Client TLS handshake completed in %s", time.Since(start)) 154 } 155 return conn, auth, err 156 } 157 158 func (dtc *DynamicClientCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 159 return nil, nil, ErrServerHandshakeNotImplemented 160 } 161 162 func (dtc *DynamicClientCredentials) Info() credentials.ProtocolInfo { 163 return credentials.NewTLS(dtc.latestConfig()).Info() 164 } 165 166 func (dtc *DynamicClientCredentials) Clone() credentials.TransportCredentials { 167 return credentials.NewTLS(dtc.latestConfig()) 168 } 169 170 func (dtc *DynamicClientCredentials) OverrideServerName(name string) error { 171 dtc.TLSConfig.ServerName = name 172 return nil 173 }