github.com/apipluspower/gqlgen@v0.15.2/graphql/handler/transport/websocket.go (about)

     1  package transport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"log"
    10  	"net/http"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/apipluspower/gqlgen/graphql"
    15  	"github.com/apipluspower/gqlgen/graphql/errcode"
    16  	"github.com/gorilla/websocket"
    17  	"github.com/vektah/gqlparser/v2/gqlerror"
    18  )
    19  
    20  type (
    21  	Websocket struct {
    22  		Upgrader              websocket.Upgrader
    23  		InitFunc              WebsocketInitFunc
    24  		KeepAlivePingInterval time.Duration
    25  		PingPongInterval      time.Duration
    26  
    27  		didInjectSubprotocols bool
    28  	}
    29  	wsConnection struct {
    30  		Websocket
    31  		ctx             context.Context
    32  		conn            *websocket.Conn
    33  		me              messageExchanger
    34  		active          map[string]context.CancelFunc
    35  		mu              sync.Mutex
    36  		keepAliveTicker *time.Ticker
    37  		pingPongTicker  *time.Ticker
    38  		exec            graphql.GraphExecutor
    39  
    40  		initPayload InitPayload
    41  	}
    42  
    43  	WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
    44  )
    45  
    46  var _ graphql.Transport = Websocket{}
    47  
    48  func (t Websocket) Supports(r *http.Request) bool {
    49  	return r.Header.Get("Upgrade") != ""
    50  }
    51  
    52  func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
    53  	t.injectGraphQLWSSubprotocols()
    54  	ws, err := t.Upgrader.Upgrade(w, r, http.Header{})
    55  	if err != nil {
    56  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    57  		SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    58  		return
    59  	}
    60  
    61  	var me messageExchanger
    62  	switch ws.Subprotocol() {
    63  	default:
    64  		msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol()))
    65  		ws.WriteMessage(websocket.CloseMessage, msg)
    66  		return
    67  	case graphqlwsSubprotocol, "":
    68  		// clients are required to send a subprotocol, to be backward compatible with the previous implementation we select
    69  		// "graphql-ws" by default
    70  		me = graphqlwsMessageExchanger{c: ws}
    71  	case graphqltransportwsSubprotocol:
    72  		me = graphqltransportwsMessageExchanger{c: ws}
    73  	}
    74  
    75  	conn := wsConnection{
    76  		active:    map[string]context.CancelFunc{},
    77  		conn:      ws,
    78  		ctx:       r.Context(),
    79  		exec:      exec,
    80  		me:        me,
    81  		Websocket: t,
    82  	}
    83  
    84  	if !conn.init() {
    85  		return
    86  	}
    87  
    88  	conn.run()
    89  }
    90  
    91  func (c *wsConnection) init() bool {
    92  	m, err := c.me.NextMessage()
    93  	if err != nil {
    94  		if err == errInvalidMsg {
    95  			c.sendConnectionError("invalid json")
    96  		}
    97  
    98  		c.close(websocket.CloseProtocolError, "decoding error")
    99  		return false
   100  	}
   101  
   102  	switch m.t {
   103  	case initMessageType:
   104  		if len(m.payload) > 0 {
   105  			c.initPayload = make(InitPayload)
   106  			err := json.Unmarshal(m.payload, &c.initPayload)
   107  			if err != nil {
   108  				return false
   109  			}
   110  		}
   111  
   112  		if c.InitFunc != nil {
   113  			ctx, err := c.InitFunc(c.ctx, c.initPayload)
   114  			if err != nil {
   115  				c.sendConnectionError(err.Error())
   116  				c.close(websocket.CloseNormalClosure, "terminated")
   117  				return false
   118  			}
   119  			c.ctx = ctx
   120  		}
   121  
   122  		c.write(&message{t: connectionAckMessageType})
   123  		c.write(&message{t: keepAliveMessageType})
   124  	case connectionCloseMessageType:
   125  		c.close(websocket.CloseNormalClosure, "terminated")
   126  		return false
   127  	default:
   128  		c.sendConnectionError("unexpected message %s", m.t)
   129  		c.close(websocket.CloseProtocolError, "unexpected message")
   130  		return false
   131  	}
   132  
   133  	return true
   134  }
   135  
   136  func (c *wsConnection) write(msg *message) {
   137  	c.mu.Lock()
   138  	// TODO: missing error handling here, err from previous implementation
   139  	// was ignored
   140  	_ = c.me.Send(msg)
   141  	c.mu.Unlock()
   142  }
   143  
   144  func (c *wsConnection) run() {
   145  	// We create a cancellation that will shutdown the keep-alive when we leave
   146  	// this function.
   147  	ctx, cancel := context.WithCancel(c.ctx)
   148  	defer func() {
   149  		cancel()
   150  		c.close(websocket.CloseAbnormalClosure, "unexpected closure")
   151  	}()
   152  
   153  	// Create a timer that will fire every interval to keep the connection alive.
   154  	if c.KeepAlivePingInterval != 0 {
   155  		c.mu.Lock()
   156  		c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
   157  		c.mu.Unlock()
   158  
   159  		go c.keepAlive(ctx)
   160  	}
   161  
   162  	// Create a timer that will fire every interval a ping message that should
   163  	// receive a pong (SetPongHandler in init() function)
   164  	if c.PingPongInterval != 0 {
   165  		c.mu.Lock()
   166  		c.pingPongTicker = time.NewTicker(c.PingPongInterval)
   167  		c.mu.Unlock()
   168  
   169  		c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
   170  		go c.ping(ctx)
   171  	}
   172  
   173  	// Close the connection when the context is cancelled.
   174  	// Will optionally send a "close reason" that is retrieved from the context.
   175  	go c.closeOnCancel(ctx)
   176  
   177  	for {
   178  		start := graphql.Now()
   179  		m, err := c.me.NextMessage()
   180  		if err != nil {
   181  			// TODO: better error handling here
   182  			return
   183  		}
   184  
   185  		switch m.t {
   186  		case startMessageType:
   187  			c.subscribe(start, &m)
   188  		case stopMessageType:
   189  			c.mu.Lock()
   190  			closer := c.active[m.id]
   191  			c.mu.Unlock()
   192  			if closer != nil {
   193  				closer()
   194  			}
   195  		case connectionCloseMessageType:
   196  			c.close(websocket.CloseNormalClosure, "terminated")
   197  			return
   198  		case pingMesageType:
   199  			c.write(&message{t: pongMessageType, payload: m.payload})
   200  		case pongMessageType:
   201  			c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
   202  		default:
   203  			c.sendConnectionError("unexpected message %s", m.t)
   204  			c.close(websocket.CloseProtocolError, "unexpected message")
   205  			return
   206  		}
   207  	}
   208  }
   209  
   210  func (c *wsConnection) keepAlive(ctx context.Context) {
   211  	for {
   212  		select {
   213  		case <-ctx.Done():
   214  			c.keepAliveTicker.Stop()
   215  			return
   216  		case <-c.keepAliveTicker.C:
   217  			c.write(&message{t: keepAliveMessageType})
   218  		}
   219  	}
   220  }
   221  
   222  func (c *wsConnection) ping(ctx context.Context) {
   223  	for {
   224  		select {
   225  		case <-ctx.Done():
   226  			c.pingPongTicker.Stop()
   227  			return
   228  		case <-c.pingPongTicker.C:
   229  			c.write(&message{t: pingMesageType, payload: json.RawMessage{}})
   230  		}
   231  	}
   232  }
   233  
   234  func (c *wsConnection) closeOnCancel(ctx context.Context) {
   235  	<-ctx.Done()
   236  
   237  	if r := closeReasonForContext(ctx); r != "" {
   238  		c.sendConnectionError(r)
   239  	}
   240  	c.close(websocket.CloseNormalClosure, "terminated")
   241  }
   242  
   243  func (c *wsConnection) subscribe(start time.Time, msg *message) {
   244  	ctx := graphql.StartOperationTrace(c.ctx)
   245  	var params *graphql.RawParams
   246  	if err := jsonDecode(bytes.NewReader(msg.payload), &params); err != nil {
   247  		c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"})
   248  		c.complete(msg.id)
   249  		return
   250  	}
   251  
   252  	params.ReadTime = graphql.TraceTiming{
   253  		Start: start,
   254  		End:   graphql.Now(),
   255  	}
   256  
   257  	rc, err := c.exec.CreateOperationContext(ctx, params)
   258  	if err != nil {
   259  		resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
   260  		switch errcode.GetErrorKind(err) {
   261  		case errcode.KindProtocol:
   262  			c.sendError(msg.id, resp.Errors...)
   263  		default:
   264  			c.sendResponse(msg.id, &graphql.Response{Errors: err})
   265  		}
   266  
   267  		c.complete(msg.id)
   268  		return
   269  	}
   270  
   271  	ctx = graphql.WithOperationContext(ctx, rc)
   272  
   273  	if c.initPayload != nil {
   274  		ctx = withInitPayload(ctx, c.initPayload)
   275  	}
   276  
   277  	ctx, cancel := context.WithCancel(ctx)
   278  	c.mu.Lock()
   279  	c.active[msg.id] = cancel
   280  	c.mu.Unlock()
   281  
   282  	go func() {
   283  		defer func() {
   284  			if r := recover(); r != nil {
   285  				err := rc.Recover(ctx, r)
   286  				var gqlerr *gqlerror.Error
   287  				if !errors.As(err, &gqlerr) {
   288  					gqlerr = &gqlerror.Error{}
   289  					if err != nil {
   290  						gqlerr.Message = err.Error()
   291  					}
   292  				}
   293  				c.sendError(msg.id, gqlerr)
   294  			}
   295  			c.complete(msg.id)
   296  			c.mu.Lock()
   297  			delete(c.active, msg.id)
   298  			c.mu.Unlock()
   299  			cancel()
   300  		}()
   301  
   302  		responses, ctx := c.exec.DispatchOperation(ctx, rc)
   303  		for {
   304  			response := responses(ctx)
   305  			if response == nil {
   306  				break
   307  			}
   308  
   309  			c.sendResponse(msg.id, response)
   310  		}
   311  		c.complete(msg.id)
   312  
   313  		c.mu.Lock()
   314  		delete(c.active, msg.id)
   315  		c.mu.Unlock()
   316  		cancel()
   317  	}()
   318  }
   319  
   320  func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
   321  	b, err := json.Marshal(response)
   322  	if err != nil {
   323  		panic(err)
   324  	}
   325  	c.write(&message{
   326  		payload: b,
   327  		id:      id,
   328  		t:       dataMessageType,
   329  	})
   330  }
   331  
   332  func (c *wsConnection) complete(id string) {
   333  	c.write(&message{id: id, t: completeMessageType})
   334  }
   335  
   336  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   337  	errs := make([]error, len(errors))
   338  	for i, err := range errors {
   339  		errs[i] = err
   340  	}
   341  	b, err := json.Marshal(errs)
   342  	if err != nil {
   343  		panic(err)
   344  	}
   345  	c.write(&message{t: errorMessageType, id: id, payload: b})
   346  }
   347  
   348  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   349  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   350  	if err != nil {
   351  		panic(err)
   352  	}
   353  
   354  	c.write(&message{t: connectionErrorMessageType, payload: b})
   355  }
   356  
   357  func (c *wsConnection) close(closeCode int, message string) {
   358  	c.mu.Lock()
   359  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   360  	for _, closer := range c.active {
   361  		closer()
   362  	}
   363  	c.mu.Unlock()
   364  	_ = c.conn.Close()
   365  }