github.com/animeshon/gqlgen@v0.13.1-0.20210304133704-3a770431bb6d/graphql/handler/transport/websocket.go (about)

     1  package transport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"log"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/animeshon/gqlgen/graphql"
    14  	"github.com/animeshon/gqlgen/graphql/errcode"
    15  	"github.com/gorilla/websocket"
    16  	"github.com/vektah/gqlparser/v2/gqlerror"
    17  )
    18  
    19  const (
    20  	connectionInitMsg      = "connection_init"      // Client -> Server
    21  	connectionTerminateMsg = "connection_terminate" // Client -> Server
    22  	startMsg               = "start"                // Client -> Server
    23  	stopMsg                = "stop"                 // Client -> Server
    24  	connectionAckMsg       = "connection_ack"       // Server -> Client
    25  	connectionErrorMsg     = "connection_error"     // Server -> Client
    26  	dataMsg                = "data"                 // Server -> Client
    27  	errorMsg               = "error"                // Server -> Client
    28  	completeMsg            = "complete"             // Server -> Client
    29  	connectionKeepAliveMsg = "ka"                   // Server -> Client
    30  )
    31  
    32  type (
    33  	Websocket struct {
    34  		Upgrader              websocket.Upgrader
    35  		InitFunc              WebsocketInitFunc
    36  		KeepAlivePingInterval time.Duration
    37  	}
    38  	wsConnection struct {
    39  		Websocket
    40  		ctx             context.Context
    41  		conn            *websocket.Conn
    42  		active          map[string]context.CancelFunc
    43  		mu              sync.Mutex
    44  		keepAliveTicker *time.Ticker
    45  		exec            graphql.GraphExecutor
    46  
    47  		initPayload InitPayload
    48  	}
    49  	operationMessage struct {
    50  		Payload json.RawMessage `json:"payload,omitempty"`
    51  		ID      string          `json:"id,omitempty"`
    52  		Type    string          `json:"type"`
    53  	}
    54  	WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
    55  )
    56  
    57  var _ graphql.Transport = Websocket{}
    58  
    59  func (t Websocket) Supports(r *http.Request) bool {
    60  	return r.Header.Get("Upgrade") != ""
    61  }
    62  
    63  func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
    64  	ws, err := t.Upgrader.Upgrade(w, r, http.Header{
    65  		"Sec-Websocket-Protocol": []string{"graphql-ws"},
    66  	})
    67  	if err != nil {
    68  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    69  		SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    70  		return
    71  	}
    72  
    73  	conn := wsConnection{
    74  		active:    map[string]context.CancelFunc{},
    75  		conn:      ws,
    76  		ctx:       r.Context(),
    77  		exec:      exec,
    78  		Websocket: t,
    79  	}
    80  
    81  	if !conn.init() {
    82  		return
    83  	}
    84  
    85  	conn.run()
    86  }
    87  
    88  func (c *wsConnection) init() bool {
    89  	message := c.readOp()
    90  	if message == nil {
    91  		c.close(websocket.CloseProtocolError, "decoding error")
    92  		return false
    93  	}
    94  
    95  	switch message.Type {
    96  	case connectionInitMsg:
    97  		if len(message.Payload) > 0 {
    98  			c.initPayload = make(InitPayload)
    99  			err := json.Unmarshal(message.Payload, &c.initPayload)
   100  			if err != nil {
   101  				return false
   102  			}
   103  		}
   104  
   105  		if c.InitFunc != nil {
   106  			ctx, err := c.InitFunc(c.ctx, c.initPayload)
   107  			if err != nil {
   108  				c.sendConnectionError(err.Error())
   109  				c.close(websocket.CloseNormalClosure, "terminated")
   110  				return false
   111  			}
   112  			c.ctx = ctx
   113  		}
   114  
   115  		c.write(&operationMessage{Type: connectionAckMsg})
   116  		c.write(&operationMessage{Type: connectionKeepAliveMsg})
   117  	case connectionTerminateMsg:
   118  		c.close(websocket.CloseNormalClosure, "terminated")
   119  		return false
   120  	default:
   121  		c.sendConnectionError("unexpected message %s", message.Type)
   122  		c.close(websocket.CloseProtocolError, "unexpected message")
   123  		return false
   124  	}
   125  
   126  	return true
   127  }
   128  
   129  func (c *wsConnection) write(msg *operationMessage) {
   130  	c.mu.Lock()
   131  	c.conn.WriteJSON(msg)
   132  	c.mu.Unlock()
   133  }
   134  
   135  func (c *wsConnection) run() {
   136  	// We create a cancellation that will shutdown the keep-alive when we leave
   137  	// this function.
   138  	ctx, cancel := context.WithCancel(c.ctx)
   139  	defer func() {
   140  		cancel()
   141  		c.close(websocket.CloseAbnormalClosure, "unexpected closure")
   142  	}()
   143  
   144  	// Create a timer that will fire every interval to keep the connection alive.
   145  	if c.KeepAlivePingInterval != 0 {
   146  		c.mu.Lock()
   147  		c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
   148  		c.mu.Unlock()
   149  
   150  		go c.keepAlive(ctx)
   151  	}
   152  
   153  	for {
   154  		start := graphql.Now()
   155  		message := c.readOp()
   156  		if message == nil {
   157  			return
   158  		}
   159  
   160  		switch message.Type {
   161  		case startMsg:
   162  			c.subscribe(start, message)
   163  		case stopMsg:
   164  			c.mu.Lock()
   165  			closer := c.active[message.ID]
   166  			c.mu.Unlock()
   167  			if closer != nil {
   168  				closer()
   169  			}
   170  		case connectionTerminateMsg:
   171  			c.close(websocket.CloseNormalClosure, "terminated")
   172  			return
   173  		default:
   174  			c.sendConnectionError("unexpected message %s", message.Type)
   175  			c.close(websocket.CloseProtocolError, "unexpected message")
   176  			return
   177  		}
   178  	}
   179  }
   180  
   181  func (c *wsConnection) keepAlive(ctx context.Context) {
   182  	for {
   183  		select {
   184  		case <-ctx.Done():
   185  			c.keepAliveTicker.Stop()
   186  			return
   187  		case <-c.keepAliveTicker.C:
   188  			c.write(&operationMessage{Type: connectionKeepAliveMsg})
   189  		}
   190  	}
   191  }
   192  
   193  func (c *wsConnection) subscribe(start time.Time, message *operationMessage) {
   194  	ctx := graphql.StartOperationTrace(c.ctx)
   195  	var params *graphql.RawParams
   196  	if err := jsonDecode(bytes.NewReader(message.Payload), &params); err != nil {
   197  		c.sendError(message.ID, &gqlerror.Error{Message: "invalid json"})
   198  		c.complete(message.ID)
   199  		return
   200  	}
   201  
   202  	params.ReadTime = graphql.TraceTiming{
   203  		Start: start,
   204  		End:   graphql.Now(),
   205  	}
   206  
   207  	rc, err := c.exec.CreateOperationContext(ctx, params)
   208  	if err != nil {
   209  		resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
   210  		switch errcode.GetErrorKind(err) {
   211  		case errcode.KindProtocol:
   212  			c.sendError(message.ID, resp.Errors...)
   213  		default:
   214  			c.sendResponse(message.ID, &graphql.Response{Errors: err})
   215  		}
   216  
   217  		c.complete(message.ID)
   218  		return
   219  	}
   220  
   221  	ctx = graphql.WithOperationContext(ctx, rc)
   222  
   223  	if c.initPayload != nil {
   224  		ctx = withInitPayload(ctx, c.initPayload)
   225  	}
   226  
   227  	ctx, cancel := context.WithCancel(ctx)
   228  	c.mu.Lock()
   229  	c.active[message.ID] = cancel
   230  	c.mu.Unlock()
   231  
   232  	go func() {
   233  		defer func() {
   234  			if r := recover(); r != nil {
   235  				userErr := rc.Recover(ctx, r)
   236  				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
   237  			}
   238  		}()
   239  		responses, ctx := c.exec.DispatchOperation(ctx, rc)
   240  		for {
   241  			response := responses(ctx)
   242  			if response == nil {
   243  				break
   244  			}
   245  
   246  			c.sendResponse(message.ID, response)
   247  		}
   248  		c.complete(message.ID)
   249  
   250  		c.mu.Lock()
   251  		delete(c.active, message.ID)
   252  		c.mu.Unlock()
   253  		cancel()
   254  	}()
   255  }
   256  
   257  func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
   258  	b, err := json.Marshal(response)
   259  	if err != nil {
   260  		panic(err)
   261  	}
   262  	c.write(&operationMessage{
   263  		Payload: b,
   264  		ID:      id,
   265  		Type:    dataMsg,
   266  	})
   267  }
   268  
   269  func (c *wsConnection) complete(id string) {
   270  	c.write(&operationMessage{ID: id, Type: completeMsg})
   271  }
   272  
   273  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   274  	errs := make([]error, len(errors))
   275  	for i, err := range errors {
   276  		errs[i] = err
   277  	}
   278  	b, err := json.Marshal(errs)
   279  	if err != nil {
   280  		panic(err)
   281  	}
   282  	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
   283  }
   284  
   285  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   286  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   287  	if err != nil {
   288  		panic(err)
   289  	}
   290  
   291  	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
   292  }
   293  
   294  func (c *wsConnection) readOp() *operationMessage {
   295  	_, r, err := c.conn.NextReader()
   296  	if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
   297  		return nil
   298  	} else if err != nil {
   299  		c.sendConnectionError("invalid json: %T %s", err, err.Error())
   300  		return nil
   301  	}
   302  	message := operationMessage{}
   303  	if err := jsonDecode(r, &message); err != nil {
   304  		c.sendConnectionError("invalid json")
   305  		return nil
   306  	}
   307  
   308  	return &message
   309  }
   310  
   311  func (c *wsConnection) close(closeCode int, message string) {
   312  	c.mu.Lock()
   313  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   314  	c.mu.Unlock()
   315  	_ = c.conn.Close()
   316  }