github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/server.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package http1
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"errors"
    23  	"io"
    24  	"net"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/cloudwego/hertz/internal/bytestr"
    29  	internalStats "github.com/cloudwego/hertz/internal/stats"
    30  	"github.com/cloudwego/hertz/pkg/app"
    31  	"github.com/cloudwego/hertz/pkg/app/server/render"
    32  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    33  	"github.com/cloudwego/hertz/pkg/common/tracer/stats"
    34  	"github.com/cloudwego/hertz/pkg/common/tracer/traceinfo"
    35  	"github.com/cloudwego/hertz/pkg/common/utils"
    36  	"github.com/cloudwego/hertz/pkg/network"
    37  	"github.com/cloudwego/hertz/pkg/protocol"
    38  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    39  	"github.com/cloudwego/hertz/pkg/protocol/http1/ext"
    40  	"github.com/cloudwego/hertz/pkg/protocol/http1/req"
    41  	"github.com/cloudwego/hertz/pkg/protocol/http1/resp"
    42  	"github.com/cloudwego/hertz/pkg/protocol/suite"
    43  )
    44  
    45  func init() {
    46  	if b, err := utils.GetBoolFromEnv("HERTZ_DISABLE_REQUEST_CONTEXT_POOL"); err == nil {
    47  		disabaleRequestContextPool = b
    48  	}
    49  }
    50  
    51  // NextProtoTLS is the NPN/ALPN protocol negotiated during
    52  // HTTP/1.1's TLS setup.
    53  // Also used for server addressing
    54  const NextProtoTLS = suite.HTTP1
    55  
    56  var (
    57  	errHijacked        = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil)
    58  	errIdleTimeout     = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil)
    59  	errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection")
    60  	errUnexpectedEOF   = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request")
    61  
    62  	disabaleRequestContextPool = false
    63  )
    64  
    65  type Option struct {
    66  	StreamRequestBody             bool
    67  	GetOnly                       bool
    68  	NoDefaultDate                 bool
    69  	NoDefaultContentType          bool
    70  	DisablePreParseMultipartForm  bool
    71  	DisableKeepalive              bool
    72  	NoDefaultServerHeader         bool
    73  	DisableHeaderNamesNormalizing bool
    74  	MaxRequestBodySize            int
    75  	IdleTimeout                   time.Duration
    76  	ReadTimeout                   time.Duration
    77  	ServerName                    []byte
    78  	TLS                           *tls.Config
    79  	HTMLRender                    render.HTMLRender
    80  	EnableTrace                   bool
    81  	ContinueHandler               func(header *protocol.RequestHeader) bool
    82  	HijackConnHandle              func(c network.Conn, h app.HijackHandler)
    83  }
    84  
    85  type Server struct {
    86  	Option
    87  	Core suite.Core
    88  
    89  	eventStackPool *sync.Pool
    90  }
    91  
    92  func (s Server) getRequestContext() *app.RequestContext {
    93  	if disabaleRequestContextPool {
    94  		return &app.RequestContext{}
    95  	}
    96  	return s.Core.GetCtxPool().Get().(*app.RequestContext)
    97  }
    98  
    99  func (s Server) putRequestContext(ctx *app.RequestContext) {
   100  	if disabaleRequestContextPool {
   101  		return
   102  	}
   103  	ctx.Reset()
   104  	s.Core.GetCtxPool().Put(ctx)
   105  }
   106  
   107  func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
   108  	var (
   109  		zr network.Reader
   110  		zw network.Writer
   111  
   112  		serverName      []byte
   113  		isHTTP11        bool
   114  		connectionClose bool
   115  
   116  		continueReadingRequest = true
   117  
   118  		hijackHandler app.HijackHandler
   119  
   120  		// HTTP1 path
   121  		// 1. Get a request context
   122  		// 2. Prepare it
   123  		// 3. Process it
   124  		// 4. Reset and recycle(in pooled mode)
   125  		ctx = s.getRequestContext()
   126  
   127  		traceCtl        = s.Core.GetTracer()
   128  		eventsToTrigger *eventStack
   129  
   130  		// Use a new variable to hold the standard context to avoid modify the initial
   131  		// context.
   132  		cc = c
   133  	)
   134  
   135  	if s.EnableTrace {
   136  		eventsToTrigger = s.eventStackPool.Get().(*eventStack)
   137  	}
   138  
   139  	defer func() {
   140  		if s.EnableTrace {
   141  			// in case of error, we need to trigger all events
   142  			if eventsToTrigger != nil {
   143  				for last := eventsToTrigger.pop(); last != nil; last = eventsToTrigger.pop() {
   144  					last(ctx.GetTraceInfo(), err)
   145  				}
   146  				s.eventStackPool.Put(eventsToTrigger)
   147  			}
   148  			if shouldRecordInTraceError(err) {
   149  				traceCtl.DoFinish(cc, ctx, err)
   150  			} else {
   151  				traceCtl.DoFinish(cc, ctx, nil)
   152  			}
   153  		}
   154  
   155  		// Hijack may release and close the connection already
   156  		if zr != nil && !errors.Is(err, errs.ErrHijacked) {
   157  			zr.Release() //nolint:errcheck
   158  			zr = nil
   159  		}
   160  
   161  		if ctx.IsExiled() {
   162  			return
   163  		}
   164  
   165  		s.putRequestContext(ctx)
   166  	}()
   167  
   168  	ctx.HTMLRender = s.HTMLRender
   169  	ctx.SetConn(conn)
   170  	ctx.Request.SetIsTLS(s.TLS != nil)
   171  	ctx.SetEnableTrace(s.EnableTrace)
   172  
   173  	if !s.NoDefaultServerHeader {
   174  		serverName = s.ServerName
   175  	}
   176  
   177  	connRequestNum := uint64(0)
   178  
   179  	for {
   180  		connRequestNum++
   181  
   182  		if zr == nil {
   183  			zr = ctx.GetReader()
   184  		}
   185  
   186  		// If this is a keep-alive connection we want to try and read the first bytes
   187  		// within the idle time.
   188  		if connRequestNum > 1 {
   189  			ctx.GetConn().SetReadTimeout(s.IdleTimeout) //nolint:errcheck
   190  
   191  			_, err = zr.Peek(4)
   192  			// This is not the first request, and we haven't read a single byte
   193  			// of a new request yet. This means it's just a keep-alive connection
   194  			// closing down either because the remote closed it or because
   195  			// or a read timeout on our side. Either way just close the connection
   196  			// and don't return any error response.
   197  			if err != nil {
   198  				err = errIdleTimeout
   199  				return
   200  			}
   201  
   202  			// Reset the real read timeout for the coming request
   203  			ctx.GetConn().SetReadTimeout(s.ReadTimeout) //nolint:errcheck
   204  		}
   205  
   206  		if s.EnableTrace {
   207  			cc = traceCtl.DoStart(c, ctx)
   208  			internalStats.Record(ctx.GetTraceInfo(), stats.ReadHeaderStart, err)
   209  			eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
   210  				internalStats.Record(ti, stats.ReadHeaderFinish, err)
   211  			})
   212  		}
   213  
   214  		ctx.Response.Header.SetNoDefaultDate(s.NoDefaultDate)
   215  		ctx.Response.Header.SetNoDefaultContentType(s.NoDefaultContentType)
   216  
   217  		if s.DisableHeaderNamesNormalizing {
   218  			ctx.Request.Header.DisableNormalizing()
   219  			ctx.Response.Header.DisableNormalizing()
   220  		}
   221  
   222  		// Read Headers
   223  		if err = req.ReadHeader(&ctx.Request.Header, zr); err == nil {
   224  			if s.EnableTrace {
   225  				// read header finished
   226  				if last := eventsToTrigger.pop(); last != nil {
   227  					last(ctx.GetTraceInfo(), err)
   228  				}
   229  				internalStats.Record(ctx.GetTraceInfo(), stats.ReadBodyStart, err)
   230  				eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
   231  					internalStats.Record(ti, stats.ReadBodyFinish, err)
   232  				})
   233  			}
   234  			// Read body
   235  			if s.StreamRequestBody {
   236  				err = req.ReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
   237  			} else {
   238  				err = req.ReadLimitBody(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
   239  			}
   240  		}
   241  
   242  		if s.EnableTrace {
   243  			if ctx.Request.Header.ContentLength() >= 0 {
   244  				ctx.GetTraceInfo().Stats().SetRecvSize(len(ctx.Request.Header.RawHeaders()) + ctx.Request.Header.ContentLength())
   245  			} else {
   246  				ctx.GetTraceInfo().Stats().SetRecvSize(0)
   247  			}
   248  			// read body finished
   249  			if last := eventsToTrigger.pop(); last != nil {
   250  				last(ctx.GetTraceInfo(), err)
   251  			}
   252  		}
   253  
   254  		if err != nil {
   255  			if errors.Is(err, errs.ErrNothingRead) {
   256  				return nil
   257  			}
   258  
   259  			if err == io.EOF {
   260  				return errUnexpectedEOF
   261  			}
   262  			writeErrorResponse(zw, ctx, serverName, err)
   263  			return
   264  		}
   265  
   266  		// 'Expect: 100-continue' request handling.
   267  		// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details.
   268  		if ctx.Request.MayContinue() {
   269  			// Allow the ability to deny reading the incoming request body
   270  			if s.ContinueHandler != nil {
   271  				if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest {
   272  					ctx.SetStatusCode(consts.StatusExpectationFailed)
   273  				}
   274  			}
   275  
   276  			if continueReadingRequest {
   277  				zw = ctx.GetWriter()
   278  				// Send 'HTTP/1.1 100 Continue' response.
   279  				_, err = zw.WriteBinary(bytestr.StrResponseContinue)
   280  				if err != nil {
   281  					return
   282  				}
   283  				err = zw.Flush()
   284  				if err != nil {
   285  					return
   286  				}
   287  
   288  				// Read body.
   289  				if zr == nil {
   290  					zr = ctx.GetReader()
   291  				}
   292  				if s.StreamRequestBody {
   293  					err = req.ContinueReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm)
   294  				} else {
   295  					err = req.ContinueReadBody(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm)
   296  				}
   297  				if err != nil {
   298  					writeErrorResponse(zw, ctx, serverName, err)
   299  					return
   300  				}
   301  			}
   302  		}
   303  
   304  		connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose()
   305  		isHTTP11 = ctx.Request.Header.IsHTTP11()
   306  
   307  		if serverName != nil {
   308  			ctx.Response.Header.SetServerBytes(serverName)
   309  		}
   310  		if s.EnableTrace {
   311  			internalStats.Record(ctx.GetTraceInfo(), stats.ServerHandleStart, err)
   312  			eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
   313  				internalStats.Record(ti, stats.ServerHandleFinish, err)
   314  			})
   315  		}
   316  		// Handle the request
   317  		//
   318  		// NOTE: All middlewares and business handler will be executed in this. And at this point, the request has been parsed
   319  		// and the route has been matched.
   320  		s.Core.ServeHTTP(cc, ctx)
   321  		if s.EnableTrace {
   322  			// application layer handle finished
   323  			if last := eventsToTrigger.pop(); last != nil {
   324  				last(ctx.GetTraceInfo(), err)
   325  			}
   326  		}
   327  
   328  		// exit check
   329  		if !s.Core.IsRunning() {
   330  			connectionClose = true
   331  		}
   332  
   333  		if !ctx.IsGet() && ctx.IsHead() {
   334  			ctx.Response.SkipBody = true
   335  		}
   336  
   337  		hijackHandler = ctx.GetHijackHandler()
   338  		ctx.SetHijackHandler(nil)
   339  
   340  		connectionClose = connectionClose || ctx.Response.ConnectionClose()
   341  		if connectionClose {
   342  			ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrClose)
   343  		} else if !isHTTP11 {
   344  			ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrKeepAlive)
   345  		}
   346  
   347  		if zw == nil {
   348  			zw = ctx.GetWriter()
   349  		}
   350  		if s.EnableTrace {
   351  			internalStats.Record(ctx.GetTraceInfo(), stats.WriteStart, err)
   352  			eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
   353  				internalStats.Record(ti, stats.WriteFinish, err)
   354  			})
   355  		}
   356  		if err = writeResponse(ctx, zw); err != nil {
   357  			return
   358  		}
   359  
   360  		if s.EnableTrace {
   361  			if ctx.Response.Header.ContentLength() > 0 {
   362  				ctx.GetTraceInfo().Stats().SetSendSize(ctx.Response.Header.GetHeaderLength() + ctx.Response.Header.ContentLength())
   363  			} else {
   364  				ctx.GetTraceInfo().Stats().SetSendSize(0)
   365  			}
   366  		}
   367  
   368  		// Release the zeroCopyReader before flush to prevent data race
   369  		if zr != nil {
   370  			zr.Release() //nolint:errcheck
   371  			zr = nil
   372  		}
   373  		// Flush the response.
   374  		if err = zw.Flush(); err != nil {
   375  			return
   376  		}
   377  		if s.EnableTrace {
   378  			// write finished
   379  			if last := eventsToTrigger.pop(); last != nil {
   380  				last(ctx.GetTraceInfo(), err)
   381  			}
   382  		}
   383  
   384  		// Release request body stream
   385  		if ctx.Request.IsBodyStream() {
   386  			err = ext.ReleaseBodyStream(ctx.RequestBodyStream())
   387  			if err != nil {
   388  				return
   389  			}
   390  		}
   391  
   392  		if hijackHandler != nil {
   393  			// Hijacked conn process the timeout by itself
   394  			err = ctx.GetConn().SetReadTimeout(0)
   395  			if err != nil {
   396  				return
   397  			}
   398  
   399  			// Hijack and block the connection until the hijackHandler return
   400  			s.HijackConnHandle(ctx.GetConn(), hijackHandler)
   401  			err = errHijacked
   402  			return
   403  		}
   404  
   405  		if connectionClose {
   406  			return errShortConnection
   407  		}
   408  		// Back to network layer to trigger.
   409  		// For now, only netpoll network mode has this feature.
   410  		if s.IdleTimeout == 0 {
   411  			return
   412  		}
   413  		// general case
   414  		if s.EnableTrace {
   415  			if shouldRecordInTraceError(err) {
   416  				traceCtl.DoFinish(cc, ctx, err)
   417  			} else {
   418  				traceCtl.DoFinish(cc, ctx, nil)
   419  			}
   420  		}
   421  
   422  		ctx.ResetWithoutConn()
   423  	}
   424  }
   425  
   426  func NewServer() *Server {
   427  	return &Server{
   428  		eventStackPool: &sync.Pool{
   429  			New: func() interface{} {
   430  				return &eventStack{}
   431  			},
   432  		},
   433  	}
   434  }
   435  
   436  func writeErrorResponse(zw network.Writer, ctx *app.RequestContext, serverName []byte, err error) network.Writer {
   437  	errorHandler := defaultErrorHandler
   438  
   439  	errorHandler(ctx, err)
   440  
   441  	if serverName != nil {
   442  		ctx.Response.Header.SetServerBytes(serverName)
   443  	}
   444  	ctx.SetConnectionClose()
   445  	if zw == nil {
   446  		zw = ctx.GetWriter()
   447  	}
   448  	writeResponse(ctx, zw) //nolint:errcheck
   449  	zw.Flush()             //nolint:errcheck
   450  	return zw
   451  }
   452  
   453  func writeResponse(ctx *app.RequestContext, w network.Writer) error {
   454  	// Skip default response writing logic if it has been hijacked
   455  	if ctx.Response.GetHijackWriter() != nil {
   456  		return ctx.Response.GetHijackWriter().Finalize()
   457  	}
   458  
   459  	err := resp.Write(&ctx.Response, w)
   460  	if err != nil {
   461  		return err
   462  	}
   463  
   464  	return err
   465  }
   466  
   467  func defaultErrorHandler(ctx *app.RequestContext, err error) {
   468  	if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
   469  		ctx.AbortWithMsg("Request timeout", consts.StatusRequestTimeout)
   470  	} else if errors.Is(err, errs.ErrBodyTooLarge) {
   471  		ctx.AbortWithMsg("Request Entity Too Large", consts.StatusRequestEntityTooLarge)
   472  	} else {
   473  		ctx.AbortWithMsg("Error when parsing request", consts.StatusBadRequest)
   474  	}
   475  }
   476  
   477  type eventStack []func(ti traceinfo.TraceInfo, err error)
   478  
   479  func (e *eventStack) isEmpty() bool {
   480  	return len(*e) == 0
   481  }
   482  
   483  func (e *eventStack) push(f func(ti traceinfo.TraceInfo, err error)) {
   484  	*e = append(*e, f)
   485  }
   486  
   487  func (e *eventStack) pop() func(ti traceinfo.TraceInfo, err error) {
   488  	if e.isEmpty() {
   489  		return nil
   490  	}
   491  	last := (*e)[len(*e)-1]
   492  	*e = (*e)[:len(*e)-1]
   493  	return last
   494  }
   495  
   496  func shouldRecordInTraceError(err error) bool {
   497  	if err == nil {
   498  		return false
   499  	}
   500  
   501  	if errors.Is(err, errs.ErrIdleTimeout) {
   502  		return false
   503  	}
   504  
   505  	if errors.Is(err, errs.ErrHijacked) {
   506  		return false
   507  	}
   508  
   509  	if errors.Is(err, errs.ErrShortConnection) {
   510  		return false
   511  	}
   512  
   513  	return true
   514  }