github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/server.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package http1 18 19 import ( 20 "context" 21 "crypto/tls" 22 "errors" 23 "io" 24 "net" 25 "sync" 26 "time" 27 28 "github.com/cloudwego/hertz/internal/bytestr" 29 internalStats "github.com/cloudwego/hertz/internal/stats" 30 "github.com/cloudwego/hertz/pkg/app" 31 "github.com/cloudwego/hertz/pkg/app/server/render" 32 errs "github.com/cloudwego/hertz/pkg/common/errors" 33 "github.com/cloudwego/hertz/pkg/common/tracer/stats" 34 "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" 35 "github.com/cloudwego/hertz/pkg/common/utils" 36 "github.com/cloudwego/hertz/pkg/network" 37 "github.com/cloudwego/hertz/pkg/protocol" 38 "github.com/cloudwego/hertz/pkg/protocol/consts" 39 "github.com/cloudwego/hertz/pkg/protocol/http1/ext" 40 "github.com/cloudwego/hertz/pkg/protocol/http1/req" 41 "github.com/cloudwego/hertz/pkg/protocol/http1/resp" 42 "github.com/cloudwego/hertz/pkg/protocol/suite" 43 ) 44 45 func init() { 46 if b, err := utils.GetBoolFromEnv("HERTZ_DISABLE_REQUEST_CONTEXT_POOL"); err == nil { 47 disabaleRequestContextPool = b 48 } 49 } 50 51 // NextProtoTLS is the NPN/ALPN protocol negotiated during 52 // HTTP/1.1's TLS setup. 53 // Also used for server addressing 54 const NextProtoTLS = suite.HTTP1 55 56 var ( 57 errHijacked = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil) 58 errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil) 59 errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection") 60 errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request") 61 62 disabaleRequestContextPool = false 63 ) 64 65 type Option struct { 66 StreamRequestBody bool 67 GetOnly bool 68 NoDefaultDate bool 69 NoDefaultContentType bool 70 DisablePreParseMultipartForm bool 71 DisableKeepalive bool 72 NoDefaultServerHeader bool 73 DisableHeaderNamesNormalizing bool 74 MaxRequestBodySize int 75 IdleTimeout time.Duration 76 ReadTimeout time.Duration 77 ServerName []byte 78 TLS *tls.Config 79 HTMLRender render.HTMLRender 80 EnableTrace bool 81 ContinueHandler func(header *protocol.RequestHeader) bool 82 HijackConnHandle func(c network.Conn, h app.HijackHandler) 83 } 84 85 type Server struct { 86 Option 87 Core suite.Core 88 89 eventStackPool *sync.Pool 90 } 91 92 func (s Server) getRequestContext() *app.RequestContext { 93 if disabaleRequestContextPool { 94 return &app.RequestContext{} 95 } 96 return s.Core.GetCtxPool().Get().(*app.RequestContext) 97 } 98 99 func (s Server) putRequestContext(ctx *app.RequestContext) { 100 if disabaleRequestContextPool { 101 return 102 } 103 ctx.Reset() 104 s.Core.GetCtxPool().Put(ctx) 105 } 106 107 func (s Server) Serve(c context.Context, conn network.Conn) (err error) { 108 var ( 109 zr network.Reader 110 zw network.Writer 111 112 serverName []byte 113 isHTTP11 bool 114 connectionClose bool 115 116 continueReadingRequest = true 117 118 hijackHandler app.HijackHandler 119 120 // HTTP1 path 121 // 1. Get a request context 122 // 2. Prepare it 123 // 3. Process it 124 // 4. Reset and recycle(in pooled mode) 125 ctx = s.getRequestContext() 126 127 traceCtl = s.Core.GetTracer() 128 eventsToTrigger *eventStack 129 130 // Use a new variable to hold the standard context to avoid modify the initial 131 // context. 132 cc = c 133 ) 134 135 if s.EnableTrace { 136 eventsToTrigger = s.eventStackPool.Get().(*eventStack) 137 } 138 139 defer func() { 140 if s.EnableTrace { 141 // in case of error, we need to trigger all events 142 if eventsToTrigger != nil { 143 for last := eventsToTrigger.pop(); last != nil; last = eventsToTrigger.pop() { 144 last(ctx.GetTraceInfo(), err) 145 } 146 s.eventStackPool.Put(eventsToTrigger) 147 } 148 if shouldRecordInTraceError(err) { 149 traceCtl.DoFinish(cc, ctx, err) 150 } else { 151 traceCtl.DoFinish(cc, ctx, nil) 152 } 153 } 154 155 // Hijack may release and close the connection already 156 if zr != nil && !errors.Is(err, errs.ErrHijacked) { 157 zr.Release() //nolint:errcheck 158 zr = nil 159 } 160 161 if ctx.IsExiled() { 162 return 163 } 164 165 s.putRequestContext(ctx) 166 }() 167 168 ctx.HTMLRender = s.HTMLRender 169 ctx.SetConn(conn) 170 ctx.Request.SetIsTLS(s.TLS != nil) 171 ctx.SetEnableTrace(s.EnableTrace) 172 173 if !s.NoDefaultServerHeader { 174 serverName = s.ServerName 175 } 176 177 connRequestNum := uint64(0) 178 179 for { 180 connRequestNum++ 181 182 if zr == nil { 183 zr = ctx.GetReader() 184 } 185 186 // If this is a keep-alive connection we want to try and read the first bytes 187 // within the idle time. 188 if connRequestNum > 1 { 189 ctx.GetConn().SetReadTimeout(s.IdleTimeout) //nolint:errcheck 190 191 _, err = zr.Peek(4) 192 // This is not the first request, and we haven't read a single byte 193 // of a new request yet. This means it's just a keep-alive connection 194 // closing down either because the remote closed it or because 195 // or a read timeout on our side. Either way just close the connection 196 // and don't return any error response. 197 if err != nil { 198 err = errIdleTimeout 199 return 200 } 201 202 // Reset the real read timeout for the coming request 203 ctx.GetConn().SetReadTimeout(s.ReadTimeout) //nolint:errcheck 204 } 205 206 if s.EnableTrace { 207 cc = traceCtl.DoStart(c, ctx) 208 internalStats.Record(ctx.GetTraceInfo(), stats.ReadHeaderStart, err) 209 eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) { 210 internalStats.Record(ti, stats.ReadHeaderFinish, err) 211 }) 212 } 213 214 ctx.Response.Header.SetNoDefaultDate(s.NoDefaultDate) 215 ctx.Response.Header.SetNoDefaultContentType(s.NoDefaultContentType) 216 217 if s.DisableHeaderNamesNormalizing { 218 ctx.Request.Header.DisableNormalizing() 219 ctx.Response.Header.DisableNormalizing() 220 } 221 222 // Read Headers 223 if err = req.ReadHeader(&ctx.Request.Header, zr); err == nil { 224 if s.EnableTrace { 225 // read header finished 226 if last := eventsToTrigger.pop(); last != nil { 227 last(ctx.GetTraceInfo(), err) 228 } 229 internalStats.Record(ctx.GetTraceInfo(), stats.ReadBodyStart, err) 230 eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) { 231 internalStats.Record(ti, stats.ReadBodyFinish, err) 232 }) 233 } 234 // Read body 235 if s.StreamRequestBody { 236 err = req.ReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) 237 } else { 238 err = req.ReadLimitBody(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) 239 } 240 } 241 242 if s.EnableTrace { 243 if ctx.Request.Header.ContentLength() >= 0 { 244 ctx.GetTraceInfo().Stats().SetRecvSize(len(ctx.Request.Header.RawHeaders()) + ctx.Request.Header.ContentLength()) 245 } else { 246 ctx.GetTraceInfo().Stats().SetRecvSize(0) 247 } 248 // read body finished 249 if last := eventsToTrigger.pop(); last != nil { 250 last(ctx.GetTraceInfo(), err) 251 } 252 } 253 254 if err != nil { 255 if errors.Is(err, errs.ErrNothingRead) { 256 return nil 257 } 258 259 if err == io.EOF { 260 return errUnexpectedEOF 261 } 262 writeErrorResponse(zw, ctx, serverName, err) 263 return 264 } 265 266 // 'Expect: 100-continue' request handling. 267 // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details. 268 if ctx.Request.MayContinue() { 269 // Allow the ability to deny reading the incoming request body 270 if s.ContinueHandler != nil { 271 if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest { 272 ctx.SetStatusCode(consts.StatusExpectationFailed) 273 } 274 } 275 276 if continueReadingRequest { 277 zw = ctx.GetWriter() 278 // Send 'HTTP/1.1 100 Continue' response. 279 _, err = zw.WriteBinary(bytestr.StrResponseContinue) 280 if err != nil { 281 return 282 } 283 err = zw.Flush() 284 if err != nil { 285 return 286 } 287 288 // Read body. 289 if zr == nil { 290 zr = ctx.GetReader() 291 } 292 if s.StreamRequestBody { 293 err = req.ContinueReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm) 294 } else { 295 err = req.ContinueReadBody(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm) 296 } 297 if err != nil { 298 writeErrorResponse(zw, ctx, serverName, err) 299 return 300 } 301 } 302 } 303 304 connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose() 305 isHTTP11 = ctx.Request.Header.IsHTTP11() 306 307 if serverName != nil { 308 ctx.Response.Header.SetServerBytes(serverName) 309 } 310 if s.EnableTrace { 311 internalStats.Record(ctx.GetTraceInfo(), stats.ServerHandleStart, err) 312 eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) { 313 internalStats.Record(ti, stats.ServerHandleFinish, err) 314 }) 315 } 316 // Handle the request 317 // 318 // NOTE: All middlewares and business handler will be executed in this. And at this point, the request has been parsed 319 // and the route has been matched. 320 s.Core.ServeHTTP(cc, ctx) 321 if s.EnableTrace { 322 // application layer handle finished 323 if last := eventsToTrigger.pop(); last != nil { 324 last(ctx.GetTraceInfo(), err) 325 } 326 } 327 328 // exit check 329 if !s.Core.IsRunning() { 330 connectionClose = true 331 } 332 333 if !ctx.IsGet() && ctx.IsHead() { 334 ctx.Response.SkipBody = true 335 } 336 337 hijackHandler = ctx.GetHijackHandler() 338 ctx.SetHijackHandler(nil) 339 340 connectionClose = connectionClose || ctx.Response.ConnectionClose() 341 if connectionClose { 342 ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrClose) 343 } else if !isHTTP11 { 344 ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrKeepAlive) 345 } 346 347 if zw == nil { 348 zw = ctx.GetWriter() 349 } 350 if s.EnableTrace { 351 internalStats.Record(ctx.GetTraceInfo(), stats.WriteStart, err) 352 eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) { 353 internalStats.Record(ti, stats.WriteFinish, err) 354 }) 355 } 356 if err = writeResponse(ctx, zw); err != nil { 357 return 358 } 359 360 if s.EnableTrace { 361 if ctx.Response.Header.ContentLength() > 0 { 362 ctx.GetTraceInfo().Stats().SetSendSize(ctx.Response.Header.GetHeaderLength() + ctx.Response.Header.ContentLength()) 363 } else { 364 ctx.GetTraceInfo().Stats().SetSendSize(0) 365 } 366 } 367 368 // Release the zeroCopyReader before flush to prevent data race 369 if zr != nil { 370 zr.Release() //nolint:errcheck 371 zr = nil 372 } 373 // Flush the response. 374 if err = zw.Flush(); err != nil { 375 return 376 } 377 if s.EnableTrace { 378 // write finished 379 if last := eventsToTrigger.pop(); last != nil { 380 last(ctx.GetTraceInfo(), err) 381 } 382 } 383 384 // Release request body stream 385 if ctx.Request.IsBodyStream() { 386 err = ext.ReleaseBodyStream(ctx.RequestBodyStream()) 387 if err != nil { 388 return 389 } 390 } 391 392 if hijackHandler != nil { 393 // Hijacked conn process the timeout by itself 394 err = ctx.GetConn().SetReadTimeout(0) 395 if err != nil { 396 return 397 } 398 399 // Hijack and block the connection until the hijackHandler return 400 s.HijackConnHandle(ctx.GetConn(), hijackHandler) 401 err = errHijacked 402 return 403 } 404 405 if connectionClose { 406 return errShortConnection 407 } 408 // Back to network layer to trigger. 409 // For now, only netpoll network mode has this feature. 410 if s.IdleTimeout == 0 { 411 return 412 } 413 // general case 414 if s.EnableTrace { 415 if shouldRecordInTraceError(err) { 416 traceCtl.DoFinish(cc, ctx, err) 417 } else { 418 traceCtl.DoFinish(cc, ctx, nil) 419 } 420 } 421 422 ctx.ResetWithoutConn() 423 } 424 } 425 426 func NewServer() *Server { 427 return &Server{ 428 eventStackPool: &sync.Pool{ 429 New: func() interface{} { 430 return &eventStack{} 431 }, 432 }, 433 } 434 } 435 436 func writeErrorResponse(zw network.Writer, ctx *app.RequestContext, serverName []byte, err error) network.Writer { 437 errorHandler := defaultErrorHandler 438 439 errorHandler(ctx, err) 440 441 if serverName != nil { 442 ctx.Response.Header.SetServerBytes(serverName) 443 } 444 ctx.SetConnectionClose() 445 if zw == nil { 446 zw = ctx.GetWriter() 447 } 448 writeResponse(ctx, zw) //nolint:errcheck 449 zw.Flush() //nolint:errcheck 450 return zw 451 } 452 453 func writeResponse(ctx *app.RequestContext, w network.Writer) error { 454 // Skip default response writing logic if it has been hijacked 455 if ctx.Response.GetHijackWriter() != nil { 456 return ctx.Response.GetHijackWriter().Finalize() 457 } 458 459 err := resp.Write(&ctx.Response, w) 460 if err != nil { 461 return err 462 } 463 464 return err 465 } 466 467 func defaultErrorHandler(ctx *app.RequestContext, err error) { 468 if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { 469 ctx.AbortWithMsg("Request timeout", consts.StatusRequestTimeout) 470 } else if errors.Is(err, errs.ErrBodyTooLarge) { 471 ctx.AbortWithMsg("Request Entity Too Large", consts.StatusRequestEntityTooLarge) 472 } else { 473 ctx.AbortWithMsg("Error when parsing request", consts.StatusBadRequest) 474 } 475 } 476 477 type eventStack []func(ti traceinfo.TraceInfo, err error) 478 479 func (e *eventStack) isEmpty() bool { 480 return len(*e) == 0 481 } 482 483 func (e *eventStack) push(f func(ti traceinfo.TraceInfo, err error)) { 484 *e = append(*e, f) 485 } 486 487 func (e *eventStack) pop() func(ti traceinfo.TraceInfo, err error) { 488 if e.isEmpty() { 489 return nil 490 } 491 last := (*e)[len(*e)-1] 492 *e = (*e)[:len(*e)-1] 493 return last 494 } 495 496 func shouldRecordInTraceError(err error) bool { 497 if err == nil { 498 return false 499 } 500 501 if errors.Is(err, errs.ErrIdleTimeout) { 502 return false 503 } 504 505 if errors.Is(err, errs.ErrHijacked) { 506 return false 507 } 508 509 if errors.Is(err, errs.ErrShortConnection) { 510 return false 511 } 512 513 return true 514 }