github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/util/util.go (about) 1 package util 2 3 //go:generate go run github.com/ecordell/optgen -output zz_generated.options.go . GRPCServerConfig HTTPServerConfig 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "errors" 10 "fmt" 11 "net" 12 "net/http" 13 "time" 14 15 "github.com/jzelinskie/stringz" 16 "github.com/rs/zerolog" 17 "github.com/spf13/cobra" 18 "github.com/spf13/pflag" 19 "google.golang.org/grpc" 20 "google.golang.org/grpc/credentials" 21 "google.golang.org/grpc/credentials/insecure" 22 "google.golang.org/grpc/keepalive" 23 "google.golang.org/grpc/test/bufconn" 24 25 // Register Snappy S2 compression 26 _ "github.com/mostynb/go-grpc-compression/experimental/s2" 27 28 "sigs.k8s.io/controller-runtime/pkg/certwatcher" 29 // Register cert watcher metrics 30 _ "sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics" 31 32 log "github.com/authzed/spicedb/internal/logging" 33 "github.com/authzed/spicedb/pkg/x509util" 34 ) 35 36 const BufferedNetwork string = "buffnet" 37 38 type GRPCServerConfig struct { 39 Address string `debugmap:"visible"` 40 Network string `debugmap:"visible"` 41 TLSCertPath string `debugmap:"visible"` 42 TLSKeyPath string `debugmap:"visible"` 43 MaxConnAge time.Duration `debugmap:"visible"` 44 Enabled bool `debugmap:"visible"` 45 BufferSize int `debugmap:"visible"` 46 ClientCAPath string `debugmap:"visible"` 47 MaxWorkers uint32 `debugmap:"visible"` 48 49 flagPrefix string 50 } 51 52 // RegisterGRPCServerFlags adds the following flags for use with 53 // GrpcServerFromFlags: 54 // - "$PREFIX-addr" 55 // - "$PREFIX-tls-cert-path" 56 // - "$PREFIX-tls-key-path" 57 // - "$PREFIX-max-conn-age" 58 func RegisterGRPCServerFlags(flags *pflag.FlagSet, config *GRPCServerConfig, flagPrefix, serviceName, defaultAddr string, defaultEnabled bool) { 59 flagPrefix = stringz.DefaultEmpty(flagPrefix, "grpc") 60 serviceName = stringz.DefaultEmpty(serviceName, "grpc") 61 defaultAddr = stringz.DefaultEmpty(defaultAddr, ":50051") 62 config.flagPrefix = flagPrefix 63 64 flags.StringVar(&config.Address, flagPrefix+"-addr", defaultAddr, "address to listen on to serve "+serviceName) 65 flags.StringVar(&config.Network, flagPrefix+"-network", "tcp", "network type to serve "+serviceName+` ("tcp", "tcp4", "tcp6", "unix", "unixpacket")`) 66 flags.StringVar(&config.TLSCertPath, flagPrefix+"-tls-cert-path", "", "local path to the TLS certificate used to serve "+serviceName) 67 flags.StringVar(&config.TLSKeyPath, flagPrefix+"-tls-key-path", "", "local path to the TLS key used to serve "+serviceName) 68 flags.DurationVar(&config.MaxConnAge, flagPrefix+"-max-conn-age", 30*time.Second, "how long a connection serving "+serviceName+" should be able to live") 69 flags.BoolVar(&config.Enabled, flagPrefix+"-enabled", defaultEnabled, "enable "+serviceName+" gRPC server") 70 flags.Uint32Var(&config.MaxWorkers, flagPrefix+"-max-workers", 0, "set the number of workers for this server (0 value means 1 worker per request)") 71 } 72 73 type ( 74 DialFunc func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) 75 NetDialFunc func(ctx context.Context, s string) (net.Conn, error) 76 ) 77 78 // Complete takes a set of default options and returns a completed server 79 func (c *GRPCServerConfig) Complete(level zerolog.Level, svcRegistrationFn func(server *grpc.Server), opts ...grpc.ServerOption) (RunnableGRPCServer, error) { 80 if !c.Enabled { 81 return &disabledGrpcServer{}, nil 82 } 83 if c.BufferSize == 0 { 84 c.BufferSize = 1024 * 1024 85 } 86 opts = append(opts, grpc.KeepaliveParams(keepalive.ServerParameters{ 87 MaxConnectionAge: c.MaxConnAge, 88 }), grpc.NumStreamWorkers(c.MaxWorkers)) 89 90 tlsOpts, certWatcher, err := c.tlsOpts() 91 if err != nil { 92 return nil, err 93 } 94 opts = append(opts, tlsOpts...) 95 96 clientCreds, err := c.clientCreds() 97 if err != nil { 98 return nil, err 99 } 100 101 l, dial, netDial, err := c.listenerAndDialer() 102 if err != nil { 103 return nil, fmt.Errorf("failed to listen on addr for gRPC server: %w", err) 104 } 105 log.WithLevel(level). 106 Str("addr", c.Address). 107 Str("network", c.Network). 108 Str("service", c.flagPrefix). 109 Uint32("workers", c.MaxWorkers). 110 Bool("insecure", c.TLSCertPath == "" && c.TLSKeyPath == ""). 111 Msg("grpc server started serving") 112 113 srv := grpc.NewServer(opts...) 114 svcRegistrationFn(srv) 115 return &completedGRPCServer{ 116 opts: opts, 117 listener: l, 118 svcRegistrationFn: svcRegistrationFn, 119 listenFunc: func() error { 120 return srv.Serve(l) 121 }, 122 dial: dial, 123 netDial: netDial, 124 prestopFunc: func() { 125 log.WithLevel(level). 126 Str("addr", c.Address). 127 Str("network", c.Network). 128 Str("service", c.flagPrefix). 129 Msg("grpc server stopped serving") 130 }, 131 stopFunc: srv.GracefulStop, 132 creds: clientCreds, 133 certWatcher: certWatcher, 134 }, nil 135 } 136 137 func (c *GRPCServerConfig) listenerAndDialer() (net.Listener, DialFunc, NetDialFunc, error) { 138 if c.Network == BufferedNetwork { 139 bl := bufconn.Listen(c.BufferSize) 140 return bl, func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { 141 opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { 142 return bl.DialContext(ctx) 143 })) 144 145 return grpc.DialContext(ctx, BufferedNetwork, opts...) 146 }, func(ctx context.Context, s string) (net.Conn, error) { 147 return bl.DialContext(ctx) 148 }, nil 149 } 150 l, err := net.Listen(c.Network, c.Address) 151 if err != nil { 152 return nil, nil, nil, err 153 } 154 return l, func(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { 155 return grpc.DialContext(ctx, c.Address, opts...) 156 }, nil, nil 157 } 158 159 func (c *GRPCServerConfig) tlsOpts() ([]grpc.ServerOption, *certwatcher.CertWatcher, error) { 160 switch { 161 case c.TLSCertPath == "" && c.TLSKeyPath == "": 162 return nil, nil, nil 163 case c.TLSCertPath != "" && c.TLSKeyPath != "": 164 watcher, err := certwatcher.New(c.TLSCertPath, c.TLSKeyPath) 165 if err != nil { 166 return nil, nil, err 167 } 168 creds := credentials.NewTLS(&tls.Config{ 169 GetCertificate: watcher.GetCertificate, 170 MinVersion: tls.VersionTLS12, 171 }) 172 return []grpc.ServerOption{grpc.Creds(creds)}, watcher, nil 173 default: 174 return nil, nil, nil 175 } 176 } 177 178 func (c *GRPCServerConfig) clientCreds() (credentials.TransportCredentials, error) { 179 switch { 180 case c.TLSCertPath == "" && c.TLSKeyPath == "": 181 return insecure.NewCredentials(), nil 182 case c.TLSCertPath != "" && c.TLSKeyPath != "": 183 var err error 184 var pool *x509.CertPool 185 if c.ClientCAPath != "" { 186 pool, err = x509util.CustomCertPool(c.ClientCAPath) 187 } else { 188 pool, err = x509.SystemCertPool() 189 } 190 if err != nil { 191 return nil, err 192 } 193 194 return credentials.NewTLS(&tls.Config{RootCAs: pool, MinVersion: tls.VersionTLS12}), nil 195 default: 196 return nil, nil 197 } 198 } 199 200 type RunnableGRPCServer interface { 201 WithOpts(opts ...grpc.ServerOption) RunnableGRPCServer 202 Listen(ctx context.Context) func() error 203 DialContext(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) 204 NetDialContext(ctx context.Context, s string) (net.Conn, error) 205 Insecure() bool 206 GracefulStop() 207 } 208 209 type completedGRPCServer struct { 210 opts []grpc.ServerOption 211 listener net.Listener 212 svcRegistrationFn func(*grpc.Server) 213 listenFunc func() error 214 prestopFunc func() 215 stopFunc func() 216 dial func(context.Context, ...grpc.DialOption) (*grpc.ClientConn, error) 217 netDial func(ctx context.Context, s string) (net.Conn, error) 218 creds credentials.TransportCredentials 219 certWatcher *certwatcher.CertWatcher 220 } 221 222 // WithOpts adds to the options for running the server 223 func (c *completedGRPCServer) WithOpts(opts ...grpc.ServerOption) RunnableGRPCServer { 224 c.opts = append(c.opts, opts...) 225 srv := grpc.NewServer(c.opts...) 226 c.svcRegistrationFn(srv) 227 c.listenFunc = func() error { 228 return srv.Serve(c.listener) 229 } 230 c.stopFunc = srv.GracefulStop 231 return c 232 } 233 234 // Listen runs a configured server 235 func (c *completedGRPCServer) Listen(ctx context.Context) func() error { 236 if c.certWatcher != nil { 237 go func() { 238 if err := c.certWatcher.Start(ctx); err != nil { 239 log.Ctx(ctx).Error().Err(err).Msg("error watching tls certs") 240 } 241 }() 242 } 243 return c.listenFunc 244 } 245 246 // DialContext starts a connection to grpc server 247 func (c *completedGRPCServer) DialContext(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) { 248 opts = append(opts, grpc.WithTransportCredentials(c.creds)) 249 return c.dial(ctx, opts...) 250 } 251 252 // NetDialContext returns a low level net.Conn connection to the server 253 func (c *completedGRPCServer) NetDialContext(ctx context.Context, s string) (net.Conn, error) { 254 return c.netDial(ctx, s) 255 } 256 257 // Insecure returns true if the server is configured without TLS enabled 258 func (c *completedGRPCServer) Insecure() bool { 259 return c.creds.Info().SecurityProtocol == "insecure" 260 } 261 262 // GracefulStop stops a running server 263 func (c *completedGRPCServer) GracefulStop() { 264 c.prestopFunc() 265 c.stopFunc() 266 } 267 268 type disabledGrpcServer struct{} 269 270 // WithOpts adds to the options for running the server 271 func (d *disabledGrpcServer) WithOpts(_ ...grpc.ServerOption) RunnableGRPCServer { 272 return d 273 } 274 275 // Listen runs a configured server 276 func (d *disabledGrpcServer) Listen(_ context.Context) func() error { 277 return func() error { 278 return nil 279 } 280 } 281 282 // Insecure returns true if the server is configured without TLS enabled 283 func (d *disabledGrpcServer) Insecure() bool { 284 return true 285 } 286 287 // DialContext starts a connection to grpc server 288 func (d *disabledGrpcServer) DialContext(_ context.Context, _ ...grpc.DialOption) (*grpc.ClientConn, error) { 289 return nil, nil 290 } 291 292 // NetDialContext starts a connection to grpc server 293 func (d *disabledGrpcServer) NetDialContext(_ context.Context, _ string) (net.Conn, error) { 294 return nil, nil 295 } 296 297 // GracefulStop stops a running server 298 func (d *disabledGrpcServer) GracefulStop() {} 299 300 type HTTPServerConfig struct { 301 HTTPAddress string `debugmap:"visible"` 302 HTTPTLSCertPath string `debugmap:"visible"` 303 HTTPTLSKeyPath string `debugmap:"visible"` 304 HTTPEnabled bool `debugmap:"visible"` 305 306 flagPrefix string 307 } 308 309 func (c *HTTPServerConfig) Complete(level zerolog.Level, handler http.Handler) (RunnableHTTPServer, error) { 310 if !c.HTTPEnabled { 311 return &disabledHTTPServer{}, nil 312 } 313 srv := &http.Server{ 314 Addr: c.HTTPAddress, 315 Handler: handler, 316 ReadHeaderTimeout: 5 * time.Second, 317 } 318 var serveFunc func() error 319 switch { 320 case c.HTTPTLSCertPath == "" && c.HTTPTLSKeyPath == "": 321 serveFunc = func() error { 322 log.WithLevel(level). 323 Str("addr", srv.Addr). 324 Str("service", c.flagPrefix). 325 Bool("insecure", c.HTTPTLSCertPath == "" && c.HTTPTLSKeyPath == ""). 326 Msg("http server started serving") 327 return srv.ListenAndServe() 328 } 329 330 case c.HTTPTLSCertPath != "" && c.HTTPTLSKeyPath != "": 331 watcher, err := certwatcher.New(c.HTTPTLSCertPath, c.HTTPTLSKeyPath) 332 if err != nil { 333 return nil, err 334 } 335 336 listener, err := tls.Listen("tcp", srv.Addr, &tls.Config{ 337 GetCertificate: watcher.GetCertificate, 338 MinVersion: tls.VersionTLS12, 339 }) 340 if err != nil { 341 return nil, err 342 } 343 serveFunc = func() error { 344 log.WithLevel(level). 345 Str("addr", srv.Addr). 346 Str("prefix", c.flagPrefix). 347 Bool("insecure", c.HTTPTLSCertPath == "" && c.HTTPTLSKeyPath == ""). 348 Msg("http server started serving") 349 return srv.Serve(listener) 350 } 351 default: 352 return nil, fmt.Errorf("failed to start http server: must provide both --%s-tls-cert-path and --%s-tls-key-path", 353 c.flagPrefix, 354 c.flagPrefix, 355 ) 356 } 357 358 return &completedHTTPServer{ 359 srvFunc: func() error { 360 if err := serveFunc(); err != nil && !errors.Is(err, http.ErrServerClosed) { 361 return fmt.Errorf("failed while serving http: %w", err) 362 } 363 return nil 364 }, 365 closeFunc: func() { 366 if err := srv.Close(); err != nil { 367 log.Error().Str("addr", srv.Addr).Str("service", c.flagPrefix).Err(err).Msg("error stopping http server") 368 } 369 log.WithLevel(level).Str("addr", srv.Addr).Str("service", c.flagPrefix).Msg("http server stopped serving") 370 }, 371 enabled: c.HTTPEnabled, 372 }, nil 373 } 374 375 type RunnableHTTPServer interface { 376 ListenAndServe() error 377 Close() 378 } 379 380 type completedHTTPServer struct { 381 srvFunc func() error 382 closeFunc func() 383 enabled bool 384 } 385 386 func (c *completedHTTPServer) ListenAndServe() error { 387 if !c.enabled { 388 return nil 389 } 390 return c.srvFunc() 391 } 392 393 func (c *completedHTTPServer) Close() { 394 c.closeFunc() 395 } 396 397 // RegisterHTTPServerFlags adds the following flags for use with 398 // HttpServerFromFlags: 399 // - "$PREFIX-addr" 400 // - "$PREFIX-tls-cert-path" 401 // - "$PREFIX-tls-key-path" 402 // - "$PREFIX-enabled" 403 func RegisterHTTPServerFlags(flags *pflag.FlagSet, config *HTTPServerConfig, flagPrefix, serviceName, defaultAddr string, defaultEnabled bool) { 404 flagPrefix = stringz.DefaultEmpty(flagPrefix, "http") 405 serviceName = stringz.DefaultEmpty(serviceName, "http") 406 defaultAddr = stringz.DefaultEmpty(defaultAddr, ":8443") 407 config.flagPrefix = flagPrefix 408 flags.StringVar(&config.HTTPAddress, flagPrefix+"-addr", defaultAddr, "address to listen on to serve "+serviceName) 409 flags.StringVar(&config.HTTPTLSCertPath, flagPrefix+"-tls-cert-path", "", "local path to the TLS certificate used to serve "+serviceName) 410 flags.StringVar(&config.HTTPTLSKeyPath, flagPrefix+"-tls-key-path", "", "local path to the TLS key used to serve "+serviceName) 411 flags.BoolVar(&config.HTTPEnabled, flagPrefix+"-enabled", defaultEnabled, "enable http "+serviceName+" server") 412 } 413 414 // RegisterDeprecatedHTTPServerFlags registers a set of HTTP server flags as fully deprecated, for a removed HTTP service. 415 func RegisterDeprecatedHTTPServerFlags(cmd *cobra.Command, flagPrefix, serviceName string) error { 416 ignored1 := "" 417 ignored2 := "" 418 ignored3 := "" 419 ignored4 := false 420 flags := cmd.Flags() 421 422 flags.StringVar(&ignored1, flagPrefix+"-addr", "", "address to listen on to serve "+serviceName) 423 flags.StringVar(&ignored2, flagPrefix+"-tls-cert-path", "", "local path to the TLS certificate used to serve "+serviceName) 424 flags.StringVar(&ignored3, flagPrefix+"-tls-key-path", "", "local path to the TLS key used to serve "+serviceName) 425 flags.BoolVar(&ignored4, flagPrefix+"-enabled", false, "enable http "+serviceName+" server") 426 427 if err := cmd.Flags().MarkDeprecated(flagPrefix+"-addr", "service has been removed; flag is a no-op"); err != nil { 428 return fmt.Errorf("failed to mark flag as deprecated: %w", err) 429 } 430 if err := cmd.Flags().MarkDeprecated(flagPrefix+"-tls-cert-path", "service has been removed; flag is a no-op"); err != nil { 431 return fmt.Errorf("failed to mark flag as deprecated: %w", err) 432 } 433 if err := cmd.Flags().MarkDeprecated(flagPrefix+"-tls-key-path", "service has been removed; flag is a no-op"); err != nil { 434 return fmt.Errorf("failed to mark flag as deprecated: %w", err) 435 } 436 if err := cmd.Flags().MarkDeprecated(flagPrefix+"-enabled", "service has been removed; flag is a no-op"); err != nil { 437 return fmt.Errorf("failed to mark flag as deprecated: %w", err) 438 } 439 440 return nil 441 } 442 443 type disabledHTTPServer struct{} 444 445 func (d *disabledHTTPServer) ListenAndServe() error { 446 return nil 447 } 448 449 func (d *disabledHTTPServer) Close() {}