github.com/lingyao2333/mo-zero@v1.4.1/rest/engine.go (about)

     1  package rest
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"sort"
     9  	"time"
    10  
    11  	"github.com/lingyao2333/mo-zero/core/codec"
    12  	"github.com/lingyao2333/mo-zero/core/load"
    13  	"github.com/lingyao2333/mo-zero/core/stat"
    14  	"github.com/lingyao2333/mo-zero/rest/chain"
    15  	"github.com/lingyao2333/mo-zero/rest/handler"
    16  	"github.com/lingyao2333/mo-zero/rest/httpx"
    17  	"github.com/lingyao2333/mo-zero/rest/internal"
    18  	"github.com/lingyao2333/mo-zero/rest/internal/response"
    19  )
    20  
    21  // use 1000m to represent 100%
    22  const topCpuUsage = 1000
    23  
    24  // ErrSignatureConfig is an error that indicates bad config for signature.
    25  var ErrSignatureConfig = errors.New("bad config for Signature")
    26  
    27  type engine struct {
    28  	conf                 RestConf
    29  	routes               []featuredRoutes
    30  	unauthorizedCallback handler.UnauthorizedCallback
    31  	unsignedCallback     handler.UnsignedCallback
    32  	chain                chain.Chain
    33  	middlewares          []Middleware
    34  	shedder              load.Shedder
    35  	priorityShedder      load.Shedder
    36  	tlsConfig            *tls.Config
    37  }
    38  
    39  func newEngine(c RestConf) *engine {
    40  	svr := &engine{
    41  		conf: c,
    42  	}
    43  	if c.CpuThreshold > 0 {
    44  		svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
    45  		svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
    46  			(c.CpuThreshold + topCpuUsage) >> 1))
    47  	}
    48  
    49  	return svr
    50  }
    51  
    52  func (ng *engine) addRoutes(r featuredRoutes) {
    53  	ng.routes = append(ng.routes, r)
    54  }
    55  
    56  func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
    57  	verifier func(chain.Chain) chain.Chain) chain.Chain {
    58  	if fr.jwt.enabled {
    59  		if len(fr.jwt.prevSecret) == 0 {
    60  			chn = chn.Append(handler.Authorize(fr.jwt.secret,
    61  				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
    62  		} else {
    63  			chn = chn.Append(handler.Authorize(fr.jwt.secret,
    64  				handler.WithPrevSecret(fr.jwt.prevSecret),
    65  				handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
    66  		}
    67  	}
    68  
    69  	return verifier(chn)
    70  }
    71  
    72  func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
    73  	verifier, err := ng.signatureVerifier(fr.signature)
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	for _, route := range fr.routes {
    79  		if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
    80  			return err
    81  		}
    82  	}
    83  
    84  	return nil
    85  }
    86  
    87  func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
    88  	route Route, verifier func(chain.Chain) chain.Chain) error {
    89  	chn := ng.chain
    90  	if chn == nil {
    91  		chn = chain.New(
    92  			handler.TracingHandler(ng.conf.Name, route.Path),
    93  			ng.getLogHandler(),
    94  			handler.PrometheusHandler(route.Path),
    95  			handler.MaxConns(ng.conf.MaxConns),
    96  			handler.BreakerHandler(route.Method, route.Path, metrics),
    97  			handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
    98  			handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
    99  			handler.RecoverHandler,
   100  			handler.MetricHandler(metrics),
   101  			handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
   102  			handler.GunzipHandler,
   103  		)
   104  	}
   105  
   106  	chn = ng.appendAuthHandler(fr, chn, verifier)
   107  
   108  	for _, middleware := range ng.middlewares {
   109  		chn = chn.Append(convertMiddleware(middleware))
   110  	}
   111  	handle := chn.ThenFunc(route.Handler)
   112  
   113  	return router.Handle(route.Method, route.Path, handle)
   114  }
   115  
   116  func (ng *engine) bindRoutes(router httpx.Router) error {
   117  	metrics := ng.createMetrics()
   118  
   119  	for _, fr := range ng.routes {
   120  		if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
   121  			return err
   122  		}
   123  	}
   124  
   125  	return nil
   126  }
   127  
   128  func (ng *engine) checkedMaxBytes(bytes int64) int64 {
   129  	if bytes > 0 {
   130  		return bytes
   131  	}
   132  
   133  	return ng.conf.MaxBytes
   134  }
   135  
   136  func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
   137  	if timeout > 0 {
   138  		return timeout
   139  	}
   140  
   141  	return time.Duration(ng.conf.Timeout) * time.Millisecond
   142  }
   143  
   144  func (ng *engine) createMetrics() *stat.Metrics {
   145  	var metrics *stat.Metrics
   146  
   147  	if len(ng.conf.Name) > 0 {
   148  		metrics = stat.NewMetrics(ng.conf.Name)
   149  	} else {
   150  		metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
   151  	}
   152  
   153  	return metrics
   154  }
   155  
   156  func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
   157  	if ng.conf.Verbose {
   158  		return handler.DetailedLogHandler
   159  	}
   160  
   161  	return handler.LogHandler
   162  }
   163  
   164  func (ng *engine) getShedder(priority bool) load.Shedder {
   165  	if priority && ng.priorityShedder != nil {
   166  		return ng.priorityShedder
   167  	}
   168  
   169  	return ng.shedder
   170  }
   171  
   172  // notFoundHandler returns a middleware that handles 404 not found requests.
   173  func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
   174  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   175  		chn := chain.New(
   176  			handler.TracingHandler(ng.conf.Name, ""),
   177  			ng.getLogHandler(),
   178  		)
   179  
   180  		var h http.Handler
   181  		if next != nil {
   182  			h = chn.Then(next)
   183  		} else {
   184  			h = chn.Then(http.NotFoundHandler())
   185  		}
   186  
   187  		cw := response.NewHeaderOnceResponseWriter(w)
   188  		h.ServeHTTP(cw, r)
   189  		cw.WriteHeader(http.StatusNotFound)
   190  	})
   191  }
   192  
   193  func (ng *engine) print() {
   194  	var routes []string
   195  
   196  	for _, fr := range ng.routes {
   197  		for _, route := range fr.routes {
   198  			routes = append(routes, fmt.Sprintf("%s %s", route.Method, route.Path))
   199  		}
   200  	}
   201  
   202  	sort.Strings(routes)
   203  
   204  	fmt.Println("Routes:")
   205  	for _, route := range routes {
   206  		fmt.Printf("  %s\n", route)
   207  	}
   208  }
   209  
   210  func (ng *engine) setTlsConfig(cfg *tls.Config) {
   211  	ng.tlsConfig = cfg
   212  }
   213  
   214  func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
   215  	ng.unauthorizedCallback = callback
   216  }
   217  
   218  func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
   219  	ng.unsignedCallback = callback
   220  }
   221  
   222  func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
   223  	if !signature.enabled {
   224  		return func(chn chain.Chain) chain.Chain {
   225  			return chn
   226  		}, nil
   227  	}
   228  
   229  	if len(signature.PrivateKeys) == 0 {
   230  		if signature.Strict {
   231  			return nil, ErrSignatureConfig
   232  		}
   233  
   234  		return func(chn chain.Chain) chain.Chain {
   235  			return chn
   236  		}, nil
   237  	}
   238  
   239  	decrypters := make(map[string]codec.RsaDecrypter)
   240  	for _, key := range signature.PrivateKeys {
   241  		fingerprint := key.Fingerprint
   242  		file := key.KeyFile
   243  		decrypter, err := codec.NewRsaDecrypter(file)
   244  		if err != nil {
   245  			return nil, err
   246  		}
   247  
   248  		decrypters[fingerprint] = decrypter
   249  	}
   250  
   251  	return func(chn chain.Chain) chain.Chain {
   252  		if ng.unsignedCallback != nil {
   253  			return chn.Append(handler.ContentSecurityHandler(
   254  				decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
   255  		}
   256  
   257  		return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
   258  	}, nil
   259  }
   260  
   261  func (ng *engine) start(router httpx.Router) error {
   262  	if err := ng.bindRoutes(router); err != nil {
   263  		return err
   264  	}
   265  
   266  	if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
   267  		return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, ng.withTimeout())
   268  	}
   269  
   270  	return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
   271  		ng.conf.KeyFile, router, func(svr *http.Server) {
   272  			if ng.tlsConfig != nil {
   273  				svr.TLSConfig = ng.tlsConfig
   274  			}
   275  		}, ng.withTimeout())
   276  }
   277  
   278  func (ng *engine) use(middleware Middleware) {
   279  	ng.middlewares = append(ng.middlewares, middleware)
   280  }
   281  
   282  func (ng *engine) withTimeout() internal.StartOption {
   283  	return func(svr *http.Server) {
   284  		timeout := ng.conf.Timeout
   285  		if timeout > 0 {
   286  			// factor 0.8, to avoid clients send longer content-length than the actual content,
   287  			// without this timeout setting, the server will time out and respond 503 Service Unavailable,
   288  			// which triggers the circuit breaker.
   289  			svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
   290  			// factor 0.9, to avoid clients not reading the response
   291  			// without this timeout setting, the server will time out and respond 503 Service Unavailable,
   292  			// which triggers the circuit breaker.
   293  			svr.WriteTimeout = 9 * time.Duration(timeout) * time.Millisecond / 10
   294  		}
   295  	}
   296  }
   297  
   298  func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
   299  	return func(next http.Handler) http.Handler {
   300  		return ware(next.ServeHTTP)
   301  	}
   302  }