github.com/apipluspower/gqlgen@v0.15.2/graphql/handler/transport/websocket_test.go (about)

     1  package transport_test
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/apipluspower/gqlgen/client"
    14  	"github.com/apipluspower/gqlgen/graphql"
    15  	"github.com/apipluspower/gqlgen/graphql/handler"
    16  	"github.com/apipluspower/gqlgen/graphql/handler/testserver"
    17  	"github.com/apipluspower/gqlgen/graphql/handler/transport"
    18  	"github.com/gorilla/websocket"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/vektah/gqlparser/v2"
    22  	"github.com/vektah/gqlparser/v2/ast"
    23  )
    24  
    25  type ckey string
    26  
    27  func TestWebsocket(t *testing.T) {
    28  	handler := testserver.New()
    29  	handler.AddTransport(transport.Websocket{})
    30  
    31  	srv := httptest.NewServer(handler)
    32  	defer srv.Close()
    33  
    34  	t.Run("client must send valid json", func(t *testing.T) {
    35  		c := wsConnect(srv.URL)
    36  		defer c.Close()
    37  
    38  		writeRaw(c, "hello")
    39  
    40  		msg := readOp(c)
    41  		assert.Equal(t, "connection_error", msg.Type)
    42  		assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload))
    43  	})
    44  
    45  	t.Run("client can terminate before init", func(t *testing.T) {
    46  		c := wsConnect(srv.URL)
    47  		defer c.Close()
    48  
    49  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    50  
    51  		_, _, err := c.ReadMessage()
    52  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    53  	})
    54  
    55  	t.Run("client must send init first", func(t *testing.T) {
    56  		c := wsConnect(srv.URL)
    57  		defer c.Close()
    58  
    59  		require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg}))
    60  
    61  		msg := readOp(c)
    62  		assert.Equal(t, connectionErrorMsg, msg.Type)
    63  		assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload))
    64  	})
    65  
    66  	t.Run("server acks init", func(t *testing.T) {
    67  		c := wsConnect(srv.URL)
    68  		defer c.Close()
    69  
    70  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    71  
    72  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    73  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    74  	})
    75  
    76  	t.Run("client can terminate before run", func(t *testing.T) {
    77  		c := wsConnect(srv.URL)
    78  		defer c.Close()
    79  
    80  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    81  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    82  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    83  
    84  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    85  
    86  		_, _, err := c.ReadMessage()
    87  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    88  	})
    89  
    90  	t.Run("client gets parse errors", func(t *testing.T) {
    91  		c := wsConnect(srv.URL)
    92  		defer c.Close()
    93  
    94  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    95  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    96  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    97  
    98  		require.NoError(t, c.WriteJSON(&operationMessage{
    99  			Type:    startMsg,
   100  			ID:      "test_1",
   101  			Payload: json.RawMessage(`{"query": "!"}`),
   102  		}))
   103  
   104  		msg := readOp(c)
   105  		assert.Equal(t, errorMsg, msg.Type)
   106  		assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload))
   107  	})
   108  
   109  	t.Run("client can receive data", func(t *testing.T) {
   110  		c := wsConnect(srv.URL)
   111  		defer c.Close()
   112  
   113  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   114  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   115  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   116  
   117  		require.NoError(t, c.WriteJSON(&operationMessage{
   118  			Type:    startMsg,
   119  			ID:      "test_1",
   120  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   121  		}))
   122  
   123  		handler.SendNextSubscriptionMessage()
   124  		msg := readOp(c)
   125  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   126  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   127  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   128  
   129  		handler.SendNextSubscriptionMessage()
   130  		msg = readOp(c)
   131  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   132  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   133  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   134  
   135  		require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"}))
   136  
   137  		msg = readOp(c)
   138  		require.Equal(t, completeMsg, msg.Type)
   139  		require.Equal(t, "test_1", msg.ID)
   140  	})
   141  }
   142  
   143  func TestWebsocketWithKeepAlive(t *testing.T) {
   144  	h := testserver.New()
   145  	h.AddTransport(transport.Websocket{
   146  		KeepAlivePingInterval: 100 * time.Millisecond,
   147  	})
   148  
   149  	srv := httptest.NewServer(h)
   150  	defer srv.Close()
   151  
   152  	c := wsConnect(srv.URL)
   153  	defer c.Close()
   154  
   155  	require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   156  	assert.Equal(t, connectionAckMsg, readOp(c).Type)
   157  	assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   158  
   159  	require.NoError(t, c.WriteJSON(&operationMessage{
   160  		Type:    startMsg,
   161  		ID:      "test_1",
   162  		Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   163  	}))
   164  
   165  	// keepalive
   166  	msg := readOp(c)
   167  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   168  
   169  	// server message
   170  	h.SendNextSubscriptionMessage()
   171  	msg = readOp(c)
   172  	assert.Equal(t, dataMsg, msg.Type)
   173  
   174  	// keepalive
   175  	msg = readOp(c)
   176  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   177  }
   178  
   179  func TestWebsocketInitFunc(t *testing.T) {
   180  	t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) {
   181  		h := testserver.New()
   182  		h.AddTransport(transport.Websocket{})
   183  		srv := httptest.NewServer(h)
   184  		defer srv.Close()
   185  
   186  		c := wsConnect(srv.URL)
   187  		defer c.Close()
   188  
   189  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   190  
   191  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   192  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   193  	})
   194  
   195  	t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   196  		h := testserver.New()
   197  		h.AddTransport(transport.Websocket{
   198  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   199  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
   200  			},
   201  		})
   202  		srv := httptest.NewServer(h)
   203  		defer srv.Close()
   204  
   205  		c := wsConnect(srv.URL)
   206  		defer c.Close()
   207  
   208  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   209  
   210  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   211  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   212  	})
   213  
   214  	t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   215  		h := testserver.New()
   216  		h.AddTransport(transport.Websocket{
   217  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   218  				return ctx, errors.New("invalid init payload")
   219  			},
   220  		})
   221  		srv := httptest.NewServer(h)
   222  		defer srv.Close()
   223  
   224  		c := wsConnect(srv.URL)
   225  		defer c.Close()
   226  
   227  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   228  
   229  		msg := readOp(c)
   230  		assert.Equal(t, connectionErrorMsg, msg.Type)
   231  		assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload))
   232  	})
   233  
   234  	t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) {
   235  		es := &graphql.ExecutableSchemaMock{
   236  			ExecFunc: func(ctx context.Context) graphql.ResponseHandler {
   237  				assert.Equal(t, "newvalue", ctx.Value(ckey("newkey")))
   238  				return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)})
   239  			},
   240  			SchemaFunc: func() *ast.Schema {
   241  				return gqlparser.MustLoadSchema(&ast.Source{Input: `
   242  				schema { query: Query }
   243  				type Query {
   244  					empty: String
   245  				}
   246  			`})
   247  			},
   248  		}
   249  		h := handler.New(es)
   250  
   251  		h.AddTransport(transport.Websocket{
   252  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   253  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
   254  			},
   255  		})
   256  
   257  		c := client.New(h)
   258  
   259  		socket := c.Websocket("{ empty } ")
   260  		defer socket.Close()
   261  		var resp struct {
   262  			Empty string
   263  		}
   264  		err := socket.Next(&resp)
   265  		require.NoError(t, err)
   266  		assert.Equal(t, "ok", resp.Empty)
   267  	})
   268  
   269  	t.Run("can set a deadline on a websocket connection and close it with a reason", func(t *testing.T) {
   270  		h := testserver.New()
   271  		var cancel func()
   272  		h.AddTransport(transport.Websocket{
   273  			InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
   274  				newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5)
   275  				return
   276  			},
   277  		})
   278  		srv := httptest.NewServer(h)
   279  		defer srv.Close()
   280  
   281  		c := wsConnect(srv.URL)
   282  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   283  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   284  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   285  
   286  		// Cancel should contain an actual value now, so let's call it when we exit this scope (to make the linter happy)
   287  		defer cancel()
   288  
   289  		time.Sleep(time.Millisecond * 10)
   290  		m := readOp(c)
   291  		assert.Equal(t, m.Type, connectionErrorMsg)
   292  		assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`)
   293  	})
   294  }
   295  
   296  func TestWebsocketGraphqltransportwsSubprotocol(t *testing.T) {
   297  	handler := testserver.New()
   298  	handler.AddTransport(transport.Websocket{})
   299  
   300  	srv := httptest.NewServer(handler)
   301  	defer srv.Close()
   302  
   303  	t.Run("server acks init", func(t *testing.T) {
   304  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   305  		defer c.Close()
   306  
   307  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   308  
   309  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   310  	})
   311  
   312  	t.Run("client can receive data", func(t *testing.T) {
   313  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   314  		defer c.Close()
   315  
   316  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   317  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   318  
   319  		require.NoError(t, c.WriteJSON(&operationMessage{
   320  			Type:    graphqltransportwsSubscribeMsg,
   321  			ID:      "test_1",
   322  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   323  		}))
   324  
   325  		handler.SendNextSubscriptionMessage()
   326  		msg := readOp(c)
   327  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   328  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   329  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   330  
   331  		handler.SendNextSubscriptionMessage()
   332  		msg = readOp(c)
   333  		require.Equal(t, graphqltransportwsNextMsg, msg.Type, string(msg.Payload))
   334  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   335  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   336  
   337  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsCompleteMsg, ID: "test_1"}))
   338  
   339  		msg = readOp(c)
   340  		require.Equal(t, graphqltransportwsCompleteMsg, msg.Type)
   341  		require.Equal(t, "test_1", msg.ID)
   342  	})
   343  }
   344  
   345  func TestWebsocketWithPingPongInterval(t *testing.T) {
   346  	handler := testserver.New()
   347  	handler.AddTransport(transport.Websocket{
   348  		PingPongInterval: time.Second * 1,
   349  	})
   350  
   351  	srv := httptest.NewServer(handler)
   352  	defer srv.Close()
   353  
   354  	t.Run("client receives ping and responds with pong", func(t *testing.T) {
   355  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   356  		defer c.Close()
   357  
   358  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   359  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   360  
   361  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   362  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPongMsg}))
   363  		assert.Equal(t, graphqltransportwsPingMsg, readOp(c).Type)
   364  	})
   365  
   366  	t.Run("client sends ping and expects pong", func(t *testing.T) {
   367  		c := wsConnectWithSubprocotol(srv.URL, graphqltransportwsSubprotocol)
   368  		defer c.Close()
   369  
   370  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsConnectionInitMsg}))
   371  		assert.Equal(t, graphqltransportwsConnectionAckMsg, readOp(c).Type)
   372  
   373  		require.NoError(t, c.WriteJSON(&operationMessage{Type: graphqltransportwsPingMsg}))
   374  		assert.Equal(t, graphqltransportwsPongMsg, readOp(c).Type)
   375  	})
   376  }
   377  
   378  func wsConnect(url string) *websocket.Conn {
   379  	return wsConnectWithSubprocotol(url, "")
   380  }
   381  
   382  func wsConnectWithSubprocotol(url, subprocotol string) *websocket.Conn {
   383  	h := make(http.Header)
   384  	if subprocotol != "" {
   385  		h.Add("Sec-WebSocket-Protocol", subprocotol)
   386  	}
   387  
   388  	c, resp, err := websocket.DefaultDialer.Dial(strings.ReplaceAll(url, "http://", "ws://"), h)
   389  	if err != nil {
   390  		panic(err)
   391  	}
   392  	_ = resp.Body.Close()
   393  
   394  	return c
   395  }
   396  
   397  func writeRaw(conn *websocket.Conn, msg string) {
   398  	if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
   399  		panic(err)
   400  	}
   401  }
   402  
   403  func readOp(conn *websocket.Conn) operationMessage {
   404  	var msg operationMessage
   405  	if err := conn.ReadJSON(&msg); err != nil {
   406  		panic(err)
   407  	}
   408  	return msg
   409  }
   410  
   411  // copied out from websocket_graphqlws.go to keep these private
   412  
   413  const (
   414  	connectionInitMsg      = "connection_init"      // Client -> Server
   415  	connectionTerminateMsg = "connection_terminate" // Client -> Server
   416  	startMsg               = "start"                // Client -> Server
   417  	stopMsg                = "stop"                 // Client -> Server
   418  	connectionAckMsg       = "connection_ack"       // Server -> Client
   419  	connectionErrorMsg     = "connection_error"     // Server -> Client
   420  	dataMsg                = "data"                 // Server -> Client
   421  	errorMsg               = "error"                // Server -> Client
   422  	completeMsg            = "complete"             // Server -> Client
   423  	connectionKeepAliveMsg = "ka"                   // Server -> Client
   424  )
   425  
   426  // copied out from websocket_graphql_transport_ws.go to keep these private
   427  
   428  const (
   429  	graphqltransportwsSubprotocol = "graphql-transport-ws"
   430  
   431  	graphqltransportwsConnectionInitMsg = "connection_init"
   432  	graphqltransportwsConnectionAckMsg  = "connection_ack"
   433  	graphqltransportwsSubscribeMsg      = "subscribe"
   434  	graphqltransportwsNextMsg           = "next"
   435  	graphqltransportwsCompleteMsg       = "complete"
   436  	graphqltransportwsPingMsg           = "ping"
   437  	graphqltransportwsPongMsg           = "pong"
   438  )
   439  
   440  type operationMessage struct {
   441  	Payload json.RawMessage `json:"payload,omitempty"`
   442  	ID      string          `json:"id,omitempty"`
   443  	Type    string          `json:"type"`
   444  }