github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/ext/stream.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * The MIT License (MIT)
    17   *
    18   * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors
    19   *
    20   * Permission is hereby granted, free of charge, to any person obtaining a copy
    21   * of this software and associated documentation files (the "Software"), to deal
    22   * in the Software without restriction, including without limitation the rights
    23   * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    24   * copies of the Software, and to permit persons to whom the Software is
    25   * furnished to do so, subject to the following conditions:
    26   *
    27   * The above copyright notice and this permission notice shall be included in
    28   * all copies or substantial portions of the Software.
    29   *
    30   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    31   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    32   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    33   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    34   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    35   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    36   * THE SOFTWARE.
    37   *
    38   * This file may have been modified by CloudWeGo authors. All CloudWeGo
    39   * Modifications are Copyright 2022 CloudWeGo Authors.
    40   */
    41  
    42  package ext
    43  
    44  import (
    45  	"bytes"
    46  	"io"
    47  	"sync"
    48  
    49  	"github.com/cloudwego/hertz/internal/bytestr"
    50  	"github.com/cloudwego/hertz/pkg/common/bytebufferpool"
    51  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    52  	"github.com/cloudwego/hertz/pkg/common/utils"
    53  	"github.com/cloudwego/hertz/pkg/network"
    54  	"github.com/cloudwego/hertz/pkg/protocol"
    55  )
    56  
    57  var (
    58  	errChunkedStream = errs.New(errs.ErrChunkedStream, errs.ErrorTypePublic, nil)
    59  
    60  	bodyStreamPool = sync.Pool{
    61  		New: func() interface{} {
    62  			return &bodyStream{}
    63  		},
    64  	}
    65  )
    66  
    67  // Deprecated: Use github.com/cloudwego/hertz/pkg/protocol.NoBody instead.
    68  var NoBody = protocol.NoBody
    69  
    70  type bodyStream struct {
    71  	prefetchedBytes *bytes.Reader
    72  	reader          network.Reader
    73  	trailer         *protocol.Trailer
    74  	offset          int
    75  	contentLength   int
    76  	chunkLeft       int
    77  	// whether the chunk has reached the EOF
    78  	chunkEOF bool
    79  }
    80  
    81  func ReadBodyWithStreaming(zr network.Reader, contentLength, maxBodySize int, dst []byte) (b []byte, err error) {
    82  	if contentLength == -1 {
    83  		// handled in requestStream.Read()
    84  		return b, errChunkedStream
    85  	}
    86  	dst = dst[:0]
    87  
    88  	if maxBodySize <= 0 {
    89  		maxBodySize = maxContentLengthInStream
    90  	}
    91  	readN := maxBodySize
    92  	if readN > contentLength {
    93  		readN = contentLength
    94  	}
    95  	if readN > maxContentLengthInStream {
    96  		readN = maxContentLengthInStream
    97  	}
    98  
    99  	if contentLength >= 0 && maxBodySize >= contentLength {
   100  		b, err = appendBodyFixedSize(zr, dst, readN)
   101  	} else {
   102  		b, err = readBodyIdentity(zr, readN, dst)
   103  	}
   104  
   105  	if err != nil {
   106  		return b, err
   107  	}
   108  	if contentLength > maxBodySize {
   109  		return b, errBodyTooLarge
   110  	}
   111  	return b, nil
   112  }
   113  
   114  func AcquireBodyStream(b *bytebufferpool.ByteBuffer, r network.Reader, t *protocol.Trailer, contentLength int) io.Reader {
   115  	rs := bodyStreamPool.Get().(*bodyStream)
   116  	rs.prefetchedBytes = bytes.NewReader(b.B)
   117  	rs.reader = r
   118  	rs.contentLength = contentLength
   119  	rs.trailer = t
   120  	rs.chunkEOF = false
   121  
   122  	return rs
   123  }
   124  
   125  func (rs *bodyStream) Read(p []byte) (int, error) {
   126  	defer func() {
   127  		if rs.reader != nil {
   128  			rs.reader.Release() //nolint:errcheck
   129  		}
   130  	}()
   131  	if rs.contentLength == -1 {
   132  		if rs.chunkEOF {
   133  			return 0, io.EOF
   134  		}
   135  
   136  		if rs.chunkLeft == 0 {
   137  			chunkSize, err := utils.ParseChunkSize(rs.reader)
   138  			if err != nil {
   139  				return 0, err
   140  			}
   141  			if chunkSize == 0 {
   142  				err = ReadTrailer(rs.trailer, rs.reader)
   143  				if err == nil {
   144  					rs.chunkEOF = true
   145  					err = io.EOF
   146  				}
   147  				return 0, err
   148  			}
   149  
   150  			rs.chunkLeft = chunkSize
   151  		}
   152  		bytesToRead := len(p)
   153  
   154  		if bytesToRead > rs.chunkLeft {
   155  			bytesToRead = rs.chunkLeft
   156  		}
   157  
   158  		src, err := rs.reader.Peek(bytesToRead)
   159  		copied := copy(p, src)
   160  		rs.reader.Skip(copied) // nolint: errcheck
   161  		rs.chunkLeft -= copied
   162  
   163  		if err != nil {
   164  			if err == io.EOF {
   165  				err = io.ErrUnexpectedEOF
   166  			}
   167  			return copied, err
   168  		}
   169  
   170  		if rs.chunkLeft == 0 {
   171  			err = utils.SkipCRLF(rs.reader)
   172  			if err == io.EOF {
   173  				err = io.ErrUnexpectedEOF
   174  			}
   175  		}
   176  
   177  		return copied, err
   178  	}
   179  	if rs.offset == rs.contentLength {
   180  		return 0, io.EOF
   181  	}
   182  	var n int
   183  	var err error
   184  	// read from the pre-read buffer
   185  	if int(rs.prefetchedBytes.Size()) > rs.offset {
   186  		n, err = rs.prefetchedBytes.Read(p)
   187  		rs.offset += n
   188  		if rs.offset == rs.contentLength {
   189  			return n, io.EOF
   190  		}
   191  		if err != nil || len(p) == n {
   192  			return n, err
   193  		}
   194  	}
   195  
   196  	// read from the wire
   197  	m := len(p) - n
   198  	remain := rs.contentLength - rs.offset
   199  
   200  	if m > remain {
   201  		m = remain
   202  	}
   203  
   204  	if conn, ok := rs.reader.(io.Reader); ok {
   205  		m, err = conn.Read(p[n:])
   206  	} else {
   207  		var tmp []byte
   208  		tmp, err = rs.reader.Peek(m)
   209  		m = copy(p[n:], tmp)
   210  		rs.reader.Skip(m) // nolint: errcheck
   211  	}
   212  	rs.offset += m
   213  	n += m
   214  
   215  	if err != nil {
   216  		// the data on stream may be incomplete
   217  		if err == io.EOF {
   218  			if rs.offset != rs.contentLength && rs.contentLength != -2 {
   219  				err = io.ErrUnexpectedEOF
   220  			}
   221  			// ensure that skipRest works fine
   222  			rs.offset = rs.contentLength
   223  		}
   224  		return n, err
   225  	}
   226  	if rs.offset == rs.contentLength {
   227  		err = io.EOF
   228  	}
   229  	return n, err
   230  }
   231  
   232  func (rs *bodyStream) skipRest() error {
   233  	// The body length doesn't exceed the maxContentLengthInStream or
   234  	// the bodyStream has been skip rest
   235  	if rs.prefetchedBytes == nil {
   236  		return nil
   237  	}
   238  
   239  	// the request is chunked encoding
   240  	if rs.contentLength == -1 {
   241  		if rs.chunkEOF {
   242  			return nil
   243  		}
   244  
   245  		strCRLFLen := len(bytestr.StrCRLF)
   246  		for {
   247  			chunkSize, err := utils.ParseChunkSize(rs.reader)
   248  			if err != nil {
   249  				return err
   250  			}
   251  
   252  			if chunkSize == 0 {
   253  				rs.chunkEOF = true
   254  				return SkipTrailer(rs.reader)
   255  			}
   256  
   257  			err = rs.reader.Skip(chunkSize)
   258  			if err != nil {
   259  				return err
   260  			}
   261  
   262  			crlf, err := rs.reader.Peek(strCRLFLen)
   263  			if err != nil {
   264  				return err
   265  			}
   266  
   267  			if !bytes.Equal(crlf, bytestr.StrCRLF) {
   268  				return errBrokenChunk
   269  			}
   270  
   271  			err = rs.reader.Skip(strCRLFLen)
   272  			if err != nil {
   273  				return err
   274  			}
   275  		}
   276  	}
   277  	// max value of pSize is 8193, it's safe.
   278  	pSize := int(rs.prefetchedBytes.Size())
   279  	if rs.contentLength <= pSize || rs.offset == rs.contentLength {
   280  		return nil
   281  	}
   282  
   283  	needSkipLen := 0
   284  	if rs.offset > pSize {
   285  		needSkipLen = rs.contentLength - rs.offset
   286  	} else {
   287  		needSkipLen = rs.contentLength - pSize
   288  	}
   289  
   290  	// must skip size
   291  	for {
   292  		skip := rs.reader.Len()
   293  		if skip == 0 {
   294  			_, err := rs.reader.Peek(1)
   295  			if err != nil {
   296  				return err
   297  			}
   298  			skip = rs.reader.Len()
   299  		}
   300  		if skip > needSkipLen {
   301  			skip = needSkipLen
   302  		}
   303  		rs.reader.Skip(skip)
   304  		needSkipLen -= skip
   305  		if needSkipLen == 0 {
   306  			return nil
   307  		}
   308  	}
   309  }
   310  
   311  // ReleaseBodyStream releases the body stream.
   312  // Error of skipRest may be returned if there is one.
   313  //
   314  // NOTE: Be careful to use this method unless you know what it's for.
   315  func ReleaseBodyStream(requestReader io.Reader) (err error) {
   316  	if rs, ok := requestReader.(*bodyStream); ok {
   317  		err = rs.skipRest()
   318  		rs.reset()
   319  		bodyStreamPool.Put(rs)
   320  	}
   321  	return
   322  }
   323  
   324  func (rs *bodyStream) reset() {
   325  	rs.prefetchedBytes = nil
   326  	rs.offset = 0
   327  	rs.reader = nil
   328  	rs.trailer = nil
   329  	rs.chunkEOF = false
   330  	rs.chunkLeft = 0
   331  	rs.contentLength = 0
   332  }