github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/server/defaults.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"net/http/pprof"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/fatih/color"
    13  	"github.com/go-logr/zerologr"
    14  	grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
    15  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
    16  	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
    17  	grpclog "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
    18  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector"
    19  	"github.com/jzelinskie/cobrautil/v2"
    20  	"github.com/jzelinskie/cobrautil/v2/cobraotel"
    21  	"github.com/jzelinskie/cobrautil/v2/cobrazerolog"
    22  	"github.com/prometheus/client_golang/prometheus"
    23  	"github.com/prometheus/client_golang/prometheus/promhttp"
    24  	"github.com/rs/zerolog"
    25  	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
    26  	"go.opentelemetry.io/otel/trace"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/codes"
    29  
    30  	"github.com/authzed/authzed-go/pkg/requestmeta"
    31  
    32  	"github.com/authzed/spicedb/internal/dispatch"
    33  	"github.com/authzed/spicedb/internal/logging"
    34  	consistencymw "github.com/authzed/spicedb/internal/middleware/consistency"
    35  	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
    36  	dispatchmw "github.com/authzed/spicedb/internal/middleware/dispatcher"
    37  	"github.com/authzed/spicedb/internal/middleware/servicespecific"
    38  	"github.com/authzed/spicedb/pkg/datastore"
    39  	logmw "github.com/authzed/spicedb/pkg/middleware/logging"
    40  	"github.com/authzed/spicedb/pkg/middleware/requestid"
    41  	"github.com/authzed/spicedb/pkg/middleware/serverversion"
    42  	"github.com/authzed/spicedb/pkg/releases"
    43  	"github.com/authzed/spicedb/pkg/runtime"
    44  )
    45  
    46  var DisableTelemetryHandler *prometheus.Registry
    47  
    48  // ServeExample creates an example usage string with the provided program name.
    49  func ServeExample(programName string) string {
    50  	return fmt.Sprintf(`	%[1]s:
    51  		%[3]s serve --grpc-preshared-key "somerandomkeyhere"
    52  
    53  	%[2]s:
    54  		%[3]s serve --grpc-preshared-key "realkeyhere" --grpc-tls-cert-path path/to/tls/cert --grpc-tls-key-path path/to/tls/key \
    55  			--http-tls-cert-path path/to/tls/cert --http-tls-key-path path/to/tls/key \
    56  			--datastore-engine postgres --datastore-conn-uri "postgres-connection-string-here"
    57  `,
    58  		color.YellowString("No TLS and in-memory"),
    59  		color.GreenString("TLS and a real datastore"),
    60  		programName,
    61  	)
    62  }
    63  
    64  // DefaultPreRunE sets up viper, zerolog, and OpenTelemetry flag handling for a
    65  // command.
    66  func DefaultPreRunE(programName string) cobrautil.CobraRunFunc {
    67  	return cobrautil.CommandStack(
    68  		cobrautil.SyncViperDotEnvPreRunE(programName, "spicedb.env", zerologr.New(&logging.Logger)),
    69  		cobrazerolog.New(
    70  			cobrazerolog.WithTarget(func(logger zerolog.Logger) {
    71  				logging.SetGlobalLogger(logger)
    72  			}),
    73  		).RunE(),
    74  		cobraotel.New("spicedb",
    75  			cobraotel.WithLogger(zerologr.New(&logging.Logger)),
    76  		).RunE(),
    77  		releases.CheckAndLogRunE(),
    78  		runtime.RunE(),
    79  	)
    80  }
    81  
    82  // MetricsHandler sets up an HTTP server that handles serving Prometheus
    83  // metrics and pprof endpoints.
    84  func MetricsHandler(telemetryRegistry *prometheus.Registry, c *Config) http.Handler {
    85  	mux := http.NewServeMux()
    86  
    87  	mux.Handle("/metrics", promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{
    88  		// Opt into OpenMetrics e.g. to support exemplars.
    89  		EnableOpenMetrics: true,
    90  	}))
    91  	if telemetryRegistry != nil {
    92  		mux.Handle("/telemetry", promhttp.HandlerFor(telemetryRegistry, promhttp.HandlerOpts{}))
    93  	}
    94  
    95  	mux.HandleFunc("/debug/pprof/", pprof.Index)
    96  	mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
    97  	mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
    98  	mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
    99  	mux.HandleFunc("/debug/pprof/cmdline", func(w http.ResponseWriter, r *http.Request) {
   100  		w.WriteHeader(http.StatusNotFound)
   101  		fmt.Fprintf(w, "This profile type has been disabled to avoid leaking private command-line arguments")
   102  	})
   103  	mux.HandleFunc("/debug/config", func(w http.ResponseWriter, r *http.Request) {
   104  		if c == nil {
   105  			w.WriteHeader(http.StatusNotFound)
   106  			return
   107  		}
   108  
   109  		json, err := json.MarshalIndent(c.DebugMap(), "", "  ")
   110  		if err != nil {
   111  			w.WriteHeader(http.StatusInternalServerError)
   112  			return
   113  		}
   114  
   115  		fmt.Fprintf(w, "%s", string(json))
   116  	})
   117  
   118  	return mux
   119  }
   120  
   121  var defaultCodeToLevel = grpclog.WithLevels(func(code codes.Code) grpclog.Level {
   122  	if code == codes.DeadlineExceeded {
   123  		// The server has a deadline set, so we consider it a normal condition.
   124  		// This ensures that we don't log them as errors.
   125  		return grpclog.LevelInfo
   126  	}
   127  	return grpclog.DefaultServerCodeToLevel(code)
   128  })
   129  
   130  var dispatchDefaultCodeToLevel = grpclog.WithLevels(func(code codes.Code) grpclog.Level {
   131  	switch code {
   132  	case codes.OK, codes.Canceled:
   133  		return grpclog.LevelDebug
   134  	case codes.NotFound, codes.AlreadyExists, codes.InvalidArgument, codes.Unauthenticated:
   135  		return grpclog.LevelWarn
   136  	default:
   137  		return grpclog.DefaultServerCodeToLevel(code)
   138  	}
   139  })
   140  
   141  var durationFieldOption = grpclog.WithDurationField(func(duration time.Duration) grpclog.Fields {
   142  	return grpclog.Fields{"grpc.time_ms", duration.Milliseconds()}
   143  })
   144  
   145  var traceIDFieldOption = grpclog.WithFieldsFromContext(func(ctx context.Context) grpclog.Fields {
   146  	if span := trace.SpanContextFromContext(ctx); span.IsSampled() {
   147  		return grpclog.Fields{"traceID", span.TraceID().String()}
   148  	}
   149  	return nil
   150  })
   151  
   152  var alwaysDebugOption = grpclog.WithLevels(func(code codes.Code) grpclog.Level {
   153  	return grpclog.LevelDebug
   154  })
   155  
   156  const (
   157  	DefaultMiddlewareRequestID     = "requestid"
   158  	DefaultMiddlewareLog           = "log"
   159  	DefaultMiddlewareGRPCLog       = "grpclog"
   160  	DefaultMiddlewareOTelGRPC      = "otelgrpc"
   161  	DefaultMiddlewareGRPCAuth      = "grpcauth"
   162  	DefaultMiddlewareGRPCProm      = "grpcprom"
   163  	DefaultMiddlewareServerVersion = "serverversion"
   164  
   165  	DefaultInternalMiddlewareDispatch       = "dispatch"
   166  	DefaultInternalMiddlewareDatastore      = "datastore"
   167  	DefaultInternalMiddlewareConsistency    = "consistency"
   168  	DefaultInternalMiddlewareServerSpecific = "servicespecific"
   169  )
   170  
   171  type MiddlewareOption struct {
   172  	logger                zerolog.Logger
   173  	authFunc              grpcauth.AuthFunc
   174  	enableVersionResponse bool
   175  	dispatcher            dispatch.Dispatcher
   176  	ds                    datastore.Datastore
   177  	enableRequestLog      bool
   178  	enableResponseLog     bool
   179  	disableGRPCHistogram  bool
   180  }
   181  
   182  // gRPCMetricsUnaryInterceptor creates the default prometheus metrics interceptor for unary gRPCs
   183  var gRPCMetricsUnaryInterceptor grpc.UnaryServerInterceptor
   184  
   185  // gRPCMetricsStreamingInterceptor creates the default prometheus metrics interceptor for streaming gRPCs
   186  var gRPCMetricsStreamingInterceptor grpc.StreamServerInterceptor
   187  
   188  var serverMetricsOnce sync.Once
   189  
   190  // GRPCMetrics returns the interceptors used for the default gRPC metrics from grpc-ecosystem/go-grpc-middleware
   191  func GRPCMetrics(disableLatencyHistogram bool) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) {
   192  	serverMetricsOnce.Do(func() {
   193  		gRPCMetricsUnaryInterceptor, gRPCMetricsStreamingInterceptor = createServerMetrics(disableLatencyHistogram)
   194  	})
   195  
   196  	return gRPCMetricsUnaryInterceptor, gRPCMetricsStreamingInterceptor
   197  }
   198  
   199  const healthCheckRoute = "/grpc.health.v1.Health/Check"
   200  
   201  func matchesRoute(route string) func(_ context.Context, c interceptors.CallMeta) bool {
   202  	return func(_ context.Context, c interceptors.CallMeta) bool {
   203  		return c.FullMethod() == route
   204  	}
   205  }
   206  
   207  func doesNotMatchRoute(route string) func(_ context.Context, c interceptors.CallMeta) bool {
   208  	return func(_ context.Context, c interceptors.CallMeta) bool {
   209  		return c.FullMethod() != route
   210  	}
   211  }
   212  
   213  // DefaultUnaryMiddleware generates the default middleware chain used for the public SpiceDB Unary gRPC methods
   214  func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryServerInterceptor], error) {
   215  	grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.disableGRPCHistogram)
   216  	chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   217  		NewUnaryMiddleware().
   218  			WithName(DefaultMiddlewareRequestID).
   219  			WithInterceptor(requestid.UnaryServerInterceptor(requestid.GenerateIfMissing(true))).
   220  			Done(),
   221  
   222  		NewUnaryMiddleware().
   223  			WithName(DefaultMiddlewareLog).
   224  			WithInterceptor(logmw.UnaryServerInterceptor(logmw.ExtractMetadataField(string(requestmeta.RequestIDKey), "requestID"))).
   225  			Done(),
   226  
   227  		NewUnaryMiddleware().
   228  			WithName(DefaultMiddlewareOTelGRPC).
   229  			WithInterceptor(otelgrpc.UnaryServerInterceptor()). // nolint: staticcheck
   230  			Done(),
   231  
   232  		NewUnaryMiddleware().
   233  			WithName(DefaultMiddlewareGRPCLog + "-debug").
   234  			WithInterceptor(selector.UnaryServerInterceptor(
   235  				grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
   236  										selector.MatchFunc(matchesRoute(healthCheckRoute)))).
   237  			EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
   238  			Done(),
   239  
   240  		NewUnaryMiddleware().
   241  			WithName(DefaultMiddlewareGRPCLog).
   242  			WithInterceptor(selector.UnaryServerInterceptor(
   243  				grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
   244  										selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))).
   245  			EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
   246  			Done(),
   247  
   248  		NewUnaryMiddleware().
   249  			WithName(DefaultMiddlewareGRPCProm).
   250  			WithInterceptor(grpcMetricsUnaryInterceptor).
   251  			Done(),
   252  
   253  		NewUnaryMiddleware().
   254  			WithName(DefaultMiddlewareGRPCAuth).
   255  			WithInterceptor(grpcauth.UnaryServerInterceptor(opts.authFunc)).
   256  			EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures
   257  			Done(),
   258  
   259  		NewUnaryMiddleware().
   260  			WithName(DefaultMiddlewareServerVersion).
   261  			WithInterceptor(serverversion.UnaryServerInterceptor(opts.enableVersionResponse)).
   262  			Done(),
   263  
   264  		NewUnaryMiddleware().
   265  			WithName(DefaultInternalMiddlewareDispatch).
   266  			WithInternal(true).
   267  			WithInterceptor(dispatchmw.UnaryServerInterceptor(opts.dispatcher)).
   268  			Done(),
   269  
   270  		NewUnaryMiddleware().
   271  			WithName(DefaultInternalMiddlewareDatastore).
   272  			WithInternal(true).
   273  			WithInterceptor(datastoremw.UnaryServerInterceptor(opts.ds)).
   274  			Done(),
   275  
   276  		NewUnaryMiddleware().
   277  			WithName(DefaultInternalMiddlewareConsistency).
   278  			WithInternal(true).
   279  			WithInterceptor(consistencymw.UnaryServerInterceptor()).
   280  			Done(),
   281  
   282  		NewUnaryMiddleware().
   283  			WithName(DefaultInternalMiddlewareServerSpecific).
   284  			WithInternal(true).
   285  			WithInterceptor(servicespecific.UnaryServerInterceptor).
   286  			Done(),
   287  	}...)
   288  	return &chain, err
   289  }
   290  
   291  // DefaultStreamingMiddleware generates the default middleware chain used for the public SpiceDB Streaming gRPC methods
   292  func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.StreamServerInterceptor], error) {
   293  	_, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.disableGRPCHistogram)
   294  	chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.StreamServerInterceptor]{
   295  		NewStreamMiddleware().
   296  			WithName(DefaultMiddlewareRequestID).
   297  			WithInterceptor(requestid.StreamServerInterceptor(requestid.GenerateIfMissing(true))).
   298  			Done(),
   299  
   300  		NewStreamMiddleware().
   301  			WithName(DefaultMiddlewareLog).
   302  			WithInterceptor(logmw.StreamServerInterceptor(logmw.ExtractMetadataField(string(requestmeta.RequestIDKey), "requestID"))).
   303  			Done(),
   304  
   305  		NewStreamMiddleware().
   306  			WithName(DefaultMiddlewareOTelGRPC).
   307  			WithInterceptor(otelgrpc.StreamServerInterceptor()). // nolint: staticcheck
   308  			Done(),
   309  
   310  		NewStreamMiddleware().
   311  			WithName(DefaultMiddlewareGRPCLog + "-debug").
   312  			WithInterceptor(selector.StreamServerInterceptor(
   313  				grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption),
   314  											selector.MatchFunc(matchesRoute(healthCheckRoute)))).
   315  			EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
   316  			Done(),
   317  
   318  		NewStreamMiddleware().
   319  			WithName(DefaultMiddlewareGRPCLog).
   320  			WithInterceptor(selector.StreamServerInterceptor(
   321  				grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption),
   322  											selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))).
   323  			EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs),
   324  			Done(),
   325  
   326  		NewStreamMiddleware().
   327  			WithName(DefaultMiddlewareGRPCProm).
   328  			WithInterceptor(grpcMetricsStreamingInterceptor).
   329  			Done(),
   330  
   331  		NewStreamMiddleware().
   332  			WithName(DefaultMiddlewareGRPCAuth).
   333  			WithInterceptor(grpcauth.StreamServerInterceptor(opts.authFunc)).
   334  			EnsureInterceptorAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures
   335  			Done(),
   336  
   337  		NewStreamMiddleware().
   338  			WithName(DefaultMiddlewareServerVersion).
   339  			WithInterceptor(serverversion.StreamServerInterceptor(opts.enableVersionResponse)).
   340  			Done(),
   341  
   342  		NewStreamMiddleware().
   343  			WithName(DefaultInternalMiddlewareDispatch).
   344  			WithInternal(true).
   345  			WithInterceptor(dispatchmw.StreamServerInterceptor(opts.dispatcher)).
   346  			Done(),
   347  
   348  		NewStreamMiddleware().
   349  			WithName(DefaultInternalMiddlewareDatastore).
   350  			WithInternal(true).
   351  			WithInterceptor(datastoremw.StreamServerInterceptor(opts.ds)).
   352  			Done(),
   353  
   354  		NewStreamMiddleware().
   355  			WithName(DefaultInternalMiddlewareConsistency).
   356  			WithInternal(true).
   357  			WithInterceptor(consistencymw.StreamServerInterceptor()).
   358  			Done(),
   359  
   360  		NewStreamMiddleware().
   361  			WithName(DefaultInternalMiddlewareServerSpecific).
   362  			WithInternal(true).
   363  			WithInterceptor(servicespecific.StreamServerInterceptor).
   364  			Done(),
   365  	}...)
   366  	return &chain, err
   367  }
   368  
   369  func determineEventsToLog(opts MiddlewareOption) grpclog.Option {
   370  	eventsToLog := []grpclog.LoggableEvent{grpclog.FinishCall}
   371  	if opts.enableRequestLog {
   372  		eventsToLog = append(eventsToLog, grpclog.PayloadReceived)
   373  	}
   374  
   375  	if opts.enableResponseLog {
   376  		eventsToLog = append(eventsToLog, grpclog.PayloadSent)
   377  	}
   378  
   379  	return grpclog.WithLogOnEvents(eventsToLog...)
   380  }
   381  
   382  // DefaultDispatchMiddleware generates the default middleware chain used for the internal dispatch SpiceDB gRPC API
   383  func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, ds datastore.Datastore,
   384  	disableGRPCLatencyHistogram bool,
   385  ) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) {
   386  	grpcMetricsUnaryInterceptor, grpcMetricsStreamingInterceptor := GRPCMetrics(disableGRPCLatencyHistogram)
   387  	return []grpc.UnaryServerInterceptor{
   388  			requestid.UnaryServerInterceptor(requestid.GenerateIfMissing(true)),
   389  			logmw.UnaryServerInterceptor(logmw.ExtractMetadataField(string(requestmeta.RequestIDKey), "requestID")),
   390  			grpclog.UnaryServerInterceptor(InterceptorLogger(logger), dispatchDefaultCodeToLevel, durationFieldOption, traceIDFieldOption),
   391  			grpcMetricsUnaryInterceptor,
   392  			grpcauth.UnaryServerInterceptor(authFunc),
   393  			datastoremw.UnaryServerInterceptor(ds),
   394  			servicespecific.UnaryServerInterceptor,
   395  		}, []grpc.StreamServerInterceptor{
   396  			requestid.StreamServerInterceptor(requestid.GenerateIfMissing(true)),
   397  			logmw.StreamServerInterceptor(logmw.ExtractMetadataField(string(requestmeta.RequestIDKey), "requestID")),
   398  			grpclog.StreamServerInterceptor(InterceptorLogger(logger), dispatchDefaultCodeToLevel, durationFieldOption, traceIDFieldOption),
   399  			grpcMetricsStreamingInterceptor,
   400  			grpcauth.StreamServerInterceptor(authFunc),
   401  			datastoremw.StreamServerInterceptor(ds),
   402  			servicespecific.StreamServerInterceptor,
   403  		}
   404  }
   405  
   406  func InterceptorLogger(l zerolog.Logger) grpclog.Logger {
   407  	return grpclog.LoggerFunc(func(ctx context.Context, lvl grpclog.Level, msg string, fields ...any) {
   408  		l := l.With().Fields(fields).Logger()
   409  
   410  		switch lvl {
   411  		case grpclog.LevelDebug:
   412  			l.Debug().Msg(msg)
   413  		case grpclog.LevelInfo:
   414  			l.Info().Msg(msg)
   415  		case grpclog.LevelWarn:
   416  			l.Warn().Msg(msg)
   417  		case grpclog.LevelError:
   418  			l.Error().Msg(msg)
   419  		default:
   420  			l.Error().Int("level", int(lvl)).Msg("unknown error level - falling back to info level")
   421  			l.Info().Msg(msg)
   422  		}
   423  	})
   424  }
   425  
   426  // initializes prometheus grpc interceptors with exemplar support enabled
   427  func createServerMetrics(disableHistogram bool) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) {
   428  	var opts []grpcprom.ServerMetricsOption
   429  	if !disableHistogram {
   430  		opts = append(opts, grpcprom.WithServerHandlingTimeHistogram(
   431  			grpcprom.WithHistogramBuckets([]float64{.001, .003, .006, .010, .018, .024, .032, .042, .056, .075, .100, .178, .316, .562, 1, 5}),
   432  		))
   433  	}
   434  	srvMetrics := grpcprom.NewServerMetrics(opts...)
   435  	// deliberately ignore if these metrics were already registered, so that
   436  	// custom builds of SpiceDB can register these metrics with custom labels
   437  	_ = prometheus.Register(srvMetrics)
   438  
   439  	exemplarFromContext := func(ctx context.Context) prometheus.Labels {
   440  		if span := trace.SpanContextFromContext(ctx); span.IsSampled() {
   441  			return prometheus.Labels{"traceID": span.TraceID().String()}
   442  		}
   443  		return nil
   444  	}
   445  
   446  	exemplarContext := grpcprom.WithExemplarFromContext(exemplarFromContext)
   447  	return srvMetrics.UnaryServerInterceptor(exemplarContext), srvMetrics.StreamServerInterceptor(exemplarContext)
   448  }