github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/remotesrv/server.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package remotesrv
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"net"
    21  	"net/http"
    22  	"strings"
    23  	"sync"
    24  
    25  	"github.com/sirupsen/logrus"
    26  	"golang.org/x/net/http2"
    27  	"golang.org/x/net/http2/h2c"
    28  	"google.golang.org/grpc"
    29  
    30  	remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1"
    31  	"github.com/dolthub/dolt/go/libraries/utils/filesys"
    32  )
    33  
    34  type Server struct {
    35  	wg       sync.WaitGroup
    36  	stopChan chan struct{}
    37  
    38  	grpcListenAddr string
    39  	httpListenAddr string
    40  
    41  	grpcSrv *grpc.Server
    42  	httpSrv http.Server
    43  
    44  	grpcHttpReqsWG sync.WaitGroup
    45  
    46  	tlsConfig *tls.Config
    47  }
    48  
    49  func (s *Server) GracefulStop() {
    50  	close(s.stopChan)
    51  	s.wg.Wait()
    52  }
    53  
    54  type ServerArgs struct {
    55  	Logger   *logrus.Entry
    56  	HttpHost string
    57  
    58  	HttpListenAddr string
    59  	GrpcListenAddr string
    60  
    61  	FS       filesys.Filesys
    62  	DBCache  DBCache
    63  	ReadOnly bool
    64  	Options  []grpc.ServerOption
    65  
    66  	ConcurrencyControl remotesapi.PushConcurrencyControl
    67  
    68  	HttpInterceptor func(http.Handler) http.Handler
    69  
    70  	// If supplied, the listener(s) returned from Listeners() will be TLS
    71  	// listeners. The scheme used in the URLs returned from the gRPC server
    72  	// will be https.
    73  	TLSConfig *tls.Config
    74  }
    75  
    76  func NewServer(args ServerArgs) (*Server, error) {
    77  	if args.Logger == nil {
    78  		args.Logger = logrus.NewEntry(logrus.StandardLogger())
    79  	}
    80  
    81  	s := new(Server)
    82  	s.stopChan = make(chan struct{})
    83  
    84  	sealer, err := NewSingleSymmetricKeySealer()
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	scheme := "http"
    90  	if args.TLSConfig != nil {
    91  		scheme = "https"
    92  	}
    93  	s.tlsConfig = args.TLSConfig
    94  
    95  	s.wg.Add(2)
    96  	s.grpcListenAddr = args.GrpcListenAddr
    97  	s.grpcSrv = grpc.NewServer(append([]grpc.ServerOption{grpc.MaxRecvMsgSize(128 * 1024 * 1024)}, args.Options...)...)
    98  	var chnkSt remotesapi.ChunkStoreServiceServer = NewHttpFSBackedChunkStore(args.Logger, args.HttpHost, args.DBCache, args.FS, scheme, args.ConcurrencyControl, sealer)
    99  	if args.ReadOnly {
   100  		chnkSt = ReadOnlyChunkStore{chnkSt}
   101  	}
   102  	remotesapi.RegisterChunkStoreServiceServer(s.grpcSrv, chnkSt)
   103  
   104  	var handler http.Handler = newFileHandler(args.Logger, args.DBCache, args.FS, args.ReadOnly, sealer)
   105  	if args.HttpInterceptor != nil {
   106  		handler = args.HttpInterceptor(handler)
   107  	}
   108  	if args.HttpListenAddr == args.GrpcListenAddr {
   109  		handler = s.grpcMultiplexHandler(s.grpcSrv, handler)
   110  	} else {
   111  		s.wg.Add(2)
   112  	}
   113  
   114  	s.httpListenAddr = args.HttpListenAddr
   115  	s.httpSrv = http.Server{
   116  		Addr:    args.HttpListenAddr,
   117  		Handler: handler,
   118  	}
   119  
   120  	return s, nil
   121  }
   122  
   123  func (s *Server) grpcMultiplexHandler(grpcSrv *grpc.Server, handler http.Handler) http.Handler {
   124  	h2s := &http2.Server{}
   125  	newHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   126  		if r.ProtoMajor == 2 && strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") {
   127  			s.grpcHttpReqsWG.Add(1)
   128  			defer s.grpcHttpReqsWG.Done()
   129  			grpcSrv.ServeHTTP(w, r)
   130  		} else {
   131  			handler.ServeHTTP(w, r)
   132  		}
   133  	})
   134  	return h2c.NewHandler(newHandler, h2s)
   135  }
   136  
   137  type Listeners struct {
   138  	http net.Listener
   139  	grpc net.Listener
   140  }
   141  
   142  func (l Listeners) Close() error {
   143  	if l.http != nil {
   144  		err := l.http.Close()
   145  		if err != nil {
   146  			if l.grpc != nil {
   147  				l.grpc.Close()
   148  			}
   149  			return err
   150  		}
   151  	}
   152  	if l.grpc != nil {
   153  		return l.grpc.Close()
   154  	}
   155  	return nil
   156  }
   157  
   158  func (s *Server) Listeners() (Listeners, error) {
   159  	var httpListener net.Listener
   160  	var grpcListener net.Listener
   161  	var err error
   162  	if s.tlsConfig != nil {
   163  		httpListener, err = tls.Listen("tcp", s.httpListenAddr, s.tlsConfig)
   164  	} else {
   165  		httpListener, err = net.Listen("tcp", s.httpListenAddr)
   166  	}
   167  	if err != nil {
   168  		return Listeners{}, err
   169  	}
   170  	if s.httpListenAddr == s.grpcListenAddr {
   171  		return Listeners{http: httpListener}, nil
   172  	}
   173  	if s.tlsConfig != nil {
   174  		grpcListener, err = tls.Listen("tcp", s.grpcListenAddr, s.tlsConfig)
   175  	} else {
   176  		grpcListener, err = net.Listen("tcp", s.grpcListenAddr)
   177  	}
   178  	if err != nil {
   179  		httpListener.Close()
   180  		return Listeners{}, err
   181  	}
   182  	return Listeners{http: httpListener, grpc: grpcListener}, nil
   183  }
   184  
   185  // Can be used to register more services on the server.
   186  // Should only be accessed before `Serve` is called.
   187  func (s *Server) GrpcServer() *grpc.Server {
   188  	return s.grpcSrv
   189  }
   190  
   191  func (s *Server) Serve(listeners Listeners) {
   192  	if listeners.grpc != nil {
   193  		go func() {
   194  			defer s.wg.Done()
   195  			logrus.Println("Starting grpc server on", s.grpcListenAddr)
   196  			err := s.grpcSrv.Serve(listeners.grpc)
   197  			logrus.Println("grpc server exited. error:", err)
   198  		}()
   199  		go func() {
   200  			defer s.wg.Done()
   201  			<-s.stopChan
   202  			logrus.Traceln("Calling grpcSrv.GracefulStop")
   203  			s.grpcSrv.GracefulStop()
   204  			logrus.Traceln("Finished calling grpcSrv.GracefulStop")
   205  		}()
   206  	}
   207  
   208  	go func() {
   209  		defer s.wg.Done()
   210  		logrus.Println("Starting http server on", s.httpListenAddr)
   211  		err := s.httpSrv.Serve(listeners.http)
   212  		logrus.Println("http server exited. exit error:", err)
   213  	}()
   214  	go func() {
   215  		defer s.wg.Done()
   216  		<-s.stopChan
   217  		logrus.Traceln("Calling httpSrv.Shutdown")
   218  		s.httpSrv.Shutdown(context.Background())
   219  		logrus.Traceln("Finished calling httpSrv.Shutdown")
   220  
   221  		// If we are multiplexing HTTP and gRPC requests on the same
   222  		// listener, we need to stop the gRPC server here as well. We
   223  		// cannot stop it gracefully, but if we stop it forcefully
   224  		// here, we guarantee all the handler threads are cleaned up
   225  		// before we return.
   226  		if listeners.grpc == nil {
   227  			logrus.Traceln("Calling grpcSrv.Stop")
   228  			s.grpcSrv.Stop()
   229  			s.grpcHttpReqsWG.Wait()
   230  			logrus.Traceln("Finished calling grpcSrv.Stop")
   231  		}
   232  	}()
   233  
   234  	s.wg.Wait()
   235  }