github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/server/http/mux.go (about) 1 /* 2 Copyright [2014] - [2023] The Last.Backend authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package http 18 19 import ( 20 "context" 21 "fmt" 22 "mime" 23 "net/http" 24 "regexp" 25 "sync" 26 27 "github.com/gorilla/mux" 28 "github.com/lastbackend/toolkit/pkg/runtime" 29 "github.com/lastbackend/toolkit/pkg/server" 30 "github.com/lastbackend/toolkit/pkg/server/http/errors" 31 "github.com/lastbackend/toolkit/pkg/server/http/marshaler" 32 "github.com/lastbackend/toolkit/pkg/server/http/websockets" 33 ) 34 35 const ( 36 defaultPrefix = "http" 37 ) 38 39 var ( 40 acceptHeader = http.CanonicalHeaderKey("Accept") 41 contentTypeHeader = http.CanonicalHeaderKey("Content-Type") 42 ) 43 44 type httpServer struct { 45 runtime runtime.Runtime 46 47 sync.RWMutex 48 49 opts Config 50 51 prefix string 52 isRunning bool 53 54 handlers map[string]server.HTTPServerHandler 55 marshalerMap map[string]marshaler.Marshaler 56 57 // fn for init user-defined service 58 // fn for server registration 59 service interface{} 60 61 middlewares *Middlewares 62 63 corsHandlerFunc http.HandlerFunc 64 65 wsManager *websockets.Manager 66 67 server *http.Server 68 exit chan chan error 69 70 r *mux.Router 71 } 72 73 func NewServer(name string, runtime runtime.Runtime, options *server.HTTPServerOptions) server.HTTPServer { 74 75 s := &httpServer{ 76 runtime: runtime, 77 prefix: defaultPrefix, 78 marshalerMap: GetMarshalerMap(), 79 exit: make(chan chan error), 80 81 corsHandlerFunc: corsHandlerFunc, 82 83 middlewares: newMiddlewares(runtime.Log()), 84 wsManager: websockets.NewManager(runtime.Log()), 85 handlers: make(map[string]server.HTTPServerHandler, 0), 86 87 r: mux.NewRouter(), 88 } 89 90 name = regexp.MustCompile(`[^_a-zA-Z0-9 ]+`).ReplaceAllString(name, "_") 91 92 if name != "" { 93 s.prefix = name 94 } 95 96 if err := runtime.Config().Parse(&s.opts, s.prefix); err != nil { 97 return nil 98 } 99 100 if options != nil { 101 s.parseOptions(options) 102 } 103 104 return s 105 } 106 107 func (s *httpServer) parseOptions(options *server.HTTPServerOptions) { 108 109 if options != nil { 110 if options.Host != "" { 111 s.opts.Host = options.Host 112 } 113 114 if options.Port > 0 { 115 s.opts.Port = options.Port 116 } 117 118 if options.TLSConfig != nil { 119 s.opts.TLSConfig = options.TLSConfig 120 } 121 } 122 } 123 124 func (s *httpServer) Info() server.ServerInfo { 125 return server.ServerInfo{ 126 Kind: server.ServerKindHTTPServer, 127 Host: s.opts.Host, 128 Port: s.opts.Port, 129 TLSConfig: s.opts.TLSConfig, 130 } 131 } 132 133 func (s *httpServer) Start(_ context.Context) error { 134 135 s.RLock() 136 if s.isRunning { 137 s.RUnlock() 138 return nil 139 } 140 s.RUnlock() 141 142 if s.opts.EnableCORS { 143 s.r.Methods(http.MethodOptions).HandlerFunc(s.corsHandlerFunc) 144 s.middlewares.global = append(s.middlewares.global, corsMiddlewareKind) 145 s.middlewares.Add(&corsMiddleware{handler: s.corsHandlerFunc}) 146 } 147 148 s.r.NotFoundHandler = s.methodNotFoundHandler() 149 s.r.MethodNotAllowedHandler = s.methodNotAllowedHandler() 150 151 s.server = &http.Server{ 152 Addr: fmt.Sprintf("%s:%d", s.opts.Host, s.opts.Port), 153 Handler: s.r, 154 TLSConfig: s.opts.TLSConfig, 155 } 156 157 for _, h := range s.handlers { 158 if err := s.registerHandler(h); err != nil { 159 return err 160 } 161 } 162 163 s.Lock() 164 s.isRunning = true 165 s.Unlock() 166 167 go func() { 168 s.runtime.Log().V(5).Infof("server [http] [%s] started", s.server.Addr) 169 if err := s.server.ListenAndServe(); err != http.ErrServerClosed { 170 s.runtime.Log().Errorf("server [http] [%s] start error: %v", s.server.Addr, err) 171 } 172 s.runtime.Log().V(5).Infof("server [http] [%s] stopped", s.server.Addr) 173 s.Lock() 174 s.isRunning = false 175 s.Unlock() 176 }() 177 178 return nil 179 } 180 181 func (s *httpServer) registerHandler(h server.HTTPServerHandler) error { 182 s.runtime.Log().V(5).Infof("register [http] route: %s", h.Path) 183 184 handler, err := s.middlewares.apply(h) 185 if err != nil { 186 return err 187 } 188 s.r.Handle(h.Path, handler).Methods(h.Method) 189 190 s.runtime.Log().V(5).Infof("bind handler: method: %s, path: %s", h.Method, h.Path) 191 192 return nil 193 } 194 195 func (s *httpServer) Stop(ctx context.Context) error { 196 s.runtime.Log().V(5).Infof("server [http] [%s] stop call start", s.server.Addr) 197 198 if err := s.server.Shutdown(ctx); err != nil { 199 s.runtime.Log().Errorf("server [http] [%s] stop call error: %v", s.server.Addr, err) 200 return err 201 } 202 203 s.runtime.Log().V(5).Infof("server [http] [%s] stop call end", s.server.Addr) 204 return nil 205 } 206 207 func (s *httpServer) UseMiddleware(middlewares ...server.KindMiddleware) { 208 s.middlewares.SetGlobal(middlewares...) 209 } 210 211 func (s *httpServer) UseMarshaler(contentType string, marshaler marshaler.Marshaler) error { 212 contentType, _, err := mime.ParseMediaType(contentType) 213 if err != nil { 214 return err 215 } 216 s.marshalerMap[contentType] = marshaler 217 return nil 218 } 219 220 func (s *httpServer) GetMiddlewares() []interface{} { 221 return s.middlewares.constructors 222 } 223 224 func (s *httpServer) GetConstructor() interface{} { 225 return s.constructor 226 } 227 228 func (s *httpServer) SetMiddleware(middleware any) { 229 s.middlewares.AddConstructor(middleware) 230 } 231 232 func (s *httpServer) AddHandler(method string, path string, h http.HandlerFunc, opts ...server.HTTPServerOption) { 233 key := fmt.Sprintf("%s:%s", method, path) 234 if !s.isRunning { 235 s.handlers[key] = server.HTTPServerHandler{Method: method, Path: path, Handler: h, Options: opts} 236 } else { 237 _ = s.registerHandler(server.HTTPServerHandler{Method: method, Path: path, Handler: h, Options: opts}) 238 } 239 } 240 241 func (s *httpServer) SetCorsHandlerFunc(hf http.HandlerFunc) { 242 s.corsHandlerFunc = hf 243 } 244 245 func (s *httpServer) SetErrorHandlerFunc(hf func(http.ResponseWriter, error)) { 246 errors.GrpcErrorHandlerFunc = hf 247 } 248 249 func (s *httpServer) Subscribe(event string, h websockets.EventHandler) { 250 s.wsManager.AddEventHandler(event, h) 251 } 252 253 func (s *httpServer) ServerWS(w http.ResponseWriter, r *http.Request) { 254 s.wsManager.ServeWS(w, r) 255 } 256 257 // SetService - set user-defined handlers 258 func (s *httpServer) SetService(service interface{}) { 259 s.service = service 260 return 261 } 262 263 // GetService - set user-defined handlers 264 func (s *httpServer) GetService() interface{} { 265 return s.service 266 } 267 268 func (s *httpServer) constructor(mws ...server.HttpServerMiddleware) { 269 for _, mw := range mws { 270 s.middlewares.Add(mw) 271 } 272 } 273 274 func GetMarshaler(s server.HTTPServer, req *http.Request) (inbound, outbound marshaler.Marshaler) { 275 for _, acceptVal := range req.Header[acceptHeader] { 276 if m, ok := s.(*httpServer).marshalerMap[acceptVal]; ok { 277 outbound = m 278 break 279 } 280 } 281 282 for _, contentTypeVal := range req.Header[contentTypeHeader] { 283 contentType, _, err := mime.ParseMediaType(contentTypeVal) 284 if err != nil { 285 continue 286 } 287 if m, ok := s.(*httpServer).marshalerMap[contentType]; ok { 288 inbound = m 289 break 290 } 291 } 292 293 if inbound == nil { 294 inbound = DefaultMarshaler 295 } 296 if outbound == nil { 297 outbound = inbound 298 } 299 300 return inbound, outbound 301 }