github.com/unionj-cloud/go-doudou@v1.3.8-0.20221011095552-0088008e5b31/framework/http/defaulthttpsrv.go (about)

     1  package ddhttp
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/arl/statsviz"
     7  	"github.com/ascarter/requestid"
     8  	"github.com/gorilla/handlers"
     9  	"github.com/gorilla/mux"
    10  	"github.com/klauspost/compress/gzhttp"
    11  	"github.com/olekukonko/tablewriter"
    12  	"github.com/rs/cors"
    13  	configui "github.com/unionj-cloud/go-doudou/framework/http/config"
    14  	"github.com/unionj-cloud/go-doudou/framework/http/model"
    15  	"github.com/unionj-cloud/go-doudou/framework/http/onlinedoc"
    16  	"github.com/unionj-cloud/go-doudou/framework/http/prometheus"
    17  	"github.com/unionj-cloud/go-doudou/framework/http/registry"
    18  	"github.com/unionj-cloud/go-doudou/framework/internal/banner"
    19  	"github.com/unionj-cloud/go-doudou/framework/internal/config"
    20  	"github.com/unionj-cloud/go-doudou/toolkit/cast"
    21  	"github.com/unionj-cloud/go-doudou/toolkit/stringutils"
    22  	logger "github.com/unionj-cloud/go-doudou/toolkit/zlogger"
    23  	"net/http"
    24  	"net/http/pprof"
    25  	"os"
    26  	"os/signal"
    27  	"path"
    28  	"strconv"
    29  	"strings"
    30  	"time"
    31  )
    32  
    33  var startAt time.Time
    34  
    35  func init() {
    36  	startAt = time.Now()
    37  }
    38  
    39  type common struct {
    40  	gddRoutes   []model.Route
    41  	debugRoutes []model.Route
    42  	bizRoutes   []model.Route
    43  	Middlewares []mux.MiddlewareFunc
    44  }
    45  
    46  // DefaultHttpSrv wraps gorilla mux router
    47  type DefaultHttpSrv struct {
    48  	*mux.Router
    49  	rootRouter *mux.Router
    50  	common
    51  }
    52  
    53  const gddPathPrefix = "/go-doudou/"
    54  const debugPathPrefix = "/debug/"
    55  
    56  var contentTypeShouldbeGzip []string
    57  
    58  func init() {
    59  	contentTypeShouldbeGzip = []string{
    60  		"text/html",
    61  		"text/css",
    62  		"text/plain",
    63  		"text/xml",
    64  		"text/x-component",
    65  		"text/javascript",
    66  		"application/x-javascript",
    67  		"application/javascript",
    68  		"application/json",
    69  		"application/manifest+json",
    70  		"application/vnd.api+json",
    71  		"application/xml",
    72  		"application/xhtml+xml",
    73  		"application/rss+xml",
    74  		"application/atom+xml",
    75  		"application/vnd.ms-fontobject",
    76  		"application/x-font-ttf",
    77  		"application/x-font-opentype",
    78  		"application/x-font-truetype",
    79  		"image/svg+xml",
    80  		"image/x-icon",
    81  		"image/vnd.microsoft.icon",
    82  		"font/ttf",
    83  		"font/eot",
    84  		"font/otf",
    85  		"font/opentype",
    86  	}
    87  }
    88  
    89  // mux.MiddlewareFunc is alias for func(http.Handler) http.Handler
    90  func toMiddlewareFunc(m func(http.Handler) http.HandlerFunc) mux.MiddlewareFunc {
    91  	return func(handler http.Handler) http.Handler {
    92  		return m(handler)
    93  	}
    94  }
    95  
    96  // NewDefaultHttpSrv create a DefaultHttpSrv instance
    97  func NewDefaultHttpSrv() *DefaultHttpSrv {
    98  	rr := config.DefaultGddRouteRootPath
    99  	if stringutils.IsNotEmpty(config.GddRouteRootPath.Load()) {
   100  		rr = config.GddRouteRootPath.Load()
   101  	}
   102  	rootRouter := mux.NewRouter().StrictSlash(true)
   103  	srv := &DefaultHttpSrv{
   104  		Router:     rootRouter.PathPrefix(rr).Subrouter().StrictSlash(true),
   105  		rootRouter: rootRouter,
   106  	}
   107  	srv.Middlewares = append(srv.Middlewares,
   108  		tracing,
   109  		metrics,
   110  	)
   111  	if cast.ToBoolOrDefault(config.GddEnableResponseGzip.Load(), config.DefaultGddEnableResponseGzip) {
   112  		gzipMiddleware, err := gzhttp.NewWrapper(gzhttp.ContentTypes(contentTypeShouldbeGzip))
   113  		if err != nil {
   114  			panic(err)
   115  		}
   116  		srv.Middlewares = append(srv.Middlewares, toMiddlewareFunc(gzipMiddleware))
   117  	}
   118  	if cast.ToBoolOrDefault(config.GddLogReqEnable.Load(), config.DefaultGddLogReqEnable) {
   119  		srv.Middlewares = append(srv.Middlewares, log)
   120  	}
   121  	srv.Middlewares = append(srv.Middlewares,
   122  		requestid.RequestIDHandler,
   123  		handlers.ProxyHeaders,
   124  		fallbackContentType(config.GddFallbackContentType.LoadOrDefault(config.DefaultGddFallbackContentType)),
   125  	)
   126  	return srv
   127  }
   128  
   129  // AddRoute adds routes to router
   130  func (srv *DefaultHttpSrv) AddRoute(route ...model.Route) {
   131  	srv.bizRoutes = append(srv.bizRoutes, route...)
   132  }
   133  
   134  func (srv *common) printRoutes() {
   135  	if !config.CheckDev() {
   136  		return
   137  	}
   138  	logger.Info().Msg("================ Registered Routes ================")
   139  	data := [][]string{}
   140  	rr := config.DefaultGddRouteRootPath
   141  	if stringutils.IsNotEmpty(config.GddRouteRootPath.Load()) {
   142  		rr = config.GddRouteRootPath.Load()
   143  	}
   144  	var all []model.Route
   145  	all = append(all, srv.bizRoutes...)
   146  	all = append(all, srv.gddRoutes...)
   147  	all = append(all, srv.debugRoutes...)
   148  	for _, r := range all {
   149  		if strings.HasPrefix(r.Pattern, gddPathPrefix) || strings.HasPrefix(r.Pattern, debugPathPrefix) {
   150  			data = append(data, []string{r.Name, r.Method, r.Pattern})
   151  		} else {
   152  			data = append(data, []string{r.Name, r.Method, path.Clean(rr + r.Pattern)})
   153  		}
   154  	}
   155  
   156  	tableString := &strings.Builder{}
   157  	table := tablewriter.NewWriter(tableString)
   158  	table.SetHeader([]string{"Name", "Method", "Pattern"})
   159  	for _, v := range data {
   160  		table.Append(v)
   161  	}
   162  	table.Render() // Send output
   163  	rows := strings.Split(strings.TrimSpace(tableString.String()), "\n")
   164  	for _, row := range rows {
   165  		logger.Info().Msg(row)
   166  	}
   167  	logger.Info().Msg("===================================================")
   168  }
   169  
   170  // AddMiddleware adds middlewares to the end of chain
   171  func (srv *DefaultHttpSrv) AddMiddleware(mwf ...func(http.Handler) http.Handler) {
   172  	for _, item := range mwf {
   173  		srv.Middlewares = append(srv.Middlewares, item)
   174  	}
   175  }
   176  
   177  // PreMiddleware adds middlewares to the head of chain
   178  func (srv *DefaultHttpSrv) PreMiddleware(mwf ...func(http.Handler) http.Handler) {
   179  	var middlewares []mux.MiddlewareFunc
   180  	for _, item := range mwf {
   181  		middlewares = append(middlewares, item)
   182  	}
   183  	srv.Middlewares = append(middlewares, srv.Middlewares...)
   184  }
   185  
   186  func (srv *DefaultHttpSrv) newHttpServer() *http.Server {
   187  	write, err := time.ParseDuration(config.GddWriteTimeout.Load())
   188  	if err != nil {
   189  		logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddWriteTimeout),
   190  			config.GddWriteTimeout.Load(), err.Error(), config.DefaultGddWriteTimeout)
   191  		write, _ = time.ParseDuration(config.DefaultGddWriteTimeout)
   192  	}
   193  
   194  	read, err := time.ParseDuration(config.GddReadTimeout.Load())
   195  	if err != nil {
   196  		logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddReadTimeout),
   197  			config.GddReadTimeout.Load(), err.Error(), config.DefaultGddReadTimeout)
   198  		read, _ = time.ParseDuration(config.DefaultGddReadTimeout)
   199  	}
   200  
   201  	idle, err := time.ParseDuration(config.GddIdleTimeout.Load())
   202  	if err != nil {
   203  		logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddIdleTimeout),
   204  			config.GddIdleTimeout.Load(), err.Error(), config.DefaultGddIdleTimeout)
   205  		idle, _ = time.ParseDuration(config.DefaultGddIdleTimeout)
   206  	}
   207  
   208  	httpPort := strconv.Itoa(config.DefaultGddPort)
   209  	if _, err = cast.ToIntE(config.GddPort.Load()); err == nil {
   210  		httpPort = config.GddPort.Load()
   211  	}
   212  	httpHost := config.DefaultGddHost
   213  	if stringutils.IsNotEmpty(config.GddHost.Load()) {
   214  		httpHost = config.GddHost.Load()
   215  	}
   216  	httpServer := &http.Server{
   217  		Addr: strings.Join([]string{httpHost, httpPort}, ":"),
   218  		// Good practice to set timeouts to avoid Slowloris attacks.
   219  		WriteTimeout: write,
   220  		ReadTimeout:  read,
   221  		IdleTimeout:  idle,
   222  		Handler:      srv.rootRouter, // Pass our instance of gorilla/mux in.
   223  	}
   224  
   225  	// Run our server in a goroutine so that it doesn't block.
   226  	go func() {
   227  		logger.Info().Msgf("Http server is listening at %v", httpServer.Addr)
   228  		logger.Info().Msgf("Http server started in %s", time.Since(startAt))
   229  		if err := httpServer.ListenAndServe(); err != nil {
   230  			logger.Error().Err(err).Msg("")
   231  		}
   232  	}()
   233  
   234  	return httpServer
   235  }
   236  
   237  // Run runs http server
   238  func (srv *DefaultHttpSrv) Run() {
   239  	banner.Print()
   240  	manage := cast.ToBoolOrDefault(config.GddManage.Load(), config.DefaultGddManage)
   241  	if manage {
   242  		srv.Middlewares = append([]mux.MiddlewareFunc{prometheus.PrometheusMiddleware}, srv.Middlewares...)
   243  		gddRouter := srv.rootRouter.PathPrefix(gddPathPrefix).Subrouter().StrictSlash(true)
   244  		corsOpts := cors.New(cors.Options{
   245  			AllowedMethods: []string{
   246  				http.MethodGet,
   247  				http.MethodPost,
   248  				http.MethodPut,
   249  				http.MethodPatch,
   250  				http.MethodDelete,
   251  				http.MethodOptions,
   252  				http.MethodHead,
   253  			},
   254  
   255  			AllowedHeaders: []string{
   256  				"*",
   257  			},
   258  
   259  			AllowOriginRequestFunc: func(r *http.Request, origin string) bool {
   260  				if r.URL.Path == fmt.Sprintf("%sopenapi.json", gddPathPrefix) {
   261  					return true
   262  				}
   263  				return false
   264  			},
   265  		})
   266  		gddRouter.Use(metrics)
   267  		gddRouter.Use(corsOpts.Handler)
   268  		gddRouter.Use(basicAuth())
   269  		srv.gddRoutes = append(srv.gddRoutes, onlinedoc.Routes()...)
   270  		srv.gddRoutes = append(srv.gddRoutes, prometheus.Routes()...)
   271  		srv.gddRoutes = append(srv.gddRoutes, registry.Routes()...)
   272  		srv.gddRoutes = append(srv.gddRoutes, configui.Routes()...)
   273  		for _, item := range srv.gddRoutes {
   274  			gddRouter.
   275  				Methods(item.Method, http.MethodOptions).
   276  				Path("/" + strings.TrimPrefix(item.Pattern, gddPathPrefix)).
   277  				Name(item.Name).
   278  				Handler(item.HandlerFunc)
   279  		}
   280  		freq, err := time.ParseDuration(config.GddStatsFreq.Load())
   281  		if err != nil {
   282  			logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddStatsFreq),
   283  				config.GddStatsFreq.Load(), err.Error(), config.DefaultGddStatsFreq)
   284  			freq, _ = time.ParseDuration(config.DefaultGddStatsFreq)
   285  		}
   286  		srv.gddRoutes = append(srv.gddRoutes, []model.Route{
   287  			{
   288  				Name:    "GetStatsvizWs",
   289  				Method:  "GET",
   290  				Pattern: gddPathPrefix + "statsviz/ws",
   291  			},
   292  			{
   293  				Name:    "GetStatsviz",
   294  				Method:  "GET",
   295  				Pattern: gddPathPrefix + "statsviz/",
   296  			},
   297  		}...)
   298  		gddRouter.
   299  			Methods(http.MethodGet).
   300  			Path("/statsviz/ws").
   301  			Name("GetStatsvizWs").
   302  			HandlerFunc(statsviz.NewWsHandler(freq))
   303  		gddRouter.
   304  			Methods(http.MethodGet).
   305  			PathPrefix("/statsviz/").
   306  			Name("GetStatsviz").
   307  			Handler(statsviz.IndexAtRoot(gddPathPrefix + "statsviz/"))
   308  		srv.debugRoutes = append(srv.debugRoutes, []model.Route{
   309  			{
   310  				Name:    "GetDebugPprofCmdline",
   311  				Method:  "GET",
   312  				Pattern: debugPathPrefix + "pprof/cmdline",
   313  			},
   314  			{
   315  				Name:    "GetDebugPprofProfile",
   316  				Method:  "GET",
   317  				Pattern: debugPathPrefix + "pprof/profile",
   318  			},
   319  			{
   320  				Name:    "GetDebugPprofSymbol",
   321  				Method:  "GET",
   322  				Pattern: debugPathPrefix + "pprof/symbol",
   323  			},
   324  			{
   325  				Name:    "GetDebugPprofTrace",
   326  				Method:  "GET",
   327  				Pattern: debugPathPrefix + "pprof/trace",
   328  			},
   329  			{
   330  				Name:    "GetDebugPprofIndex",
   331  				Method:  "GET",
   332  				Pattern: debugPathPrefix + "pprof/",
   333  			},
   334  		}...)
   335  		debugRouter := srv.rootRouter.PathPrefix(debugPathPrefix).Subrouter().StrictSlash(true)
   336  		debugRouter.Use(metrics)
   337  		debugRouter.Use(corsOpts.Handler)
   338  		debugRouter.Use(basicAuth())
   339  		debugRouter.Methods(http.MethodGet).Path("/pprof/cmdline").Name("GetDebugPprofCmdline").HandlerFunc(pprof.Cmdline)
   340  		debugRouter.Methods(http.MethodGet).Path("/pprof/profile").Name("GetDebugPprofProfile").HandlerFunc(pprof.Profile)
   341  		debugRouter.Methods(http.MethodGet).Path("/pprof/symbol").Name("GetDebugPprofSymbol").HandlerFunc(pprof.Symbol)
   342  		debugRouter.Methods(http.MethodGet).Path("/pprof/trace").Name("GetDebugPprofTrace").HandlerFunc(pprof.Trace)
   343  		debugRouter.Methods(http.MethodGet).PathPrefix("/pprof/").Name("GetDebugPprofIndex").HandlerFunc(pprof.Index)
   344  	}
   345  	srv.Middlewares = append(srv.Middlewares, recovery)
   346  	srv.Use(srv.Middlewares...)
   347  	for _, item := range srv.bizRoutes {
   348  		srv.
   349  			Methods(item.Method, http.MethodOptions).
   350  			Path(item.Pattern).
   351  			Name(item.Name).
   352  			Handler(item.HandlerFunc)
   353  	}
   354  	srv.rootRouter.NotFoundHandler = srv.rootRouter.NewRoute().BuildOnly().HandlerFunc(http.NotFound).GetHandler()
   355  	srv.rootRouter.MethodNotAllowedHandler = srv.rootRouter.NewRoute().BuildOnly().HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   356  		w.WriteHeader(http.StatusMethodNotAllowed)
   357  		w.Write([]byte("405 method not allowed"))
   358  	}).GetHandler()
   359  	for i := len(srv.Middlewares) - 1; i >= 0; i-- {
   360  		srv.rootRouter.NotFoundHandler = srv.Middlewares[i].Middleware(srv.rootRouter.NotFoundHandler)
   361  		srv.rootRouter.MethodNotAllowedHandler = srv.Middlewares[i].Middleware(srv.rootRouter.MethodNotAllowedHandler)
   362  	}
   363  	srv.printRoutes()
   364  	httpServer := srv.newHttpServer()
   365  	defer func() {
   366  		grace, err := time.ParseDuration(config.GddGraceTimeout.Load())
   367  		if err != nil {
   368  			logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddGraceTimeout),
   369  				config.GddGraceTimeout.Load(), err.Error(), config.DefaultGddGraceTimeout)
   370  			grace, _ = time.ParseDuration(config.DefaultGddGraceTimeout)
   371  		}
   372  		logger.Info().Msgf("Http server is gracefully shutting down in %s", grace)
   373  
   374  		ctx, cancel := context.WithTimeout(context.Background(), grace)
   375  		defer cancel()
   376  		// Doesn't block if no connections, but will otherwise wait
   377  		// until the timeout deadline.
   378  		httpServer.Shutdown(ctx)
   379  	}()
   380  
   381  	c := make(chan os.Signal, 1)
   382  	// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C)
   383  	// SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught.
   384  	signal.Notify(c, os.Interrupt)
   385  
   386  	// Block until we receive our signal.
   387  	<-c
   388  }