github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/graphql/handler/transport/websocket_test.go (about)

     1  package transport_test
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/gorilla/websocket"
    14  	"github.com/mstephano/gqlgen-schemagen/client"
    15  	"github.com/mstephano/gqlgen-schemagen/graphql"
    16  	"github.com/mstephano/gqlgen-schemagen/graphql/handler"
    17  	"github.com/mstephano/gqlgen-schemagen/graphql/handler/testserver"
    18  	"github.com/mstephano/gqlgen-schemagen/graphql/handler/transport"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/vektah/gqlparser/v2"
    22  	"github.com/vektah/gqlparser/v2/ast"
    23  )
    24  
    25  type ckey string
    26  
    27  func TestWebsocket(t *testing.T) {
    28  	handler := testserver.New()
    29  	handler.AddTransport(transport.Websocket{})
    30  
    31  	srv := httptest.NewServer(handler)
    32  	defer srv.Close()
    33  
    34  	t.Run("client must send valid json", func(t *testing.T) {
    35  		c := wsConnect(srv.URL)
    36  		defer c.Close()
    37  
    38  		writeRaw(c, "hello")
    39  
    40  		msg := readOp(c)
    41  		assert.Equal(t, "connection_error", msg.Type)
    42  		assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload))
    43  	})
    44  
    45  	t.Run("client can terminate before init", func(t *testing.T) {
    46  		c := wsConnect(srv.URL)
    47  		defer c.Close()
    48  
    49  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    50  
    51  		_, _, err := c.ReadMessage()
    52  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    53  	})
    54  
    55  	t.Run("client must send init first", func(t *testing.T) {
    56  		c := wsConnect(srv.URL)
    57  		defer c.Close()
    58  
    59  		require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg}))
    60  
    61  		msg := readOp(c)
    62  		assert.Equal(t, connectionErrorMsg, msg.Type)
    63  		assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload))
    64  	})
    65  
    66  	t.Run("server acks init", func(t *testing.T) {
    67  		c := wsConnect(srv.URL)
    68  		defer c.Close()
    69  
    70  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    71  
    72  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    73  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    74  	})
    75  
    76  	t.Run("client can terminate before run", func(t *testing.T) {
    77  		c := wsConnect(srv.URL)
    78  		defer c.Close()
    79  
    80  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    81  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    82  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    83  
    84  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    85  
    86  		_, _, err := c.ReadMessage()
    87  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    88  	})
    89  
    90  	t.Run("client gets parse errors", func(t *testing.T) {
    91  		c := wsConnect(srv.URL)
    92  		defer c.Close()
    93  
    94  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    95  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    96  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    97  
    98  		require.NoError(t, c.WriteJSON(&operationMessage{
    99  			Type:    startMsg,
   100  			ID:      "test_1",
   101  			Payload: json.RawMessage(`{"query": "!"}`),
   102  		}))
   103  
   104  		msg := readOp(c)
   105  		assert.Equal(t, errorMsg, msg.Type)
   106  		assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload))
   107  	})
   108  
   109  	t.Run("client can receive data", func(t *testing.T) {
   110  		c := wsConnect(srv.URL)
   111  		defer c.Close()
   112  
   113  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   114  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   115  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   116  
   117  		require.NoError(t, c.WriteJSON(&operationMessage{
   118  			Type:    startMsg,
   119  			ID:      "test_1",
   120  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   121  		}))
   122  
   123  		handler.SendNextSubscriptionMessage()
   124  		msg := readOp(c)
   125  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   126  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   127  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   128  
   129  		handler.SendNextSubscriptionMessage()
   130  		msg = readOp(c)
   131  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   132  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   133  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   134  
   135  		require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"}))
   136  
   137  		msg = readOp(c)
   138  		require.Equal(t, completeMsg, msg.Type)
   139  		require.Equal(t, "test_1", msg.ID)
   140  
   141  		// At this point we should be done and should not receive another message.
   142  		c.SetReadDeadline(time.Now().UTC().Add(1 * time.Millisecond))
   143  
   144  		err := c.ReadJSON(&msg)
   145  		if err == nil {
   146  			// This should not send a second close message for the same id.
   147  			require.NotEqual(t, completeMsg, msg.Type)
   148  			require.NotEqual(t, "test_1", msg.ID)
   149  		} else {
   150  			assert.Contains(t, err.Error(), "timeout")
   151  		}
   152  	})
   153  }
   154  
   155  func TestWebsocketWithKeepAlive(t *testing.T) {
   156  	h := testserver.New()
   157  	h.AddTransport(transport.Websocket{
   158  		KeepAlivePingInterval: 100 * time.Millisecond,
   159  	})
   160  
   161  	srv := httptest.NewServer(h)
   162  	defer srv.Close()
   163  
   164  	c := wsConnect(srv.URL)
   165  	defer c.Close()
   166  
   167  	require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   168  	assert.Equal(t, connectionAckMsg, readOp(c).Type)
   169  	assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   170  
   171  	require.NoError(t, c.WriteJSON(&operationMessage{
   172  		Type:    startMsg,
   173  		ID:      "test_1",
   174  		Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   175  	}))
   176  
   177  	// keepalive
   178  	msg := readOp(c)
   179  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   180  
   181  	// server message
   182  	h.SendNextSubscriptionMessage()
   183  	msg = readOp(c)
   184  	assert.Equal(t, dataMsg, msg.Type)
   185  
   186  	// keepalive
   187  	msg = readOp(c)
   188  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   189  }
   190  
   191  func TestWebsocketInitFunc(t *testing.T) {
   192  	t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) {
   193  		h := testserver.New()
   194  		h.AddTransport(transport.Websocket{})
   195  		srv := httptest.NewServer(h)
   196  		defer srv.Close()
   197  
   198  		c := wsConnect(srv.URL)
   199  		defer c.Close()
   200  
   201  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   202  
   203  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   204  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   205  	})
   206  
   207  	t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   208  		h := testserver.New()
   209  		h.AddTransport(transport.Websocket{
   210  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   211  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
   212  			},
   213  		})
   214  		srv := httptest.NewServer(h)
   215  		defer srv.Close()
   216  
   217  		c := wsConnect(srv.URL)
   218  		defer c.Close()
   219  
   220  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   221  
   222  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   223  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   224  	})
   225  
   226  	t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   227  		h := testserver.New()
   228  		h.AddTransport(transport.Websocket{
   229  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   230  				return ctx, errors.New("invalid init payload")
   231  			},
   232  		})
   233  		srv := httptest.NewServer(h)
   234  		defer srv.Close()
   235  
   236  		c := wsConnect(srv.URL)
   237  		defer c.Close()
   238  
   239  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   240  
   241  		msg := readOp(c)
   242  		assert.Equal(t, connectionErrorMsg, msg.Type)
   243  		assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload))
   244  	})
   245  
   246  	t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) {
   247  		es := &graphql.ExecutableSchemaMock{
   248  			ExecFunc: func(ctx context.Context) graphql.ResponseHandler {
   249  				assert.Equal(t, "newvalue", ctx.Value(ckey("newkey")))
   250  				return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)})
   251  			},
   252  			SchemaFunc: func() *ast.Schema {
   253  				return gqlparser.MustLoadSchema(&ast.Source{Input: `
   254  				schema { query: Query }
   255  				type Query {
   256  					empty: String
   257  				}
   258  			`})
   259  			},
   260  		}
   261  		h := handler.New(es)
   262  
   263  		h.AddTransport(transport.Websocket{
   264  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   265  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
   266  			},
   267  		})
   268  
   269  		c := client.New(h)
   270  
   271  		socket := c.Websocket("{ empty } ")
   272  		defer socket.Close()
   273  		var resp struct {
   274  			Empty string
   275  		}
   276  		err := socket.Next(&resp)
   277  		require.NoError(t, err)
   278  		assert.Equal(t, "ok", resp.Empty)
   279  	})
   280  
   281  	t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) {
   282  		h := testserver.New()
   283  		var cancel func()
   284  		h.AddTransport(transport.Websocket{
   285  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
   286  				newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5)
   287  				return
   288  			},
   289  		})
   290  		srv := httptest.NewServer(h)
   291  		defer srv.Close()
   292  
   293  		c := wsConnect(srv.URL)
   294  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   295  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   296  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   297  
   298  		// Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy)
   299  		defer cancel()
   300  
   301  		time.Sleep(time.Millisecond * 10)
   302  		m := readOp(c)
   303  		assert.Equal(t, m.Type, connectionErrorMsg)
   304  		assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`)
   305  	})
   306  }
   307  
   308  func TestWebSocketInitTimeout(t *testing.T) {
   309  	t.Run("times out if no init message is received within the configured duration", func(t *testing.T) {
   310  		h := testserver.New()
   311  		h.AddTransport(transport.Websocket{
   312  			InitTimeout: 5 * time.Millisecond,
   313  		})
   314  		srv := httptest.NewServer(h)
   315  		defer srv.Close()
   316  
   317  		c := wsConnect(srv.URL)
   318  		defer c.Close()
   319  
   320  		var msg operationMessage
   321  		err := c.ReadJSON(&msg)
   322  		assert.Error(t, err)
   323  		assert.Contains(t, err.Error(), "timeout")
   324  	})
   325  
   326  	t.Run("keeps waiting for an init message if no time out is configured", func(t *testing.T) {
   327  		h := testserver.New()
   328  		h.AddTransport(transport.Websocket{})
   329  		srv := httptest.NewServer(h)
   330  		defer srv.Close()
   331  
   332  		c := wsConnect(srv.URL)
   333  		defer c.Close()
   334  
   335  		done := make(chan interface{}, 1)
   336  		go func() {
   337  			var msg operationMessage
   338  			_ = c.ReadJSON(&msg)
   339  			done <- 1
   340  		}()
   341  
   342  		select {
   343  		case <-done:
   344  			assert.Fail(t, "web socket read operation finished while it shouldn't have")
   345  		case <-time.After(100 * time.Millisecond):
   346  			// Success! I guess? Can't really wait forever to see if the read waits forever...
   347  		}
   348  	})
   349  }
   350  
   351  func TestWebSocketErrorFunc(t *testing.T) {
   352  	t.Run("the error handler gets called when an error occurs", func(t *testing.T) {
   353  		errFuncCalled := make(chan bool, 1)
   354  		h := testserver.New()
   355  		h.AddTransport(transport.Websocket{
   356  			ErrorFunc: func(_ context.Context, err error) {
   357  				require.Error(t, err)
   358  				assert.Equal(t, err.Error(), "websocket read: invalid message received")
   359  				assert.IsType(t, transport.WebsocketError{}, err)
   360  				assert.True(t, err.(transport.WebsocketError).IsReadError)
   361  				errFuncCalled <- true
   362  			},
   363  		})
   364  
   365  		srv := httptest.NewServer(h)
   366  		defer srv.Close()
   367  
   368  		c := wsConnect(srv.URL)
   369  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   370  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   371  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   372  		require.NoError(t, c.WriteMessage(websocket.TextMessage, []byte("mark my words, you will regret this")))
   373  
   374  		select {
   375  		case res := <-errFuncCalled:
   376  			assert.True(t, res)
   377  		case <-time.NewTimer(time.Millisecond * 20).C:
   378  			assert.Fail(t, "The fail handler was not called in time")
   379  		}
   380  	})
   381  
   382  	t.Run("init func errors do not call the error handler", func(t *testing.T) {
   383  		h := testserver.New()
   384  		h.AddTransport(transport.Websocket{
   385  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
   386  				return ctx, errors.New("this is not what we agreed upon")
   387  			},
   388  			ErrorFunc: func(_ context.Context, err error) {
   389  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   390  			},
   391  		})
   392  		srv := httptest.NewServer(h)
   393  		defer srv.Close()
   394  
   395  		c := wsConnect(srv.URL)
   396  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   397  		time.Sleep(time.Millisecond * 20)
   398  	})
   399  
   400  	t.Run("init func context closes do not call the error handler", func(t *testing.T) {
   401  		h := testserver.New()
   402  		h.AddTransport(transport.Websocket{
   403  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
   404  				newCtx, cancel := context.WithCancel(ctx)
   405  				time.AfterFunc(time.Millisecond*5, cancel)
   406  				return newCtx, nil
   407  			},
   408  			ErrorFunc: func(_ context.Context, err error) {
   409  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   410  			},
   411  		})
   412  		srv := httptest.NewServer(h)
   413  		defer srv.Close()
   414  
   415  		c := wsConnect(srv.URL)
   416  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   417  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   418  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   419  		time.Sleep(time.Millisecond * 20)
   420  	})
   421  
   422  	t.Run("init func context deadlines do not call the error handler", func(t *testing.T) {
   423  		h := testserver.New()
   424  		var cancel func()
   425  		h.AddTransport(transport.Websocket{
   426  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
   427  				newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5))
   428  				return newCtx, nil
   429  			},
   430  			ErrorFunc: func(_ context.Context, err error) {
   431  				assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
   432  			},
   433  		})
   434  		srv := httptest.NewServer(h)
   435  		defer srv.Close()
   436  
   437  		c := wsConnect(srv.URL)
   438  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   439  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   440  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   441  
   442  		// Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy)
   443  		defer cancel()
   444  
   445  		time.Sleep(time.Millisecond * 20)
   446  	})
   447  }
   448  
   449  func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) {
   450  	initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
   451  		h := testserver.New()
   452  		h.AddTransport(ws)
   453  		return h, httptest.NewServer(h)
   454  	}
   455  
   456  	t.Run("server acks init", func(t *testing.T) {
   457  		_, srv := initialize(transport.Websocket{})
   458  		defer srv.Close()
   459  
   460  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   461  		defer c.Close()
   462  
   463  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   464  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   465  	})
   466  
   467  	t.Run("client can receive data", func(t *testing.T) {
   468  		handler, srv := initialize(transport.Websocket{})
   469  		defer srv.Close()
   470  
   471  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   472  		defer c.Close()
   473  
   474  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   475  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   476  
   477  		require.NoError(t, c.WriteJSON(&operationMessage{
   478  			Type:    graphqltransportwsSubscribeMsg,
   479  			ID:      "test_1",
   480  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   481  		}))
   482  
   483  		handler.SendNextSubscriptionMessage()
   484  		msg := readOp(c)
   485  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   486  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   487  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   488  
   489  		handler.SendNextSubscriptionMessage()
   490  		msg = readOp(c)
   491  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   492  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   493  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   494  
   495  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"}))
   496  
   497  		msg = readOp(c)
   498  		require.Equal(t, graphqltransportwsCompleteMsg, msg.Type)
   499  		require.Equal(t, "test_1", msg.ID)
   500  	})
   501  
   502  	t.Run("receives no graphql-ws keep alive messages", func(t *testing.T) {
   503  		_, srv := initialize(transport.Websocket{KeepAlivePingInterval: 5 * time.Millisecond})
   504  		defer srv.Close()
   505  
   506  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   507  		defer c.Close()
   508  
   509  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   510  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   511  
   512  		// If the keep-alives are sent, this deadline will not be used, and no timeout error will be found
   513  		c.SetReadDeadline(time.Now().UTC().Add(50 * time.Millisecond))
   514  		var msg operationMessage
   515  		err := c.ReadJSON(&msg)
   516  		require.Error(t, err)
   517  		assert.Contains(t, err.Error(), "timeout")
   518  	})
   519  }
   520  
   521  func TestWebsocketWithPingPongInterval(t *testing.T) {
   522  	initialize := func(ws transport.Websocket) (*testserver.TestServer, *httptest.Server) {
   523  		h := testserver.New()
   524  		h.AddTransport(ws)
   525  		return h, httptest.NewServer(h)
   526  	}
   527  
   528  	t.Run("client receives ping and responds with pong", func(t *testing.T) {
   529  		_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
   530  		defer srv.Close()
   531  
   532  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   533  		defer c.Close()
   534  
   535  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   536  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   537  
   538  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   539  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg}))
   540  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   541  	})
   542  
   543  	t.Run("client sends ping and expects pong", func(t *testing.T) {
   544  		_, srv := initialize(transport.Websocket{PingPongInterval: 10 * time.Millisecond})
   545  		defer srv.Close()
   546  
   547  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   548  		defer c.Close()
   549  
   550  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   551  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   552  
   553  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg}))
   554  		assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
   555  	})
   556  
   557  	t.Run("ping-pongs are not sent when the graphql-ws sub protocol is used", func(t *testing.T) {
   558  		// Regression test
   559  		// ---
   560  		// Before the refactor, the code would try to convert a ping message to a graphql-ws message type
   561  		// But since this message type does not exist in the graphql-ws sub protocol, it would fail
   562  
   563  		_, srv := initialize(transport.Websocket{
   564  			PingPongInterval:      5 * time.Millisecond,
   565  			KeepAlivePingInterval: 10 * time.Millisecond,
   566  		})
   567  		defer srv.Close()
   568  
   569  		// Create connection
   570  		c := wsConnect(srv.URL)
   571  		defer c.Close()
   572  
   573  		// Initialize connection
   574  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   575  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   576  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   577  
   578  		// Wait for a few more keep alives to be sure nothing goes wrong
   579  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   580  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   581  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   582  	})
   583  }
   584  
   585  func wsConnect(url string) *websocket.Conn {
   586  	return wsConnectWithSubprocotol(url, "")
   587  }
   588  
   589  func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn {
   590  	h := make(http.Header)
   591  	if subprocotol != "" {
   592  		h.Add("Sec-WebSocket-Protocol", subprocotol)
   593  	}
   594  
   595  	c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h)
   596  	if err != nil {
   597  		panic(err)
   598  	}
   599  	_ = resp.Body.Close()
   600  
   601  	return c
   602  }
   603  
   604  func writeRaw(conn *websocket.Conn, msg string) {
   605  	if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
   606  		panic(err)
   607  	}
   608  }
   609  
   610  func readOp(conn *websocket.Conn) operationMessage {
   611  	var msg operationMessage
   612  	if err := conn.ReadJSON(&msg); err != nil {
   613  		panic(err)
   614  	}
   615  	return msg
   616  }
   617  
   618  // copied out from websocket_graphqlws.go to keep these private
   619  
   620  const (
   621  	connectionInitMsg      = "connection_init"      // Client -> Server
   622  	connectionTerminateMsg = "connection_terminate" // Client -> Server
   623  	startMsg               = "start"                // Client -> Server
   624  	stopMsg                = "stop"                 // Client -> Server
   625  	connectionAckMsg       = "connection_ack"       // Server -> Client
   626  	connectionErrorMsg     = "connection_error"     // Server -> Client
   627  	dataMsg                = "data"                 // Server -> Client
   628  	errorMsg               = "error"                // Server -> Client
   629  	completeMsg            = "complete"             // Server -> Client
   630  	connectionKeepAliveMsg = "ka"                   // Server -> Client
   631  )
   632  
   633  // copied out from websocket_graphql_transport_ws.go to keep these private
   634  
   635  const (
   636  	graphqltransportwsSubprotocol = "graphql-transport-ws"
   637  
   638  	graphqltransportwsConnectionInitMsg = "connection_init"
   639  	graphqltransportwsConnectionAckMsg  = "connection_ack"
   640  	graphqltransportwsSubscribeMsg      = "subscribe"
   641  	graphqltransportwsNextMsg           = "next"
   642  	graphqltransportwsCompleteMsg       = "complete"
   643  	graphqltransportwsPingMsg           = "ping"
   644  	graphqltransportwsPongMsg           = "pong"
   645  )
   646  
   647  type operationMessage struct {
   648  	Payload json.RawMessage `json:"payload,omitempty"`
   649  	ID      string          `json:"id,omitempty"`
   650  	Type    string          `json:"type"`
   651  }