github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/server/defaults.go (about) 1 package server 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "net/http" 8 "net/http/pprof" 9 "sync" 10 "time" 11 12 "github.com/fatih/color" 13 "github.com/go-logr/zerologr" 14 grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" 15 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" 16 grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" 17 grpclog "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" 18 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" 19 "github.com/jzelinskie/cobrautil/v2" 20 "github.com/jzelinskie/cobrautil/v2/cobraotel" 21 "github.com/jzelinskie/cobrautil/v2/cobrazerolog" 22 "github.com/prometheus/client_golang/prometheus" 23 "github.com/prometheus/client_golang/prometheus/promhttp" 24 "github.com/rs/zerolog" 25 "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" 26 "go.opentelemetry.io/otel/trace" 27 "google.golang.org/grpc" 28 "google.golang.org/grpc/codes" 29 30 "github.com/authzed/authzed-go/pkg/requestmeta" 31 32 "github.com/authzed/spicedb/internal/dispatch" 33 "github.com/authzed/spicedb/internal/logging" 34 consistencymw "github.com/authzed/spicedb/internal/middleware/consistency" 35 datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" 36 dispatchmw "github.com/authzed/spicedb/internal/middleware/dispatcher" 37 "github.com/authzed/spicedb/internal/middleware/servicespecific" 38 "github.com/authzed/spicedb/pkg/datastore" 39 logmw "github.com/authzed/spicedb/pkg/middleware/logging" 40 "github.com/authzed/spicedb/pkg/middleware/requestid" 41 "github.com/authzed/spicedb/pkg/middleware/serverversion" 42 "github.com/authzed/spicedb/pkg/releases" 43 "github.com/authzed/spicedb/pkg/runtime" 44 ) 45 46 var DisableTelemetryHandler *prometheus.Registry 47 48 // ServeExample creates an example usage string with the provided program name. 49 func ServeExample(programName string) string { 50 return fmt.Sprintf(` %[1]s: 51 %[3]s serve --grpc-preshared-key "somerandomkeyhere" 52 53 %[2]s: 54 %[3]s serve --grpc-preshared-key "realkeyhere" --grpc-tls-cert-path path/to/tls/cert --grpc-tls-key-path path/to/tls/key \ 55 --http-tls-cert-path path/to/tls/cert --http-tls-key-path path/to/tls/key \ 56 --datastore-engine postgres --datastore-conn-uri "postgres-connection-string-here" 57 `, 58 color.YellowString("No TLS and in-memory"), 59 color.GreenString("TLS and a real datastore"), 60 programName, 61 ) 62 } 63 64 // DefaultPreRunE sets up viper, zerolog, and OpenTelemetry flag handling for a 65 // command. 66 func DefaultPreRunE(programName string) cobrautil.CobraRunFunc { 67 return cobrautil.CommandStack( 68 cobrautil.SyncViperDotEnvPreRunE(programName, "spicedb.env", zerologr.New(&logging.Logger)), 69 cobrazerolog.New( 70 cobrazerolog.WithTarget(func(logger zerolog.Logger) { 71 logging.SetGlobalLogger(logger) 72 }), 73 ).RunE(), 74 cobraotel.New("spicedb", 75 cobraotel.WithLogger(zerologr.New(&logging.Logger)), 76 ).RunE(), 77 releases.CheckAndLogRunE(), 78 runtime.RunE(), 79 ) 80 } 81 82 // MetricsHandler sets up an HTTP server that handles serving Prometheus 83 // metrics and pprof endpoints. 84 func MetricsHandler(telemetryRegistry *prometheus.Registry, c *Config) http.Handler { 85 mux := http.NewServeMux() 86 87 mux.Handle("/metrics", promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{ 88 // Opt into OpenMetrics e.g. to support exemplars. 89 EnableOpenMetrics: true, 90 })) 91 if telemetryRegistry != nil { 92 mux.Handle("/telemetry", promhttp.HandlerFor(telemetryRegistry, promhttp.HandlerOpts{})) 93 } 94 95 mux.HandleFunc("/debug/pprof/", pprof.Index) 96 mux.HandleFunc("/debug/pprof/profile", pprof.Profile) 97 mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) 98 mux.HandleFunc("/debug/pprof/trace", pprof.Trace) 99 mux.HandleFunc("/debug/pprof/cmdline", func(w http.ResponseWriter, r *http.Request) { 100 w.WriteHeader(http.StatusNotFound) 101 fmt.Fprintf(w, "This profile type has been disabled to avoid leaking private command-line arguments") 102 }) 103 mux.HandleFunc("/debug/config", func(w http.ResponseWriter, r *http.Request) { 104 if c == nil { 105 w.WriteHeader(http.StatusNotFound) 106 return 107 } 108 109 json, err := json.MarshalIndent(c.DebugMap(), "", " ") 110 if err != nil { 111 w.WriteHeader(http.StatusInternalServerError) 112 return 113 } 114 115 fmt.Fprintf(w, "%s", string(json)) 116 }) 117 118 return mux 119 } 120 121 var defaultCodeToLevel = grpclog.WithLevels(func(code codes.Code) grpclog.Level { 122 if code == codes.DeadlineExceeded { 123 // The server has a deadline set, so we consider it a normal condition. 124 // This ensures that we don't log them as errors. 125 return grpclog.LevelInfo 126 } 127 return grpclog.DefaultServerCodeToLevel(code) 128 }) 129 130 var dispatchDefaultCodeToLevel = grpclog.WithLevels(func(code codes.Code) grpclog.Level { 131 switch code { 132 case codes.OK, codes.Canceled: 133 return grpclog.LevelDebug 134 case codes.NotFound, codes.AlreadyExists, codes.InvalidArgument, codes.Unauthenticated: 135 return grpclog.LevelWarn 136 default: 137 return grpclog.DefaultServerCodeToLevel(code) 138 } 139 }) 140 141 var durationFieldOption = grpclog.WithDurationField(func(duration time.Duration) grpclog.Fields { 142 return grpclog.Fields{"grpc.time_ms", duration.Milliseconds()} 143 }) 144 145 var traceIDFieldOption = grpclog.WithFieldsFromContext(func(ctx context.Context) grpclog.Fields { 146 if span := trace.SpanContextFromContext(ctx); span.IsSampled() { 147 return grpclog.Fields{"traceID", span.TraceID().String()} 148 } 149 return nil 150 }) 151 152 var alwaysDebugOption = grpclog.WithLevels(func(code codes.Code) grpclog.Level { 153 return grpclog.LevelDebug 154 }) 155 156 const ( 157 DefaultMiddlewareRequestID = "requestid" 158 DefaultMiddlewareLog = "log" 159 DefaultMiddlewareGRPCLog = "grpclog" 160 DefaultMiddlewareOTelGRPC = "otelgrpc" 161 DefaultMiddlewareGRPCAuth = "grpcauth" 162 DefaultMiddlewareGRPCProm = "grpcprom" 163 DefaultMiddlewareServerVersion = "serverversion" 164 165 DefaultInternalMiddlewareDispatch = "dispatch" 166 DefaultInternalMiddlewareDatastore = "datastore" 167 DefaultInternalMiddlewareConsistency = "consistency" 168 DefaultInternalMiddlewareServerSpecific = "servicespecific" 169 ) 170 171 type MiddlewareOption struct { 172 logger zerolog.Logger 173 authFunc grpcauth.AuthFunc 174 enableVersionResponse bool 175 dispatcher dispatch.Dispatcher 176 ds datastore.Datastore 177 enableRequestLog bool 178 enableResponseLog bool 179 disableGRPCHistogram bool 180 } 181 182 // gRPCMetricsUnaryInterceptor creates the default prometheus metrics interceptor for unary gRPCs 183 var gRPCMetricsUnaryInterceptor grpc.UnaryServerInterceptor 184 185 // gRPCMetricsStreamingInterceptor creates the default prometheus metrics interceptor for streaming gRPCs 186 var gRPCMetricsStreamingInterceptor grpc.StreamServerInterceptor 187 188 var serverMetricsOnce sync.Once 189 190 // GRPCMetrics returns the interceptors used for the default gRPC metrics from grpc-ecosystem/go-grpc-middleware 191 func GRPCMetrics(disableLatencyHistogram bool) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) { 192 serverMetricsOnce.Do(func() { 193 gRPCMetricsUnaryInterceptor, gRPCMetricsStreamingInterceptor = createServerMetrics(disableLatencyHistogram) 194 }) 195 196 return gRPCMetricsUnaryInterceptor, gRPCMetricsStreamingInterceptor 197 } 198 199 const healthCheckRoute = "/grpc.health.v1.Health/Check" 200 201 func matchesRoute(route string) func(_ context.Context, c interceptors.CallMeta) bool { 202 return func(_ context.Context, c interceptors.CallMeta) bool { 203 return c.FullMethod() == route 204 } 205 } 206 207 func doesNotMatchRoute(route string) func(_ context.Context, c interceptors.CallMeta) bool { 208 return func(_ context.Context, c interceptors.CallMeta) bool { 209 return c.FullMethod() != route 210 } 211 } 212 213 // DefaultUnaryMiddleware generates the default middleware chain used for the public SpiceDB Unary gRPC methods 214 func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryServerInterceptor], error) { 215 grpcMetricsUnaryInterceptor, _ := GRPCMetrics(opts.disableGRPCHistogram) 216 chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.UnaryServerInterceptor]{ 217 NewUnaryMiddleware(). 218 WithName(DefaultMiddlewareRequestID). 219 WithInterceptor(requestid.UnaryServerInterceptor(requestid.GenerateIfMissing(true))). 220 Done(), 221 222 NewUnaryMiddleware(). 223 WithName(DefaultMiddlewareLog). 224 WithInterceptor(logmw.UnaryServerInterceptor(logmw.ExtractMetadataField(string(requestmeta.RequestIDKey), "requestID"))). 225 Done(), 226 227 NewUnaryMiddleware(). 228 WithName(DefaultMiddlewareOTelGRPC). 229 WithInterceptor(otelgrpc.UnaryServerInterceptor()). // nolint: staticcheck 230 Done(), 231 232 NewUnaryMiddleware(). 233 WithName(DefaultMiddlewareGRPCLog + "-debug"). 234 WithInterceptor(selector.UnaryServerInterceptor( 235 grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption), 236 selector.MatchFunc(matchesRoute(healthCheckRoute)))). 237 EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs), 238 Done(), 239 240 NewUnaryMiddleware(). 241 WithName(DefaultMiddlewareGRPCLog). 242 WithInterceptor(selector.UnaryServerInterceptor( 243 grpclog.UnaryServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption), 244 selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))). 245 EnsureAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs), 246 Done(), 247 248 NewUnaryMiddleware(). 249 WithName(DefaultMiddlewareGRPCProm). 250 WithInterceptor(grpcMetricsUnaryInterceptor). 251 Done(), 252 253 NewUnaryMiddleware(). 254 WithName(DefaultMiddlewareGRPCAuth). 255 WithInterceptor(grpcauth.UnaryServerInterceptor(opts.authFunc)). 256 EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures 257 Done(), 258 259 NewUnaryMiddleware(). 260 WithName(DefaultMiddlewareServerVersion). 261 WithInterceptor(serverversion.UnaryServerInterceptor(opts.enableVersionResponse)). 262 Done(), 263 264 NewUnaryMiddleware(). 265 WithName(DefaultInternalMiddlewareDispatch). 266 WithInternal(true). 267 WithInterceptor(dispatchmw.UnaryServerInterceptor(opts.dispatcher)). 268 Done(), 269 270 NewUnaryMiddleware(). 271 WithName(DefaultInternalMiddlewareDatastore). 272 WithInternal(true). 273 WithInterceptor(datastoremw.UnaryServerInterceptor(opts.ds)). 274 Done(), 275 276 NewUnaryMiddleware(). 277 WithName(DefaultInternalMiddlewareConsistency). 278 WithInternal(true). 279 WithInterceptor(consistencymw.UnaryServerInterceptor()). 280 Done(), 281 282 NewUnaryMiddleware(). 283 WithName(DefaultInternalMiddlewareServerSpecific). 284 WithInternal(true). 285 WithInterceptor(servicespecific.UnaryServerInterceptor). 286 Done(), 287 }...) 288 return &chain, err 289 } 290 291 // DefaultStreamingMiddleware generates the default middleware chain used for the public SpiceDB Streaming gRPC methods 292 func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.StreamServerInterceptor], error) { 293 _, grpcMetricsStreamingInterceptor := GRPCMetrics(opts.disableGRPCHistogram) 294 chain, err := NewMiddlewareChain([]ReferenceableMiddleware[grpc.StreamServerInterceptor]{ 295 NewStreamMiddleware(). 296 WithName(DefaultMiddlewareRequestID). 297 WithInterceptor(requestid.StreamServerInterceptor(requestid.GenerateIfMissing(true))). 298 Done(), 299 300 NewStreamMiddleware(). 301 WithName(DefaultMiddlewareLog). 302 WithInterceptor(logmw.StreamServerInterceptor(logmw.ExtractMetadataField(string(requestmeta.RequestIDKey), "requestID"))). 303 Done(), 304 305 NewStreamMiddleware(). 306 WithName(DefaultMiddlewareOTelGRPC). 307 WithInterceptor(otelgrpc.StreamServerInterceptor()). // nolint: staticcheck 308 Done(), 309 310 NewStreamMiddleware(). 311 WithName(DefaultMiddlewareGRPCLog + "-debug"). 312 WithInterceptor(selector.StreamServerInterceptor( 313 grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), alwaysDebugOption, durationFieldOption, traceIDFieldOption), 314 selector.MatchFunc(matchesRoute(healthCheckRoute)))). 315 EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs), 316 Done(), 317 318 NewStreamMiddleware(). 319 WithName(DefaultMiddlewareGRPCLog). 320 WithInterceptor(selector.StreamServerInterceptor( 321 grpclog.StreamServerInterceptor(InterceptorLogger(opts.logger), determineEventsToLog(opts), defaultCodeToLevel, durationFieldOption, traceIDFieldOption), 322 selector.MatchFunc(doesNotMatchRoute(healthCheckRoute)))). 323 EnsureInterceptorAlreadyExecuted(DefaultMiddlewareOTelGRPC). // dependency so that OTel traceID is injected in logs), 324 Done(), 325 326 NewStreamMiddleware(). 327 WithName(DefaultMiddlewareGRPCProm). 328 WithInterceptor(grpcMetricsStreamingInterceptor). 329 Done(), 330 331 NewStreamMiddleware(). 332 WithName(DefaultMiddlewareGRPCAuth). 333 WithInterceptor(grpcauth.StreamServerInterceptor(opts.authFunc)). 334 EnsureInterceptorAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures 335 Done(), 336 337 NewStreamMiddleware(). 338 WithName(DefaultMiddlewareServerVersion). 339 WithInterceptor(serverversion.StreamServerInterceptor(opts.enableVersionResponse)). 340 Done(), 341 342 NewStreamMiddleware(). 343 WithName(DefaultInternalMiddlewareDispatch). 344 WithInternal(true). 345 WithInterceptor(dispatchmw.StreamServerInterceptor(opts.dispatcher)). 346 Done(), 347 348 NewStreamMiddleware(). 349 WithName(DefaultInternalMiddlewareDatastore). 350 WithInternal(true). 351 WithInterceptor(datastoremw.StreamServerInterceptor(opts.ds)). 352 Done(), 353 354 NewStreamMiddleware(). 355 WithName(DefaultInternalMiddlewareConsistency). 356 WithInternal(true). 357 WithInterceptor(consistencymw.StreamServerInterceptor()). 358 Done(), 359 360 NewStreamMiddleware(). 361 WithName(DefaultInternalMiddlewareServerSpecific). 362 WithInternal(true). 363 WithInterceptor(servicespecific.StreamServerInterceptor). 364 Done(), 365 }...) 366 return &chain, err 367 } 368 369 func determineEventsToLog(opts MiddlewareOption) grpclog.Option { 370 eventsToLog := []grpclog.LoggableEvent{grpclog.FinishCall} 371 if opts.enableRequestLog { 372 eventsToLog = append(eventsToLog, grpclog.PayloadReceived) 373 } 374 375 if opts.enableResponseLog { 376 eventsToLog = append(eventsToLog, grpclog.PayloadSent) 377 } 378 379 return grpclog.WithLogOnEvents(eventsToLog...) 380 } 381 382 // DefaultDispatchMiddleware generates the default middleware chain used for the internal dispatch SpiceDB gRPC API 383 func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, ds datastore.Datastore, 384 disableGRPCLatencyHistogram bool, 385 ) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) { 386 grpcMetricsUnaryInterceptor, grpcMetricsStreamingInterceptor := GRPCMetrics(disableGRPCLatencyHistogram) 387 return []grpc.UnaryServerInterceptor{ 388 requestid.UnaryServerInterceptor(requestid.GenerateIfMissing(true)), 389 logmw.UnaryServerInterceptor(logmw.ExtractMetadataField(string(requestmeta.RequestIDKey), "requestID")), 390 grpclog.UnaryServerInterceptor(InterceptorLogger(logger), dispatchDefaultCodeToLevel, durationFieldOption, traceIDFieldOption), 391 grpcMetricsUnaryInterceptor, 392 grpcauth.UnaryServerInterceptor(authFunc), 393 datastoremw.UnaryServerInterceptor(ds), 394 servicespecific.UnaryServerInterceptor, 395 }, []grpc.StreamServerInterceptor{ 396 requestid.StreamServerInterceptor(requestid.GenerateIfMissing(true)), 397 logmw.StreamServerInterceptor(logmw.ExtractMetadataField(string(requestmeta.RequestIDKey), "requestID")), 398 grpclog.StreamServerInterceptor(InterceptorLogger(logger), dispatchDefaultCodeToLevel, durationFieldOption, traceIDFieldOption), 399 grpcMetricsStreamingInterceptor, 400 grpcauth.StreamServerInterceptor(authFunc), 401 datastoremw.StreamServerInterceptor(ds), 402 servicespecific.StreamServerInterceptor, 403 } 404 } 405 406 func InterceptorLogger(l zerolog.Logger) grpclog.Logger { 407 return grpclog.LoggerFunc(func(ctx context.Context, lvl grpclog.Level, msg string, fields ...any) { 408 l := l.With().Fields(fields).Logger() 409 410 switch lvl { 411 case grpclog.LevelDebug: 412 l.Debug().Msg(msg) 413 case grpclog.LevelInfo: 414 l.Info().Msg(msg) 415 case grpclog.LevelWarn: 416 l.Warn().Msg(msg) 417 case grpclog.LevelError: 418 l.Error().Msg(msg) 419 default: 420 l.Error().Int("level", int(lvl)).Msg("unknown error level - falling back to info level") 421 l.Info().Msg(msg) 422 } 423 }) 424 } 425 426 // initializes prometheus grpc interceptors with exemplar support enabled 427 func createServerMetrics(disableHistogram bool) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) { 428 var opts []grpcprom.ServerMetricsOption 429 if !disableHistogram { 430 opts = append(opts, grpcprom.WithServerHandlingTimeHistogram( 431 grpcprom.WithHistogramBuckets([]float64{.001, .003, .006, .010, .018, .024, .032, .042, .056, .075, .100, .178, .316, .562, 1, 5}), 432 )) 433 } 434 srvMetrics := grpcprom.NewServerMetrics(opts...) 435 // deliberately ignore if these metrics were already registered, so that 436 // custom builds of SpiceDB can register these metrics with custom labels 437 _ = prometheus.Register(srvMetrics) 438 439 exemplarFromContext := func(ctx context.Context) prometheus.Labels { 440 if span := trace.SpanContextFromContext(ctx); span.IsSampled() { 441 return prometheus.Labels{"traceID": span.TraceID().String()} 442 } 443 return nil 444 } 445 446 exemplarContext := grpcprom.WithExemplarFromContext(exemplarFromContext) 447 return srvMetrics.UnaryServerInterceptor(exemplarContext), srvMetrics.StreamServerInterceptor(exemplarContext) 448 }