github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/ext/common.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * The MIT License (MIT)
    17   *
    18   * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors
    19   *
    20   * Permission is hereby granted, free of charge, to any person obtaining a copy
    21   * of this software and associated documentation files (the "Software"), to deal
    22   * in the Software without restriction, including without limitation the rights
    23   * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    24   * copies of the Software, and to permit persons to whom the Software is
    25   * furnished to do so, subject to the following conditions:
    26   *
    27   * The above copyright notice and this permission notice shall be included in
    28   * all copies or substantial portions of the Software.
    29   *
    30   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    31   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    32   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    33   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    34   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    35   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    36   * THE SOFTWARE.
    37   *
    38   * This file may have been modified by CloudWeGo authors. All CloudWeGo
    39   * Modifications are Copyright 2022 CloudWeGo Authors.
    40   */
    41  
    42  package ext
    43  
    44  import (
    45  	"bytes"
    46  	"errors"
    47  	"fmt"
    48  	"io"
    49  	"strings"
    50  
    51  	"github.com/cloudwego/hertz/internal/bytesconv"
    52  	"github.com/cloudwego/hertz/internal/bytestr"
    53  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    54  	"github.com/cloudwego/hertz/pkg/common/hlog"
    55  	"github.com/cloudwego/hertz/pkg/common/utils"
    56  	"github.com/cloudwego/hertz/pkg/network"
    57  	"github.com/cloudwego/hertz/pkg/protocol"
    58  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    59  )
    60  
    61  const maxContentLengthInStream = 8 * 1024
    62  
    63  var errBrokenChunk = errs.NewPublic("cannot find crlf at the end of chunk").SetMeta("when read body chunk")
    64  
    65  func MustPeekBuffered(r network.Reader) []byte {
    66  	l := r.Len()
    67  	buf, err := r.Peek(l)
    68  	if len(buf) == 0 || err != nil {
    69  		panic(fmt.Sprintf("bufio.Reader.Peek() returned unexpected data (%q, %v)", buf, err))
    70  	}
    71  
    72  	return buf
    73  }
    74  
    75  func MustDiscard(r network.Reader, n int) {
    76  	if err := r.Skip(n); err != nil {
    77  		panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %s", n, err))
    78  	}
    79  }
    80  
    81  func ReadRawHeaders(dst, buf []byte) ([]byte, int, error) {
    82  	n := bytes.IndexByte(buf, '\n')
    83  	if n < 0 {
    84  		return dst[:0], 0, errNeedMore
    85  	}
    86  	if (n == 1 && buf[0] == '\r') || n == 0 {
    87  		// empty headers
    88  		return dst, n + 1, nil
    89  	}
    90  
    91  	n++
    92  	b := buf
    93  	m := n
    94  	for {
    95  		b = b[m:]
    96  		m = bytes.IndexByte(b, '\n')
    97  		if m < 0 {
    98  			return dst, 0, errNeedMore
    99  		}
   100  		m++
   101  		n += m
   102  		if (m == 2 && b[0] == '\r') || m == 1 {
   103  			dst = append(dst, buf[:n]...)
   104  			return dst, n, nil
   105  		}
   106  	}
   107  }
   108  
   109  func WriteBodyChunked(w network.Writer, r io.Reader) error {
   110  	vbuf := utils.CopyBufPool.Get()
   111  	buf := vbuf.([]byte)
   112  
   113  	var err error
   114  	var n int
   115  	for {
   116  		n, err = r.Read(buf)
   117  		if n == 0 {
   118  			if err == nil {
   119  				panic("BUG: io.Reader returned 0, nil")
   120  			}
   121  
   122  			if !errors.Is(err, io.EOF) {
   123  				hlog.SystemLogger().Warnf("writing chunked response body encountered an error from the reader, "+
   124  					"this may cause the short of the content in response body, error: %s", err.Error())
   125  			}
   126  
   127  			if err = WriteChunk(w, buf[:0], true); err != nil {
   128  				break
   129  			}
   130  
   131  			err = nil
   132  			break
   133  		}
   134  		if err = WriteChunk(w, buf[:n], true); err != nil {
   135  			break
   136  		}
   137  	}
   138  
   139  	utils.CopyBufPool.Put(vbuf)
   140  	return err
   141  }
   142  
   143  func WriteBodyFixedSize(w network.Writer, r io.Reader, size int64) error {
   144  	if size == 0 {
   145  		return nil
   146  	}
   147  	if size > consts.MaxSmallFileSize {
   148  		if err := w.Flush(); err != nil {
   149  			return err
   150  		}
   151  	}
   152  
   153  	if size > 0 {
   154  		r = io.LimitReader(r, size)
   155  	}
   156  
   157  	n, err := utils.CopyZeroAlloc(w, r)
   158  
   159  	if n != size && err == nil {
   160  		err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size)
   161  	}
   162  	return err
   163  }
   164  
   165  func appendBodyFixedSize(r network.Reader, dst []byte, n int) ([]byte, error) {
   166  	if n == 0 {
   167  		return dst, nil
   168  	}
   169  
   170  	offset := len(dst)
   171  	dstLen := offset + n
   172  	if cap(dst) < dstLen {
   173  		b := make([]byte, round2(dstLen))
   174  		copy(b, dst)
   175  		dst = b
   176  	}
   177  	dst = dst[:dstLen]
   178  
   179  	// Peek can get all data, otherwise it will through error
   180  	buf, err := r.Peek(n)
   181  	if err != nil {
   182  		if err == io.EOF {
   183  			err = io.ErrUnexpectedEOF
   184  		}
   185  		return dst[:offset], err
   186  	}
   187  	copy(dst[offset:], buf)
   188  	r.Skip(len(buf)) // nolint: errcheck
   189  	return dst, nil
   190  }
   191  
   192  func readBodyIdentity(r network.Reader, maxBodySize int, dst []byte) ([]byte, error) {
   193  	dst = dst[:cap(dst)]
   194  	if len(dst) == 0 {
   195  		dst = make([]byte, 1024)
   196  	}
   197  	offset := 0
   198  	for {
   199  		nn := r.Len()
   200  
   201  		if nn == 0 {
   202  			_, err := r.Peek(1)
   203  			if err != nil {
   204  				return dst[:offset], nil
   205  			}
   206  			nn = r.Len()
   207  		}
   208  		if nn >= (len(dst) - offset) {
   209  			nn = len(dst) - offset
   210  		}
   211  
   212  		buf, err := r.Peek(nn)
   213  		if err != nil {
   214  			return dst[:offset], err
   215  		}
   216  		copy(dst[offset:], buf)
   217  		r.Skip(nn) // nolint: errcheck
   218  
   219  		offset += nn
   220  		if maxBodySize > 0 && offset > maxBodySize {
   221  			return dst[:offset], errBodyTooLarge
   222  		}
   223  		if len(dst) == offset {
   224  			n := round2(2 * offset)
   225  			if maxBodySize > 0 && n > maxBodySize {
   226  				n = maxBodySize + 1
   227  			}
   228  			b := make([]byte, n)
   229  			copy(b, dst)
   230  			dst = b
   231  		}
   232  	}
   233  }
   234  
   235  func ReadBody(r network.Reader, contentLength, maxBodySize int, dst []byte) ([]byte, error) {
   236  	dst = dst[:0]
   237  	if contentLength >= 0 {
   238  		if maxBodySize > 0 && contentLength > maxBodySize {
   239  			return dst, errBodyTooLarge
   240  		}
   241  		return appendBodyFixedSize(r, dst, contentLength)
   242  	}
   243  	if contentLength == -1 {
   244  		return readBodyChunked(r, maxBodySize, dst)
   245  	}
   246  	return readBodyIdentity(r, maxBodySize, dst)
   247  }
   248  
   249  func LimitedReaderSize(r io.Reader) int64 {
   250  	lr, ok := r.(*io.LimitedReader)
   251  	if !ok {
   252  		return -1
   253  	}
   254  	return lr.N
   255  }
   256  
   257  func readBodyChunked(r network.Reader, maxBodySize int, dst []byte) ([]byte, error) {
   258  	if len(dst) > 0 {
   259  		panic("BUG: expected zero-length buffer")
   260  	}
   261  
   262  	strCRLFLen := len(bytestr.StrCRLF)
   263  	for {
   264  		chunkSize, err := utils.ParseChunkSize(r)
   265  		if err != nil {
   266  			return dst, err
   267  		}
   268  		// If it is the end of chunk, Read CRLF after reading trailer
   269  		if chunkSize == 0 {
   270  			return dst, nil
   271  		}
   272  		if maxBodySize > 0 && len(dst)+chunkSize > maxBodySize {
   273  			return dst, errBodyTooLarge
   274  		}
   275  		dst, err = appendBodyFixedSize(r, dst, chunkSize+strCRLFLen)
   276  		if err != nil {
   277  			return dst, err
   278  		}
   279  		if !bytes.Equal(dst[len(dst)-strCRLFLen:], bytestr.StrCRLF) {
   280  			return dst, errBrokenChunk
   281  		}
   282  		dst = dst[:len(dst)-strCRLFLen]
   283  	}
   284  }
   285  
   286  func round2(n int) int {
   287  	if n <= 0 {
   288  		return 0
   289  	}
   290  	n--
   291  	x := uint(0)
   292  	for n > 0 {
   293  		n >>= 1
   294  		x++
   295  	}
   296  	return 1 << x
   297  }
   298  
   299  func WriteChunk(w network.Writer, b []byte, withFlush bool) (err error) {
   300  	n := len(b)
   301  	if err = bytesconv.WriteHexInt(w, n); err != nil {
   302  		return err
   303  	}
   304  
   305  	w.WriteBinary(bytestr.StrCRLF) //nolint:errcheck
   306  	if _, err = w.WriteBinary(b); err != nil {
   307  		return err
   308  	}
   309  
   310  	// If it is the end of chunk, write CRLF after writing trailer
   311  	if n > 0 {
   312  		w.WriteBinary(bytestr.StrCRLF) //nolint:errcheck
   313  	}
   314  
   315  	if !withFlush {
   316  		return nil
   317  	}
   318  	err = w.Flush()
   319  	return
   320  }
   321  
   322  func isOnlyCRLF(b []byte) bool {
   323  	for _, ch := range b {
   324  		if ch != '\r' && ch != '\n' {
   325  			return false
   326  		}
   327  	}
   328  	return true
   329  }
   330  
   331  func BufferSnippet(b []byte) string {
   332  	n := len(b)
   333  	start := 20
   334  	end := n - start
   335  	if start >= end {
   336  		start = n
   337  		end = n
   338  	}
   339  	bStart, bEnd := b[:start], b[end:]
   340  	if len(bEnd) == 0 {
   341  		return fmt.Sprintf("%q", b)
   342  	}
   343  	return fmt.Sprintf("%q...%q", bStart, bEnd)
   344  }
   345  
   346  func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl int) {
   347  	nv = ov
   348  	length := len(ov)
   349  	if length <= 0 {
   350  		return
   351  	}
   352  	write := 0
   353  	shrunk := 0
   354  	lineStart := false
   355  	for read := 0; read < length; read++ {
   356  		c := ov[read]
   357  		if c == '\r' || c == '\n' {
   358  			shrunk++
   359  			if c == '\n' {
   360  				lineStart = true
   361  			}
   362  			continue
   363  		} else if lineStart && c == '\t' {
   364  			c = ' '
   365  		} else {
   366  			lineStart = false
   367  		}
   368  		nv[write] = c
   369  		write++
   370  	}
   371  
   372  	nv = nv[:write]
   373  	copy(ob[write:], ob[write+shrunk:])
   374  
   375  	// Check if we need to skip \r\n or just \n
   376  	skip := 0
   377  	if ob[write] == '\r' {
   378  		if ob[write+1] == '\n' {
   379  			skip += 2
   380  		} else {
   381  			skip++
   382  		}
   383  	} else if ob[write] == '\n' {
   384  		skip++
   385  	}
   386  
   387  	nb = ob[write+skip : len(ob)-shrunk]
   388  	nhl = headerLength - shrunk
   389  	return
   390  }
   391  
   392  func stripSpace(b []byte) []byte {
   393  	for len(b) > 0 && b[0] == ' ' {
   394  		b = b[1:]
   395  	}
   396  	for len(b) > 0 && b[len(b)-1] == ' ' {
   397  		b = b[:len(b)-1]
   398  	}
   399  	return b
   400  }
   401  
   402  func SkipTrailer(r network.Reader) error {
   403  	n := 1
   404  	for {
   405  		err := trySkipTrailer(r, n)
   406  		if err == nil {
   407  			return nil
   408  		}
   409  		if !errors.Is(err, errs.ErrNeedMore) {
   410  			return err
   411  		}
   412  		// No more data available on the wire, try block peek(by netpoll)
   413  		if n == r.Len() {
   414  			n++
   415  
   416  			continue
   417  		}
   418  		n = r.Len()
   419  	}
   420  }
   421  
   422  func trySkipTrailer(r network.Reader, n int) error {
   423  	b, err := r.Peek(n)
   424  	if len(b) == 0 {
   425  		// Return ErrTimeout on any timeout.
   426  		if err != nil && strings.Contains(err.Error(), "timeout") {
   427  			return errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read response header")
   428  		}
   429  
   430  		if n == 1 || err == io.EOF {
   431  			return io.EOF
   432  		}
   433  
   434  		return errs.NewPublicf("error when reading request trailer: %w", err)
   435  	}
   436  	b = MustPeekBuffered(r)
   437  	headersLen, errParse := skipTrailer(b)
   438  	if errParse != nil {
   439  		if err == io.EOF {
   440  			return err
   441  		}
   442  		return HeaderError("response", err, errParse, b)
   443  	}
   444  	MustDiscard(r, headersLen)
   445  	return nil
   446  }
   447  
   448  func skipTrailer(buf []byte) (int, error) {
   449  	skip := 0
   450  	strCRLFLen := len(bytestr.StrCRLF)
   451  	for {
   452  		index := bytes.Index(buf, bytestr.StrCRLF)
   453  		if index == -1 {
   454  			return 0, errs.ErrNeedMore
   455  		}
   456  
   457  		buf = buf[index+strCRLFLen:]
   458  		skip += index + strCRLFLen
   459  
   460  		if index == 0 {
   461  			return skip, nil
   462  		}
   463  	}
   464  }
   465  
   466  func ReadTrailer(t *protocol.Trailer, r network.Reader) error {
   467  	n := 1
   468  	for {
   469  		err := tryReadTrailer(t, r, n)
   470  		if err == nil {
   471  			return nil
   472  		}
   473  		if !errors.Is(err, errs.ErrNeedMore) {
   474  			t.ResetSkipNormalize()
   475  			return err
   476  		}
   477  		// No more data available on the wire, try block peek(by netpoll)
   478  		if n == r.Len() {
   479  			n++
   480  
   481  			continue
   482  		}
   483  		n = r.Len()
   484  	}
   485  }
   486  
   487  func tryReadTrailer(t *protocol.Trailer, r network.Reader, n int) error {
   488  	b, err := r.Peek(n)
   489  	if len(b) == 0 {
   490  		// Return ErrTimeout on any timeout.
   491  		if err != nil && strings.Contains(err.Error(), "timeout") {
   492  			return errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read response header")
   493  		}
   494  
   495  		if n == 1 || err == io.EOF {
   496  			return io.EOF
   497  		}
   498  
   499  		return errs.NewPublicf("error when reading request trailer: %w", err)
   500  	}
   501  	b = MustPeekBuffered(r)
   502  	headersLen, errParse := parseTrailer(t, b)
   503  	if errParse != nil {
   504  		if err == io.EOF {
   505  			return err
   506  		}
   507  		return HeaderError("response", err, errParse, b)
   508  	}
   509  	MustDiscard(r, headersLen)
   510  	return nil
   511  }
   512  
   513  func parseTrailer(t *protocol.Trailer, buf []byte) (int, error) {
   514  	// Skip any 0 length chunk.
   515  	if buf[0] == '0' {
   516  		skip := len(bytestr.StrCRLF) + 1
   517  		if len(buf) < skip {
   518  			return 0, io.EOF
   519  		}
   520  		buf = buf[skip:]
   521  	}
   522  
   523  	var s HeaderScanner
   524  	s.B = buf
   525  	s.DisableNormalizing = t.IsDisableNormalizing()
   526  	var err error
   527  	for s.Next() {
   528  		if len(s.Key) > 0 {
   529  			if bytes.IndexByte(s.Key, ' ') != -1 || bytes.IndexByte(s.Key, '\t') != -1 {
   530  				err = fmt.Errorf("invalid trailer key %q", s.Key)
   531  				continue
   532  			}
   533  			err = t.UpdateArgBytes(s.Key, s.Value)
   534  		}
   535  	}
   536  	if s.Err != nil {
   537  		return 0, s.Err
   538  	}
   539  	if err != nil {
   540  		return 0, err
   541  	}
   542  	return s.HLen, nil
   543  }
   544  
   545  // writeTrailer writes response trailer to w
   546  func WriteTrailer(t *protocol.Trailer, w network.Writer) error {
   547  	_, err := w.WriteBinary(t.Header())
   548  	return err
   549  }