github.com/lingyao2333/mo-zero@v1.4.1/rest/engine.go (about) 1 package rest 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "net/http" 8 "sort" 9 "time" 10 11 "github.com/lingyao2333/mo-zero/core/codec" 12 "github.com/lingyao2333/mo-zero/core/load" 13 "github.com/lingyao2333/mo-zero/core/stat" 14 "github.com/lingyao2333/mo-zero/rest/chain" 15 "github.com/lingyao2333/mo-zero/rest/handler" 16 "github.com/lingyao2333/mo-zero/rest/httpx" 17 "github.com/lingyao2333/mo-zero/rest/internal" 18 "github.com/lingyao2333/mo-zero/rest/internal/response" 19 ) 20 21 // use 1000m to represent 100% 22 const topCpuUsage = 1000 23 24 // ErrSignatureConfig is an error that indicates bad config for signature. 25 var ErrSignatureConfig = errors.New("bad config for Signature") 26 27 type engine struct { 28 conf RestConf 29 routes []featuredRoutes 30 unauthorizedCallback handler.UnauthorizedCallback 31 unsignedCallback handler.UnsignedCallback 32 chain chain.Chain 33 middlewares []Middleware 34 shedder load.Shedder 35 priorityShedder load.Shedder 36 tlsConfig *tls.Config 37 } 38 39 func newEngine(c RestConf) *engine { 40 svr := &engine{ 41 conf: c, 42 } 43 if c.CpuThreshold > 0 { 44 svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) 45 svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold( 46 (c.CpuThreshold + topCpuUsage) >> 1)) 47 } 48 49 return svr 50 } 51 52 func (ng *engine) addRoutes(r featuredRoutes) { 53 ng.routes = append(ng.routes, r) 54 } 55 56 func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain, 57 verifier func(chain.Chain) chain.Chain) chain.Chain { 58 if fr.jwt.enabled { 59 if len(fr.jwt.prevSecret) == 0 { 60 chn = chn.Append(handler.Authorize(fr.jwt.secret, 61 handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) 62 } else { 63 chn = chn.Append(handler.Authorize(fr.jwt.secret, 64 handler.WithPrevSecret(fr.jwt.prevSecret), 65 handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) 66 } 67 } 68 69 return verifier(chn) 70 } 71 72 func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { 73 verifier, err := ng.signatureVerifier(fr.signature) 74 if err != nil { 75 return err 76 } 77 78 for _, route := range fr.routes { 79 if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil { 80 return err 81 } 82 } 83 84 return nil 85 } 86 87 func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, 88 route Route, verifier func(chain.Chain) chain.Chain) error { 89 chn := ng.chain 90 if chn == nil { 91 chn = chain.New( 92 handler.TracingHandler(ng.conf.Name, route.Path), 93 ng.getLogHandler(), 94 handler.PrometheusHandler(route.Path), 95 handler.MaxConns(ng.conf.MaxConns), 96 handler.BreakerHandler(route.Method, route.Path, metrics), 97 handler.SheddingHandler(ng.getShedder(fr.priority), metrics), 98 handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)), 99 handler.RecoverHandler, 100 handler.MetricHandler(metrics), 101 handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)), 102 handler.GunzipHandler, 103 ) 104 } 105 106 chn = ng.appendAuthHandler(fr, chn, verifier) 107 108 for _, middleware := range ng.middlewares { 109 chn = chn.Append(convertMiddleware(middleware)) 110 } 111 handle := chn.ThenFunc(route.Handler) 112 113 return router.Handle(route.Method, route.Path, handle) 114 } 115 116 func (ng *engine) bindRoutes(router httpx.Router) error { 117 metrics := ng.createMetrics() 118 119 for _, fr := range ng.routes { 120 if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil { 121 return err 122 } 123 } 124 125 return nil 126 } 127 128 func (ng *engine) checkedMaxBytes(bytes int64) int64 { 129 if bytes > 0 { 130 return bytes 131 } 132 133 return ng.conf.MaxBytes 134 } 135 136 func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration { 137 if timeout > 0 { 138 return timeout 139 } 140 141 return time.Duration(ng.conf.Timeout) * time.Millisecond 142 } 143 144 func (ng *engine) createMetrics() *stat.Metrics { 145 var metrics *stat.Metrics 146 147 if len(ng.conf.Name) > 0 { 148 metrics = stat.NewMetrics(ng.conf.Name) 149 } else { 150 metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port)) 151 } 152 153 return metrics 154 } 155 156 func (ng *engine) getLogHandler() func(http.Handler) http.Handler { 157 if ng.conf.Verbose { 158 return handler.DetailedLogHandler 159 } 160 161 return handler.LogHandler 162 } 163 164 func (ng *engine) getShedder(priority bool) load.Shedder { 165 if priority && ng.priorityShedder != nil { 166 return ng.priorityShedder 167 } 168 169 return ng.shedder 170 } 171 172 // notFoundHandler returns a middleware that handles 404 not found requests. 173 func (ng *engine) notFoundHandler(next http.Handler) http.Handler { 174 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 175 chn := chain.New( 176 handler.TracingHandler(ng.conf.Name, ""), 177 ng.getLogHandler(), 178 ) 179 180 var h http.Handler 181 if next != nil { 182 h = chn.Then(next) 183 } else { 184 h = chn.Then(http.NotFoundHandler()) 185 } 186 187 cw := response.NewHeaderOnceResponseWriter(w) 188 h.ServeHTTP(cw, r) 189 cw.WriteHeader(http.StatusNotFound) 190 }) 191 } 192 193 func (ng *engine) print() { 194 var routes []string 195 196 for _, fr := range ng.routes { 197 for _, route := range fr.routes { 198 routes = append(routes, fmt.Sprintf("%s %s", route.Method, route.Path)) 199 } 200 } 201 202 sort.Strings(routes) 203 204 fmt.Println("Routes:") 205 for _, route := range routes { 206 fmt.Printf(" %s\n", route) 207 } 208 } 209 210 func (ng *engine) setTlsConfig(cfg *tls.Config) { 211 ng.tlsConfig = cfg 212 } 213 214 func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) { 215 ng.unauthorizedCallback = callback 216 } 217 218 func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) { 219 ng.unsignedCallback = callback 220 } 221 222 func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) { 223 if !signature.enabled { 224 return func(chn chain.Chain) chain.Chain { 225 return chn 226 }, nil 227 } 228 229 if len(signature.PrivateKeys) == 0 { 230 if signature.Strict { 231 return nil, ErrSignatureConfig 232 } 233 234 return func(chn chain.Chain) chain.Chain { 235 return chn 236 }, nil 237 } 238 239 decrypters := make(map[string]codec.RsaDecrypter) 240 for _, key := range signature.PrivateKeys { 241 fingerprint := key.Fingerprint 242 file := key.KeyFile 243 decrypter, err := codec.NewRsaDecrypter(file) 244 if err != nil { 245 return nil, err 246 } 247 248 decrypters[fingerprint] = decrypter 249 } 250 251 return func(chn chain.Chain) chain.Chain { 252 if ng.unsignedCallback != nil { 253 return chn.Append(handler.ContentSecurityHandler( 254 decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback)) 255 } 256 257 return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict)) 258 }, nil 259 } 260 261 func (ng *engine) start(router httpx.Router) error { 262 if err := ng.bindRoutes(router); err != nil { 263 return err 264 } 265 266 if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { 267 return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, ng.withTimeout()) 268 } 269 270 return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile, 271 ng.conf.KeyFile, router, func(svr *http.Server) { 272 if ng.tlsConfig != nil { 273 svr.TLSConfig = ng.tlsConfig 274 } 275 }, ng.withTimeout()) 276 } 277 278 func (ng *engine) use(middleware Middleware) { 279 ng.middlewares = append(ng.middlewares, middleware) 280 } 281 282 func (ng *engine) withTimeout() internal.StartOption { 283 return func(svr *http.Server) { 284 timeout := ng.conf.Timeout 285 if timeout > 0 { 286 // factor 0.8, to avoid clients send longer content-length than the actual content, 287 // without this timeout setting, the server will time out and respond 503 Service Unavailable, 288 // which triggers the circuit breaker. 289 svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5 290 // factor 0.9, to avoid clients not reading the response 291 // without this timeout setting, the server will time out and respond 503 Service Unavailable, 292 // which triggers the circuit breaker. 293 svr.WriteTimeout = 9 * time.Duration(timeout) * time.Millisecond / 10 294 } 295 } 296 } 297 298 func convertMiddleware(ware Middleware) func(http.Handler) http.Handler { 299 return func(next http.Handler) http.Handler { 300 return ware(next.ServeHTTP) 301 } 302 }