github.com/osdi23p228/fabric@v0.0.0-20221218062954-77808885f5db/internal/pkg/comm/creds.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"errors"
    14  	"net"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/osdi23p228/fabric/common/flogging"
    19  	"google.golang.org/grpc/credentials"
    20  )
    21  
    22  var (
    23  	ErrClientHandshakeNotImplemented = errors.New("core/comm: client handshakes are not implemented with serverCreds")
    24  	ErrServerHandshakeNotImplemented = errors.New("core/comm: server handshakes are not implemented with clientCreds")
    25  	ErrOverrideHostnameNotSupported  = errors.New("core/comm: OverrideServerName is not supported")
    26  
    27  	// alpnProtoStr are the specified application level protocols for gRPC.
    28  	alpnProtoStr = []string{"h2"}
    29  
    30  	// Logger for TLS client connections
    31  	tlsClientLogger = flogging.MustGetLogger("comm.tls")
    32  )
    33  
    34  // NewServerTransportCredentials returns a new initialized
    35  // grpc/credentials.TransportCredentials
    36  func NewServerTransportCredentials(
    37  	serverConfig *TLSConfig,
    38  	logger *flogging.FabricLogger) credentials.TransportCredentials {
    39  	// NOTE: unlike the default grpc/credentials implementation, we do not
    40  	// clone the tls.Config which allows us to update it dynamically
    41  	serverConfig.config.NextProtos = alpnProtoStr
    42  	serverConfig.config.MinVersion = tls.VersionTLS12
    43  
    44  	if logger == nil {
    45  		logger = tlsClientLogger
    46  	}
    47  
    48  	return &serverCreds{
    49  		serverConfig: serverConfig,
    50  		logger:       logger}
    51  }
    52  
    53  // serverCreds is an implementation of grpc/credentials.TransportCredentials.
    54  type serverCreds struct {
    55  	serverConfig *TLSConfig
    56  	logger       *flogging.FabricLogger
    57  }
    58  
    59  type TLSConfig struct {
    60  	config *tls.Config
    61  	lock   sync.RWMutex
    62  }
    63  
    64  func NewTLSConfig(config *tls.Config) *TLSConfig {
    65  	return &TLSConfig{
    66  		config: config,
    67  	}
    68  }
    69  
    70  func (t *TLSConfig) Config() tls.Config {
    71  	t.lock.RLock()
    72  	defer t.lock.RUnlock()
    73  
    74  	if t.config != nil {
    75  		return *t.config.Clone()
    76  	}
    77  
    78  	return tls.Config{}
    79  }
    80  
    81  func (t *TLSConfig) AddClientRootCA(cert *x509.Certificate) {
    82  	t.lock.Lock()
    83  	defer t.lock.Unlock()
    84  
    85  	t.config.ClientCAs.AddCert(cert)
    86  }
    87  
    88  func (t *TLSConfig) SetClientCAs(certPool *x509.CertPool) {
    89  	t.lock.Lock()
    90  	defer t.lock.Unlock()
    91  
    92  	t.config.ClientCAs = certPool
    93  }
    94  
    95  // ClientHandShake is not implemented for `serverCreds`.
    96  func (sc *serverCreds) ClientHandshake(context.Context,
    97  	string, net.Conn) (net.Conn, credentials.AuthInfo, error) {
    98  	return nil, nil, ErrClientHandshakeNotImplemented
    99  }
   100  
   101  // ServerHandshake does the authentication handshake for servers.
   102  func (sc *serverCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   103  	serverConfig := sc.serverConfig.Config()
   104  
   105  	conn := tls.Server(rawConn, &serverConfig)
   106  	l := sc.logger.With("remote address", conn.RemoteAddr().String())
   107  	start := time.Now()
   108  	if err := conn.Handshake(); err != nil {
   109  		l.Errorf("Server TLS handshake failed in %s with error %s", time.Since(start), err)
   110  		return nil, nil, err
   111  	}
   112  	l.Debugf("Server TLS handshake completed in %s", time.Since(start))
   113  	return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil
   114  }
   115  
   116  // Info provides the ProtocolInfo of this TransportCredentials.
   117  func (sc *serverCreds) Info() credentials.ProtocolInfo {
   118  	return credentials.ProtocolInfo{
   119  		SecurityProtocol: "tls",
   120  		SecurityVersion:  "1.2",
   121  	}
   122  }
   123  
   124  // Clone makes a copy of this TransportCredentials.
   125  func (sc *serverCreds) Clone() credentials.TransportCredentials {
   126  	config := sc.serverConfig.Config()
   127  	serverConfig := NewTLSConfig(&config)
   128  	return NewServerTransportCredentials(serverConfig, sc.logger)
   129  }
   130  
   131  // OverrideServerName overrides the server name used to verify the hostname
   132  // on the returned certificates from the server.
   133  func (sc *serverCreds) OverrideServerName(string) error {
   134  	return ErrOverrideHostnameNotSupported
   135  }
   136  
   137  type DynamicClientCredentials struct {
   138  	TLSConfig  *tls.Config
   139  	TLSOptions []TLSOption
   140  }
   141  
   142  func (dtc *DynamicClientCredentials) latestConfig() *tls.Config {
   143  	tlsConfigCopy := dtc.TLSConfig.Clone()
   144  	for _, tlsOption := range dtc.TLSOptions {
   145  		tlsOption(tlsConfigCopy)
   146  	}
   147  	return tlsConfigCopy
   148  }
   149  
   150  func (dtc *DynamicClientCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   151  	l := tlsClientLogger.With("remote address", rawConn.RemoteAddr().String())
   152  	creds := credentials.NewTLS(dtc.latestConfig())
   153  	start := time.Now()
   154  	conn, auth, err := creds.ClientHandshake(ctx, authority, rawConn)
   155  	if err != nil {
   156  		l.Errorf("Client TLS handshake failed after %s with error: %s", time.Since(start), err)
   157  	} else {
   158  		l.Debugf("Client TLS handshake completed in %s", time.Since(start))
   159  	}
   160  	return conn, auth, err
   161  }
   162  
   163  func (dtc *DynamicClientCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   164  	return nil, nil, ErrServerHandshakeNotImplemented
   165  }
   166  
   167  func (dtc *DynamicClientCredentials) Info() credentials.ProtocolInfo {
   168  	return credentials.NewTLS(dtc.latestConfig()).Info()
   169  }
   170  
   171  func (dtc *DynamicClientCredentials) Clone() credentials.TransportCredentials {
   172  	return credentials.NewTLS(dtc.latestConfig())
   173  }
   174  
   175  func (dtc *DynamicClientCredentials) OverrideServerName(name string) error {
   176  	dtc.TLSConfig.ServerName = name
   177  	return nil
   178  }