github.com/openfga/openfga@v1.5.4-rc1/cmd/run/run.go (about) 1 // Package run contains the command to run an OpenFGA server. 2 package run 3 4 import ( 5 "context" 6 "errors" 7 "fmt" 8 "html/template" 9 "net" 10 "net/http" 11 "net/http/pprof" 12 "os" 13 "os/signal" 14 goruntime "runtime" 15 "strconv" 16 "strings" 17 "syscall" 18 "time" 19 20 "github.com/cenkalti/backoff/v4" 21 grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" 22 grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" 23 grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" 24 "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" 25 grpc_prometheus "github.com/jon-whit/go-grpc-prometheus" 26 openfgav1 "github.com/openfga/api/proto/openfga/v1" 27 "github.com/prometheus/client_golang/prometheus/promhttp" 28 "github.com/rs/cors" 29 "github.com/spf13/cobra" 30 "github.com/spf13/viper" 31 "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" 32 "go.opentelemetry.io/otel" 33 semconv "go.opentelemetry.io/otel/semconv/v1.12.0" 34 "go.opentelemetry.io/otel/trace/noop" 35 "go.uber.org/zap" 36 "google.golang.org/grpc" 37 "google.golang.org/grpc/credentials" 38 "google.golang.org/grpc/credentials/insecure" 39 healthv1pb "google.golang.org/grpc/health/grpc_health_v1" 40 "google.golang.org/grpc/reflection" 41 "google.golang.org/grpc/status" 42 43 "github.com/openfga/openfga/pkg/gateway" 44 45 "github.com/openfga/openfga/assets" 46 "github.com/openfga/openfga/internal/authn" 47 "github.com/openfga/openfga/internal/authn/oidc" 48 "github.com/openfga/openfga/internal/authn/presharedkey" 49 "github.com/openfga/openfga/internal/build" 50 authnmw "github.com/openfga/openfga/internal/middleware/authn" 51 serverconfig "github.com/openfga/openfga/internal/server/config" 52 "github.com/openfga/openfga/pkg/logger" 53 "github.com/openfga/openfga/pkg/middleware" 54 httpmiddleware "github.com/openfga/openfga/pkg/middleware/http" 55 "github.com/openfga/openfga/pkg/middleware/logging" 56 "github.com/openfga/openfga/pkg/middleware/recovery" 57 "github.com/openfga/openfga/pkg/middleware/requestid" 58 "github.com/openfga/openfga/pkg/middleware/storeid" 59 "github.com/openfga/openfga/pkg/middleware/validator" 60 "github.com/openfga/openfga/pkg/server" 61 serverErrors "github.com/openfga/openfga/pkg/server/errors" 62 "github.com/openfga/openfga/pkg/server/health" 63 "github.com/openfga/openfga/pkg/storage" 64 "github.com/openfga/openfga/pkg/storage/memory" 65 "github.com/openfga/openfga/pkg/storage/mysql" 66 "github.com/openfga/openfga/pkg/storage/postgres" 67 "github.com/openfga/openfga/pkg/storage/sqlcommon" 68 "github.com/openfga/openfga/pkg/storage/storagewrappers" 69 "github.com/openfga/openfga/pkg/telemetry" 70 ) 71 72 const ( 73 datastoreEngineFlag = "datastore-engine" 74 datastoreURIFlag = "datastore-uri" 75 ) 76 77 func NewRunCommand() *cobra.Command { 78 cmd := &cobra.Command{ 79 Use: "run", 80 Short: "Run the OpenFGA server", 81 Long: "Run the OpenFGA server.", 82 Run: run, 83 Args: cobra.NoArgs, 84 } 85 86 defaultConfig := serverconfig.DefaultConfig() 87 flags := cmd.Flags() 88 89 flags.StringSlice("experimentals", defaultConfig.Experimentals, "a list of experimental features to enable. Allowed values: `enable-list-users`") 90 91 flags.String("grpc-addr", defaultConfig.GRPC.Addr, "the host:port address to serve the grpc server on") 92 93 flags.Bool("grpc-tls-enabled", defaultConfig.GRPC.TLS.Enabled, "enable/disable transport layer security (TLS)") 94 95 flags.String("grpc-tls-cert", defaultConfig.GRPC.TLS.CertPath, "the (absolute) file path of the certificate to use for the TLS connection") 96 97 flags.String("grpc-tls-key", defaultConfig.GRPC.TLS.KeyPath, "the (absolute) file path of the TLS key that should be used for the TLS connection") 98 99 cmd.MarkFlagsRequiredTogether("grpc-tls-enabled", "grpc-tls-cert", "grpc-tls-key") 100 101 flags.Bool("http-enabled", defaultConfig.HTTP.Enabled, "enable/disable the OpenFGA HTTP server") 102 103 flags.String("http-addr", defaultConfig.HTTP.Addr, "the host:port address to serve the HTTP server on") 104 105 flags.Bool("http-tls-enabled", defaultConfig.HTTP.TLS.Enabled, "enable/disable transport layer security (TLS)") 106 107 flags.String("http-tls-cert", defaultConfig.HTTP.TLS.CertPath, "the (absolute) file path of the certificate to use for the TLS connection") 108 109 flags.String("http-tls-key", defaultConfig.HTTP.TLS.KeyPath, "the (absolute) file path of the TLS key that should be used for the TLS connection") 110 111 cmd.MarkFlagsRequiredTogether("http-tls-enabled", "http-tls-cert", "http-tls-key") 112 113 flags.Duration("http-upstream-timeout", defaultConfig.HTTP.UpstreamTimeout, "the timeout duration for proxying HTTP requests upstream to the grpc endpoint") 114 115 flags.StringSlice("http-cors-allowed-origins", defaultConfig.HTTP.CORSAllowedOrigins, "specifies the CORS allowed origins") 116 117 flags.StringSlice("http-cors-allowed-headers", defaultConfig.HTTP.CORSAllowedHeaders, "specifies the CORS allowed headers") 118 119 flags.String("authn-method", defaultConfig.Authn.Method, "the authentication method to use") 120 121 flags.StringSlice("authn-preshared-keys", defaultConfig.Authn.Keys, "one or more preshared keys to use for authentication") 122 123 flags.String("authn-oidc-audience", defaultConfig.Authn.Audience, "the OIDC audience of the tokens being signed by the authorization server") 124 125 flags.String("authn-oidc-issuer", defaultConfig.Authn.Issuer, "the OIDC issuer (authorization server) signing the tokens, and where the keys will be fetched from") 126 127 flags.StringSlice("authn-oidc-issuer-aliases", defaultConfig.Authn.IssuerAliases, "the OIDC issuer DNS aliases that will be accepted as valid when verifying tokens") 128 129 flags.String("datastore-engine", defaultConfig.Datastore.Engine, "the datastore engine that will be used for persistence") 130 131 flags.String("datastore-uri", defaultConfig.Datastore.URI, "the connection uri to use to connect to the datastore (for any engine other than 'memory')") 132 133 flags.String("datastore-username", "", "the connection username to use to connect to the datastore (overwrites any username provided in the connection uri)") 134 135 flags.String("datastore-password", "", "the connection password to use to connect to the datastore (overwrites any password provided in the connection uri)") 136 137 flags.Int("datastore-max-cache-size", defaultConfig.Datastore.MaxCacheSize, "the maximum number of cache keys that the storage cache can store before evicting old keys") 138 139 flags.Int("datastore-max-open-conns", defaultConfig.Datastore.MaxOpenConns, "the maximum number of open connections to the datastore") 140 141 flags.Int("datastore-max-idle-conns", defaultConfig.Datastore.MaxIdleConns, "the maximum number of connections to the datastore in the idle connection pool") 142 143 flags.Duration("datastore-conn-max-idle-time", defaultConfig.Datastore.ConnMaxIdleTime, "the maximum amount of time a connection to the datastore may be idle") 144 145 flags.Duration("datastore-conn-max-lifetime", defaultConfig.Datastore.ConnMaxLifetime, "the maximum amount of time a connection to the datastore may be reused") 146 147 flags.Bool("datastore-metrics-enabled", defaultConfig.Datastore.Metrics.Enabled, "enable/disable sql metrics") 148 149 flags.Bool("playground-enabled", defaultConfig.Playground.Enabled, "enable/disable the OpenFGA Playground") 150 151 flags.Int("playground-port", defaultConfig.Playground.Port, "the port to serve the local OpenFGA Playground on") 152 153 flags.Bool("profiler-enabled", defaultConfig.Profiler.Enabled, "enable/disable pprof profiling") 154 155 flags.String("profiler-addr", defaultConfig.Profiler.Addr, "the host:port address to serve the pprof profiler server on") 156 157 flags.String("log-format", defaultConfig.Log.Format, "the log format to output logs in") 158 159 flags.String("log-level", defaultConfig.Log.Level, "the log level to use") 160 161 flags.String("log-timestamp-format", defaultConfig.Log.TimestampFormat, "the timestamp format to use for log messages") 162 163 flags.Bool("trace-enabled", defaultConfig.Trace.Enabled, "enable tracing") 164 165 flags.String("trace-otlp-endpoint", defaultConfig.Trace.OTLP.Endpoint, "the endpoint of the trace collector") 166 167 flags.Bool("trace-otlp-tls-enabled", defaultConfig.Trace.OTLP.TLS.Enabled, "use TLS connection for trace collector") 168 169 flags.Float64("trace-sample-ratio", defaultConfig.Trace.SampleRatio, "the fraction of traces to sample. 1 means all, 0 means none.") 170 171 flags.String("trace-service-name", defaultConfig.Trace.ServiceName, "the service name included in sampled traces.") 172 173 flags.Bool("metrics-enabled", defaultConfig.Metrics.Enabled, "enable/disable prometheus metrics on the '/metrics' endpoint") 174 175 flags.String("metrics-addr", defaultConfig.Metrics.Addr, "the host:port address to serve the prometheus metrics server on") 176 177 flags.Bool("metrics-enable-rpc-histograms", defaultConfig.Metrics.EnableRPCHistograms, "enables prometheus histogram metrics for RPC latency distributions") 178 179 flags.Int("max-tuples-per-write", defaultConfig.MaxTuplesPerWrite, "the maximum allowed number of tuples per Write transaction") 180 181 flags.Int("max-types-per-authorization-model", defaultConfig.MaxTypesPerAuthorizationModel, "the maximum allowed number of type definitions per authorization model") 182 183 flags.Int("max-authorization-model-size-in-bytes", defaultConfig.MaxAuthorizationModelSizeInBytes, "the maximum size in bytes allowed for persisting an Authorization Model.") 184 185 flags.Uint32("max-concurrent-reads-for-list-users", defaultConfig.MaxConcurrentReadsForListUsers, "the maximum allowed number of concurrent datastore reads in a single ListUsers query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.") 186 187 flags.Uint32("max-concurrent-reads-for-list-objects", defaultConfig.MaxConcurrentReadsForListObjects, "the maximum allowed number of concurrent datastore reads in a single ListObjects or StreamedListObjects query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.") 188 189 flags.Uint32("max-concurrent-reads-for-check", defaultConfig.MaxConcurrentReadsForCheck, "the maximum allowed number of concurrent datastore reads in a single Check query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.") 190 191 flags.Int("changelog-horizon-offset", defaultConfig.ChangelogHorizonOffset, "the offset (in minutes) from the current time. Changes that occur after this offset will not be included in the response of ReadChanges") 192 193 flags.Uint32("resolve-node-limit", defaultConfig.ResolveNodeLimit, "maximum resolution depth to attempt before throwing an error (defines how deeply nested an authorization model can be before a query errors out).") 194 195 flags.Uint32("resolve-node-breadth-limit", defaultConfig.ResolveNodeBreadthLimit, "defines how many nodes on a given level can be evaluated concurrently in a Check resolution tree") 196 197 flags.Duration("listObjects-deadline", defaultConfig.ListObjectsDeadline, "the timeout deadline for serving ListObjects and StreamedListObjects requests") 198 199 flags.Uint32("listObjects-max-results", defaultConfig.ListObjectsMaxResults, "the maximum results to return in non-streaming ListObjects API responses. If 0, all results can be returned") 200 201 flags.Duration("listUsers-deadline", defaultConfig.ListUsersDeadline, "the timeout deadline for serving ListUsers requests. If 0, there is no deadline") 202 203 flags.Uint32("listUsers-max-results", defaultConfig.ListUsersMaxResults, "the maximum results to return in ListUsers API responses. If 0, all results can be returned") 204 205 flags.Bool("check-query-cache-enabled", defaultConfig.CheckQueryCache.Enabled, "when executing Check and ListObjects requests, enables caching. This will turn Check and ListObjects responses into eventually consistent responses") 206 207 flags.Uint32("check-query-cache-limit", defaultConfig.CheckQueryCache.Limit, "if caching of Check and ListObjects calls is enabled, this is the size limit of the cache") 208 209 flags.Duration("check-query-cache-ttl", defaultConfig.CheckQueryCache.TTL, "if caching of Check and ListObjects is enabled, this is the TTL of each value") 210 211 // Unfortunately UintSlice/IntSlice does not work well when used as environment variable, we need to stick with string slice and convert back to integer 212 flags.StringSlice("request-duration-datastore-query-count-buckets", defaultConfig.RequestDurationDatastoreQueryCountBuckets, "datastore query count buckets used in labelling request_duration_ms.") 213 214 flags.StringSlice("request-duration-dispatch-count-buckets", defaultConfig.RequestDurationDispatchCountBuckets, "dispatch count (i.e number of concurrent traversals to resolve a query) buckets used in labelling request_duration_ms.") 215 216 flags.Bool("dispatch-throttling-enabled", defaultConfig.DispatchThrottling.Enabled, "enable throttling when request's number of dispatches is high. Enabling this feature will prioritize dispatched requests requiring less than the configured dispatch threshold over requests whose dispatch count exceeds the configured threshold.") 217 218 flags.Duration("dispatch-throttling-frequency", defaultConfig.DispatchThrottling.Frequency, "defines how frequent dispatch throttling will be evaluated. Frequency controls how frequently throttled dispatch requests are dispatched.") 219 220 flags.Uint32("dispatch-throttling-threshold", defaultConfig.DispatchThrottling.Threshold, "define the default threshold on number of dispatches above which requests will be throttled.") 221 222 flags.Uint32("dispatch-throttling-max-threshold", defaultConfig.DispatchThrottling.MaxThreshold, "define the maximum dispatch threshold beyond which requests will be throttled. 0 will use the 'dispatch-throttling-threshold' value as maximum") 223 224 flags.Duration("request-timeout", defaultConfig.RequestTimeout, "configures request timeout. If both HTTP upstream timeout and request timeout are specified, request timeout will be used.") 225 226 // NOTE: if you add a new flag here, update the function below, too 227 228 cmd.PreRun = bindRunFlagsFunc(flags) 229 230 return cmd 231 } 232 233 // ReadConfig returns the OpenFGA server configuration based on the values provided in the server's 'config.yaml' file. 234 // The 'config.yaml' file is loaded from '/etc/openfga', '$HOME/.openfga', or the current working directory. If no configuration 235 // file is present, the default values are returned. 236 func ReadConfig() (*serverconfig.Config, error) { 237 config := serverconfig.DefaultConfig() 238 239 viper.SetTypeByDefaultValue(true) 240 err := viper.ReadInConfig() 241 if err != nil { 242 if _, ok := err.(viper.ConfigFileNotFoundError); !ok { 243 return nil, fmt.Errorf("failed to load server config: %w", err) 244 } 245 } 246 247 if err := viper.Unmarshal(config); err != nil { 248 return nil, fmt.Errorf("failed to unmarshal server config: %w", err) 249 } 250 251 return config, nil 252 } 253 254 func run(_ *cobra.Command, _ []string) { 255 config, err := ReadConfig() 256 if err != nil { 257 panic(err) 258 } 259 260 if err := config.Verify(); err != nil { 261 panic(err) 262 } 263 264 logger := logger.MustNewLogger(config.Log.Format, config.Log.Level, config.Log.TimestampFormat) 265 266 serverCtx := &ServerContext{Logger: logger} 267 if err := serverCtx.Run(context.Background(), config); err != nil { 268 panic(err) 269 } 270 } 271 272 type ServerContext struct { 273 Logger logger.Logger 274 } 275 276 func convertStringArrayToUintArray(stringArray []string) []uint { 277 uintArray := []uint{} 278 for _, val := range stringArray { 279 // note that we have already validated whether the array item is non-negative integer 280 valInt, err := strconv.Atoi(val) 281 if err == nil { 282 uintArray = append(uintArray, uint(valInt)) 283 } 284 } 285 return uintArray 286 } 287 288 func (s *ServerContext) telemetryConfig(ctx context.Context, config *serverconfig.Config) func() { 289 var tracerProviderCloser func() 290 291 if config.Trace.Enabled { 292 s.Logger.Info(fmt.Sprintf("๐ต tracing enabled: sampling ratio is %v and sending traces to '%s', tls: %t", config.Trace.SampleRatio, config.Trace.OTLP.Endpoint, config.Trace.OTLP.TLS.Enabled)) 293 294 options := []telemetry.TracerOption{ 295 telemetry.WithOTLPEndpoint( 296 config.Trace.OTLP.Endpoint, 297 ), 298 telemetry.WithAttributes( 299 semconv.ServiceNameKey.String(config.Trace.ServiceName), 300 semconv.ServiceVersionKey.String(build.Version), 301 ), 302 telemetry.WithSamplingRatio(config.Trace.SampleRatio), 303 } 304 305 if !config.Trace.OTLP.TLS.Enabled { 306 options = append(options, telemetry.WithOTLPInsecure()) 307 } 308 309 tp := telemetry.MustNewTracerProvider(options...) 310 tracerProviderCloser = func() { 311 _ = tp.ForceFlush(ctx) 312 _ = tp.Shutdown(ctx) 313 } 314 } else { 315 otel.SetTracerProvider(noop.NewTracerProvider()) 316 } 317 return tracerProviderCloser 318 } 319 320 func (s *ServerContext) datastoreConfig(config *serverconfig.Config) (storage.OpenFGADatastore, error) { 321 datastoreOptions := []sqlcommon.DatastoreOption{ 322 sqlcommon.WithUsername(config.Datastore.Username), 323 sqlcommon.WithPassword(config.Datastore.Password), 324 sqlcommon.WithLogger(s.Logger), 325 sqlcommon.WithMaxTuplesPerWrite(config.MaxTuplesPerWrite), 326 sqlcommon.WithMaxTypesPerAuthorizationModel(config.MaxTypesPerAuthorizationModel), 327 sqlcommon.WithMaxOpenConns(config.Datastore.MaxOpenConns), 328 sqlcommon.WithMaxIdleConns(config.Datastore.MaxIdleConns), 329 sqlcommon.WithConnMaxIdleTime(config.Datastore.ConnMaxIdleTime), 330 sqlcommon.WithConnMaxLifetime(config.Datastore.ConnMaxLifetime), 331 } 332 333 if config.Datastore.Metrics.Enabled { 334 datastoreOptions = append(datastoreOptions, sqlcommon.WithMetrics()) 335 } 336 337 dsCfg := sqlcommon.NewConfig(datastoreOptions...) 338 339 var datastore storage.OpenFGADatastore 340 var err error 341 switch config.Datastore.Engine { 342 case "memory": 343 opts := []memory.StorageOption{ 344 memory.WithMaxTypesPerAuthorizationModel(config.MaxTypesPerAuthorizationModel), 345 memory.WithMaxTuplesPerWrite(config.MaxTuplesPerWrite), 346 } 347 datastore = memory.New(opts...) 348 case "mysql": 349 datastore, err = mysql.New(config.Datastore.URI, dsCfg) 350 if err != nil { 351 return nil, fmt.Errorf("initialize mysql datastore: %w", err) 352 } 353 case "postgres": 354 datastore, err = postgres.New(config.Datastore.URI, dsCfg) 355 if err != nil { 356 return nil, fmt.Errorf("initialize postgres datastore: %w", err) 357 } 358 default: 359 return nil, fmt.Errorf("storage engine '%s' is unsupported", config.Datastore.Engine) 360 } 361 datastore = storagewrappers.NewCachedOpenFGADatastore(storagewrappers.NewContextWrapper(datastore), config.Datastore.MaxCacheSize) 362 363 s.Logger.Info(fmt.Sprintf("using '%v' storage engine", config.Datastore.Engine)) 364 return datastore, nil 365 } 366 367 func (s *ServerContext) authenticatorConfig(config *serverconfig.Config) (authn.Authenticator, error) { 368 var authenticator authn.Authenticator 369 var err error 370 371 switch config.Authn.Method { 372 case "none": 373 s.Logger.Warn("authentication is disabled") 374 authenticator = authn.NoopAuthenticator{} 375 case "preshared": 376 s.Logger.Info("using 'preshared' authentication") 377 authenticator, err = presharedkey.NewPresharedKeyAuthenticator(config.Authn.Keys) 378 case "oidc": 379 s.Logger.Info("using 'oidc' authentication") 380 authenticator, err = oidc.NewRemoteOidcAuthenticator(config.Authn.Issuer, config.Authn.IssuerAliases, config.Authn.Audience) 381 default: 382 return nil, fmt.Errorf("unsupported authentication method '%v'", config.Authn.Method) 383 } 384 if err != nil { 385 return nil, fmt.Errorf("failed to initialize authenticator: %w", err) 386 } 387 return authenticator, nil 388 } 389 390 // Run returns an error if the server was unable to start successfully. 391 // If it started and terminated successfully, it returns a nil error. 392 func (s *ServerContext) Run(ctx context.Context, config *serverconfig.Config) error { 393 tracerProviderCloser := s.telemetryConfig(ctx, config) 394 395 s.Logger.Info(fmt.Sprintf("๐งช experimental features enabled: %v", config.Experimentals)) 396 397 var experimentals []server.ExperimentalFeatureFlag 398 for _, feature := range config.Experimentals { 399 experimentals = append(experimentals, server.ExperimentalFeatureFlag(feature)) 400 } 401 402 datastore, err := s.datastoreConfig(config) 403 if err != nil { 404 return err 405 } 406 407 authenticator, err := s.authenticatorConfig(config) 408 409 if err != nil { 410 return err 411 } 412 413 serverOpts := []grpc.ServerOption{ 414 grpc.MaxRecvMsgSize(serverconfig.DefaultMaxRPCMessageSizeInBytes), 415 grpc.ChainUnaryInterceptor( 416 []grpc.UnaryServerInterceptor{ 417 grpc_recovery.UnaryServerInterceptor( // panic middleware must be 1st in chain 418 grpc_recovery.WithRecoveryHandlerContext( 419 recovery.PanicRecoveryHandler(s.Logger), 420 ), 421 ), 422 grpc_ctxtags.UnaryServerInterceptor(), // needed for logging 423 requestid.NewUnaryInterceptor(), // add request_id to ctxtags 424 }..., 425 ), 426 grpc.ChainStreamInterceptor( 427 []grpc.StreamServerInterceptor{ 428 grpc_recovery.StreamServerInterceptor( 429 grpc_recovery.WithRecoveryHandlerContext( 430 recovery.PanicRecoveryHandler(s.Logger), 431 ), 432 ), 433 requestid.NewStreamingInterceptor(), 434 }..., 435 ), 436 } 437 438 if config.RequestTimeout > 0 { 439 timeoutMiddleware := middleware.NewTimeoutInterceptor(config.RequestTimeout, s.Logger) 440 441 serverOpts = append(serverOpts, grpc.ChainUnaryInterceptor(timeoutMiddleware.NewUnaryTimeoutInterceptor())) 442 serverOpts = append(serverOpts, grpc.ChainStreamInterceptor(timeoutMiddleware.NewStreamTimeoutInterceptor())) 443 } 444 445 serverOpts = append(serverOpts, 446 grpc.ChainUnaryInterceptor( 447 []grpc.UnaryServerInterceptor{ 448 storeid.NewUnaryInterceptor(), // if available, add store_id to ctxtags 449 logging.NewLoggingInterceptor(s.Logger), // needed to log invalid requests 450 validator.UnaryServerInterceptor(), 451 }..., 452 ), 453 grpc.ChainStreamInterceptor( 454 []grpc.StreamServerInterceptor{ 455 validator.StreamServerInterceptor(), 456 grpc_ctxtags.StreamServerInterceptor(), 457 }..., 458 ), 459 ) 460 461 if config.Metrics.Enabled { 462 serverOpts = append(serverOpts, 463 grpc.ChainUnaryInterceptor(grpc_prometheus.UnaryServerInterceptor), 464 grpc.ChainStreamInterceptor(grpc_prometheus.StreamServerInterceptor)) 465 466 if config.Metrics.EnableRPCHistograms { 467 grpc_prometheus.EnableHandlingTimeHistogram() 468 } 469 } 470 471 if config.Trace.Enabled { 472 serverOpts = append(serverOpts, grpc.StatsHandler(otelgrpc.NewServerHandler())) 473 } 474 475 serverOpts = append(serverOpts, grpc.ChainUnaryInterceptor( 476 []grpc.UnaryServerInterceptor{ 477 grpcauth.UnaryServerInterceptor(authnmw.AuthFunc(authenticator)), 478 }...), 479 grpc.ChainStreamInterceptor( 480 []grpc.StreamServerInterceptor{ 481 grpcauth.StreamServerInterceptor(authnmw.AuthFunc(authenticator)), 482 // The following interceptors wrap the server stream with our own 483 // wrapper and must come last. 484 storeid.NewStreamingInterceptor(), 485 logging.NewStreamingLoggingInterceptor(s.Logger), 486 }..., 487 ), 488 ) 489 490 if config.GRPC.TLS.Enabled { 491 if config.GRPC.TLS.CertPath == "" || config.GRPC.TLS.KeyPath == "" { 492 return errors.New("'grpc.tls.cert' and 'grpc.tls.key' configs must be set") 493 } 494 creds, err := credentials.NewServerTLSFromFile(config.GRPC.TLS.CertPath, config.GRPC.TLS.KeyPath) 495 if err != nil { 496 return err 497 } 498 499 serverOpts = append(serverOpts, grpc.Creds(creds)) 500 501 s.Logger.Info("grpc TLS is enabled, serving connections using the provided certificate") 502 } else { 503 s.Logger.Warn("grpc TLS is disabled, serving connections using insecure plaintext") 504 } 505 506 if config.Profiler.Enabled { 507 mux := http.NewServeMux() 508 mux.HandleFunc("/debug/pprof/", pprof.Index) 509 mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) 510 mux.HandleFunc("/debug/pprof/profile", pprof.Profile) 511 mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) 512 mux.HandleFunc("/debug/pprof/trace", pprof.Trace) 513 514 go func() { 515 s.Logger.Info(fmt.Sprintf("๐ฌ starting pprof profiler on '%s'", config.Profiler.Addr)) 516 517 if err := http.ListenAndServe(config.Profiler.Addr, mux); err != nil { 518 if err != http.ErrServerClosed { 519 s.Logger.Fatal("failed to start pprof profiler", zap.Error(err)) 520 } 521 } 522 }() 523 } 524 525 if config.Metrics.Enabled { 526 s.Logger.Info(fmt.Sprintf("๐ starting metrics server on '%s'", config.Metrics.Addr)) 527 528 go func() { 529 mux := http.NewServeMux() 530 mux.Handle("/metrics", promhttp.Handler()) 531 if err := http.ListenAndServe(config.Metrics.Addr, mux); err != nil { 532 if err != http.ErrServerClosed { 533 s.Logger.Fatal("failed to start prometheus metrics server", zap.Error(err)) 534 } 535 } 536 }() 537 } 538 539 svr := server.MustNewServerWithOpts( 540 server.WithDatastore(datastore), 541 server.WithLogger(s.Logger), 542 server.WithTransport(gateway.NewRPCTransport(s.Logger)), 543 server.WithResolveNodeLimit(config.ResolveNodeLimit), 544 server.WithResolveNodeBreadthLimit(config.ResolveNodeBreadthLimit), 545 server.WithChangelogHorizonOffset(config.ChangelogHorizonOffset), 546 server.WithListObjectsDeadline(config.ListObjectsDeadline), 547 server.WithListObjectsMaxResults(config.ListObjectsMaxResults), 548 server.WithListUsersDeadline(config.ListUsersDeadline), 549 server.WithListUsersMaxResults(config.ListUsersMaxResults), 550 server.WithMaxConcurrentReadsForListObjects(config.MaxConcurrentReadsForListObjects), 551 server.WithMaxConcurrentReadsForCheck(config.MaxConcurrentReadsForCheck), 552 server.WithMaxConcurrentReadsForListUsers(config.MaxConcurrentReadsForListUsers), 553 server.WithCheckQueryCacheEnabled(config.CheckQueryCache.Enabled), 554 server.WithCheckQueryCacheLimit(config.CheckQueryCache.Limit), 555 server.WithCheckQueryCacheTTL(config.CheckQueryCache.TTL), 556 server.WithRequestDurationByQueryHistogramBuckets(convertStringArrayToUintArray(config.RequestDurationDatastoreQueryCountBuckets)), 557 server.WithRequestDurationByDispatchCountHistogramBuckets(convertStringArrayToUintArray(config.RequestDurationDispatchCountBuckets)), 558 server.WithMaxAuthorizationModelSizeInBytes(config.MaxAuthorizationModelSizeInBytes), 559 server.WithDispatchThrottlingCheckResolverEnabled(config.DispatchThrottling.Enabled), 560 server.WithDispatchThrottlingCheckResolverFrequency(config.DispatchThrottling.Frequency), 561 server.WithDispatchThrottlingCheckResolverThreshold(config.DispatchThrottling.Threshold), 562 server.WithDispatchThrottlingCheckResolverMaxThreshold(config.DispatchThrottling.MaxThreshold), 563 server.WithExperimentals(experimentals...), 564 ) 565 566 s.Logger.Info( 567 "๐ starting openfga service...", 568 zap.String("version", build.Version), 569 zap.String("date", build.Date), 570 zap.String("commit", build.Commit), 571 zap.String("go-version", goruntime.Version()), 572 ) 573 574 // nosemgrep: grpc-server-insecure-connection 575 grpcServer := grpc.NewServer(serverOpts...) 576 openfgav1.RegisterOpenFGAServiceServer(grpcServer, svr) 577 healthServer := &health.Checker{TargetService: svr, TargetServiceName: openfgav1.OpenFGAService_ServiceDesc.ServiceName} 578 healthv1pb.RegisterHealthServer(grpcServer, healthServer) 579 reflection.Register(grpcServer) 580 581 lis, err := net.Listen("tcp", config.GRPC.Addr) 582 if err != nil { 583 return fmt.Errorf("failed to listen: %w", err) 584 } 585 586 go func() { 587 if err := grpcServer.Serve(lis); err != nil { 588 if !errors.Is(err, grpc.ErrServerStopped) { 589 s.Logger.Fatal("failed to start grpc server", zap.Error(err)) 590 } 591 592 s.Logger.Info("grpc server shut down..") 593 } 594 }() 595 s.Logger.Info(fmt.Sprintf("grpc server listening on '%s'...", config.GRPC.Addr)) 596 597 var httpServer *http.Server 598 if config.HTTP.Enabled { 599 runtime.DefaultContextTimeout = serverconfig.DefaultContextTimeout(config) 600 601 dialOpts := []grpc.DialOption{ 602 grpc.WithBlock(), 603 } 604 if config.GRPC.TLS.Enabled { 605 creds, err := credentials.NewClientTLSFromFile(config.GRPC.TLS.CertPath, "") 606 if err != nil { 607 s.Logger.Fatal("", zap.Error(err)) 608 } 609 dialOpts = append(dialOpts, grpc.WithTransportCredentials(creds)) 610 } else { 611 dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) 612 } 613 614 timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) 615 defer cancel() 616 617 conn, err := grpc.DialContext(timeoutCtx, config.GRPC.Addr, dialOpts...) 618 if err != nil { 619 s.Logger.Fatal("", zap.Error(err)) 620 } 621 defer conn.Close() 622 623 muxOpts := []runtime.ServeMuxOption{ 624 runtime.WithForwardResponseOption(httpmiddleware.HTTPResponseModifier), 625 runtime.WithErrorHandler(func(c context.Context, sr *runtime.ServeMux, mm runtime.Marshaler, w http.ResponseWriter, r *http.Request, e error) { 626 intCode := serverErrors.ConvertToEncodedErrorCode(status.Convert(e)) 627 httpmiddleware.CustomHTTPErrorHandler(c, w, r, serverErrors.NewEncodedError(intCode, e.Error())) 628 }), 629 runtime.WithStreamErrorHandler(func(ctx context.Context, e error) *status.Status { 630 intCode := serverErrors.ConvertToEncodedErrorCode(status.Convert(e)) 631 encodedErr := serverErrors.NewEncodedError(intCode, e.Error()) 632 return status.Convert(encodedErr) 633 }), 634 runtime.WithHealthzEndpoint(healthv1pb.NewHealthClient(conn)), 635 runtime.WithOutgoingHeaderMatcher(func(s string) (string, bool) { return s, true }), 636 } 637 mux := runtime.NewServeMux(muxOpts...) 638 if err := openfgav1.RegisterOpenFGAServiceHandler(ctx, mux, conn); err != nil { 639 return err 640 } 641 642 httpServer = &http.Server{ 643 Addr: config.HTTP.Addr, 644 Handler: recovery.HTTPPanicRecoveryHandler(cors.New(cors.Options{ 645 AllowedOrigins: config.HTTP.CORSAllowedOrigins, 646 AllowCredentials: true, 647 AllowedHeaders: config.HTTP.CORSAllowedHeaders, 648 AllowedMethods: []string{http.MethodGet, http.MethodPost, 649 http.MethodHead, http.MethodPatch, http.MethodDelete, http.MethodPut}, 650 }).Handler(mux), s.Logger), 651 } 652 653 go func() { 654 var err error 655 if config.HTTP.TLS.Enabled { 656 if config.HTTP.TLS.CertPath == "" || config.HTTP.TLS.KeyPath == "" { 657 s.Logger.Fatal("'http.tls.cert' and 'http.tls.key' configs must be set") 658 } 659 err = httpServer.ListenAndServeTLS(config.HTTP.TLS.CertPath, config.HTTP.TLS.KeyPath) 660 } else { 661 err = httpServer.ListenAndServe() 662 } 663 if err != http.ErrServerClosed { 664 s.Logger.Fatal("HTTP server closed with unexpected error", zap.Error(err)) 665 } 666 }() 667 s.Logger.Info(fmt.Sprintf("HTTP server listening on '%s'...", httpServer.Addr)) 668 } 669 670 var playground *http.Server 671 if config.Playground.Enabled { 672 if !config.HTTP.Enabled { 673 return errors.New("the HTTP server must be enabled to run the openfga playground") 674 } 675 676 authMethod := config.Authn.Method 677 if !(authMethod == "none" || authMethod == "preshared") { 678 return errors.New("the playground only supports authn methods 'none' and 'preshared'") 679 } 680 681 playgroundAddr := fmt.Sprintf(":%d", config.Playground.Port) 682 s.Logger.Info(fmt.Sprintf("๐ starting openfga playground on http://localhost%s/playground", playgroundAddr)) 683 684 tmpl, err := template.ParseFS(assets.EmbedPlayground, "playground/index.html") 685 if err != nil { 686 return fmt.Errorf("failed to parse playground index.html as Go template: %w", err) 687 } 688 689 fileServer := http.FileServer(http.FS(assets.EmbedPlayground)) 690 691 policy := backoff.NewExponentialBackOff() 692 policy.MaxElapsedTime = 3 * time.Second 693 694 var conn net.Conn 695 err = backoff.Retry( 696 func() error { 697 conn, err = net.Dial("tcp", config.HTTP.Addr) 698 return err 699 }, 700 policy, 701 ) 702 if err != nil { 703 return fmt.Errorf("failed to establish playground connection to HTTP server: %w", err) 704 } 705 706 playgroundAPIToken := "" 707 if authMethod == "preshared" { 708 playgroundAPIToken = config.Authn.AuthnPresharedKeyConfig.Keys[0] 709 } 710 711 mux := http.NewServeMux() 712 mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 713 if strings.HasPrefix(r.URL.Path, "/playground") { 714 if r.URL.Path == "/playground" || r.URL.Path == "/playground/index.html" { 715 err = tmpl.Execute(w, struct { 716 HTTPServerURL string 717 PlaygroundAPIToken string 718 }{ 719 HTTPServerURL: conn.RemoteAddr().String(), 720 PlaygroundAPIToken: playgroundAPIToken, 721 }) 722 if err != nil { 723 w.WriteHeader(http.StatusInternalServerError) 724 s.Logger.Error("failed to execute/render the playground web template", zap.Error(err)) 725 } 726 727 return 728 } 729 730 fileServer.ServeHTTP(w, r) 731 return 732 } 733 734 http.NotFound(w, r) 735 })) 736 737 playground = &http.Server{Addr: playgroundAddr, Handler: mux} 738 739 go func() { 740 err = playground.ListenAndServe() 741 if err != http.ErrServerClosed { 742 s.Logger.Fatal("failed to start the openfga playground server", zap.Error(err)) 743 } 744 s.Logger.Info("shutdown the openfga playground server") 745 }() 746 } 747 748 done := make(chan os.Signal, 1) 749 signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) 750 751 select { 752 case <-done: 753 case <-ctx.Done(): 754 } 755 s.Logger.Info("attempting to shutdown gracefully") 756 757 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 758 defer cancel() 759 760 if playground != nil { 761 if err := playground.Shutdown(ctx); err != nil { 762 s.Logger.Info("failed to gracefully shutdown playground server", zap.Error(err)) 763 } 764 } 765 766 if httpServer != nil { 767 if err := httpServer.Shutdown(ctx); err != nil { 768 s.Logger.Info("failed to shutdown the http server", zap.Error(err)) 769 } 770 } 771 772 grpcServer.GracefulStop() 773 774 svr.Close() 775 776 authenticator.Close() 777 778 datastore.Close() 779 780 if tracerProviderCloser != nil { 781 tracerProviderCloser() 782 } 783 784 s.Logger.Info("server exited. goodbye ๐") 785 786 return nil 787 }