github.com/lingyao2333/mo-zero@v1.4.1/rest/server.go (about)

     1  package rest
     2  
     3  import (
     4  	"crypto/tls"
     5  	"log"
     6  	"net/http"
     7  	"path"
     8  	"time"
     9  
    10  	"github.com/lingyao2333/mo-zero/core/logx"
    11  	"github.com/lingyao2333/mo-zero/rest/chain"
    12  	"github.com/lingyao2333/mo-zero/rest/handler"
    13  	"github.com/lingyao2333/mo-zero/rest/httpx"
    14  	"github.com/lingyao2333/mo-zero/rest/internal/cors"
    15  	"github.com/lingyao2333/mo-zero/rest/router"
    16  )
    17  
    18  type (
    19  	// RunOption defines the method to customize a Server.
    20  	RunOption func(*Server)
    21  
    22  	// A Server is a http server.
    23  	Server struct {
    24  		ngin   *engine
    25  		router httpx.Router
    26  	}
    27  )
    28  
    29  // MustNewServer returns a server with given config of c and options defined in opts.
    30  // Be aware that later RunOption might overwrite previous one that write the same option.
    31  // The process will exit if error occurs.
    32  func MustNewServer(c RestConf, opts ...RunOption) *Server {
    33  	server, err := NewServer(c, opts...)
    34  	if err != nil {
    35  		log.Fatal(err)
    36  	}
    37  
    38  	return server
    39  }
    40  
    41  // NewServer returns a server with given config of c and options defined in opts.
    42  // Be aware that later RunOption might overwrite previous one that write the same option.
    43  func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
    44  	if err := c.SetUp(); err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	server := &Server{
    49  		ngin:   newEngine(c),
    50  		router: router.NewRouter(),
    51  	}
    52  
    53  	opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...)
    54  	for _, opt := range opts {
    55  		opt(server)
    56  	}
    57  
    58  	return server, nil
    59  }
    60  
    61  // AddRoutes add given routes into the Server.
    62  func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
    63  	r := featuredRoutes{
    64  		routes: rs,
    65  	}
    66  	for _, opt := range opts {
    67  		opt(&r)
    68  	}
    69  	s.ngin.addRoutes(r)
    70  }
    71  
    72  // AddRoute adds given route into the Server.
    73  func (s *Server) AddRoute(r Route, opts ...RouteOption) {
    74  	s.AddRoutes([]Route{r}, opts...)
    75  }
    76  
    77  // PrintRoutes prints the added routes to stdout.
    78  func (s *Server) PrintRoutes() {
    79  	s.ngin.print()
    80  }
    81  
    82  // Routes returns the HTTP routers that registered in the server.
    83  func (s *Server) Routes() []Route {
    84  	var routes []Route
    85  
    86  	for _, r := range s.ngin.routes {
    87  		routes = append(routes, r.routes...)
    88  	}
    89  
    90  	return routes
    91  }
    92  
    93  // Start starts the Server.
    94  // Graceful shutdown is enabled by default.
    95  // Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
    96  func (s *Server) Start() {
    97  	handleError(s.ngin.start(s.router))
    98  }
    99  
   100  // Stop stops the Server.
   101  func (s *Server) Stop() {
   102  	logx.Close()
   103  }
   104  
   105  // Use adds the given middleware in the Server.
   106  func (s *Server) Use(middleware Middleware) {
   107  	s.ngin.use(middleware)
   108  }
   109  
   110  // ToMiddleware converts the given handler to a Middleware.
   111  func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
   112  	return func(handle http.HandlerFunc) http.HandlerFunc {
   113  		return handler(handle).ServeHTTP
   114  	}
   115  }
   116  
   117  // WithChain returns a RunOption that uses the given chain to replace the default chain.
   118  // JWT auth middleware and the middlewares that added by svr.Use() will be appended.
   119  func WithChain(chn chain.Chain) RunOption {
   120  	return func(svr *Server) {
   121  		svr.ngin.chain = chn
   122  	}
   123  }
   124  
   125  // WithCors returns a func to enable CORS for given origin, or default to all origins (*).
   126  func WithCors(origin ...string) RunOption {
   127  	return func(server *Server) {
   128  		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
   129  		server.router = newCorsRouter(server.router, nil, origin...)
   130  	}
   131  }
   132  
   133  // WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
   134  // fn lets caller customizing the response.
   135  func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter),
   136  	origin ...string) RunOption {
   137  	return func(server *Server) {
   138  		server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
   139  		server.router = newCorsRouter(server.router, middlewareFn, origin...)
   140  	}
   141  }
   142  
   143  // WithJwt returns a func to enable jwt authentication in given route.
   144  func WithJwt(secret string) RouteOption {
   145  	return func(r *featuredRoutes) {
   146  		validateSecret(secret)
   147  		r.jwt.enabled = true
   148  		r.jwt.secret = secret
   149  	}
   150  }
   151  
   152  // WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition.
   153  // Which means old and new jwt secrets work together for a period.
   154  func WithJwtTransition(secret, prevSecret string) RouteOption {
   155  	return func(r *featuredRoutes) {
   156  		// why not validate prevSecret, because prevSecret is an already used one,
   157  		// even it not meet our requirement, we still need to allow the transition.
   158  		validateSecret(secret)
   159  		r.jwt.enabled = true
   160  		r.jwt.secret = secret
   161  		r.jwt.prevSecret = prevSecret
   162  	}
   163  }
   164  
   165  // WithMaxBytes returns a RouteOption to set maxBytes with the given value.
   166  func WithMaxBytes(maxBytes int64) RouteOption {
   167  	return func(r *featuredRoutes) {
   168  		r.maxBytes = maxBytes
   169  	}
   170  }
   171  
   172  // WithMiddlewares adds given middlewares to given routes.
   173  func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
   174  	for i := len(ms) - 1; i >= 0; i-- {
   175  		rs = WithMiddleware(ms[i], rs...)
   176  	}
   177  	return rs
   178  }
   179  
   180  // WithMiddleware adds given middleware to given route.
   181  func WithMiddleware(middleware Middleware, rs ...Route) []Route {
   182  	routes := make([]Route, len(rs))
   183  
   184  	for i := range rs {
   185  		route := rs[i]
   186  		routes[i] = Route{
   187  			Method:  route.Method,
   188  			Path:    route.Path,
   189  			Handler: middleware(route.Handler),
   190  		}
   191  	}
   192  
   193  	return routes
   194  }
   195  
   196  // WithNotFoundHandler returns a RunOption with not found handler set to given handler.
   197  func WithNotFoundHandler(handler http.Handler) RunOption {
   198  	return func(server *Server) {
   199  		notFoundHandler := server.ngin.notFoundHandler(handler)
   200  		server.router.SetNotFoundHandler(notFoundHandler)
   201  	}
   202  }
   203  
   204  // WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler.
   205  func WithNotAllowedHandler(handler http.Handler) RunOption {
   206  	return func(server *Server) {
   207  		server.router.SetNotAllowedHandler(handler)
   208  	}
   209  }
   210  
   211  // WithPrefix adds group as a prefix to the route paths.
   212  func WithPrefix(group string) RouteOption {
   213  	return func(r *featuredRoutes) {
   214  		var routes []Route
   215  		for _, rt := range r.routes {
   216  			p := path.Join(group, rt.Path)
   217  			routes = append(routes, Route{
   218  				Method:  rt.Method,
   219  				Path:    p,
   220  				Handler: rt.Handler,
   221  			})
   222  		}
   223  		r.routes = routes
   224  	}
   225  }
   226  
   227  // WithPriority returns a RunOption with priority.
   228  func WithPriority() RouteOption {
   229  	return func(r *featuredRoutes) {
   230  		r.priority = true
   231  	}
   232  }
   233  
   234  // WithRouter returns a RunOption that make server run with given router.
   235  func WithRouter(router httpx.Router) RunOption {
   236  	return func(server *Server) {
   237  		server.router = router
   238  	}
   239  }
   240  
   241  // WithSignature returns a RouteOption to enable signature verification.
   242  func WithSignature(signature SignatureConf) RouteOption {
   243  	return func(r *featuredRoutes) {
   244  		r.signature.enabled = true
   245  		r.signature.Strict = signature.Strict
   246  		r.signature.Expiry = signature.Expiry
   247  		r.signature.PrivateKeys = signature.PrivateKeys
   248  	}
   249  }
   250  
   251  // WithTimeout returns a RouteOption to set timeout with given value.
   252  func WithTimeout(timeout time.Duration) RouteOption {
   253  	return func(r *featuredRoutes) {
   254  		r.timeout = timeout
   255  	}
   256  }
   257  
   258  // WithTLSConfig returns a RunOption that with given tls config.
   259  func WithTLSConfig(cfg *tls.Config) RunOption {
   260  	return func(svr *Server) {
   261  		svr.ngin.setTlsConfig(cfg)
   262  	}
   263  }
   264  
   265  // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
   266  func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
   267  	return func(svr *Server) {
   268  		svr.ngin.setUnauthorizedCallback(callback)
   269  	}
   270  }
   271  
   272  // WithUnsignedCallback returns a RunOption that with given unsigned callback set.
   273  func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
   274  	return func(svr *Server) {
   275  		svr.ngin.setUnsignedCallback(callback)
   276  	}
   277  }
   278  
   279  func handleError(err error) {
   280  	// ErrServerClosed means the server is closed manually
   281  	if err == nil || err == http.ErrServerClosed {
   282  		return
   283  	}
   284  
   285  	logx.Error(err)
   286  	panic(err)
   287  }
   288  
   289  func validateSecret(secret string) {
   290  	if len(secret) < 8 {
   291  		panic("secret's length can't be less than 8")
   292  	}
   293  }
   294  
   295  type corsRouter struct {
   296  	httpx.Router
   297  	middleware Middleware
   298  }
   299  
   300  func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...string) httpx.Router {
   301  	return &corsRouter{
   302  		Router:     router,
   303  		middleware: cors.Middleware(headerFn, origins...),
   304  	}
   305  }
   306  
   307  func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   308  	c.middleware(c.Router.ServeHTTP)(w, r)
   309  }