github.com/animeshon/gqlgen@v0.13.1-0.20210304133704-3a770431bb6d/graphql/handler/transport/websocket_test.go (about)

     1  package transport_test
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/animeshon/gqlgen/client"
    13  	"github.com/animeshon/gqlgen/graphql"
    14  	"github.com/animeshon/gqlgen/graphql/handler"
    15  	"github.com/animeshon/gqlgen/graphql/handler/testserver"
    16  	"github.com/animeshon/gqlgen/graphql/handler/transport"
    17  	"github.com/gorilla/websocket"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"github.com/vektah/gqlparser/v2"
    21  	"github.com/vektah/gqlparser/v2/ast"
    22  )
    23  
    24  type ckey string
    25  
    26  func TestWebsocket(t *testing.T) {
    27  	handler := testserver.New()
    28  	handler.AddTransport(transport.Websocket{})
    29  
    30  	srv := httptest.NewServer(handler)
    31  	defer srv.Close()
    32  
    33  	t.Run("client must send valid json", func(t *testing.T) {
    34  		c := wsConnect(srv.URL)
    35  		defer c.Close()
    36  
    37  		writeRaw(c, "hello")
    38  
    39  		msg := readOp(c)
    40  		assert.Equal(t, "connection_error", msg.Type)
    41  		assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload))
    42  	})
    43  
    44  	t.Run("client can terminate before init", func(t *testing.T) {
    45  		c := wsConnect(srv.URL)
    46  		defer c.Close()
    47  
    48  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    49  
    50  		_, _, err := c.ReadMessage()
    51  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    52  	})
    53  
    54  	t.Run("client must send init first", func(t *testing.T) {
    55  		c := wsConnect(srv.URL)
    56  		defer c.Close()
    57  
    58  		require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg}))
    59  
    60  		msg := readOp(c)
    61  		assert.Equal(t, connectionErrorMsg, msg.Type)
    62  		assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload))
    63  	})
    64  
    65  	t.Run("server acks init", func(t *testing.T) {
    66  		c := wsConnect(srv.URL)
    67  		defer c.Close()
    68  
    69  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    70  
    71  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    72  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    73  	})
    74  
    75  	t.Run("client can terminate before run", func(t *testing.T) {
    76  		c := wsConnect(srv.URL)
    77  		defer c.Close()
    78  
    79  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    80  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    81  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    82  
    83  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    84  
    85  		_, _, err := c.ReadMessage()
    86  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    87  	})
    88  
    89  	t.Run("client gets parse errors", func(t *testing.T) {
    90  		c := wsConnect(srv.URL)
    91  		defer c.Close()
    92  
    93  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    94  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    95  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    96  
    97  		require.NoError(t, c.WriteJSON(&operationMessage{
    98  			Type:    startMsg,
    99  			ID:      "test_1",
   100  			Payload: json.RawMessage(`{"query": "!"}`),
   101  		}))
   102  
   103  		msg := readOp(c)
   104  		assert.Equal(t, errorMsg, msg.Type)
   105  		assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload))
   106  	})
   107  
   108  	t.Run("client can receive data", func(t *testing.T) {
   109  		c := wsConnect(srv.URL)
   110  		defer c.Close()
   111  
   112  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   113  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   114  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   115  
   116  		require.NoError(t, c.WriteJSON(&operationMessage{
   117  			Type:    startMsg,
   118  			ID:      "test_1",
   119  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   120  		}))
   121  
   122  		handler.SendNextSubscriptionMessage()
   123  		msg := readOp(c)
   124  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   125  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   126  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   127  
   128  		handler.SendNextSubscriptionMessage()
   129  		msg = readOp(c)
   130  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   131  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   132  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   133  
   134  		require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"}))
   135  
   136  		msg = readOp(c)
   137  		require.Equal(t, completeMsg, msg.Type)
   138  		require.Equal(t, "test_1", msg.ID)
   139  	})
   140  }
   141  
   142  func TestWebsocketWithKeepAlive(t *testing.T) {
   143  
   144  	h := testserver.New()
   145  	h.AddTransport(transport.Websocket{
   146  		KeepAlivePingInterval: 100 * time.Millisecond,
   147  	})
   148  
   149  	srv := httptest.NewServer(h)
   150  	defer srv.Close()
   151  
   152  	c := wsConnect(srv.URL)
   153  	defer c.Close()
   154  
   155  	require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   156  	assert.Equal(t, connectionAckMsg, readOp(c).Type)
   157  	assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   158  
   159  	require.NoError(t, c.WriteJSON(&operationMessage{
   160  		Type:    startMsg,
   161  		ID:      "test_1",
   162  		Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   163  	}))
   164  
   165  	// keepalive
   166  	msg := readOp(c)
   167  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   168  
   169  	// server message
   170  	h.SendNextSubscriptionMessage()
   171  	msg = readOp(c)
   172  	assert.Equal(t, dataMsg, msg.Type)
   173  
   174  	// keepalive
   175  	msg = readOp(c)
   176  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   177  }
   178  
   179  func TestWebsocketInitFunc(t *testing.T) {
   180  	t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) {
   181  		h := testserver.New()
   182  		h.AddTransport(transport.Websocket{})
   183  		srv := httptest.NewServer(h)
   184  		defer srv.Close()
   185  
   186  		c := wsConnect(srv.URL)
   187  		defer c.Close()
   188  
   189  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   190  
   191  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   192  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   193  	})
   194  
   195  	t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   196  		h := testserver.New()
   197  		h.AddTransport(transport.Websocket{
   198  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   199  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
   200  			},
   201  		})
   202  		srv := httptest.NewServer(h)
   203  		defer srv.Close()
   204  
   205  		c := wsConnect(srv.URL)
   206  		defer c.Close()
   207  
   208  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   209  
   210  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   211  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   212  	})
   213  
   214  	t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   215  		h := testserver.New()
   216  		h.AddTransport(transport.Websocket{
   217  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   218  				return ctx, errors.New("invalid init payload")
   219  			},
   220  		})
   221  		srv := httptest.NewServer(h)
   222  		defer srv.Close()
   223  
   224  		c := wsConnect(srv.URL)
   225  		defer c.Close()
   226  
   227  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   228  
   229  		msg := readOp(c)
   230  		assert.Equal(t, connectionErrorMsg, msg.Type)
   231  		assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload))
   232  	})
   233  
   234  	t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) {
   235  		es := &graphql.ExecutableSchemaMock{
   236  			ExecFunc: func(ctx context.Context) graphql.ResponseHandler {
   237  				assert.Equal(t, "newvalue", ctx.Value(ckey("newkey")))
   238  				return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)})
   239  			},
   240  			SchemaFunc: func() *ast.Schema {
   241  				return gqlparser.MustLoadSchema(&ast.Source{Input: `
   242  				schema { query: Query }
   243  				type Query {
   244  					empty: String
   245  				}
   246  			`})
   247  			},
   248  		}
   249  		h := handler.New(es)
   250  
   251  		h.AddTransport(transport.Websocket{
   252  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   253  				return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
   254  			},
   255  		})
   256  
   257  		c := client.New(h)
   258  
   259  		socket := c.Websocket("{ empty } ")
   260  		defer socket.Close()
   261  		var resp struct {
   262  			Empty string
   263  		}
   264  		err := socket.Next(&resp)
   265  		require.NoError(t, err)
   266  		assert.Equal(t, "ok", resp.Empty)
   267  	})
   268  }
   269  
   270  func wsConnect(url string) *websocket.Conn {
   271  	c, resp, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil)
   272  	if err != nil {
   273  		panic(err)
   274  	}
   275  	_ = resp.Body.Close()
   276  
   277  	return c
   278  }
   279  
   280  func writeRaw(conn *websocket.Conn, msg string) {
   281  	if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
   282  		panic(err)
   283  	}
   284  }
   285  
   286  func readOp(conn *websocket.Conn) operationMessage {
   287  	var msg operationMessage
   288  	if err := conn.ReadJSON(&msg); err != nil {
   289  		panic(err)
   290  	}
   291  	return msg
   292  }
   293  
   294  // copied out from weboscket.go to keep these private
   295  
   296  const (
   297  	connectionInitMsg      = "connection_init"      // Client -> Server
   298  	connectionTerminateMsg = "connection_terminate" // Client -> Server
   299  	startMsg               = "start"                // Client -> Server
   300  	stopMsg                = "stop"                 // Client -> Server
   301  	connectionAckMsg       = "connection_ack"       // Server -> Client
   302  	connectionErrorMsg     = "connection_error"     // Server -> Client
   303  	dataMsg                = "data"                 // Server -> Client
   304  	errorMsg               = "error"                // Server -> Client
   305  	completeMsg            = "complete"             // Server -> Client
   306  	connectionKeepAliveMsg = "ka"                   // Server -> Client
   307  )
   308  
   309  type operationMessage struct {
   310  	Payload json.RawMessage `json:"payload,omitempty"`
   311  	ID      string          `json:"id,omitempty"`
   312  	Type    string          `json:"type"`
   313  }