git.sr.ht/~sircmpwn/gqlgen@v0.0.0-20200522192042-c84d29a1c940/graphql/handler/transport/websocket.go (about)

     1  package transport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"log"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"git.sr.ht/~sircmpwn/gqlgen/graphql"
    14  	"git.sr.ht/~sircmpwn/gqlgen/graphql/errcode"
    15  	"github.com/gorilla/websocket"
    16  	"github.com/vektah/gqlparser/v2/gqlerror"
    17  )
    18  
    19  const (
    20  	connectionInitMsg      = "connection_init"      // Client -> Server
    21  	connectionTerminateMsg = "connection_terminate" // Client -> Server
    22  	startMsg               = "start"                // Client -> Server
    23  	stopMsg                = "stop"                 // Client -> Server
    24  	connectionAckMsg       = "connection_ack"       // Server -> Client
    25  	connectionErrorMsg     = "connection_error"     // Server -> Client
    26  	dataMsg                = "data"                 // Server -> Client
    27  	errorMsg               = "error"                // Server -> Client
    28  	completeMsg            = "complete"             // Server -> Client
    29  	connectionKeepAliveMsg = "ka"                   // Server -> Client
    30  )
    31  
    32  type (
    33  	Websocket struct {
    34  		Upgrader              websocket.Upgrader
    35  		InitFunc              WebsocketInitFunc
    36  		KeepAlivePingInterval time.Duration
    37  	}
    38  	wsConnection struct {
    39  		Websocket
    40  		ctx             context.Context
    41  		conn            *websocket.Conn
    42  		active          map[string]context.CancelFunc
    43  		mu              sync.Mutex
    44  		keepAliveTicker *time.Ticker
    45  		exec            graphql.GraphExecutor
    46  
    47  		initPayload InitPayload
    48  	}
    49  	operationMessage struct {
    50  		Payload json.RawMessage `json:"payload,omitempty"`
    51  		ID      string          `json:"id,omitempty"`
    52  		Type    string          `json:"type"`
    53  	}
    54  	WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
    55  )
    56  
    57  var _ graphql.Transport = Websocket{}
    58  
    59  func (t Websocket) Supports(r *http.Request) bool {
    60  	return r.Header.Get("Upgrade") != ""
    61  }
    62  
    63  func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
    64  	ws, err := t.Upgrader.Upgrade(w, r, http.Header{
    65  		"Sec-Websocket-Protocol": []string{"graphql-ws"},
    66  	})
    67  	if err != nil {
    68  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    69  		SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    70  		return
    71  	}
    72  
    73  	conn := wsConnection{
    74  		active:    map[string]context.CancelFunc{},
    75  		conn:      ws,
    76  		ctx:       r.Context(),
    77  		exec:      exec,
    78  		Websocket: t,
    79  	}
    80  
    81  	if !conn.init() {
    82  		return
    83  	}
    84  
    85  	conn.run()
    86  }
    87  
    88  func (c *wsConnection) init() bool {
    89  	message := c.readOp()
    90  	if message == nil {
    91  		c.close(websocket.CloseProtocolError, "decoding error")
    92  		return false
    93  	}
    94  
    95  	switch message.Type {
    96  	case connectionInitMsg:
    97  		if len(message.Payload) > 0 {
    98  			c.initPayload = make(InitPayload)
    99  			err := json.Unmarshal(message.Payload, &c.initPayload)
   100  			if err != nil {
   101  				return false
   102  			}
   103  		}
   104  
   105  		if c.InitFunc != nil {
   106  			ctx, err := c.InitFunc(c.ctx, c.initPayload)
   107  			if err != nil {
   108  				c.sendConnectionError(err.Error())
   109  				c.close(websocket.CloseNormalClosure, "terminated")
   110  				return false
   111  			}
   112  			c.ctx = ctx
   113  		}
   114  
   115  		c.write(&operationMessage{Type: connectionAckMsg})
   116  		c.write(&operationMessage{Type: connectionKeepAliveMsg})
   117  	case connectionTerminateMsg:
   118  		c.close(websocket.CloseNormalClosure, "terminated")
   119  		return false
   120  	default:
   121  		c.sendConnectionError("unexpected message %s", message.Type)
   122  		c.close(websocket.CloseProtocolError, "unexpected message")
   123  		return false
   124  	}
   125  
   126  	return true
   127  }
   128  
   129  func (c *wsConnection) write(msg *operationMessage) {
   130  	c.mu.Lock()
   131  	c.conn.WriteJSON(msg)
   132  	c.mu.Unlock()
   133  }
   134  
   135  func (c *wsConnection) run() {
   136  	// We create a cancellation that will shutdown the keep-alive when we leave
   137  	// this function.
   138  	ctx, cancel := context.WithCancel(c.ctx)
   139  	defer cancel()
   140  
   141  	// Create a timer that will fire every interval to keep the connection alive.
   142  	if c.KeepAlivePingInterval != 0 {
   143  		c.mu.Lock()
   144  		c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
   145  		c.mu.Unlock()
   146  
   147  		go c.keepAlive(ctx)
   148  	}
   149  
   150  	for {
   151  		start := graphql.Now()
   152  		message := c.readOp()
   153  		if message == nil {
   154  			return
   155  		}
   156  
   157  		switch message.Type {
   158  		case startMsg:
   159  			c.subscribe(start, message)
   160  		case stopMsg:
   161  			c.mu.Lock()
   162  			closer := c.active[message.ID]
   163  			c.mu.Unlock()
   164  			if closer != nil {
   165  				closer()
   166  			}
   167  		case connectionTerminateMsg:
   168  			c.close(websocket.CloseNormalClosure, "terminated")
   169  			return
   170  		default:
   171  			c.sendConnectionError("unexpected message %s", message.Type)
   172  			c.close(websocket.CloseProtocolError, "unexpected message")
   173  			return
   174  		}
   175  	}
   176  }
   177  
   178  func (c *wsConnection) keepAlive(ctx context.Context) {
   179  	for {
   180  		select {
   181  		case <-ctx.Done():
   182  			c.keepAliveTicker.Stop()
   183  			return
   184  		case <-c.keepAliveTicker.C:
   185  			c.write(&operationMessage{Type: connectionKeepAliveMsg})
   186  		}
   187  	}
   188  }
   189  
   190  func (c *wsConnection) subscribe(start time.Time, message *operationMessage) {
   191  	ctx := graphql.StartOperationTrace(c.ctx)
   192  	var params *graphql.RawParams
   193  	if err := jsonDecode(bytes.NewReader(message.Payload), &params); err != nil {
   194  		c.sendError(message.ID, &gqlerror.Error{Message: "invalid json"})
   195  		c.complete(message.ID)
   196  		return
   197  	}
   198  
   199  	params.ReadTime = graphql.TraceTiming{
   200  		Start: start,
   201  		End:   graphql.Now(),
   202  	}
   203  
   204  	rc, err := c.exec.CreateOperationContext(ctx, params)
   205  	if err != nil {
   206  		resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
   207  		switch errcode.GetErrorKind(err) {
   208  		case errcode.KindProtocol:
   209  			c.sendError(message.ID, resp.Errors...)
   210  		default:
   211  			c.sendResponse(message.ID, &graphql.Response{Errors: err})
   212  		}
   213  
   214  		c.complete(message.ID)
   215  		return
   216  	}
   217  
   218  	ctx = graphql.WithOperationContext(ctx, rc)
   219  
   220  	if c.initPayload != nil {
   221  		ctx = withInitPayload(ctx, c.initPayload)
   222  	}
   223  
   224  	ctx, cancel := context.WithCancel(ctx)
   225  	c.mu.Lock()
   226  	c.active[message.ID] = cancel
   227  	c.mu.Unlock()
   228  
   229  	go func() {
   230  		defer func() {
   231  			if r := recover(); r != nil {
   232  				userErr := rc.Recover(ctx, r)
   233  				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
   234  			}
   235  		}()
   236  		responses, ctx := c.exec.DispatchOperation(ctx, rc)
   237  		for {
   238  			response := responses(ctx)
   239  			if response == nil {
   240  				break
   241  			}
   242  
   243  			c.sendResponse(message.ID, response)
   244  		}
   245  		c.complete(message.ID)
   246  
   247  		c.mu.Lock()
   248  		delete(c.active, message.ID)
   249  		c.mu.Unlock()
   250  		cancel()
   251  	}()
   252  }
   253  
   254  func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
   255  	b, err := json.Marshal(response)
   256  	if err != nil {
   257  		panic(err)
   258  	}
   259  	c.write(&operationMessage{
   260  		Payload: b,
   261  		ID:      id,
   262  		Type:    dataMsg,
   263  	})
   264  }
   265  
   266  func (c *wsConnection) complete(id string) {
   267  	c.write(&operationMessage{ID: id, Type: completeMsg})
   268  }
   269  
   270  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   271  	errs := make([]error, len(errors))
   272  	for i, err := range errors {
   273  		errs[i] = err
   274  	}
   275  	b, err := json.Marshal(errs)
   276  	if err != nil {
   277  		panic(err)
   278  	}
   279  	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
   280  }
   281  
   282  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   283  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   284  	if err != nil {
   285  		panic(err)
   286  	}
   287  
   288  	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
   289  }
   290  
   291  func (c *wsConnection) readOp() *operationMessage {
   292  	_, r, err := c.conn.NextReader()
   293  	if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
   294  		return nil
   295  	} else if err != nil {
   296  		c.sendConnectionError("invalid json: %T %s", err, err.Error())
   297  		return nil
   298  	}
   299  	message := operationMessage{}
   300  	if err := jsonDecode(r, &message); err != nil {
   301  		c.sendConnectionError("invalid json")
   302  		return nil
   303  	}
   304  
   305  	return &message
   306  }
   307  
   308  func (c *wsConnection) close(closeCode int, message string) {
   309  	c.mu.Lock()
   310  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   311  	c.mu.Unlock()
   312  	_ = c.conn.Close()
   313  }