github.com/TBD54566975/ftl@v0.219.0/internal/rpc/server.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"net/http"
     8  	"net/url"
     9  	"strings"
    10  	"time"
    11  
    12  	"connectrpc.com/connect"
    13  	"connectrpc.com/grpcreflect"
    14  	"github.com/alecthomas/concurrency"
    15  	"github.com/alecthomas/types/pubsub"
    16  	"golang.org/x/net/http2"
    17  	"golang.org/x/net/http2/h2c"
    18  )
    19  
    20  const ShutdownGracePeriod = time.Second * 5
    21  
    22  type serverOptions struct {
    23  	mux             *http.ServeMux
    24  	reflectionPaths []string
    25  }
    26  
    27  type Option func(*serverOptions)
    28  
    29  type GRPCServerConstructor[Iface Pingable] func(svc Iface, opts ...connect.HandlerOption) (string, http.Handler)
    30  type RawGRPCServerConstructor[Iface any] func(svc Iface, opts ...connect.HandlerOption) (string, http.Handler)
    31  
    32  // GRPC is a convenience function for registering a GRPC server with default options.
    33  // TODO(aat): Do we need pingable here?
    34  func GRPC[Iface, Impl Pingable](constructor GRPCServerConstructor[Iface], impl Impl, options ...connect.HandlerOption) Option {
    35  	return func(o *serverOptions) {
    36  		options = append(options, DefaultHandlerOptions()...)
    37  		path, handler := constructor(any(impl).(Iface), options...)
    38  		o.reflectionPaths = append(o.reflectionPaths, strings.Trim(path, "/"))
    39  		o.mux.Handle(path, handler)
    40  	}
    41  }
    42  
    43  // RawGRPC is a convenience function for registering a GRPC server with default options without Pingable.
    44  func RawGRPC[Iface, Impl any](constructor RawGRPCServerConstructor[Iface], impl Impl, options ...connect.HandlerOption) Option {
    45  	return func(o *serverOptions) {
    46  		options = append(options, DefaultHandlerOptions()...)
    47  		path, handler := constructor(any(impl).(Iface), options...)
    48  		o.reflectionPaths = append(o.reflectionPaths, strings.Trim(path, "/"))
    49  		o.mux.Handle(path, handler)
    50  	}
    51  }
    52  
    53  // HTTP adds a HTTP route to the server.
    54  func HTTP(prefix string, handler http.Handler) Option {
    55  	return func(o *serverOptions) {
    56  		o.mux.Handle(prefix, handler)
    57  	}
    58  }
    59  
    60  type Server struct {
    61  	listen *url.URL
    62  	Bind   *pubsub.Topic[*url.URL] // Will be updated with the actual bind address.
    63  	Server *http.Server
    64  }
    65  
    66  func NewServer(ctx context.Context, listen *url.URL, options ...Option) (*Server, error) {
    67  	opts := &serverOptions{
    68  		mux: http.NewServeMux(),
    69  	}
    70  
    71  	opts.mux.Handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    72  		w.WriteHeader(http.StatusOK)
    73  	}))
    74  
    75  	for _, option := range options {
    76  		option(opts)
    77  	}
    78  
    79  	// Register reflection services.
    80  	reflector := grpcreflect.NewStaticReflector(opts.reflectionPaths...)
    81  	opts.mux.Handle(grpcreflect.NewHandlerV1(reflector))
    82  	opts.mux.Handle(grpcreflect.NewHandlerV1Alpha(reflector))
    83  	root := ContextValuesMiddleware(ctx, opts.mux)
    84  
    85  	http1Server := &http.Server{
    86  		Handler:           h2c.NewHandler(root, &http2.Server{}),
    87  		ReadHeaderTimeout: time.Second * 30,
    88  		BaseContext:       func(net.Listener) context.Context { return ctx },
    89  	}
    90  
    91  	return &Server{
    92  		listen: listen,
    93  		Bind:   pubsub.New[*url.URL](),
    94  		Server: http1Server,
    95  	}, nil
    96  }
    97  
    98  // Serve runs the server, updating .Bind with the actual bind address.
    99  func (s *Server) Serve(ctx context.Context) error {
   100  	listener, err := net.Listen("tcp", s.listen.Host)
   101  	if err != nil {
   102  		return err
   103  	}
   104  	if s.listen.Port() == "0" {
   105  		s.listen.Host = listener.Addr().String()
   106  	}
   107  	s.Bind.Publish(s.listen)
   108  
   109  	tree, _ := concurrency.New(ctx)
   110  
   111  	// Shutdown server on context cancellation.
   112  	tree.Go(func(ctx context.Context) error {
   113  		<-ctx.Done()
   114  		ctx, cancel := context.WithTimeout(context.Background(), ShutdownGracePeriod)
   115  		defer cancel()
   116  		err := s.Server.Shutdown(ctx)
   117  		if err == nil {
   118  			return nil
   119  		}
   120  		if errors.Is(err, context.Canceled) {
   121  			_ = s.Server.Close()
   122  			return err
   123  		}
   124  		return err
   125  	})
   126  
   127  	// Start server.
   128  	tree.Go(func(ctx context.Context) error {
   129  		err = s.Server.Serve(listener)
   130  		if errors.Is(err, http.ErrServerClosed) {
   131  			return nil
   132  		}
   133  		return err
   134  	})
   135  
   136  	return tree.Wait()
   137  }
   138  
   139  // Serve starts a HTTP and Connect gRPC server with sane defaults for FTL.
   140  //
   141  // Blocks until the context is cancelled.
   142  func Serve(ctx context.Context, listen *url.URL, options ...Option) error {
   143  	server, err := NewServer(ctx, listen, options...)
   144  	if err != nil {
   145  		return err
   146  	}
   147  	return server.Serve(ctx)
   148  }