github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/req/request.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * The MIT License (MIT)
    17   *
    18   * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors
    19   *
    20   * Permission is hereby granted, free of charge, to any person obtaining a copy
    21   * of this software and associated documentation files (the "Software"), to deal
    22   * in the Software without restriction, including without limitation the rights
    23   * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    24   * copies of the Software, and to permit persons to whom the Software is
    25   * furnished to do so, subject to the following conditions:
    26   *
    27   * The above copyright notice and this permission notice shall be included in
    28   * all copies or substantial portions of the Software.
    29   *
    30   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    31   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    32   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    33   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    34   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    35   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    36   * THE SOFTWARE.
    37   *
    38   * This file may have been modified by CloudWeGo authors. All CloudWeGo
    39   * Modifications are Copyright 2022 CloudWeGo Authors.
    40   */
    41  
    42  package req
    43  
    44  import (
    45  	"bytes"
    46  	"encoding/base64"
    47  	"errors"
    48  	"fmt"
    49  	"io"
    50  	"mime/multipart"
    51  
    52  	"github.com/cloudwego/hertz/internal/bytestr"
    53  	"github.com/cloudwego/hertz/pkg/common/bytebufferpool"
    54  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    55  	"github.com/cloudwego/hertz/pkg/network"
    56  	"github.com/cloudwego/hertz/pkg/protocol"
    57  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    58  	"github.com/cloudwego/hertz/pkg/protocol/http1/ext"
    59  )
    60  
    61  var (
    62  	errRequestHostRequired = errs.NewPublic("missing required Host header in request")
    63  	errGetOnly             = errs.NewPublic("non-GET request received")
    64  	errBodyTooLarge        = errs.New(errs.ErrBodyTooLarge, errs.ErrorTypePublic, "http1/req")
    65  )
    66  
    67  type h1Request struct {
    68  	*protocol.Request
    69  }
    70  
    71  // String returns request representation.
    72  //
    73  // Returns error message instead of request representation on error.
    74  //
    75  // Use Write instead of String for performance-critical code.
    76  func (h1Req *h1Request) String() string {
    77  	w := bytebufferpool.Get()
    78  	zw := network.NewWriter(w)
    79  	if err := Write(h1Req.Request, zw); err != nil {
    80  		return err.Error()
    81  	}
    82  	if err := zw.Flush(); err != nil {
    83  		return err.Error()
    84  	}
    85  	s := string(w.B)
    86  	bytebufferpool.Put(w)
    87  	return s
    88  }
    89  
    90  func GetHTTP1Request(req *protocol.Request) fmt.Stringer {
    91  	return &h1Request{req}
    92  }
    93  
    94  // ReadHeaderAndLimitBody reads request from the given r, limiting the body size.
    95  //
    96  // If maxBodySize > 0 and the body size exceeds maxBodySize,
    97  // then errBodyTooLarge is returned.
    98  //
    99  // RemoveMultipartFormFiles or Reset must be called after
   100  // reading multipart/form-data request in order to delete temporarily
   101  // uploaded files.
   102  //
   103  // If MayContinue returns true, the caller must:
   104  //
   105  //   - Either send StatusExpectationFailed response if request headers don't
   106  //     satisfy the caller.
   107  //   - Or send StatusContinue response before reading request body
   108  //     with ContinueReadBody.
   109  //   - Or close the connection.
   110  //
   111  // io.EOF is returned if r is closed before reading the first header byte.
   112  func ReadHeaderAndLimitBody(req *protocol.Request, r network.Reader, maxBodySize int, preParse ...bool) error {
   113  	var parse bool
   114  	if len(preParse) == 0 {
   115  		parse = true
   116  	} else {
   117  		parse = preParse[0]
   118  	}
   119  	req.ResetSkipHeader()
   120  
   121  	if err := ReadHeader(&req.Header, r); err != nil {
   122  		return err
   123  	}
   124  
   125  	return ReadLimitBody(req, r, maxBodySize, false, parse)
   126  }
   127  
   128  // Read reads request (including body) from the given r.
   129  //
   130  // RemoveMultipartFormFiles or Reset must be called after
   131  // reading multipart/form-data request in order to delete temporarily
   132  // uploaded files.
   133  //
   134  // If MayContinue returns true, the caller must:
   135  //
   136  //   - Either send StatusExpectationFailed response if request headers don't
   137  //     satisfy the caller.
   138  //   - Or send StatusContinue response before reading request body
   139  //     with ContinueReadBody.
   140  //   - Or close the connection.
   141  //
   142  // io.EOF is returned if r is closed before reading the first header byte.
   143  func Read(req *protocol.Request, r network.Reader, preParse ...bool) error {
   144  	return ReadHeaderAndLimitBody(req, r, 0, preParse...)
   145  }
   146  
   147  // Write writes request to w.
   148  //
   149  // Write doesn't flush request to w for performance reasons.
   150  //
   151  // See also WriteTo.
   152  func Write(req *protocol.Request, w network.Writer) error {
   153  	return write(req, w, false)
   154  }
   155  
   156  // ProxyWrite is like Write but writes the request in the form
   157  // expected by an HTTP proxy. In particular, ProxyWrite writes the
   158  // initial Request-URI line of the request with an absolute URI, per
   159  // section 5.3 of RFC 7230, including the scheme and host.
   160  func ProxyWrite(req *protocol.Request, w network.Writer) error {
   161  	return write(req, w, true)
   162  }
   163  
   164  // write writes request to w.
   165  // It supports proxy situation.
   166  func write(req *protocol.Request, w network.Writer, usingProxy bool) error {
   167  	if len(req.Header.Host()) == 0 || req.IsURIParsed() {
   168  		uri := req.URI()
   169  		host := uri.Host()
   170  		if len(host) == 0 {
   171  			return errRequestHostRequired
   172  		}
   173  
   174  		if len(req.Header.Host()) == 0 {
   175  			req.Header.SetHostBytes(host)
   176  		}
   177  
   178  		ruri := uri.RequestURI()
   179  		if bytes.Equal(req.Method(), bytestr.StrConnect) {
   180  			ruri = uri.Host()
   181  		} else if usingProxy {
   182  			ruri = uri.FullURI()
   183  		}
   184  
   185  		req.Header.SetRequestURIBytes(ruri)
   186  
   187  		if len(uri.Username()) > 0 {
   188  			// RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key
   189  			// So we are free to use RequestHeader.bufKV.value as a scratch pad for
   190  			// the base64 encoding.
   191  			nl := len(uri.Username()) + len(uri.Password()) + 1
   192  			nb := nl + len(bytestr.StrBasicSpace)
   193  			tl := nb + base64.StdEncoding.EncodedLen(nl)
   194  
   195  			req.Header.InitBufValue(tl)
   196  			buf := req.Header.GetBufValue()[:0]
   197  			buf = append(buf, uri.Username()...)
   198  			buf = append(buf, bytestr.StrColon...)
   199  			buf = append(buf, uri.Password()...)
   200  			buf = append(buf, bytestr.StrBasicSpace...)
   201  			base64.StdEncoding.Encode(buf[nb:tl], buf[:nl])
   202  			req.Header.SetBytesKV(bytestr.StrAuthorization, buf[nl:tl])
   203  		}
   204  	}
   205  
   206  	if req.IsBodyStream() {
   207  		return writeBodyStream(req, w)
   208  	}
   209  
   210  	body := req.BodyBytes()
   211  	err := handleMultipart(req)
   212  	if err != nil {
   213  		return fmt.Errorf("error when handle multipart: %s", err)
   214  	}
   215  	if req.OnlyMultipartForm() {
   216  		m, _ := req.MultipartForm() // req.multipartForm != nil
   217  		body, err = protocol.MarshalMultipartForm(m, req.MultipartFormBoundary())
   218  		if err != nil {
   219  			return fmt.Errorf("error when marshaling multipart form: %s", err)
   220  		}
   221  		req.Header.SetMultipartFormBoundary(req.MultipartFormBoundary())
   222  	}
   223  
   224  	hasBody := false
   225  	if len(body) == 0 {
   226  		body = req.PostArgString()
   227  	}
   228  	if len(body) != 0 || !req.Header.IgnoreBody() {
   229  		hasBody = true
   230  		req.Header.SetContentLength(len(body))
   231  	}
   232  
   233  	header := req.Header.Header()
   234  	if _, err := w.WriteBinary(header); err != nil {
   235  		return err
   236  	}
   237  
   238  	// Write body
   239  	if hasBody {
   240  		w.WriteBinary(body) //nolint:errcheck
   241  	} else if len(body) > 0 {
   242  		return fmt.Errorf("non-zero body for non-POST request. body=%q", body)
   243  	}
   244  	return nil
   245  }
   246  
   247  // ContinueReadBodyStream reads request body in stream if request header contains
   248  // 'Expect: 100-continue'.
   249  //
   250  // The caller must send StatusContinue response before calling this method.
   251  //
   252  // If maxBodySize > 0 and the body size exceeds maxBodySize,
   253  // then errBodyTooLarge is returned.
   254  func ContinueReadBodyStream(req *protocol.Request, zr network.Reader, maxBodySize int, preParseMultipartForm ...bool) error {
   255  	var err error
   256  	contentLength := req.Header.ContentLength()
   257  	if contentLength > 0 {
   258  		if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] {
   259  			// Pre-read multipart form data of known length.
   260  			// This way we limit memory usage for large file uploads, since their contents
   261  			// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
   262  			req.SetMultipartFormBoundary(string(req.Header.MultipartFormBoundary()))
   263  			if len(req.MultipartFormBoundary()) > 0 && len(req.Header.PeekContentEncoding()) == 0 {
   264  				err := protocol.ParseMultipartForm(zr.(io.Reader), req, contentLength, consts.DefaultMaxInMemoryFileSize)
   265  				if err != nil {
   266  					req.Reset()
   267  				}
   268  				return err
   269  			}
   270  		}
   271  	}
   272  
   273  	if contentLength == -2 {
   274  		// identity body has no sense for http requests, since
   275  		// the end of body is determined by connection close.
   276  		// So just ignore request body for requests without
   277  		// 'Content-Length' and 'Transfer-Encoding' headers.
   278  
   279  		// refer to https://tools.ietf.org/html/rfc7230#section-3.3.2
   280  		if !req.Header.IgnoreBody() {
   281  			req.Header.SetContentLength(0)
   282  		}
   283  		return nil
   284  	}
   285  
   286  	bodyBuf := req.BodyBuffer()
   287  	bodyBuf.Reset()
   288  	bodyBuf.B, err = ext.ReadBodyWithStreaming(zr, contentLength, maxBodySize, bodyBuf.B)
   289  	if err != nil {
   290  		if errors.Is(err, errs.ErrBodyTooLarge) {
   291  			req.Header.SetContentLength(contentLength)
   292  			req.ConstructBodyStream(bodyBuf, ext.AcquireBodyStream(bodyBuf, zr, req.Header.Trailer(), contentLength))
   293  
   294  			return nil
   295  		}
   296  		if errors.Is(err, errs.ErrChunkedStream) {
   297  			req.ConstructBodyStream(bodyBuf, ext.AcquireBodyStream(bodyBuf, zr, req.Header.Trailer(), contentLength))
   298  			return nil
   299  		}
   300  		req.Reset()
   301  		return err
   302  	}
   303  
   304  	req.ConstructBodyStream(bodyBuf, ext.AcquireBodyStream(bodyBuf, zr, req.Header.Trailer(), contentLength))
   305  	return nil
   306  }
   307  
   308  func ContinueReadBody(req *protocol.Request, r network.Reader, maxBodySize int, preParseMultipartForm ...bool) error {
   309  	var err error
   310  	contentLength := req.Header.ContentLength()
   311  	if contentLength > 0 {
   312  		if maxBodySize > 0 && contentLength > maxBodySize {
   313  			return errBodyTooLarge
   314  		}
   315  
   316  		if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] {
   317  			// Pre-read multipart form data of known length.
   318  			// This way we limit memory usage for large file uploads, since their contents
   319  			// is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
   320  			req.SetMultipartFormBoundary(string(req.Header.MultipartFormBoundary()))
   321  			if len(req.MultipartFormBoundary()) > 0 && len(req.Header.PeekContentEncoding()) == 0 {
   322  				err := protocol.ParseMultipartForm(r.(io.Reader), req, contentLength, consts.DefaultMaxInMemoryFileSize)
   323  				if err != nil {
   324  					req.Reset()
   325  				}
   326  				return err
   327  			}
   328  		}
   329  
   330  		// This optimization is just suitable for ping-pong case and the ext.ReadBody is
   331  		// a common function, so we just handle this situation before ext.ReadBody
   332  		buf, err := r.Peek(contentLength)
   333  		if err != nil {
   334  			return err
   335  		}
   336  		r.Skip(contentLength) // nolint: errcheck
   337  		req.SetBodyRaw(buf)
   338  		return nil
   339  	}
   340  
   341  	if contentLength == -2 {
   342  		// identity body has no sense for http requests, since
   343  		// the end of body is determined by connection close.
   344  		// So just ignore request body for requests without
   345  		// 'Content-Length' and 'Transfer-Encoding' headers.
   346  
   347  		// refer to https://tools.ietf.org/html/rfc7230#section-3.3.2
   348  		if !req.Header.IgnoreBody() {
   349  			req.Header.SetContentLength(0)
   350  		}
   351  		return nil
   352  	}
   353  
   354  	bodyBuf := req.BodyBuffer()
   355  	bodyBuf.Reset()
   356  	bodyBuf.B, err = ext.ReadBody(r, contentLength, maxBodySize, bodyBuf.B)
   357  	if err != nil {
   358  		req.Reset()
   359  		return err
   360  	}
   361  
   362  	if req.Header.ContentLength() == -1 {
   363  		err = ext.ReadTrailer(req.Header.Trailer(), r)
   364  		if err != nil && err != io.EOF {
   365  			return err
   366  		}
   367  	}
   368  
   369  	req.Header.SetContentLength(len(bodyBuf.B))
   370  	return nil
   371  }
   372  
   373  func ReadBodyStream(req *protocol.Request, zr network.Reader, maxBodySize int, getOnly, preParseMultipartForm bool) error {
   374  	if getOnly && !req.Header.IsGet() {
   375  		return errGetOnly
   376  	}
   377  
   378  	if req.MayContinue() {
   379  		// 'Expect: 100-continue' header found. Let the caller deciding
   380  		// whether to read request body or
   381  		// to return StatusExpectationFailed.
   382  		return nil
   383  	}
   384  
   385  	return ContinueReadBodyStream(req, zr, maxBodySize, preParseMultipartForm)
   386  }
   387  
   388  func ReadLimitBody(req *protocol.Request, r network.Reader, maxBodySize int, getOnly, preParseMultipartForm bool) error {
   389  	// Do not reset the request here - the caller must reset it before
   390  	// calling this method.
   391  	if getOnly && !req.Header.IsGet() {
   392  		return errGetOnly
   393  	}
   394  
   395  	if req.MayContinue() {
   396  		// 'Expect: 100-continue' header found. Let the caller deciding
   397  		// whether to read request body or
   398  		// to return StatusExpectationFailed.
   399  		return nil
   400  	}
   401  
   402  	return ContinueReadBody(req, r, maxBodySize, preParseMultipartForm)
   403  }
   404  
   405  func writeBodyStream(req *protocol.Request, w network.Writer) error {
   406  	var err error
   407  
   408  	contentLength := req.Header.ContentLength()
   409  	if contentLength < 0 {
   410  		lrSize := ext.LimitedReaderSize(req.BodyStream())
   411  		if lrSize >= 0 {
   412  			contentLength = int(lrSize)
   413  			if int64(contentLength) != lrSize {
   414  				contentLength = -1
   415  			}
   416  			if contentLength >= 0 {
   417  				req.Header.SetContentLength(contentLength)
   418  			}
   419  		}
   420  	}
   421  	if contentLength >= 0 {
   422  		if err = WriteHeader(&req.Header, w); err == nil {
   423  			err = ext.WriteBodyFixedSize(w, req.BodyStream(), int64(contentLength))
   424  		}
   425  	} else {
   426  		req.Header.SetContentLength(-1)
   427  		err = WriteHeader(&req.Header, w)
   428  		if err == nil {
   429  			err = ext.WriteBodyChunked(w, req.BodyStream())
   430  		}
   431  		if err == nil {
   432  			err = ext.WriteTrailer(req.Header.Trailer(), w)
   433  		}
   434  	}
   435  	err1 := req.CloseBodyStream()
   436  	if err == nil {
   437  		err = err1
   438  	}
   439  	return err
   440  }
   441  
   442  func handleMultipart(req *protocol.Request) error {
   443  	if len(req.MultipartFiles()) == 0 && len(req.MultipartFields()) == 0 {
   444  		return nil
   445  	}
   446  	var err error
   447  	bodyBuffer := &bytes.Buffer{}
   448  	w := multipart.NewWriter(bodyBuffer)
   449  	if len(req.MultipartFiles()) > 0 {
   450  		for _, f := range req.MultipartFiles() {
   451  			if f.Reader != nil {
   452  				err = protocol.WriteMultipartFormFile(w, f.ParamName, f.Name, f.Reader)
   453  			} else {
   454  				err = protocol.AddFile(w, f.ParamName, f.Name)
   455  			}
   456  			if err != nil {
   457  				return err
   458  			}
   459  		}
   460  	}
   461  
   462  	if len(req.MultipartFields()) > 0 {
   463  		for _, mf := range req.MultipartFields() {
   464  			if err = protocol.AddMultipartFormField(w, mf); err != nil {
   465  				return err
   466  			}
   467  		}
   468  	}
   469  
   470  	req.Header.Set(consts.HeaderContentType, w.FormDataContentType())
   471  	if err = w.Close(); err != nil {
   472  		return err
   473  	}
   474  
   475  	r := multipart.NewReader(bodyBuffer, w.Boundary())
   476  	f, err := r.ReadForm(int64(bodyBuffer.Len()))
   477  	if err != nil {
   478  		return err
   479  	}
   480  	protocol.SetMultipartFormWithBoundary(req, f, w.Boundary())
   481  
   482  	return nil
   483  }