github.com/openfga/openfga@v1.5.4-rc1/cmd/run/run.go (about)

     1  // Package run contains the command to run an OpenFGA server.
     2  package run
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"html/template"
     9  	"net"
    10  	"net/http"
    11  	"net/http/pprof"
    12  	"os"
    13  	"os/signal"
    14  	goruntime "runtime"
    15  	"strconv"
    16  	"strings"
    17  	"syscall"
    18  	"time"
    19  
    20  	"github.com/cenkalti/backoff/v4"
    21  	grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
    22  	grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
    23  	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
    24  	"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
    25  	grpc_prometheus "github.com/jon-whit/go-grpc-prometheus"
    26  	openfgav1 "github.com/openfga/api/proto/openfga/v1"
    27  	"github.com/prometheus/client_golang/prometheus/promhttp"
    28  	"github.com/rs/cors"
    29  	"github.com/spf13/cobra"
    30  	"github.com/spf13/viper"
    31  	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
    32  	"go.opentelemetry.io/otel"
    33  	semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
    34  	"go.opentelemetry.io/otel/trace/noop"
    35  	"go.uber.org/zap"
    36  	"google.golang.org/grpc"
    37  	"google.golang.org/grpc/credentials"
    38  	"google.golang.org/grpc/credentials/insecure"
    39  	healthv1pb "google.golang.org/grpc/health/grpc_health_v1"
    40  	"google.golang.org/grpc/reflection"
    41  	"google.golang.org/grpc/status"
    42  
    43  	"github.com/openfga/openfga/pkg/gateway"
    44  
    45  	"github.com/openfga/openfga/assets"
    46  	"github.com/openfga/openfga/internal/authn"
    47  	"github.com/openfga/openfga/internal/authn/oidc"
    48  	"github.com/openfga/openfga/internal/authn/presharedkey"
    49  	"github.com/openfga/openfga/internal/build"
    50  	authnmw "github.com/openfga/openfga/internal/middleware/authn"
    51  	serverconfig "github.com/openfga/openfga/internal/server/config"
    52  	"github.com/openfga/openfga/pkg/logger"
    53  	"github.com/openfga/openfga/pkg/middleware"
    54  	httpmiddleware "github.com/openfga/openfga/pkg/middleware/http"
    55  	"github.com/openfga/openfga/pkg/middleware/logging"
    56  	"github.com/openfga/openfga/pkg/middleware/recovery"
    57  	"github.com/openfga/openfga/pkg/middleware/requestid"
    58  	"github.com/openfga/openfga/pkg/middleware/storeid"
    59  	"github.com/openfga/openfga/pkg/middleware/validator"
    60  	"github.com/openfga/openfga/pkg/server"
    61  	serverErrors "github.com/openfga/openfga/pkg/server/errors"
    62  	"github.com/openfga/openfga/pkg/server/health"
    63  	"github.com/openfga/openfga/pkg/storage"
    64  	"github.com/openfga/openfga/pkg/storage/memory"
    65  	"github.com/openfga/openfga/pkg/storage/mysql"
    66  	"github.com/openfga/openfga/pkg/storage/postgres"
    67  	"github.com/openfga/openfga/pkg/storage/sqlcommon"
    68  	"github.com/openfga/openfga/pkg/storage/storagewrappers"
    69  	"github.com/openfga/openfga/pkg/telemetry"
    70  )
    71  
    72  const (
    73  	datastoreEngineFlag = "datastore-engine"
    74  	datastoreURIFlag    = "datastore-uri"
    75  )
    76  
    77  func NewRunCommand() *cobra.Command {
    78  	cmd := &cobra.Command{
    79  		Use:   "run",
    80  		Short: "Run the OpenFGA server",
    81  		Long:  "Run the OpenFGA server.",
    82  		Run:   run,
    83  		Args:  cobra.NoArgs,
    84  	}
    85  
    86  	defaultConfig := serverconfig.DefaultConfig()
    87  	flags := cmd.Flags()
    88  
    89  	flags.StringSlice("experimentals", defaultConfig.Experimentals, "a list of experimental features to enable. Allowed values: `enable-list-users`")
    90  
    91  	flags.String("grpc-addr", defaultConfig.GRPC.Addr, "the host:port address to serve the grpc server on")
    92  
    93  	flags.Bool("grpc-tls-enabled", defaultConfig.GRPC.TLS.Enabled, "enable/disable transport layer security (TLS)")
    94  
    95  	flags.String("grpc-tls-cert", defaultConfig.GRPC.TLS.CertPath, "the (absolute) file path of the certificate to use for the TLS connection")
    96  
    97  	flags.String("grpc-tls-key", defaultConfig.GRPC.TLS.KeyPath, "the (absolute) file path of the TLS key that should be used for the TLS connection")
    98  
    99  	cmd.MarkFlagsRequiredTogether("grpc-tls-enabled", "grpc-tls-cert", "grpc-tls-key")
   100  
   101  	flags.Bool("http-enabled", defaultConfig.HTTP.Enabled, "enable/disable the OpenFGA HTTP server")
   102  
   103  	flags.String("http-addr", defaultConfig.HTTP.Addr, "the host:port address to serve the HTTP server on")
   104  
   105  	flags.Bool("http-tls-enabled", defaultConfig.HTTP.TLS.Enabled, "enable/disable transport layer security (TLS)")
   106  
   107  	flags.String("http-tls-cert", defaultConfig.HTTP.TLS.CertPath, "the (absolute) file path of the certificate to use for the TLS connection")
   108  
   109  	flags.String("http-tls-key", defaultConfig.HTTP.TLS.KeyPath, "the (absolute) file path of the TLS key that should be used for the TLS connection")
   110  
   111  	cmd.MarkFlagsRequiredTogether("http-tls-enabled", "http-tls-cert", "http-tls-key")
   112  
   113  	flags.Duration("http-upstream-timeout", defaultConfig.HTTP.UpstreamTimeout, "the timeout duration for proxying HTTP requests upstream to the grpc endpoint")
   114  
   115  	flags.StringSlice("http-cors-allowed-origins", defaultConfig.HTTP.CORSAllowedOrigins, "specifies the CORS allowed origins")
   116  
   117  	flags.StringSlice("http-cors-allowed-headers", defaultConfig.HTTP.CORSAllowedHeaders, "specifies the CORS allowed headers")
   118  
   119  	flags.String("authn-method", defaultConfig.Authn.Method, "the authentication method to use")
   120  
   121  	flags.StringSlice("authn-preshared-keys", defaultConfig.Authn.Keys, "one or more preshared keys to use for authentication")
   122  
   123  	flags.String("authn-oidc-audience", defaultConfig.Authn.Audience, "the OIDC audience of the tokens being signed by the authorization server")
   124  
   125  	flags.String("authn-oidc-issuer", defaultConfig.Authn.Issuer, "the OIDC issuer (authorization server) signing the tokens, and where the keys will be fetched from")
   126  
   127  	flags.StringSlice("authn-oidc-issuer-aliases", defaultConfig.Authn.IssuerAliases, "the OIDC issuer DNS aliases that will be accepted as valid when verifying tokens")
   128  
   129  	flags.String("datastore-engine", defaultConfig.Datastore.Engine, "the datastore engine that will be used for persistence")
   130  
   131  	flags.String("datastore-uri", defaultConfig.Datastore.URI, "the connection uri to use to connect to the datastore (for any engine other than 'memory')")
   132  
   133  	flags.String("datastore-username", "", "the connection username to use to connect to the datastore (overwrites any username provided in the connection uri)")
   134  
   135  	flags.String("datastore-password", "", "the connection password to use to connect to the datastore (overwrites any password provided in the connection uri)")
   136  
   137  	flags.Int("datastore-max-cache-size", defaultConfig.Datastore.MaxCacheSize, "the maximum number of cache keys that the storage cache can store before evicting old keys")
   138  
   139  	flags.Int("datastore-max-open-conns", defaultConfig.Datastore.MaxOpenConns, "the maximum number of open connections to the datastore")
   140  
   141  	flags.Int("datastore-max-idle-conns", defaultConfig.Datastore.MaxIdleConns, "the maximum number of connections to the datastore in the idle connection pool")
   142  
   143  	flags.Duration("datastore-conn-max-idle-time", defaultConfig.Datastore.ConnMaxIdleTime, "the maximum amount of time a connection to the datastore may be idle")
   144  
   145  	flags.Duration("datastore-conn-max-lifetime", defaultConfig.Datastore.ConnMaxLifetime, "the maximum amount of time a connection to the datastore may be reused")
   146  
   147  	flags.Bool("datastore-metrics-enabled", defaultConfig.Datastore.Metrics.Enabled, "enable/disable sql metrics")
   148  
   149  	flags.Bool("playground-enabled", defaultConfig.Playground.Enabled, "enable/disable the OpenFGA Playground")
   150  
   151  	flags.Int("playground-port", defaultConfig.Playground.Port, "the port to serve the local OpenFGA Playground on")
   152  
   153  	flags.Bool("profiler-enabled", defaultConfig.Profiler.Enabled, "enable/disable pprof profiling")
   154  
   155  	flags.String("profiler-addr", defaultConfig.Profiler.Addr, "the host:port address to serve the pprof profiler server on")
   156  
   157  	flags.String("log-format", defaultConfig.Log.Format, "the log format to output logs in")
   158  
   159  	flags.String("log-level", defaultConfig.Log.Level, "the log level to use")
   160  
   161  	flags.String("log-timestamp-format", defaultConfig.Log.TimestampFormat, "the timestamp format to use for log messages")
   162  
   163  	flags.Bool("trace-enabled", defaultConfig.Trace.Enabled, "enable tracing")
   164  
   165  	flags.String("trace-otlp-endpoint", defaultConfig.Trace.OTLP.Endpoint, "the endpoint of the trace collector")
   166  
   167  	flags.Bool("trace-otlp-tls-enabled", defaultConfig.Trace.OTLP.TLS.Enabled, "use TLS connection for trace collector")
   168  
   169  	flags.Float64("trace-sample-ratio", defaultConfig.Trace.SampleRatio, "the fraction of traces to sample. 1 means all, 0 means none.")
   170  
   171  	flags.String("trace-service-name", defaultConfig.Trace.ServiceName, "the service name included in sampled traces.")
   172  
   173  	flags.Bool("metrics-enabled", defaultConfig.Metrics.Enabled, "enable/disable prometheus metrics on the '/metrics' endpoint")
   174  
   175  	flags.String("metrics-addr", defaultConfig.Metrics.Addr, "the host:port address to serve the prometheus metrics server on")
   176  
   177  	flags.Bool("metrics-enable-rpc-histograms", defaultConfig.Metrics.EnableRPCHistograms, "enables prometheus histogram metrics for RPC latency distributions")
   178  
   179  	flags.Int("max-tuples-per-write", defaultConfig.MaxTuplesPerWrite, "the maximum allowed number of tuples per Write transaction")
   180  
   181  	flags.Int("max-types-per-authorization-model", defaultConfig.MaxTypesPerAuthorizationModel, "the maximum allowed number of type definitions per authorization model")
   182  
   183  	flags.Int("max-authorization-model-size-in-bytes", defaultConfig.MaxAuthorizationModelSizeInBytes, "the maximum size in bytes allowed for persisting an Authorization Model.")
   184  
   185  	flags.Uint32("max-concurrent-reads-for-list-users", defaultConfig.MaxConcurrentReadsForListUsers, "the maximum allowed number of concurrent datastore reads in a single ListUsers query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.")
   186  
   187  	flags.Uint32("max-concurrent-reads-for-list-objects", defaultConfig.MaxConcurrentReadsForListObjects, "the maximum allowed number of concurrent datastore reads in a single ListObjects or StreamedListObjects query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.")
   188  
   189  	flags.Uint32("max-concurrent-reads-for-check", defaultConfig.MaxConcurrentReadsForCheck, "the maximum allowed number of concurrent datastore reads in a single Check query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.")
   190  
   191  	flags.Int("changelog-horizon-offset", defaultConfig.ChangelogHorizonOffset, "the offset (in minutes) from the current time. Changes that occur after this offset will not be included in the response of ReadChanges")
   192  
   193  	flags.Uint32("resolve-node-limit", defaultConfig.ResolveNodeLimit, "maximum resolution depth to attempt before throwing an error (defines how deeply nested an authorization model can be before a query errors out).")
   194  
   195  	flags.Uint32("resolve-node-breadth-limit", defaultConfig.ResolveNodeBreadthLimit, "defines how many nodes on a given level can be evaluated concurrently in a Check resolution tree")
   196  
   197  	flags.Duration("listObjects-deadline", defaultConfig.ListObjectsDeadline, "the timeout deadline for serving ListObjects and StreamedListObjects requests")
   198  
   199  	flags.Uint32("listObjects-max-results", defaultConfig.ListObjectsMaxResults, "the maximum results to return in non-streaming ListObjects API responses. If 0, all results can be returned")
   200  
   201  	flags.Duration("listUsers-deadline", defaultConfig.ListUsersDeadline, "the timeout deadline for serving ListUsers requests. If 0, there is no deadline")
   202  
   203  	flags.Uint32("listUsers-max-results", defaultConfig.ListUsersMaxResults, "the maximum results to return in ListUsers API responses. If 0, all results can be returned")
   204  
   205  	flags.Bool("check-query-cache-enabled", defaultConfig.CheckQueryCache.Enabled, "when executing Check and ListObjects requests, enables caching. This will turn Check and ListObjects responses into eventually consistent responses")
   206  
   207  	flags.Uint32("check-query-cache-limit", defaultConfig.CheckQueryCache.Limit, "if caching of Check and ListObjects calls is enabled, this is the size limit of the cache")
   208  
   209  	flags.Duration("check-query-cache-ttl", defaultConfig.CheckQueryCache.TTL, "if caching of Check and ListObjects is enabled, this is the TTL of each value")
   210  
   211  	// Unfortunately UintSlice/IntSlice does not work well when used as environment variable, we need to stick with string slice and convert back to integer
   212  	flags.StringSlice("request-duration-datastore-query-count-buckets", defaultConfig.RequestDurationDatastoreQueryCountBuckets, "datastore query count buckets used in labelling request_duration_ms.")
   213  
   214  	flags.StringSlice("request-duration-dispatch-count-buckets", defaultConfig.RequestDurationDispatchCountBuckets, "dispatch count (i.e number of concurrent traversals to resolve a query) buckets used in labelling request_duration_ms.")
   215  
   216  	flags.Bool("dispatch-throttling-enabled", defaultConfig.DispatchThrottling.Enabled, "enable throttling when request's number of dispatches is high. Enabling this feature will prioritize dispatched requests requiring less than the configured dispatch threshold over requests whose dispatch count exceeds the configured threshold.")
   217  
   218  	flags.Duration("dispatch-throttling-frequency", defaultConfig.DispatchThrottling.Frequency, "defines how frequent dispatch throttling will be evaluated. Frequency controls how frequently throttled dispatch requests are dispatched.")
   219  
   220  	flags.Uint32("dispatch-throttling-threshold", defaultConfig.DispatchThrottling.Threshold, "define the default threshold on number of dispatches above which requests will be throttled.")
   221  
   222  	flags.Uint32("dispatch-throttling-max-threshold", defaultConfig.DispatchThrottling.MaxThreshold, "define the maximum dispatch threshold beyond which requests will be throttled. 0 will use the 'dispatch-throttling-threshold' value as maximum")
   223  
   224  	flags.Duration("request-timeout", defaultConfig.RequestTimeout, "configures request timeout.  If both HTTP upstream timeout and request timeout are specified, request timeout will be used.")
   225  
   226  	// NOTE: if you add a new flag here, update the function below, too
   227  
   228  	cmd.PreRun = bindRunFlagsFunc(flags)
   229  
   230  	return cmd
   231  }
   232  
   233  // ReadConfig returns the OpenFGA server configuration based on the values provided in the server's 'config.yaml' file.
   234  // The 'config.yaml' file is loaded from '/etc/openfga', '$HOME/.openfga', or the current working directory. If no configuration
   235  // file is present, the default values are returned.
   236  func ReadConfig() (*serverconfig.Config, error) {
   237  	config := serverconfig.DefaultConfig()
   238  
   239  	viper.SetTypeByDefaultValue(true)
   240  	err := viper.ReadInConfig()
   241  	if err != nil {
   242  		if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
   243  			return nil, fmt.Errorf("failed to load server config: %w", err)
   244  		}
   245  	}
   246  
   247  	if err := viper.Unmarshal(config); err != nil {
   248  		return nil, fmt.Errorf("failed to unmarshal server config: %w", err)
   249  	}
   250  
   251  	return config, nil
   252  }
   253  
   254  func run(_ *cobra.Command, _ []string) {
   255  	config, err := ReadConfig()
   256  	if err != nil {
   257  		panic(err)
   258  	}
   259  
   260  	if err := config.Verify(); err != nil {
   261  		panic(err)
   262  	}
   263  
   264  	logger := logger.MustNewLogger(config.Log.Format, config.Log.Level, config.Log.TimestampFormat)
   265  
   266  	serverCtx := &ServerContext{Logger: logger}
   267  	if err := serverCtx.Run(context.Background(), config); err != nil {
   268  		panic(err)
   269  	}
   270  }
   271  
   272  type ServerContext struct {
   273  	Logger logger.Logger
   274  }
   275  
   276  func convertStringArrayToUintArray(stringArray []string) []uint {
   277  	uintArray := []uint{}
   278  	for _, val := range stringArray {
   279  		// note that we have already validated whether the array item is non-negative integer
   280  		valInt, err := strconv.Atoi(val)
   281  		if err == nil {
   282  			uintArray = append(uintArray, uint(valInt))
   283  		}
   284  	}
   285  	return uintArray
   286  }
   287  
   288  func (s *ServerContext) telemetryConfig(ctx context.Context, config *serverconfig.Config) func() {
   289  	var tracerProviderCloser func()
   290  
   291  	if config.Trace.Enabled {
   292  		s.Logger.Info(fmt.Sprintf("๐Ÿ•ต tracing enabled: sampling ratio is %v and sending traces to '%s', tls: %t", config.Trace.SampleRatio, config.Trace.OTLP.Endpoint, config.Trace.OTLP.TLS.Enabled))
   293  
   294  		options := []telemetry.TracerOption{
   295  			telemetry.WithOTLPEndpoint(
   296  				config.Trace.OTLP.Endpoint,
   297  			),
   298  			telemetry.WithAttributes(
   299  				semconv.ServiceNameKey.String(config.Trace.ServiceName),
   300  				semconv.ServiceVersionKey.String(build.Version),
   301  			),
   302  			telemetry.WithSamplingRatio(config.Trace.SampleRatio),
   303  		}
   304  
   305  		if !config.Trace.OTLP.TLS.Enabled {
   306  			options = append(options, telemetry.WithOTLPInsecure())
   307  		}
   308  
   309  		tp := telemetry.MustNewTracerProvider(options...)
   310  		tracerProviderCloser = func() {
   311  			_ = tp.ForceFlush(ctx)
   312  			_ = tp.Shutdown(ctx)
   313  		}
   314  	} else {
   315  		otel.SetTracerProvider(noop.NewTracerProvider())
   316  	}
   317  	return tracerProviderCloser
   318  }
   319  
   320  func (s *ServerContext) datastoreConfig(config *serverconfig.Config) (storage.OpenFGADatastore, error) {
   321  	datastoreOptions := []sqlcommon.DatastoreOption{
   322  		sqlcommon.WithUsername(config.Datastore.Username),
   323  		sqlcommon.WithPassword(config.Datastore.Password),
   324  		sqlcommon.WithLogger(s.Logger),
   325  		sqlcommon.WithMaxTuplesPerWrite(config.MaxTuplesPerWrite),
   326  		sqlcommon.WithMaxTypesPerAuthorizationModel(config.MaxTypesPerAuthorizationModel),
   327  		sqlcommon.WithMaxOpenConns(config.Datastore.MaxOpenConns),
   328  		sqlcommon.WithMaxIdleConns(config.Datastore.MaxIdleConns),
   329  		sqlcommon.WithConnMaxIdleTime(config.Datastore.ConnMaxIdleTime),
   330  		sqlcommon.WithConnMaxLifetime(config.Datastore.ConnMaxLifetime),
   331  	}
   332  
   333  	if config.Datastore.Metrics.Enabled {
   334  		datastoreOptions = append(datastoreOptions, sqlcommon.WithMetrics())
   335  	}
   336  
   337  	dsCfg := sqlcommon.NewConfig(datastoreOptions...)
   338  
   339  	var datastore storage.OpenFGADatastore
   340  	var err error
   341  	switch config.Datastore.Engine {
   342  	case "memory":
   343  		opts := []memory.StorageOption{
   344  			memory.WithMaxTypesPerAuthorizationModel(config.MaxTypesPerAuthorizationModel),
   345  			memory.WithMaxTuplesPerWrite(config.MaxTuplesPerWrite),
   346  		}
   347  		datastore = memory.New(opts...)
   348  	case "mysql":
   349  		datastore, err = mysql.New(config.Datastore.URI, dsCfg)
   350  		if err != nil {
   351  			return nil, fmt.Errorf("initialize mysql datastore: %w", err)
   352  		}
   353  	case "postgres":
   354  		datastore, err = postgres.New(config.Datastore.URI, dsCfg)
   355  		if err != nil {
   356  			return nil, fmt.Errorf("initialize postgres datastore: %w", err)
   357  		}
   358  	default:
   359  		return nil, fmt.Errorf("storage engine '%s' is unsupported", config.Datastore.Engine)
   360  	}
   361  	datastore = storagewrappers.NewCachedOpenFGADatastore(storagewrappers.NewContextWrapper(datastore), config.Datastore.MaxCacheSize)
   362  
   363  	s.Logger.Info(fmt.Sprintf("using '%v' storage engine", config.Datastore.Engine))
   364  	return datastore, nil
   365  }
   366  
   367  func (s *ServerContext) authenticatorConfig(config *serverconfig.Config) (authn.Authenticator, error) {
   368  	var authenticator authn.Authenticator
   369  	var err error
   370  
   371  	switch config.Authn.Method {
   372  	case "none":
   373  		s.Logger.Warn("authentication is disabled")
   374  		authenticator = authn.NoopAuthenticator{}
   375  	case "preshared":
   376  		s.Logger.Info("using 'preshared' authentication")
   377  		authenticator, err = presharedkey.NewPresharedKeyAuthenticator(config.Authn.Keys)
   378  	case "oidc":
   379  		s.Logger.Info("using 'oidc' authentication")
   380  		authenticator, err = oidc.NewRemoteOidcAuthenticator(config.Authn.Issuer, config.Authn.IssuerAliases, config.Authn.Audience)
   381  	default:
   382  		return nil, fmt.Errorf("unsupported authentication method '%v'", config.Authn.Method)
   383  	}
   384  	if err != nil {
   385  		return nil, fmt.Errorf("failed to initialize authenticator: %w", err)
   386  	}
   387  	return authenticator, nil
   388  }
   389  
   390  // Run returns an error if the server was unable to start successfully.
   391  // If it started and terminated successfully, it returns a nil error.
   392  func (s *ServerContext) Run(ctx context.Context, config *serverconfig.Config) error {
   393  	tracerProviderCloser := s.telemetryConfig(ctx, config)
   394  
   395  	s.Logger.Info(fmt.Sprintf("๐Ÿงช experimental features enabled: %v", config.Experimentals))
   396  
   397  	var experimentals []server.ExperimentalFeatureFlag
   398  	for _, feature := range config.Experimentals {
   399  		experimentals = append(experimentals, server.ExperimentalFeatureFlag(feature))
   400  	}
   401  
   402  	datastore, err := s.datastoreConfig(config)
   403  	if err != nil {
   404  		return err
   405  	}
   406  
   407  	authenticator, err := s.authenticatorConfig(config)
   408  
   409  	if err != nil {
   410  		return err
   411  	}
   412  
   413  	serverOpts := []grpc.ServerOption{
   414  		grpc.MaxRecvMsgSize(serverconfig.DefaultMaxRPCMessageSizeInBytes),
   415  		grpc.ChainUnaryInterceptor(
   416  			[]grpc.UnaryServerInterceptor{
   417  				grpc_recovery.UnaryServerInterceptor( // panic middleware must be 1st in chain
   418  					grpc_recovery.WithRecoveryHandlerContext(
   419  						recovery.PanicRecoveryHandler(s.Logger),
   420  					),
   421  				),
   422  				grpc_ctxtags.UnaryServerInterceptor(), // needed for logging
   423  				requestid.NewUnaryInterceptor(),       // add request_id to ctxtags
   424  			}...,
   425  		),
   426  		grpc.ChainStreamInterceptor(
   427  			[]grpc.StreamServerInterceptor{
   428  				grpc_recovery.StreamServerInterceptor(
   429  					grpc_recovery.WithRecoveryHandlerContext(
   430  						recovery.PanicRecoveryHandler(s.Logger),
   431  					),
   432  				),
   433  				requestid.NewStreamingInterceptor(),
   434  			}...,
   435  		),
   436  	}
   437  
   438  	if config.RequestTimeout > 0 {
   439  		timeoutMiddleware := middleware.NewTimeoutInterceptor(config.RequestTimeout, s.Logger)
   440  
   441  		serverOpts = append(serverOpts, grpc.ChainUnaryInterceptor(timeoutMiddleware.NewUnaryTimeoutInterceptor()))
   442  		serverOpts = append(serverOpts, grpc.ChainStreamInterceptor(timeoutMiddleware.NewStreamTimeoutInterceptor()))
   443  	}
   444  
   445  	serverOpts = append(serverOpts,
   446  		grpc.ChainUnaryInterceptor(
   447  			[]grpc.UnaryServerInterceptor{
   448  				storeid.NewUnaryInterceptor(),           // if available, add store_id to ctxtags
   449  				logging.NewLoggingInterceptor(s.Logger), // needed to log invalid requests
   450  				validator.UnaryServerInterceptor(),
   451  			}...,
   452  		),
   453  		grpc.ChainStreamInterceptor(
   454  			[]grpc.StreamServerInterceptor{
   455  				validator.StreamServerInterceptor(),
   456  				grpc_ctxtags.StreamServerInterceptor(),
   457  			}...,
   458  		),
   459  	)
   460  
   461  	if config.Metrics.Enabled {
   462  		serverOpts = append(serverOpts,
   463  			grpc.ChainUnaryInterceptor(grpc_prometheus.UnaryServerInterceptor),
   464  			grpc.ChainStreamInterceptor(grpc_prometheus.StreamServerInterceptor))
   465  
   466  		if config.Metrics.EnableRPCHistograms {
   467  			grpc_prometheus.EnableHandlingTimeHistogram()
   468  		}
   469  	}
   470  
   471  	if config.Trace.Enabled {
   472  		serverOpts = append(serverOpts, grpc.StatsHandler(otelgrpc.NewServerHandler()))
   473  	}
   474  
   475  	serverOpts = append(serverOpts, grpc.ChainUnaryInterceptor(
   476  		[]grpc.UnaryServerInterceptor{
   477  			grpcauth.UnaryServerInterceptor(authnmw.AuthFunc(authenticator)),
   478  		}...),
   479  		grpc.ChainStreamInterceptor(
   480  			[]grpc.StreamServerInterceptor{
   481  				grpcauth.StreamServerInterceptor(authnmw.AuthFunc(authenticator)),
   482  				// The following interceptors wrap the server stream with our own
   483  				// wrapper and must come last.
   484  				storeid.NewStreamingInterceptor(),
   485  				logging.NewStreamingLoggingInterceptor(s.Logger),
   486  			}...,
   487  		),
   488  	)
   489  
   490  	if config.GRPC.TLS.Enabled {
   491  		if config.GRPC.TLS.CertPath == "" || config.GRPC.TLS.KeyPath == "" {
   492  			return errors.New("'grpc.tls.cert' and 'grpc.tls.key' configs must be set")
   493  		}
   494  		creds, err := credentials.NewServerTLSFromFile(config.GRPC.TLS.CertPath, config.GRPC.TLS.KeyPath)
   495  		if err != nil {
   496  			return err
   497  		}
   498  
   499  		serverOpts = append(serverOpts, grpc.Creds(creds))
   500  
   501  		s.Logger.Info("grpc TLS is enabled, serving connections using the provided certificate")
   502  	} else {
   503  		s.Logger.Warn("grpc TLS is disabled, serving connections using insecure plaintext")
   504  	}
   505  
   506  	if config.Profiler.Enabled {
   507  		mux := http.NewServeMux()
   508  		mux.HandleFunc("/debug/pprof/", pprof.Index)
   509  		mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
   510  		mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
   511  		mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
   512  		mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
   513  
   514  		go func() {
   515  			s.Logger.Info(fmt.Sprintf("๐Ÿ”ฌ starting pprof profiler on '%s'", config.Profiler.Addr))
   516  
   517  			if err := http.ListenAndServe(config.Profiler.Addr, mux); err != nil {
   518  				if err != http.ErrServerClosed {
   519  					s.Logger.Fatal("failed to start pprof profiler", zap.Error(err))
   520  				}
   521  			}
   522  		}()
   523  	}
   524  
   525  	if config.Metrics.Enabled {
   526  		s.Logger.Info(fmt.Sprintf("๐Ÿ“ˆ starting metrics server on '%s'", config.Metrics.Addr))
   527  
   528  		go func() {
   529  			mux := http.NewServeMux()
   530  			mux.Handle("/metrics", promhttp.Handler())
   531  			if err := http.ListenAndServe(config.Metrics.Addr, mux); err != nil {
   532  				if err != http.ErrServerClosed {
   533  					s.Logger.Fatal("failed to start prometheus metrics server", zap.Error(err))
   534  				}
   535  			}
   536  		}()
   537  	}
   538  
   539  	svr := server.MustNewServerWithOpts(
   540  		server.WithDatastore(datastore),
   541  		server.WithLogger(s.Logger),
   542  		server.WithTransport(gateway.NewRPCTransport(s.Logger)),
   543  		server.WithResolveNodeLimit(config.ResolveNodeLimit),
   544  		server.WithResolveNodeBreadthLimit(config.ResolveNodeBreadthLimit),
   545  		server.WithChangelogHorizonOffset(config.ChangelogHorizonOffset),
   546  		server.WithListObjectsDeadline(config.ListObjectsDeadline),
   547  		server.WithListObjectsMaxResults(config.ListObjectsMaxResults),
   548  		server.WithListUsersDeadline(config.ListUsersDeadline),
   549  		server.WithListUsersMaxResults(config.ListUsersMaxResults),
   550  		server.WithMaxConcurrentReadsForListObjects(config.MaxConcurrentReadsForListObjects),
   551  		server.WithMaxConcurrentReadsForCheck(config.MaxConcurrentReadsForCheck),
   552  		server.WithMaxConcurrentReadsForListUsers(config.MaxConcurrentReadsForListUsers),
   553  		server.WithCheckQueryCacheEnabled(config.CheckQueryCache.Enabled),
   554  		server.WithCheckQueryCacheLimit(config.CheckQueryCache.Limit),
   555  		server.WithCheckQueryCacheTTL(config.CheckQueryCache.TTL),
   556  		server.WithRequestDurationByQueryHistogramBuckets(convertStringArrayToUintArray(config.RequestDurationDatastoreQueryCountBuckets)),
   557  		server.WithRequestDurationByDispatchCountHistogramBuckets(convertStringArrayToUintArray(config.RequestDurationDispatchCountBuckets)),
   558  		server.WithMaxAuthorizationModelSizeInBytes(config.MaxAuthorizationModelSizeInBytes),
   559  		server.WithDispatchThrottlingCheckResolverEnabled(config.DispatchThrottling.Enabled),
   560  		server.WithDispatchThrottlingCheckResolverFrequency(config.DispatchThrottling.Frequency),
   561  		server.WithDispatchThrottlingCheckResolverThreshold(config.DispatchThrottling.Threshold),
   562  		server.WithDispatchThrottlingCheckResolverMaxThreshold(config.DispatchThrottling.MaxThreshold),
   563  		server.WithExperimentals(experimentals...),
   564  	)
   565  
   566  	s.Logger.Info(
   567  		"๐Ÿš€ starting openfga service...",
   568  		zap.String("version", build.Version),
   569  		zap.String("date", build.Date),
   570  		zap.String("commit", build.Commit),
   571  		zap.String("go-version", goruntime.Version()),
   572  	)
   573  
   574  	// nosemgrep: grpc-server-insecure-connection
   575  	grpcServer := grpc.NewServer(serverOpts...)
   576  	openfgav1.RegisterOpenFGAServiceServer(grpcServer, svr)
   577  	healthServer := &health.Checker{TargetService: svr, TargetServiceName: openfgav1.OpenFGAService_ServiceDesc.ServiceName}
   578  	healthv1pb.RegisterHealthServer(grpcServer, healthServer)
   579  	reflection.Register(grpcServer)
   580  
   581  	lis, err := net.Listen("tcp", config.GRPC.Addr)
   582  	if err != nil {
   583  		return fmt.Errorf("failed to listen: %w", err)
   584  	}
   585  
   586  	go func() {
   587  		if err := grpcServer.Serve(lis); err != nil {
   588  			if !errors.Is(err, grpc.ErrServerStopped) {
   589  				s.Logger.Fatal("failed to start grpc server", zap.Error(err))
   590  			}
   591  
   592  			s.Logger.Info("grpc server shut down..")
   593  		}
   594  	}()
   595  	s.Logger.Info(fmt.Sprintf("grpc server listening on '%s'...", config.GRPC.Addr))
   596  
   597  	var httpServer *http.Server
   598  	if config.HTTP.Enabled {
   599  		runtime.DefaultContextTimeout = serverconfig.DefaultContextTimeout(config)
   600  
   601  		dialOpts := []grpc.DialOption{
   602  			grpc.WithBlock(),
   603  		}
   604  		if config.GRPC.TLS.Enabled {
   605  			creds, err := credentials.NewClientTLSFromFile(config.GRPC.TLS.CertPath, "")
   606  			if err != nil {
   607  				s.Logger.Fatal("", zap.Error(err))
   608  			}
   609  			dialOpts = append(dialOpts, grpc.WithTransportCredentials(creds))
   610  		} else {
   611  			dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
   612  		}
   613  
   614  		timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
   615  		defer cancel()
   616  
   617  		conn, err := grpc.DialContext(timeoutCtx, config.GRPC.Addr, dialOpts...)
   618  		if err != nil {
   619  			s.Logger.Fatal("", zap.Error(err))
   620  		}
   621  		defer conn.Close()
   622  
   623  		muxOpts := []runtime.ServeMuxOption{
   624  			runtime.WithForwardResponseOption(httpmiddleware.HTTPResponseModifier),
   625  			runtime.WithErrorHandler(func(c context.Context, sr *runtime.ServeMux, mm runtime.Marshaler, w http.ResponseWriter, r *http.Request, e error) {
   626  				intCode := serverErrors.ConvertToEncodedErrorCode(status.Convert(e))
   627  				httpmiddleware.CustomHTTPErrorHandler(c, w, r, serverErrors.NewEncodedError(intCode, e.Error()))
   628  			}),
   629  			runtime.WithStreamErrorHandler(func(ctx context.Context, e error) *status.Status {
   630  				intCode := serverErrors.ConvertToEncodedErrorCode(status.Convert(e))
   631  				encodedErr := serverErrors.NewEncodedError(intCode, e.Error())
   632  				return status.Convert(encodedErr)
   633  			}),
   634  			runtime.WithHealthzEndpoint(healthv1pb.NewHealthClient(conn)),
   635  			runtime.WithOutgoingHeaderMatcher(func(s string) (string, bool) { return s, true }),
   636  		}
   637  		mux := runtime.NewServeMux(muxOpts...)
   638  		if err := openfgav1.RegisterOpenFGAServiceHandler(ctx, mux, conn); err != nil {
   639  			return err
   640  		}
   641  
   642  		httpServer = &http.Server{
   643  			Addr: config.HTTP.Addr,
   644  			Handler: recovery.HTTPPanicRecoveryHandler(cors.New(cors.Options{
   645  				AllowedOrigins:   config.HTTP.CORSAllowedOrigins,
   646  				AllowCredentials: true,
   647  				AllowedHeaders:   config.HTTP.CORSAllowedHeaders,
   648  				AllowedMethods: []string{http.MethodGet, http.MethodPost,
   649  					http.MethodHead, http.MethodPatch, http.MethodDelete, http.MethodPut},
   650  			}).Handler(mux), s.Logger),
   651  		}
   652  
   653  		go func() {
   654  			var err error
   655  			if config.HTTP.TLS.Enabled {
   656  				if config.HTTP.TLS.CertPath == "" || config.HTTP.TLS.KeyPath == "" {
   657  					s.Logger.Fatal("'http.tls.cert' and 'http.tls.key' configs must be set")
   658  				}
   659  				err = httpServer.ListenAndServeTLS(config.HTTP.TLS.CertPath, config.HTTP.TLS.KeyPath)
   660  			} else {
   661  				err = httpServer.ListenAndServe()
   662  			}
   663  			if err != http.ErrServerClosed {
   664  				s.Logger.Fatal("HTTP server closed with unexpected error", zap.Error(err))
   665  			}
   666  		}()
   667  		s.Logger.Info(fmt.Sprintf("HTTP server listening on '%s'...", httpServer.Addr))
   668  	}
   669  
   670  	var playground *http.Server
   671  	if config.Playground.Enabled {
   672  		if !config.HTTP.Enabled {
   673  			return errors.New("the HTTP server must be enabled to run the openfga playground")
   674  		}
   675  
   676  		authMethod := config.Authn.Method
   677  		if !(authMethod == "none" || authMethod == "preshared") {
   678  			return errors.New("the playground only supports authn methods 'none' and 'preshared'")
   679  		}
   680  
   681  		playgroundAddr := fmt.Sprintf(":%d", config.Playground.Port)
   682  		s.Logger.Info(fmt.Sprintf("๐Ÿ› starting openfga playground on http://localhost%s/playground", playgroundAddr))
   683  
   684  		tmpl, err := template.ParseFS(assets.EmbedPlayground, "playground/index.html")
   685  		if err != nil {
   686  			return fmt.Errorf("failed to parse playground index.html as Go template: %w", err)
   687  		}
   688  
   689  		fileServer := http.FileServer(http.FS(assets.EmbedPlayground))
   690  
   691  		policy := backoff.NewExponentialBackOff()
   692  		policy.MaxElapsedTime = 3 * time.Second
   693  
   694  		var conn net.Conn
   695  		err = backoff.Retry(
   696  			func() error {
   697  				conn, err = net.Dial("tcp", config.HTTP.Addr)
   698  				return err
   699  			},
   700  			policy,
   701  		)
   702  		if err != nil {
   703  			return fmt.Errorf("failed to establish playground connection to HTTP server: %w", err)
   704  		}
   705  
   706  		playgroundAPIToken := ""
   707  		if authMethod == "preshared" {
   708  			playgroundAPIToken = config.Authn.AuthnPresharedKeyConfig.Keys[0]
   709  		}
   710  
   711  		mux := http.NewServeMux()
   712  		mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   713  			if strings.HasPrefix(r.URL.Path, "/playground") {
   714  				if r.URL.Path == "/playground" || r.URL.Path == "/playground/index.html" {
   715  					err = tmpl.Execute(w, struct {
   716  						HTTPServerURL      string
   717  						PlaygroundAPIToken string
   718  					}{
   719  						HTTPServerURL:      conn.RemoteAddr().String(),
   720  						PlaygroundAPIToken: playgroundAPIToken,
   721  					})
   722  					if err != nil {
   723  						w.WriteHeader(http.StatusInternalServerError)
   724  						s.Logger.Error("failed to execute/render the playground web template", zap.Error(err))
   725  					}
   726  
   727  					return
   728  				}
   729  
   730  				fileServer.ServeHTTP(w, r)
   731  				return
   732  			}
   733  
   734  			http.NotFound(w, r)
   735  		}))
   736  
   737  		playground = &http.Server{Addr: playgroundAddr, Handler: mux}
   738  
   739  		go func() {
   740  			err = playground.ListenAndServe()
   741  			if err != http.ErrServerClosed {
   742  				s.Logger.Fatal("failed to start the openfga playground server", zap.Error(err))
   743  			}
   744  			s.Logger.Info("shutdown the openfga playground server")
   745  		}()
   746  	}
   747  
   748  	done := make(chan os.Signal, 1)
   749  	signal.Notify(done, syscall.SIGINT, syscall.SIGTERM)
   750  
   751  	select {
   752  	case <-done:
   753  	case <-ctx.Done():
   754  	}
   755  	s.Logger.Info("attempting to shutdown gracefully")
   756  
   757  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   758  	defer cancel()
   759  
   760  	if playground != nil {
   761  		if err := playground.Shutdown(ctx); err != nil {
   762  			s.Logger.Info("failed to gracefully shutdown playground server", zap.Error(err))
   763  		}
   764  	}
   765  
   766  	if httpServer != nil {
   767  		if err := httpServer.Shutdown(ctx); err != nil {
   768  			s.Logger.Info("failed to shutdown the http server", zap.Error(err))
   769  		}
   770  	}
   771  
   772  	grpcServer.GracefulStop()
   773  
   774  	svr.Close()
   775  
   776  	authenticator.Close()
   777  
   778  	datastore.Close()
   779  
   780  	if tracerProviderCloser != nil {
   781  		tracerProviderCloser()
   782  	}
   783  
   784  	s.Logger.Info("server exited. goodbye ๐Ÿ‘‹")
   785  
   786  	return nil
   787  }