github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/server/http/mux.go (about)

     1  /*
     2  Copyright [2014] - [2023] The Last.Backend authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package http
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"mime"
    23  	"net/http"
    24  	"regexp"
    25  	"sync"
    26  
    27  	"github.com/gorilla/mux"
    28  	"github.com/lastbackend/toolkit/pkg/runtime"
    29  	"github.com/lastbackend/toolkit/pkg/server"
    30  	"github.com/lastbackend/toolkit/pkg/server/http/errors"
    31  	"github.com/lastbackend/toolkit/pkg/server/http/marshaler"
    32  	"github.com/lastbackend/toolkit/pkg/server/http/websockets"
    33  )
    34  
    35  const (
    36  	defaultPrefix = "http"
    37  )
    38  
    39  var (
    40  	acceptHeader      = http.CanonicalHeaderKey("Accept")
    41  	contentTypeHeader = http.CanonicalHeaderKey("Content-Type")
    42  )
    43  
    44  type httpServer struct {
    45  	runtime runtime.Runtime
    46  
    47  	sync.RWMutex
    48  
    49  	opts Config
    50  
    51  	prefix    string
    52  	isRunning bool
    53  
    54  	handlers     map[string]server.HTTPServerHandler
    55  	marshalerMap map[string]marshaler.Marshaler
    56  
    57  	// fn for init user-defined service
    58  	// fn for server registration
    59  	service interface{}
    60  
    61  	middlewares *Middlewares
    62  
    63  	corsHandlerFunc http.HandlerFunc
    64  
    65  	wsManager *websockets.Manager
    66  
    67  	server *http.Server
    68  	exit   chan chan error
    69  
    70  	r *mux.Router
    71  }
    72  
    73  func NewServer(name string, runtime runtime.Runtime, options *server.HTTPServerOptions) server.HTTPServer {
    74  
    75  	s := &httpServer{
    76  		runtime:      runtime,
    77  		prefix:       defaultPrefix,
    78  		marshalerMap: GetMarshalerMap(),
    79  		exit:         make(chan chan error),
    80  
    81  		corsHandlerFunc: corsHandlerFunc,
    82  
    83  		middlewares: newMiddlewares(runtime.Log()),
    84  		wsManager:   websockets.NewManager(runtime.Log()),
    85  		handlers:    make(map[string]server.HTTPServerHandler, 0),
    86  
    87  		r: mux.NewRouter(),
    88  	}
    89  
    90  	name = regexp.MustCompile(`[^_a-zA-Z0-9 ]+`).ReplaceAllString(name, "_")
    91  
    92  	if name != "" {
    93  		s.prefix = name
    94  	}
    95  
    96  	if err := runtime.Config().Parse(&s.opts, s.prefix); err != nil {
    97  		return nil
    98  	}
    99  
   100  	if options != nil {
   101  		s.parseOptions(options)
   102  	}
   103  
   104  	return s
   105  }
   106  
   107  func (s *httpServer) parseOptions(options *server.HTTPServerOptions) {
   108  
   109  	if options != nil {
   110  		if options.Host != "" {
   111  			s.opts.Host = options.Host
   112  		}
   113  
   114  		if options.Port > 0 {
   115  			s.opts.Port = options.Port
   116  		}
   117  
   118  		if options.TLSConfig != nil {
   119  			s.opts.TLSConfig = options.TLSConfig
   120  		}
   121  	}
   122  }
   123  
   124  func (s *httpServer) Info() server.ServerInfo {
   125  	return server.ServerInfo{
   126  		Kind:      server.ServerKindHTTPServer,
   127  		Host:      s.opts.Host,
   128  		Port:      s.opts.Port,
   129  		TLSConfig: s.opts.TLSConfig,
   130  	}
   131  }
   132  
   133  func (s *httpServer) Start(_ context.Context) error {
   134  
   135  	s.RLock()
   136  	if s.isRunning {
   137  		s.RUnlock()
   138  		return nil
   139  	}
   140  	s.RUnlock()
   141  
   142  	if s.opts.EnableCORS {
   143  		s.r.Methods(http.MethodOptions).HandlerFunc(s.corsHandlerFunc)
   144  		s.middlewares.global = append(s.middlewares.global, corsMiddlewareKind)
   145  		s.middlewares.Add(&corsMiddleware{handler: s.corsHandlerFunc})
   146  	}
   147  
   148  	s.r.NotFoundHandler = s.methodNotFoundHandler()
   149  	s.r.MethodNotAllowedHandler = s.methodNotAllowedHandler()
   150  
   151  	s.server = &http.Server{
   152  		Addr:      fmt.Sprintf("%s:%d", s.opts.Host, s.opts.Port),
   153  		Handler:   s.r,
   154  		TLSConfig: s.opts.TLSConfig,
   155  	}
   156  
   157  	for _, h := range s.handlers {
   158  		if err := s.registerHandler(h); err != nil {
   159  			return err
   160  		}
   161  	}
   162  
   163  	s.Lock()
   164  	s.isRunning = true
   165  	s.Unlock()
   166  
   167  	go func() {
   168  		s.runtime.Log().V(5).Infof("server [http] [%s] started", s.server.Addr)
   169  		if err := s.server.ListenAndServe(); err != http.ErrServerClosed {
   170  			s.runtime.Log().Errorf("server [http] [%s] start error: %v", s.server.Addr, err)
   171  		}
   172  		s.runtime.Log().V(5).Infof("server [http] [%s] stopped", s.server.Addr)
   173  		s.Lock()
   174  		s.isRunning = false
   175  		s.Unlock()
   176  	}()
   177  
   178  	return nil
   179  }
   180  
   181  func (s *httpServer) registerHandler(h server.HTTPServerHandler) error {
   182  	s.runtime.Log().V(5).Infof("register [http] route: %s", h.Path)
   183  
   184  	handler, err := s.middlewares.apply(h)
   185  	if err != nil {
   186  		return err
   187  	}
   188  	s.r.Handle(h.Path, handler).Methods(h.Method)
   189  
   190  	s.runtime.Log().V(5).Infof("bind handler: method: %s, path: %s", h.Method, h.Path)
   191  
   192  	return nil
   193  }
   194  
   195  func (s *httpServer) Stop(ctx context.Context) error {
   196  	s.runtime.Log().V(5).Infof("server [http] [%s] stop call start", s.server.Addr)
   197  
   198  	if err := s.server.Shutdown(ctx); err != nil {
   199  		s.runtime.Log().Errorf("server [http] [%s] stop call error: %v", s.server.Addr, err)
   200  		return err
   201  	}
   202  
   203  	s.runtime.Log().V(5).Infof("server [http] [%s] stop call end", s.server.Addr)
   204  	return nil
   205  }
   206  
   207  func (s *httpServer) UseMiddleware(middlewares ...server.KindMiddleware) {
   208  	s.middlewares.SetGlobal(middlewares...)
   209  }
   210  
   211  func (s *httpServer) UseMarshaler(contentType string, marshaler marshaler.Marshaler) error {
   212  	contentType, _, err := mime.ParseMediaType(contentType)
   213  	if err != nil {
   214  		return err
   215  	}
   216  	s.marshalerMap[contentType] = marshaler
   217  	return nil
   218  }
   219  
   220  func (s *httpServer) GetMiddlewares() []interface{} {
   221  	return s.middlewares.constructors
   222  }
   223  
   224  func (s *httpServer) GetConstructor() interface{} {
   225  	return s.constructor
   226  }
   227  
   228  func (s *httpServer) SetMiddleware(middleware any) {
   229  	s.middlewares.AddConstructor(middleware)
   230  }
   231  
   232  func (s *httpServer) AddHandler(method string, path string, h http.HandlerFunc, opts ...server.HTTPServerOption) {
   233  	key := fmt.Sprintf("%s:%s", method, path)
   234  	if !s.isRunning {
   235  		s.handlers[key] = server.HTTPServerHandler{Method: method, Path: path, Handler: h, Options: opts}
   236  	} else {
   237  		_ = s.registerHandler(server.HTTPServerHandler{Method: method, Path: path, Handler: h, Options: opts})
   238  	}
   239  }
   240  
   241  func (s *httpServer) SetCorsHandlerFunc(hf http.HandlerFunc) {
   242  	s.corsHandlerFunc = hf
   243  }
   244  
   245  func (s *httpServer) SetErrorHandlerFunc(hf func(http.ResponseWriter, error)) {
   246  	errors.GrpcErrorHandlerFunc = hf
   247  }
   248  
   249  func (s *httpServer) Subscribe(event string, h websockets.EventHandler) {
   250  	s.wsManager.AddEventHandler(event, h)
   251  }
   252  
   253  func (s *httpServer) ServerWS(w http.ResponseWriter, r *http.Request) {
   254  	s.wsManager.ServeWS(w, r)
   255  }
   256  
   257  // SetService - set user-defined handlers
   258  func (s *httpServer) SetService(service interface{}) {
   259  	s.service = service
   260  	return
   261  }
   262  
   263  // GetService - set user-defined handlers
   264  func (s *httpServer) GetService() interface{} {
   265  	return s.service
   266  }
   267  
   268  func (s *httpServer) constructor(mws ...server.HttpServerMiddleware) {
   269  	for _, mw := range mws {
   270  		s.middlewares.Add(mw)
   271  	}
   272  }
   273  
   274  func GetMarshaler(s server.HTTPServer, req *http.Request) (inbound, outbound marshaler.Marshaler) {
   275  	for _, acceptVal := range req.Header[acceptHeader] {
   276  		if m, ok := s.(*httpServer).marshalerMap[acceptVal]; ok {
   277  			outbound = m
   278  			break
   279  		}
   280  	}
   281  
   282  	for _, contentTypeVal := range req.Header[contentTypeHeader] {
   283  		contentType, _, err := mime.ParseMediaType(contentTypeVal)
   284  		if err != nil {
   285  			continue
   286  		}
   287  		if m, ok := s.(*httpServer).marshalerMap[contentType]; ok {
   288  			inbound = m
   289  			break
   290  		}
   291  	}
   292  
   293  	if inbound == nil {
   294  		inbound = DefaultMarshaler
   295  	}
   296  	if outbound == nil {
   297  		outbound = inbound
   298  	}
   299  
   300  	return inbound, outbound
   301  }