github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/server.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm 8 9 import ( 10 "crypto/tls" 11 "crypto/x509" 12 "encoding/pem" 13 "net" 14 "sync" 15 "sync/atomic" 16 "time" 17 18 grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" 19 "github.com/pkg/errors" 20 "google.golang.org/grpc" 21 "google.golang.org/grpc/health" 22 healthpb "google.golang.org/grpc/health/grpc_health_v1" 23 ) 24 25 type GRPCServer struct { 26 // Listen address for the server specified as hostname:port 27 address string 28 // Listener for handling network requests 29 listener net.Listener 30 // GRPC server 31 server *grpc.Server 32 // Certificate presented by the server for TLS communication 33 // stored as an atomic reference 34 serverCertificate atomic.Value 35 // lock to protect concurrent access to append / remove 36 lock *sync.Mutex 37 // TLS configuration used by the grpc server 38 tls *TLSConfig 39 // Server for gRPC Health Check Protocol. 40 healthServer *health.Server 41 } 42 43 // NewGRPCServer creates a new implementation of a GRPCServer given a 44 // listen address 45 func NewGRPCServer(address string, serverConfig ServerConfig) (*GRPCServer, error) { 46 if address == "" { 47 return nil, errors.New("missing address parameter") 48 } 49 // create our listener 50 lis, err := net.Listen("tcp", address) 51 if err != nil { 52 return nil, err 53 } 54 return NewGRPCServerFromListener(lis, serverConfig) 55 } 56 57 // NewGRPCServerFromListener creates a new implementation of a GRPCServer given 58 // an existing net.Listener instance using default keepalive 59 func NewGRPCServerFromListener(listener net.Listener, serverConfig ServerConfig) (*GRPCServer, error) { 60 grpcServer := &GRPCServer{ 61 address: listener.Addr().String(), 62 listener: listener, 63 lock: &sync.Mutex{}, 64 } 65 66 // set up our server options 67 var serverOpts []grpc.ServerOption 68 69 secureConfig := serverConfig.SecOpts 70 if secureConfig.UseTLS { 71 // both key and cert are required 72 if secureConfig.Key != nil && secureConfig.Certificate != nil { 73 // load server public and private keys 74 cert, err := tls.X509KeyPair(secureConfig.Certificate, secureConfig.Key) 75 if err != nil { 76 return nil, err 77 } 78 79 grpcServer.serverCertificate.Store(cert) 80 81 // set up our TLS config 82 if len(secureConfig.CipherSuites) == 0 { 83 secureConfig.CipherSuites = DefaultTLSCipherSuites 84 } 85 getCert := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { 86 cert := grpcServer.serverCertificate.Load().(tls.Certificate) 87 return &cert, nil 88 } 89 90 grpcServer.tls = NewTLSConfig(&tls.Config{ 91 VerifyPeerCertificate: secureConfig.VerifyCertificate, 92 GetCertificate: getCert, 93 SessionTicketsDisabled: true, 94 CipherSuites: secureConfig.CipherSuites, 95 }) 96 97 if serverConfig.SecOpts.TimeShift > 0 { 98 timeShift := serverConfig.SecOpts.TimeShift 99 grpcServer.tls.config.Time = func() time.Time { 100 return time.Now().Add((-1) * timeShift) 101 } 102 } 103 grpcServer.tls.config.ClientAuth = tls.RequestClientCert 104 // check if client authentication is required 105 if secureConfig.RequireClientCert { 106 // require TLS client auth 107 grpcServer.tls.config.ClientAuth = tls.RequireAndVerifyClientCert 108 // if we have client root CAs, create a certPool 109 if len(secureConfig.ClientRootCAs) > 0 { 110 grpcServer.tls.config.ClientCAs = x509.NewCertPool() 111 for _, clientRootCA := range secureConfig.ClientRootCAs { 112 err = grpcServer.appendClientRootCA(clientRootCA) 113 if err != nil { 114 return nil, err 115 } 116 } 117 } 118 } 119 120 // create credentials and add to server options 121 creds := NewServerTransportCredentials(grpcServer.tls, serverConfig.Logger) 122 serverOpts = append(serverOpts, grpc.Creds(creds)) 123 } else { 124 return nil, errors.New("serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true") 125 } 126 } 127 128 // set max send and recv msg sizes 129 maxSendMsgSize := DefaultMaxSendMsgSize 130 if serverConfig.MaxSendMsgSize != 0 { 131 maxSendMsgSize = serverConfig.MaxSendMsgSize 132 } 133 maxRecvMsgSize := DefaultMaxRecvMsgSize 134 if serverConfig.MaxRecvMsgSize != 0 { 135 maxRecvMsgSize = serverConfig.MaxRecvMsgSize 136 } 137 serverOpts = append(serverOpts, grpc.MaxSendMsgSize(maxSendMsgSize)) 138 serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(maxRecvMsgSize)) 139 // set the keepalive options 140 serverOpts = append(serverOpts, serverConfig.KaOpts.ServerKeepaliveOptions()...) 141 // set connection timeout 142 if serverConfig.ConnectionTimeout <= 0 { 143 serverConfig.ConnectionTimeout = DefaultConnectionTimeout 144 } 145 serverOpts = append( 146 serverOpts, 147 grpc.ConnectionTimeout(serverConfig.ConnectionTimeout)) 148 // set the interceptors 149 if len(serverConfig.StreamInterceptors) > 0 { 150 serverOpts = append( 151 serverOpts, 152 grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(serverConfig.StreamInterceptors...)), 153 ) 154 } 155 156 if len(serverConfig.UnaryInterceptors) > 0 { 157 serverOpts = append( 158 serverOpts, 159 grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(serverConfig.UnaryInterceptors...)), 160 ) 161 } 162 163 if serverConfig.ServerStatsHandler != nil { 164 serverOpts = append(serverOpts, grpc.StatsHandler(serverConfig.ServerStatsHandler)) 165 } 166 167 grpcServer.server = grpc.NewServer(serverOpts...) 168 169 if serverConfig.HealthCheckEnabled { 170 grpcServer.healthServer = health.NewServer() 171 healthpb.RegisterHealthServer(grpcServer.server, grpcServer.healthServer) 172 } 173 174 return grpcServer, nil 175 } 176 177 // SetServerCertificate assigns the current TLS certificate to be the peer's server certificate 178 func (gServer *GRPCServer) SetServerCertificate(cert tls.Certificate) { 179 gServer.serverCertificate.Store(cert) 180 } 181 182 // Address returns the listen address for this GRPCServer instance 183 func (gServer *GRPCServer) Address() string { 184 return gServer.address 185 } 186 187 // Listener returns the net.Listener for the GRPCServer instance 188 func (gServer *GRPCServer) Listener() net.Listener { 189 return gServer.listener 190 } 191 192 // Server returns the grpc.Server for the GRPCServer instance 193 func (gServer *GRPCServer) Server() *grpc.Server { 194 return gServer.server 195 } 196 197 // ServerCertificate returns the tls.Certificate used by the grpc.Server 198 func (gServer *GRPCServer) ServerCertificate() tls.Certificate { 199 return gServer.serverCertificate.Load().(tls.Certificate) 200 } 201 202 // TLSEnabled is a flag indicating whether or not TLS is enabled for the 203 // GRPCServer instance 204 func (gServer *GRPCServer) TLSEnabled() bool { 205 return gServer.tls != nil 206 } 207 208 // MutualTLSRequired is a flag indicating whether or not client certificates 209 // are required for this GRPCServer instance 210 func (gServer *GRPCServer) MutualTLSRequired() bool { 211 return gServer.TLSEnabled() && 212 gServer.tls.Config().ClientAuth == tls.RequireAndVerifyClientCert 213 } 214 215 // Start starts the underlying grpc.Server 216 func (gServer *GRPCServer) Start() error { 217 // if health check is enabled, set the health status for all registered services 218 if gServer.healthServer != nil { 219 for name := range gServer.server.GetServiceInfo() { 220 gServer.healthServer.SetServingStatus( 221 name, 222 healthpb.HealthCheckResponse_SERVING, 223 ) 224 } 225 226 gServer.healthServer.SetServingStatus( 227 "", 228 healthpb.HealthCheckResponse_SERVING, 229 ) 230 } 231 return gServer.server.Serve(gServer.listener) 232 } 233 234 // Stop stops the underlying grpc.Server 235 func (gServer *GRPCServer) Stop() { 236 gServer.server.Stop() 237 } 238 239 // internal function to add a PEM-encoded clientRootCA 240 func (gServer *GRPCServer) appendClientRootCA(clientRoot []byte) error { 241 certs, err := pemToX509Certs(clientRoot) 242 if err != nil { 243 return errors.WithMessage(err, "failed to append client root certificate(s)") 244 } 245 246 if len(certs) < 1 { 247 return errors.New("no client root certificates found") 248 } 249 250 for _, cert := range certs { 251 gServer.tls.AddClientRootCA(cert) 252 } 253 254 return nil 255 } 256 257 // parse PEM-encoded certs 258 func pemToX509Certs(pemCerts []byte) ([]*x509.Certificate, error) { 259 var certs []*x509.Certificate 260 261 // it's possible that multiple certs are encoded 262 for len(pemCerts) > 0 { 263 var block *pem.Block 264 block, pemCerts = pem.Decode(pemCerts) 265 if block == nil { 266 break 267 } 268 269 cert, err := x509.ParseCertificate(block.Bytes) 270 if err != nil { 271 return nil, err 272 } 273 274 certs = append(certs, cert) 275 } 276 277 return certs, nil 278 } 279 280 // SetClientRootCAs sets the list of authorities used to verify client 281 // certificates based on a list of PEM-encoded X509 certificate authorities 282 func (gServer *GRPCServer) SetClientRootCAs(clientRoots [][]byte) error { 283 gServer.lock.Lock() 284 defer gServer.lock.Unlock() 285 286 certPool := x509.NewCertPool() 287 for _, clientRoot := range clientRoots { 288 if !certPool.AppendCertsFromPEM(clientRoot) { 289 return errors.New("failed to set client root certificate(s)") 290 } 291 } 292 gServer.tls.SetClientCAs(certPool) 293 return nil 294 }