github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/ext/stream_test.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package ext 17 18 import ( 19 "bytes" 20 "errors" 21 "fmt" 22 "io" 23 "testing" 24 25 "github.com/cloudwego/hertz/pkg/common/bytebufferpool" 26 "github.com/cloudwego/hertz/pkg/common/test/assert" 27 "github.com/cloudwego/hertz/pkg/common/test/mock" 28 "github.com/cloudwego/hertz/pkg/protocol" 29 ) 30 31 func createChunkedBody(body, rest []byte, trailer map[string]string, hasTrailer bool) []byte { 32 var b []byte 33 chunkSize := 1 34 for len(body) > 0 { 35 if chunkSize > len(body) { 36 chunkSize = len(body) 37 } 38 b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...) 39 b = append(b, body[:chunkSize]...) 40 b = append(b, []byte("\r\n")...) 41 body = body[chunkSize:] 42 chunkSize++ 43 } 44 if hasTrailer { 45 b = append(b, "0\r\n"...) 46 for k, v := range trailer { 47 b = append(b, k...) 48 b = append(b, ": "...) 49 b = append(b, v...) 50 b = append(b, "\r\n"...) 51 } 52 b = append(b, "\r\n"...) 53 } 54 return append(b, rest...) 55 } 56 57 func testChunkedSkipRest(t *testing.T, data, rest string) { 58 var pool bytebufferpool.Pool 59 reader := mock.NewZeroCopyReader(data) 60 61 bs := AcquireBodyStream(pool.Get(), reader, &protocol.Trailer{}, -1) 62 err := bs.(*bodyStream).skipRest() 63 assert.Nil(t, err) 64 65 rest_data, err := io.ReadAll(reader) 66 assert.Nil(t, err) 67 assert.DeepEqual(t, rest, string(rest_data)) 68 } 69 70 func testChunkedSkipRestWithBodySize(t *testing.T, bodySize int) { 71 body := mock.CreateFixedBody(bodySize) 72 rest := mock.CreateFixedBody(bodySize) 73 data := createChunkedBody(body, rest, map[string]string{"foo": "bar"}, true) 74 75 testChunkedSkipRest(t, string(data), string(rest)) 76 } 77 78 func TestChunkedSkipRest(t *testing.T) { 79 t.Parallel() 80 81 testChunkedSkipRest(t, "0\r\n\r\n", "") 82 testChunkedSkipRest(t, "0\r\n\r\nHTTP/1.1 / POST", "HTTP/1.1 / POST") 83 testChunkedSkipRest(t, "0\r\nHertz: test\r\nfoo: bar\r\n\r\nHTTP/1.1 / POST", "HTTP/1.1 / POST") 84 85 testChunkedSkipRestWithBodySize(t, 5) 86 87 // medium-size body 88 testChunkedSkipRestWithBodySize(t, 43488) 89 90 // big body 91 testChunkedSkipRestWithBodySize(t, 3*1024*1024) 92 } 93 94 func TestBodyStream_Reset(t *testing.T) { 95 t.Parallel() 96 bs := bodyStream{ 97 prefetchedBytes: bytes.NewReader([]byte("aaa")), 98 reader: mock.NewZeroCopyReader("bbb"), 99 trailer: &protocol.Trailer{}, 100 offset: 10, 101 contentLength: 20, 102 chunkLeft: 50, 103 chunkEOF: true, 104 } 105 106 bs.reset() 107 108 assert.Nil(t, bs.prefetchedBytes) 109 assert.Nil(t, bs.reader) 110 assert.Nil(t, bs.trailer) 111 assert.DeepEqual(t, 0, bs.offset) 112 assert.DeepEqual(t, 0, bs.contentLength) 113 assert.DeepEqual(t, 0, bs.chunkLeft) 114 assert.False(t, bs.chunkEOF) 115 } 116 117 func TestReadBodyWithStreaming(t *testing.T) { 118 t.Run("TestBodyFixedSize", func(t *testing.T) { 119 bodySize := 1024 120 body := mock.CreateFixedBody(bodySize) 121 reader := mock.NewZeroCopyReader(string(body)) 122 dst, err := ReadBodyWithStreaming(reader, bodySize, -1, nil) 123 assert.Nil(t, err) 124 assert.DeepEqual(t, body, dst) 125 }) 126 127 t.Run("TestBodyFixedSizeMaxContentLength", func(t *testing.T) { 128 bodySize := 8 * 1024 * 2 129 body := mock.CreateFixedBody(bodySize) 130 reader := mock.NewZeroCopyReader(string(body)) 131 dst, err := ReadBodyWithStreaming(reader, bodySize, 8*1024*10, nil) 132 assert.Nil(t, err) 133 assert.DeepEqual(t, body[:maxContentLengthInStream], dst) 134 }) 135 136 t.Run("TestBodyIdentity", func(t *testing.T) { 137 bodySize := 1024 138 body := mock.CreateFixedBody(bodySize) 139 reader := mock.NewZeroCopyReader(string(body)) 140 dst, err := ReadBodyWithStreaming(reader, -2, 512, nil) 141 assert.Nil(t, err) 142 assert.DeepEqual(t, body, dst) 143 }) 144 145 t.Run("TestErrBodyTooLarge", func(t *testing.T) { 146 bodySize := 2048 147 body := mock.CreateFixedBody(bodySize) 148 reader := mock.NewZeroCopyReader(string(body)) 149 dst, err := ReadBodyWithStreaming(reader, bodySize, 1024, nil) 150 assert.True(t, errors.Is(err, errBodyTooLarge)) 151 assert.DeepEqual(t, body[:len(dst)], dst) 152 }) 153 154 t.Run("TestErrChunkedStream", func(t *testing.T) { 155 bodySize := 1024 156 body := mock.CreateFixedBody(bodySize) 157 reader := mock.NewZeroCopyReader(string(body)) 158 dst, err := ReadBodyWithStreaming(reader, -1, bodySize, nil) 159 assert.True(t, errors.Is(err, errChunkedStream)) 160 assert.Nil(t, dst) 161 }) 162 } 163 164 func TestBodyStream(t *testing.T) { 165 t.Run("TestBodyStreamPrereadBuffer", func(t *testing.T) { 166 bodySize := 1024 167 body := mock.CreateFixedBody(bodySize) 168 byteBuffer := &bytebufferpool.ByteBuffer{} 169 byteBuffer.Set(body) 170 171 bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(""), nil, len(body)) 172 defer func() { 173 ReleaseBodyStream(bs) 174 }() 175 176 b := make([]byte, bodySize) 177 err := bodyStreamRead(bs, b) 178 assert.Nil(t, err) 179 assert.DeepEqual(t, len(body), len(b)) 180 assert.DeepEqual(t, string(body), string(b)) 181 }) 182 183 t.Run("TestBodyStreamRelease", func(t *testing.T) { 184 bodySize := 1024 185 body := mock.CreateFixedBody(bodySize) 186 byteBuffer := &bytebufferpool.ByteBuffer{} 187 byteBuffer.Set(body) 188 bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(body)), nil, bodySize*2) 189 err := ReleaseBodyStream(bs) 190 assert.Nil(t, err) 191 }) 192 193 t.Run("TestBodyStreamChunked", func(t *testing.T) { 194 bodySize := 5 195 body := mock.CreateFixedBody(bodySize) 196 expectedTrailer := map[string]string{"Foo": "chunked shit"} 197 chunkedBody := mock.CreateChunkedBody(body, expectedTrailer, true) 198 199 byteBuffer := &bytebufferpool.ByteBuffer{} 200 byteBuffer.Set(chunkedBody) 201 202 bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(chunkedBody)), &protocol.Trailer{}, -1) 203 defer func() { 204 ReleaseBodyStream(bs) 205 }() 206 207 b := make([]byte, bodySize) 208 err := bodyStreamRead(bs, b) 209 assert.Nil(t, err) 210 assert.DeepEqual(t, len(body), len(b)) 211 assert.DeepEqual(t, string(body), string(b)) 212 }) 213 214 t.Run("TestBodyStreamReadFromWire", func(t *testing.T) { 215 bodySize := 1024 216 body := mock.CreateFixedBody(bodySize) 217 byteBuffer := &bytebufferpool.ByteBuffer{} 218 byteBuffer.Set(body) 219 220 rcBodySize := 128 221 rcBody := mock.CreateFixedBody(rcBodySize) 222 bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(rcBody)), nil, -2) 223 defer func() { 224 ReleaseBodyStream(bs) 225 }() 226 227 b := make([]byte, bodySize) 228 err := bodyStreamRead(bs, b) 229 assert.Nil(t, err) 230 assert.DeepEqual(t, len(body), len(b)) 231 assert.DeepEqual(t, string(body), string(b)) 232 }) 233 } 234 235 func bodyStreamRead(bs io.Reader, b []byte) (err error) { 236 nb := 0 237 for { 238 p := make([]byte, 64) 239 n, rErr := bs.Read(p) 240 if n > 0 { 241 copy(b[nb:], p[:]) 242 nb = nb + n 243 } 244 245 if rErr != nil { 246 if rErr != io.EOF { 247 err = rErr 248 } 249 break 250 } 251 } 252 return 253 }