github.com/humans-group/gqlgen@v0.7.2/handler/websocket.go (about)

     1  package handler
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"log"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/99designs/gqlgen/graphql"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/vektah/gqlparser"
    16  	"github.com/vektah/gqlparser/ast"
    17  	"github.com/vektah/gqlparser/gqlerror"
    18  	"github.com/vektah/gqlparser/validator"
    19  )
    20  
    21  const (
    22  	connectionInitMsg      = "connection_init"      // Client -> Server
    23  	connectionTerminateMsg = "connection_terminate" // Client -> Server
    24  	startMsg               = "start"                // Client -> Server
    25  	stopMsg                = "stop"                 // Client -> Server
    26  	connectionAckMsg       = "connection_ack"       // Server -> Client
    27  	connectionErrorMsg     = "connection_error"     // Server -> Client
    28  	dataMsg                = "data"                 // Server -> Client
    29  	errorMsg               = "error"                // Server -> Client
    30  	completeMsg            = "complete"             // Server -> Client
    31  	connectionKeepAliveMsg = "ka"                   // Server -> Client
    32  )
    33  
    34  type operationMessage struct {
    35  	Payload json.RawMessage `json:"payload,omitempty"`
    36  	ID      string          `json:"id,omitempty"`
    37  	Type    string          `json:"type"`
    38  }
    39  
    40  type wsConnection struct {
    41  	ctx             context.Context
    42  	conn            *websocket.Conn
    43  	exec            graphql.ExecutableSchema
    44  	active          map[string]context.CancelFunc
    45  	mu              sync.Mutex
    46  	cfg             *Config
    47  	keepAliveTicker *time.Ticker
    48  
    49  	initPayload InitPayload
    50  }
    51  
    52  func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
    53  	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
    54  		"Sec-Websocket-Protocol": []string{"graphql-ws"},
    55  	})
    56  	if err != nil {
    57  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    58  		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    59  		return
    60  	}
    61  
    62  	conn := wsConnection{
    63  		active: map[string]context.CancelFunc{},
    64  		exec:   exec,
    65  		conn:   ws,
    66  		ctx:    r.Context(),
    67  		cfg:    cfg,
    68  	}
    69  
    70  	if !conn.init() {
    71  		return
    72  	}
    73  
    74  	conn.run()
    75  }
    76  
    77  func (c *wsConnection) init() bool {
    78  	message := c.readOp()
    79  	if message == nil {
    80  		c.close(websocket.CloseProtocolError, "decoding error")
    81  		return false
    82  	}
    83  
    84  	switch message.Type {
    85  	case connectionInitMsg:
    86  		if len(message.Payload) > 0 {
    87  			c.initPayload = make(InitPayload)
    88  			err := json.Unmarshal(message.Payload, &c.initPayload)
    89  			if err != nil {
    90  				return false
    91  			}
    92  		}
    93  
    94  		c.write(&operationMessage{Type: connectionAckMsg})
    95  	case connectionTerminateMsg:
    96  		c.close(websocket.CloseNormalClosure, "terminated")
    97  		return false
    98  	default:
    99  		c.sendConnectionError("unexpected message %s", message.Type)
   100  		c.close(websocket.CloseProtocolError, "unexpected message")
   101  		return false
   102  	}
   103  
   104  	return true
   105  }
   106  
   107  func (c *wsConnection) write(msg *operationMessage) {
   108  	c.mu.Lock()
   109  	c.conn.WriteJSON(msg)
   110  	c.mu.Unlock()
   111  }
   112  
   113  func (c *wsConnection) run() {
   114  	// We create a cancellation that will shutdown the keep-alive when we leave
   115  	// this function.
   116  	ctx, cancel := context.WithCancel(c.ctx)
   117  	defer cancel()
   118  
   119  	// Create a timer that will fire every interval to keep the connection alive.
   120  	if c.cfg.connectionKeepAlivePingInterval != 0 {
   121  		c.mu.Lock()
   122  		c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
   123  		c.mu.Unlock()
   124  
   125  		go c.keepAlive(ctx)
   126  	}
   127  
   128  	for {
   129  		message := c.readOp()
   130  		if message == nil {
   131  			return
   132  		}
   133  
   134  		switch message.Type {
   135  		case startMsg:
   136  			if !c.subscribe(message) {
   137  				return
   138  			}
   139  		case stopMsg:
   140  			c.mu.Lock()
   141  			closer := c.active[message.ID]
   142  			c.mu.Unlock()
   143  			if closer == nil {
   144  				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
   145  				continue
   146  			}
   147  
   148  			closer()
   149  		case connectionTerminateMsg:
   150  			c.close(websocket.CloseNormalClosure, "terminated")
   151  			return
   152  		default:
   153  			c.sendConnectionError("unexpected message %s", message.Type)
   154  			c.close(websocket.CloseProtocolError, "unexpected message")
   155  			return
   156  		}
   157  	}
   158  }
   159  
   160  func (c *wsConnection) keepAlive(ctx context.Context) {
   161  	for {
   162  		select {
   163  		case <-ctx.Done():
   164  			c.keepAliveTicker.Stop()
   165  			return
   166  		case <-c.keepAliveTicker.C:
   167  			c.write(&operationMessage{Type: connectionKeepAliveMsg})
   168  		}
   169  	}
   170  }
   171  
   172  func (c *wsConnection) subscribe(message *operationMessage) bool {
   173  	var reqParams params
   174  	if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
   175  		c.sendConnectionError("invalid json")
   176  		return false
   177  	}
   178  
   179  	doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
   180  	if qErr != nil {
   181  		c.sendError(message.ID, qErr...)
   182  		return true
   183  	}
   184  
   185  	op := doc.Operations.ForName(reqParams.OperationName)
   186  	if op == nil {
   187  		c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
   188  		return true
   189  	}
   190  
   191  	vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
   192  	if err != nil {
   193  		c.sendError(message.ID, err)
   194  		return true
   195  	}
   196  	reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
   197  	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
   198  
   199  	if c.initPayload != nil {
   200  		ctx = withInitPayload(ctx, c.initPayload)
   201  	}
   202  
   203  	if op.Operation != ast.Subscription {
   204  		var result *graphql.Response
   205  		if op.Operation == ast.Query {
   206  			result = c.exec.Query(ctx, op)
   207  		} else {
   208  			result = c.exec.Mutation(ctx, op)
   209  		}
   210  
   211  		c.sendData(message.ID, result)
   212  		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
   213  		return true
   214  	}
   215  
   216  	ctx, cancel := context.WithCancel(ctx)
   217  	c.mu.Lock()
   218  	c.active[message.ID] = cancel
   219  	c.mu.Unlock()
   220  	go func() {
   221  		defer func() {
   222  			if r := recover(); r != nil {
   223  				userErr := reqCtx.Recover(ctx, r)
   224  				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
   225  			}
   226  		}()
   227  		next := c.exec.Subscription(ctx, op)
   228  		for result := next(); result != nil; result = next() {
   229  			c.sendData(message.ID, result)
   230  		}
   231  
   232  		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
   233  
   234  		c.mu.Lock()
   235  		delete(c.active, message.ID)
   236  		c.mu.Unlock()
   237  		cancel()
   238  	}()
   239  
   240  	return true
   241  }
   242  
   243  func (c *wsConnection) sendData(id string, response *graphql.Response) {
   244  	b, err := json.Marshal(response)
   245  	if err != nil {
   246  		c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
   247  		return
   248  	}
   249  
   250  	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
   251  }
   252  
   253  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   254  	var errs []error
   255  	for _, err := range errors {
   256  		errs = append(errs, err)
   257  	}
   258  	b, err := json.Marshal(errs)
   259  	if err != nil {
   260  		panic(err)
   261  	}
   262  	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
   263  }
   264  
   265  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   266  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   267  	if err != nil {
   268  		panic(err)
   269  	}
   270  
   271  	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
   272  }
   273  
   274  func (c *wsConnection) readOp() *operationMessage {
   275  	_, r, err := c.conn.NextReader()
   276  	if err != nil {
   277  		c.sendConnectionError("invalid json")
   278  		return nil
   279  	}
   280  	message := operationMessage{}
   281  	if err := jsonDecode(r, &message); err != nil {
   282  		c.sendConnectionError("invalid json")
   283  		return nil
   284  	}
   285  
   286  	return &message
   287  }
   288  
   289  func (c *wsConnection) close(closeCode int, message string) {
   290  	c.mu.Lock()
   291  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   292  	c.mu.Unlock()
   293  	_ = c.conn.Close()
   294  }