goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/websocket/websocket_test.go (about)

     1  package websocket
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"regexp"
     9  	"strings"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  	"goyave.dev/goyave/v5"
    16  	"goyave.dev/goyave/v5/config"
    17  	"goyave.dev/goyave/v5/slog"
    18  	"goyave.dev/goyave/v5/util/errors"
    19  	"goyave.dev/goyave/v5/util/testutil"
    20  
    21  	ws "github.com/gorilla/websocket"
    22  
    23  	stdslog "log/slog"
    24  )
    25  
    26  func prepareTestConfig() goyave.Options {
    27  	cfg := config.LoadDefault()
    28  	cfg.Set("server.port", 0)
    29  	cfg.Set("server.websocketCloseTimeout", 1)
    30  	cfg.Set("app.debug", false)
    31  	return goyave.Options{Config: cfg}
    32  }
    33  
    34  type testController struct {
    35  	goyave.Component
    36  	t  *testing.T
    37  	wg *sync.WaitGroup
    38  
    39  	serve          func(conn *Conn, r *goyave.Request) error
    40  	checkOrigin    func(r *goyave.Request) bool
    41  	upgradeHeaders func(r *goyave.Request) http.Header
    42  }
    43  
    44  func (c *testController) CheckOrigin(r *goyave.Request) bool {
    45  	if c.checkOrigin != nil {
    46  		return c.checkOrigin(r)
    47  	}
    48  	return false
    49  }
    50  
    51  func (c *testController) UpgradeHeaders(r *goyave.Request) http.Header {
    52  	if c.upgradeHeaders != nil {
    53  		return c.upgradeHeaders(r)
    54  	}
    55  	return http.Header{}
    56  }
    57  
    58  func (c *testController) Serve(conn *Conn, r *goyave.Request) error {
    59  	c.wg.Add(1)
    60  	defer c.wg.Done()
    61  	if c.serve != nil {
    62  		return c.serve(conn, r)
    63  	}
    64  	for {
    65  		mt, message, err := conn.ReadMessage()
    66  		if err != nil {
    67  			if IsCloseError(err) {
    68  				return err
    69  			}
    70  			err = fmt.Errorf("read: %w", err)
    71  			assert.Error(c.t, err)
    72  			return err
    73  		}
    74  		err = conn.WriteMessage(mt, message)
    75  		if err != nil {
    76  			err = fmt.Errorf("write: %w", err)
    77  			assert.Error(c.t, err)
    78  			return err
    79  		}
    80  	}
    81  }
    82  
    83  type testControllerWithErrorHandler struct {
    84  	testController
    85  
    86  	onUpgradeError func(response *goyave.Response, request *goyave.Request, status int, reason error)
    87  	onError        func(c *testControllerWithErrorHandler, request *goyave.Request, err error)
    88  }
    89  
    90  func (c *testControllerWithErrorHandler) OnUpgradeError(response *goyave.Response, request *goyave.Request, status int, reason error) {
    91  	if c.onUpgradeError != nil {
    92  		c.onUpgradeError(response, request, status, reason)
    93  		return
    94  	}
    95  }
    96  
    97  func (c *testControllerWithErrorHandler) OnError(request *goyave.Request, err error) {
    98  	if c.onError != nil {
    99  		c.onError(c, request, err)
   100  		return
   101  	}
   102  }
   103  
   104  type testControllerRegistrer struct {
   105  	testController
   106  
   107  	registerRoute func(*goyave.Router, goyave.Handler)
   108  }
   109  
   110  func (c *testControllerRegistrer) RegisterRoute(router *goyave.Router, handler goyave.Handler) {
   111  	if c.registerRoute != nil {
   112  		c.registerRoute(router, handler)
   113  	}
   114  }
   115  
   116  func TestIsCloseError(t *testing.T) {
   117  	cases := []struct {
   118  		err  error
   119  		want bool
   120  	}{
   121  		{err: &ws.CloseError{Code: ws.CloseNormalClosure}, want: true},
   122  		{err: &ws.CloseError{Code: ws.CloseGoingAway}, want: true},
   123  		{err: &ws.CloseError{Code: ws.CloseNoStatusReceived}, want: true},
   124  		{err: fmt.Errorf("wrap: %w", &ws.CloseError{Code: ws.CloseNoStatusReceived}), want: true},
   125  		{err: &ws.CloseError{Code: ws.CloseAbnormalClosure}, want: false},
   126  		{err: &ws.CloseError{Code: ws.CloseProtocolError}, want: false},
   127  		{err: fmt.Errorf("wrap: %w", &ws.CloseError{Code: ws.CloseProtocolError}), want: false},
   128  		{err: errors.New(&ws.CloseError{Code: ws.CloseNormalClosure}), want: true},
   129  		{err: errors.New(&ws.CloseError{Code: ws.CloseProtocolError}), want: false},
   130  	}
   131  
   132  	for _, c := range cases {
   133  		c := c
   134  		t.Run(c.err.Error(), func(t *testing.T) {
   135  			assert.Equal(t, c.want, IsCloseError(c.err))
   136  		})
   137  	}
   138  }
   139  
   140  func TestAdapterOnError(t *testing.T) {
   141  	req := testutil.NewTestRequest(http.MethodGet, "/websocket", nil)
   142  	resp, _ := testutil.NewTestResponse(req)
   143  	reasonErr := fmt.Errorf("test adapter error")
   144  	executed := false
   145  	a := adapter{
   146  		upgradeErrorHandler: func(response *goyave.Response, request *goyave.Request, status int, reason error) {
   147  			assert.Equal(t, req, request)
   148  			assert.Equal(t, resp, response)
   149  			assert.Equal(t, http.StatusBadRequest, status)
   150  			assert.Equal(t, reasonErr, reason)
   151  			executed = true
   152  		},
   153  		request: req,
   154  	}
   155  
   156  	a.onError(resp, req.Request(), http.StatusBadRequest, reasonErr)
   157  	assert.True(t, executed)
   158  	assert.Equal(t, "13", resp.Header().Get("Sec-Websocket-Version"))
   159  
   160  	assert.Panics(t, func() {
   161  		a.onError(resp, req.Request(), http.StatusInternalServerError, reasonErr)
   162  	})
   163  }
   164  
   165  func TestGetCheckOriginFunction(t *testing.T) {
   166  	req := testutil.NewTestRequest(http.MethodGet, "/websocket", nil)
   167  	a := adapter{
   168  		request: req,
   169  	}
   170  	assert.Nil(t, a.getCheckOriginFunc())
   171  
   172  	executed := false
   173  	a.checkOrigin = func(r *goyave.Request) bool {
   174  		assert.Equal(t, req, r)
   175  		executed = true
   176  		return true
   177  	}
   178  
   179  	f := a.getCheckOriginFunc()
   180  	assert.NotNil(t, f)
   181  	assert.True(t, f(req.Request()))
   182  	assert.True(t, executed)
   183  }
   184  
   185  func TestDefaultUpgradeErrorHandler(t *testing.T) {
   186  
   187  	cases := []struct {
   188  		config    func() goyave.Options
   189  		expect    func(*testing.T, map[string]string)
   190  		reasonErr error
   191  		desc      string
   192  	}{
   193  		{
   194  			desc:      "debug_on",
   195  			config:    func() goyave.Options { return goyave.Options{Config: config.LoadDefault()} },
   196  			reasonErr: fmt.Errorf("test upgrade error handler"),
   197  			expect: func(t *testing.T, body map[string]string) {
   198  				assert.Equal(t, map[string]string{"error": "test upgrade error handler"}, body)
   199  			},
   200  		},
   201  		{
   202  			desc:      "debug_off",
   203  			config:    prepareTestConfig,
   204  			reasonErr: fmt.Errorf("test upgrade error handler"),
   205  			expect: func(t *testing.T, body map[string]string) {
   206  				assert.Equal(t, map[string]string{"error": http.StatusText(http.StatusBadRequest)}, body)
   207  			},
   208  		},
   209  	}
   210  
   211  	for _, c := range cases {
   212  		c := c
   213  		t.Run(c.desc, func(t *testing.T) {
   214  			server := testutil.NewTestServerWithOptions(t, c.config())
   215  			req := server.NewTestRequest(http.MethodGet, "/websocket", nil)
   216  			resp, recorder := server.NewTestResponse(req)
   217  
   218  			upgrader := &Upgrader{}
   219  			upgrader.Init(server.Server)
   220  			upgrader.defaultUpgradeErrorHandler(resp, req, http.StatusBadRequest, c.reasonErr)
   221  
   222  			result := recorder.Result()
   223  			assert.Equal(t, "application/json; charset=utf-8", result.Header.Get("Content-Type"))
   224  			assert.Equal(t, http.StatusBadRequest, result.StatusCode)
   225  			body, err := testutil.ReadJSONBody[map[string]string](result.Body)
   226  			assert.NoError(t, result.Body.Close())
   227  			assert.NoError(t, err)
   228  			c.expect(t, body)
   229  		})
   230  	}
   231  }
   232  
   233  func TestMakeUpgrader(t *testing.T) {
   234  	upgrader := Upgrader{}
   235  
   236  	req := testutil.NewTestRequest(http.MethodGet, "/websocket", nil)
   237  	u := upgrader.makeUpgrader(req)
   238  
   239  	assert.Equal(t, upgrader.Settings.HandshakeTimeout, u.HandshakeTimeout)
   240  	assert.Equal(t, upgrader.Settings.ReadBufferSize, u.ReadBufferSize)
   241  	assert.Equal(t, upgrader.Settings.WriteBufferSize, u.WriteBufferSize)
   242  	assert.Equal(t, upgrader.Settings.WriteBufferPool, u.WriteBufferPool)
   243  	assert.Equal(t, upgrader.Settings.Subprotocols, u.Subprotocols)
   244  	assert.Equal(t, upgrader.Settings.EnableCompression, u.EnableCompression)
   245  	assert.NotNil(t, u.Error)
   246  	assert.Nil(t, u.CheckOrigin)
   247  
   248  	upgrader.Settings.EnableCompression = true
   249  	u = upgrader.makeUpgrader(req)
   250  	assert.Equal(t, upgrader.Settings.EnableCompression, u.EnableCompression)
   251  
   252  	upgradeErrorExecuted := false
   253  	checkOriginExecuted := false
   254  	upgrader.Controller = &testControllerWithErrorHandler{
   255  		onUpgradeError: func(_ *goyave.Response, _ *goyave.Request, _ int, _ error) {
   256  			upgradeErrorExecuted = true
   257  		},
   258  		testController: testController{
   259  			checkOrigin: func(_ *goyave.Request) bool {
   260  				checkOriginExecuted = true
   261  				return true
   262  			},
   263  		},
   264  	}
   265  
   266  	u = upgrader.makeUpgrader(req)
   267  	assert.True(t, u.CheckOrigin(req.Request()))
   268  	assert.True(t, checkOriginExecuted)
   269  
   270  	resp, _ := testutil.NewTestResponse(req)
   271  	u.Error(resp, nil, 0, nil)
   272  	assert.True(t, upgradeErrorExecuted)
   273  }
   274  
   275  func TestUpgrade(t *testing.T) {
   276  	// Server shutdown doesn't wait for Hijacked connections to
   277  	// terminate before returning.
   278  	wg := sync.WaitGroup{}
   279  	wg.Add(2)
   280  
   281  	var routeURL string
   282  	server := testutil.NewTestServerWithOptions(t, prepareTestConfig())
   283  	server.RegisterRoutes(func(_ *goyave.Server, r *goyave.Router) {
   284  		upgrader := New(&testController{
   285  			t:  t,
   286  			wg: &wg,
   287  			checkOrigin: func(_ *goyave.Request) bool {
   288  				return true
   289  			},
   290  			upgradeHeaders: func(_ *goyave.Request) http.Header {
   291  				headers := http.Header{}
   292  				headers.Add("X-Test", "Value")
   293  				return headers
   294  			},
   295  		})
   296  		r.Subrouter("/websocket").Controller(upgrader)
   297  	})
   298  
   299  	server.RegisterStartupHook(func(s *goyave.Server) {
   300  		defer func() {
   301  			server.Stop()
   302  			wg.Done()
   303  		}()
   304  		route := s.Router().GetSubrouters()[0].GetRoutes()[0]
   305  		routeURL = "ws" + strings.TrimPrefix(route.BuildURL(), "http")
   306  
   307  		conn, resp, err := ws.DefaultDialer.Dial(routeURL, nil)
   308  		assert.Equal(t, "Value", resp.Header.Get("X-Test"))
   309  		assert.NoError(t, err, fmt.Sprintf("RESPONSE STATUS: %d, RESPONSE HEADERS: %v", resp.StatusCode, resp.Header))
   310  		assert.NoError(t, resp.Body.Close())
   311  		defer func() {
   312  			assert.NoError(t, conn.Close())
   313  		}()
   314  
   315  		message := []byte("hello world")
   316  		assert.NoError(t, conn.WriteMessage(ws.TextMessage, message))
   317  
   318  		messageType, data, err := conn.ReadMessage()
   319  		assert.NoError(t, err)
   320  		assert.Equal(t, ws.TextMessage, messageType)
   321  		assert.Equal(t, message, data)
   322  
   323  		m := ws.FormatCloseMessage(ws.CloseNormalClosure, "Connection closed by client")
   324  		assert.NoError(t, conn.WriteControl(ws.CloseMessage, m, time.Now().Add(time.Second)))
   325  	})
   326  
   327  	go func() {
   328  		assert.NoError(t, server.Start())
   329  		wg.Done()
   330  	}()
   331  	wg.Wait()
   332  }
   333  
   334  func TestUpgradeError(t *testing.T) {
   335  	server := testutil.NewTestServerWithOptions(t, prepareTestConfig())
   336  	server.RegisterRoutes(func(_ *goyave.Server, r *goyave.Router) {
   337  		upgrader := New(&testController{
   338  			t: t,
   339  			checkOrigin: func(_ *goyave.Request) bool {
   340  				return true
   341  			},
   342  		})
   343  		r.Subrouter("/websocket").Controller(upgrader)
   344  	})
   345  
   346  	resp := server.TestRequest(httptest.NewRequest(http.MethodGet, "/websocket", nil))
   347  
   348  	assert.Equal(t, "application/json; charset=utf-8", resp.Header.Get("Content-Type"))
   349  	assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   350  
   351  	body, err := testutil.ReadJSONBody[map[string]string](resp.Body)
   352  	assert.NoError(t, resp.Body.Close())
   353  	assert.NoError(t, err)
   354  	assert.Equal(t, map[string]string{"error": http.StatusText(http.StatusBadRequest)}, body)
   355  }
   356  
   357  func TestRegistrer(t *testing.T) {
   358  	server := testutil.NewTestServerWithOptions(t, prepareTestConfig())
   359  	upgrader := New(&testControllerRegistrer{
   360  		registerRoute: func(router *goyave.Router, handler goyave.Handler) {
   361  			router.Get("", handler).SetMeta("key", "value").Name("websocket")
   362  		},
   363  	})
   364  	router := server.Router()
   365  	router.Subrouter("/websocket").Controller(upgrader)
   366  
   367  	route := router.GetRoute("websocket")
   368  	if !assert.NotNil(t, route) {
   369  		return
   370  	}
   371  
   372  	assert.Equal(t, "value", route.Meta["key"])
   373  }
   374  
   375  func TestConnCloseHandshakeTimeout(t *testing.T) {
   376  	c := newConn(&ws.Conn{}, 0)
   377  
   378  	c.SetCloseHandshakeTimeout(time.Second * 2)
   379  	assert.Equal(t, time.Second*2, c.closeTimeout)
   380  	assert.Equal(t, time.Second*2, c.GetCloseHandshakeTimeout())
   381  }
   382  
   383  func TestCloseHandshakeTimeout(t *testing.T) {
   384  	wg := sync.WaitGroup{}
   385  	wg.Add(2)
   386  
   387  	var routeURL string
   388  	server := testutil.NewTestServerWithOptions(t, prepareTestConfig())
   389  	server.RegisterRoutes(func(_ *goyave.Server, r *goyave.Router) {
   390  		upgrader := New(&testController{
   391  			t:  t,
   392  			wg: &wg,
   393  			serve: func(_ *Conn, _ *goyave.Request) error {
   394  				return nil // Immediately return to trigger the close handshake
   395  			},
   396  			checkOrigin: func(_ *goyave.Request) bool {
   397  				return true
   398  			},
   399  		})
   400  		r.Subrouter("/websocket").Controller(upgrader)
   401  	})
   402  
   403  	server.RegisterStartupHook(func(s *goyave.Server) {
   404  		defer func() {
   405  			server.Stop()
   406  			wg.Done()
   407  		}()
   408  		route := s.Router().GetSubrouters()[0].GetRoutes()[0]
   409  		routeURL = "ws" + strings.TrimPrefix(route.BuildURL(), "http")
   410  
   411  		conn, resp, err := ws.DefaultDialer.Dial(routeURL, nil)
   412  		assert.NoError(t, resp.Body.Close())
   413  		assert.NoError(t, err, fmt.Sprintf("RESPONSE STATUS: %d, RESPONSE HEADERS: %v", resp.StatusCode, resp.Header))
   414  		defer func() {
   415  			assert.NoError(t, conn.Close())
   416  		}()
   417  		time.Sleep(1500 * time.Millisecond)
   418  
   419  		messageType, _, err := conn.ReadMessage()
   420  		assert.Error(t, err)
   421  
   422  		// The server has sent the close handshake payload with NormalClosureMessage
   423  		assert.Equal(t, &ws.CloseError{Code: ws.CloseNormalClosure, Text: NormalClosureMessage}, err)
   424  		assert.Equal(t, -1, messageType)
   425  	})
   426  
   427  	go func() {
   428  		assert.NoError(t, server.Start())
   429  		wg.Done()
   430  	}()
   431  	wg.Wait()
   432  }
   433  
   434  func TestCloseHandler(t *testing.T) {
   435  	c := newConn(&ws.Conn{}, 1*time.Second)
   436  
   437  	assert.NoError(t, c.closeHandler(ws.CloseNormalClosure, ""))
   438  	select {
   439  	case <-c.waitClose:
   440  	default:
   441  		assert.Fail(t, "Expected waitClose to not be empty")
   442  	}
   443  }
   444  
   445  func TestGracefulClose(t *testing.T) {
   446  
   447  	cases := []struct {
   448  		expectedError *ws.CloseError
   449  		serve         func(conn *Conn, r *goyave.Request) error
   450  		errorHandler  func(c *testControllerWithErrorHandler, request *goyave.Request, err error)
   451  		expectedLogs  *regexp.Regexp
   452  		desc          string
   453  	}{
   454  		{
   455  			desc: "recovery",
   456  			serve: func(_ *Conn, _ *goyave.Request) error {
   457  				panic("websocket handler panic")
   458  			},
   459  			expectedError: &ws.CloseError{Code: ws.CloseInternalServerErr, Text: http.StatusText(http.StatusInternalServerError)},
   460  			expectedLogs:  regexp.MustCompile(`{"time":"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,9}((\+\d{2}:\d{2})|Z)?","level":"ERROR","msg":"websocket handler panic","trace":".+"}\n`),
   461  		},
   462  		{
   463  			desc: "normal_error",
   464  			serve: func(_ *Conn, _ *goyave.Request) error {
   465  				return errors.New("websocket handler error")
   466  			},
   467  			expectedError: &ws.CloseError{Code: ws.CloseInternalServerErr, Text: http.StatusText(http.StatusInternalServerError)},
   468  			expectedLogs:  regexp.MustCompile(`{"time":"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,9}((\+\d{2}:\d{2})|Z)?","level":"ERROR","msg":"websocket handler error","trace":".+"}\n`),
   469  		},
   470  		{
   471  			desc: "erro_handler",
   472  			serve: func(_ *Conn, _ *goyave.Request) error {
   473  				return errors.New("websocket handler error")
   474  			},
   475  			errorHandler: func(c *testControllerWithErrorHandler, _ *goyave.Request, _ error) {
   476  				c.Logger().Info("message override")
   477  			},
   478  			expectedError: &ws.CloseError{Code: ws.CloseInternalServerErr, Text: http.StatusText(http.StatusInternalServerError)},
   479  			expectedLogs:  regexp.MustCompile(`{"time":"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,9}((\+\d{2}:\d{2})|Z)?","level":"INFO","msg":"message override"}\n`),
   480  		},
   481  		{
   482  			desc: "normal_serve_closure",
   483  			serve: func(_ *Conn, _ *goyave.Request) error {
   484  				return nil
   485  			},
   486  			expectedError: &ws.CloseError{Code: ws.CloseNormalClosure, Text: NormalClosureMessage},
   487  		},
   488  	}
   489  
   490  	for _, c := range cases {
   491  		c := c
   492  		t.Run(c.desc, func(t *testing.T) {
   493  			t.Parallel()
   494  			wg := sync.WaitGroup{}
   495  			wg.Add(2)
   496  
   497  			var routeURL string
   498  			server := testutil.NewTestServerWithOptions(t, prepareTestConfig())
   499  			server.RegisterRoutes(func(_ *goyave.Server, r *goyave.Router) {
   500  				var ctrl Controller = &testController{
   501  					t:     t,
   502  					wg:    &wg,
   503  					serve: c.serve,
   504  
   505  					checkOrigin: func(_ *goyave.Request) bool {
   506  						return true
   507  					},
   508  				}
   509  				if c.errorHandler != nil {
   510  					ctrl = &testControllerWithErrorHandler{
   511  						testController: *ctrl.(*testController),
   512  						onError:        c.errorHandler,
   513  					}
   514  				}
   515  				upgrader := New(ctrl)
   516  				r.Subrouter("/websocket").Controller(upgrader)
   517  			})
   518  			buf := &bytes.Buffer{}
   519  			server.Logger = slog.New(stdslog.NewJSONHandler(buf, &stdslog.HandlerOptions{Level: stdslog.LevelInfo}))
   520  
   521  			server.RegisterStartupHook(func(s *goyave.Server) {
   522  				defer func() {
   523  					server.Stop()
   524  					wg.Done()
   525  				}()
   526  				route := s.Router().GetSubrouters()[0].GetRoutes()[0]
   527  				routeURL = "ws" + strings.TrimPrefix(route.BuildURL(), "http")
   528  
   529  				testGracefulClose(t, routeURL, c.expectedError)
   530  			})
   531  
   532  			go func() {
   533  				assert.NoError(t, server.Start())
   534  				wg.Done()
   535  			}()
   536  			wg.Wait()
   537  			if c.expectedLogs != nil {
   538  				assert.Regexp(t, c.expectedLogs, buf.String())
   539  			}
   540  		})
   541  	}
   542  
   543  }
   544  
   545  func testGracefulClose(t *testing.T, routeURL string, expectedError *ws.CloseError) {
   546  	conn, resp, err := ws.DefaultDialer.Dial(routeURL, nil)
   547  	assert.NoError(t, resp.Body.Close())
   548  	assert.NoError(t, err, fmt.Sprintf("RESPONSE STATUS: %d, RESPONSE HEADERS: %v", resp.StatusCode, resp.Header))
   549  	defer func() {
   550  		assert.NoError(t, conn.Close())
   551  	}()
   552  
   553  	messageType, _, err := conn.ReadMessage()
   554  	assert.Equal(t, expectedError, err)
   555  
   556  	// advanceFrame returns noFrame (-1) when a close frame is received
   557  	assert.Equal(t, -1, messageType)
   558  }
   559  
   560  func TestCloseConnectionClosed(t *testing.T) {
   561  	wg := sync.WaitGroup{}
   562  	wg.Add(3)
   563  
   564  	var routeURL string
   565  	server := testutil.NewTestServerWithOptions(t, prepareTestConfig())
   566  	server.RegisterRoutes(func(s *goyave.Server, r *goyave.Router) {
   567  		upgrader := New(&testController{
   568  			t: t,
   569  			checkOrigin: func(_ *goyave.Request) bool {
   570  				return true
   571  			},
   572  		})
   573  		upgrader.Init(s)
   574  		upgrader.Controller.Init(s)
   575  		r.Subrouter("/websocket").Get("", func(response *goyave.Response, request *goyave.Request) {
   576  			defer wg.Done()
   577  			c, err := upgrader.makeUpgrader(request).Upgrade(response, request.Request(), nil)
   578  			assert.NoError(t, err)
   579  			response.Status(http.StatusSwitchingProtocols)
   580  
   581  			conn := newConn(c, time.Second)
   582  
   583  			assert.NoError(t, conn.Conn.Close()) // Connection closed right away, server wont be able to write anymore
   584  			err = conn.CloseNormal()
   585  			assert.Error(t, err)
   586  			assert.Contains(t, err.Error(), "use of closed network connection")
   587  		})
   588  	})
   589  
   590  	server.RegisterStartupHook(func(s *goyave.Server) {
   591  		defer func() {
   592  			server.Stop()
   593  			wg.Done()
   594  		}()
   595  		route := s.Router().GetSubrouters()[0].GetRoutes()[0]
   596  		routeURL = "ws" + strings.TrimPrefix(route.BuildURL(), "http")
   597  
   598  		conn, resp, err := ws.DefaultDialer.Dial(routeURL, nil)
   599  		assert.NoError(t, resp.Body.Close())
   600  		assert.NoError(t, err, fmt.Sprintf("RESPONSE STATUS: %d, RESPONSE HEADERS: %v", resp.StatusCode, resp.Header))
   601  		defer func() {
   602  			assert.NoError(t, conn.Close())
   603  		}()
   604  
   605  		_, _, err = conn.ReadMessage()
   606  		assert.Error(t, err)
   607  		assert.Equal(t, &ws.CloseError{Code: ws.CloseAbnormalClosure, Text: "unexpected EOF"}, err)
   608  	})
   609  
   610  	go func() {
   611  		assert.NoError(t, server.Start())
   612  		wg.Done()
   613  	}()
   614  	wg.Wait()
   615  }
   616  
   617  func TestCloseWriteTimeout(t *testing.T) {
   618  	wg := sync.WaitGroup{}
   619  	wg.Add(3)
   620  
   621  	var routeURL string
   622  	server := testutil.NewTestServerWithOptions(t, prepareTestConfig())
   623  	server.RegisterRoutes(func(s *goyave.Server, r *goyave.Router) {
   624  		upgrader := New(&testController{
   625  			t: t,
   626  			checkOrigin: func(_ *goyave.Request) bool {
   627  				return true
   628  			},
   629  		})
   630  		upgrader.Init(s)
   631  		upgrader.Controller.Init(s)
   632  		r.Subrouter("/websocket").Get("", func(response *goyave.Response, request *goyave.Request) {
   633  			defer wg.Done()
   634  			c, err := upgrader.makeUpgrader(request).Upgrade(response, request.Request(), nil)
   635  			assert.NoError(t, err)
   636  			response.Status(http.StatusSwitchingProtocols)
   637  
   638  			conn := newConn(c, time.Second)
   639  			conn.closeTimeout = -1 * time.Second
   640  
   641  			// No error expected, the connection should close as normal without waiting
   642  			assert.NoError(t, conn.CloseNormal())
   643  		})
   644  	})
   645  
   646  	server.RegisterStartupHook(func(s *goyave.Server) {
   647  		defer func() {
   648  			server.Stop()
   649  			wg.Done()
   650  		}()
   651  		route := s.Router().GetSubrouters()[0].GetRoutes()[0]
   652  		routeURL = "ws" + strings.TrimPrefix(route.BuildURL(), "http")
   653  
   654  		conn, resp, err := ws.DefaultDialer.Dial(routeURL, nil)
   655  		assert.NoError(t, resp.Body.Close())
   656  		assert.NoError(t, err, fmt.Sprintf("RESPONSE STATUS: %d, RESPONSE HEADERS: %v", resp.StatusCode, resp.Header))
   657  		assert.NoError(t, conn.Close())
   658  	})
   659  
   660  	go func() {
   661  		assert.NoError(t, server.Start())
   662  		wg.Done()
   663  	}()
   664  	wg.Wait()
   665  }