github.com/lingyao2333/mo-zero@v1.4.1/rest/server.go (about) 1 package rest 2 3 import ( 4 "crypto/tls" 5 "log" 6 "net/http" 7 "path" 8 "time" 9 10 "github.com/lingyao2333/mo-zero/core/logx" 11 "github.com/lingyao2333/mo-zero/rest/chain" 12 "github.com/lingyao2333/mo-zero/rest/handler" 13 "github.com/lingyao2333/mo-zero/rest/httpx" 14 "github.com/lingyao2333/mo-zero/rest/internal/cors" 15 "github.com/lingyao2333/mo-zero/rest/router" 16 ) 17 18 type ( 19 // RunOption defines the method to customize a Server. 20 RunOption func(*Server) 21 22 // A Server is a http server. 23 Server struct { 24 ngin *engine 25 router httpx.Router 26 } 27 ) 28 29 // MustNewServer returns a server with given config of c and options defined in opts. 30 // Be aware that later RunOption might overwrite previous one that write the same option. 31 // The process will exit if error occurs. 32 func MustNewServer(c RestConf, opts ...RunOption) *Server { 33 server, err := NewServer(c, opts...) 34 if err != nil { 35 log.Fatal(err) 36 } 37 38 return server 39 } 40 41 // NewServer returns a server with given config of c and options defined in opts. 42 // Be aware that later RunOption might overwrite previous one that write the same option. 43 func NewServer(c RestConf, opts ...RunOption) (*Server, error) { 44 if err := c.SetUp(); err != nil { 45 return nil, err 46 } 47 48 server := &Server{ 49 ngin: newEngine(c), 50 router: router.NewRouter(), 51 } 52 53 opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...) 54 for _, opt := range opts { 55 opt(server) 56 } 57 58 return server, nil 59 } 60 61 // AddRoutes add given routes into the Server. 62 func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) { 63 r := featuredRoutes{ 64 routes: rs, 65 } 66 for _, opt := range opts { 67 opt(&r) 68 } 69 s.ngin.addRoutes(r) 70 } 71 72 // AddRoute adds given route into the Server. 73 func (s *Server) AddRoute(r Route, opts ...RouteOption) { 74 s.AddRoutes([]Route{r}, opts...) 75 } 76 77 // PrintRoutes prints the added routes to stdout. 78 func (s *Server) PrintRoutes() { 79 s.ngin.print() 80 } 81 82 // Routes returns the HTTP routers that registered in the server. 83 func (s *Server) Routes() []Route { 84 var routes []Route 85 86 for _, r := range s.ngin.routes { 87 routes = append(routes, r.routes...) 88 } 89 90 return routes 91 } 92 93 // Start starts the Server. 94 // Graceful shutdown is enabled by default. 95 // Use proc.SetTimeToForceQuit to customize the graceful shutdown period. 96 func (s *Server) Start() { 97 handleError(s.ngin.start(s.router)) 98 } 99 100 // Stop stops the Server. 101 func (s *Server) Stop() { 102 logx.Close() 103 } 104 105 // Use adds the given middleware in the Server. 106 func (s *Server) Use(middleware Middleware) { 107 s.ngin.use(middleware) 108 } 109 110 // ToMiddleware converts the given handler to a Middleware. 111 func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { 112 return func(handle http.HandlerFunc) http.HandlerFunc { 113 return handler(handle).ServeHTTP 114 } 115 } 116 117 // WithChain returns a RunOption that uses the given chain to replace the default chain. 118 // JWT auth middleware and the middlewares that added by svr.Use() will be appended. 119 func WithChain(chn chain.Chain) RunOption { 120 return func(svr *Server) { 121 svr.ngin.chain = chn 122 } 123 } 124 125 // WithCors returns a func to enable CORS for given origin, or default to all origins (*). 126 func WithCors(origin ...string) RunOption { 127 return func(server *Server) { 128 server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...)) 129 server.router = newCorsRouter(server.router, nil, origin...) 130 } 131 } 132 133 // WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*), 134 // fn lets caller customizing the response. 135 func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter), 136 origin ...string) RunOption { 137 return func(server *Server) { 138 server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...)) 139 server.router = newCorsRouter(server.router, middlewareFn, origin...) 140 } 141 } 142 143 // WithJwt returns a func to enable jwt authentication in given route. 144 func WithJwt(secret string) RouteOption { 145 return func(r *featuredRoutes) { 146 validateSecret(secret) 147 r.jwt.enabled = true 148 r.jwt.secret = secret 149 } 150 } 151 152 // WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition. 153 // Which means old and new jwt secrets work together for a period. 154 func WithJwtTransition(secret, prevSecret string) RouteOption { 155 return func(r *featuredRoutes) { 156 // why not validate prevSecret, because prevSecret is an already used one, 157 // even it not meet our requirement, we still need to allow the transition. 158 validateSecret(secret) 159 r.jwt.enabled = true 160 r.jwt.secret = secret 161 r.jwt.prevSecret = prevSecret 162 } 163 } 164 165 // WithMaxBytes returns a RouteOption to set maxBytes with the given value. 166 func WithMaxBytes(maxBytes int64) RouteOption { 167 return func(r *featuredRoutes) { 168 r.maxBytes = maxBytes 169 } 170 } 171 172 // WithMiddlewares adds given middlewares to given routes. 173 func WithMiddlewares(ms []Middleware, rs ...Route) []Route { 174 for i := len(ms) - 1; i >= 0; i-- { 175 rs = WithMiddleware(ms[i], rs...) 176 } 177 return rs 178 } 179 180 // WithMiddleware adds given middleware to given route. 181 func WithMiddleware(middleware Middleware, rs ...Route) []Route { 182 routes := make([]Route, len(rs)) 183 184 for i := range rs { 185 route := rs[i] 186 routes[i] = Route{ 187 Method: route.Method, 188 Path: route.Path, 189 Handler: middleware(route.Handler), 190 } 191 } 192 193 return routes 194 } 195 196 // WithNotFoundHandler returns a RunOption with not found handler set to given handler. 197 func WithNotFoundHandler(handler http.Handler) RunOption { 198 return func(server *Server) { 199 notFoundHandler := server.ngin.notFoundHandler(handler) 200 server.router.SetNotFoundHandler(notFoundHandler) 201 } 202 } 203 204 // WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler. 205 func WithNotAllowedHandler(handler http.Handler) RunOption { 206 return func(server *Server) { 207 server.router.SetNotAllowedHandler(handler) 208 } 209 } 210 211 // WithPrefix adds group as a prefix to the route paths. 212 func WithPrefix(group string) RouteOption { 213 return func(r *featuredRoutes) { 214 var routes []Route 215 for _, rt := range r.routes { 216 p := path.Join(group, rt.Path) 217 routes = append(routes, Route{ 218 Method: rt.Method, 219 Path: p, 220 Handler: rt.Handler, 221 }) 222 } 223 r.routes = routes 224 } 225 } 226 227 // WithPriority returns a RunOption with priority. 228 func WithPriority() RouteOption { 229 return func(r *featuredRoutes) { 230 r.priority = true 231 } 232 } 233 234 // WithRouter returns a RunOption that make server run with given router. 235 func WithRouter(router httpx.Router) RunOption { 236 return func(server *Server) { 237 server.router = router 238 } 239 } 240 241 // WithSignature returns a RouteOption to enable signature verification. 242 func WithSignature(signature SignatureConf) RouteOption { 243 return func(r *featuredRoutes) { 244 r.signature.enabled = true 245 r.signature.Strict = signature.Strict 246 r.signature.Expiry = signature.Expiry 247 r.signature.PrivateKeys = signature.PrivateKeys 248 } 249 } 250 251 // WithTimeout returns a RouteOption to set timeout with given value. 252 func WithTimeout(timeout time.Duration) RouteOption { 253 return func(r *featuredRoutes) { 254 r.timeout = timeout 255 } 256 } 257 258 // WithTLSConfig returns a RunOption that with given tls config. 259 func WithTLSConfig(cfg *tls.Config) RunOption { 260 return func(svr *Server) { 261 svr.ngin.setTlsConfig(cfg) 262 } 263 } 264 265 // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. 266 func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { 267 return func(svr *Server) { 268 svr.ngin.setUnauthorizedCallback(callback) 269 } 270 } 271 272 // WithUnsignedCallback returns a RunOption that with given unsigned callback set. 273 func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { 274 return func(svr *Server) { 275 svr.ngin.setUnsignedCallback(callback) 276 } 277 } 278 279 func handleError(err error) { 280 // ErrServerClosed means the server is closed manually 281 if err == nil || err == http.ErrServerClosed { 282 return 283 } 284 285 logx.Error(err) 286 panic(err) 287 } 288 289 func validateSecret(secret string) { 290 if len(secret) < 8 { 291 panic("secret's length can't be less than 8") 292 } 293 } 294 295 type corsRouter struct { 296 httpx.Router 297 middleware Middleware 298 } 299 300 func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...string) httpx.Router { 301 return &corsRouter{ 302 Router: router, 303 middleware: cors.Middleware(headerFn, origins...), 304 } 305 } 306 307 func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { 308 c.middleware(c.Router.ServeHTTP)(w, r) 309 }