github.com/unionj-cloud/go-doudou@v1.3.8-0.20221011095552-0088008e5b31/framework/http/httproutersrv.go (about)

     1  package ddhttp
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/arl/statsviz"
     7  	"github.com/ascarter/requestid"
     8  	"github.com/gorilla/handlers"
     9  	"github.com/gorilla/mux"
    10  	"github.com/klauspost/compress/gzhttp"
    11  	"github.com/rs/cors"
    12  	configui "github.com/unionj-cloud/go-doudou/framework/http/config"
    13  	"github.com/unionj-cloud/go-doudou/framework/http/httprouter"
    14  	"github.com/unionj-cloud/go-doudou/framework/http/model"
    15  	"github.com/unionj-cloud/go-doudou/framework/http/onlinedoc"
    16  	"github.com/unionj-cloud/go-doudou/framework/http/prometheus"
    17  	"github.com/unionj-cloud/go-doudou/framework/http/registry"
    18  	"github.com/unionj-cloud/go-doudou/framework/internal/banner"
    19  	"github.com/unionj-cloud/go-doudou/framework/internal/config"
    20  	"github.com/unionj-cloud/go-doudou/toolkit/cast"
    21  	"github.com/unionj-cloud/go-doudou/toolkit/stringutils"
    22  	logger "github.com/unionj-cloud/go-doudou/toolkit/zlogger"
    23  	"net/http"
    24  	"net/http/pprof"
    25  	"os"
    26  	"os/signal"
    27  	"strconv"
    28  	"strings"
    29  	"time"
    30  )
    31  
    32  // HttpRouterSrv wraps httpRouter router
    33  type HttpRouterSrv struct {
    34  	Router     *httprouter.RouteGroup
    35  	rootRouter *httprouter.Router
    36  	common
    37  }
    38  
    39  // NewHttpRouterSrv create a HttpRouterSrv instance
    40  func NewHttpRouterSrv() *HttpRouterSrv {
    41  	rr := config.DefaultGddRouteRootPath
    42  	if stringutils.IsNotEmpty(config.GddRouteRootPath.Load()) {
    43  		rr = config.GddRouteRootPath.Load()
    44  	}
    45  	if stringutils.IsEmpty(rr) {
    46  		rr = "/"
    47  	}
    48  	rootRouter := httprouter.New()
    49  	rootRouter.SaveMatchedRoutePath = cast.ToBoolOrDefault(config.GddRouterSaveMatchedRoutePath.Load(), config.DefaultGddRouterSaveMatchedRoutePath)
    50  	srv := &HttpRouterSrv{
    51  		Router:     rootRouter.NewGroup(rr),
    52  		rootRouter: rootRouter,
    53  	}
    54  	srv.Middlewares = append(srv.Middlewares,
    55  		tracing,
    56  		metrics,
    57  	)
    58  	if cast.ToBoolOrDefault(config.GddEnableResponseGzip.Load(), config.DefaultGddEnableResponseGzip) {
    59  		gzipMiddleware, err := gzhttp.NewWrapper(gzhttp.ContentTypes(contentTypeShouldbeGzip))
    60  		if err != nil {
    61  			panic(err)
    62  		}
    63  		srv.Middlewares = append(srv.Middlewares, toMiddlewareFunc(gzipMiddleware))
    64  	}
    65  	if cast.ToBoolOrDefault(config.GddLogReqEnable.Load(), config.DefaultGddLogReqEnable) {
    66  		srv.Middlewares = append(srv.Middlewares, log)
    67  	}
    68  	srv.Middlewares = append(srv.Middlewares,
    69  		requestid.RequestIDHandler,
    70  		handlers.ProxyHeaders,
    71  		fallbackContentType(config.GddFallbackContentType.LoadOrDefault(config.DefaultGddFallbackContentType)),
    72  	)
    73  	return srv
    74  }
    75  
    76  // AddRoute adds routes to router
    77  func (srv *HttpRouterSrv) AddRoute(route ...model.Route) {
    78  	srv.bizRoutes = append(srv.bizRoutes, route...)
    79  }
    80  
    81  // AddMiddleware adds middlewares to the end of chain
    82  func (srv *HttpRouterSrv) AddMiddleware(mwf ...func(http.Handler) http.Handler) {
    83  	for _, item := range mwf {
    84  		srv.Middlewares = append(srv.Middlewares, item)
    85  	}
    86  }
    87  
    88  // PreMiddleware adds middlewares to the head of chain
    89  func (srv *HttpRouterSrv) PreMiddleware(mwf ...func(http.Handler) http.Handler) {
    90  	var middlewares []mux.MiddlewareFunc
    91  	for _, item := range mwf {
    92  		middlewares = append(middlewares, item)
    93  	}
    94  	srv.Middlewares = append(middlewares, srv.Middlewares...)
    95  }
    96  
    97  // RootRouter returns pointer type of httprouter.Router for directly putting into http.ListenAndServe as http.Handler implementation
    98  func (srv *HttpRouterSrv) RootRouter() *httprouter.Router {
    99  	return srv.rootRouter
   100  }
   101  
   102  func (srv *HttpRouterSrv) newHttpServer() *http.Server {
   103  	write, err := time.ParseDuration(config.GddWriteTimeout.Load())
   104  	if err != nil {
   105  		logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddWriteTimeout),
   106  			config.GddWriteTimeout.Load(), err.Error(), config.DefaultGddWriteTimeout)
   107  		write, _ = time.ParseDuration(config.DefaultGddWriteTimeout)
   108  	}
   109  
   110  	read, err := time.ParseDuration(config.GddReadTimeout.Load())
   111  	if err != nil {
   112  		logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddReadTimeout),
   113  			config.GddReadTimeout.Load(), err.Error(), config.DefaultGddReadTimeout)
   114  		read, _ = time.ParseDuration(config.DefaultGddReadTimeout)
   115  	}
   116  
   117  	idle, err := time.ParseDuration(config.GddIdleTimeout.Load())
   118  	if err != nil {
   119  		logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddIdleTimeout),
   120  			config.GddIdleTimeout.Load(), err.Error(), config.DefaultGddIdleTimeout)
   121  		idle, _ = time.ParseDuration(config.DefaultGddIdleTimeout)
   122  	}
   123  
   124  	httpPort := strconv.Itoa(config.DefaultGddPort)
   125  	if _, err = cast.ToIntE(config.GddPort.Load()); err == nil {
   126  		httpPort = config.GddPort.Load()
   127  	}
   128  	httpHost := config.DefaultGddHost
   129  	if stringutils.IsNotEmpty(config.GddHost.Load()) {
   130  		httpHost = config.GddHost.Load()
   131  	}
   132  	httpServer := &http.Server{
   133  		Addr: strings.Join([]string{httpHost, httpPort}, ":"),
   134  		// Good practice to set timeouts to avoid Slowloris attacks.
   135  		WriteTimeout: write,
   136  		ReadTimeout:  read,
   137  		IdleTimeout:  idle,
   138  		Handler:      srv.rootRouter, // Pass our instance of httprouter.Router in.
   139  	}
   140  
   141  	// Run our server in a goroutine so that it doesn't block.
   142  	go func() {
   143  		logger.Info().Msgf("Http server is listening at %v", httpServer.Addr)
   144  		logger.Info().Msgf("Http server started in %s", time.Since(startAt))
   145  		if err := httpServer.ListenAndServe(); err != nil {
   146  			logger.Error().Err(err).Msg("")
   147  		}
   148  	}()
   149  
   150  	return httpServer
   151  }
   152  
   153  // Run runs http server
   154  func (srv *HttpRouterSrv) Run() {
   155  	banner.Print()
   156  	manage := cast.ToBoolOrDefault(config.GddManage.Load(), config.DefaultGddManage)
   157  	if manage {
   158  		srv.Middlewares = append([]mux.MiddlewareFunc{prometheus.PrometheusMiddleware}, srv.Middlewares...)
   159  		gddRouter := srv.rootRouter.NewGroup(gddPathPrefix)
   160  		corsOpts := cors.New(cors.Options{
   161  			AllowedMethods: []string{
   162  				http.MethodGet,
   163  				http.MethodPost,
   164  				http.MethodPut,
   165  				http.MethodPatch,
   166  				http.MethodDelete,
   167  				http.MethodOptions,
   168  				http.MethodHead,
   169  			},
   170  
   171  			AllowedHeaders: []string{
   172  				"*",
   173  			},
   174  
   175  			AllowOriginRequestFunc: func(r *http.Request, origin string) bool {
   176  				if r.URL.Path == fmt.Sprintf("%sopenapi.json", gddPathPrefix) {
   177  					return true
   178  				}
   179  				return false
   180  			},
   181  		})
   182  		basicAuthMiddle := mux.MiddlewareFunc(basicAuth())
   183  		gddMiddlewares := []mux.MiddlewareFunc{metrics, corsOpts.Handler, basicAuthMiddle}
   184  		srv.gddRoutes = append(srv.gddRoutes, onlinedoc.Routes()...)
   185  		srv.gddRoutes = append(srv.gddRoutes, prometheus.Routes()...)
   186  		srv.gddRoutes = append(srv.gddRoutes, registry.Routes()...)
   187  		srv.gddRoutes = append(srv.gddRoutes, configui.Routes()...)
   188  		freq, err := time.ParseDuration(config.GddStatsFreq.Load())
   189  		if err != nil {
   190  			logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddStatsFreq),
   191  				config.GddStatsFreq.Load(), err.Error(), config.DefaultGddStatsFreq)
   192  			freq, _ = time.ParseDuration(config.DefaultGddStatsFreq)
   193  		}
   194  		_ = freq
   195  		srv.gddRoutes = append(srv.gddRoutes, []model.Route{
   196  			{
   197  				Name:    "GetStatsvizWs",
   198  				Method:  http.MethodGet,
   199  				Pattern: gddPathPrefix + "statsviz/ws",
   200  			},
   201  			{
   202  				Name:    "GetStatsviz",
   203  				Method:  http.MethodGet,
   204  				Pattern: gddPathPrefix + "statsviz/*",
   205  				HandlerFunc: func(writer http.ResponseWriter, request *http.Request) {
   206  					if strings.HasSuffix(request.URL.Path, "/ws") {
   207  						statsviz.Ws(writer, request)
   208  						return
   209  					}
   210  					statsviz.IndexAtRoot(gddPathPrefix+"statsviz/").ServeHTTP(writer, request)
   211  				},
   212  			},
   213  		}...)
   214  		for _, item := range srv.gddRoutes {
   215  			if item.HandlerFunc == nil {
   216  				continue
   217  			}
   218  			h := http.Handler(item.HandlerFunc)
   219  			for i := len(gddMiddlewares) - 1; i >= 0; i-- {
   220  				h = gddMiddlewares[i].Middleware(h)
   221  			}
   222  			gddRouter.Handler(item.Method, "/"+strings.TrimPrefix(item.Pattern, gddPathPrefix), h, item.Name)
   223  		}
   224  		srv.debugRoutes = append(srv.debugRoutes, []model.Route{
   225  			{
   226  				Name:    "GetDebugPprofCmdline",
   227  				Method:  http.MethodGet,
   228  				Pattern: debugPathPrefix + "pprof/cmdline",
   229  			},
   230  			{
   231  				Name:    "GetDebugPprofProfile",
   232  				Method:  http.MethodGet,
   233  				Pattern: debugPathPrefix + "pprof/profile",
   234  			},
   235  			{
   236  				Name:    "GetDebugPprofSymbol",
   237  				Method:  http.MethodGet,
   238  				Pattern: debugPathPrefix + "pprof/symbol",
   239  			},
   240  			{
   241  				Name:    "GetDebugPprofTrace",
   242  				Method:  http.MethodGet,
   243  				Pattern: debugPathPrefix + "pprof/trace",
   244  			},
   245  			{
   246  				Name:    "GetDebugPprofIndex",
   247  				Method:  http.MethodGet,
   248  				Pattern: debugPathPrefix + "pprof/*",
   249  				HandlerFunc: func(writer http.ResponseWriter, request *http.Request) {
   250  					lastSegment := request.URL.Path[strings.LastIndex(request.URL.Path, "/"):]
   251  					switch lastSegment {
   252  					case "/cmdline":
   253  						pprof.Cmdline(writer, request)
   254  						return
   255  					case "/profile":
   256  						pprof.Profile(writer, request)
   257  						return
   258  					case "/symbol":
   259  						pprof.Symbol(writer, request)
   260  						return
   261  					case "/trace":
   262  						pprof.Trace(writer, request)
   263  						return
   264  					}
   265  					pprof.Index(writer, request)
   266  				},
   267  			},
   268  		}...)
   269  		debugRouter := srv.rootRouter.NewGroup(debugPathPrefix)
   270  		for _, item := range srv.debugRoutes {
   271  			if item.HandlerFunc == nil {
   272  				continue
   273  			}
   274  			h := http.Handler(item.HandlerFunc)
   275  			for i := len(gddMiddlewares) - 1; i >= 0; i-- {
   276  				h = gddMiddlewares[i].Middleware(h)
   277  			}
   278  			debugRouter.Handler(item.Method, "/"+strings.TrimPrefix(item.Pattern, debugPathPrefix), h, item.Name)
   279  		}
   280  	}
   281  	srv.Middlewares = append(srv.Middlewares, recovery)
   282  	for _, item := range srv.bizRoutes {
   283  		h := http.Handler(item.HandlerFunc)
   284  		for i := len(srv.Middlewares) - 1; i >= 0; i-- {
   285  			h = srv.Middlewares[i].Middleware(h)
   286  		}
   287  		srv.Router.Handler(item.Method, item.Pattern, h, item.Name)
   288  	}
   289  	srv.rootRouter.NotFound = http.HandlerFunc(http.NotFound)
   290  	srv.rootRouter.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   291  		w.WriteHeader(http.StatusMethodNotAllowed)
   292  		w.Write([]byte("405 method not allowed"))
   293  	})
   294  	for i := len(srv.Middlewares) - 1; i >= 0; i-- {
   295  		srv.rootRouter.NotFound = srv.Middlewares[i].Middleware(srv.rootRouter.NotFound)
   296  		srv.rootRouter.MethodNotAllowed = srv.Middlewares[i].Middleware(srv.rootRouter.MethodNotAllowed)
   297  	}
   298  	srv.printRoutes()
   299  	httpServer := srv.newHttpServer()
   300  	defer func() {
   301  		grace, err := time.ParseDuration(config.GddGraceTimeout.Load())
   302  		if err != nil {
   303  			logger.Debug().Msgf("Parse %s %s as time.Duration failed: %s, use default %s instead.\n", string(config.GddGraceTimeout),
   304  				config.GddGraceTimeout.Load(), err.Error(), config.DefaultGddGraceTimeout)
   305  			grace, _ = time.ParseDuration(config.DefaultGddGraceTimeout)
   306  		}
   307  		logger.Info().Msgf("Http server is gracefully shutting down in %s", grace)
   308  
   309  		ctx, cancel := context.WithTimeout(context.Background(), grace)
   310  		defer cancel()
   311  		// Doesn't block if no connections, but will otherwise wait
   312  		// until the timeout deadline.
   313  		httpServer.Shutdown(ctx)
   314  	}()
   315  
   316  	c := make(chan os.Signal, 1)
   317  	// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C)
   318  	// SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught.
   319  	signal.Notify(c, os.Interrupt)
   320  
   321  	// Block until we receive our signal.
   322  	<-c
   323  }