github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/common/fabhttp/server.go (about)

     1  /*
     2  Copyright hechain All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package fabhttp
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"net"
    13  	"net/http"
    14  	"os"
    15  	"time"
    16  
    17  	"github.com/hechain20/hechain/common/flogging"
    18  	"github.com/hechain20/hechain/common/util"
    19  	"github.com/hechain20/hechain/core/middleware"
    20  )
    21  
    22  //go:generate counterfeiter -o fakes/logger.go -fake-name Logger . Logger
    23  
    24  type Logger interface {
    25  	Warn(args ...interface{})
    26  	Warnf(template string, args ...interface{})
    27  }
    28  
    29  type Options struct {
    30  	Logger        Logger
    31  	ListenAddress string
    32  	TLS           TLS
    33  }
    34  
    35  type Server struct {
    36  	logger     Logger
    37  	options    Options
    38  	httpServer *http.Server
    39  	mux        *http.ServeMux
    40  	addr       string
    41  }
    42  
    43  func NewServer(o Options) *Server {
    44  	logger := o.Logger
    45  	if logger == nil {
    46  		logger = flogging.MustGetLogger("fabhttp")
    47  	}
    48  
    49  	server := &Server{
    50  		logger:  logger,
    51  		options: o,
    52  	}
    53  
    54  	server.initializeServer()
    55  
    56  	return server
    57  }
    58  
    59  func (s *Server) Run(signals <-chan os.Signal, ready chan<- struct{}) error {
    60  	err := s.Start()
    61  	if err != nil {
    62  		return err
    63  	}
    64  
    65  	close(ready)
    66  
    67  	<-signals
    68  	return s.Stop()
    69  }
    70  
    71  func (s *Server) Start() error {
    72  	listener, err := s.Listen()
    73  	if err != nil {
    74  		return err
    75  	}
    76  	s.addr = listener.Addr().String()
    77  
    78  	go s.httpServer.Serve(listener)
    79  
    80  	return nil
    81  }
    82  
    83  func (s *Server) Stop() error {
    84  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    85  	defer cancel()
    86  
    87  	return s.httpServer.Shutdown(ctx)
    88  }
    89  
    90  func (s *Server) initializeServer() {
    91  	s.mux = http.NewServeMux()
    92  	s.httpServer = &http.Server{
    93  		Addr:         s.options.ListenAddress,
    94  		Handler:      s.mux,
    95  		ReadTimeout:  10 * time.Second,
    96  		WriteTimeout: 2 * time.Minute,
    97  	}
    98  }
    99  
   100  func (s *Server) HandlerChain(h http.Handler, secure bool) http.Handler {
   101  	if secure {
   102  		return middleware.NewChain(middleware.RequireCert(), middleware.WithRequestID(util.GenerateUUID)).Handler(h)
   103  	}
   104  	return middleware.NewChain(middleware.WithRequestID(util.GenerateUUID)).Handler(h)
   105  }
   106  
   107  // RegisterHandler registers into the ServeMux a handler chain that borrows
   108  // its security properties from the fabhttp.Server. This method is thread
   109  // safe because ServeMux.Handle() is thread safe, and options are immutable.
   110  // This method can be called either before or after Server.Start(). If the
   111  // pattern exists the method panics.
   112  func (s *Server) RegisterHandler(pattern string, handler http.Handler, secure bool) {
   113  	s.mux.Handle(
   114  		pattern,
   115  		s.HandlerChain(
   116  			handler,
   117  			secure,
   118  		),
   119  	)
   120  }
   121  
   122  func (s *Server) Listen() (net.Listener, error) {
   123  	listener, err := net.Listen("tcp", s.options.ListenAddress)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	tlsConfig, err := s.options.TLS.Config()
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	if tlsConfig != nil {
   132  		listener = tls.NewListener(listener, tlsConfig)
   133  	}
   134  	return listener, nil
   135  }
   136  
   137  func (s *Server) Addr() string {
   138  	return s.addr
   139  }
   140  
   141  func (s *Server) Log(keyvals ...interface{}) error {
   142  	s.logger.Warn(keyvals...)
   143  	return nil
   144  }