github.com/osdi23p228/fabric@v0.0.0-20221218062954-77808885f5db/internal/pkg/comm/client.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"time"
    14  
    15  	"github.com/pkg/errors"
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/keepalive"
    18  )
    19  
    20  type GRPCClient struct {
    21  	// TLS configuration used by the grpc.ClientConn
    22  	tlsConfig *tls.Config
    23  	// Options for setting up new connections
    24  	dialOpts []grpc.DialOption
    25  	// Duration for which to block while established a new connection
    26  	timeout time.Duration
    27  	// Maximum message size the client can receive
    28  	maxRecvMsgSize int
    29  	// Maximum message size the client can send
    30  	maxSendMsgSize int
    31  }
    32  
    33  // NewGRPCClient creates a new implementation of GRPCClient given an address
    34  // and client configuration
    35  func NewGRPCClient(config ClientConfig) (*GRPCClient, error) {
    36  	client := &GRPCClient{}
    37  
    38  	// parse secure options
    39  	err := client.parseSecureOptions(config.SecOpts)
    40  	if err != nil {
    41  		return client, err
    42  	}
    43  
    44  	// keepalive options
    45  
    46  	kap := keepalive.ClientParameters{
    47  		Time:                config.KaOpts.ClientInterval,
    48  		Timeout:             config.KaOpts.ClientTimeout,
    49  		PermitWithoutStream: true,
    50  	}
    51  	// set keepalive
    52  	client.dialOpts = append(client.dialOpts, grpc.WithKeepaliveParams(kap))
    53  	// Unless asynchronous connect is set, make connection establishment blocking.
    54  	if !config.AsyncConnect {
    55  		client.dialOpts = append(client.dialOpts, grpc.WithBlock())
    56  		client.dialOpts = append(client.dialOpts, grpc.FailOnNonTempDialError(true))
    57  	}
    58  	client.timeout = config.Timeout
    59  	// set send/recv message size to package defaults
    60  	client.maxRecvMsgSize = MaxRecvMsgSize
    61  	client.maxSendMsgSize = MaxSendMsgSize
    62  
    63  	return client, nil
    64  }
    65  
    66  func (client *GRPCClient) parseSecureOptions(opts SecureOptions) error {
    67  	// if TLS is not enabled, return
    68  	if !opts.UseTLS {
    69  		return nil
    70  	}
    71  
    72  	client.tlsConfig = &tls.Config{
    73  		VerifyPeerCertificate: opts.VerifyCertificate,
    74  		MinVersion:            tls.VersionTLS12,
    75  	}
    76  	if len(opts.ServerRootCAs) > 0 {
    77  		client.tlsConfig.RootCAs = x509.NewCertPool()
    78  		for _, certBytes := range opts.ServerRootCAs {
    79  			err := AddPemToCertPool(certBytes, client.tlsConfig.RootCAs)
    80  			if err != nil {
    81  				commLogger.Debugf("error adding root certificate: %v", err)
    82  				return errors.WithMessage(err, "error adding root certificate")
    83  			}
    84  		}
    85  	}
    86  	if opts.RequireClientCert {
    87  		// make sure we have both Key and Certificate
    88  		if opts.Key != nil &&
    89  			opts.Certificate != nil {
    90  			cert, err := tls.X509KeyPair(opts.Certificate,
    91  				opts.Key)
    92  			if err != nil {
    93  				return errors.WithMessage(err, "failed to load client certificate")
    94  			}
    95  			client.tlsConfig.Certificates = append(
    96  				client.tlsConfig.Certificates, cert)
    97  		} else {
    98  			return errors.New("both Key and Certificate are required when using mutual TLS")
    99  		}
   100  	}
   101  
   102  	if opts.TimeShift > 0 {
   103  		client.tlsConfig.Time = func() time.Time {
   104  			return time.Now().Add((-1) * opts.TimeShift)
   105  		}
   106  	}
   107  
   108  	return nil
   109  }
   110  
   111  // Certificate returns the tls.Certificate used to make TLS connections
   112  // when client certificates are required by the server
   113  func (client *GRPCClient) Certificate() tls.Certificate {
   114  	cert := tls.Certificate{}
   115  	if client.tlsConfig != nil && len(client.tlsConfig.Certificates) > 0 {
   116  		cert = client.tlsConfig.Certificates[0]
   117  	}
   118  	return cert
   119  }
   120  
   121  // TLSEnabled is a flag indicating whether to use TLS for client
   122  // connections
   123  func (client *GRPCClient) TLSEnabled() bool {
   124  	return client.tlsConfig != nil
   125  }
   126  
   127  // MutualTLSRequired is a flag indicating whether the client
   128  // must send a certificate when making TLS connections
   129  func (client *GRPCClient) MutualTLSRequired() bool {
   130  	return client.tlsConfig != nil &&
   131  		len(client.tlsConfig.Certificates) > 0
   132  }
   133  
   134  // SetMaxRecvMsgSize sets the maximum message size the client can receive
   135  func (client *GRPCClient) SetMaxRecvMsgSize(size int) {
   136  	client.maxRecvMsgSize = size
   137  }
   138  
   139  // SetMaxSendMsgSize sets the maximum message size the client can send
   140  func (client *GRPCClient) SetMaxSendMsgSize(size int) {
   141  	client.maxSendMsgSize = size
   142  }
   143  
   144  // SetServerRootCAs sets the list of authorities used to verify server
   145  // certificates based on a list of PEM-encoded X509 certificate authorities
   146  func (client *GRPCClient) SetServerRootCAs(serverRoots [][]byte) error {
   147  
   148  	// NOTE: if no serverRoots are specified, the current cert pool will be
   149  	// replaced with an empty one
   150  	certPool := x509.NewCertPool()
   151  	for _, root := range serverRoots {
   152  		err := AddPemToCertPool(root, certPool)
   153  		if err != nil {
   154  			return errors.WithMessage(err, "error adding root certificate")
   155  		}
   156  	}
   157  	client.tlsConfig.RootCAs = certPool
   158  	return nil
   159  }
   160  
   161  type TLSOption func(tlsConfig *tls.Config)
   162  
   163  func ServerNameOverride(name string) TLSOption {
   164  	return func(tlsConfig *tls.Config) {
   165  		tlsConfig.ServerName = name
   166  	}
   167  }
   168  
   169  func CertPoolOverride(pool *x509.CertPool) TLSOption {
   170  	return func(tlsConfig *tls.Config) {
   171  		tlsConfig.RootCAs = pool
   172  	}
   173  }
   174  
   175  // NewConnection returns a grpc.ClientConn for the target address and
   176  // overrides the server name used to verify the hostname on the
   177  // certificate returned by a server when using TLS
   178  func (client *GRPCClient) NewConnection(address string, tlsOptions ...TLSOption) (*grpc.ClientConn, error) {
   179  
   180  	var dialOpts []grpc.DialOption
   181  	dialOpts = append(dialOpts, client.dialOpts...)
   182  
   183  	// set transport credentials and max send/recv message sizes
   184  	// immediately before creating a connection in order to allow
   185  	// SetServerRootCAs / SetMaxRecvMsgSize / SetMaxSendMsgSize
   186  	//  to take effect on a per connection basis
   187  	if client.tlsConfig != nil {
   188  		dialOpts = append(dialOpts, grpc.WithTransportCredentials(
   189  			&DynamicClientCredentials{
   190  				TLSConfig:  client.tlsConfig,
   191  				TLSOptions: tlsOptions,
   192  			},
   193  		))
   194  	} else {
   195  		dialOpts = append(dialOpts, grpc.WithInsecure())
   196  	}
   197  
   198  	dialOpts = append(dialOpts, grpc.WithDefaultCallOptions(
   199  		grpc.MaxCallRecvMsgSize(client.maxRecvMsgSize),
   200  		grpc.MaxCallSendMsgSize(client.maxSendMsgSize),
   201  	))
   202  
   203  	ctx, cancel := context.WithTimeout(context.Background(), client.timeout)
   204  	defer cancel()
   205  	conn, err := grpc.DialContext(ctx, address, dialOpts...)
   206  	if err != nil {
   207  		return nil, errors.WithMessage(errors.WithStack(err),
   208  			"failed to create new connection")
   209  	}
   210  	return conn, nil
   211  }