vitess.io/vitess@v0.16.2/go/vt/servenv/grpc_server.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package servenv
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"fmt"
    23  	"math"
    24  	"net"
    25  	"time"
    26  
    27  	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
    28  	grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
    29  	"github.com/spf13/pflag"
    30  	"google.golang.org/grpc"
    31  	"google.golang.org/grpc/credentials"
    32  	"google.golang.org/grpc/health"
    33  	healthpb "google.golang.org/grpc/health/grpc_health_v1"
    34  	"google.golang.org/grpc/keepalive"
    35  	"google.golang.org/grpc/reflection"
    36  
    37  	"vitess.io/vitess/go/trace"
    38  	"vitess.io/vitess/go/vt/grpccommon"
    39  	"vitess.io/vitess/go/vt/grpcoptionaltls"
    40  	"vitess.io/vitess/go/vt/log"
    41  	"vitess.io/vitess/go/vt/vttls"
    42  )
    43  
    44  // This file handles gRPC server, on its own port.
    45  // Clients register servers, based on service map:
    46  //
    47  // servenv.RegisterGRPCFlags()
    48  //
    49  //	servenv.OnRun(func() {
    50  //	  if servenv.GRPCCheckServiceMap("XXX") {
    51  //	    pb.RegisterXXX(servenv.GRPCServer, XXX)
    52  //	  }
    53  //	}
    54  //
    55  // Note servenv.GRPCServer can only be used in servenv.OnRun,
    56  // and not before, as it is initialized right before calling OnRun.
    57  var (
    58  	// gRPCAuth specifies which auth plugin to use. Currently only "static" and
    59  	// "mtls" are supported.
    60  	//
    61  	// To expose this flag, call RegisterGRPCAuthServerFlags before ParseFlags.
    62  	gRPCAuth string
    63  
    64  	// GRPCServer is the global server to serve gRPC.
    65  	GRPCServer *grpc.Server
    66  
    67  	authPlugin Authenticator
    68  )
    69  
    70  // Misc. server variables.
    71  var (
    72  	// gRPCPort is the port to listen on for gRPC. If zero, don't listen.
    73  	gRPCPort int
    74  
    75  	// gRPCMaxConnectionAge is the maximum age of a client connection, before GoAway is sent.
    76  	// This is useful for L4 loadbalancing to ensure rebalancing after scaling.
    77  	gRPCMaxConnectionAge = time.Duration(math.MaxInt64)
    78  
    79  	// gRPCMaxConnectionAgeGrace is an additional grace period after GRPCMaxConnectionAge, after which
    80  	// connections are forcibly closed.
    81  	gRPCMaxConnectionAgeGrace = time.Duration(math.MaxInt64)
    82  
    83  	// gRPCInitialConnWindowSize ServerOption that sets window size for a connection.
    84  	// The lower bound for window size is 64K and any value smaller than that will be ignored.
    85  	gRPCInitialConnWindowSize int
    86  
    87  	// gRPCInitialWindowSize ServerOption that sets window size for stream.
    88  	// The lower bound for window size is 64K and any value smaller than that will be ignored.
    89  	gRPCInitialWindowSize int
    90  
    91  	// gRPCKeepAliveEnforcementPolicyMinTime sets the keepalive enforcement policy on the server.
    92  	// This is the minimum amount of time a client should wait before sending a keepalive ping.
    93  	gRPCKeepAliveEnforcementPolicyMinTime = 10 * time.Second
    94  
    95  	// gRPCKeepAliveEnforcementPolicyPermitWithoutStream, if true, instructs the server to allow keepalive pings
    96  	// even when there are no active streams (RPCs). If false, and client sends ping when
    97  	// there are no active streams, server will send GOAWAY and close the connection.
    98  	gRPCKeepAliveEnforcementPolicyPermitWithoutStream bool
    99  )
   100  
   101  // TLS variables.
   102  var (
   103  	// gRPCCert is the cert to use if TLS is enabled.
   104  	gRPCCert string
   105  	// gRPCKey is the key to use if TLS is enabled.
   106  	gRPCKey string
   107  	// gRPCCA is the CA to use if TLS is enabled.
   108  	gRPCCA string
   109  	// gRPCCRL is the CRL (Certificate Revocation List) to use if TLS is
   110  	// enabled.
   111  	gRPCCRL string
   112  	// gRPCEnableOptionalTLS enables an optional TLS mode when a server accepts
   113  	// both TLS and plain-text connections on the same port.
   114  	gRPCEnableOptionalTLS bool
   115  	// gRPCServerCA if specified will combine server cert and server CA.
   116  	gRPCServerCA string
   117  )
   118  
   119  // RegisterGRPCServerFlags registers flags required to run a gRPC server via Run
   120  // or RunDefault.
   121  //
   122  // `go/cmd/*` entrypoints should call this function before
   123  // ParseFlags(WithArgs)? if they wish to run a gRPC server.
   124  func RegisterGRPCServerFlags() {
   125  	OnParse(func(fs *pflag.FlagSet) {
   126  		fs.IntVar(&gRPCPort, "grpc_port", gRPCPort, "Port to listen on for gRPC calls. If zero, do not listen.")
   127  		fs.DurationVar(&gRPCMaxConnectionAge, "grpc_max_connection_age", gRPCMaxConnectionAge, "Maximum age of a client connection before GoAway is sent.")
   128  		fs.DurationVar(&gRPCMaxConnectionAgeGrace, "grpc_max_connection_age_grace", gRPCMaxConnectionAgeGrace, "Additional grace period after grpc_max_connection_age, after which connections are forcibly closed.")
   129  		fs.IntVar(&gRPCInitialConnWindowSize, "grpc_server_initial_conn_window_size", gRPCInitialConnWindowSize, "gRPC server initial connection window size")
   130  		fs.IntVar(&gRPCInitialWindowSize, "grpc_server_initial_window_size", gRPCInitialWindowSize, "gRPC server initial window size")
   131  		fs.DurationVar(&gRPCKeepAliveEnforcementPolicyMinTime, "grpc_server_keepalive_enforcement_policy_min_time", gRPCKeepAliveEnforcementPolicyMinTime, "gRPC server minimum keepalive time")
   132  		fs.BoolVar(&gRPCKeepAliveEnforcementPolicyPermitWithoutStream, "grpc_server_keepalive_enforcement_policy_permit_without_stream", gRPCKeepAliveEnforcementPolicyPermitWithoutStream, "gRPC server permit client keepalive pings even when there are no active streams (RPCs)")
   133  
   134  		fs.StringVar(&gRPCCert, "grpc_cert", gRPCCert, "server certificate to use for gRPC connections, requires grpc_key, enables TLS")
   135  		fs.StringVar(&gRPCKey, "grpc_key", gRPCKey, "server private key to use for gRPC connections, requires grpc_cert, enables TLS")
   136  		fs.StringVar(&gRPCCA, "grpc_ca", gRPCCA, "server CA to use for gRPC connections, requires TLS, and enforces client certificate check")
   137  		fs.StringVar(&gRPCCRL, "grpc_crl", gRPCCRL, "path to a certificate revocation list in PEM format, client certificates will be further verified against this file during TLS handshake")
   138  		fs.BoolVar(&gRPCEnableOptionalTLS, "grpc_enable_optional_tls", gRPCEnableOptionalTLS, "enable optional TLS mode when a server accepts both TLS and plain-text connections on the same port")
   139  		fs.StringVar(&gRPCServerCA, "grpc_server_ca", gRPCServerCA, "path to server CA in PEM format, which will be combine with server cert, return full certificate chain to clients")
   140  	})
   141  }
   142  
   143  // GRPCCert returns the value of the `--grpc_cert` flag.
   144  func GRPCCert() string {
   145  	return gRPCCert
   146  }
   147  
   148  // GRPCCertificateAuthority returns the value of the `--grpc_ca` flag.
   149  func GRPCCertificateAuthority() string {
   150  	return gRPCCA
   151  }
   152  
   153  // GRPCKey returns the value of the `--grpc_key` flag.
   154  func GRPCKey() string {
   155  	return gRPCKey
   156  }
   157  
   158  // GRPCPort returns the value of the `--grpc_port` flag.
   159  func GRPCPort() int {
   160  	return gRPCPort
   161  }
   162  
   163  // isGRPCEnabled returns true if gRPC server is set
   164  func isGRPCEnabled() bool {
   165  	if gRPCPort != 0 {
   166  		return true
   167  	}
   168  
   169  	if socketFile != "" {
   170  		return true
   171  	}
   172  
   173  	return false
   174  }
   175  
   176  // createGRPCServer create the gRPC server we will be using.
   177  // It has to be called after flags are parsed, but before
   178  // services register themselves.
   179  func createGRPCServer() {
   180  	// skip if not registered
   181  	if !isGRPCEnabled() {
   182  		log.Infof("Skipping gRPC server creation")
   183  		return
   184  	}
   185  
   186  	var opts []grpc.ServerOption
   187  	if gRPCCert != "" && gRPCKey != "" {
   188  		config, err := vttls.ServerConfig(gRPCCert, gRPCKey, gRPCCA, gRPCCRL, gRPCServerCA, tls.VersionTLS12)
   189  		if err != nil {
   190  			log.Exitf("Failed to log gRPC cert/key/ca: %v", err)
   191  		}
   192  
   193  		// create the creds server options
   194  		creds := credentials.NewTLS(config)
   195  		if gRPCEnableOptionalTLS {
   196  			log.Warning("Optional TLS is active. Plain-text connections will be accepted")
   197  			creds = grpcoptionaltls.New(creds)
   198  		}
   199  		opts = []grpc.ServerOption{grpc.Creds(creds)}
   200  	}
   201  	// Override the default max message size for both send and receive
   202  	// (which is 4 MiB in gRPC 1.0.0).
   203  	// Large messages can occur when users try to insert or fetch very big
   204  	// rows. If they hit the limit, they'll see the following error:
   205  	// grpc: received message length XXXXXXX exceeding the max size 4194304
   206  	// Note: For gRPC 1.0.0 it's sufficient to set the limit on the server only
   207  	// because it's not enforced on the client side.
   208  	msgSize := grpccommon.MaxMessageSize()
   209  	log.Infof("Setting grpc max message size to %d", msgSize)
   210  	opts = append(opts, grpc.MaxRecvMsgSize(msgSize))
   211  	opts = append(opts, grpc.MaxSendMsgSize(msgSize))
   212  
   213  	if gRPCInitialConnWindowSize != 0 {
   214  		log.Infof("Setting grpc server initial conn window size to %d", int32(gRPCInitialConnWindowSize))
   215  		opts = append(opts, grpc.InitialConnWindowSize(int32(gRPCInitialConnWindowSize)))
   216  	}
   217  
   218  	if gRPCInitialWindowSize != 0 {
   219  		log.Infof("Setting grpc server initial window size to %d", int32(gRPCInitialWindowSize))
   220  		opts = append(opts, grpc.InitialWindowSize(int32(gRPCInitialWindowSize)))
   221  	}
   222  
   223  	ep := keepalive.EnforcementPolicy{
   224  		MinTime:             gRPCKeepAliveEnforcementPolicyMinTime,
   225  		PermitWithoutStream: gRPCKeepAliveEnforcementPolicyPermitWithoutStream,
   226  	}
   227  	opts = append(opts, grpc.KeepaliveEnforcementPolicy(ep))
   228  
   229  	ka := keepalive.ServerParameters{
   230  		MaxConnectionAge:      gRPCMaxConnectionAge,
   231  		MaxConnectionAgeGrace: gRPCMaxConnectionAgeGrace,
   232  	}
   233  	opts = append(opts, grpc.KeepaliveParams(ka))
   234  
   235  	opts = append(opts, interceptors()...)
   236  
   237  	GRPCServer = grpc.NewServer(opts...)
   238  }
   239  
   240  // We can only set a ServerInterceptor once, so we chain multiple interceptors into one
   241  func interceptors() []grpc.ServerOption {
   242  	interceptors := &serverInterceptorBuilder{}
   243  
   244  	if gRPCAuth != "" {
   245  		log.Infof("enabling auth plugin %v", gRPCAuth)
   246  		pluginInitializer := GetAuthenticator(gRPCAuth)
   247  		authPluginImpl, err := pluginInitializer()
   248  		if err != nil {
   249  			log.Fatalf("Failed to load auth plugin: %v", err)
   250  		}
   251  		authPlugin = authPluginImpl
   252  		interceptors.Add(authenticatingStreamInterceptor, authenticatingUnaryInterceptor)
   253  	}
   254  
   255  	if grpccommon.EnableGRPCPrometheus() {
   256  		interceptors.Add(grpc_prometheus.StreamServerInterceptor, grpc_prometheus.UnaryServerInterceptor)
   257  	}
   258  
   259  	trace.AddGrpcServerOptions(interceptors.Add)
   260  
   261  	return interceptors.Build()
   262  }
   263  
   264  func serveGRPC() {
   265  	if grpccommon.EnableGRPCPrometheus() {
   266  		grpc_prometheus.Register(GRPCServer)
   267  		grpc_prometheus.EnableHandlingTimeHistogram()
   268  	}
   269  	// skip if not registered
   270  	if gRPCPort == 0 {
   271  		return
   272  	}
   273  
   274  	// register reflection to support list calls :)
   275  	reflection.Register(GRPCServer)
   276  
   277  	// register health service to support health checks
   278  	healthServer := health.NewServer()
   279  	healthpb.RegisterHealthServer(GRPCServer, healthServer)
   280  
   281  	for service := range GRPCServer.GetServiceInfo() {
   282  		healthServer.SetServingStatus(service, healthpb.HealthCheckResponse_SERVING)
   283  	}
   284  
   285  	// listen on the port
   286  	log.Infof("Listening for gRPC calls on port %v", gRPCPort)
   287  	listener, err := net.Listen("tcp", fmt.Sprintf(":%d", gRPCPort))
   288  	if err != nil {
   289  		log.Exitf("Cannot listen on port %v for gRPC: %v", gRPCPort, err)
   290  	}
   291  
   292  	// and serve on it
   293  	// NOTE: Before we call Serve(), all services must have registered themselves
   294  	//       with "GRPCServer". This is the case because go/vt/servenv/run.go
   295  	//       runs all OnRun() hooks after createGRPCServer() and before
   296  	//       serveGRPC(). If this was not the case, the binary would crash with
   297  	//       the error "grpc: Server.RegisterService after Server.Serve".
   298  	go func() {
   299  		err := GRPCServer.Serve(listener)
   300  		if err != nil {
   301  			log.Exitf("Failed to start grpc server: %v", err)
   302  		}
   303  	}()
   304  
   305  	OnTermSync(func() {
   306  		log.Info("Initiated graceful stop of gRPC server")
   307  		GRPCServer.GracefulStop()
   308  		log.Info("gRPC server stopped")
   309  	})
   310  }
   311  
   312  // GRPCCheckServiceMap returns if we should register a gRPC service
   313  // (and also logs how to enable / disable it)
   314  func GRPCCheckServiceMap(name string) bool {
   315  	// Silently fail individual services if gRPC is not enabled in
   316  	// the first place (either on a grpc port or on the socket file)
   317  	if !isGRPCEnabled() {
   318  		return false
   319  	}
   320  
   321  	// then check ServiceMap
   322  	return checkServiceMap("grpc", name)
   323  }
   324  
   325  func authenticatingStreamInterceptor(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   326  	newCtx, err := authPlugin.Authenticate(stream.Context(), info.FullMethod)
   327  
   328  	if err != nil {
   329  		return err
   330  	}
   331  
   332  	wrapped := WrapServerStream(stream)
   333  	wrapped.WrappedContext = newCtx
   334  	return handler(srv, wrapped)
   335  }
   336  
   337  func authenticatingUnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
   338  	newCtx, err := authPlugin.Authenticate(ctx, info.FullMethod)
   339  	if err != nil {
   340  		return nil, err
   341  	}
   342  
   343  	return handler(newCtx, req)
   344  }
   345  
   346  // WrappedServerStream is based on the service stream wrapper from: https://github.com/grpc-ecosystem/go-grpc-middleware
   347  type WrappedServerStream struct {
   348  	grpc.ServerStream
   349  	WrappedContext context.Context
   350  }
   351  
   352  // Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
   353  func (w *WrappedServerStream) Context() context.Context {
   354  	return w.WrappedContext
   355  }
   356  
   357  // WrapServerStream returns a ServerStream that has the ability to overwrite context.
   358  func WrapServerStream(stream grpc.ServerStream) *WrappedServerStream {
   359  	if existing, ok := stream.(*WrappedServerStream); ok {
   360  		return existing
   361  	}
   362  	return &WrappedServerStream{ServerStream: stream, WrappedContext: stream.Context()}
   363  }
   364  
   365  // serverInterceptorBuilder chains together multiple ServerInterceptors
   366  type serverInterceptorBuilder struct {
   367  	streamInterceptors []grpc.StreamServerInterceptor
   368  	unaryInterceptors  []grpc.UnaryServerInterceptor
   369  }
   370  
   371  // Add adds interceptors to the builder
   372  func (collector *serverInterceptorBuilder) Add(s grpc.StreamServerInterceptor, u grpc.UnaryServerInterceptor) {
   373  	collector.streamInterceptors = append(collector.streamInterceptors, s)
   374  	collector.unaryInterceptors = append(collector.unaryInterceptors, u)
   375  }
   376  
   377  // AddUnary adds a single unary interceptor to the builder
   378  func (collector *serverInterceptorBuilder) AddUnary(u grpc.UnaryServerInterceptor) {
   379  	collector.unaryInterceptors = append(collector.unaryInterceptors, u)
   380  }
   381  
   382  // Build returns DialOptions to add to the grpc.Dial call
   383  func (collector *serverInterceptorBuilder) Build() []grpc.ServerOption {
   384  	log.Infof("Building interceptors with %d unary interceptors and %d stream interceptors", len(collector.unaryInterceptors), len(collector.streamInterceptors))
   385  	switch len(collector.unaryInterceptors) + len(collector.streamInterceptors) {
   386  	case 0:
   387  		return []grpc.ServerOption{}
   388  	default:
   389  		return []grpc.ServerOption{
   390  			grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(collector.unaryInterceptors...)),
   391  			grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(collector.streamInterceptors...)),
   392  		}
   393  	}
   394  }