github.com/geneva/gqlgen@v0.17.7-0.20230801155730-7b9317164836/graphql/handler/transport/websocket.go (about)

     1  package transport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"log"
    10  	"net"
    11  	"net/http"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/geneva/gqlgen/graphql"
    16  	"github.com/geneva/gqlgen/graphql/errcode"
    17  	"github.com/gorilla/websocket"
    18  	"github.com/vektah/gqlparser/v2/gqlerror"
    19  )
    20  
    21  type (
    22  	Websocket struct {
    23  		Upgrader              websocket.Upgrader
    24  		InitFunc              WebsocketInitFunc
    25  		InitTimeout           time.Duration
    26  		ErrorFunc             WebsocketErrorFunc
    27  		CloseFunc             WebsocketCloseFunc
    28  		KeepAlivePingInterval time.Duration
    29  		PingPongInterval      time.Duration
    30  
    31  		didInjectSubprotocols bool
    32  	}
    33  	wsConnection struct {
    34  		Websocket
    35  		ctx             context.Context
    36  		conn            *websocket.Conn
    37  		me              messageExchanger
    38  		active          map[string]context.CancelFunc
    39  		mu              sync.Mutex
    40  		keepAliveTicker *time.Ticker
    41  		pingPongTicker  *time.Ticker
    42  		exec            graphql.GraphExecutor
    43  
    44  		initPayload InitPayload
    45  	}
    46  
    47  	WebsocketInitFunc  func(ctx context.Context, initPayload InitPayload) (context.Context, error)
    48  	WebsocketErrorFunc func(ctx context.Context, err error)
    49  
    50  	// Callback called when websocket is closed.
    51  	WebsocketCloseFunc func(ctx context.Context, closeCode int)
    52  )
    53  
    54  var errReadTimeout = errors.New("read timeout")
    55  
    56  type WebsocketError struct {
    57  	Err error
    58  
    59  	// IsReadError flags whether the error occurred on read or write to the websocket
    60  	IsReadError bool
    61  }
    62  
    63  func (e WebsocketError) Error() string {
    64  	if e.IsReadError {
    65  		return fmt.Sprintf("websocket read: %v", e.Err)
    66  	}
    67  	return fmt.Sprintf("websocket write: %v", e.Err)
    68  }
    69  
    70  var (
    71  	_ graphql.Transport = Websocket{}
    72  	_ error             = WebsocketError{}
    73  )
    74  
    75  func (t Websocket) Supports(r *http.Request) bool {
    76  	return r.Header.Get("Upgrade") != ""
    77  }
    78  
    79  func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
    80  	t.injectGraphQLWSSubprotocols()
    81  	ws, err := t.Upgrader.Upgrade(w, r, http.Header{})
    82  	if err != nil {
    83  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    84  		SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    85  		return
    86  	}
    87  
    88  	var me messageExchanger
    89  	switch ws.Subprotocol() {
    90  	default:
    91  		msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol()))
    92  		ws.WriteMessage(websocket.CloseMessage, msg)
    93  		return
    94  	case graphqlwsSubprotocol, "":
    95  		// clients are required to send a subprotocol, to be backward compatible with the previous implementation we select
    96  		// "graphql-ws" by default
    97  		me = graphqlwsMessageExchanger{c: ws}
    98  	case graphqltransportwsSubprotocol:
    99  		me = graphqltransportwsMessageExchanger{c: ws}
   100  	}
   101  
   102  	conn := wsConnection{
   103  		active:    map[string]context.CancelFunc{},
   104  		conn:      ws,
   105  		ctx:       r.Context(),
   106  		exec:      exec,
   107  		me:        me,
   108  		Websocket: t,
   109  	}
   110  
   111  	if !conn.init() {
   112  		return
   113  	}
   114  
   115  	conn.run()
   116  }
   117  
   118  func (c *wsConnection) handlePossibleError(err error, isReadError bool) {
   119  	if c.ErrorFunc != nil && err != nil {
   120  		c.ErrorFunc(c.ctx, WebsocketError{
   121  			Err:         err,
   122  			IsReadError: isReadError,
   123  		})
   124  	}
   125  }
   126  
   127  func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) {
   128  	messages, errs := make(chan message, 1), make(chan error, 1)
   129  
   130  	go func() {
   131  		if m, err := c.me.NextMessage(); err != nil {
   132  			errs <- err
   133  		} else {
   134  			messages <- m
   135  		}
   136  	}()
   137  
   138  	select {
   139  	case m := <-messages:
   140  		return m, nil
   141  	case err := <-errs:
   142  		return message{}, err
   143  	case <-time.After(timeout):
   144  		return message{}, errReadTimeout
   145  	}
   146  }
   147  
   148  func (c *wsConnection) init() bool {
   149  	var m message
   150  	var err error
   151  
   152  	if c.InitTimeout != 0 {
   153  		m, err = c.nextMessageWithTimeout(c.InitTimeout)
   154  	} else {
   155  		m, err = c.me.NextMessage()
   156  	}
   157  
   158  	if err != nil {
   159  		if err == errReadTimeout {
   160  			c.close(websocket.CloseProtocolError, "connection initialisation timeout")
   161  			return false
   162  		}
   163  
   164  		if err == errInvalidMsg {
   165  			c.sendConnectionError("invalid json")
   166  		}
   167  
   168  		c.close(websocket.CloseProtocolError, "decoding error")
   169  		return false
   170  	}
   171  
   172  	switch m.t {
   173  	case initMessageType:
   174  		if len(m.payload) > 0 {
   175  			c.initPayload = make(InitPayload)
   176  			err := json.Unmarshal(m.payload, &c.initPayload)
   177  			if err != nil {
   178  				return false
   179  			}
   180  		}
   181  
   182  		if c.InitFunc != nil {
   183  			ctx, err := c.InitFunc(c.ctx, c.initPayload)
   184  			if err != nil {
   185  				c.sendConnectionError(err.Error())
   186  				c.close(websocket.CloseNormalClosure, "terminated")
   187  				return false
   188  			}
   189  			c.ctx = ctx
   190  		}
   191  
   192  		c.write(&message{t: connectionAckMessageType})
   193  		c.write(&message{t: keepAliveMessageType})
   194  	case connectionCloseMessageType:
   195  		c.close(websocket.CloseNormalClosure, "terminated")
   196  		return false
   197  	default:
   198  		c.sendConnectionError("unexpected message %s", m.t)
   199  		c.close(websocket.CloseProtocolError, "unexpected message")
   200  		return false
   201  	}
   202  
   203  	return true
   204  }
   205  
   206  func (c *wsConnection) write(msg *message) {
   207  	c.mu.Lock()
   208  	c.handlePossibleError(c.me.Send(msg), false)
   209  	c.mu.Unlock()
   210  }
   211  
   212  func (c *wsConnection) run() {
   213  	// We create a cancellation that will shutdown the keep-alive when we leave
   214  	// this function.
   215  	ctx, cancel := context.WithCancel(c.ctx)
   216  	defer func() {
   217  		cancel()
   218  		c.close(websocket.CloseAbnormalClosure, "unexpected closure")
   219  	}()
   220  
   221  	// If we're running in graphql-ws mode, create a timer that will trigger a
   222  	// keep alive message every interval
   223  	if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 {
   224  		c.mu.Lock()
   225  		c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
   226  		c.mu.Unlock()
   227  
   228  		go c.keepAlive(ctx)
   229  	}
   230  
   231  	// If we're running in graphql-transport-ws mode, create a timer that will
   232  	// trigger a ping message every interval
   233  	if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 {
   234  		c.mu.Lock()
   235  		c.pingPongTicker = time.NewTicker(c.PingPongInterval)
   236  		c.mu.Unlock()
   237  
   238  		// Note: when the connection is closed by this deadline, the client
   239  		// will receive an "invalid close code"
   240  		c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
   241  		go c.ping(ctx)
   242  	}
   243  
   244  	// Close the connection when the context is cancelled.
   245  	// Will optionally send a "close reason" that is retrieved from the context.
   246  	go c.closeOnCancel(ctx)
   247  
   248  	for {
   249  		start := graphql.Now()
   250  		m, err := c.me.NextMessage()
   251  		if err != nil {
   252  			// If the connection got closed by us, don't report the error
   253  			if !errors.Is(err, net.ErrClosed) {
   254  				c.handlePossibleError(err, true)
   255  			}
   256  			return
   257  		}
   258  
   259  		switch m.t {
   260  		case startMessageType:
   261  			c.subscribe(start, &m)
   262  		case stopMessageType:
   263  			c.mu.Lock()
   264  			closer := c.active[m.id]
   265  			c.mu.Unlock()
   266  			if closer != nil {
   267  				closer()
   268  			}
   269  		case connectionCloseMessageType:
   270  			c.close(websocket.CloseNormalClosure, "terminated")
   271  			return
   272  		case pingMessageType:
   273  			c.write(&message{t: pongMessageType, payload: m.payload})
   274  		case pongMessageType:
   275  			c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
   276  		default:
   277  			c.sendConnectionError("unexpected message %s", m.t)
   278  			c.close(websocket.CloseProtocolError, "unexpected message")
   279  			return
   280  		}
   281  	}
   282  }
   283  
   284  func (c *wsConnection) keepAlive(ctx context.Context) {
   285  	for {
   286  		select {
   287  		case <-ctx.Done():
   288  			c.keepAliveTicker.Stop()
   289  			return
   290  		case <-c.keepAliveTicker.C:
   291  			c.write(&message{t: keepAliveMessageType})
   292  		}
   293  	}
   294  }
   295  
   296  func (c *wsConnection) ping(ctx context.Context) {
   297  	for {
   298  		select {
   299  		case <-ctx.Done():
   300  			c.pingPongTicker.Stop()
   301  			return
   302  		case <-c.pingPongTicker.C:
   303  			c.write(&message{t: pingMessageType, payload: json.RawMessage{}})
   304  		}
   305  	}
   306  }
   307  
   308  func (c *wsConnection) closeOnCancel(ctx context.Context) {
   309  	<-ctx.Done()
   310  
   311  	if r := closeReasonForContext(ctx); r != "" {
   312  		c.sendConnectionError(r)
   313  	}
   314  	c.close(websocket.CloseNormalClosure, "terminated")
   315  }
   316  
   317  func (c *wsConnection) subscribe(start time.Time, msg *message) {
   318  	ctx := graphql.StartOperationTrace(c.ctx)
   319  	var params *graphql.RawParams
   320  	if err := jsonDecode(bytes.NewReader(msg.payload), &params); err != nil {
   321  		c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"})
   322  		c.complete(msg.id)
   323  		return
   324  	}
   325  
   326  	params.ReadTime = graphql.TraceTiming{
   327  		Start: start,
   328  		End:   graphql.Now(),
   329  	}
   330  
   331  	rc, err := c.exec.CreateOperationContext(ctx, params)
   332  	if err != nil {
   333  		resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
   334  		switch errcode.GetErrorKind(err) {
   335  		case errcode.KindProtocol:
   336  			c.sendError(msg.id, resp.Errors...)
   337  		default:
   338  			c.sendResponse(msg.id, &graphql.Response{Errors: err})
   339  		}
   340  
   341  		c.complete(msg.id)
   342  		return
   343  	}
   344  
   345  	ctx = graphql.WithOperationContext(ctx, rc)
   346  
   347  	if c.initPayload != nil {
   348  		ctx = withInitPayload(ctx, c.initPayload)
   349  	}
   350  
   351  	ctx, cancel := context.WithCancel(ctx)
   352  	c.mu.Lock()
   353  	c.active[msg.id] = cancel
   354  	c.mu.Unlock()
   355  
   356  	go func() {
   357  		ctx = withSubscriptionErrorContext(ctx)
   358  		defer func() {
   359  			if r := recover(); r != nil {
   360  				err := rc.Recover(ctx, r)
   361  				var gqlerr *gqlerror.Error
   362  				if !errors.As(err, &gqlerr) {
   363  					gqlerr = &gqlerror.Error{}
   364  					if err != nil {
   365  						gqlerr.Message = err.Error()
   366  					}
   367  				}
   368  				c.sendError(msg.id, gqlerr)
   369  			}
   370  			if errs := getSubscriptionError(ctx); len(errs) != 0 {
   371  				c.sendError(msg.id, errs...)
   372  			} else {
   373  				c.complete(msg.id)
   374  			}
   375  			c.mu.Lock()
   376  			delete(c.active, msg.id)
   377  			c.mu.Unlock()
   378  			cancel()
   379  		}()
   380  
   381  		responses, ctx := c.exec.DispatchOperation(ctx, rc)
   382  		for {
   383  			response := responses(ctx)
   384  			if response == nil {
   385  				break
   386  			}
   387  
   388  			c.sendResponse(msg.id, response)
   389  		}
   390  
   391  		// complete and context cancel comes from the defer
   392  	}()
   393  }
   394  
   395  func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
   396  	b, err := json.Marshal(response)
   397  	if err != nil {
   398  		panic(err)
   399  	}
   400  	c.write(&message{
   401  		payload: b,
   402  		id:      id,
   403  		t:       dataMessageType,
   404  	})
   405  }
   406  
   407  func (c *wsConnection) complete(id string) {
   408  	c.write(&message{id: id, t: completeMessageType})
   409  }
   410  
   411  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   412  	errs := make([]error, len(errors))
   413  	for i, err := range errors {
   414  		errs[i] = err
   415  	}
   416  	b, err := json.Marshal(errs)
   417  	if err != nil {
   418  		panic(err)
   419  	}
   420  	c.write(&message{t: errorMessageType, id: id, payload: b})
   421  }
   422  
   423  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   424  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   425  	if err != nil {
   426  		panic(err)
   427  	}
   428  
   429  	c.write(&message{t: connectionErrorMessageType, payload: b})
   430  }
   431  
   432  func (c *wsConnection) close(closeCode int, message string) {
   433  	c.mu.Lock()
   434  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   435  	for _, closer := range c.active {
   436  		closer()
   437  	}
   438  	c.mu.Unlock()
   439  	_ = c.conn.Close()
   440  
   441  	if c.CloseFunc != nil {
   442  		c.CloseFunc(c.ctx, closeCode)
   443  	}
   444  }