github.com/jepp2078/gqlgen@v0.7.2/handler/websocket_test.go (about)

     1  package handler
     2  
     3  import (
     4  	"encoding/json"
     5  	"net/http/httptest"
     6  	"strings"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/gorilla/websocket"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func TestWebsocket(t *testing.T) {
    15  	next := make(chan struct{})
    16  	h := GraphQL(&executableSchemaStub{next})
    17  
    18  	srv := httptest.NewServer(h)
    19  	defer srv.Close()
    20  
    21  	t.Run("client must send valid json", func(t *testing.T) {
    22  		c := wsConnect(srv.URL)
    23  		defer c.Close()
    24  
    25  		writeRaw(c, "hello")
    26  
    27  		msg := readOp(c)
    28  		require.Equal(t, connectionErrorMsg, msg.Type)
    29  		require.Equal(t, `{"message":"invalid json"}`, string(msg.Payload))
    30  	})
    31  
    32  	t.Run("client can terminate before init", func(t *testing.T) {
    33  		c := wsConnect(srv.URL)
    34  		defer c.Close()
    35  
    36  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    37  
    38  		_, _, err := c.ReadMessage()
    39  		require.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    40  	})
    41  
    42  	t.Run("client must send init first", func(t *testing.T) {
    43  		c := wsConnect(srv.URL)
    44  		defer c.Close()
    45  
    46  		require.NoError(t, c.WriteJSON(&operationMessage{Type: startMsg}))
    47  
    48  		msg := readOp(c)
    49  		require.Equal(t, connectionErrorMsg, msg.Type)
    50  		require.Equal(t, `{"message":"unexpected message start"}`, string(msg.Payload))
    51  	})
    52  
    53  	t.Run("server acks init", func(t *testing.T) {
    54  		c := wsConnect(srv.URL)
    55  		defer c.Close()
    56  
    57  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    58  
    59  		require.Equal(t, connectionAckMsg, readOp(c).Type)
    60  	})
    61  
    62  	t.Run("client can terminate before run", func(t *testing.T) {
    63  		c := wsConnect(srv.URL)
    64  		defer c.Close()
    65  
    66  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    67  		require.Equal(t, connectionAckMsg, readOp(c).Type)
    68  
    69  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionTerminateMsg}))
    70  
    71  		_, _, err := c.ReadMessage()
    72  		require.Equal(t, websocket.CloseNormalClosure, err.(*websocket.CloseError).Code)
    73  	})
    74  
    75  	t.Run("client gets parse errors", func(t *testing.T) {
    76  		c := wsConnect(srv.URL)
    77  		defer c.Close()
    78  
    79  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    80  		require.Equal(t, connectionAckMsg, readOp(c).Type)
    81  
    82  		require.NoError(t, c.WriteJSON(&operationMessage{
    83  			Type:    startMsg,
    84  			ID:      "test_1",
    85  			Payload: json.RawMessage(`{"query": "!"}`),
    86  		}))
    87  
    88  		msg := readOp(c)
    89  		require.Equal(t, errorMsg, msg.Type)
    90  		require.Equal(t, `[{"message":"Unexpected !","locations":[{"line":1,"column":1}]}]`, string(msg.Payload))
    91  	})
    92  
    93  	t.Run("client can receive data", func(t *testing.T) {
    94  		c := wsConnect(srv.URL)
    95  		defer c.Close()
    96  
    97  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
    98  		require.Equal(t, connectionAckMsg, readOp(c).Type)
    99  
   100  		require.NoError(t, c.WriteJSON(&operationMessage{
   101  			Type:    startMsg,
   102  			ID:      "test_1",
   103  			Payload: json.RawMessage(`{"query": "subscription { user { title } }"}`),
   104  		}))
   105  
   106  		next <- struct{}{}
   107  		msg := readOp(c)
   108  		require.Equal(t, dataMsg, msg.Type)
   109  		require.Equal(t, "test_1", msg.ID)
   110  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   111  
   112  		next <- struct{}{}
   113  		msg = readOp(c)
   114  		require.Equal(t, dataMsg, msg.Type)
   115  		require.Equal(t, "test_1", msg.ID)
   116  		require.Equal(t, `{"data":{"name":"test"}}`, string(msg.Payload))
   117  
   118  		require.NoError(t, c.WriteJSON(&operationMessage{Type: stopMsg, ID: "test_1"}))
   119  
   120  		msg = readOp(c)
   121  		require.Equal(t, completeMsg, msg.Type)
   122  		require.Equal(t, "test_1", msg.ID)
   123  	})
   124  }
   125  
   126  func TestWebsocketWithKeepAlive(t *testing.T) {
   127  	next := make(chan struct{})
   128  	h := GraphQL(&executableSchemaStub{next}, WebsocketKeepAliveDuration(10*time.Millisecond))
   129  
   130  	srv := httptest.NewServer(h)
   131  	defer srv.Close()
   132  
   133  	t.Run("client must receive keepalive", func(t *testing.T) {
   134  		c := wsConnect(srv.URL)
   135  		defer c.Close()
   136  
   137  		require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))
   138  		require.Equal(t, connectionAckMsg, readOp(c).Type)
   139  
   140  		require.NoError(t, c.WriteJSON(&operationMessage{
   141  			Type:    startMsg,
   142  			ID:      "test_1",
   143  			Payload: json.RawMessage(`{"query": "subscription { user { title } }"}`),
   144  		}))
   145  
   146  		// keepalive
   147  		msg := readOp(c)
   148  		require.Equal(t, connectionKeepAliveMsg, msg.Type)
   149  
   150  		// server message
   151  		next <- struct{}{}
   152  		msg = readOp(c)
   153  		require.Equal(t, dataMsg, msg.Type)
   154  
   155  		// keepalive
   156  		msg = readOp(c)
   157  		require.Equal(t, connectionKeepAliveMsg, msg.Type)
   158  	})
   159  }
   160  
   161  func wsConnect(url string) *websocket.Conn {
   162  	c, _, err := websocket.DefaultDialer.Dial(strings.Replace(url, "http://", "ws://", -1), nil)
   163  	if err != nil {
   164  		panic(err)
   165  	}
   166  	return c
   167  }
   168  
   169  func writeRaw(conn *websocket.Conn, msg string) {
   170  	if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil {
   171  		panic(err)
   172  	}
   173  }
   174  
   175  func readOp(conn *websocket.Conn) operationMessage {
   176  	var msg operationMessage
   177  	if err := conn.ReadJSON(&msg); err != nil {
   178  		panic(err)
   179  	}
   180  	return msg
   181  }