github.com/prysmaticlabs/prysm@v1.4.4/shared/gateway/gateway.go (about) 1 // Package gateway defines a grpc-gateway server that serves HTTP-JSON traffic and acts a proxy between HTTP and gRPC. 2 package gateway 3 4 import ( 5 "context" 6 "fmt" 7 "net" 8 "net/http" 9 "path" 10 "strings" 11 "time" 12 13 gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" 14 "github.com/pkg/errors" 15 "github.com/prysmaticlabs/prysm/shared" 16 "github.com/rs/cors" 17 "google.golang.org/grpc" 18 "google.golang.org/grpc/connectivity" 19 "google.golang.org/grpc/credentials" 20 ) 21 22 var _ shared.Service = (*Gateway)(nil) 23 24 // PbMux serves grpc-gateway requests for selected patterns using registered protobuf handlers. 25 type PbMux struct { 26 Registrations []PbHandlerRegistration // Protobuf registrations to be registered in Mux. 27 Patterns []string // URL patterns that will be handled by Mux. 28 Mux *gwruntime.ServeMux // The mux that will be used for grpc-gateway requests. 29 } 30 31 // PbHandlerRegistration is a function that registers a protobuf handler. 32 type PbHandlerRegistration func(context.Context, *gwruntime.ServeMux, *grpc.ClientConn) error 33 34 // MuxHandler is a function that implements the mux handler functionality. 35 type MuxHandler func(http.Handler, http.ResponseWriter, *http.Request) 36 37 // Gateway is the gRPC gateway to serve HTTP JSON traffic as a proxy and forward it to the gRPC server. 38 type Gateway struct { 39 conn *grpc.ClientConn 40 pbHandlers []PbMux 41 muxHandler MuxHandler 42 maxCallRecvMsgSize uint64 43 mux *http.ServeMux 44 server *http.Server 45 cancel context.CancelFunc 46 remoteCert string 47 gatewayAddr string 48 apiMiddlewareAddr string 49 apiMiddlewareEndpointFactory EndpointFactory 50 ctx context.Context 51 startFailure error 52 remoteAddr string 53 allowedOrigins []string 54 } 55 56 // New returns a new instance of the Gateway. 57 func New( 58 ctx context.Context, 59 pbHandlers []PbMux, 60 muxHandler MuxHandler, 61 remoteAddr, 62 gatewayAddress string, 63 ) *Gateway { 64 g := &Gateway{ 65 pbHandlers: pbHandlers, 66 muxHandler: muxHandler, 67 mux: http.NewServeMux(), 68 gatewayAddr: gatewayAddress, 69 ctx: ctx, 70 remoteAddr: remoteAddr, 71 allowedOrigins: []string{}, 72 } 73 return g 74 } 75 76 // WithMux allows adding a custom http.ServeMux to the gateway. 77 func (g *Gateway) WithMux(m *http.ServeMux) *Gateway { 78 g.mux = m 79 return g 80 } 81 82 // WithAllowedOrigins allows adding a set of allowed origins to the gateway. 83 func (g *Gateway) WithAllowedOrigins(origins []string) *Gateway { 84 g.allowedOrigins = origins 85 return g 86 } 87 88 // WithRemoteCert allows adding a custom certificate to the gateway, 89 func (g *Gateway) WithRemoteCert(cert string) *Gateway { 90 g.remoteCert = cert 91 return g 92 } 93 94 // WithMaxCallRecvMsgSize allows specifying the maximum allowed gRPC message size. 95 func (g *Gateway) WithMaxCallRecvMsgSize(size uint64) *Gateway { 96 g.maxCallRecvMsgSize = size 97 return g 98 } 99 100 // WithApiMiddleware allows adding API Middleware proxy to the gateway. 101 func (g *Gateway) WithApiMiddleware(address string, endpointFactory EndpointFactory) *Gateway { 102 g.apiMiddlewareAddr = address 103 g.apiMiddlewareEndpointFactory = endpointFactory 104 return g 105 } 106 107 // Start the gateway service. 108 func (g *Gateway) Start() { 109 ctx, cancel := context.WithCancel(g.ctx) 110 g.cancel = cancel 111 112 conn, err := g.dial(ctx, "tcp", g.remoteAddr) 113 if err != nil { 114 log.WithError(err).Error("Failed to connect to gRPC server") 115 g.startFailure = err 116 return 117 } 118 g.conn = conn 119 120 for _, h := range g.pbHandlers { 121 for _, r := range h.Registrations { 122 if err := r(ctx, h.Mux, g.conn); err != nil { 123 log.WithError(err).Error("Failed to register handler") 124 g.startFailure = err 125 return 126 } 127 } 128 for _, p := range h.Patterns { 129 g.mux.Handle(p, h.Mux) 130 } 131 } 132 133 corsMux := g.corsMiddleware(g.mux) 134 135 if g.muxHandler != nil { 136 g.mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 137 g.muxHandler(corsMux, w, r) 138 }) 139 } 140 141 g.server = &http.Server{ 142 Addr: g.gatewayAddr, 143 Handler: corsMux, 144 } 145 146 go func() { 147 log.WithField("address", g.gatewayAddr).Info("Starting gRPC gateway") 148 if err := g.server.ListenAndServe(); err != http.ErrServerClosed { 149 log.WithError(err).Error("Failed to start gRPC gateway") 150 g.startFailure = err 151 return 152 } 153 }() 154 155 if g.apiMiddlewareAddr != "" && g.apiMiddlewareEndpointFactory != nil && !g.apiMiddlewareEndpointFactory.IsNil() { 156 go g.registerApiMiddleware() 157 } 158 } 159 160 // Status of grpc gateway. Returns an error if this service is unhealthy. 161 func (g *Gateway) Status() error { 162 if g.startFailure != nil { 163 return g.startFailure 164 } 165 166 if s := g.conn.GetState(); s != connectivity.Ready { 167 return fmt.Errorf("grpc server is %s", s) 168 } 169 170 return nil 171 } 172 173 // Stop the gateway with a graceful shutdown. 174 func (g *Gateway) Stop() error { 175 if g.server != nil { 176 shutdownCtx, shutdownCancel := context.WithTimeout(g.ctx, 2*time.Second) 177 defer shutdownCancel() 178 if err := g.server.Shutdown(shutdownCtx); err != nil { 179 if errors.Is(err, context.DeadlineExceeded) { 180 log.Warn("Existing connections terminated") 181 } else { 182 log.WithError(err).Error("Failed to gracefully shut down server") 183 } 184 } 185 } 186 187 if g.cancel != nil { 188 g.cancel() 189 } 190 191 return nil 192 } 193 194 func (g *Gateway) corsMiddleware(h http.Handler) http.Handler { 195 c := cors.New(cors.Options{ 196 AllowedOrigins: g.allowedOrigins, 197 AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodOptions}, 198 AllowCredentials: true, 199 MaxAge: 600, 200 AllowedHeaders: []string{"*"}, 201 }) 202 return c.Handler(h) 203 } 204 205 const swaggerDir = "proto/beacon/rpc/v1/" 206 207 // SwaggerServer returns swagger specification files located under "/swagger/" 208 func SwaggerServer() http.HandlerFunc { 209 return func(w http.ResponseWriter, r *http.Request) { 210 if !strings.HasSuffix(r.URL.Path, ".swagger.json") { 211 log.Debugf("Not found: %s", r.URL.Path) 212 http.NotFound(w, r) 213 return 214 } 215 216 log.Debugf("Serving %s\n", r.URL.Path) 217 p := strings.TrimPrefix(r.URL.Path, "/swagger/") 218 p = path.Join(swaggerDir, p) 219 http.ServeFile(w, r, p) 220 } 221 } 222 223 // dial the gRPC server. 224 func (g *Gateway) dial(ctx context.Context, network, addr string) (*grpc.ClientConn, error) { 225 switch network { 226 case "tcp": 227 return g.dialTCP(ctx, addr) 228 case "unix": 229 return g.dialUnix(ctx, addr) 230 default: 231 return nil, fmt.Errorf("unsupported network type %q", network) 232 } 233 } 234 235 // dialTCP creates a client connection via TCP. 236 // "addr" must be a valid TCP address with a port number. 237 func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, error) { 238 security := grpc.WithInsecure() 239 if len(g.remoteCert) > 0 { 240 creds, err := credentials.NewClientTLSFromFile(g.remoteCert, "") 241 if err != nil { 242 return nil, err 243 } 244 security = grpc.WithTransportCredentials(creds) 245 } 246 opts := []grpc.DialOption{ 247 security, 248 grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))), 249 } 250 251 return grpc.DialContext(ctx, addr, opts...) 252 } 253 254 // dialUnix creates a client connection via a unix domain socket. 255 // "addr" must be a valid path to the socket. 256 func (g *Gateway) dialUnix(ctx context.Context, addr string) (*grpc.ClientConn, error) { 257 d := func(addr string, timeout time.Duration) (net.Conn, error) { 258 return net.DialTimeout("unix", addr, timeout) 259 } 260 f := func(ctx context.Context, addr string) (net.Conn, error) { 261 if deadline, ok := ctx.Deadline(); ok { 262 return d(addr, time.Until(deadline)) 263 } 264 return d(addr, 0) 265 } 266 opts := []grpc.DialOption{ 267 grpc.WithInsecure(), 268 grpc.WithContextDialer(f), 269 grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))), 270 } 271 return grpc.DialContext(ctx, addr, opts...) 272 } 273 274 func (g *Gateway) registerApiMiddleware() { 275 proxy := &ApiProxyMiddleware{ 276 GatewayAddress: g.gatewayAddr, 277 ProxyAddress: g.apiMiddlewareAddr, 278 EndpointCreator: g.apiMiddlewareEndpointFactory, 279 } 280 log.WithField("API middleware address", g.apiMiddlewareAddr).Info("Starting API middleware") 281 if err := proxy.Run(); err != http.ErrServerClosed { 282 log.WithError(err).Error("Failed to start API middleware") 283 g.startFailure = err 284 return 285 } 286 }