github.com/shippio/gqlgen@v0.0.0-20220912092219-633ea699ef07/graphql/handler/transport/websocket_test.go (about)

     1  package transport_test
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/99designs/gqlgen/client"
    14  	"github.com/99designs/gqlgen/graphql"
    15  	"github.com/99designs/gqlgen/graphql/handler"
    16  	"github.com/99designs/gqlgen/graphql/handler/testserver"
    17  	"github.com/99designs/gqlgen/graphql/handler/transport"
    18  	"github.com/gorilla/websocket"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/vektah/gqlparser/v2"
    22  	"github.com/vektah/gqlparser/v2/ast"
    23  )
    24  
    25  type ckey string
    26  
    27  func TestWebsocket(t *testing.T) {
    28  	handler := testserver.New()
    29  	handler.AddTransport(transport.Websocket{})
    30  
    31  	srv := httptest.NewServer(handler)
    32  	defer srv.Close()
    33  
    34  	t.Run("client must send valid json", func(t *testing.T) {
    35  		c := wsConnect(srv.URL)
    36  		defer c.Close()
    37  
    38  		writeRaw(c, "hello")
    39  
    40  		msg := readOp(c)
    41  		assert.Equal(t, "connection_error", msg.Type)
    42  		assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload))
    43  	})
    44  
    45  	t.Run("client can terminate before init", func(t *testing.T) {
    46  		c := wsConnect(srv.URL)
    47  		defer c.Close()
    48  
    49  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    50  
    51  		_, _, err := c.ReadMessage()
    52  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    53  	})
    54  
    55  	t.Run("client must send init first", func(t *testing.T) {
    56  		c := wsConnect(srv.URL)
    57  		defer c.Close()
    58  
    59  		require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg}))
    60  
    61  		msg := readOp(c)
    62  		assert.Equal(t, connectionErrorMsg, msg.Type)
    63  		assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload))
    64  	})
    65  
    66  	t.Run("server acks init", func(t *testing.T) {
    67  		c := wsConnect(srv.URL)
    68  		defer c.Close()
    69  
    70  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    71  
    72  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    73  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    74  	})
    75  
    76  	t.Run("client can terminate before run", func(t *testing.T) {
    77  		c := wsConnect(srv.URL)
    78  		defer c.Close()
    79  
    80  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    81  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    82  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    83  
    84  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    85  
    86  		_, _, err := c.ReadMessage()
    87  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    88  	})
    89  
    90  	t.Run("client gets parse errors", func(t *testing.T) {
    91  		c := wsConnect(srv.URL)
    92  		defer c.Close()
    93  
    94  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    95  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    96  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    97  
    98  		require.NoError(t, c.WriteJSON(&operationMessage{
    99  			Type:    startMsg,
   100  			ID:      "test_1",
   101  			Payload: json.RawMessage(`{"query": "!"}`),
   102  		}))
   103  
   104  		msg := readOp(c)
   105  		assert.Equal(t, errorMsg, msg.Type)
   106  		assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload))
   107  	})
   108  
   109  	t.Run("client can receive data", func(t *testing.T) {
   110  		c := wsConnect(srv.URL)
   111  		defer c.Close()
   112  
   113  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   114  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   115  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   116  
   117  		require.NoError(t, c.WriteJSON(&operationMessage{
   118  			Type:    startMsg,
   119  			ID:      "test_1",
   120  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   121  		}))
   122  
   123  		handler.SendNextSubscriptionMessage()
   124  		msg := readOp(c)
   125  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   126  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   127  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   128  
   129  		handler.SendNextSubscriptionMessage()
   130  		msg = readOp(c)
   131  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   132  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   133  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   134  
   135  		require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"}))
   136  
   137  		msg = readOp(c)
   138  		require.Equal(t, completeMsg, msg.Type)
   139  		require.Equal(t, "test_1", msg.ID)
   140  
   141  		// At this point we should be done and should not receive another message.
   142  		c.SetReadDeadline(time.Now().UTC().Add(1 * time.Millisecond))
   143  
   144  		err := c.ReadJSON(&msg)
   145  		if err == nil {
   146  			// This should not send a second close message for the same id.
   147  			require.NotEqual(t, completeMsg, msg.Type)
   148  			require.NotEqual(t, "test_1", msg.ID)
   149  		} else {
   150  			assert.Contains(t, err.Error(), "timeout")
   151  		}
   152  
   153  	})
   154  }
   155  
   156  func TestWebsocketWithKeepAlive(t *testing.T) {
   157  	h := testserver.New()
   158  	h.AddTransport(transport.Websocket{
   159  		KeepAlivePingInterval: 100 * time.Millisecond,
   160  	})
   161  
   162  	srv := httptest.NewServer(h)
   163  	defer srv.Close()
   164  
   165  	c := wsConnect(srv.URL)
   166  	defer c.Close()
   167  
   168  	require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   169  	assert.Equal(t, connectionAckMsg, readOp(c).Type)
   170  	assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   171  
   172  	require.NoError(t, c.WriteJSON(&operationMessage{
   173  		Type:    startMsg,
   174  		ID:      "test_1",
   175  		Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   176  	}))
   177  
   178  	// keepalive
   179  	msg := readOp(c)
   180  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   181  
   182  	// server message
   183  	h.SendNextSubscriptionMessage()
   184  	msg = readOp(c)
   185  	assert.Equal(t, dataMsg, msg.Type)
   186  
   187  	// keepalive
   188  	msg = readOp(c)
   189  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   190  }
   191  
   192  func TestWebsocketInitFunc(t *testing.T) {
   193  	t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) {
   194  		h := testserver.New()
   195  		h.AddTransport(transport.Websocket{})
   196  		srv := httptest.NewServer(h)
   197  		defer srv.Close()
   198  
   199  		c := wsConnect(srv.URL)
   200  		defer c.Close()
   201  
   202  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   203  
   204  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   205  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   206  	})
   207  
   208  	t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   209  		h := testserver.New()
   210  		h.AddTransport(transport.Websocket{
   211  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   212  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
   213  			},
   214  		})
   215  		srv := httptest.NewServer(h)
   216  		defer srv.Close()
   217  
   218  		c := wsConnect(srv.URL)
   219  		defer c.Close()
   220  
   221  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   222  
   223  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   224  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   225  	})
   226  
   227  	t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   228  		h := testserver.New()
   229  		h.AddTransport(transport.Websocket{
   230  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   231  				return ctx, errors.New("invalid init payload")
   232  			},
   233  		})
   234  		srv := httptest.NewServer(h)
   235  		defer srv.Close()
   236  
   237  		c := wsConnect(srv.URL)
   238  		defer c.Close()
   239  
   240  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   241  
   242  		msg := readOp(c)
   243  		assert.Equal(t, connectionErrorMsg, msg.Type)
   244  		assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload))
   245  	})
   246  
   247  	t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) {
   248  		es := &graphql.ExecutableSchemaMock{
   249  			ExecFunc: func(ctx context.Context) graphql.ResponseHandler {
   250  				assert.Equal(t, "newvalue", ctx.Value(ckey("newkey")))
   251  				return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)})
   252  			},
   253  			SchemaFunc: func() *ast.Schema {
   254  				return gqlparser.MustLoadSchema(&ast.Source{Input: `
   255  				schema { query: Query }
   256  				type Query {
   257  					empty: String
   258  				}
   259  			`})
   260  			},
   261  		}
   262  		h := handler.New(es)
   263  
   264  		h.AddTransport(transport.Websocket{
   265  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   266  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
   267  			},
   268  		})
   269  
   270  		c := client.New(h)
   271  
   272  		socket := c.Websocket("{ empty } ")
   273  		defer socket.Close()
   274  		var resp struct {
   275  			Empty string
   276  		}
   277  		err := socket.Next(&resp)
   278  		require.NoError(t, err)
   279  		assert.Equal(t, "ok", resp.Empty)
   280  	})
   281  
   282  	t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) {
   283  		h := testserver.New()
   284  		var cancel func()
   285  		h.AddTransport(transport.Websocket{
   286  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
   287  				newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5)
   288  				return
   289  			},
   290  		})
   291  		srv := httptest.NewServer(h)
   292  		defer srv.Close()
   293  
   294  		c := wsConnect(srv.URL)
   295  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   296  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   297  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   298  
   299  		// Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy)
   300  		defer cancel()
   301  
   302  		time.Sleep(time.Millisecond * 10)
   303  		m := readOp(c)
   304  		assert.Equal(t, m.Type, connectionErrorMsg)
   305  		assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`)
   306  	})
   307  }
   308  
   309  func TestWebSocketInitTimeout(t *testing.T) {
   310  	t.Run("times out if no init message is received within the configured duration", func(t *testing.T) {
   311  		h := testserver.New()
   312  		h.AddTransport(transport.Websocket{
   313  			InitTimeout: 5 * time.Millisecond,
   314  		})
   315  		srv := httptest.NewServer(h)
   316  		defer srv.Close()
   317  
   318  		c := wsConnect(srv.URL)
   319  		defer c.Close()
   320  
   321  		var msg operationMessage
   322  		err := c.ReadJSON(&msg)
   323  		assert.Error(t, err)
   324  		assert.Contains(t, err.Error(), "timeout")
   325  	})
   326  
   327  	t.Run("keeps waiting for an init message if no time out is configured", func(t *testing.T) {
   328  		h := testserver.New()
   329  		h.AddTransport(transport.Websocket{})
   330  		srv := httptest.NewServer(h)
   331  		defer srv.Close()
   332  
   333  		c := wsConnect(srv.URL)
   334  		defer c.Close()
   335  
   336  		done := make(chan interface{}, 1)
   337  		go func() {
   338  			var msg operationMessage
   339  			_ = c.ReadJSON(&msg)
   340  			done <- 1
   341  		}()
   342  
   343  		select {
   344  		case <-done:
   345  			assert.Fail(t, "web socket read operation finished while it shouldn't have")
   346  		case <-time.After(100 * time.Millisecond):
   347  			// Success! I guess? Can't really wait forever to see if the read waits forever...
   348  		}
   349  	})
   350  }
   351  
   352  func TestWebSocketErrorFunc(t *testing.T) {
   353  	t.Run("the error handler gets called when an error occurs", func(t *testing.T) {
   354  		errFuncCalled := make(chan bool, 1)
   355  		h := testserver.New()
   356  		h.AddTransport(transport.Websocket{
   357  			ErrorFunc: func(_ context.Context, err error) {
   358  				require.Error(t, err)
   359  				assert.Equal(t, err.Error(), "websocket read: invalid message received")
   360  				assert.IsType(t, transport.WebsocketError{}, err)
   361  				assert.True(t, err.(transport.WebsocketError).IsReadError)
   362  				errFuncCalled <- true
   363  			},
   364  		})
   365  
   366  		srv := httptest.NewServer(h)
   367  		defer srv.Close()
   368  
   369  		c := wsConnect(srv.URL)
   370  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   371  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   372  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   373  		require.NoError(t, c.WriteMessage(websocket.TextMessage, []byte("mark my words, you will regret this")))
   374  
   375  		select {
   376  		case res := <-errFuncCalled:
   377  			assert.True(t, res)
   378  		case <-time.NewTimer(time.Millisecond * 20).C:
   379  			assert.Fail(t, "The fail handler was not called in time")
   380  		}
   381  	})
   382  
   383  	t.Run("init func errors do not call the error handler", func(t *testing.T) {
   384  		h := testserver.New()
   385  		h.AddTransport(transport.Websocket{
   386  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
   387  				return ctx, errors.New("this is not what we agreed upon")
   388  			},
   389  			ErrorFunc: func(_ context.Context, err error) {
   390  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   391  			},
   392  		})
   393  		srv := httptest.NewServer(h)
   394  		defer srv.Close()
   395  
   396  		c := wsConnect(srv.URL)
   397  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   398  		time.Sleep(time.Millisecond * 20)
   399  	})
   400  
   401  	t.Run("init func context closes do not call the error handler", func(t *testing.T) {
   402  		h := testserver.New()
   403  		h.AddTransport(transport.Websocket{
   404  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
   405  				newCtx, cancel := context.WithCancel(ctx)
   406  				time.AfterFunc(time.Millisecond*5, cancel)
   407  				return newCtx, nil
   408  			},
   409  			ErrorFunc: func(_ context.Context, err error) {
   410  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   411  			},
   412  		})
   413  		srv := httptest.NewServer(h)
   414  		defer srv.Close()
   415  
   416  		c := wsConnect(srv.URL)
   417  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   418  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   419  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   420  		time.Sleep(time.Millisecond * 20)
   421  	})
   422  
   423  	t.Run("init func context deadlines do not call the error handler", func(t *testing.T) {
   424  		h := testserver.New()
   425  		var cancel func()
   426  		h.AddTransport(transport.Websocket{
   427  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
   428  				newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5))
   429  				return newCtx, nil
   430  			},
   431  			ErrorFunc: func(_ context.Context, err error) {
   432  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   433  			},
   434  		})
   435  		srv := httptest.NewServer(h)
   436  		defer srv.Close()
   437  
   438  		c := wsConnect(srv.URL)
   439  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   440  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   441  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   442  
   443  		// Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy)
   444  		defer cancel()
   445  
   446  		time.Sleep(time.Millisecond * 20)
   447  	})
   448  }
   449  
   450  func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) {
   451  	initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
   452  		h := testserver.New()
   453  		h.AddTransport(ws)
   454  		return h, httptest.NewServer(h)
   455  	}
   456  
   457  	t.Run("server acks init", func(t *testing.T) {
   458  		_, srv := initialize(transport.Websocket{})
   459  		defer srv.Close()
   460  
   461  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   462  		defer c.Close()
   463  
   464  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   465  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   466  	})
   467  
   468  	t.Run("client can receive data", func(t *testing.T) {
   469  		handler, srv := initialize(transport.Websocket{})
   470  		defer srv.Close()
   471  
   472  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   473  		defer c.Close()
   474  
   475  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   476  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   477  
   478  		require.NoError(t, c.WriteJSON(&operationMessage{
   479  			Type:    graphqltransportwsSubscribeMsg,
   480  			ID:      "test_1",
   481  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   482  		}))
   483  
   484  		handler.SendNextSubscriptionMessage()
   485  		msg := readOp(c)
   486  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   487  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   488  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   489  
   490  		handler.SendNextSubscriptionMessage()
   491  		msg = readOp(c)
   492  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   493  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   494  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   495  
   496  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"}))
   497  
   498  		msg = readOp(c)
   499  		require.Equal(t, graphqltransportwsCompleteMsg, msg.Type)
   500  		require.Equal(t, "test_1", msg.ID)
   501  	})
   502  
   503  	t.Run("receives no graphql-ws keep alive messages", func(t *testing.T) {
   504  		_, srv := initialize(transport.Websocket{KeepAlivePingInterval: 5 * time.Millisecond})
   505  		defer srv.Close()
   506  
   507  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   508  		defer c.Close()
   509  
   510  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   511  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   512  
   513  		// If the keep-alives are sent, this deadline will not be used, and no timeout error will be found
   514  		c.SetReadDeadline(time.Now().UTC().Add(50 * time.Millisecond))
   515  		var msg operationMessage
   516  		err := c.ReadJSON(&msg)
   517  		require.Error(t, err)
   518  		assert.Contains(t, err.Error(), "timeout")
   519  	})
   520  }
   521  
   522  func TestWebsocketWithPingPongInterval(t *testing.T) {
   523  	initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
   524  		h := testserver.New()
   525  		h.AddTransport(ws)
   526  		return h, httptest.NewServer(h)
   527  	}
   528  
   529  	t.Run("client receives ping and responds with pong", func(t *testing.T) {
   530  		_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
   531  		defer srv.Close()
   532  
   533  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   534  		defer c.Close()
   535  
   536  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   537  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   538  
   539  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   540  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg}))
   541  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   542  	})
   543  
   544  	t.Run("client sends ping and expects pong", func(t *testing.T) {
   545  		_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
   546  		defer srv.Close()
   547  
   548  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   549  		defer c.Close()
   550  
   551  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   552  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   553  
   554  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg}))
   555  		assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
   556  	})
   557  
   558  	t.Run("ping-pongs are not sent when the graphql-ws sub protocol is used", func(t *testing.T) {
   559  		// Regression test
   560  		// ---
   561  		// Before the refactor, the code would try to convert a ping message to a graphql-ws message type
   562  		// But since this message type does not exist in the graphql-ws sub protocol, it would fail
   563  
   564  		_, srv := initialize(transport.Websocket{
   565  			PingPongInterval:      5 * time.Millisecond,
   566  			KeepAlivePingInterval: 10 * time.Millisecond,
   567  		})
   568  		defer srv.Close()
   569  
   570  		// Create connection
   571  		c := wsConnect(srv.URL)
   572  		defer c.Close()
   573  
   574  		// Initialize connection
   575  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   576  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   577  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   578  
   579  		// Wait for a few more keep alives to be sure nothing goes wrong
   580  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   581  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   582  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   583  	})
   584  }
   585  
   586  func wsConnect(url string) *websocket.Conn {
   587  	return wsConnectWithSubprocotol(url, "")
   588  }
   589  
   590  func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn {
   591  	h := make(http.Header)
   592  	if subprocotol != "" {
   593  		h.Add("Sec-WebSocket-Protocol", subprocotol)
   594  	}
   595  
   596  	c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h)
   597  	if err != nil {
   598  		panic(err)
   599  	}
   600  	_ = resp.Body.Close()
   601  
   602  	return c
   603  }
   604  
   605  func writeRaw(conn *websocket.Conn, msg string) {
   606  	if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
   607  		panic(err)
   608  	}
   609  }
   610  
   611  func readOp(conn *websocket.Conn) operationMessage {
   612  	var msg operationMessage
   613  	if err := conn.ReadJSON(&msg); err != nil {
   614  		panic(err)
   615  	}
   616  	return msg
   617  }
   618  
   619  // copied out from websocket_graphqlws.go to keep these private
   620  
   621  const (
   622  	connectionInitMsg      = "connection_init"      // Client -> Server
   623  	connectionTerminateMsg = "connection_terminate" // Client -> Server
   624  	startMsg               = "start"                // Client -> Server
   625  	stopMsg                = "stop"                 // Client -> Server
   626  	connectionAckMsg       = "connection_ack"       // Server -> Client
   627  	connectionErrorMsg     = "connection_error"     // Server -> Client
   628  	dataMsg                = "data"                 // Server -> Client
   629  	errorMsg               = "error"                // Server -> Client
   630  	completeMsg            = "complete"             // Server -> Client
   631  	connectionKeepAliveMsg = "ka"                   // Server -> Client
   632  )
   633  
   634  // copied out from websocket_graphql_transport_ws.go to keep these private
   635  
   636  const (
   637  	graphqltransportwsSubprotocol = "graphql-transport-ws"
   638  
   639  	graphqltransportwsConnectionInitMsg = "connection_init"
   640  	graphqltransportwsConnectionAckMsg  = "connection_ack"
   641  	graphqltransportwsSubscribeMsg      = "subscribe"
   642  	graphqltransportwsNextMsg           = "next"
   643  	graphqltransportwsCompleteMsg       = "complete"
   644  	graphqltransportwsPingMsg           = "ping"
   645  	graphqltransportwsPongMsg           = "pong"
   646  )
   647  
   648  type operationMessage struct {
   649  	Payload json.RawMessage `json:"payload,omitempty"`
   650  	ID      string          `json:"id,omitempty"`
   651  	Type    string          `json:"type"`
   652  }