github.com/fortexxx/gqlgen@v0.10.3-0.20191216030626-ca5ea8b21ead/graphql/handler/transport/websocket.go (about)

     1  package transport
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"log"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/99designs/gqlgen/graphql"
    14  	"github.com/gorilla/websocket"
    15  	"github.com/vektah/gqlparser/gqlerror"
    16  )
    17  
    18  const (
    19  	connectionInitMsg      = "connection_init"      // Client -> Server
    20  	connectionTerminateMsg = "connection_terminate" // Client -> Server
    21  	startMsg               = "start"                // Client -> Server
    22  	stopMsg                = "stop"                 // Client -> Server
    23  	connectionAckMsg       = "connection_ack"       // Server -> Client
    24  	connectionErrorMsg     = "connection_error"     // Server -> Client
    25  	dataMsg                = "data"                 // Server -> Client
    26  	errorMsg               = "error"                // Server -> Client
    27  	completeMsg            = "complete"             // Server -> Client
    28  	connectionKeepAliveMsg = "ka"                   // Server -> Client
    29  )
    30  
    31  type (
    32  	Websocket struct {
    33  		Upgrader              websocket.Upgrader
    34  		InitFunc              WebsocketInitFunc
    35  		KeepAlivePingInterval time.Duration
    36  	}
    37  	wsConnection struct {
    38  		Websocket
    39  		ctx             context.Context
    40  		conn            *websocket.Conn
    41  		active          map[string]context.CancelFunc
    42  		mu              sync.Mutex
    43  		keepAliveTicker *time.Ticker
    44  		exec            graphql.GraphExecutor
    45  
    46  		initPayload InitPayload
    47  	}
    48  	operationMessage struct {
    49  		Payload json.RawMessage `json:"payload,omitempty"`
    50  		ID      string          `json:"id,omitempty"`
    51  		Type    string          `json:"type"`
    52  	}
    53  	WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
    54  )
    55  
    56  var _ graphql.Transport = Websocket{}
    57  
    58  func (t Websocket) Supports(r *http.Request) bool {
    59  	return r.Header.Get("Upgrade") != ""
    60  }
    61  
    62  func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
    63  	ws, err := t.Upgrader.Upgrade(w, r, http.Header{
    64  		"Sec-Websocket-Protocol": []string{"graphql-ws"},
    65  	})
    66  	if err != nil {
    67  		log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
    68  		SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
    69  		return
    70  	}
    71  
    72  	conn := wsConnection{
    73  		active:    map[string]context.CancelFunc{},
    74  		conn:      ws,
    75  		ctx:       r.Context(),
    76  		exec:      exec,
    77  		Websocket: t,
    78  	}
    79  
    80  	if !conn.init() {
    81  		return
    82  	}
    83  
    84  	conn.run()
    85  }
    86  
    87  func (c *wsConnection) init() bool {
    88  	message := c.readOp()
    89  	if message == nil {
    90  		c.close(websocket.CloseProtocolError, "decoding error")
    91  		return false
    92  	}
    93  
    94  	switch message.Type {
    95  	case connectionInitMsg:
    96  		if len(message.Payload) > 0 {
    97  			c.initPayload = make(InitPayload)
    98  			err := json.Unmarshal(message.Payload, &c.initPayload)
    99  			if err != nil {
   100  				return false
   101  			}
   102  		}
   103  
   104  		if c.InitFunc != nil {
   105  			ctx, err := c.InitFunc(c.ctx, c.initPayload)
   106  			if err != nil {
   107  				c.sendConnectionError(err.Error())
   108  				c.close(websocket.CloseNormalClosure, "terminated")
   109  				return false
   110  			}
   111  			c.ctx = ctx
   112  		}
   113  
   114  		c.write(&operationMessage{Type: connectionAckMsg})
   115  		c.write(&operationMessage{Type: connectionKeepAliveMsg})
   116  	case connectionTerminateMsg:
   117  		c.close(websocket.CloseNormalClosure, "terminated")
   118  		return false
   119  	default:
   120  		c.sendConnectionError("unexpected message %s", message.Type)
   121  		c.close(websocket.CloseProtocolError, "unexpected message")
   122  		return false
   123  	}
   124  
   125  	return true
   126  }
   127  
   128  func (c *wsConnection) write(msg *operationMessage) {
   129  	c.mu.Lock()
   130  	c.conn.WriteJSON(msg)
   131  	c.mu.Unlock()
   132  }
   133  
   134  func (c *wsConnection) run() {
   135  	// We create a cancellation that will shutdown the keep-alive when we leave
   136  	// this function.
   137  	ctx, cancel := context.WithCancel(c.ctx)
   138  	defer cancel()
   139  
   140  	// Create a timer that will fire every interval to keep the connection alive.
   141  	if c.KeepAlivePingInterval != 0 {
   142  		c.mu.Lock()
   143  		c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
   144  		c.mu.Unlock()
   145  
   146  		go c.keepAlive(ctx)
   147  	}
   148  
   149  	for {
   150  		message := c.readOp()
   151  		if message == nil {
   152  			return
   153  		}
   154  
   155  		switch message.Type {
   156  		case startMsg:
   157  			if !c.subscribe(message) {
   158  				return
   159  			}
   160  		case stopMsg:
   161  			c.mu.Lock()
   162  			closer := c.active[message.ID]
   163  			c.mu.Unlock()
   164  			if closer == nil {
   165  				c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
   166  				continue
   167  			}
   168  
   169  			closer()
   170  		case connectionTerminateMsg:
   171  			c.close(websocket.CloseNormalClosure, "terminated")
   172  			return
   173  		default:
   174  			c.sendConnectionError("unexpected message %s", message.Type)
   175  			c.close(websocket.CloseProtocolError, "unexpected message")
   176  			return
   177  		}
   178  	}
   179  }
   180  
   181  func (c *wsConnection) keepAlive(ctx context.Context) {
   182  	for {
   183  		select {
   184  		case <-ctx.Done():
   185  			c.keepAliveTicker.Stop()
   186  			return
   187  		case <-c.keepAliveTicker.C:
   188  			c.write(&operationMessage{Type: connectionKeepAliveMsg})
   189  		}
   190  	}
   191  }
   192  
   193  func (c *wsConnection) subscribe(message *operationMessage) bool {
   194  	ctx := graphql.StartOperationTrace(c.ctx)
   195  	var params *graphql.RawParams
   196  	if err := jsonDecode(bytes.NewReader(message.Payload), &params); err != nil {
   197  		c.sendConnectionError("invalid json")
   198  		return false
   199  	}
   200  
   201  	rc, err := c.exec.CreateOperationContext(ctx, params)
   202  	if err != nil {
   203  		resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
   204  		c.sendError(message.ID, resp.Errors...)
   205  		return false
   206  	}
   207  
   208  	ctx = graphql.WithOperationContext(ctx, rc)
   209  
   210  	if c.initPayload != nil {
   211  		ctx = withInitPayload(ctx, c.initPayload)
   212  	}
   213  
   214  	ctx, cancel := context.WithCancel(ctx)
   215  	c.mu.Lock()
   216  	c.active[message.ID] = cancel
   217  	c.mu.Unlock()
   218  
   219  	go func() {
   220  		defer func() {
   221  			if r := recover(); r != nil {
   222  				userErr := rc.Recover(ctx, r)
   223  				c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
   224  			}
   225  		}()
   226  		responses, ctx := c.exec.DispatchOperation(ctx, rc)
   227  		for {
   228  			response := responses(ctx)
   229  			if response == nil {
   230  				break
   231  			}
   232  
   233  			b, err := json.Marshal(response)
   234  			if err != nil {
   235  				panic(err)
   236  			}
   237  			c.write(&operationMessage{
   238  				Payload: b,
   239  				ID:      message.ID,
   240  				Type:    dataMsg,
   241  			})
   242  		}
   243  		c.write(&operationMessage{ID: message.ID, Type: completeMsg})
   244  
   245  		c.mu.Lock()
   246  		delete(c.active, message.ID)
   247  		c.mu.Unlock()
   248  		cancel()
   249  	}()
   250  
   251  	return true
   252  }
   253  
   254  func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
   255  	errs := make([]error, len(errors))
   256  	for i, err := range errors {
   257  		errs[i] = err
   258  	}
   259  	b, err := json.Marshal(errs)
   260  	if err != nil {
   261  		panic(err)
   262  	}
   263  	c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
   264  }
   265  
   266  func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
   267  	b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
   268  	if err != nil {
   269  		panic(err)
   270  	}
   271  
   272  	c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
   273  }
   274  
   275  func (c *wsConnection) readOp() *operationMessage {
   276  	_, r, err := c.conn.NextReader()
   277  	if err != nil {
   278  		c.sendConnectionError("invalid json")
   279  		return nil
   280  	}
   281  	message := operationMessage{}
   282  	if err := jsonDecode(r, &message); err != nil {
   283  		c.sendConnectionError("invalid json")
   284  		return nil
   285  	}
   286  
   287  	return &message
   288  }
   289  
   290  func (c *wsConnection) close(closeCode int, message string) {
   291  	c.mu.Lock()
   292  	_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
   293  	c.mu.Unlock()
   294  	_ = c.conn.Close()
   295  }