github.com/letsencrypt/boulder@v0.20251208.0/grpc/server.go (about)

     1  package grpc
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"slices"
    10  	"strings"
    11  	"time"
    12  
    13  	grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
    14  	"github.com/jmhodges/clock"
    15  	"github.com/prometheus/client_golang/prometheus"
    16  	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
    17  	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/filters"
    18  	"google.golang.org/grpc"
    19  	"google.golang.org/grpc/health"
    20  	healthpb "google.golang.org/grpc/health/grpc_health_v1"
    21  	"google.golang.org/grpc/keepalive"
    22  	"google.golang.org/grpc/status"
    23  
    24  	"github.com/letsencrypt/boulder/cmd"
    25  	bcreds "github.com/letsencrypt/boulder/grpc/creds"
    26  	blog "github.com/letsencrypt/boulder/log"
    27  )
    28  
    29  // CodedError is a alias required to appease go vet
    30  var CodedError = status.Errorf
    31  
    32  var errNilTLS = errors.New("boulder/grpc: received nil tls.Config")
    33  
    34  // checker is an interface for checking the health of a grpc service
    35  // implementation.
    36  type checker interface {
    37  	// Health returns nil if the service is healthy, or an error if it is not.
    38  	// If the passed context is canceled, it should return immediately with an
    39  	// error.
    40  	Health(context.Context) error
    41  }
    42  
    43  // service represents a single gRPC service that can be registered with a gRPC
    44  // server.
    45  type service struct {
    46  	desc *grpc.ServiceDesc
    47  	impl any
    48  }
    49  
    50  // serverBuilder implements a builder pattern for constructing new gRPC servers
    51  // and registering gRPC services on those servers.
    52  type serverBuilder struct {
    53  	cfg           *cmd.GRPCServerConfig
    54  	services      map[string]service
    55  	healthSrv     *health.Server
    56  	checkInterval time.Duration
    57  	logger        blog.Logger
    58  	err           error
    59  }
    60  
    61  // NewServer returns an object which can be used to build gRPC servers. It takes
    62  // the server's configuration to perform initialization and a logger for deep
    63  // health checks.
    64  func NewServer(c *cmd.GRPCServerConfig, logger blog.Logger) *serverBuilder {
    65  	return &serverBuilder{cfg: c, services: make(map[string]service), logger: logger}
    66  }
    67  
    68  // WithCheckInterval sets the interval at which the server will check the health
    69  // of its registered services. If this is not called, a default interval of 5
    70  // seconds will be used.
    71  func (sb *serverBuilder) WithCheckInterval(i time.Duration) *serverBuilder {
    72  	sb.checkInterval = i
    73  	return sb
    74  }
    75  
    76  // Add registers a new service (consisting of its description and its
    77  // implementation) to the set of services which will be exposed by this server.
    78  // It returns the modified-in-place serverBuilder so that calls can be chained.
    79  // If there is an error adding this service, it will be exposed when .Build() is
    80  // called.
    81  func (sb *serverBuilder) Add(desc *grpc.ServiceDesc, impl any) *serverBuilder {
    82  	if _, found := sb.services[desc.ServiceName]; found {
    83  		// We've already registered a service with this same name, error out.
    84  		sb.err = fmt.Errorf("attempted double-registration of gRPC service %q", desc.ServiceName)
    85  		return sb
    86  	}
    87  	sb.services[desc.ServiceName] = service{desc: desc, impl: impl}
    88  	return sb
    89  }
    90  
    91  // Build creates a gRPC server that uses the provided *tls.Config and exposes
    92  // all of the services added to the builder. It also exposes a health check
    93  // service. It returns one functions, start(), which should be used to start
    94  // the server. It spawns a goroutine which will listen for OS signals and
    95  // gracefully stop the server if one is caught, causing the start() function to
    96  // exit.
    97  func (sb *serverBuilder) Build(tlsConfig *tls.Config, statsRegistry prometheus.Registerer, clk clock.Clock) (func() error, error) {
    98  	// Register the health service with the server.
    99  	sb.healthSrv = health.NewServer()
   100  	sb.Add(&healthpb.Health_ServiceDesc, sb.healthSrv)
   101  
   102  	// Check to see if any of the calls to .Add() resulted in an error.
   103  	if sb.err != nil {
   104  		return nil, sb.err
   105  	}
   106  
   107  	// Ensure that every configured service also got added.
   108  	var registeredServices []string
   109  	for r := range sb.services {
   110  		registeredServices = append(registeredServices, r)
   111  	}
   112  	for serviceName := range sb.cfg.Services {
   113  		_, ok := sb.services[serviceName]
   114  		if !ok {
   115  			return nil, fmt.Errorf("gRPC service %q in config does not match any service: %s", serviceName, strings.Join(registeredServices, ", "))
   116  		}
   117  	}
   118  
   119  	if tlsConfig == nil {
   120  		return nil, errNilTLS
   121  	}
   122  
   123  	// Collect all names which should be allowed to connect to the server at all.
   124  	// This is the names which are allowlisted at the server level, plus the union
   125  	// of all names which are allowlisted for any individual service.
   126  	acceptedSANs := make(map[string]struct{})
   127  	var acceptedSANsSlice []string
   128  	for _, service := range sb.cfg.Services {
   129  		for _, name := range service.ClientNames {
   130  			acceptedSANs[name] = struct{}{}
   131  			if !slices.Contains(acceptedSANsSlice, name) {
   132  				acceptedSANsSlice = append(acceptedSANsSlice, name)
   133  			}
   134  		}
   135  	}
   136  
   137  	// Ensure that the health service has the same ClientNames as the other
   138  	// services, so that health checks can be performed by clients which are
   139  	// allowed to connect to the server.
   140  	sb.cfg.Services[healthpb.Health_ServiceDesc.ServiceName].ClientNames = acceptedSANsSlice
   141  
   142  	creds, err := bcreds.NewServerCredentials(tlsConfig, acceptedSANs)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	// Set up all of our interceptors which handle metrics, traces, error
   148  	// propagation, and more.
   149  	metrics, err := newServerMetrics(statsRegistry)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	var ai serverInterceptor
   155  	if len(sb.cfg.Services) > 0 {
   156  		ai = newServiceAuthChecker(sb.cfg)
   157  	} else {
   158  		ai = &noopServerInterceptor{}
   159  	}
   160  
   161  	mi := newServerMetadataInterceptor(metrics, clk)
   162  
   163  	unaryInterceptors := []grpc.UnaryServerInterceptor{
   164  		mi.metrics.grpcMetrics.UnaryServerInterceptor(),
   165  		ai.Unary,
   166  		mi.Unary,
   167  	}
   168  
   169  	streamInterceptors := []grpc.StreamServerInterceptor{
   170  		mi.metrics.grpcMetrics.StreamServerInterceptor(),
   171  		ai.Stream,
   172  		mi.Stream,
   173  	}
   174  
   175  	options := []grpc.ServerOption{
   176  		grpc.Creds(creds),
   177  		grpc.ChainUnaryInterceptor(unaryInterceptors...),
   178  		grpc.ChainStreamInterceptor(streamInterceptors...),
   179  		grpc.StatsHandler(otelgrpc.NewServerHandler(otelgrpc.WithFilter(filters.Not(filters.HealthCheck())))),
   180  	}
   181  	if sb.cfg.MaxConnectionAge.Duration > 0 {
   182  		options = append(options,
   183  			grpc.KeepaliveParams(keepalive.ServerParameters{
   184  				MaxConnectionAge: sb.cfg.MaxConnectionAge.Duration,
   185  			}))
   186  	}
   187  
   188  	// Create the server itself and register all of our services on it.
   189  	server := grpc.NewServer(options...)
   190  	for _, service := range sb.services {
   191  		server.RegisterService(service.desc, service.impl)
   192  	}
   193  
   194  	if sb.cfg.Address == "" {
   195  		return nil, errors.New("GRPC listen address not configured")
   196  	}
   197  	sb.logger.Infof("grpc listening on %s", sb.cfg.Address)
   198  
   199  	// Finally return the functions which will start and stop the server.
   200  	listener, err := net.Listen("tcp", sb.cfg.Address)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	start := func() error {
   206  		return server.Serve(listener)
   207  	}
   208  
   209  	// Initialize long-running health checks of all services which implement the
   210  	// checker interface.
   211  	if sb.checkInterval <= 0 {
   212  		sb.checkInterval = 5 * time.Second
   213  	}
   214  	healthCtx, stopHealthChecks := context.WithCancel(context.Background())
   215  	for _, s := range sb.services {
   216  		check, ok := s.impl.(checker)
   217  		if !ok {
   218  			continue
   219  		}
   220  		sb.initLongRunningCheck(healthCtx, s.desc.ServiceName, check.Health)
   221  	}
   222  
   223  	// Start a goroutine which listens for a termination signal, and then
   224  	// gracefully stops the gRPC server. This in turn causes the start() function
   225  	// to exit, allowing its caller (generally a main() function) to exit.
   226  	go cmd.CatchSignals(func() {
   227  		stopHealthChecks()
   228  		sb.healthSrv.Shutdown()
   229  		server.GracefulStop()
   230  	})
   231  
   232  	return start, nil
   233  }
   234  
   235  // initLongRunningCheck initializes a goroutine which will periodically check
   236  // the health of the provided service and update the health server accordingly.
   237  //
   238  // TODO(#8255): Remove the service parameter and instead rely on transitioning
   239  // the overall health of the server (e.g. "") instead of individual services.
   240  func (sb *serverBuilder) initLongRunningCheck(shutdownCtx context.Context, service string, checkImpl func(context.Context) error) {
   241  	// Set the initial health status for the service.
   242  	sb.healthSrv.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING)
   243  	sb.healthSrv.SetServingStatus(service, healthpb.HealthCheckResponse_NOT_SERVING)
   244  
   245  	// check is a helper function that checks the health of the service and, if
   246  	// necessary, updates its status in the health server.
   247  	checkAndMaybeUpdate := func(checkCtx context.Context, last healthpb.HealthCheckResponse_ServingStatus) healthpb.HealthCheckResponse_ServingStatus {
   248  		// Make a context with a timeout at 90% of the interval.
   249  		checkImplCtx, cancel := context.WithTimeout(checkCtx, sb.checkInterval*9/10)
   250  		defer cancel()
   251  
   252  		var next healthpb.HealthCheckResponse_ServingStatus
   253  		err := checkImpl(checkImplCtx)
   254  		if err != nil {
   255  			next = healthpb.HealthCheckResponse_NOT_SERVING
   256  		} else {
   257  			next = healthpb.HealthCheckResponse_SERVING
   258  		}
   259  
   260  		if last == next {
   261  			// No change in health status.
   262  			return next
   263  		}
   264  
   265  		if next != healthpb.HealthCheckResponse_SERVING {
   266  			sb.logger.Errf("transitioning overall health from %q to %q, due to: %s", last, next, err)
   267  			sb.logger.Errf("transitioning health of %q from %q to %q, due to: %s", service, last, next, err)
   268  		} else {
   269  			sb.logger.Infof("transitioning overall health from %q to %q", last, next)
   270  			sb.logger.Infof("transitioning health of %q from %q to %q", service, last, next)
   271  		}
   272  		sb.healthSrv.SetServingStatus("", next)
   273  		sb.healthSrv.SetServingStatus(service, next)
   274  		return next
   275  	}
   276  
   277  	go func() {
   278  		ticker := time.NewTicker(sb.checkInterval)
   279  		defer ticker.Stop()
   280  
   281  		// Assume the service is not healthy to start.
   282  		last := healthpb.HealthCheckResponse_NOT_SERVING
   283  
   284  		// Check immediately, and then at the specified interval.
   285  		last = checkAndMaybeUpdate(shutdownCtx, last)
   286  		for {
   287  			select {
   288  			case <-shutdownCtx.Done():
   289  				// The server is shutting down.
   290  				return
   291  			case <-ticker.C:
   292  				last = checkAndMaybeUpdate(shutdownCtx, last)
   293  			}
   294  		}
   295  	}()
   296  }
   297  
   298  // serverMetrics is a struct type used to return a few registered metrics from
   299  // `newServerMetrics`
   300  type serverMetrics struct {
   301  	grpcMetrics *grpc_prometheus.ServerMetrics
   302  	rpcLag      prometheus.Histogram
   303  }
   304  
   305  // newServerMetrics registers metrics with a registry. It constructs and
   306  // registers a *grpc_prometheus.ServerMetrics with timing histogram enabled as
   307  // well as a prometheus Histogram for RPC latency. If called more than once on a
   308  // single registry, it will gracefully avoid registering duplicate metrics.
   309  func newServerMetrics(stats prometheus.Registerer) (serverMetrics, error) {
   310  	// Create the grpc prometheus server metrics instance and register it
   311  	grpcMetrics := grpc_prometheus.NewServerMetrics(
   312  		grpc_prometheus.WithServerHandlingTimeHistogram(
   313  			grpc_prometheus.WithHistogramBuckets([]float64{.01, .025, .05, .1, .5, 1, 2.5, 5, 10, 45, 90}),
   314  		),
   315  	)
   316  	err := stats.Register(grpcMetrics)
   317  	if err != nil {
   318  		are := prometheus.AlreadyRegisteredError{}
   319  		if errors.As(err, &are) {
   320  			grpcMetrics = are.ExistingCollector.(*grpc_prometheus.ServerMetrics)
   321  		} else {
   322  			return serverMetrics{}, err
   323  		}
   324  	}
   325  
   326  	// rpcLag is a prometheus histogram tracking the difference between the time
   327  	// the client sent an RPC and the time the server received it. Create and
   328  	// register it.
   329  	rpcLag := prometheus.NewHistogram(
   330  		prometheus.HistogramOpts{
   331  			Name: "grpc_lag",
   332  			Help: "Delta between client RPC send time and server RPC receipt time",
   333  		})
   334  	err = stats.Register(rpcLag)
   335  	if err != nil {
   336  		are := prometheus.AlreadyRegisteredError{}
   337  		if errors.As(err, &are) {
   338  			rpcLag = are.ExistingCollector.(prometheus.Histogram)
   339  		} else {
   340  			return serverMetrics{}, err
   341  		}
   342  	}
   343  
   344  	return serverMetrics{
   345  		grpcMetrics: grpcMetrics,
   346  		rpcLag:      rpcLag,
   347  	}, nil
   348  }