github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/server.go (about)

     1  // Copyright 2021 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package larking
     6  
     7  import (
     8  	"context"
     9  	"crypto/tls"
    10  	"fmt"
    11  	"math"
    12  	"net"
    13  	"net/http"
    14  	"os"
    15  	"os/signal"
    16  	"runtime"
    17  	"strings"
    18  	"time"
    19  
    20  	"github.com/go-logr/logr"
    21  	"golang.org/x/net/http2"
    22  	"golang.org/x/net/http2/h2c"
    23  	"golang.org/x/net/trace"
    24  	"google.golang.org/grpc"
    25  	"google.golang.org/grpc/credentials/insecure"
    26  )
    27  
    28  // NewOSSignalContext tries to gracefully handle OS closure.
    29  func NewOSSignalContext(ctx context.Context) (context.Context, func()) {
    30  	// trap Ctrl+C and call cancel on the context
    31  	ctx, cancel := context.WithCancel(ctx)
    32  	c := make(chan os.Signal, 1)
    33  	signal.Notify(c, os.Interrupt)
    34  	go func() {
    35  		select {
    36  		case <-c:
    37  			cancel()
    38  		case <-ctx.Done():
    39  		}
    40  	}()
    41  
    42  	return ctx, func() {
    43  		signal.Stop(c)
    44  		cancel()
    45  	}
    46  }
    47  
    48  type Server struct {
    49  	opts serverOptions
    50  	mux  *Mux
    51  
    52  	gs  *grpc.Server
    53  	hs  *http.Server
    54  	h2s *http2.Server
    55  
    56  	events trace.EventLog
    57  }
    58  
    59  // NewServer creates a new Proxy server.
    60  func NewServer(mux *Mux, opts ...ServerOption) (*Server, error) {
    61  	if mux == nil {
    62  		return nil, fmt.Errorf("invalid mux must not be nil")
    63  	}
    64  
    65  	var svrOpts serverOptions
    66  	for _, opt := range opts {
    67  		if err := opt(&svrOpts); err != nil {
    68  			return nil, err
    69  		}
    70  	}
    71  	if svrOpts.tlsConfig == nil && !svrOpts.insecure {
    72  		return nil, fmt.Errorf("credentials must be set")
    73  	}
    74  
    75  	svrOpts.serveMux = http.NewServeMux()
    76  	if len(svrOpts.muxPatterns) == 0 {
    77  		svrOpts.muxPatterns = []string{"/"}
    78  	}
    79  	for _, pattern := range svrOpts.muxPatterns {
    80  		prefix := strings.TrimSuffix(pattern, "/")
    81  		if len(prefix) > 0 {
    82  			svrOpts.serveMux.Handle(prefix+"/", http.StripPrefix(prefix, mux))
    83  		} else {
    84  			svrOpts.serveMux.Handle("/", mux)
    85  		}
    86  	}
    87  
    88  	// TODO: use our own flag?
    89  	// grpc.EnableTracing sets tracing for the golang.org/x/net/trace
    90  	var events trace.EventLog
    91  	if grpc.EnableTracing {
    92  		_, file, line, _ := runtime.Caller(1)
    93  		events = trace.NewEventLog("larking.Server", fmt.Sprintf("%s:%d", file, line))
    94  	}
    95  
    96  	var grpcOpts []grpc.ServerOption
    97  
    98  	grpcOpts = append(grpcOpts, grpc.UnknownServiceHandler(mux.StreamHandler()))
    99  	if i := mux.opts.unaryInterceptor; i != nil {
   100  		grpcOpts = append(grpcOpts, grpc.UnaryInterceptor(i))
   101  	}
   102  	if i := mux.opts.streamInterceptor; i != nil {
   103  		grpcOpts = append(grpcOpts, grpc.StreamInterceptor(i))
   104  	}
   105  	if h := mux.opts.statsHandler; h != nil {
   106  		grpcOpts = append(grpcOpts, grpc.StatsHandler(h))
   107  	}
   108  
   109  	// TLS termination controlled by listeners in Serve.
   110  	creds := insecure.NewCredentials()
   111  	grpcOpts = append(grpcOpts, grpc.Creds(creds))
   112  
   113  	gs := grpc.NewServer(grpcOpts...)
   114  	// Register local gRPC services
   115  	for sd, ss := range mux.services {
   116  		gs.RegisterService(sd, ss)
   117  	}
   118  	serveWeb := createGRPCWebHandler(gs)
   119  	index := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   120  		contentType := r.Header.Get("content-type")
   121  		if strings.HasPrefix(contentType, grpcWeb) {
   122  			serveWeb(w, r)
   123  		} else if r.ProtoMajor == 2 && strings.HasPrefix(contentType, grpcBase) {
   124  			gs.ServeHTTP(w, r)
   125  		} else {
   126  			svrOpts.serveMux.ServeHTTP(w, r)
   127  		}
   128  	})
   129  	h2s := &http2.Server{}
   130  	hs := &http.Server{
   131  		Handler:   h2c.NewHandler(index, h2s),
   132  		TLSConfig: svrOpts.tlsConfig,
   133  	}
   134  	if err := http2.ConfigureServer(hs, h2s); err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	return &Server{
   139  		opts:   svrOpts,
   140  		mux:    mux,
   141  		gs:     gs,
   142  		hs:     hs,
   143  		h2s:    h2s,
   144  		events: events,
   145  	}, nil
   146  }
   147  
   148  // Serve accepts incoming connections on the listener.
   149  // Serve will return always return a non-nil error, http.ErrServerClosed.
   150  func (s *Server) Serve(l net.Listener) error {
   151  	if config := s.opts.tlsConfig; config != nil {
   152  		l = tls.NewListener(l, config)
   153  	}
   154  	return s.hs.Serve(l)
   155  }
   156  
   157  func (s *Server) Shutdown(ctx context.Context) error {
   158  	if s.events != nil {
   159  		s.events.Finish()
   160  		s.events = nil
   161  	}
   162  	if err := s.hs.Shutdown(ctx); err != nil {
   163  		return err
   164  	}
   165  	return nil
   166  }
   167  
   168  func (s *Server) Close() error {
   169  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   170  	defer cancel()
   171  	return s.Shutdown(ctx)
   172  }
   173  
   174  const (
   175  	defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4
   176  	defaultServerMaxSendMessageSize    = math.MaxInt32
   177  	defaultServerConnectionTimeout     = 120 * time.Second
   178  )
   179  
   180  type serverOptions struct {
   181  	tlsConfig *tls.Config
   182  	insecure  bool
   183  	log       logr.Logger
   184  
   185  	muxPatterns []string
   186  	serveMux    *http.ServeMux
   187  }
   188  
   189  // ServerOption is similar to grpc.ServerOption.
   190  type ServerOption func(*serverOptions) error
   191  
   192  func TLSCredsOption(c *tls.Config) ServerOption {
   193  	return func(opts *serverOptions) error {
   194  		opts.tlsConfig = c
   195  		return nil
   196  	}
   197  }
   198  
   199  func InsecureServerOption() ServerOption {
   200  	return func(opts *serverOptions) error {
   201  		opts.insecure = true
   202  		return nil
   203  	}
   204  }
   205  
   206  //func LarkingServerOption(threads map[string]string) ServerOption {
   207  //	return func(opts *serverOptions) error {
   208  //		opts.larkingEnabled = true
   209  //		opts.larkingThreads = threads
   210  //		return nil
   211  //	}
   212  //}
   213  
   214  func LogOption(log logr.Logger) ServerOption {
   215  	return func(opts *serverOptions) error {
   216  		opts.log = log
   217  		return nil
   218  	}
   219  }
   220  
   221  func MuxHandleOption(patterns ...string) ServerOption {
   222  	return func(opts *serverOptions) error {
   223  		if opts.muxPatterns != nil {
   224  			return fmt.Errorf("duplicate mux patterns registered")
   225  		}
   226  		opts.muxPatterns = patterns
   227  		return nil
   228  	}
   229  }
   230  
   231  func HTTPHandlerOption(pattern string, handler http.Handler) ServerOption {
   232  	return func(opts *serverOptions) error {
   233  		if opts.serveMux == nil {
   234  			opts.serveMux = http.NewServeMux()
   235  		}
   236  		opts.serveMux.Handle(pattern, handler)
   237  		return nil
   238  	}
   239  }
   240  
   241  //func AdminOption(addr string) ServerOption {
   242  //	return func(opts *serverOptions) {
   243  //
   244  //	}
   245  //}
   246  
   247  //func (s *Server) RegisterService(desc *grpc.ServiceDesc, impl interface{}) {
   248  //	s.gs.RegisterService(desc, impl)
   249  //	if s.opts.mux != nil {
   250  //		s.opts.mux.RegisterService(desc, impl)
   251  //	}
   252  //}