github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/server_test.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package http1
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"errors"
    23  	"strings"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	inStats "github.com/cloudwego/hertz/internal/stats"
    29  	"github.com/cloudwego/hertz/pkg/app"
    30  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    31  	"github.com/cloudwego/hertz/pkg/common/test/assert"
    32  	"github.com/cloudwego/hertz/pkg/common/test/mock"
    33  	"github.com/cloudwego/hertz/pkg/common/tracer"
    34  	"github.com/cloudwego/hertz/pkg/common/tracer/stats"
    35  	"github.com/cloudwego/hertz/pkg/common/tracer/traceinfo"
    36  	"github.com/cloudwego/hertz/pkg/network"
    37  	"github.com/cloudwego/hertz/pkg/protocol"
    38  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    39  	"github.com/cloudwego/hertz/pkg/protocol/http1/resp"
    40  )
    41  
    42  var pool = &sync.Pool{New: func() interface{} {
    43  	return &eventStack{}
    44  }}
    45  
    46  func TestTraceEventCompleted(t *testing.T) {
    47  	server := &Server{}
    48  	server.eventStackPool = pool
    49  	server.EnableTrace = true
    50  	reqCtx := &app.RequestContext{}
    51  	server.Core = &mockCore{
    52  		ctxPool: &sync.Pool{New: func() interface{} {
    53  			ti := traceinfo.NewTraceInfo()
    54  			ti.Stats().SetLevel(2)
    55  			reqCtx.SetTraceInfo(&mockTraceInfo{ti})
    56  			return reqCtx
    57  		}},
    58  		controller: &inStats.Controller{},
    59  	}
    60  	err := server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n"))
    61  	assert.True(t, errors.Is(err, errs.ErrShortConnection))
    62  	traceInfo := reqCtx.GetTraceInfo()
    63  	assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil())
    64  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil())
    65  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil())
    66  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil())
    67  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil())
    68  	assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart).IsNil())
    69  	assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish).IsNil())
    70  	assert.False(t, traceInfo.Stats().GetEvent(stats.WriteStart).IsNil())
    71  	assert.False(t, traceInfo.Stats().GetEvent(stats.WriteFinish).IsNil())
    72  	assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil())
    73  	assert.Nil(t, traceInfo.Stats().Error())
    74  }
    75  
    76  func TestTraceEventReadHeaderError(t *testing.T) {
    77  	server := &Server{}
    78  	server.eventStackPool = pool
    79  	server.EnableTrace = true
    80  	reqCtx := &app.RequestContext{}
    81  	server.Core = &mockCore{
    82  		ctxPool: &sync.Pool{New: func() interface{} {
    83  			ti := traceinfo.NewTraceInfo()
    84  			ti.Stats().SetLevel(2)
    85  			reqCtx.SetTraceInfo(&mockTraceInfo{ti})
    86  			return reqCtx
    87  		}},
    88  		controller: &inStats.Controller{},
    89  	}
    90  	err := server.Serve(context.TODO(), mock.NewConn("ErrorFirstLine\r\n\r\n"))
    91  	assert.NotNil(t, err)
    92  	traceInfo := reqCtx.GetTraceInfo()
    93  	assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil())
    94  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil())
    95  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil())
    96  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart))
    97  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish))
    98  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart))
    99  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish))
   100  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteStart))
   101  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteFinish))
   102  	assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil())
   103  }
   104  
   105  func TestTraceEventReadBodyError(t *testing.T) {
   106  	server := &Server{}
   107  	server.eventStackPool = pool
   108  	server.EnableTrace = true
   109  	server.GetOnly = true
   110  	reqCtx := &app.RequestContext{}
   111  	server.Core = &mockCore{
   112  		ctxPool: &sync.Pool{New: func() interface{} {
   113  			ti := traceinfo.NewTraceInfo()
   114  			ti.Stats().SetLevel(2)
   115  			reqCtx.SetTraceInfo(&mockTraceInfo{ti})
   116  			return reqCtx
   117  		}},
   118  		controller: &inStats.Controller{},
   119  	}
   120  	err := server.Serve(context.TODO(), mock.NewConn("POST /aaa HTTP/1.1\nHost: foobar.com\nContent-Length: 5\nContent-Type: foo/bar\n\n12346\n\n"))
   121  	assert.NotNil(t, err)
   122  
   123  	traceInfo := reqCtx.GetTraceInfo()
   124  	assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil())
   125  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil())
   126  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil())
   127  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil())
   128  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil())
   129  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart))
   130  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish))
   131  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteStart))
   132  	assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteFinish))
   133  	assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil())
   134  }
   135  
   136  func TestTraceEventWriteError(t *testing.T) {
   137  	server := &Server{}
   138  	server.eventStackPool = pool
   139  	server.EnableTrace = true
   140  	reqCtx := &app.RequestContext{}
   141  	server.Core = &mockCore{
   142  		ctxPool: &sync.Pool{New: func() interface{} {
   143  			ti := traceinfo.NewTraceInfo()
   144  			ti.Stats().SetLevel(2)
   145  			reqCtx.SetTraceInfo(&mockTraceInfo{ti})
   146  			return reqCtx
   147  		}},
   148  		controller: &inStats.Controller{},
   149  	}
   150  	err := server.Serve(
   151  		context.TODO(),
   152  		&mockErrorWriter{
   153  			mock.NewConn("POST /aaa HTTP/1.1\nHost: foobar.com\nContent-Length: 5\nContent-Type: foo/bar\n\n12346\n\n"),
   154  		},
   155  	)
   156  	assert.NotNil(t, err)
   157  	traceInfo := reqCtx.GetTraceInfo()
   158  	assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil())
   159  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil())
   160  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil())
   161  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil())
   162  	assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil())
   163  	assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart).IsNil())
   164  	assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish).IsNil())
   165  	assert.False(t, traceInfo.Stats().GetEvent(stats.WriteStart).IsNil())
   166  	assert.False(t, traceInfo.Stats().GetEvent(stats.WriteFinish).IsNil())
   167  	assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil())
   168  }
   169  
   170  func TestEventStack(t *testing.T) {
   171  	// Create a stack.
   172  	s := &eventStack{}
   173  	assert.True(t, s.isEmpty())
   174  
   175  	count := 0
   176  
   177  	// Push 10 events.
   178  	for i := 0; i < 10; i++ {
   179  		s.push(func(ti traceinfo.TraceInfo, err error) {
   180  			count += 1
   181  		})
   182  	}
   183  
   184  	assert.False(t, s.isEmpty())
   185  	// Pop 10 events and process them.
   186  	for last := s.pop(); last != nil; last = s.pop() {
   187  		last(nil, nil)
   188  	}
   189  
   190  	assert.DeepEqual(t, 10, count)
   191  
   192  	// Pop an empty stack.
   193  	e := s.pop()
   194  	if e != nil {
   195  		t.Fatalf("should be nil")
   196  	}
   197  }
   198  
   199  func TestDefaultWriter(t *testing.T) {
   200  	server := &Server{}
   201  	reqCtx := &app.RequestContext{}
   202  	server.Core = &mockCore{
   203  		ctxPool: &sync.Pool{New: func() interface{} {
   204  			return reqCtx
   205  		}},
   206  		mockHandler: func(c context.Context, ctx *app.RequestContext) {
   207  			ctx.Write([]byte("hello, hertz"))
   208  			ctx.Flush()
   209  		},
   210  	}
   211  	defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
   212  	err := server.Serve(context.TODO(), defaultConn)
   213  	assert.True(t, errors.Is(err, errs.ErrShortConnection))
   214  	defaultResponseResult := defaultConn.WriterRecorder()
   215  	assert.DeepEqual(t, 0, defaultResponseResult.Len()) // all data is flushed so the buffer length is 0
   216  	response := protocol.AcquireResponse()
   217  	resp.Read(response, defaultResponseResult)
   218  	assert.DeepEqual(t, "hello, hertz", string(response.Body()))
   219  }
   220  
   221  func TestServerDisableReqCtxPool(t *testing.T) {
   222  	server := &Server{}
   223  	reqCtx := &app.RequestContext{}
   224  	server.Core = &mockCore{
   225  		ctxPool: &sync.Pool{New: func() interface{} {
   226  			reqCtx.Set("POOL_KEY", "in pool")
   227  			return reqCtx
   228  		}},
   229  		mockHandler: func(c context.Context, ctx *app.RequestContext) {
   230  			if ctx.GetString("POOL_KEY") != "in pool" {
   231  				t.Fatal("reqCtx is not in pool")
   232  			}
   233  		},
   234  		isRunning: true,
   235  	}
   236  	defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
   237  	err := server.Serve(context.TODO(), defaultConn)
   238  	assert.Nil(t, err)
   239  	disabaleRequestContextPool = true
   240  	defer func() {
   241  		// reset global variable
   242  		disabaleRequestContextPool = false
   243  	}()
   244  	server.Core = &mockCore{
   245  		ctxPool: &sync.Pool{New: func() interface{} {
   246  			reqCtx.Set("POOL_KEY", "in pool")
   247  			return reqCtx
   248  		}},
   249  		mockHandler: func(c context.Context, ctx *app.RequestContext) {
   250  			if len(ctx.GetString("POOL_KEY")) != 0 {
   251  				t.Fatal("must not get pool key")
   252  			}
   253  		},
   254  		isRunning: true,
   255  	}
   256  	defaultConn = mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
   257  	err = server.Serve(context.TODO(), defaultConn)
   258  	assert.Nil(t, err)
   259  }
   260  
   261  func TestHijackResponseWriter(t *testing.T) {
   262  	server := &Server{}
   263  	reqCtx := &app.RequestContext{}
   264  	buf := new(bytes.Buffer)
   265  	isFinal := false
   266  	server.Core = &mockCore{
   267  		ctxPool: &sync.Pool{New: func() interface{} {
   268  			return reqCtx
   269  		}},
   270  		mockHandler: func(c context.Context, ctx *app.RequestContext) {
   271  			// response before write will be dropped
   272  			ctx.Write([]byte("invalid data"))
   273  
   274  			ctx.Response.HijackWriter(&mock.ExtWriter{
   275  				Buf:     buf,
   276  				IsFinal: &isFinal,
   277  			})
   278  
   279  			ctx.Write([]byte("hello, hertz"))
   280  			ctx.Flush()
   281  		},
   282  	}
   283  	defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
   284  	err := server.Serve(context.TODO(), defaultConn)
   285  	assert.True(t, errors.Is(err, errs.ErrShortConnection))
   286  	defaultResponseResult := defaultConn.WriterRecorder()
   287  	response := protocol.AcquireResponse()
   288  	resp.Read(response, defaultResponseResult)
   289  	assert.DeepEqual(t, 0, len(response.Body()))
   290  	assert.DeepEqual(t, "hello, hertz", buf.String())
   291  	assert.True(t, isFinal)
   292  }
   293  
   294  func TestHijackHandler(t *testing.T) {
   295  	server := NewServer()
   296  	reqCtx := &app.RequestContext{}
   297  	originReadTimeout := time.Second
   298  	hijackReadTimeout := 200 * time.Millisecond
   299  	reqCtx.SetHijackHandler(func(c network.Conn) {
   300  		c.SetReadTimeout(hijackReadTimeout) // hijack read timeout
   301  	})
   302  
   303  	server.Core = &mockCore{
   304  		ctxPool: &sync.Pool{New: func() interface{} {
   305  			return reqCtx
   306  		}},
   307  	}
   308  
   309  	server.HijackConnHandle = func(c network.Conn, h app.HijackHandler) {
   310  		h(c)
   311  	}
   312  
   313  	defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
   314  	defaultConn.SetReadTimeout(originReadTimeout)
   315  	assert.DeepEqual(t, originReadTimeout, defaultConn.GetReadTimeout())
   316  	err := server.Serve(context.TODO(), defaultConn)
   317  	assert.True(t, errors.Is(err, errs.ErrHijacked))
   318  	assert.DeepEqual(t, hijackReadTimeout, defaultConn.GetReadTimeout())
   319  }
   320  
   321  func TestKeepAlive(t *testing.T) {
   322  	server := NewServer()
   323  	reqCtx := &app.RequestContext{}
   324  	times := 0
   325  	server.Core = &mockCore{
   326  		ctxPool: &sync.Pool{New: func() interface{} {
   327  			return reqCtx
   328  		}},
   329  		isRunning: true,
   330  		mockHandler: func(c context.Context, ctx *app.RequestContext) {
   331  			times++
   332  			if string(ctx.Path()) == "/close" {
   333  				ctx.SetConnectionClose()
   334  			}
   335  		},
   336  	}
   337  	server.IdleTimeout = time.Second
   338  
   339  	var s strings.Builder
   340  	s.WriteString("GET / HTTP/1.1\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n")
   341  	s.WriteString("GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") // set connection close
   342  
   343  	defaultConn := mock.NewConn(s.String())
   344  	err := server.Serve(context.TODO(), defaultConn)
   345  	assert.True(t, errors.Is(err, errs.ErrShortConnection))
   346  	assert.DeepEqual(t, times, 2)
   347  }
   348  
   349  func TestExpect100Continue(t *testing.T) {
   350  	server := &Server{}
   351  	reqCtx := &app.RequestContext{}
   352  	server.Core = &mockCore{
   353  		ctxPool: &sync.Pool{New: func() interface{} {
   354  			return reqCtx
   355  		}},
   356  		mockHandler: func(c context.Context, ctx *app.RequestContext) {
   357  			data, err := ctx.Body()
   358  			if err == nil {
   359  				ctx.Write(data)
   360  			}
   361  		},
   362  	}
   363  
   364  	defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
   365  	err := server.Serve(context.TODO(), defaultConn)
   366  	assert.True(t, errors.Is(err, errs.ErrShortConnection))
   367  	defaultResponseResult := defaultConn.WriterRecorder()
   368  	assert.DeepEqual(t, 0, defaultResponseResult.Len())
   369  	response := protocol.AcquireResponse()
   370  	resp.Read(response, defaultResponseResult)
   371  	assert.DeepEqual(t, "12345", string(response.Body()))
   372  }
   373  
   374  func TestExpect100ContinueHandler(t *testing.T) {
   375  	server := &Server{}
   376  	reqCtx := &app.RequestContext{}
   377  	server.Core = &mockCore{
   378  		ctxPool: &sync.Pool{New: func() interface{} {
   379  			return reqCtx
   380  		}},
   381  		mockHandler: func(c context.Context, ctx *app.RequestContext) {
   382  			data, err := ctx.Body()
   383  			if err == nil {
   384  				ctx.Write(data)
   385  			}
   386  		},
   387  	}
   388  	server.ContinueHandler = func(header *protocol.RequestHeader) bool {
   389  		return false
   390  	}
   391  
   392  	defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
   393  	err := server.Serve(context.TODO(), defaultConn)
   394  	assert.True(t, errors.Is(err, errs.ErrShortConnection))
   395  	defaultResponseResult := defaultConn.WriterRecorder()
   396  	assert.DeepEqual(t, 0, defaultResponseResult.Len())
   397  	response := protocol.AcquireResponse()
   398  	resp.Read(response, defaultResponseResult)
   399  	assert.DeepEqual(t, consts.StatusExpectationFailed, response.StatusCode())
   400  	assert.DeepEqual(t, "", string(response.Body()))
   401  }
   402  
   403  type mockController struct {
   404  	FinishTimes int
   405  }
   406  
   407  func (m *mockController) Append(col tracer.Tracer) {}
   408  
   409  func (m *mockController) DoStart(ctx context.Context, c *app.RequestContext) context.Context {
   410  	return ctx
   411  }
   412  
   413  func (m *mockController) DoFinish(ctx context.Context, c *app.RequestContext, err error) {
   414  	m.FinishTimes++
   415  }
   416  
   417  func (m *mockController) HasTracer() bool { return true }
   418  
   419  func (m *mockController) reset() { m.FinishTimes = 0 }
   420  
   421  func TestTraceDoFinishTimes(t *testing.T) {
   422  	server := &Server{}
   423  	server.eventStackPool = pool
   424  	server.EnableTrace = true
   425  	reqCtx := &app.RequestContext{}
   426  	controller := &mockController{}
   427  	server.Core = &mockCore{
   428  		ctxPool: &sync.Pool{New: func() interface{} {
   429  			ti := traceinfo.NewTraceInfo()
   430  			ti.Stats().SetLevel(2)
   431  			reqCtx.SetTraceInfo(&mockTraceInfo{ti})
   432  			return reqCtx
   433  		}},
   434  		controller: controller,
   435  	}
   436  	// for disableKeepAlive case
   437  	server.DisableKeepalive = true
   438  	err := server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n"))
   439  	assert.True(t, errors.Is(err, errs.ErrShortConnection))
   440  	assert.DeepEqual(t, 1, controller.FinishTimes)
   441  	// for IdleTimeout==0 case
   442  	server.IdleTimeout = 0
   443  	controller.reset()
   444  	err = server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n"))
   445  	assert.True(t, errors.Is(err, errs.ErrShortConnection))
   446  	assert.DeepEqual(t, 1, controller.FinishTimes)
   447  }
   448  
   449  type mockCore struct {
   450  	ctxPool     *sync.Pool
   451  	controller  tracer.Controller
   452  	mockHandler func(c context.Context, ctx *app.RequestContext)
   453  	isRunning   bool
   454  }
   455  
   456  func (m *mockCore) IsRunning() bool {
   457  	return m.isRunning
   458  }
   459  
   460  func (m *mockCore) GetCtxPool() *sync.Pool {
   461  	return m.ctxPool
   462  }
   463  
   464  func (m *mockCore) ServeHTTP(c context.Context, ctx *app.RequestContext) {
   465  	if m.mockHandler != nil {
   466  		m.mockHandler(c, ctx)
   467  	}
   468  }
   469  
   470  func (m *mockCore) GetTracer() tracer.Controller {
   471  	return m.controller
   472  }
   473  
   474  type mockTraceInfo struct {
   475  	traceinfo.TraceInfo
   476  }
   477  
   478  func (m *mockTraceInfo) Reset() {}
   479  
   480  type mockErrorWriter struct {
   481  	network.Conn
   482  }
   483  
   484  func (errorWriter *mockErrorWriter) Flush() error {
   485  	return errors.New("error")
   486  }
   487  
   488  func TestShouldRecordInTraceError(t *testing.T) {
   489  	assert.False(t, shouldRecordInTraceError(nil))
   490  	assert.False(t, shouldRecordInTraceError(errHijacked))
   491  	assert.False(t, shouldRecordInTraceError(errIdleTimeout))
   492  	assert.False(t, shouldRecordInTraceError(errShortConnection))
   493  
   494  	assert.True(t, shouldRecordInTraceError(errTimeout))
   495  	assert.True(t, shouldRecordInTraceError(errors.New("foo error")))
   496  }