github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/ext/stream_test.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  package ext
    17  
    18  import (
    19  	"bytes"
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"testing"
    24  
    25  	"github.com/cloudwego/hertz/pkg/common/bytebufferpool"
    26  	"github.com/cloudwego/hertz/pkg/common/test/assert"
    27  	"github.com/cloudwego/hertz/pkg/common/test/mock"
    28  	"github.com/cloudwego/hertz/pkg/protocol"
    29  )
    30  
    31  func createChunkedBody(body, rest []byte, trailer map[string]string, hasTrailer bool) []byte {
    32  	var b []byte
    33  	chunkSize := 1
    34  	for len(body) > 0 {
    35  		if chunkSize > len(body) {
    36  			chunkSize = len(body)
    37  		}
    38  		b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...)
    39  		b = append(b, body[:chunkSize]...)
    40  		b = append(b, []byte("\r\n")...)
    41  		body = body[chunkSize:]
    42  		chunkSize++
    43  	}
    44  	if hasTrailer {
    45  		b = append(b, "0\r\n"...)
    46  		for k, v := range trailer {
    47  			b = append(b, k...)
    48  			b = append(b, ": "...)
    49  			b = append(b, v...)
    50  			b = append(b, "\r\n"...)
    51  		}
    52  		b = append(b, "\r\n"...)
    53  	}
    54  	return append(b, rest...)
    55  }
    56  
    57  func testChunkedSkipRest(t *testing.T, data, rest string) {
    58  	var pool bytebufferpool.Pool
    59  	reader := mock.NewZeroCopyReader(data)
    60  
    61  	bs := AcquireBodyStream(pool.Get(), reader, &protocol.Trailer{}, -1)
    62  	err := bs.(*bodyStream).skipRest()
    63  	assert.Nil(t, err)
    64  
    65  	rest_data, err := io.ReadAll(reader)
    66  	assert.Nil(t, err)
    67  	assert.DeepEqual(t, rest, string(rest_data))
    68  }
    69  
    70  func testChunkedSkipRestWithBodySize(t *testing.T, bodySize int) {
    71  	body := mock.CreateFixedBody(bodySize)
    72  	rest := mock.CreateFixedBody(bodySize)
    73  	data := createChunkedBody(body, rest, map[string]string{"foo": "bar"}, true)
    74  
    75  	testChunkedSkipRest(t, string(data), string(rest))
    76  }
    77  
    78  func TestChunkedSkipRest(t *testing.T) {
    79  	t.Parallel()
    80  
    81  	testChunkedSkipRest(t, "0\r\n\r\n", "")
    82  	testChunkedSkipRest(t, "0\r\n\r\nHTTP/1.1 / POST", "HTTP/1.1 / POST")
    83  	testChunkedSkipRest(t, "0\r\nHertz: test\r\nfoo: bar\r\n\r\nHTTP/1.1 / POST", "HTTP/1.1 / POST")
    84  
    85  	testChunkedSkipRestWithBodySize(t, 5)
    86  
    87  	// medium-size body
    88  	testChunkedSkipRestWithBodySize(t, 43488)
    89  
    90  	// big body
    91  	testChunkedSkipRestWithBodySize(t, 3*1024*1024)
    92  }
    93  
    94  func TestBodyStream_Reset(t *testing.T) {
    95  	t.Parallel()
    96  	bs := bodyStream{
    97  		prefetchedBytes: bytes.NewReader([]byte("aaa")),
    98  		reader:          mock.NewZeroCopyReader("bbb"),
    99  		trailer:         &protocol.Trailer{},
   100  		offset:          10,
   101  		contentLength:   20,
   102  		chunkLeft:       50,
   103  		chunkEOF:        true,
   104  	}
   105  
   106  	bs.reset()
   107  
   108  	assert.Nil(t, bs.prefetchedBytes)
   109  	assert.Nil(t, bs.reader)
   110  	assert.Nil(t, bs.trailer)
   111  	assert.DeepEqual(t, 0, bs.offset)
   112  	assert.DeepEqual(t, 0, bs.contentLength)
   113  	assert.DeepEqual(t, 0, bs.chunkLeft)
   114  	assert.False(t, bs.chunkEOF)
   115  }
   116  
   117  func TestReadBodyWithStreaming(t *testing.T) {
   118  	t.Run("TestBodyFixedSize", func(t *testing.T) {
   119  		bodySize := 1024
   120  		body := mock.CreateFixedBody(bodySize)
   121  		reader := mock.NewZeroCopyReader(string(body))
   122  		dst, err := ReadBodyWithStreaming(reader, bodySize, -1, nil)
   123  		assert.Nil(t, err)
   124  		assert.DeepEqual(t, body, dst)
   125  	})
   126  
   127  	t.Run("TestBodyFixedSizeMaxContentLength", func(t *testing.T) {
   128  		bodySize := 8 * 1024 * 2
   129  		body := mock.CreateFixedBody(bodySize)
   130  		reader := mock.NewZeroCopyReader(string(body))
   131  		dst, err := ReadBodyWithStreaming(reader, bodySize, 8*1024*10, nil)
   132  		assert.Nil(t, err)
   133  		assert.DeepEqual(t, body[:maxContentLengthInStream], dst)
   134  	})
   135  
   136  	t.Run("TestBodyIdentity", func(t *testing.T) {
   137  		bodySize := 1024
   138  		body := mock.CreateFixedBody(bodySize)
   139  		reader := mock.NewZeroCopyReader(string(body))
   140  		dst, err := ReadBodyWithStreaming(reader, -2, 512, nil)
   141  		assert.Nil(t, err)
   142  		assert.DeepEqual(t, body, dst)
   143  	})
   144  
   145  	t.Run("TestErrBodyTooLarge", func(t *testing.T) {
   146  		bodySize := 2048
   147  		body := mock.CreateFixedBody(bodySize)
   148  		reader := mock.NewZeroCopyReader(string(body))
   149  		dst, err := ReadBodyWithStreaming(reader, bodySize, 1024, nil)
   150  		assert.True(t, errors.Is(err, errBodyTooLarge))
   151  		assert.DeepEqual(t, body[:len(dst)], dst)
   152  	})
   153  
   154  	t.Run("TestErrChunkedStream", func(t *testing.T) {
   155  		bodySize := 1024
   156  		body := mock.CreateFixedBody(bodySize)
   157  		reader := mock.NewZeroCopyReader(string(body))
   158  		dst, err := ReadBodyWithStreaming(reader, -1, bodySize, nil)
   159  		assert.True(t, errors.Is(err, errChunkedStream))
   160  		assert.Nil(t, dst)
   161  	})
   162  }
   163  
   164  func TestBodyStream(t *testing.T) {
   165  	t.Run("TestBodyStreamPrereadBuffer", func(t *testing.T) {
   166  		bodySize := 1024
   167  		body := mock.CreateFixedBody(bodySize)
   168  		byteBuffer := &bytebufferpool.ByteBuffer{}
   169  		byteBuffer.Set(body)
   170  
   171  		bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(""), nil, len(body))
   172  		defer func() {
   173  			ReleaseBodyStream(bs)
   174  		}()
   175  
   176  		b := make([]byte, bodySize)
   177  		err := bodyStreamRead(bs, b)
   178  		assert.Nil(t, err)
   179  		assert.DeepEqual(t, len(body), len(b))
   180  		assert.DeepEqual(t, string(body), string(b))
   181  	})
   182  
   183  	t.Run("TestBodyStreamRelease", func(t *testing.T) {
   184  		bodySize := 1024
   185  		body := mock.CreateFixedBody(bodySize)
   186  		byteBuffer := &bytebufferpool.ByteBuffer{}
   187  		byteBuffer.Set(body)
   188  		bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(body)), nil, bodySize*2)
   189  		err := ReleaseBodyStream(bs)
   190  		assert.Nil(t, err)
   191  	})
   192  
   193  	t.Run("TestBodyStreamChunked", func(t *testing.T) {
   194  		bodySize := 5
   195  		body := mock.CreateFixedBody(bodySize)
   196  		expectedTrailer := map[string]string{"Foo": "chunked shit"}
   197  		chunkedBody := mock.CreateChunkedBody(body, expectedTrailer, true)
   198  
   199  		byteBuffer := &bytebufferpool.ByteBuffer{}
   200  		byteBuffer.Set(chunkedBody)
   201  
   202  		bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(chunkedBody)), &protocol.Trailer{}, -1)
   203  		defer func() {
   204  			ReleaseBodyStream(bs)
   205  		}()
   206  
   207  		b := make([]byte, bodySize)
   208  		err := bodyStreamRead(bs, b)
   209  		assert.Nil(t, err)
   210  		assert.DeepEqual(t, len(body), len(b))
   211  		assert.DeepEqual(t, string(body), string(b))
   212  	})
   213  
   214  	t.Run("TestBodyStreamReadFromWire", func(t *testing.T) {
   215  		bodySize := 1024
   216  		body := mock.CreateFixedBody(bodySize)
   217  		byteBuffer := &bytebufferpool.ByteBuffer{}
   218  		byteBuffer.Set(body)
   219  
   220  		rcBodySize := 128
   221  		rcBody := mock.CreateFixedBody(rcBodySize)
   222  		bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(rcBody)), nil, -2)
   223  		defer func() {
   224  			ReleaseBodyStream(bs)
   225  		}()
   226  
   227  		b := make([]byte, bodySize)
   228  		err := bodyStreamRead(bs, b)
   229  		assert.Nil(t, err)
   230  		assert.DeepEqual(t, len(body), len(b))
   231  		assert.DeepEqual(t, string(body), string(b))
   232  	})
   233  }
   234  
   235  func bodyStreamRead(bs io.Reader, b []byte) (err error) {
   236  	nb := 0
   237  	for {
   238  		p := make([]byte, 64)
   239  		n, rErr := bs.Read(p)
   240  		if n > 0 {
   241  			copy(b[nb:], p[:])
   242  			nb = nb + n
   243  		}
   244  
   245  		if rErr != nil {
   246  			if rErr != io.EOF {
   247  				err = rErr
   248  			}
   249  			break
   250  		}
   251  	}
   252  	return
   253  }