git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/graphql/handler/transport/websocket_test.go (about)

     1  package transport_test
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"git.sr.ht/~sircmpwn/gqlgen/client"
    13  	"git.sr.ht/~sircmpwn/gqlgen/graphql"
    14  	"git.sr.ht/~sircmpwn/gqlgen/graphql/handler"
    15  	"git.sr.ht/~sircmpwn/gqlgen/graphql/handler/testserver"
    16  	"git.sr.ht/~sircmpwn/gqlgen/graphql/handler/transport"
    17  	"github.com/gorilla/websocket"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"github.com/vektah/gqlparser/v2"
    21  	"github.com/vektah/gqlparser/v2/ast"
    22  )
    23  
    24  func TestWebsocket(t *testing.T) {
    25  	handler := testserver.New()
    26  	handler.AddTransport(transport.Websocket{})
    27  
    28  	srv := httptest.NewServer(handler)
    29  	defer srv.Close()
    30  
    31  	t.Run("client must send valid json", func(t *testing.T) {
    32  		c := wsConnect(srv.URL)
    33  		defer c.Close()
    34  
    35  		writeRaw(c, "hello")
    36  
    37  		msg := readOp(c)
    38  		assert.Equal(t, "connection_error", msg.Type)
    39  		assert.Equal(t, `{"message":"invalid json"}`, string(msg.Payload))
    40  	})
    41  
    42  	t.Run("client can terminate before init", func(t *testing.T) {
    43  		c := wsConnect(srv.URL)
    44  		defer c.Close()
    45  
    46  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    47  
    48  		_, _, err := c.ReadMessage()
    49  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    50  	})
    51  
    52  	t.Run("client must send init first", func(t *testing.T) {
    53  		c := wsConnect(srv.URL)
    54  		defer c.Close()
    55  
    56  		require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg}))
    57  
    58  		msg := readOp(c)
    59  		assert.Equal(t, connectionErrorMsg, msg.Type)
    60  		assert.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload))
    61  	})
    62  
    63  	t.Run("server acks init", func(t *testing.T) {
    64  		c := wsConnect(srv.URL)
    65  		defer c.Close()
    66  
    67  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    68  
    69  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    70  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    71  	})
    72  
    73  	t.Run("client can terminate before run", func(t *testing.T) {
    74  		c := wsConnect(srv.URL)
    75  		defer c.Close()
    76  
    77  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    78  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    79  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    80  
    81  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    82  
    83  		_, _, err := c.ReadMessage()
    84  		assert.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    85  	})
    86  
    87  	t.Run("client gets parse errors", func(t *testing.T) {
    88  		c := wsConnect(srv.URL)
    89  		defer c.Close()
    90  
    91  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    92  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
    93  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
    94  
    95  		require.NoError(t, c.WriteJSON(&operationMessage{
    96  			Type:    startMsg,
    97  			ID:      "test_1",
    98  			Payload: json.RawMessage(`{"query": "!"}`),
    99  		}))
   100  
   101  		msg := readOp(c)
   102  		assert.Equal(t, errorMsg, msg.Type)
   103  		assert.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]`, string(msg.Payload))
   104  	})
   105  
   106  	t.Run("client can receive data", func(t *testing.T) {
   107  		c := wsConnect(srv.URL)
   108  		defer c.Close()
   109  
   110  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   111  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   112  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   113  
   114  		require.NoError(t, c.WriteJSON(&operationMessage{
   115  			Type:    startMsg,
   116  			ID:      "test_1",
   117  			Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   118  		}))
   119  
   120  		handler.SendNextSubscriptionMessage()
   121  		msg := readOp(c)
   122  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   123  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   124  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   125  
   126  		handler.SendNextSubscriptionMessage()
   127  		msg = readOp(c)
   128  		require.Equal(t, dataMsg, msg.Type, string(msg.Payload))
   129  		require.Equal(t, "test_1", msg.ID, string(msg.Payload))
   130  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   131  
   132  		require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"}))
   133  
   134  		msg = readOp(c)
   135  		require.Equal(t, completeMsg, msg.Type)
   136  		require.Equal(t, "test_1", msg.ID)
   137  	})
   138  }
   139  
   140  func TestWebsocketWithKeepAlive(t *testing.T) {
   141  
   142  	h := testserver.New()
   143  	h.AddTransport(transport.Websocket{
   144  		KeepAlivePingInterval: 100 * time.Millisecond,
   145  	})
   146  
   147  	srv := httptest.NewServer(h)
   148  	defer srv.Close()
   149  
   150  	c := wsConnect(srv.URL)
   151  	defer c.Close()
   152  
   153  	require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   154  	assert.Equal(t, connectionAckMsg, readOp(c).Type)
   155  	assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   156  
   157  	require.NoError(t, c.WriteJSON(&operationMessage{
   158  		Type:    startMsg,
   159  		ID:      "test_1",
   160  		Payload: json.RawMessage(`{"query": "subscription { name }"}`),
   161  	}))
   162  
   163  	// keepalive
   164  	msg := readOp(c)
   165  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   166  
   167  	// server message
   168  	h.SendNextSubscriptionMessage()
   169  	msg = readOp(c)
   170  	assert.Equal(t, dataMsg, msg.Type)
   171  
   172  	// keepalive
   173  	msg = readOp(c)
   174  	assert.Equal(t, connectionKeepAliveMsg, msg.Type)
   175  }
   176  
   177  func TestWebsocketInitFunc(t *testing.T) {
   178  	t.Run("accept connection if WebsocketInitFunc is NOT provided", func(t *testing.T) {
   179  		h := testserver.New()
   180  		h.AddTransport(transport.Websocket{})
   181  		srv := httptest.NewServer(h)
   182  		defer srv.Close()
   183  
   184  		c := wsConnect(srv.URL)
   185  		defer c.Close()
   186  
   187  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   188  
   189  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   190  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   191  	})
   192  
   193  	t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   194  		h := testserver.New()
   195  		h.AddTransport(transport.Websocket{
   196  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   197  				return context.WithValue(ctx, "newkey", "newvalue"), nil
   198  			},
   199  		})
   200  		srv := httptest.NewServer(h)
   201  		defer srv.Close()
   202  
   203  		c := wsConnect(srv.URL)
   204  		defer c.Close()
   205  
   206  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   207  
   208  		assert.Equal(t, connectionAckMsg, readOp(c).Type)
   209  		assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
   210  	})
   211  
   212  	t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
   213  		h := testserver.New()
   214  		h.AddTransport(transport.Websocket{
   215  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   216  				return ctx, errors.New("invalid init payload")
   217  			},
   218  		})
   219  		srv := httptest.NewServer(h)
   220  		defer srv.Close()
   221  
   222  		c := wsConnect(srv.URL)
   223  		defer c.Close()
   224  
   225  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   226  
   227  		msg := readOp(c)
   228  		assert.Equal(t, connectionErrorMsg, msg.Type)
   229  		assert.Equal(t, `{"message":"invalid init payload"}`, string(msg.Payload))
   230  	})
   231  
   232  	t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) {
   233  		es := &graphql.ExecutableSchemaMock{
   234  			ExecFunc: func(ctx context.Context) graphql.ResponseHandler {
   235  				assert.Equal(t, "newvalue", ctx.Value("newkey"))
   236  				return graphql.OneShot(&graphql.Response{Data: []byte(`{"empty":"ok"}`)})
   237  			},
   238  			SchemaFunc: func() *ast.Schema {
   239  				return gqlparser.MustLoadSchema(&ast.Source{Input: `
   240  				schema { query: Query }
   241  				type Query {
   242  					empty: String
   243  				}
   244  			`})
   245  			},
   246  		}
   247  		h := handler.New(es)
   248  
   249  		h.AddTransport(transport.Websocket{
   250  			InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
   251  				return context.WithValue(ctx, "newkey", "newvalue"), nil
   252  			},
   253  		})
   254  
   255  		c := client.New(h)
   256  
   257  		socket := c.Websocket("{ empty } ")
   258  		defer socket.Close()
   259  		var resp struct {
   260  			Empty string
   261  		}
   262  		err := socket.Next(&resp)
   263  		require.NoError(t, err)
   264  		assert.Equal(t, "ok", resp.Empty)
   265  	})
   266  }
   267  
   268  func wsConnect(url string) *websocket.Conn {
   269  	c, resp, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil)
   270  	if err != nil {
   271  		panic(err)
   272  	}
   273  	_ = resp.Body.Close()
   274  
   275  	return c
   276  }
   277  
   278  func writeRaw(conn *websocket.Conn, msg string) {
   279  	if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
   280  		panic(err)
   281  	}
   282  }
   283  
   284  func readOp(conn *websocket.Conn) operationMessage {
   285  	var msg operationMessage
   286  	if err := conn.ReadJSON(&msg); err != nil {
   287  		panic(err)
   288  	}
   289  	return msg
   290  }
   291  
   292  // copied out from weboscket.go to keep these private
   293  
   294  const (
   295  	connectionInitMsg      = "connection_init"      // Client -> Server
   296  	connectionTerminateMsg = "connection_terminate" // Client -> Server
   297  	startMsg               = "start"                // Client -> Server
   298  	stopMsg                = "stop"                 // Client -> Server
   299  	connectionAckMsg       = "connection_ack"       // Server -> Client
   300  	connectionErrorMsg     = "connection_error"     // Server -> Client
   301  	dataMsg                = "data"                 // Server -> Client
   302  	errorMsg               = "error"                // Server -> Client
   303  	completeMsg            = "complete"             // Server -> Client
   304  	connectionKeepAliveMsg = "ka"                   // Server -> Client
   305  )
   306  
   307  type operationMessage struct {
   308  	Payload json.RawMessage `json:"payload,omitempty"`
   309  	ID      string          `json:"id,omitempty"`
   310  	Type    string          `json:"type"`
   311  }