github.com/HaswinVidanage/gqlgen@v0.8.1-0.20220609041233-69528c1bf712/handler/websocket.go (about)

     1  package handler
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"log"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/HaswinVidanage/gqlgen/graphql"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/hashicorp/golang-lru"
    16  	"github.com/vektah/gqlparser"
    17  	"github.com/vektah/gqlparser/ast"
    18  	"github.com/vektah/gqlparser/gqlerror"
    19  	"github.com/vektah/gqlparser/validator"
    20  )
    21  
    22  const (
    23  	connectionInitMsg      = "connection_init"      // Client -> Server
    24  	connectionTerminateMsg = "connection_terminate" // Client -> Server
    25  	startMsg               = "start"                // Client -> Server
    26  	stopMsg                = "stop"                 // Client -> Server
    27  	connectionAckMsg       = "connection_ack"       // Server -> Client
    28  	connectionErrorMsg     = "connection_error"     // Server -> Client
    29  	dataMsg                = "data"                 // Server -> Client
    30  	errorMsg               = "error"                // Server -> Client
    31  	completeMsg            = "complete"             // Server -> Client
    32  	connectionKeepAliveMsg = "ka"                   // Server -> Client
    33  )
    34  
    35  type operationMessage struct {
    36  	Payload json.RawMessage `json:"payload,omitempty"`
    37  	ID      string          `json:"id,omitempty"`
    38  	Type    string          `json:"type"`
    39  }
    40  
    41  type wsConnection struct {
    42  	ctx             context.Context
    43  	conn            *websocket.Conn
    44  	exec            graphql.ExecutableSchema
    45  	active          map[string]context.CancelFunc
    46  	mu              sync.Mutex
    47  	cfg             *Config
    48  	cache           *lru.Cache
    49  	keepAliveTicker *time.Ticker
    50  
    51  	initPayload InitPayload
    52  }
    53  
    54  func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) {
    55  	ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
    56  		"Sec-Websocket-Protocol": []string{"graphql-ws"},
    57  	})
    58  	if err != nil {
    59  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    60  		sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    61  		return
    62  	}
    63  
    64  	conn := wsConnection{
    65  		active: map[string]context.CancelFunc{},
    66  		exec:   exec,
    67  		conn:   ws,
    68  		ctx:    r.Context(),
    69  		cfg:    cfg,
    70  		cache:  cache,
    71  	}
    72  
    73  	if !conn.init() {
    74  		return
    75  	}
    76  
    77  	conn.run()
    78  }
    79  
    80  func (c *wsConnection) init() bool {
    81  	message := c.readOp()
    82  	if message == nil {
    83  		c.close(websocket.CloseProtocolError, "decoding error")
    84  		return false
    85  	}
    86  
    87  	switch message.Type {
    88  	case connectionInitMsg:
    89  		if len(message.Payload) > 0 {
    90  			c.initPayload = make(InitPayload)
    91  			err := json.Unmarshal(message.Payload, &c.initPayload)
    92  			if err != nil {
    93  				return false
    94  			}
    95  		}
    96  
    97  		c.write(&operationMessage{Type: connectionAckMsg})
    98  	case connectionTerminateMsg:
    99  		c.close(websocket.CloseNormalClosure, "terminated")
   100  		return false
   101  	default:
   102  		c.sendConnectionError("unexpected message %s", message.Type)
   103  		c.close(websocket.CloseProtocolError, "unexpected message")
   104  		return false
   105  	}
   106  
   107  	return true
   108  }
   109  
   110  func (c *wsConnection) write(msg *operationMessage) {
   111  	c.mu.Lock()
   112  	c.conn.WriteJSON(msg)
   113  	c.mu.Unlock()
   114  }
   115  
   116  func (c *wsConnection) run() {
   117  	// We create a cancellation that will shutdown the keep-alive when we leave
   118  	// this function.
   119  	ctx, cancel := context.WithCancel(c.ctx)
   120  	defer cancel()
   121  
   122  	// Create a timer that will fire every interval to keep the connection alive.
   123  	if c.cfg.connectionKeepAlivePingInterval != 0 {
   124  		c.mu.Lock()
   125  		c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
   126  		c.mu.Unlock()
   127  
   128  		go c.keepAlive(ctx)
   129  	}
   130  
   131  	for {
   132  		message := c.readOp()
   133  		if message == nil {
   134  			return
   135  		}
   136  
   137  		switch message.Type {
   138  		case startMsg:
   139  			if !c.subscribe(message) {
   140  				return
   141  			}
   142  		case stopMsg:
   143  			c.mu.Lock()
   144  			closer := c.active[message.ID]
   145  			c.mu.Unlock()
   146  			if closer == nil {
   147  				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
   148  				continue
   149  			}
   150  
   151  			closer()
   152  		case connectionTerminateMsg:
   153  			c.close(websocket.CloseNormalClosure, "terminated")
   154  			return
   155  		default:
   156  			c.sendConnectionError("unexpected message %s", message.Type)
   157  			c.close(websocket.CloseProtocolError, "unexpected message")
   158  			return
   159  		}
   160  	}
   161  }
   162  
   163  func (c *wsConnection) keepAlive(ctx context.Context) {
   164  	for {
   165  		select {
   166  		case <-ctx.Done():
   167  			c.keepAliveTicker.Stop()
   168  			return
   169  		case <-c.keepAliveTicker.C:
   170  			c.write(&operationMessage{Type: connectionKeepAliveMsg})
   171  		}
   172  	}
   173  }
   174  
   175  func (c *wsConnection) subscribe(message *operationMessage) bool {
   176  	var reqParams params
   177  	if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
   178  		c.sendConnectionError("invalid json")
   179  		return false
   180  	}
   181  
   182  	var (
   183  		doc      *ast.QueryDocument
   184  		cacheHit bool
   185  	)
   186  	if c.cache != nil {
   187  		val, ok := c.cache.Get(reqParams.Query)
   188  		if ok {
   189  			doc = val.(*ast.QueryDocument)
   190  			cacheHit = true
   191  		}
   192  	}
   193  	if !cacheHit {
   194  		var qErr gqlerror.List
   195  		doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
   196  		if qErr != nil {
   197  			c.sendError(message.ID, qErr...)
   198  			return true
   199  		}
   200  		if c.cache != nil {
   201  			c.cache.Add(reqParams.Query, doc)
   202  		}
   203  	}
   204  
   205  	op := doc.Operations.ForName(reqParams.OperationName)
   206  	if op == nil {
   207  		c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
   208  		return true
   209  	}
   210  
   211  	vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
   212  	if err != nil {
   213  		c.sendError(message.ID, err)
   214  		return true
   215  	}
   216  	reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
   217  	ctx := graphql.WithRequestContext(c.ctx, reqCtx)
   218  
   219  	if c.initPayload != nil {
   220  		ctx = withInitPayload(ctx, c.initPayload)
   221  	}
   222  
   223  	if op.Operation != ast.Subscription {
   224  		var result *graphql.Response
   225  		if op.Operation == ast.Query {
   226  			result = c.exec.Query(ctx, op)
   227  		} else {
   228  			result = c.exec.Mutation(ctx, op)
   229  		}
   230  
   231  		c.sendData(message.ID, result)
   232  		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
   233  		return true
   234  	}
   235  
   236  	ctx, cancel := context.WithCancel(ctx)
   237  	c.mu.Lock()
   238  	c.active[message.ID] = cancel
   239  	c.mu.Unlock()
   240  	go func() {
   241  		defer func() {
   242  			if r := recover(); r != nil {
   243  				userErr := reqCtx.Recover(ctx, r)
   244  				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
   245  			}
   246  		}()
   247  		next := c.exec.Subscription(ctx, op)
   248  		for result := next(); result != nil; result = next() {
   249  			c.sendData(message.ID, result)
   250  		}
   251  
   252  		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
   253  
   254  		c.mu.Lock()
   255  		delete(c.active, message.ID)
   256  		c.mu.Unlock()
   257  		cancel()
   258  	}()
   259  
   260  	return true
   261  }
   262  
   263  func (c *wsConnection) sendData(id string, response *graphql.Response) {
   264  	b, err := json.Marshal(response)
   265  	if err != nil {
   266  		c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
   267  		return
   268  	}
   269  
   270  	c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
   271  }
   272  
   273  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   274  	var errs []error
   275  	for _, err := range errors {
   276  		errs = append(errs, err)
   277  	}
   278  	b, err := json.Marshal(errs)
   279  	if err != nil {
   280  		panic(err)
   281  	}
   282  	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
   283  }
   284  
   285  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   286  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   287  	if err != nil {
   288  		panic(err)
   289  	}
   290  
   291  	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
   292  }
   293  
   294  func (c *wsConnection) readOp() *operationMessage {
   295  	_, r, err := c.conn.NextReader()
   296  	if err != nil {
   297  		c.sendConnectionError("invalid json")
   298  		return nil
   299  	}
   300  	message := operationMessage{}
   301  	if err := jsonDecode(r, &message); err != nil {
   302  		c.sendConnectionError("invalid json")
   303  		return nil
   304  	}
   305  
   306  	return &message
   307  }
   308  
   309  func (c *wsConnection) close(closeCode int, message string) {
   310  	c.mu.Lock()
   311  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   312  	c.mu.Unlock()
   313  	_ = c.conn.Close()
   314  }