github.com/cloudwego/hertz@v0.9.3/pkg/route/engine.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * The MIT License (MIT) 16 * Copyright (c) 2014 Manuel MartÃnez-Almeida 17 * 18 * Permission is hereby granted, free of charge, to any person obtaining a copy 19 * of this software and associated documentation files (the "Software"), to deal 20 * in the Software without restriction, including without limitation the rights 21 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 22 * copies of the Software, and to permit persons to whom the Software is 23 * furnished to do so, subject to the following conditions: 24 * 25 * The above copyright notice and this permission notice shall be included in 26 * all copies or substantial portions of the Software. 27 * 28 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 29 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 30 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 31 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 32 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 33 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 34 * THE SOFTWARE. 35 * 36 * This file may have been modified by CloudWeGo authors. All CloudWeGo 37 * Modifications are Copyright 2022 CloudWeGo Authors 38 */ 39 40 package route 41 42 import ( 43 "bytes" 44 "context" 45 "crypto/tls" 46 "errors" 47 "fmt" 48 "html/template" 49 "io" 50 "path/filepath" 51 "reflect" 52 "runtime" 53 "strings" 54 "sync" 55 "sync/atomic" 56 57 "github.com/cloudwego/hertz/internal/bytesconv" 58 "github.com/cloudwego/hertz/internal/bytestr" 59 "github.com/cloudwego/hertz/internal/nocopy" 60 internalStats "github.com/cloudwego/hertz/internal/stats" 61 "github.com/cloudwego/hertz/pkg/app" 62 "github.com/cloudwego/hertz/pkg/app/server/binding" 63 "github.com/cloudwego/hertz/pkg/app/server/render" 64 "github.com/cloudwego/hertz/pkg/common/config" 65 errs "github.com/cloudwego/hertz/pkg/common/errors" 66 "github.com/cloudwego/hertz/pkg/common/hlog" 67 "github.com/cloudwego/hertz/pkg/common/tracer" 68 "github.com/cloudwego/hertz/pkg/common/tracer/stats" 69 "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" 70 "github.com/cloudwego/hertz/pkg/common/utils" 71 "github.com/cloudwego/hertz/pkg/network" 72 "github.com/cloudwego/hertz/pkg/network/standard" 73 "github.com/cloudwego/hertz/pkg/protocol" 74 "github.com/cloudwego/hertz/pkg/protocol/consts" 75 "github.com/cloudwego/hertz/pkg/protocol/http1" 76 "github.com/cloudwego/hertz/pkg/protocol/http1/factory" 77 "github.com/cloudwego/hertz/pkg/protocol/suite" 78 "github.com/cloudwego/hertz/pkg/route/param" 79 ) 80 81 const unknownTransporterName = "unknown" 82 83 var ( 84 defaultTransporter = standard.NewTransporter 85 86 errInitFailed = errs.NewPrivate("engine has been init already") 87 errAlreadyRunning = errs.NewPrivate("engine is already running") 88 errStatusNotRunning = errs.NewPrivate("engine is not running") 89 90 default404Body = []byte("404 page not found") 91 default405Body = []byte("405 method not allowed") 92 default400Body = []byte("400 bad request") 93 94 requiredHostBody = []byte("missing required Host header") 95 ) 96 97 type hijackConn struct { 98 network.Conn 99 e *Engine 100 } 101 102 type CtxCallback func(ctx context.Context) 103 104 type CtxErrCallback func(ctx context.Context) error 105 106 // RouteInfo represents a request route's specification which contains method and path and its handler. 107 type RouteInfo struct { 108 Method string 109 Path string 110 Handler string 111 HandlerFunc app.HandlerFunc 112 } 113 114 // RoutesInfo defines a RouteInfo array. 115 type RoutesInfo []RouteInfo 116 117 type Engine struct { 118 noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used 119 120 // engine name 121 Name string 122 serverName atomic.Value 123 124 // Options for route and protocol server 125 options *config.Options 126 127 // route 128 RouterGroup 129 trees MethodTrees 130 131 maxParams uint16 132 133 allNoMethod app.HandlersChain 134 allNoRoute app.HandlersChain 135 noRoute app.HandlersChain 136 noMethod app.HandlersChain 137 138 // For render HTML 139 delims render.Delims 140 funcMap template.FuncMap 141 htmlRender render.HTMLRender 142 143 // NoHijackConnPool will control whether invite pool to acquire/release the hijackConn or not. 144 // If it is difficult to guarantee that hijackConn will not be closed repeatedly, set it to true. 145 NoHijackConnPool bool 146 hijackConnPool sync.Pool 147 // KeepHijackedConns is an opt-in disable of connection 148 // close by hertz after connections' HijackHandler returns. 149 // This allows to save goroutines, e.g. when hertz used to upgrade 150 // http connections to WS and connection goes to another handler, 151 // which will close it when needed. 152 KeepHijackedConns bool 153 154 // underlying transport 155 transport network.Transporter 156 157 // trace 158 tracerCtl tracer.Controller 159 enableTrace bool 160 161 // protocol layer management 162 protocolSuite *suite.Config 163 protocolServers map[string]protocol.Server 164 protocolStreamServers map[string]protocol.StreamServer 165 166 // RequestContext pool 167 ctxPool sync.Pool 168 169 // Function to handle panics recovered from http handlers. 170 // It should be used to generate an error page and return the http error code 171 // 500 (Internal Server Error). 172 // The handler can be used to keep your server from crashing because of 173 // unrecovered panics. 174 PanicHandler app.HandlerFunc 175 176 // ContinueHandler is called after receiving the Expect 100 Continue Header 177 // 178 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 179 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1 180 // Using ContinueHandler a server can make decisioning on whether or not 181 // to read a potentially large request body based on the headers 182 // 183 // The default is to automatically read request bodies of Expect 100 Continue requests 184 // like they are normal requests 185 ContinueHandler func(header *protocol.RequestHeader) bool 186 187 // Indicates the engine status (Init/Running/Shutdown/Closed). 188 status uint32 189 190 // Hook functions get triggered sequentially when engine start 191 OnRun []CtxErrCallback 192 193 // Hook functions get triggered simultaneously when engine shutdown 194 OnShutdown []CtxCallback 195 196 // Custom Functions 197 clientIPFunc app.ClientIP 198 formValueFunc app.FormValueFunc 199 200 // Custom Binder and Validator 201 binder binding.Binder 202 validator binding.StructValidator 203 } 204 205 func (engine *Engine) IsTraceEnable() bool { 206 return engine.enableTrace 207 } 208 209 func (engine *Engine) GetCtxPool() *sync.Pool { 210 return &engine.ctxPool 211 } 212 213 func (engine *Engine) GetOptions() *config.Options { 214 return engine.options 215 } 216 217 // SetTransporter only sets the global default value for the transporter. 218 // Use WithTransporter during engine creation to set the transporter for the engine. 219 func SetTransporter(transporter func(options *config.Options) network.Transporter) { 220 defaultTransporter = transporter 221 } 222 223 func (engine *Engine) GetTransporterName() (tName string) { 224 return getTransporterName(engine.transport) 225 } 226 227 func getTransporterName(transporter network.Transporter) (tName string) { 228 defer func() { 229 err := recover() 230 if err != nil || tName == "" { 231 tName = unknownTransporterName 232 } 233 }() 234 t := reflect.ValueOf(transporter).Type().String() 235 tName = strings.Split(strings.TrimPrefix(t, "*"), ".")[0] 236 return tName 237 } 238 239 // Deprecated: This only get the global default transporter - may not be the real one used by the engine. 240 // Use engine.GetTransporterName for the real transporter used. 241 func GetTransporterName() (tName string) { 242 defer func() { 243 err := recover() 244 if err != nil || tName == "" { 245 tName = unknownTransporterName 246 } 247 }() 248 fName := runtime.FuncForPC(reflect.ValueOf(defaultTransporter).Pointer()).Name() 249 fSlice := strings.Split(fName, "/") 250 name := fSlice[len(fSlice)-1] 251 fSlice = strings.Split(name, ".") 252 tName = fSlice[0] 253 return 254 } 255 256 func (engine *Engine) IsStreamRequestBody() bool { 257 return engine.options.StreamRequestBody 258 } 259 260 func (engine *Engine) IsRunning() bool { 261 return atomic.LoadUint32(&engine.status) == statusRunning 262 } 263 264 func (engine *Engine) HijackConnHandle(c network.Conn, h app.HijackHandler) { 265 engine.hijackConnHandler(c, h) 266 } 267 268 func (engine *Engine) GetTracer() tracer.Controller { 269 return engine.tracerCtl 270 } 271 272 const ( 273 _ uint32 = iota 274 statusInitialized 275 statusRunning 276 statusShutdown 277 statusClosed 278 ) 279 280 // NewContext make a pure RequestContext without any http request/response information 281 // 282 // Set the Request filed before use it for handlers 283 func (engine *Engine) NewContext() *app.RequestContext { 284 return app.NewContext(engine.maxParams) 285 } 286 287 // Shutdown starts the server's graceful exit by next steps: 288 // 289 // 1. Trigger OnShutdown hooks concurrently and wait them until wait timeout or finish 290 // 2. Close the net listener, which means new connection won't be accepted 291 // 3. Wait all connections get closed: 292 // One connection gets closed after reaching out the shorter time of processing 293 // one request (in hand or next incoming), idleTimeout or ExitWaitTime 294 // 4. Exit 295 func (engine *Engine) Shutdown(ctx context.Context) (err error) { 296 if atomic.LoadUint32(&engine.status) != statusRunning { 297 return errStatusNotRunning 298 } 299 if !atomic.CompareAndSwapUint32(&engine.status, statusRunning, statusShutdown) { 300 return 301 } 302 303 ch := make(chan struct{}) 304 // trigger hooks if any 305 go engine.executeOnShutdownHooks(ctx, ch) 306 307 defer func() { 308 // ensure that the hook is executed until wait timeout or finish 309 select { 310 case <-ctx.Done(): 311 hlog.SystemLogger().Infof("Execute OnShutdownHooks timeout: error=%v", ctx.Err()) 312 return 313 case <-ch: 314 hlog.SystemLogger().Info("Execute OnShutdownHooks finish") 315 return 316 } 317 }() 318 319 if opt := engine.options; opt != nil && opt.Registry != nil { 320 if err = opt.Registry.Deregister(opt.RegistryInfo); err != nil { 321 hlog.SystemLogger().Errorf("Deregister error=%v", err) 322 return err 323 } 324 } 325 326 // call transport shutdown 327 if err := engine.transport.Shutdown(ctx); err != ctx.Err() { 328 return err 329 } 330 331 return 332 } 333 334 func (engine *Engine) executeOnShutdownHooks(ctx context.Context, ch chan struct{}) { 335 wg := sync.WaitGroup{} 336 for i := range engine.OnShutdown { 337 wg.Add(1) 338 go func(index int) { 339 defer wg.Done() 340 engine.OnShutdown[index](ctx) 341 }(i) 342 } 343 wg.Wait() 344 ch <- struct{}{} 345 } 346 347 func (engine *Engine) Run() (err error) { 348 if err = engine.Init(); err != nil { 349 return err 350 } 351 352 // trigger hooks if any 353 ctx := context.Background() 354 for i := range engine.OnRun { 355 if err = engine.OnRun[i](ctx); err != nil { 356 return err 357 } 358 } 359 360 if err = engine.MarkAsRunning(); err != nil { 361 return err 362 } 363 defer atomic.StoreUint32(&engine.status, statusClosed) 364 365 return engine.listenAndServe() 366 } 367 368 func (engine *Engine) Init() error { 369 // add built-in http1 server by default 370 if !engine.HasServer(suite.HTTP1) { 371 engine.AddProtocol(suite.HTTP1, factory.NewServerFactory(newHttp1OptionFromEngine(engine))) 372 } 373 374 serverMap, streamServerMap, err := engine.protocolSuite.LoadAll(engine) 375 if err != nil { 376 return errs.New(err, errs.ErrorTypePrivate, "LoadAll protocol suite error") 377 } 378 379 engine.protocolServers = serverMap 380 engine.protocolStreamServers = streamServerMap 381 382 if engine.alpnEnable() { 383 engine.options.TLS.NextProtos = append(engine.options.TLS.NextProtos, suite.HTTP1) 384 } 385 386 if !atomic.CompareAndSwapUint32(&engine.status, 0, statusInitialized) { 387 return errInitFailed 388 } 389 return nil 390 } 391 392 func (engine *Engine) alpnEnable() bool { 393 return engine.options.TLS != nil && engine.options.ALPN 394 } 395 396 func (engine *Engine) listenAndServe() error { 397 hlog.SystemLogger().Infof("Using network library=%s", engine.GetTransporterName()) 398 return engine.transport.ListenAndServe(engine.onData) 399 } 400 401 func (c *hijackConn) Close() error { 402 if !c.e.KeepHijackedConns { 403 // when we do not keep hijacked connections, 404 // it is closed in hijackConnHandler. 405 return nil 406 } 407 408 conn := c.Conn 409 c.e.releaseHijackConn(c) 410 return conn.Close() 411 } 412 413 func (engine *Engine) getNextProto(conn network.Conn) (proto string, err error) { 414 if tlsConn, ok := conn.(network.ConnTLSer); ok { 415 if engine.options.ReadTimeout > 0 { 416 if err := conn.SetReadTimeout(engine.options.ReadTimeout); err != nil { 417 hlog.SystemLogger().Errorf("BUG: error in SetReadDeadline=%s: error=%s", engine.options.ReadTimeout, err) 418 } 419 } 420 err = tlsConn.Handshake() 421 if err == nil { 422 proto = tlsConn.ConnectionState().NegotiatedProtocol 423 } 424 } 425 return 426 } 427 428 func (engine *Engine) onData(c context.Context, conn interface{}) (err error) { 429 switch conn := conn.(type) { 430 case network.Conn: 431 err = engine.Serve(c, conn) 432 case network.StreamConn: 433 err = engine.ServeStream(c, conn) 434 } 435 return 436 } 437 438 func errProcess(conn io.Closer, err error) { 439 if err == nil { 440 return 441 } 442 443 defer func() { 444 if err != nil { 445 conn.Close() 446 } 447 }() 448 449 // Quiet close the connection 450 if errors.Is(err, errs.ErrShortConnection) || errors.Is(err, errs.ErrIdleTimeout) { 451 return 452 } 453 454 // Do not process the hijack connection error 455 if errors.Is(err, errs.ErrHijacked) { 456 err = nil 457 return 458 } 459 460 // Get remote address 461 rip := getRemoteAddrFromCloser(conn) 462 463 // Handle Specific error 464 if hsp, ok := conn.(network.HandleSpecificError); ok { 465 if hsp.HandleSpecificError(err, rip) { 466 return 467 } 468 } 469 // other errors 470 hlog.SystemLogger().Errorf(hlog.EngineErrorFormat, err.Error(), rip) 471 } 472 473 func getRemoteAddrFromCloser(conn io.Closer) string { 474 if c, ok := conn.(network.Conn); ok { 475 if addr := c.RemoteAddr(); addr != nil { 476 return addr.String() 477 } 478 } 479 return "" 480 } 481 482 func (engine *Engine) Close() error { 483 if engine.htmlRender != nil { 484 engine.htmlRender.Close() //nolint:errcheck 485 } 486 return engine.transport.Close() 487 } 488 489 func (engine *Engine) GetServerName() []byte { 490 v := engine.serverName.Load() 491 var serverName []byte 492 if v == nil { 493 serverName = []byte(engine.Name) 494 if len(serverName) == 0 { 495 serverName = bytestr.DefaultServerName 496 } 497 engine.serverName.Store(serverName) 498 } else { 499 serverName = v.([]byte) 500 } 501 return serverName 502 } 503 504 func (engine *Engine) Serve(c context.Context, conn network.Conn) (err error) { 505 defer func() { 506 errProcess(conn, err) 507 }() 508 509 // H2C path 510 if engine.options.H2C { 511 // protocol sniffer 512 buf, _ := conn.Peek(len(bytestr.StrClientPreface)) 513 if bytes.Equal(buf, bytestr.StrClientPreface) && engine.protocolServers[suite.HTTP2] != nil { 514 return engine.protocolServers[suite.HTTP2].Serve(c, conn) 515 } 516 hlog.SystemLogger().Warn("HTTP2 server is not loaded, request is going to fallback to HTTP1 server") 517 } 518 519 // ALPN path 520 if engine.options.ALPN && engine.options.TLS != nil { 521 proto, err1 := engine.getNextProto(conn) 522 if err1 != nil { 523 // The client closes the connection when handshake. So just ignore it. 524 if err1 == io.EOF { 525 return nil 526 } 527 if re, ok := err1.(tls.RecordHeaderError); ok && re.Conn != nil && utils.TLSRecordHeaderLooksLikeHTTP(re.RecordHeader) { 528 io.WriteString(re.Conn, "HTTP/1.0 400 Bad Request\r\n\r\nClient sent an HTTP request to an HTTPS server.\n") 529 re.Conn.Close() 530 return re 531 } 532 return err1 533 } 534 if server, ok := engine.protocolServers[proto]; ok { 535 return server.Serve(c, conn) 536 } 537 } 538 539 // HTTP1 path 540 err = engine.protocolServers[suite.HTTP1].Serve(c, conn) 541 542 return 543 } 544 545 func (engine *Engine) ServeStream(ctx context.Context, conn network.StreamConn) error { 546 // ALPN path 547 if engine.options.ALPN && engine.options.TLS != nil { 548 version := conn.GetVersion() 549 nextProtocol := versionToALNP(version) 550 if server, ok := engine.protocolStreamServers[nextProtocol]; ok { 551 return server.Serve(ctx, conn) 552 } 553 } 554 555 // default path 556 if server, ok := engine.protocolStreamServers[suite.HTTP3]; ok { 557 return server.Serve(ctx, conn) 558 } 559 return errs.ErrNotSupportProtocol 560 } 561 562 func (engine *Engine) initBinderAndValidator(opt *config.Options) { 563 // init validator 564 if opt.CustomValidator != nil { 565 customValidator, ok := opt.CustomValidator.(binding.StructValidator) 566 if !ok { 567 panic("customized validator does not implement binding.StructValidator") 568 } 569 engine.validator = customValidator 570 } else { 571 engine.validator = binding.NewValidator(binding.NewValidateConfig()) 572 if opt.ValidateConfig != nil { 573 vConf, ok := opt.ValidateConfig.(*binding.ValidateConfig) 574 if !ok { 575 panic("opt.ValidateConfig is not the '*binding.ValidateConfig' type") 576 } 577 engine.validator = binding.NewValidator(vConf) 578 } 579 } 580 581 if opt.CustomBinder != nil { 582 customBinder, ok := opt.CustomBinder.(binding.Binder) 583 if !ok { 584 panic("customized binder can not implement binding.Binder") 585 } 586 engine.binder = customBinder 587 return 588 } 589 // Init binder. Due to the existence of the "BindAndValidate" interface, the Validator needs to be injected here. 590 defaultBindConfig := binding.NewBindConfig() 591 defaultBindConfig.Validator = engine.validator 592 engine.binder = binding.NewDefaultBinder(defaultBindConfig) 593 if opt.BindConfig != nil { 594 bConf, ok := opt.BindConfig.(*binding.BindConfig) 595 if !ok { 596 panic("opt.BindConfig is not the '*binding.BindConfig' type") 597 } 598 if bConf.Validator == nil { 599 bConf.Validator = engine.validator 600 } 601 engine.binder = binding.NewDefaultBinder(bConf) 602 } 603 } 604 605 func NewEngine(opt *config.Options) *Engine { 606 engine := &Engine{ 607 trees: make(MethodTrees, 0, 9), 608 RouterGroup: RouterGroup{ 609 Handlers: nil, 610 basePath: opt.BasePath, 611 root: true, 612 }, 613 transport: defaultTransporter(opt), 614 tracerCtl: &internalStats.Controller{}, 615 protocolServers: make(map[string]protocol.Server), 616 protocolStreamServers: make(map[string]protocol.StreamServer), 617 enableTrace: true, 618 options: opt, 619 } 620 engine.initBinderAndValidator(opt) 621 if opt.TransporterNewer != nil { 622 engine.transport = opt.TransporterNewer(opt) 623 } 624 engine.RouterGroup.engine = engine 625 626 traceLevel := initTrace(engine) 627 628 // prepare RequestContext pool 629 engine.ctxPool.New = func() interface{} { 630 ctx := engine.allocateContext() 631 if engine.enableTrace { 632 ti := traceinfo.NewTraceInfo() 633 ti.Stats().SetLevel(traceLevel) 634 ctx.SetTraceInfo(ti) 635 } 636 return ctx 637 } 638 639 // Init protocolSuite 640 engine.protocolSuite = suite.New() 641 642 return engine 643 } 644 645 func initTrace(engine *Engine) stats.Level { 646 for _, ti := range engine.options.Tracers { 647 if tracer, ok := ti.(tracer.Tracer); ok { 648 engine.tracerCtl.Append(tracer) 649 } 650 } 651 652 if !engine.tracerCtl.HasTracer() { 653 engine.enableTrace = false 654 } 655 656 traceLevel := stats.LevelDetailed 657 if tl, ok := engine.options.TraceLevel.(stats.Level); ok { 658 traceLevel = tl 659 } 660 return traceLevel 661 } 662 663 func debugPrintRoute(httpMethod, absolutePath string, handlers app.HandlersChain) { 664 nuHandlers := len(handlers) 665 handlerName := app.GetHandlerName(handlers.Last()) 666 if handlerName == "" { 667 handlerName = utils.NameOfFunction(handlers.Last()) 668 } 669 hlog.SystemLogger().Debugf("Method=%-6s absolutePath=%-25s --> handlerName=%s (num=%d handlers)", httpMethod, absolutePath, handlerName, nuHandlers) 670 } 671 672 func (engine *Engine) addRoute(method, path string, handlers app.HandlersChain) { 673 if len(path) == 0 { 674 panic("path should not be ''") 675 } 676 utils.Assert(path[0] == '/', "path must begin with '/'") 677 utils.Assert(method != "", "HTTP method can not be empty") 678 utils.Assert(len(handlers) > 0, "there must be at least one handler") 679 680 if !engine.options.DisablePrintRoute { 681 debugPrintRoute(method, path, handlers) 682 } 683 684 methodRouter := engine.trees.get(method) 685 if methodRouter == nil { 686 methodRouter = &router{method: method, root: &node{}, hasTsrHandler: make(map[string]bool)} 687 engine.trees = append(engine.trees, methodRouter) 688 } 689 methodRouter.addRoute(path, handlers) 690 691 // Update maxParams 692 if paramsCount := countParams(path); paramsCount > engine.maxParams { 693 engine.maxParams = paramsCount 694 } 695 } 696 697 func (engine *Engine) PrintRoute(method string) { 698 root := engine.trees.get(method) 699 printNode(root.root, 0) 700 } 701 702 // debug use 703 func printNode(node *node, level int) { 704 fmt.Println("node.prefix: " + node.prefix) 705 fmt.Println("node.ppath: " + node.ppath) 706 fmt.Printf("level: %#v\n\n", level) 707 for i := 0; i < len(node.children); i++ { 708 printNode(node.children[i], level+1) 709 } 710 } 711 712 func (engine *Engine) recv(ctx *app.RequestContext) { 713 if rcv := recover(); rcv != nil { 714 engine.PanicHandler(context.Background(), ctx) 715 } 716 } 717 718 // ServeHTTP makes the router implement the Handler interface. 719 func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { 720 ctx.SetBinder(engine.binder) 721 ctx.SetValidator(engine.validator) 722 if engine.PanicHandler != nil { 723 defer engine.recv(ctx) 724 } 725 726 rPath := string(ctx.Request.URI().Path()) 727 728 // align with https://datatracker.ietf.org/doc/html/rfc2616#section-5.2 729 if len(ctx.Request.Host()) == 0 && ctx.Request.Header.IsHTTP11() && bytesconv.B2s(ctx.Request.Method()) != consts.MethodConnect { 730 ctx.SetHandlers(engine.Handlers) 731 serveError(c, ctx, consts.StatusBadRequest, requiredHostBody) 732 return 733 } 734 735 httpMethod := bytesconv.B2s(ctx.Request.Header.Method()) 736 unescape := false 737 if engine.options.UseRawPath { 738 rPath = string(ctx.Request.URI().PathOriginal()) 739 unescape = engine.options.UnescapePathValues 740 } 741 742 if engine.options.RemoveExtraSlash { 743 rPath = utils.CleanPath(rPath) 744 } 745 746 // Follow RFC7230#section-5.3 747 if rPath == "" || rPath[0] != '/' { 748 ctx.SetHandlers(engine.Handlers) 749 serveError(c, ctx, consts.StatusBadRequest, default400Body) 750 return 751 } 752 753 // if Params is re-assigned in HandlerFunc and the capacity is not enough we need to realloc 754 maxParams := int(engine.maxParams) 755 if cap(ctx.Params) < maxParams { 756 ctx.Params = make(param.Params, 0, maxParams) 757 } 758 759 // Find root of the tree for the given HTTP method 760 t := engine.trees 761 paramsPointer := &ctx.Params 762 for i, tl := 0, len(t); i < tl; i++ { 763 if t[i].method != httpMethod { 764 continue 765 } 766 // Find route in tree 767 value := t[i].find(rPath, paramsPointer, unescape) 768 769 if value.handlers != nil { 770 ctx.SetHandlers(value.handlers) 771 ctx.SetFullPath(value.fullPath) 772 ctx.Next(c) 773 return 774 } 775 if httpMethod != consts.MethodConnect && rPath != "/" { 776 if value.tsr && engine.options.RedirectTrailingSlash { 777 redirectTrailingSlash(ctx) 778 return 779 } 780 if engine.options.RedirectFixedPath && redirectFixedPath(ctx, t[i].root, engine.options.RedirectFixedPath) { 781 return 782 } 783 } 784 break 785 } 786 787 if engine.options.HandleMethodNotAllowed { 788 for _, tree := range engine.trees { 789 if tree.method == httpMethod { 790 continue 791 } 792 if value := tree.find(rPath, paramsPointer, unescape); value.handlers != nil { 793 ctx.SetHandlers(engine.allNoMethod) 794 serveError(c, ctx, consts.StatusMethodNotAllowed, default405Body) 795 return 796 } 797 } 798 } 799 ctx.SetHandlers(engine.allNoRoute) 800 serveError(c, ctx, consts.StatusNotFound, default404Body) 801 } 802 803 func (engine *Engine) allocateContext() *app.RequestContext { 804 ctx := engine.NewContext() 805 ctx.Request.SetMaxKeepBodySize(engine.options.MaxKeepBodySize) 806 ctx.Response.SetMaxKeepBodySize(engine.options.MaxKeepBodySize) 807 ctx.SetClientIPFunc(engine.clientIPFunc) 808 ctx.SetFormValueFunc(engine.formValueFunc) 809 return ctx 810 } 811 812 func serveError(c context.Context, ctx *app.RequestContext, code int, defaultMessage []byte) { 813 ctx.SetStatusCode(code) 814 ctx.Next(c) 815 if ctx.Response.StatusCode() == code { 816 // if body exists(maybe customized by users), leave it alone. 817 if ctx.Response.HasBodyBytes() || ctx.Response.IsBodyStream() { 818 return 819 } 820 ctx.Response.Header.Set("Content-Type", "text/plain") 821 ctx.Response.SetBody(defaultMessage) 822 } 823 } 824 825 func trailingSlashURL(ts string) string { 826 tmpURI := ts + "/" 827 if length := len(ts); length > 1 && ts[length-1] == '/' { 828 tmpURI = ts[:length-1] 829 } 830 return tmpURI 831 } 832 833 func redirectTrailingSlash(c *app.RequestContext) { 834 p := bytesconv.B2s(c.Request.URI().Path()) 835 if prefix := utils.CleanPath(bytesconv.B2s(c.Request.Header.Peek("X-Forwarded-Prefix"))); prefix != "." { 836 p = prefix + "/" + p 837 } 838 839 tmpURI := trailingSlashURL(p) 840 841 query := c.Request.URI().QueryString() 842 843 if len(query) > 0 { 844 tmpURI = tmpURI + "?" + bytesconv.B2s(query) 845 } 846 847 c.Request.SetRequestURI(tmpURI) 848 redirectRequest(c) 849 } 850 851 func redirectRequest(c *app.RequestContext) { 852 code := consts.StatusMovedPermanently // Permanent redirect, request with GET method 853 if bytesconv.B2s(c.Request.Header.Method()) != consts.MethodGet { 854 code = consts.StatusTemporaryRedirect 855 } 856 857 c.Redirect(code, c.Request.URI().RequestURI()) 858 } 859 860 func redirectFixedPath(c *app.RequestContext, root *node, trailingSlash bool) bool { 861 rPath := bytesconv.B2s(c.Request.URI().Path()) 862 if fixedPath, ok := root.findCaseInsensitivePath(utils.CleanPath(rPath), trailingSlash); ok { 863 c.Request.SetRequestURI(bytesconv.B2s(fixedPath)) 864 redirectRequest(c) 865 return true 866 } 867 return false 868 } 869 870 // NoRoute adds handlers for NoRoute. It returns a 404 code by default. 871 func (engine *Engine) NoRoute(handlers ...app.HandlerFunc) { 872 engine.noRoute = handlers 873 engine.rebuild404Handlers() 874 } 875 876 // NoMethod sets the handlers called when the HTTP method does not match. 877 func (engine *Engine) NoMethod(handlers ...app.HandlerFunc) { 878 engine.noMethod = handlers 879 engine.rebuild405Handlers() 880 } 881 882 func (engine *Engine) rebuild404Handlers() { 883 engine.allNoRoute = engine.combineHandlers(engine.noRoute) 884 } 885 886 func (engine *Engine) rebuild405Handlers() { 887 engine.allNoMethod = engine.combineHandlers(engine.noMethod) 888 } 889 890 // Use attaches a global middleware to the router. ie. the middleware attached though Use() will be 891 // included in the handlers chain for every single request. Even 404, 405, static files... 892 // 893 // For example, this is the right place for a logger or error management middleware. 894 func (engine *Engine) Use(middleware ...app.HandlerFunc) IRoutes { 895 engine.RouterGroup.Use(middleware...) 896 engine.rebuild404Handlers() 897 engine.rebuild405Handlers() 898 return engine 899 } 900 901 // LoadHTMLGlob loads HTML files identified by glob pattern 902 // and associates the result with HTML renderer. 903 func (engine *Engine) LoadHTMLGlob(pattern string) { 904 tmpl := template.Must(template.New(""). 905 Delims(engine.delims.Left, engine.delims.Right). 906 Funcs(engine.funcMap). 907 ParseGlob(pattern)) 908 909 if engine.options.AutoReloadRender { 910 files, err := filepath.Glob(pattern) 911 if err != nil { 912 hlog.SystemLogger().Errorf("LoadHTMLGlob: %v", err) 913 return 914 } 915 engine.SetAutoReloadHTMLTemplate(tmpl, files) 916 return 917 } 918 919 engine.SetHTMLTemplate(tmpl) 920 } 921 922 // LoadHTMLFiles loads a slice of HTML files 923 // and associates the result with HTML renderer. 924 func (engine *Engine) LoadHTMLFiles(files ...string) { 925 tmpl := template.Must(template.New(""). 926 Delims(engine.delims.Left, engine.delims.Right). 927 Funcs(engine.funcMap). 928 ParseFiles(files...)) 929 930 if engine.options.AutoReloadRender { 931 engine.SetAutoReloadHTMLTemplate(tmpl, files) 932 return 933 } 934 935 engine.SetHTMLTemplate(tmpl) 936 } 937 938 // SetHTMLTemplate associate a template with HTML renderer. 939 func (engine *Engine) SetHTMLTemplate(tmpl *template.Template) { 940 engine.htmlRender = render.HTMLProduction{Template: tmpl.Funcs(engine.funcMap)} 941 } 942 943 // SetAutoReloadHTMLTemplate associate a template with HTML renderer. 944 func (engine *Engine) SetAutoReloadHTMLTemplate(tmpl *template.Template, files []string) { 945 engine.htmlRender = &render.HTMLDebug{ 946 Template: tmpl, 947 Files: files, 948 FuncMap: engine.funcMap, 949 Delims: engine.delims, 950 RefreshInterval: engine.options.AutoReloadInterval, 951 } 952 } 953 954 // SetFuncMap sets the funcMap used for template.funcMap. 955 func (engine *Engine) SetFuncMap(funcMap template.FuncMap) { 956 engine.funcMap = funcMap 957 } 958 959 func (engine *Engine) SetClientIPFunc(f app.ClientIP) { 960 engine.clientIPFunc = f 961 } 962 963 func (engine *Engine) SetFormValueFunc(f app.FormValueFunc) { 964 engine.formValueFunc = f 965 } 966 967 // Delims sets template left and right delims and returns an Engine instance. 968 func (engine *Engine) Delims(left, right string) *Engine { 969 engine.delims = render.Delims{Left: left, Right: right} 970 return engine 971 } 972 973 func (engine *Engine) acquireHijackConn(c network.Conn) *hijackConn { 974 if engine.NoHijackConnPool { 975 return &hijackConn{ 976 Conn: c, 977 e: engine, 978 } 979 } 980 v := engine.hijackConnPool.Get() 981 if v == nil { 982 return &hijackConn{ 983 Conn: c, 984 e: engine, 985 } 986 } 987 hjc := v.(*hijackConn) 988 hjc.Conn = c 989 return hjc 990 } 991 992 func (engine *Engine) releaseHijackConn(hjc *hijackConn) { 993 if engine.NoHijackConnPool { 994 return 995 } 996 hjc.Conn = nil 997 engine.hijackConnPool.Put(hjc) 998 } 999 1000 func (engine *Engine) hijackConnHandler(c network.Conn, h app.HijackHandler) { 1001 hjc := engine.acquireHijackConn(c) 1002 h(hjc) 1003 1004 if !engine.KeepHijackedConns { 1005 c.Close() 1006 engine.releaseHijackConn(hjc) 1007 } 1008 } 1009 1010 // Routes returns a slice of registered routes, including some useful information, such as: 1011 // the http method, path and the handler name. 1012 func (engine *Engine) Routes() (routes RoutesInfo) { 1013 for _, tree := range engine.trees { 1014 routes = iterate(tree.method, routes, tree.root) 1015 } 1016 1017 return routes 1018 } 1019 1020 func (engine *Engine) AddProtocol(protocol string, factory interface{}) { 1021 engine.protocolSuite.Add(protocol, factory) 1022 } 1023 1024 // SetAltHeader sets the value of "Alt-Svc" header for protocols other than targetProtocol. 1025 func (engine *Engine) SetAltHeader(targetProtocol, altHeaderValue string) { 1026 engine.protocolSuite.SetAltHeader(targetProtocol, altHeaderValue) 1027 } 1028 1029 func (engine *Engine) HasServer(name string) bool { 1030 return engine.protocolSuite.Get(name) != nil 1031 } 1032 1033 // iterate iterates the method tree by depth firstly. 1034 func iterate(method string, routes RoutesInfo, root *node) RoutesInfo { 1035 if len(root.handlers) > 0 { 1036 handlerFunc := root.handlers.Last() 1037 routes = append(routes, RouteInfo{ 1038 Method: method, 1039 Path: root.ppath, 1040 Handler: utils.NameOfFunction(handlerFunc), 1041 HandlerFunc: handlerFunc, 1042 }) 1043 } 1044 1045 for _, child := range root.children { 1046 routes = iterate(method, routes, child) 1047 } 1048 1049 if root.paramChild != nil { 1050 routes = iterate(method, routes, root.paramChild) 1051 } 1052 1053 if root.anyChild != nil { 1054 routes = iterate(method, routes, root.anyChild) 1055 } 1056 return routes 1057 } 1058 1059 // for built-in http1 impl only. 1060 func newHttp1OptionFromEngine(engine *Engine) *http1.Option { 1061 opt := &http1.Option{ 1062 StreamRequestBody: engine.options.StreamRequestBody, 1063 GetOnly: engine.options.GetOnly, 1064 DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm, 1065 DisableKeepalive: engine.options.DisableKeepalive, 1066 NoDefaultServerHeader: engine.options.NoDefaultServerHeader, 1067 MaxRequestBodySize: engine.options.MaxRequestBodySize, 1068 IdleTimeout: engine.options.IdleTimeout, 1069 ReadTimeout: engine.options.ReadTimeout, 1070 ServerName: engine.GetServerName(), 1071 ContinueHandler: engine.ContinueHandler, 1072 TLS: engine.options.TLS, 1073 HTMLRender: engine.htmlRender, 1074 EnableTrace: engine.IsTraceEnable(), 1075 HijackConnHandle: engine.HijackConnHandle, 1076 DisableHeaderNamesNormalizing: engine.options.DisableHeaderNamesNormalizing, 1077 NoDefaultDate: engine.options.NoDefaultDate, 1078 NoDefaultContentType: engine.options.NoDefaultContentType, 1079 } 1080 // Idle timeout of standard network must not be zero. Set it to -1 seconds if it is zero. 1081 // Due to the different triggering ways of the network library, see the actual use of this value for the detailed reasons. 1082 if opt.IdleTimeout == 0 && engine.GetTransporterName() == "standard" { 1083 opt.IdleTimeout = -1 1084 } 1085 return opt 1086 } 1087 1088 func versionToALNP(v uint32) string { 1089 if v == network.Version1 || v == network.Version2 { 1090 return suite.HTTP3 1091 } 1092 if v == network.VersionTLS || v == network.VersionDraft29 { 1093 return suite.HTTP3Draft29 1094 } 1095 return "" 1096 } 1097 1098 // MarkAsRunning will mark the status of the hertz engine as "running". 1099 // Warning: do not call this method by yourself, unless you know what you are doing. 1100 func (engine *Engine) MarkAsRunning() (err error) { 1101 if !atomic.CompareAndSwapUint32(&engine.status, statusInitialized, statusRunning) { 1102 return errAlreadyRunning 1103 } 1104 return nil 1105 }