github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/server.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"encoding/pem"
    13  	"net"
    14  	"sync"
    15  	"sync/atomic"
    16  	"time"
    17  
    18  	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
    19  	"github.com/pkg/errors"
    20  	"google.golang.org/grpc"
    21  	"google.golang.org/grpc/health"
    22  	healthpb "google.golang.org/grpc/health/grpc_health_v1"
    23  )
    24  
    25  type GRPCServer struct {
    26  	// Listen address for the server specified as hostname:port
    27  	address string
    28  	// Listener for handling network requests
    29  	listener net.Listener
    30  	// GRPC server
    31  	server *grpc.Server
    32  	// Certificate presented by the server for TLS communication
    33  	// stored as an atomic reference
    34  	serverCertificate atomic.Value
    35  	// lock to protect concurrent access to append / remove
    36  	lock *sync.Mutex
    37  	// TLS configuration used by the grpc server
    38  	tls *TLSConfig
    39  	// Server for gRPC Health Check Protocol.
    40  	healthServer *health.Server
    41  }
    42  
    43  // NewGRPCServer creates a new implementation of a GRPCServer given a
    44  // listen address
    45  func NewGRPCServer(address string, serverConfig ServerConfig) (*GRPCServer, error) {
    46  	if address == "" {
    47  		return nil, errors.New("missing address parameter")
    48  	}
    49  	// create our listener
    50  	lis, err := net.Listen("tcp", address)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	return NewGRPCServerFromListener(lis, serverConfig)
    55  }
    56  
    57  // NewGRPCServerFromListener creates a new implementation of a GRPCServer given
    58  // an existing net.Listener instance using default keepalive
    59  func NewGRPCServerFromListener(listener net.Listener, serverConfig ServerConfig) (*GRPCServer, error) {
    60  	grpcServer := &GRPCServer{
    61  		address:  listener.Addr().String(),
    62  		listener: listener,
    63  		lock:     &sync.Mutex{},
    64  	}
    65  
    66  	// set up our server options
    67  	var serverOpts []grpc.ServerOption
    68  
    69  	secureConfig := serverConfig.SecOpts
    70  	if secureConfig.UseTLS {
    71  		// both key and cert are required
    72  		if secureConfig.Key != nil && secureConfig.Certificate != nil {
    73  			// load server public and private keys
    74  			cert, err := tls.X509KeyPair(secureConfig.Certificate, secureConfig.Key)
    75  			if err != nil {
    76  				return nil, err
    77  			}
    78  
    79  			grpcServer.serverCertificate.Store(cert)
    80  
    81  			// set up our TLS config
    82  			if len(secureConfig.CipherSuites) == 0 {
    83  				secureConfig.CipherSuites = DefaultTLSCipherSuites
    84  			}
    85  			getCert := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
    86  				cert := grpcServer.serverCertificate.Load().(tls.Certificate)
    87  				return &cert, nil
    88  			}
    89  
    90  			grpcServer.tls = NewTLSConfig(&tls.Config{
    91  				VerifyPeerCertificate:  secureConfig.VerifyCertificate,
    92  				GetCertificate:         getCert,
    93  				SessionTicketsDisabled: true,
    94  				CipherSuites:           secureConfig.CipherSuites,
    95  			})
    96  
    97  			if serverConfig.SecOpts.TimeShift > 0 {
    98  				timeShift := serverConfig.SecOpts.TimeShift
    99  				grpcServer.tls.config.Time = func() time.Time {
   100  					return time.Now().Add((-1) * timeShift)
   101  				}
   102  			}
   103  			grpcServer.tls.config.ClientAuth = tls.RequestClientCert
   104  			// check if client authentication is required
   105  			if secureConfig.RequireClientCert {
   106  				// require TLS client auth
   107  				grpcServer.tls.config.ClientAuth = tls.RequireAndVerifyClientCert
   108  				// if we have client root CAs, create a certPool
   109  				if len(secureConfig.ClientRootCAs) > 0 {
   110  					grpcServer.tls.config.ClientCAs = x509.NewCertPool()
   111  					for _, clientRootCA := range secureConfig.ClientRootCAs {
   112  						err = grpcServer.appendClientRootCA(clientRootCA)
   113  						if err != nil {
   114  							return nil, err
   115  						}
   116  					}
   117  				}
   118  			}
   119  
   120  			// create credentials and add to server options
   121  			creds := NewServerTransportCredentials(grpcServer.tls, serverConfig.Logger)
   122  			serverOpts = append(serverOpts, grpc.Creds(creds))
   123  		} else {
   124  			return nil, errors.New("serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true")
   125  		}
   126  	}
   127  
   128  	// set max send and recv msg sizes
   129  	maxSendMsgSize := DefaultMaxSendMsgSize
   130  	if serverConfig.MaxSendMsgSize != 0 {
   131  		maxSendMsgSize = serverConfig.MaxSendMsgSize
   132  	}
   133  	maxRecvMsgSize := DefaultMaxRecvMsgSize
   134  	if serverConfig.MaxRecvMsgSize != 0 {
   135  		maxRecvMsgSize = serverConfig.MaxRecvMsgSize
   136  	}
   137  	serverOpts = append(serverOpts, grpc.MaxSendMsgSize(maxSendMsgSize))
   138  	serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(maxRecvMsgSize))
   139  	// set the keepalive options
   140  	serverOpts = append(serverOpts, serverConfig.KaOpts.ServerKeepaliveOptions()...)
   141  	// set connection timeout
   142  	if serverConfig.ConnectionTimeout <= 0 {
   143  		serverConfig.ConnectionTimeout = DefaultConnectionTimeout
   144  	}
   145  	serverOpts = append(
   146  		serverOpts,
   147  		grpc.ConnectionTimeout(serverConfig.ConnectionTimeout))
   148  	// set the interceptors
   149  	if len(serverConfig.StreamInterceptors) > 0 {
   150  		serverOpts = append(
   151  			serverOpts,
   152  			grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(serverConfig.StreamInterceptors...)),
   153  		)
   154  	}
   155  
   156  	if len(serverConfig.UnaryInterceptors) > 0 {
   157  		serverOpts = append(
   158  			serverOpts,
   159  			grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(serverConfig.UnaryInterceptors...)),
   160  		)
   161  	}
   162  
   163  	if serverConfig.ServerStatsHandler != nil {
   164  		serverOpts = append(serverOpts, grpc.StatsHandler(serverConfig.ServerStatsHandler))
   165  	}
   166  
   167  	grpcServer.server = grpc.NewServer(serverOpts...)
   168  
   169  	if serverConfig.HealthCheckEnabled {
   170  		grpcServer.healthServer = health.NewServer()
   171  		healthpb.RegisterHealthServer(grpcServer.server, grpcServer.healthServer)
   172  	}
   173  
   174  	return grpcServer, nil
   175  }
   176  
   177  // SetServerCertificate assigns the current TLS certificate to be the peer's server certificate
   178  func (gServer *GRPCServer) SetServerCertificate(cert tls.Certificate) {
   179  	gServer.serverCertificate.Store(cert)
   180  }
   181  
   182  // Address returns the listen address for this GRPCServer instance
   183  func (gServer *GRPCServer) Address() string {
   184  	return gServer.address
   185  }
   186  
   187  // Listener returns the net.Listener for the GRPCServer instance
   188  func (gServer *GRPCServer) Listener() net.Listener {
   189  	return gServer.listener
   190  }
   191  
   192  // Server returns the grpc.Server for the GRPCServer instance
   193  func (gServer *GRPCServer) Server() *grpc.Server {
   194  	return gServer.server
   195  }
   196  
   197  // ServerCertificate returns the tls.Certificate used by the grpc.Server
   198  func (gServer *GRPCServer) ServerCertificate() tls.Certificate {
   199  	return gServer.serverCertificate.Load().(tls.Certificate)
   200  }
   201  
   202  // TLSEnabled is a flag indicating whether or not TLS is enabled for the
   203  // GRPCServer instance
   204  func (gServer *GRPCServer) TLSEnabled() bool {
   205  	return gServer.tls != nil
   206  }
   207  
   208  // MutualTLSRequired is a flag indicating whether or not client certificates
   209  // are required for this GRPCServer instance
   210  func (gServer *GRPCServer) MutualTLSRequired() bool {
   211  	return gServer.TLSEnabled() &&
   212  		gServer.tls.Config().ClientAuth == tls.RequireAndVerifyClientCert
   213  }
   214  
   215  // Start starts the underlying grpc.Server
   216  func (gServer *GRPCServer) Start() error {
   217  	// if health check is enabled, set the health status for all registered services
   218  	if gServer.healthServer != nil {
   219  		for name := range gServer.server.GetServiceInfo() {
   220  			gServer.healthServer.SetServingStatus(
   221  				name,
   222  				healthpb.HealthCheckResponse_SERVING,
   223  			)
   224  		}
   225  
   226  		gServer.healthServer.SetServingStatus(
   227  			"",
   228  			healthpb.HealthCheckResponse_SERVING,
   229  		)
   230  	}
   231  	return gServer.server.Serve(gServer.listener)
   232  }
   233  
   234  // Stop stops the underlying grpc.Server
   235  func (gServer *GRPCServer) Stop() {
   236  	gServer.server.Stop()
   237  }
   238  
   239  // internal function to add a PEM-encoded clientRootCA
   240  func (gServer *GRPCServer) appendClientRootCA(clientRoot []byte) error {
   241  	certs, err := pemToX509Certs(clientRoot)
   242  	if err != nil {
   243  		return errors.WithMessage(err, "failed to append client root certificate(s)")
   244  	}
   245  
   246  	if len(certs) < 1 {
   247  		return errors.New("no client root certificates found")
   248  	}
   249  
   250  	for _, cert := range certs {
   251  		gServer.tls.AddClientRootCA(cert)
   252  	}
   253  
   254  	return nil
   255  }
   256  
   257  // parse PEM-encoded certs
   258  func pemToX509Certs(pemCerts []byte) ([]*x509.Certificate, error) {
   259  	var certs []*x509.Certificate
   260  
   261  	// it's possible that multiple certs are encoded
   262  	for len(pemCerts) > 0 {
   263  		var block *pem.Block
   264  		block, pemCerts = pem.Decode(pemCerts)
   265  		if block == nil {
   266  			break
   267  		}
   268  
   269  		cert, err := x509.ParseCertificate(block.Bytes)
   270  		if err != nil {
   271  			return nil, err
   272  		}
   273  
   274  		certs = append(certs, cert)
   275  	}
   276  
   277  	return certs, nil
   278  }
   279  
   280  // SetClientRootCAs sets the list of authorities used to verify client
   281  // certificates based on a list of PEM-encoded X509 certificate authorities
   282  func (gServer *GRPCServer) SetClientRootCAs(clientRoots [][]byte) error {
   283  	gServer.lock.Lock()
   284  	defer gServer.lock.Unlock()
   285  
   286  	certPool := x509.NewCertPool()
   287  	for _, clientRoot := range clientRoots {
   288  		if !certPool.AppendCertsFromPEM(clientRoot) {
   289  			return errors.New("failed to set client root certificate(s)")
   290  		}
   291  	}
   292  	gServer.tls.SetClientCAs(certPool)
   293  	return nil
   294  }