github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/graphql/handler/transport/websocket.go (about)

     1  package transport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"log"
    10  	"net"
    11  	"net/http"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/gorilla/websocket"
    16  	"github.com/mstephano/gqlgen-schemagen/graphql"
    17  	"github.com/mstephano/gqlgen-schemagen/graphql/errcode"
    18  	"github.com/vektah/gqlparser/v2/gqlerror"
    19  )
    20  
    21  type (
    22  	Websocket struct {
    23  		Upgrader              websocket.Upgrader
    24  		InitFunc              WebsocketInitFunc
    25  		InitTimeout           time.Duration
    26  		ErrorFunc             WebsocketErrorFunc
    27  		KeepAlivePingInterval time.Duration
    28  		PingPongInterval      time.Duration
    29  
    30  		didInjectSubprotocols bool
    31  	}
    32  	wsConnection struct {
    33  		Websocket
    34  		ctx             context.Context
    35  		conn            *websocket.Conn
    36  		me              messageExchanger
    37  		active          map[string]context.CancelFunc
    38  		mu              sync.Mutex
    39  		keepAliveTicker *time.Ticker
    40  		pingPongTicker  *time.Ticker
    41  		exec            graphql.GraphExecutor
    42  
    43  		initPayload InitPayload
    44  	}
    45  
    46  	WebsocketInitFunc  func(ctx context.Context, initPayload InitPayload) (context.Context, error)
    47  	WebsocketErrorFunc func(ctx context.Context, err error)
    48  )
    49  
    50  var errReadTimeout = errors.New("read timeout")
    51  
    52  type WebsocketError struct {
    53  	Err error
    54  
    55  	// IsReadError flags whether the error occurred on read or write to the websocket
    56  	IsReadError bool
    57  }
    58  
    59  func (e WebsocketError) Error() string {
    60  	if e.IsReadError {
    61  		return fmt.Sprintf("websocket read: %v", e.Err)
    62  	}
    63  	return fmt.Sprintf("websocket write: %v", e.Err)
    64  }
    65  
    66  var (
    67  	_ graphql.Transport = Websocket{}
    68  	_ error             = WebsocketError{}
    69  )
    70  
    71  func (t Websocket) Supports(r *http.Request) bool {
    72  	return r.Header.Get("Upgrade") != ""
    73  }
    74  
    75  func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
    76  	t.injectGraphQLWSSubprotocols()
    77  	ws, err := t.Upgrader.Upgrade(w, r, http.Header{})
    78  	if err != nil {
    79  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    80  		SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    81  		return
    82  	}
    83  
    84  	var me messageExchanger
    85  	switch ws.Subprotocol() {
    86  	default:
    87  		msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol()))
    88  		ws.WriteMessage(websocket.CloseMessage, msg)
    89  		return
    90  	case graphqlwsSubprotocol, "":
    91  		// clients are required to send a subprotocol, to be backward compatible with the previous implementation we select
    92  		// "graphql-ws" by default
    93  		me = graphqlwsMessageExchanger{c: ws}
    94  	case graphqltransportwsSubprotocol:
    95  		me = graphqltransportwsMessageExchanger{c: ws}
    96  	}
    97  
    98  	conn := wsConnection{
    99  		active:    map[string]context.CancelFunc{},
   100  		conn:      ws,
   101  		ctx:       r.Context(),
   102  		exec:      exec,
   103  		me:        me,
   104  		Websocket: t,
   105  	}
   106  
   107  	if !conn.init() {
   108  		return
   109  	}
   110  
   111  	conn.run()
   112  }
   113  
   114  func (c *wsConnection) handlePossibleError(err error, isReadError bool) {
   115  	if c.ErrorFunc != nil && err != nil {
   116  		c.ErrorFunc(c.ctx, WebsocketError{
   117  			Err:         err,
   118  			IsReadError: isReadError,
   119  		})
   120  	}
   121  }
   122  
   123  func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) {
   124  	messages, errs := make(chan message, 1), make(chan error, 1)
   125  
   126  	go func() {
   127  		if m, err := c.me.NextMessage(); err != nil {
   128  			errs <- err
   129  		} else {
   130  			messages <- m
   131  		}
   132  	}()
   133  
   134  	select {
   135  	case m := <-messages:
   136  		return m, nil
   137  	case err := <-errs:
   138  		return message{}, err
   139  	case <-time.After(timeout):
   140  		return message{}, errReadTimeout
   141  	}
   142  }
   143  
   144  func (c *wsConnection) init() bool {
   145  	var m message
   146  	var err error
   147  
   148  	if c.InitTimeout != 0 {
   149  		m, err = c.nextMessageWithTimeout(c.InitTimeout)
   150  	} else {
   151  		m, err = c.me.NextMessage()
   152  	}
   153  
   154  	if err != nil {
   155  		if err == errReadTimeout {
   156  			c.close(websocket.CloseProtocolError, "connection initialisation timeout")
   157  			return false
   158  		}
   159  
   160  		if err == errInvalidMsg {
   161  			c.sendConnectionError("invalid json")
   162  		}
   163  
   164  		c.close(websocket.CloseProtocolError, "decoding error")
   165  		return false
   166  	}
   167  
   168  	switch m.t {
   169  	case initMessageType:
   170  		if len(m.payload) > 0 {
   171  			c.initPayload = make(InitPayload)
   172  			err := json.Unmarshal(m.payload, &c.initPayload)
   173  			if err != nil {
   174  				return false
   175  			}
   176  		}
   177  
   178  		if c.InitFunc != nil {
   179  			ctx, err := c.InitFunc(c.ctx, c.initPayload)
   180  			if err != nil {
   181  				c.sendConnectionError(err.Error())
   182  				c.close(websocket.CloseNormalClosure, "terminated")
   183  				return false
   184  			}
   185  			c.ctx = ctx
   186  		}
   187  
   188  		c.write(&message{t: connectionAckMessageType})
   189  		c.write(&message{t: keepAliveMessageType})
   190  	case connectionCloseMessageType:
   191  		c.close(websocket.CloseNormalClosure, "terminated")
   192  		return false
   193  	default:
   194  		c.sendConnectionError("unexpected message %s", m.t)
   195  		c.close(websocket.CloseProtocolError, "unexpected message")
   196  		return false
   197  	}
   198  
   199  	return true
   200  }
   201  
   202  func (c *wsConnection) write(msg *message) {
   203  	c.mu.Lock()
   204  	c.handlePossibleError(c.me.Send(msg), false)
   205  	c.mu.Unlock()
   206  }
   207  
   208  func (c *wsConnection) run() {
   209  	// We create a cancellation that will shutdown the keep-alive when we leave
   210  	// this function.
   211  	ctx, cancel := context.WithCancel(c.ctx)
   212  	defer func() {
   213  		cancel()
   214  		c.close(websocket.CloseAbnormalClosure, "unexpected closure")
   215  	}()
   216  
   217  	// If we're running in graphql-ws mode, create a timer that will trigger a
   218  	// keep alive message every interval
   219  	if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 {
   220  		c.mu.Lock()
   221  		c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
   222  		c.mu.Unlock()
   223  
   224  		go c.keepAlive(ctx)
   225  	}
   226  
   227  	// If we're running in graphql-transport-ws mode, create a timer that will
   228  	// trigger a ping message every interval
   229  	if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 {
   230  		c.mu.Lock()
   231  		c.pingPongTicker = time.NewTicker(c.PingPongInterval)
   232  		c.mu.Unlock()
   233  
   234  		// Note: when the connection is closed by this deadline, the client
   235  		// will receive an "invalid close code"
   236  		c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
   237  		go c.ping(ctx)
   238  	}
   239  
   240  	// Close the connection when the context is cancelled.
   241  	// Will optionally send a "close reason" that is retrieved from the context.
   242  	go c.closeOnCancel(ctx)
   243  
   244  	for {
   245  		start := graphql.Now()
   246  		m, err := c.me.NextMessage()
   247  		if err != nil {
   248  			// If the connection got closed by us, don't report the error
   249  			if !errors.Is(err, net.ErrClosed) {
   250  				c.handlePossibleError(err, true)
   251  			}
   252  			return
   253  		}
   254  
   255  		switch m.t {
   256  		case startMessageType:
   257  			c.subscribe(start, &m)
   258  		case stopMessageType:
   259  			c.mu.Lock()
   260  			closer := c.active[m.id]
   261  			c.mu.Unlock()
   262  			if closer != nil {
   263  				closer()
   264  			}
   265  		case connectionCloseMessageType:
   266  			c.close(websocket.CloseNormalClosure, "terminated")
   267  			return
   268  		case pingMessageType:
   269  			c.write(&message{t: pongMessageType, payload: m.payload})
   270  		case pongMessageType:
   271  			c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
   272  		default:
   273  			c.sendConnectionError("unexpected message %s", m.t)
   274  			c.close(websocket.CloseProtocolError, "unexpected message")
   275  			return
   276  		}
   277  	}
   278  }
   279  
   280  func (c *wsConnection) keepAlive(ctx context.Context) {
   281  	for {
   282  		select {
   283  		case <-ctx.Done():
   284  			c.keepAliveTicker.Stop()
   285  			return
   286  		case <-c.keepAliveTicker.C:
   287  			c.write(&message{t: keepAliveMessageType})
   288  		}
   289  	}
   290  }
   291  
   292  func (c *wsConnection) ping(ctx context.Context) {
   293  	for {
   294  		select {
   295  		case <-ctx.Done():
   296  			c.pingPongTicker.Stop()
   297  			return
   298  		case <-c.pingPongTicker.C:
   299  			c.write(&message{t: pingMessageType, payload: json.RawMessage{}})
   300  		}
   301  	}
   302  }
   303  
   304  func (c *wsConnection) closeOnCancel(ctx context.Context) {
   305  	<-ctx.Done()
   306  
   307  	if r := closeReasonForContext(ctx); r != "" {
   308  		c.sendConnectionError(r)
   309  	}
   310  	c.close(websocket.CloseNormalClosure, "terminated")
   311  }
   312  
   313  func (c *wsConnection) subscribe(start time.Time, msg *message) {
   314  	ctx := graphql.StartOperationTrace(c.ctx)
   315  	var params *graphql.RawParams
   316  	if err := jsonDecode(bytes.NewReader(msg.payload), &params); err != nil {
   317  		c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"})
   318  		c.complete(msg.id)
   319  		return
   320  	}
   321  
   322  	params.ReadTime = graphql.TraceTiming{
   323  		Start: start,
   324  		End:   graphql.Now(),
   325  	}
   326  
   327  	rc, err := c.exec.CreateOperationContext(ctx, params)
   328  	if err != nil {
   329  		resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
   330  		switch errcode.GetErrorKind(err) {
   331  		case errcode.KindProtocol:
   332  			c.sendError(msg.id, resp.Errors...)
   333  		default:
   334  			c.sendResponse(msg.id, &graphql.Response{Errors: err})
   335  		}
   336  
   337  		c.complete(msg.id)
   338  		return
   339  	}
   340  
   341  	ctx = graphql.WithOperationContext(ctx, rc)
   342  
   343  	if c.initPayload != nil {
   344  		ctx = withInitPayload(ctx, c.initPayload)
   345  	}
   346  
   347  	ctx, cancel := context.WithCancel(ctx)
   348  	c.mu.Lock()
   349  	c.active[msg.id] = cancel
   350  	c.mu.Unlock()
   351  
   352  	go func() {
   353  		defer func() {
   354  			if r := recover(); r != nil {
   355  				err := rc.Recover(ctx, r)
   356  				var gqlerr *gqlerror.Error
   357  				if !errors.As(err, &gqlerr) {
   358  					gqlerr = &gqlerror.Error{}
   359  					if err != nil {
   360  						gqlerr.Message = err.Error()
   361  					}
   362  				}
   363  				c.sendError(msg.id, gqlerr)
   364  			}
   365  			c.complete(msg.id)
   366  			c.mu.Lock()
   367  			delete(c.active, msg.id)
   368  			c.mu.Unlock()
   369  			cancel()
   370  		}()
   371  
   372  		responses, ctx := c.exec.DispatchOperation(ctx, rc)
   373  		for {
   374  			response := responses(ctx)
   375  			if response == nil {
   376  				break
   377  			}
   378  
   379  			c.sendResponse(msg.id, response)
   380  		}
   381  
   382  		// complete and context cancel comes from the defer
   383  	}()
   384  }
   385  
   386  func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
   387  	b, err := json.Marshal(response)
   388  	if err != nil {
   389  		panic(err)
   390  	}
   391  	c.write(&message{
   392  		payload: b,
   393  		id:      id,
   394  		t:       dataMessageType,
   395  	})
   396  }
   397  
   398  func (c *wsConnection) complete(id string) {
   399  	c.write(&message{id: id, t: completeMessageType})
   400  }
   401  
   402  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   403  	errs := make([]error, len(errors))
   404  	for i, err := range errors {
   405  		errs[i] = err
   406  	}
   407  	b, err := json.Marshal(errs)
   408  	if err != nil {
   409  		panic(err)
   410  	}
   411  	c.write(&message{t: errorMessageType, id: id, payload: b})
   412  }
   413  
   414  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   415  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   416  	if err != nil {
   417  		panic(err)
   418  	}
   419  
   420  	c.write(&message{t: connectionErrorMessageType, payload: b})
   421  }
   422  
   423  func (c *wsConnection) close(closeCode int, message string) {
   424  	c.mu.Lock()
   425  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   426  	for _, closer := range c.active {
   427  		closer()
   428  	}
   429  	c.mu.Unlock()
   430  	_ = c.conn.Close()
   431  }