github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/server_test.go (about) 1 /* 2 * Copyright 2023 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package http1 18 19 import ( 20 "bytes" 21 "context" 22 "errors" 23 "strings" 24 "sync" 25 "testing" 26 "time" 27 28 inStats "github.com/cloudwego/hertz/internal/stats" 29 "github.com/cloudwego/hertz/pkg/app" 30 errs "github.com/cloudwego/hertz/pkg/common/errors" 31 "github.com/cloudwego/hertz/pkg/common/test/assert" 32 "github.com/cloudwego/hertz/pkg/common/test/mock" 33 "github.com/cloudwego/hertz/pkg/common/tracer" 34 "github.com/cloudwego/hertz/pkg/common/tracer/stats" 35 "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" 36 "github.com/cloudwego/hertz/pkg/network" 37 "github.com/cloudwego/hertz/pkg/protocol" 38 "github.com/cloudwego/hertz/pkg/protocol/consts" 39 "github.com/cloudwego/hertz/pkg/protocol/http1/resp" 40 ) 41 42 var pool = &sync.Pool{New: func() interface{} { 43 return &eventStack{} 44 }} 45 46 func TestTraceEventCompleted(t *testing.T) { 47 server := &Server{} 48 server.eventStackPool = pool 49 server.EnableTrace = true 50 reqCtx := &app.RequestContext{} 51 server.Core = &mockCore{ 52 ctxPool: &sync.Pool{New: func() interface{} { 53 ti := traceinfo.NewTraceInfo() 54 ti.Stats().SetLevel(2) 55 reqCtx.SetTraceInfo(&mockTraceInfo{ti}) 56 return reqCtx 57 }}, 58 controller: &inStats.Controller{}, 59 } 60 err := server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) 61 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 62 traceInfo := reqCtx.GetTraceInfo() 63 assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil()) 64 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil()) 65 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil()) 66 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil()) 67 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil()) 68 assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart).IsNil()) 69 assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish).IsNil()) 70 assert.False(t, traceInfo.Stats().GetEvent(stats.WriteStart).IsNil()) 71 assert.False(t, traceInfo.Stats().GetEvent(stats.WriteFinish).IsNil()) 72 assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) 73 assert.Nil(t, traceInfo.Stats().Error()) 74 } 75 76 func TestTraceEventReadHeaderError(t *testing.T) { 77 server := &Server{} 78 server.eventStackPool = pool 79 server.EnableTrace = true 80 reqCtx := &app.RequestContext{} 81 server.Core = &mockCore{ 82 ctxPool: &sync.Pool{New: func() interface{} { 83 ti := traceinfo.NewTraceInfo() 84 ti.Stats().SetLevel(2) 85 reqCtx.SetTraceInfo(&mockTraceInfo{ti}) 86 return reqCtx 87 }}, 88 controller: &inStats.Controller{}, 89 } 90 err := server.Serve(context.TODO(), mock.NewConn("ErrorFirstLine\r\n\r\n")) 91 assert.NotNil(t, err) 92 traceInfo := reqCtx.GetTraceInfo() 93 assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil()) 94 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil()) 95 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil()) 96 assert.Nil(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart)) 97 assert.Nil(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish)) 98 assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart)) 99 assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish)) 100 assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteStart)) 101 assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteFinish)) 102 assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) 103 } 104 105 func TestTraceEventReadBodyError(t *testing.T) { 106 server := &Server{} 107 server.eventStackPool = pool 108 server.EnableTrace = true 109 server.GetOnly = true 110 reqCtx := &app.RequestContext{} 111 server.Core = &mockCore{ 112 ctxPool: &sync.Pool{New: func() interface{} { 113 ti := traceinfo.NewTraceInfo() 114 ti.Stats().SetLevel(2) 115 reqCtx.SetTraceInfo(&mockTraceInfo{ti}) 116 return reqCtx 117 }}, 118 controller: &inStats.Controller{}, 119 } 120 err := server.Serve(context.TODO(), mock.NewConn("POST /aaa HTTP/1.1\nHost: foobar.com\nContent-Length: 5\nContent-Type: foo/bar\n\n12346\n\n")) 121 assert.NotNil(t, err) 122 123 traceInfo := reqCtx.GetTraceInfo() 124 assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil()) 125 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil()) 126 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil()) 127 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil()) 128 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil()) 129 assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart)) 130 assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish)) 131 assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteStart)) 132 assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteFinish)) 133 assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) 134 } 135 136 func TestTraceEventWriteError(t *testing.T) { 137 server := &Server{} 138 server.eventStackPool = pool 139 server.EnableTrace = true 140 reqCtx := &app.RequestContext{} 141 server.Core = &mockCore{ 142 ctxPool: &sync.Pool{New: func() interface{} { 143 ti := traceinfo.NewTraceInfo() 144 ti.Stats().SetLevel(2) 145 reqCtx.SetTraceInfo(&mockTraceInfo{ti}) 146 return reqCtx 147 }}, 148 controller: &inStats.Controller{}, 149 } 150 err := server.Serve( 151 context.TODO(), 152 &mockErrorWriter{ 153 mock.NewConn("POST /aaa HTTP/1.1\nHost: foobar.com\nContent-Length: 5\nContent-Type: foo/bar\n\n12346\n\n"), 154 }, 155 ) 156 assert.NotNil(t, err) 157 traceInfo := reqCtx.GetTraceInfo() 158 assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil()) 159 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil()) 160 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil()) 161 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil()) 162 assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil()) 163 assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart).IsNil()) 164 assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish).IsNil()) 165 assert.False(t, traceInfo.Stats().GetEvent(stats.WriteStart).IsNil()) 166 assert.False(t, traceInfo.Stats().GetEvent(stats.WriteFinish).IsNil()) 167 assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) 168 } 169 170 func TestEventStack(t *testing.T) { 171 // Create a stack. 172 s := &eventStack{} 173 assert.True(t, s.isEmpty()) 174 175 count := 0 176 177 // Push 10 events. 178 for i := 0; i < 10; i++ { 179 s.push(func(ti traceinfo.TraceInfo, err error) { 180 count += 1 181 }) 182 } 183 184 assert.False(t, s.isEmpty()) 185 // Pop 10 events and process them. 186 for last := s.pop(); last != nil; last = s.pop() { 187 last(nil, nil) 188 } 189 190 assert.DeepEqual(t, 10, count) 191 192 // Pop an empty stack. 193 e := s.pop() 194 if e != nil { 195 t.Fatalf("should be nil") 196 } 197 } 198 199 func TestDefaultWriter(t *testing.T) { 200 server := &Server{} 201 reqCtx := &app.RequestContext{} 202 server.Core = &mockCore{ 203 ctxPool: &sync.Pool{New: func() interface{} { 204 return reqCtx 205 }}, 206 mockHandler: func(c context.Context, ctx *app.RequestContext) { 207 ctx.Write([]byte("hello, hertz")) 208 ctx.Flush() 209 }, 210 } 211 defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") 212 err := server.Serve(context.TODO(), defaultConn) 213 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 214 defaultResponseResult := defaultConn.WriterRecorder() 215 assert.DeepEqual(t, 0, defaultResponseResult.Len()) // all data is flushed so the buffer length is 0 216 response := protocol.AcquireResponse() 217 resp.Read(response, defaultResponseResult) 218 assert.DeepEqual(t, "hello, hertz", string(response.Body())) 219 } 220 221 func TestServerDisableReqCtxPool(t *testing.T) { 222 server := &Server{} 223 reqCtx := &app.RequestContext{} 224 server.Core = &mockCore{ 225 ctxPool: &sync.Pool{New: func() interface{} { 226 reqCtx.Set("POOL_KEY", "in pool") 227 return reqCtx 228 }}, 229 mockHandler: func(c context.Context, ctx *app.RequestContext) { 230 if ctx.GetString("POOL_KEY") != "in pool" { 231 t.Fatal("reqCtx is not in pool") 232 } 233 }, 234 isRunning: true, 235 } 236 defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") 237 err := server.Serve(context.TODO(), defaultConn) 238 assert.Nil(t, err) 239 disabaleRequestContextPool = true 240 defer func() { 241 // reset global variable 242 disabaleRequestContextPool = false 243 }() 244 server.Core = &mockCore{ 245 ctxPool: &sync.Pool{New: func() interface{} { 246 reqCtx.Set("POOL_KEY", "in pool") 247 return reqCtx 248 }}, 249 mockHandler: func(c context.Context, ctx *app.RequestContext) { 250 if len(ctx.GetString("POOL_KEY")) != 0 { 251 t.Fatal("must not get pool key") 252 } 253 }, 254 isRunning: true, 255 } 256 defaultConn = mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") 257 err = server.Serve(context.TODO(), defaultConn) 258 assert.Nil(t, err) 259 } 260 261 func TestHijackResponseWriter(t *testing.T) { 262 server := &Server{} 263 reqCtx := &app.RequestContext{} 264 buf := new(bytes.Buffer) 265 isFinal := false 266 server.Core = &mockCore{ 267 ctxPool: &sync.Pool{New: func() interface{} { 268 return reqCtx 269 }}, 270 mockHandler: func(c context.Context, ctx *app.RequestContext) { 271 // response before write will be dropped 272 ctx.Write([]byte("invalid data")) 273 274 ctx.Response.HijackWriter(&mock.ExtWriter{ 275 Buf: buf, 276 IsFinal: &isFinal, 277 }) 278 279 ctx.Write([]byte("hello, hertz")) 280 ctx.Flush() 281 }, 282 } 283 defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") 284 err := server.Serve(context.TODO(), defaultConn) 285 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 286 defaultResponseResult := defaultConn.WriterRecorder() 287 response := protocol.AcquireResponse() 288 resp.Read(response, defaultResponseResult) 289 assert.DeepEqual(t, 0, len(response.Body())) 290 assert.DeepEqual(t, "hello, hertz", buf.String()) 291 assert.True(t, isFinal) 292 } 293 294 func TestHijackHandler(t *testing.T) { 295 server := NewServer() 296 reqCtx := &app.RequestContext{} 297 originReadTimeout := time.Second 298 hijackReadTimeout := 200 * time.Millisecond 299 reqCtx.SetHijackHandler(func(c network.Conn) { 300 c.SetReadTimeout(hijackReadTimeout) // hijack read timeout 301 }) 302 303 server.Core = &mockCore{ 304 ctxPool: &sync.Pool{New: func() interface{} { 305 return reqCtx 306 }}, 307 } 308 309 server.HijackConnHandle = func(c network.Conn, h app.HijackHandler) { 310 h(c) 311 } 312 313 defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") 314 defaultConn.SetReadTimeout(originReadTimeout) 315 assert.DeepEqual(t, originReadTimeout, defaultConn.GetReadTimeout()) 316 err := server.Serve(context.TODO(), defaultConn) 317 assert.True(t, errors.Is(err, errs.ErrHijacked)) 318 assert.DeepEqual(t, hijackReadTimeout, defaultConn.GetReadTimeout()) 319 } 320 321 func TestKeepAlive(t *testing.T) { 322 server := NewServer() 323 reqCtx := &app.RequestContext{} 324 times := 0 325 server.Core = &mockCore{ 326 ctxPool: &sync.Pool{New: func() interface{} { 327 return reqCtx 328 }}, 329 isRunning: true, 330 mockHandler: func(c context.Context, ctx *app.RequestContext) { 331 times++ 332 if string(ctx.Path()) == "/close" { 333 ctx.SetConnectionClose() 334 } 335 }, 336 } 337 server.IdleTimeout = time.Second 338 339 var s strings.Builder 340 s.WriteString("GET / HTTP/1.1\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") 341 s.WriteString("GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") // set connection close 342 343 defaultConn := mock.NewConn(s.String()) 344 err := server.Serve(context.TODO(), defaultConn) 345 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 346 assert.DeepEqual(t, times, 2) 347 } 348 349 func TestExpect100Continue(t *testing.T) { 350 server := &Server{} 351 reqCtx := &app.RequestContext{} 352 server.Core = &mockCore{ 353 ctxPool: &sync.Pool{New: func() interface{} { 354 return reqCtx 355 }}, 356 mockHandler: func(c context.Context, ctx *app.RequestContext) { 357 data, err := ctx.Body() 358 if err == nil { 359 ctx.Write(data) 360 } 361 }, 362 } 363 364 defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") 365 err := server.Serve(context.TODO(), defaultConn) 366 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 367 defaultResponseResult := defaultConn.WriterRecorder() 368 assert.DeepEqual(t, 0, defaultResponseResult.Len()) 369 response := protocol.AcquireResponse() 370 resp.Read(response, defaultResponseResult) 371 assert.DeepEqual(t, "12345", string(response.Body())) 372 } 373 374 func TestExpect100ContinueHandler(t *testing.T) { 375 server := &Server{} 376 reqCtx := &app.RequestContext{} 377 server.Core = &mockCore{ 378 ctxPool: &sync.Pool{New: func() interface{} { 379 return reqCtx 380 }}, 381 mockHandler: func(c context.Context, ctx *app.RequestContext) { 382 data, err := ctx.Body() 383 if err == nil { 384 ctx.Write(data) 385 } 386 }, 387 } 388 server.ContinueHandler = func(header *protocol.RequestHeader) bool { 389 return false 390 } 391 392 defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") 393 err := server.Serve(context.TODO(), defaultConn) 394 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 395 defaultResponseResult := defaultConn.WriterRecorder() 396 assert.DeepEqual(t, 0, defaultResponseResult.Len()) 397 response := protocol.AcquireResponse() 398 resp.Read(response, defaultResponseResult) 399 assert.DeepEqual(t, consts.StatusExpectationFailed, response.StatusCode()) 400 assert.DeepEqual(t, "", string(response.Body())) 401 } 402 403 type mockController struct { 404 FinishTimes int 405 } 406 407 func (m *mockController) Append(col tracer.Tracer) {} 408 409 func (m *mockController) DoStart(ctx context.Context, c *app.RequestContext) context.Context { 410 return ctx 411 } 412 413 func (m *mockController) DoFinish(ctx context.Context, c *app.RequestContext, err error) { 414 m.FinishTimes++ 415 } 416 417 func (m *mockController) HasTracer() bool { return true } 418 419 func (m *mockController) reset() { m.FinishTimes = 0 } 420 421 func TestTraceDoFinishTimes(t *testing.T) { 422 server := &Server{} 423 server.eventStackPool = pool 424 server.EnableTrace = true 425 reqCtx := &app.RequestContext{} 426 controller := &mockController{} 427 server.Core = &mockCore{ 428 ctxPool: &sync.Pool{New: func() interface{} { 429 ti := traceinfo.NewTraceInfo() 430 ti.Stats().SetLevel(2) 431 reqCtx.SetTraceInfo(&mockTraceInfo{ti}) 432 return reqCtx 433 }}, 434 controller: controller, 435 } 436 // for disableKeepAlive case 437 server.DisableKeepalive = true 438 err := server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) 439 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 440 assert.DeepEqual(t, 1, controller.FinishTimes) 441 // for IdleTimeout==0 case 442 server.IdleTimeout = 0 443 controller.reset() 444 err = server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) 445 assert.True(t, errors.Is(err, errs.ErrShortConnection)) 446 assert.DeepEqual(t, 1, controller.FinishTimes) 447 } 448 449 type mockCore struct { 450 ctxPool *sync.Pool 451 controller tracer.Controller 452 mockHandler func(c context.Context, ctx *app.RequestContext) 453 isRunning bool 454 } 455 456 func (m *mockCore) IsRunning() bool { 457 return m.isRunning 458 } 459 460 func (m *mockCore) GetCtxPool() *sync.Pool { 461 return m.ctxPool 462 } 463 464 func (m *mockCore) ServeHTTP(c context.Context, ctx *app.RequestContext) { 465 if m.mockHandler != nil { 466 m.mockHandler(c, ctx) 467 } 468 } 469 470 func (m *mockCore) GetTracer() tracer.Controller { 471 return m.controller 472 } 473 474 type mockTraceInfo struct { 475 traceinfo.TraceInfo 476 } 477 478 func (m *mockTraceInfo) Reset() {} 479 480 type mockErrorWriter struct { 481 network.Conn 482 } 483 484 func (errorWriter *mockErrorWriter) Flush() error { 485 return errors.New("error") 486 } 487 488 func TestShouldRecordInTraceError(t *testing.T) { 489 assert.False(t, shouldRecordInTraceError(nil)) 490 assert.False(t, shouldRecordInTraceError(errHijacked)) 491 assert.False(t, shouldRecordInTraceError(errIdleTimeout)) 492 assert.False(t, shouldRecordInTraceError(errShortConnection)) 493 494 assert.True(t, shouldRecordInTraceError(errTimeout)) 495 assert.True(t, shouldRecordInTraceError(errors.New("foo error"))) 496 }