github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/creds.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"errors"
    14  	"net"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/hechain20/hechain/common/flogging"
    19  	"google.golang.org/grpc/credentials"
    20  )
    21  
    22  var (
    23  	ErrClientHandshakeNotImplemented = errors.New("core/comm: client handshakes are not implemented with serverCreds")
    24  	ErrServerHandshakeNotImplemented = errors.New("core/comm: server handshakes are not implemented with clientCreds")
    25  	ErrOverrideHostnameNotSupported  = errors.New("core/comm: OverrideServerName is not supported")
    26  
    27  	// alpnProtoStr are the specified application level protocols for gRPC.
    28  	alpnProtoStr = []string{"h2"}
    29  
    30  	// Logger for TLS client connections
    31  	tlsClientLogger = flogging.MustGetLogger("comm.tls")
    32  )
    33  
    34  // NewServerTransportCredentials returns a new initialized
    35  // grpc/credentials.TransportCredentials
    36  func NewServerTransportCredentials(
    37  	serverConfig *TLSConfig,
    38  	logger *flogging.FabricLogger) credentials.TransportCredentials {
    39  	// NOTE: unlike the default grpc/credentials implementation, we do not
    40  	// clone the tls.Config which allows us to update it dynamically
    41  	serverConfig.config.NextProtos = alpnProtoStr
    42  	serverConfig.config.MinVersion = tls.VersionTLS12
    43  
    44  	if logger == nil {
    45  		logger = tlsClientLogger
    46  	}
    47  
    48  	return &serverCreds{
    49  		serverConfig: serverConfig,
    50  		logger:       logger,
    51  	}
    52  }
    53  
    54  // serverCreds is an implementation of grpc/credentials.TransportCredentials.
    55  type serverCreds struct {
    56  	serverConfig *TLSConfig
    57  	logger       *flogging.FabricLogger
    58  }
    59  
    60  type TLSConfig struct {
    61  	config *tls.Config
    62  	lock   sync.RWMutex
    63  }
    64  
    65  func NewTLSConfig(config *tls.Config) *TLSConfig {
    66  	return &TLSConfig{
    67  		config: config,
    68  	}
    69  }
    70  
    71  func (t *TLSConfig) Config() tls.Config {
    72  	t.lock.RLock()
    73  	defer t.lock.RUnlock()
    74  
    75  	if t.config != nil {
    76  		return *t.config.Clone()
    77  	}
    78  
    79  	return tls.Config{}
    80  }
    81  
    82  func (t *TLSConfig) AddClientRootCA(cert *x509.Certificate) {
    83  	t.lock.Lock()
    84  	defer t.lock.Unlock()
    85  
    86  	t.config.ClientCAs.AddCert(cert)
    87  }
    88  
    89  func (t *TLSConfig) SetClientCAs(certPool *x509.CertPool) {
    90  	t.lock.Lock()
    91  	defer t.lock.Unlock()
    92  
    93  	t.config.ClientCAs = certPool
    94  }
    95  
    96  // ClientHandShake is not implemented for `serverCreds`.
    97  func (sc *serverCreds) ClientHandshake(context.Context, string, net.Conn) (net.Conn, credentials.AuthInfo, error) {
    98  	return nil, nil, ErrClientHandshakeNotImplemented
    99  }
   100  
   101  // ServerHandshake does the authentication handshake for servers.
   102  func (sc *serverCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   103  	serverConfig := sc.serverConfig.Config()
   104  
   105  	conn := tls.Server(rawConn, &serverConfig)
   106  	l := sc.logger.With("remote address", conn.RemoteAddr().String())
   107  	start := time.Now()
   108  	if err := conn.Handshake(); err != nil {
   109  		l.Errorf("Server TLS handshake failed in %s with error %s", time.Since(start), err)
   110  		return nil, nil, err
   111  	}
   112  	l.Debugf("Server TLS handshake completed in %s", time.Since(start))
   113  	return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil
   114  }
   115  
   116  // Info provides the ProtocolInfo of this TransportCredentials.
   117  func (sc *serverCreds) Info() credentials.ProtocolInfo {
   118  	return credentials.ProtocolInfo{
   119  		SecurityProtocol: "tls",
   120  		SecurityVersion:  "1.2",
   121  	}
   122  }
   123  
   124  // Clone makes a copy of this TransportCredentials.
   125  func (sc *serverCreds) Clone() credentials.TransportCredentials {
   126  	config := sc.serverConfig.Config()
   127  	serverConfig := NewTLSConfig(&config)
   128  	return NewServerTransportCredentials(serverConfig, sc.logger)
   129  }
   130  
   131  // OverrideServerName overrides the server name used to verify the hostname
   132  // on the returned certificates from the server.
   133  func (sc *serverCreds) OverrideServerName(string) error {
   134  	return ErrOverrideHostnameNotSupported
   135  }
   136  
   137  type DynamicClientCredentials struct {
   138  	TLSConfig *tls.Config
   139  }
   140  
   141  func (dtc *DynamicClientCredentials) latestConfig() *tls.Config {
   142  	return dtc.TLSConfig.Clone()
   143  }
   144  
   145  func (dtc *DynamicClientCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   146  	l := tlsClientLogger.With("remote address", rawConn.RemoteAddr().String())
   147  	creds := credentials.NewTLS(dtc.latestConfig())
   148  	start := time.Now()
   149  	conn, auth, err := creds.ClientHandshake(ctx, authority, rawConn)
   150  	if err != nil {
   151  		l.Errorf("Client TLS handshake failed after %s with error: %s", time.Since(start), err)
   152  	} else {
   153  		l.Debugf("Client TLS handshake completed in %s", time.Since(start))
   154  	}
   155  	return conn, auth, err
   156  }
   157  
   158  func (dtc *DynamicClientCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   159  	return nil, nil, ErrServerHandshakeNotImplemented
   160  }
   161  
   162  func (dtc *DynamicClientCredentials) Info() credentials.ProtocolInfo {
   163  	return credentials.NewTLS(dtc.latestConfig()).Info()
   164  }
   165  
   166  func (dtc *DynamicClientCredentials) Clone() credentials.TransportCredentials {
   167  	return credentials.NewTLS(dtc.latestConfig())
   168  }
   169  
   170  func (dtc *DynamicClientCredentials) OverrideServerName(name string) error {
   171  	dtc.TLSConfig.ServerName = name
   172  	return nil
   173  }