github.com/prysmaticlabs/prysm@v1.4.4/shared/gateway/gateway.go (about)

     1  // Package gateway defines a grpc-gateway server that serves HTTP-JSON traffic and acts a proxy between HTTP and gRPC.
     2  package gateway
     3  
     4  import (
     5  	"context"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"path"
    10  	"strings"
    11  	"time"
    12  
    13  	gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
    14  	"github.com/pkg/errors"
    15  	"github.com/prysmaticlabs/prysm/shared"
    16  	"github.com/rs/cors"
    17  	"google.golang.org/grpc"
    18  	"google.golang.org/grpc/connectivity"
    19  	"google.golang.org/grpc/credentials"
    20  )
    21  
    22  var _ shared.Service = (*Gateway)(nil)
    23  
    24  // PbMux serves grpc-gateway requests for selected patterns using registered protobuf handlers.
    25  type PbMux struct {
    26  	Registrations []PbHandlerRegistration // Protobuf registrations to be registered in Mux.
    27  	Patterns      []string                // URL patterns that will be handled by Mux.
    28  	Mux           *gwruntime.ServeMux     // The mux that will be used for grpc-gateway requests.
    29  }
    30  
    31  // PbHandlerRegistration is a function that registers a protobuf handler.
    32  type PbHandlerRegistration func(context.Context, *gwruntime.ServeMux, *grpc.ClientConn) error
    33  
    34  // MuxHandler is a function that implements the mux handler functionality.
    35  type MuxHandler func(http.Handler, http.ResponseWriter, *http.Request)
    36  
    37  // Gateway is the gRPC gateway to serve HTTP JSON traffic as a proxy and forward it to the gRPC server.
    38  type Gateway struct {
    39  	conn                         *grpc.ClientConn
    40  	pbHandlers                   []PbMux
    41  	muxHandler                   MuxHandler
    42  	maxCallRecvMsgSize           uint64
    43  	mux                          *http.ServeMux
    44  	server                       *http.Server
    45  	cancel                       context.CancelFunc
    46  	remoteCert                   string
    47  	gatewayAddr                  string
    48  	apiMiddlewareAddr            string
    49  	apiMiddlewareEndpointFactory EndpointFactory
    50  	ctx                          context.Context
    51  	startFailure                 error
    52  	remoteAddr                   string
    53  	allowedOrigins               []string
    54  }
    55  
    56  // New returns a new instance of the Gateway.
    57  func New(
    58  	ctx context.Context,
    59  	pbHandlers []PbMux,
    60  	muxHandler MuxHandler,
    61  	remoteAddr,
    62  	gatewayAddress string,
    63  ) *Gateway {
    64  	g := &Gateway{
    65  		pbHandlers:     pbHandlers,
    66  		muxHandler:     muxHandler,
    67  		mux:            http.NewServeMux(),
    68  		gatewayAddr:    gatewayAddress,
    69  		ctx:            ctx,
    70  		remoteAddr:     remoteAddr,
    71  		allowedOrigins: []string{},
    72  	}
    73  	return g
    74  }
    75  
    76  // WithMux allows adding a custom http.ServeMux to the gateway.
    77  func (g *Gateway) WithMux(m *http.ServeMux) *Gateway {
    78  	g.mux = m
    79  	return g
    80  }
    81  
    82  // WithAllowedOrigins allows adding a set of allowed origins to the gateway.
    83  func (g *Gateway) WithAllowedOrigins(origins []string) *Gateway {
    84  	g.allowedOrigins = origins
    85  	return g
    86  }
    87  
    88  // WithRemoteCert allows adding a custom certificate to the gateway,
    89  func (g *Gateway) WithRemoteCert(cert string) *Gateway {
    90  	g.remoteCert = cert
    91  	return g
    92  }
    93  
    94  // WithMaxCallRecvMsgSize allows specifying the maximum allowed gRPC message size.
    95  func (g *Gateway) WithMaxCallRecvMsgSize(size uint64) *Gateway {
    96  	g.maxCallRecvMsgSize = size
    97  	return g
    98  }
    99  
   100  // WithApiMiddleware allows adding API Middleware proxy to the gateway.
   101  func (g *Gateway) WithApiMiddleware(address string, endpointFactory EndpointFactory) *Gateway {
   102  	g.apiMiddlewareAddr = address
   103  	g.apiMiddlewareEndpointFactory = endpointFactory
   104  	return g
   105  }
   106  
   107  // Start the gateway service.
   108  func (g *Gateway) Start() {
   109  	ctx, cancel := context.WithCancel(g.ctx)
   110  	g.cancel = cancel
   111  
   112  	conn, err := g.dial(ctx, "tcp", g.remoteAddr)
   113  	if err != nil {
   114  		log.WithError(err).Error("Failed to connect to gRPC server")
   115  		g.startFailure = err
   116  		return
   117  	}
   118  	g.conn = conn
   119  
   120  	for _, h := range g.pbHandlers {
   121  		for _, r := range h.Registrations {
   122  			if err := r(ctx, h.Mux, g.conn); err != nil {
   123  				log.WithError(err).Error("Failed to register handler")
   124  				g.startFailure = err
   125  				return
   126  			}
   127  		}
   128  		for _, p := range h.Patterns {
   129  			g.mux.Handle(p, h.Mux)
   130  		}
   131  	}
   132  
   133  	corsMux := g.corsMiddleware(g.mux)
   134  
   135  	if g.muxHandler != nil {
   136  		g.mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
   137  			g.muxHandler(corsMux, w, r)
   138  		})
   139  	}
   140  
   141  	g.server = &http.Server{
   142  		Addr:    g.gatewayAddr,
   143  		Handler: corsMux,
   144  	}
   145  
   146  	go func() {
   147  		log.WithField("address", g.gatewayAddr).Info("Starting gRPC gateway")
   148  		if err := g.server.ListenAndServe(); err != http.ErrServerClosed {
   149  			log.WithError(err).Error("Failed to start gRPC gateway")
   150  			g.startFailure = err
   151  			return
   152  		}
   153  	}()
   154  
   155  	if g.apiMiddlewareAddr != "" && g.apiMiddlewareEndpointFactory != nil && !g.apiMiddlewareEndpointFactory.IsNil() {
   156  		go g.registerApiMiddleware()
   157  	}
   158  }
   159  
   160  // Status of grpc gateway. Returns an error if this service is unhealthy.
   161  func (g *Gateway) Status() error {
   162  	if g.startFailure != nil {
   163  		return g.startFailure
   164  	}
   165  
   166  	if s := g.conn.GetState(); s != connectivity.Ready {
   167  		return fmt.Errorf("grpc server is %s", s)
   168  	}
   169  
   170  	return nil
   171  }
   172  
   173  // Stop the gateway with a graceful shutdown.
   174  func (g *Gateway) Stop() error {
   175  	if g.server != nil {
   176  		shutdownCtx, shutdownCancel := context.WithTimeout(g.ctx, 2*time.Second)
   177  		defer shutdownCancel()
   178  		if err := g.server.Shutdown(shutdownCtx); err != nil {
   179  			if errors.Is(err, context.DeadlineExceeded) {
   180  				log.Warn("Existing connections terminated")
   181  			} else {
   182  				log.WithError(err).Error("Failed to gracefully shut down server")
   183  			}
   184  		}
   185  	}
   186  
   187  	if g.cancel != nil {
   188  		g.cancel()
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  func (g *Gateway) corsMiddleware(h http.Handler) http.Handler {
   195  	c := cors.New(cors.Options{
   196  		AllowedOrigins:   g.allowedOrigins,
   197  		AllowedMethods:   []string{http.MethodPost, http.MethodGet, http.MethodOptions},
   198  		AllowCredentials: true,
   199  		MaxAge:           600,
   200  		AllowedHeaders:   []string{"*"},
   201  	})
   202  	return c.Handler(h)
   203  }
   204  
   205  const swaggerDir = "proto/beacon/rpc/v1/"
   206  
   207  // SwaggerServer returns swagger specification files located under "/swagger/"
   208  func SwaggerServer() http.HandlerFunc {
   209  	return func(w http.ResponseWriter, r *http.Request) {
   210  		if !strings.HasSuffix(r.URL.Path, ".swagger.json") {
   211  			log.Debugf("Not found: %s", r.URL.Path)
   212  			http.NotFound(w, r)
   213  			return
   214  		}
   215  
   216  		log.Debugf("Serving %s\n", r.URL.Path)
   217  		p := strings.TrimPrefix(r.URL.Path, "/swagger/")
   218  		p = path.Join(swaggerDir, p)
   219  		http.ServeFile(w, r, p)
   220  	}
   221  }
   222  
   223  // dial the gRPC server.
   224  func (g *Gateway) dial(ctx context.Context, network, addr string) (*grpc.ClientConn, error) {
   225  	switch network {
   226  	case "tcp":
   227  		return g.dialTCP(ctx, addr)
   228  	case "unix":
   229  		return g.dialUnix(ctx, addr)
   230  	default:
   231  		return nil, fmt.Errorf("unsupported network type %q", network)
   232  	}
   233  }
   234  
   235  // dialTCP creates a client connection via TCP.
   236  // "addr" must be a valid TCP address with a port number.
   237  func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, error) {
   238  	security := grpc.WithInsecure()
   239  	if len(g.remoteCert) > 0 {
   240  		creds, err := credentials.NewClientTLSFromFile(g.remoteCert, "")
   241  		if err != nil {
   242  			return nil, err
   243  		}
   244  		security = grpc.WithTransportCredentials(creds)
   245  	}
   246  	opts := []grpc.DialOption{
   247  		security,
   248  		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))),
   249  	}
   250  
   251  	return grpc.DialContext(ctx, addr, opts...)
   252  }
   253  
   254  // dialUnix creates a client connection via a unix domain socket.
   255  // "addr" must be a valid path to the socket.
   256  func (g *Gateway) dialUnix(ctx context.Context, addr string) (*grpc.ClientConn, error) {
   257  	d := func(addr string, timeout time.Duration) (net.Conn, error) {
   258  		return net.DialTimeout("unix", addr, timeout)
   259  	}
   260  	f := func(ctx context.Context, addr string) (net.Conn, error) {
   261  		if deadline, ok := ctx.Deadline(); ok {
   262  			return d(addr, time.Until(deadline))
   263  		}
   264  		return d(addr, 0)
   265  	}
   266  	opts := []grpc.DialOption{
   267  		grpc.WithInsecure(),
   268  		grpc.WithContextDialer(f),
   269  		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))),
   270  	}
   271  	return grpc.DialContext(ctx, addr, opts...)
   272  }
   273  
   274  func (g *Gateway) registerApiMiddleware() {
   275  	proxy := &ApiProxyMiddleware{
   276  		GatewayAddress:  g.gatewayAddr,
   277  		ProxyAddress:    g.apiMiddlewareAddr,
   278  		EndpointCreator: g.apiMiddlewareEndpointFactory,
   279  	}
   280  	log.WithField("API middleware address", g.apiMiddlewareAddr).Info("Starting API middleware")
   281  	if err := proxy.Run(); err != http.ErrServerClosed {
   282  		log.WithError(err).Error("Failed to start API middleware")
   283  		g.startFailure = err
   284  		return
   285  	}
   286  }