github.com/maeglindeveloper/gqlgen@v0.13.1-0.20210413081235-57808b12a0a0/graphql/handler/transport/websocket.go (about)

     1  package transport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"log"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/99designs/gqlgen/graphql"
    14  	"github.com/99designs/gqlgen/graphql/errcode"
    15  	"github.com/gorilla/websocket"
    16  	"github.com/vektah/gqlparser/v2/gqlerror"
    17  )
    18  
    19  const (
    20  	connectionInitMsg      = "connection_init"      // Client -> Server
    21  	connectionTerminateMsg = "connection_terminate" // Client -> Server
    22  	startMsg               = "start"                // Client -> Server
    23  	stopMsg                = "stop"                 // Client -> Server
    24  	connectionAckMsg       = "connection_ack"       // Server -> Client
    25  	connectionErrorMsg     = "connection_error"     // Server -> Client
    26  	dataMsg                = "data"                 // Server -> Client
    27  	errorMsg               = "error"                // Server -> Client
    28  	completeMsg            = "complete"             // Server -> Client
    29  	connectionKeepAliveMsg = "ka"                   // Server -> Client
    30  )
    31  
    32  type (
    33  	Websocket struct {
    34  		Upgrader              websocket.Upgrader
    35  		InitFunc              WebsocketInitFunc
    36  		KeepAlivePingInterval time.Duration
    37  		PingPongInterval      time.Duration
    38  	}
    39  	wsConnection struct {
    40  		Websocket
    41  		ctx             context.Context
    42  		conn            *websocket.Conn
    43  		active          map[string]context.CancelFunc
    44  		mu              sync.Mutex
    45  		keepAliveTicker *time.Ticker
    46  		pingPongTicker  *time.Ticker
    47  		exec            graphql.GraphExecutor
    48  
    49  		initPayload InitPayload
    50  	}
    51  	operationMessage struct {
    52  		Payload json.RawMessage `json:"payload,omitempty"`
    53  		ID      string          `json:"id,omitempty"`
    54  		Type    string          `json:"type"`
    55  	}
    56  	WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
    57  )
    58  
    59  var _ graphql.Transport = Websocket{}
    60  
    61  func (t Websocket) Supports(r *http.Request) bool {
    62  	return r.Header.Get("Upgrade") != ""
    63  }
    64  
    65  func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
    66  	ws, err := t.Upgrader.Upgrade(w, r, http.Header{
    67  		"Sec-Websocket-Protocol": []string{"graphql-ws"},
    68  	})
    69  	if err != nil {
    70  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    71  		SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    72  		return
    73  	}
    74  
    75  	conn := wsConnection{
    76  		active:    map[string]context.CancelFunc{},
    77  		conn:      ws,
    78  		ctx:       r.Context(),
    79  		exec:      exec,
    80  		Websocket: t,
    81  	}
    82  
    83  	if !conn.init() {
    84  		return
    85  	}
    86  
    87  	conn.run()
    88  }
    89  
    90  func (c *wsConnection) init() bool {
    91  	message := c.readOp()
    92  	if message == nil {
    93  		c.close(websocket.CloseProtocolError, "decoding error")
    94  		return false
    95  	}
    96  
    97  	switch message.Type {
    98  	case connectionInitMsg:
    99  		if len(message.Payload) > 0 {
   100  			c.initPayload = make(InitPayload)
   101  			err := json.Unmarshal(message.Payload, &c.initPayload)
   102  			if err != nil {
   103  				return false
   104  			}
   105  		}
   106  
   107  		if c.InitFunc != nil {
   108  			ctx, err := c.InitFunc(c.ctx, c.initPayload)
   109  			if err != nil {
   110  				c.sendConnectionError(err.Error())
   111  				c.close(websocket.CloseNormalClosure, "terminated")
   112  				return false
   113  			}
   114  			c.ctx = ctx
   115  		}
   116  
   117  		c.write(&operationMessage{Type: connectionAckMsg})
   118  		c.write(&operationMessage{Type: connectionKeepAliveMsg})
   119  	case connectionTerminateMsg:
   120  		c.close(websocket.CloseNormalClosure, "terminated")
   121  		return false
   122  	default:
   123  		c.sendConnectionError("unexpected message %s", message.Type)
   124  		c.close(websocket.CloseProtocolError, "unexpected message")
   125  		return false
   126  	}
   127  
   128  	return true
   129  }
   130  
   131  func (c *wsConnection) write(msg *operationMessage) {
   132  	c.mu.Lock()
   133  	c.conn.WriteJSON(msg)
   134  	c.mu.Unlock()
   135  }
   136  
   137  func (c *wsConnection) run() {
   138  	// We create a cancellation that will shutdown the keep-alive when we leave
   139  	// this function.
   140  	ctx, cancel := context.WithCancel(c.ctx)
   141  	defer func() {
   142  		cancel()
   143  	}()
   144  
   145  	// Create a timer that will fire every interval to keep the connection alive.
   146  	if c.KeepAlivePingInterval != 0 {
   147  		c.mu.Lock()
   148  		c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
   149  		c.mu.Unlock()
   150  		go c.keepAlive(ctx)
   151  	}
   152  
   153  	// Create a timer that will fire every interval a ping message that should
   154  	// receive a pong (SetPongHandler in init() function)
   155  	if c.PingPongInterval != 0 {
   156  
   157  		pongWait := 2 * c.PingPongInterval
   158  		c.conn.SetReadDeadline(time.Now().Add(pongWait))
   159  		c.conn.SetPongHandler(func(string) error {
   160  			return c.conn.SetReadDeadline(time.Now().UTC().Add(pongWait))
   161  		})
   162  
   163  		c.mu.Lock()
   164  		c.pingPongTicker = time.NewTicker(c.PingPongInterval)
   165  		c.mu.Unlock()
   166  
   167  		go c.ping(ctx)
   168  	}
   169  
   170  	for {
   171  		start := graphql.Now()
   172  		message := c.readOp()
   173  		if message == nil {
   174  			c.close(websocket.CloseAbnormalClosure, "unexpected closure")
   175  			return
   176  		}
   177  
   178  		switch message.Type {
   179  		case startMsg:
   180  			c.subscribe(start, message)
   181  		case stopMsg:
   182  			c.mu.Lock()
   183  			closer := c.active[message.ID]
   184  			c.mu.Unlock()
   185  			if closer != nil {
   186  				closer()
   187  			}
   188  		case connectionTerminateMsg:
   189  			c.close(websocket.CloseNormalClosure, "terminated")
   190  			return
   191  		default:
   192  			c.sendConnectionError("unexpected message %s", message.Type)
   193  			c.close(websocket.CloseProtocolError, "unexpected message")
   194  			return
   195  		}
   196  	}
   197  }
   198  
   199  func (c *wsConnection) keepAlive(ctx context.Context) {
   200  	for {
   201  		select {
   202  		case <-ctx.Done():
   203  			c.keepAliveTicker.Stop()
   204  			return
   205  		case <-c.keepAliveTicker.C:
   206  			c.write(&operationMessage{Type: connectionKeepAliveMsg})
   207  		}
   208  	}
   209  }
   210  
   211  func (c *wsConnection) ping(ctx context.Context) {
   212  	for {
   213  		select {
   214  		case <-ctx.Done():
   215  			c.pingPongTicker.Stop()
   216  			return
   217  		case <-c.pingPongTicker.C:
   218  			c.mu.Lock()
   219  			c.conn.WriteMessage(websocket.PingMessage, nil)
   220  			c.mu.Unlock()
   221  		}
   222  	}
   223  }
   224  
   225  func (c *wsConnection) subscribe(start time.Time, message *operationMessage) {
   226  	ctx := graphql.StartOperationTrace(c.ctx)
   227  	var params *graphql.RawParams
   228  	if err := jsonDecode(bytes.NewReader(message.Payload), &params); err != nil {
   229  		c.sendError(message.ID, &gqlerror.Error{Message: "invalid json"})
   230  		c.complete(message.ID)
   231  		return
   232  	}
   233  
   234  	params.ReadTime = graphql.TraceTiming{
   235  		Start: start,
   236  		End:   graphql.Now(),
   237  	}
   238  
   239  	rc, err := c.exec.CreateOperationContext(ctx, params)
   240  	if err != nil {
   241  		resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
   242  		switch errcode.GetErrorKind(err) {
   243  		case errcode.KindProtocol:
   244  			c.sendError(message.ID, resp.Errors...)
   245  		default:
   246  			c.sendResponse(message.ID, &graphql.Response{Errors: err})
   247  		}
   248  
   249  		c.complete(message.ID)
   250  		return
   251  	}
   252  
   253  	ctx = graphql.WithOperationContext(ctx, rc)
   254  
   255  	if c.initPayload != nil {
   256  		ctx = withInitPayload(ctx, c.initPayload)
   257  	}
   258  
   259  	ctx, cancel := context.WithCancel(ctx)
   260  	c.mu.Lock()
   261  	c.active[message.ID] = cancel
   262  	c.mu.Unlock()
   263  
   264  	go func() {
   265  		defer func() {
   266  			if r := recover(); r != nil {
   267  				userErr := rc.Recover(ctx, r)
   268  				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
   269  			}
   270  		}()
   271  		responses, ctx := c.exec.DispatchOperation(ctx, rc)
   272  		for {
   273  			response := responses(ctx)
   274  			if response == nil {
   275  				break
   276  			}
   277  
   278  			c.sendResponse(message.ID, response)
   279  		}
   280  		c.complete(message.ID)
   281  
   282  		c.mu.Lock()
   283  		delete(c.active, message.ID)
   284  		c.mu.Unlock()
   285  		cancel()
   286  	}()
   287  }
   288  
   289  func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
   290  	b, err := json.Marshal(response)
   291  	if err != nil {
   292  		panic(err)
   293  	}
   294  	c.write(&operationMessage{
   295  		Payload: b,
   296  		ID:      id,
   297  		Type:    dataMsg,
   298  	})
   299  }
   300  
   301  func (c *wsConnection) complete(id string) {
   302  	c.write(&operationMessage{ID: id, Type: completeMsg})
   303  }
   304  
   305  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   306  	errs := make([]error, len(errors))
   307  	for i, err := range errors {
   308  		errs[i] = err
   309  	}
   310  	b, err := json.Marshal(errs)
   311  	if err != nil {
   312  		panic(err)
   313  	}
   314  	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
   315  }
   316  
   317  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   318  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   319  	if err != nil {
   320  		panic(err)
   321  	}
   322  
   323  	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
   324  }
   325  
   326  func (c *wsConnection) readOp() *operationMessage {
   327  	_, r, err := c.conn.NextReader()
   328  	if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
   329  		return nil
   330  	} else if err != nil {
   331  		c.sendConnectionError("invalid json: %T %s", err, err.Error())
   332  		return nil
   333  	}
   334  	message := operationMessage{}
   335  	if err := jsonDecode(r, &message); err != nil {
   336  		c.sendConnectionError("invalid json")
   337  		return nil
   338  	}
   339  
   340  	return &message
   341  }
   342  
   343  func (c *wsConnection) close(closeCode int, message string) {
   344  	c.mu.Lock()
   345  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   346  	for _, closer := range c.active {
   347  		closer()
   348  	}
   349  	c.mu.Unlock()
   350  	_ = c.conn.Close()
   351  }