github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/ext/common_test.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package ext 18 19 import ( 20 "bytes" 21 "errors" 22 "io" 23 "strings" 24 "testing" 25 26 errs "github.com/cloudwego/hertz/pkg/common/errors" 27 "github.com/cloudwego/hertz/pkg/common/hlog" 28 "github.com/cloudwego/hertz/pkg/common/test/assert" 29 "github.com/cloudwego/hertz/pkg/common/test/mock" 30 "github.com/cloudwego/hertz/pkg/protocol" 31 "github.com/cloudwego/netpoll" 32 ) 33 34 func Test_stripSpace(t *testing.T) { 35 a := stripSpace([]byte(" a")) 36 b := stripSpace([]byte("b ")) 37 c := stripSpace([]byte(" c ")) 38 assert.DeepEqual(t, []byte("a"), a) 39 assert.DeepEqual(t, []byte("b"), b) 40 assert.DeepEqual(t, []byte("c"), c) 41 } 42 43 func Test_bufferSnippet(t *testing.T) { 44 a := make([]byte, 39) 45 b := make([]byte, 41) 46 assert.False(t, strings.Contains(BufferSnippet(a), "\"...\"")) 47 assert.True(t, strings.Contains(BufferSnippet(b), "\"...\"")) 48 } 49 50 func Test_isOnlyCRLF(t *testing.T) { 51 assert.True(t, isOnlyCRLF([]byte("\r\n"))) 52 assert.True(t, isOnlyCRLF([]byte("\n"))) 53 } 54 55 func TestReadTrailer(t *testing.T) { 56 exceptedTrailers := map[string]string{"Hertz": "test"} 57 zr := mock.NewZeroCopyReader("0\r\nHertz: test\r\n\r\n") 58 trailer := protocol.Trailer{} 59 keys := make([]string, 0, len(exceptedTrailers)) 60 for k := range exceptedTrailers { 61 keys = append(keys, k) 62 } 63 trailer.SetTrailers([]byte(strings.Join(keys, ", "))) 64 err := ReadTrailer(&trailer, zr) 65 if err != nil { 66 t.Fatalf("Cannot read trailer: %v", err) 67 } 68 69 for k, v := range exceptedTrailers { 70 got := trailer.Peek(k) 71 if !bytes.Equal(got, []byte(v)) { 72 t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got) 73 } 74 } 75 } 76 77 func TestReadTrailerError(t *testing.T) { 78 // with bad trailer 79 zr := mock.NewZeroCopyReader("0\r\nHertz: test\r\nContent-Type: aaa\r\n\r\n") 80 trailer := protocol.Trailer{} 81 err := ReadTrailer(&trailer, zr) 82 if err == nil { 83 t.Fatalf("expecting error.") 84 } 85 86 // eof 87 er := mock.EOFReader{} 88 trailer = protocol.Trailer{} 89 err = ReadTrailer(&trailer, &er) 90 assert.DeepEqual(t, io.EOF, err) 91 } 92 93 func TestReadTrailer1(t *testing.T) { 94 exceptedTrailers := map[string]string{} 95 zr := mock.NewZeroCopyReader("0\r\n\r\n") 96 trailer := protocol.Trailer{} 97 err := ReadTrailer(&trailer, zr) 98 if err != nil { 99 t.Fatalf("Cannot read trailer: %v", err) 100 } 101 102 for k, v := range exceptedTrailers { 103 got := trailer.Peek(k) 104 if !bytes.Equal(got, []byte(v)) { 105 t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got) 106 } 107 } 108 } 109 110 func TestReadRawHeaders(t *testing.T) { 111 s := "HTTP/1.1 200 OK\r\n" + 112 "EmptyValue1:\r\n" + 113 "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + 114 "Foo: Bar\r\n" + 115 "Multi-Line: one;\r\n two\r\n" + 116 "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" + 117 "Content-Length: 5\r\n\r\n" + 118 "HELLOaaa" 119 120 var dst []byte 121 rawHeaders, index, err := ReadRawHeaders(dst, []byte(s)) 122 assert.Nil(t, err) 123 assert.DeepEqual(t, s[:index], string(rawHeaders)) 124 } 125 126 func TestBodyChunked(t *testing.T) { 127 var log bytes.Buffer 128 hlog.SetOutput(&log) 129 130 body := "foobar baz aaa bbb ccc" 131 chunk := "16\r\nfoobar baz aaa bbb ccc\r\n0\r\n" 132 b := bytes.NewBufferString(body) 133 134 var w bytes.Buffer 135 zw := netpoll.NewWriter(&w) 136 WriteBodyChunked(zw, b) 137 138 assert.DeepEqual(t, chunk, w.String()) 139 140 zr := mock.NewZeroCopyReader(chunk) 141 rb, err := ReadBody(zr, -1, 0, nil) 142 assert.Nil(t, err) 143 assert.DeepEqual(t, body, string(rb)) 144 145 assert.DeepEqual(t, 0, log.Len()) 146 } 147 148 func TestBrokenBodyChunked(t *testing.T) { 149 brokenReader := mock.NewBrokenConn("") 150 var log bytes.Buffer 151 hlog.SetOutput(&log) 152 153 var w bytes.Buffer 154 zw := netpoll.NewWriter(&w) 155 err := WriteBodyChunked(zw, brokenReader) 156 assert.Nil(t, err) 157 158 assert.DeepEqual(t, []byte("0\r\n"), w.Bytes()) 159 assert.True(t, bytes.Contains(log.Bytes(), []byte("writing chunked response body encountered an error from the reader"))) 160 } 161 162 func TestBodyFixedSize(t *testing.T) { 163 body := mock.CreateFixedBody(10) 164 b := bytes.NewBuffer(body) 165 166 var w bytes.Buffer 167 zw := netpoll.NewWriter(&w) 168 WriteBodyFixedSize(zw, b, int64(len(body))) 169 170 assert.DeepEqual(t, body, w.Bytes()) 171 172 zr := mock.NewZeroCopyReader(string(body)) 173 rb, err := ReadBody(zr, len(body), 0, nil) 174 assert.Nil(t, err) 175 assert.DeepEqual(t, body, rb) 176 } 177 178 func TestBodyFixedSizeQuickPath(t *testing.T) { 179 conn := mock.NewBrokenConn("") 180 err := WriteBodyFixedSize(conn.Writer(), conn, 0) 181 assert.Nil(t, err) 182 } 183 184 func TestBodyIdentity(t *testing.T) { 185 body := mock.CreateFixedBody(1024) 186 zr := mock.NewZeroCopyReader(string(body)) 187 rb, err := ReadBody(zr, -2, 0, nil) 188 assert.Nil(t, err) 189 assert.DeepEqual(t, string(body), string(rb)) 190 } 191 192 func TestBodySkipTrailer(t *testing.T) { 193 t.Run("TestBodySkipTrailer", func(t *testing.T) { 194 body := mock.CreateFixedBody(10) 195 trailer := map[string]string{"Foo": "chunked shit"} 196 chunkedBody := mock.CreateChunkedBody(body, trailer, true) 197 r := mock.NewSlowReadConn(string(chunkedBody)) 198 err := SkipTrailer(r) 199 assert.Nil(t, err) 200 _, err = r.ReadByte() 201 assert.NotNil(t, err) 202 assert.True(t, errors.Is(err, netpoll.ErrEOF)) 203 }) 204 205 t.Run("TestBodySkipTrailerError", func(t *testing.T) { 206 // timeout error 207 sr := mock.NewSlowReadConn("") 208 err := SkipTrailer(sr) 209 assert.NotNil(t, err) 210 assert.True(t, errors.Is(err, errs.ErrTimeout)) 211 // EOF error 212 er := &mock.EOFReader{} 213 err = SkipTrailer(er) 214 assert.NotNil(t, err) 215 assert.True(t, errors.Is(err, io.EOF)) 216 }) 217 }