code.vegaprotocol.io/vega@v0.79.0/datanode/gateway/server/server.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package server
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"net"
    22  	"net/http"
    23  	"strconv"
    24  	"strings"
    25  
    26  	"code.vegaprotocol.io/vega/datanode/gateway"
    27  	gql "code.vegaprotocol.io/vega/datanode/gateway/graphql"
    28  	"code.vegaprotocol.io/vega/datanode/gateway/rest"
    29  	libhttp "code.vegaprotocol.io/vega/libs/http"
    30  	"code.vegaprotocol.io/vega/logging"
    31  	"code.vegaprotocol.io/vega/paths"
    32  
    33  	"github.com/rs/cors"
    34  	"golang.org/x/sync/errgroup"
    35  )
    36  
    37  type Server struct {
    38  	cfg       *gateway.Config
    39  	log       *logging.Logger
    40  	vegaPaths paths.Paths
    41  
    42  	rest *rest.ProxyServer
    43  	gql  *gql.GraphServer
    44  
    45  	srv *http.Server
    46  }
    47  
    48  const namedLogger = "gateway"
    49  
    50  func New(cfg gateway.Config, log *logging.Logger, vegaPaths paths.Paths) *Server {
    51  	log = log.Named(namedLogger)
    52  	log.SetLevel(cfg.Level.Get())
    53  
    54  	return &Server{
    55  		log:       log,
    56  		cfg:       &cfg,
    57  		vegaPaths: vegaPaths,
    58  	}
    59  }
    60  
    61  func (srv *Server) Start(ctx context.Context) error {
    62  	eg, ctx := errgroup.WithContext(ctx)
    63  
    64  	// <--- cors support - configure for production
    65  	corsOptions := libhttp.CORSOptions(srv.cfg.CORS)
    66  	corz := cors.New(corsOptions)
    67  	// cors support - configure for production --->
    68  
    69  	var gqlHandler, restHandler http.Handler
    70  	if srv.cfg.GraphQL.Enabled {
    71  		var err error
    72  		srv.gql, err = gql.New(srv.log, *srv.cfg, srv.vegaPaths)
    73  		if err != nil {
    74  			return err
    75  		}
    76  		gqlHandler, err = srv.gql.Start()
    77  		if err != nil {
    78  			return err
    79  		}
    80  	}
    81  
    82  	if srv.cfg.REST.Enabled {
    83  		srv.rest = rest.NewProxyServer(srv.log, *srv.cfg, srv.vegaPaths)
    84  
    85  		var err error
    86  		restHandler, err = srv.rest.Start(ctx)
    87  		if err != nil {
    88  			return err
    89  		}
    90  	}
    91  
    92  	handlr := corz.Handler(
    93  		&Handler{
    94  			gqlPrefix:   srv.cfg.GraphQL.Endpoint,
    95  			gqlHandler:  gqlHandler,
    96  			restHandler: restHandler,
    97  		},
    98  	)
    99  
   100  	port := srv.cfg.Port
   101  	ip := srv.cfg.IP
   102  
   103  	srv.log.Info("Starting http based API", logging.String("addr", ip), logging.Int("port", port))
   104  
   105  	addr := net.JoinHostPort(ip, strconv.Itoa(port))
   106  
   107  	tlsConfig, fallback, err := gateway.GenerateTlsConfig(srv.cfg, srv.vegaPaths)
   108  	if err != nil {
   109  		return fmt.Errorf("problem with HTTPS configuration: %w", err)
   110  	}
   111  	srv.srv = &http.Server{
   112  		Addr:      addr,
   113  		Handler:   handlr,
   114  		TLSConfig: tlsConfig,
   115  	}
   116  
   117  	var fallbacksrv *http.Server
   118  	if srv.cfg.REST.Enabled || srv.cfg.GraphQL.Enabled {
   119  		eg.Go(func() error {
   120  			if srv.srv.TLSConfig != nil {
   121  				if fallback != nil {
   122  					eg.Go(func() error {
   123  						fallbacksrv = &http.Server{Addr: ":http", Handler: fallback}
   124  						// serve HTTP, which will redirect automatically to HTTPS
   125  						err := fallbacksrv.ListenAndServe()
   126  						if err != nil && err != http.ErrServerClosed {
   127  							return fmt.Errorf("failed start fallback http server: %w", err)
   128  						}
   129  						return nil
   130  					})
   131  				}
   132  				err = srv.srv.ListenAndServeTLS("", "")
   133  			} else {
   134  				srv.log.Warn("GraphQL server is not configured to use HTTPS, which is required for subscriptions to work. Please see README.md for help configuring")
   135  				err = srv.srv.ListenAndServe()
   136  			}
   137  			if err != nil && err != http.ErrServerClosed {
   138  				return fmt.Errorf("failed to listen and serve on graphQL server: %w", err)
   139  			}
   140  
   141  			return nil
   142  		})
   143  
   144  		eg.Go(func() error {
   145  			<-ctx.Done()
   146  			srv.stop()
   147  			if fallbacksrv != nil {
   148  				fallbacksrv.Shutdown(context.Background())
   149  			}
   150  			return nil
   151  		})
   152  	}
   153  
   154  	return eg.Wait()
   155  }
   156  
   157  // stop stops the server.
   158  func (srv *Server) stop() {
   159  	if srv.srv != nil {
   160  		srv.log.Info("stopping http based API")
   161  
   162  		if err := srv.srv.Shutdown(context.Background()); err != nil {
   163  			srv.log.Error("Failed to stop http based API cleanly",
   164  				logging.Error(err))
   165  		}
   166  	}
   167  }
   168  
   169  type Handler struct {
   170  	gqlPrefix   string
   171  	restHandler http.Handler
   172  	gqlHandler  http.Handler
   173  }
   174  
   175  func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   176  	if strings.HasPrefix(r.URL.Path, h.gqlPrefix) {
   177  		if h.gqlHandler != nil {
   178  			h.gqlHandler.ServeHTTP(w, r)
   179  			return
   180  		}
   181  	} else if h.restHandler != nil {
   182  		h.restHandler.ServeHTTP(w, r)
   183  		return
   184  	}
   185  
   186  	// cover for unknow routes, or disabled servers
   187  	http.NotFound(w, r)
   188  }