github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/util/util.go (about)

     1  package util
     2  
     3  //go:generate go run github.com/ecordell/optgen -output zz_generated.options.go . GRPCServerConfig HTTPServerConfig
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"net/http"
    13  	"time"
    14  
    15  	"github.com/jzelinskie/stringz"
    16  	"github.com/rs/zerolog"
    17  	"github.com/spf13/cobra"
    18  	"github.com/spf13/pflag"
    19  	"google.golang.org/grpc"
    20  	"google.golang.org/grpc/credentials"
    21  	"google.golang.org/grpc/credentials/insecure"
    22  	"google.golang.org/grpc/keepalive"
    23  	"google.golang.org/grpc/test/bufconn"
    24  
    25  	// Register Snappy S2 compression
    26  	_ "github.com/mostynb/go-grpc-compression/experimental/s2"
    27  
    28  	"sigs.k8s.io/controller-runtime/pkg/certwatcher"
    29  	// Register cert watcher metrics
    30  	_ "sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics"
    31  
    32  	log "github.com/authzed/spicedb/internal/logging"
    33  	"github.com/authzed/spicedb/pkg/x509util"
    34  )
    35  
    36  const BufferedNetwork string = "buffnet"
    37  
    38  type GRPCServerConfig struct {
    39  	Address      string        `debugmap:"visible"`
    40  	Network      string        `debugmap:"visible"`
    41  	TLSCertPath  string        `debugmap:"visible"`
    42  	TLSKeyPath   string        `debugmap:"visible"`
    43  	MaxConnAge   time.Duration `debugmap:"visible"`
    44  	Enabled      bool          `debugmap:"visible"`
    45  	BufferSize   int           `debugmap:"visible"`
    46  	ClientCAPath string        `debugmap:"visible"`
    47  	MaxWorkers   uint32        `debugmap:"visible"`
    48  
    49  	flagPrefix string
    50  }
    51  
    52  // RegisterGRPCServerFlags adds the following flags for use with
    53  // GrpcServerFromFlags:
    54  // - "$PREFIX-addr"
    55  // - "$PREFIX-tls-cert-path"
    56  // - "$PREFIX-tls-key-path"
    57  // - "$PREFIX-max-conn-age"
    58  func RegisterGRPCServerFlags(flags *pflag.FlagSet, config *GRPCServerConfig, flagPrefix, serviceName, defaultAddr string, defaultEnabled bool) {
    59  	flagPrefix = stringz.DefaultEmpty(flagPrefix, "grpc")
    60  	serviceName = stringz.DefaultEmpty(serviceName, "grpc")
    61  	defaultAddr = stringz.DefaultEmpty(defaultAddr, ":50051")
    62  	config.flagPrefix = flagPrefix
    63  
    64  	flags.StringVar(&config.Address, flagPrefix+"-addr", defaultAddr, "address to listen on to serve "+serviceName)
    65  	flags.StringVar(&config.Network, flagPrefix+"-network", "tcp", "network type to serve "+serviceName+` ("tcp", "tcp4", "tcp6", "unix", "unixpacket")`)
    66  	flags.StringVar(&config.TLSCertPath, flagPrefix+"-tls-cert-path", "", "local path to the TLS certificate used to serve "+serviceName)
    67  	flags.StringVar(&config.TLSKeyPath, flagPrefix+"-tls-key-path", "", "local path to the TLS key used to serve "+serviceName)
    68  	flags.DurationVar(&config.MaxConnAge, flagPrefix+"-max-conn-age", 30*time.Second, "how long a connection serving "+serviceName+" should be able to live")
    69  	flags.BoolVar(&config.Enabled, flagPrefix+"-enabled", defaultEnabled, "enable "+serviceName+" gRPC server")
    70  	flags.Uint32Var(&config.MaxWorkers, flagPrefix+"-max-workers", 0, "set the number of workers for this server (0 value means 1 worker per request)")
    71  }
    72  
    73  type (
    74  	DialFunc    func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error)
    75  	NetDialFunc func(ctx context.Context, s string) (net.Conn, error)
    76  )
    77  
    78  // Complete takes a set of default options and returns a completed server
    79  func (c *GRPCServerConfig) Complete(level zerolog.Level, svcRegistrationFn func(server *grpc.Server), opts ...grpc.ServerOption) (RunnableGRPCServer, error) {
    80  	if !c.Enabled {
    81  		return &disabledGrpcServer{}, nil
    82  	}
    83  	if c.BufferSize == 0 {
    84  		c.BufferSize = 1024 * 1024
    85  	}
    86  	opts = append(opts, grpc.KeepaliveParams(keepalive.ServerParameters{
    87  		MaxConnectionAge: c.MaxConnAge,
    88  	}), grpc.NumStreamWorkers(c.MaxWorkers))
    89  
    90  	tlsOpts, certWatcher, err := c.tlsOpts()
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	opts = append(opts, tlsOpts...)
    95  
    96  	clientCreds, err := c.clientCreds()
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	l, dial, netDial, err := c.listenerAndDialer()
   102  	if err != nil {
   103  		return nil, fmt.Errorf("failed to listen on addr for gRPC server: %w", err)
   104  	}
   105  	log.WithLevel(level).
   106  		Str("addr", c.Address).
   107  		Str("network", c.Network).
   108  		Str("service", c.flagPrefix).
   109  		Uint32("workers", c.MaxWorkers).
   110  		Bool("insecure", c.TLSCertPath == "" && c.TLSKeyPath == "").
   111  		Msg("grpc server started serving")
   112  
   113  	srv := grpc.NewServer(opts...)
   114  	svcRegistrationFn(srv)
   115  	return &completedGRPCServer{
   116  		opts:              opts,
   117  		listener:          l,
   118  		svcRegistrationFn: svcRegistrationFn,
   119  		listenFunc: func() error {
   120  			return srv.Serve(l)
   121  		},
   122  		dial:    dial,
   123  		netDial: netDial,
   124  		prestopFunc: func() {
   125  			log.WithLevel(level).
   126  				Str("addr", c.Address).
   127  				Str("network", c.Network).
   128  				Str("service", c.flagPrefix).
   129  				Msg("grpc server stopped serving")
   130  		},
   131  		stopFunc:    srv.GracefulStop,
   132  		creds:       clientCreds,
   133  		certWatcher: certWatcher,
   134  	}, nil
   135  }
   136  
   137  func (c *GRPCServerConfig) listenerAndDialer() (net.Listener, DialFunc, NetDialFunc, error) {
   138  	if c.Network == BufferedNetwork {
   139  		bl := bufconn.Listen(c.BufferSize)
   140  		return bl, func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
   141  				opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
   142  					return bl.DialContext(ctx)
   143  				}))
   144  
   145  				return grpc.DialContext(ctx, BufferedNetwork, opts...)
   146  			}, func(ctx context.Context, s string) (net.Conn, error) {
   147  				return bl.DialContext(ctx)
   148  			}, nil
   149  	}
   150  	l, err := net.Listen(c.Network, c.Address)
   151  	if err != nil {
   152  		return nil, nil, nil, err
   153  	}
   154  	return l, func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
   155  		return grpc.DialContext(ctx, c.Address, opts...)
   156  	}, nil, nil
   157  }
   158  
   159  func (c *GRPCServerConfig) tlsOpts() ([]grpc.ServerOption, *certwatcher.CertWatcher, error) {
   160  	switch {
   161  	case c.TLSCertPath == "" && c.TLSKeyPath == "":
   162  		return nil, nil, nil
   163  	case c.TLSCertPath != "" && c.TLSKeyPath != "":
   164  		watcher, err := certwatcher.New(c.TLSCertPath, c.TLSKeyPath)
   165  		if err != nil {
   166  			return nil, nil, err
   167  		}
   168  		creds := credentials.NewTLS(&tls.Config{
   169  			GetCertificate: watcher.GetCertificate,
   170  			MinVersion:     tls.VersionTLS12,
   171  		})
   172  		return []grpc.ServerOption{grpc.Creds(creds)}, watcher, nil
   173  	default:
   174  		return nil, nil, nil
   175  	}
   176  }
   177  
   178  func (c *GRPCServerConfig) clientCreds() (credentials.TransportCredentials, error) {
   179  	switch {
   180  	case c.TLSCertPath == "" && c.TLSKeyPath == "":
   181  		return insecure.NewCredentials(), nil
   182  	case c.TLSCertPath != "" && c.TLSKeyPath != "":
   183  		var err error
   184  		var pool *x509.CertPool
   185  		if c.ClientCAPath != "" {
   186  			pool, err = x509util.CustomCertPool(c.ClientCAPath)
   187  		} else {
   188  			pool, err = x509.SystemCertPool()
   189  		}
   190  		if err != nil {
   191  			return nil, err
   192  		}
   193  
   194  		return credentials.NewTLS(&tls.Config{RootCAs: pool, MinVersion: tls.VersionTLS12}), nil
   195  	default:
   196  		return nil, nil
   197  	}
   198  }
   199  
   200  type RunnableGRPCServer interface {
   201  	WithOpts(opts ...grpc.ServerOption) RunnableGRPCServer
   202  	Listen(ctx context.Context) func() error
   203  	DialContext(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error)
   204  	NetDialContext(ctx context.Context, s string) (net.Conn, error)
   205  	Insecure() bool
   206  	GracefulStop()
   207  }
   208  
   209  type completedGRPCServer struct {
   210  	opts              []grpc.ServerOption
   211  	listener          net.Listener
   212  	svcRegistrationFn func(*grpc.Server)
   213  	listenFunc        func() error
   214  	prestopFunc       func()
   215  	stopFunc          func()
   216  	dial              func(context.Context, ...grpc.DialOption) (*grpc.ClientConn, error)
   217  	netDial           func(ctx context.Context, s string) (net.Conn, error)
   218  	creds             credentials.TransportCredentials
   219  	certWatcher       *certwatcher.CertWatcher
   220  }
   221  
   222  // WithOpts adds to the options for running the server
   223  func (c *completedGRPCServer) WithOpts(opts ...grpc.ServerOption) RunnableGRPCServer {
   224  	c.opts = append(c.opts, opts...)
   225  	srv := grpc.NewServer(c.opts...)
   226  	c.svcRegistrationFn(srv)
   227  	c.listenFunc = func() error {
   228  		return srv.Serve(c.listener)
   229  	}
   230  	c.stopFunc = srv.GracefulStop
   231  	return c
   232  }
   233  
   234  // Listen runs a configured server
   235  func (c *completedGRPCServer) Listen(ctx context.Context) func() error {
   236  	if c.certWatcher != nil {
   237  		go func() {
   238  			if err := c.certWatcher.Start(ctx); err != nil {
   239  				log.Ctx(ctx).Error().Err(err).Msg("error watching tls certs")
   240  			}
   241  		}()
   242  	}
   243  	return c.listenFunc
   244  }
   245  
   246  // DialContext starts a connection to grpc server
   247  func (c *completedGRPCServer) DialContext(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
   248  	opts = append(opts, grpc.WithTransportCredentials(c.creds))
   249  	return c.dial(ctx, opts...)
   250  }
   251  
   252  // NetDialContext returns a low level net.Conn connection to the server
   253  func (c *completedGRPCServer) NetDialContext(ctx context.Context, s string) (net.Conn, error) {
   254  	return c.netDial(ctx, s)
   255  }
   256  
   257  // Insecure returns true if the server is configured without TLS enabled
   258  func (c *completedGRPCServer) Insecure() bool {
   259  	return c.creds.Info().SecurityProtocol == "insecure"
   260  }
   261  
   262  // GracefulStop stops a running server
   263  func (c *completedGRPCServer) GracefulStop() {
   264  	c.prestopFunc()
   265  	c.stopFunc()
   266  }
   267  
   268  type disabledGrpcServer struct{}
   269  
   270  // WithOpts adds to the options for running the server
   271  func (d *disabledGrpcServer) WithOpts(_ ...grpc.ServerOption) RunnableGRPCServer {
   272  	return d
   273  }
   274  
   275  // Listen runs a configured server
   276  func (d *disabledGrpcServer) Listen(_ context.Context) func() error {
   277  	return func() error {
   278  		return nil
   279  	}
   280  }
   281  
   282  // Insecure returns true if the server is configured without TLS enabled
   283  func (d *disabledGrpcServer) Insecure() bool {
   284  	return true
   285  }
   286  
   287  // DialContext starts a connection to grpc server
   288  func (d *disabledGrpcServer) DialContext(_ context.Context, _ ...grpc.DialOption) (*grpc.ClientConn, error) {
   289  	return nil, nil
   290  }
   291  
   292  // NetDialContext starts a connection to grpc server
   293  func (d *disabledGrpcServer) NetDialContext(_ context.Context, _ string) (net.Conn, error) {
   294  	return nil, nil
   295  }
   296  
   297  // GracefulStop stops a running server
   298  func (d *disabledGrpcServer) GracefulStop() {}
   299  
   300  type HTTPServerConfig struct {
   301  	HTTPAddress     string `debugmap:"visible"`
   302  	HTTPTLSCertPath string `debugmap:"visible"`
   303  	HTTPTLSKeyPath  string `debugmap:"visible"`
   304  	HTTPEnabled     bool   `debugmap:"visible"`
   305  
   306  	flagPrefix string
   307  }
   308  
   309  func (c *HTTPServerConfig) Complete(level zerolog.Level, handler http.Handler) (RunnableHTTPServer, error) {
   310  	if !c.HTTPEnabled {
   311  		return &disabledHTTPServer{}, nil
   312  	}
   313  	srv := &http.Server{
   314  		Addr:              c.HTTPAddress,
   315  		Handler:           handler,
   316  		ReadHeaderTimeout: 5 * time.Second,
   317  	}
   318  	var serveFunc func() error
   319  	switch {
   320  	case c.HTTPTLSCertPath == "" && c.HTTPTLSKeyPath == "":
   321  		serveFunc = func() error {
   322  			log.WithLevel(level).
   323  				Str("addr", srv.Addr).
   324  				Str("service", c.flagPrefix).
   325  				Bool("insecure", c.HTTPTLSCertPath == "" && c.HTTPTLSKeyPath == "").
   326  				Msg("http server started serving")
   327  			return srv.ListenAndServe()
   328  		}
   329  
   330  	case c.HTTPTLSCertPath != "" && c.HTTPTLSKeyPath != "":
   331  		watcher, err := certwatcher.New(c.HTTPTLSCertPath, c.HTTPTLSKeyPath)
   332  		if err != nil {
   333  			return nil, err
   334  		}
   335  
   336  		listener, err := tls.Listen("tcp", srv.Addr, &tls.Config{
   337  			GetCertificate: watcher.GetCertificate,
   338  			MinVersion:     tls.VersionTLS12,
   339  		})
   340  		if err != nil {
   341  			return nil, err
   342  		}
   343  		serveFunc = func() error {
   344  			log.WithLevel(level).
   345  				Str("addr", srv.Addr).
   346  				Str("prefix", c.flagPrefix).
   347  				Bool("insecure", c.HTTPTLSCertPath == "" && c.HTTPTLSKeyPath == "").
   348  				Msg("http server started serving")
   349  			return srv.Serve(listener)
   350  		}
   351  	default:
   352  		return nil, fmt.Errorf("failed to start http server: must provide both --%s-tls-cert-path and --%s-tls-key-path",
   353  			c.flagPrefix,
   354  			c.flagPrefix,
   355  		)
   356  	}
   357  
   358  	return &completedHTTPServer{
   359  		srvFunc: func() error {
   360  			if err := serveFunc(); err != nil && !errors.Is(err, http.ErrServerClosed) {
   361  				return fmt.Errorf("failed while serving http: %w", err)
   362  			}
   363  			return nil
   364  		},
   365  		closeFunc: func() {
   366  			if err := srv.Close(); err != nil {
   367  				log.Error().Str("addr", srv.Addr).Str("service", c.flagPrefix).Err(err).Msg("error stopping http server")
   368  			}
   369  			log.WithLevel(level).Str("addr", srv.Addr).Str("service", c.flagPrefix).Msg("http server stopped serving")
   370  		},
   371  		enabled: c.HTTPEnabled,
   372  	}, nil
   373  }
   374  
   375  type RunnableHTTPServer interface {
   376  	ListenAndServe() error
   377  	Close()
   378  }
   379  
   380  type completedHTTPServer struct {
   381  	srvFunc   func() error
   382  	closeFunc func()
   383  	enabled   bool
   384  }
   385  
   386  func (c *completedHTTPServer) ListenAndServe() error {
   387  	if !c.enabled {
   388  		return nil
   389  	}
   390  	return c.srvFunc()
   391  }
   392  
   393  func (c *completedHTTPServer) Close() {
   394  	c.closeFunc()
   395  }
   396  
   397  // RegisterHTTPServerFlags adds the following flags for use with
   398  // HttpServerFromFlags:
   399  // - "$PREFIX-addr"
   400  // - "$PREFIX-tls-cert-path"
   401  // - "$PREFIX-tls-key-path"
   402  // - "$PREFIX-enabled"
   403  func RegisterHTTPServerFlags(flags *pflag.FlagSet, config *HTTPServerConfig, flagPrefix, serviceName, defaultAddr string, defaultEnabled bool) {
   404  	flagPrefix = stringz.DefaultEmpty(flagPrefix, "http")
   405  	serviceName = stringz.DefaultEmpty(serviceName, "http")
   406  	defaultAddr = stringz.DefaultEmpty(defaultAddr, ":8443")
   407  	config.flagPrefix = flagPrefix
   408  	flags.StringVar(&config.HTTPAddress, flagPrefix+"-addr", defaultAddr, "address to listen on to serve "+serviceName)
   409  	flags.StringVar(&config.HTTPTLSCertPath, flagPrefix+"-tls-cert-path", "", "local path to the TLS certificate used to serve "+serviceName)
   410  	flags.StringVar(&config.HTTPTLSKeyPath, flagPrefix+"-tls-key-path", "", "local path to the TLS key used to serve "+serviceName)
   411  	flags.BoolVar(&config.HTTPEnabled, flagPrefix+"-enabled", defaultEnabled, "enable http "+serviceName+" server")
   412  }
   413  
   414  // RegisterDeprecatedHTTPServerFlags registers a set of HTTP server flags as fully deprecated, for a removed HTTP service.
   415  func RegisterDeprecatedHTTPServerFlags(cmd *cobra.Command, flagPrefix, serviceName string) error {
   416  	ignored1 := ""
   417  	ignored2 := ""
   418  	ignored3 := ""
   419  	ignored4 := false
   420  	flags := cmd.Flags()
   421  
   422  	flags.StringVar(&ignored1, flagPrefix+"-addr", "", "address to listen on to serve "+serviceName)
   423  	flags.StringVar(&ignored2, flagPrefix+"-tls-cert-path", "", "local path to the TLS certificate used to serve "+serviceName)
   424  	flags.StringVar(&ignored3, flagPrefix+"-tls-key-path", "", "local path to the TLS key used to serve "+serviceName)
   425  	flags.BoolVar(&ignored4, flagPrefix+"-enabled", false, "enable http "+serviceName+" server")
   426  
   427  	if err := cmd.Flags().MarkDeprecated(flagPrefix+"-addr", "service has been removed; flag is a no-op"); err != nil {
   428  		return fmt.Errorf("failed to mark flag as deprecated: %w", err)
   429  	}
   430  	if err := cmd.Flags().MarkDeprecated(flagPrefix+"-tls-cert-path", "service has been removed; flag is a no-op"); err != nil {
   431  		return fmt.Errorf("failed to mark flag as deprecated: %w", err)
   432  	}
   433  	if err := cmd.Flags().MarkDeprecated(flagPrefix+"-tls-key-path", "service has been removed; flag is a no-op"); err != nil {
   434  		return fmt.Errorf("failed to mark flag as deprecated: %w", err)
   435  	}
   436  	if err := cmd.Flags().MarkDeprecated(flagPrefix+"-enabled", "service has been removed; flag is a no-op"); err != nil {
   437  		return fmt.Errorf("failed to mark flag as deprecated: %w", err)
   438  	}
   439  
   440  	return nil
   441  }
   442  
   443  type disabledHTTPServer struct{}
   444  
   445  func (d *disabledHTTPServer) ListenAndServe() error {
   446  	return nil
   447  }
   448  
   449  func (d *disabledHTTPServer) Close() {}