github.com/apipluspower/gqlgen@v0.15.2/graphql/handler/transport/websocket_graphqlws.go (about)

     1  package transport
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  
     7  	"github.com/gorilla/websocket"
     8  )
     9  
    10  // https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
    11  const (
    12  	graphqlwsSubprotocol = "graphql-ws"
    13  
    14  	graphqlwsConnectionInitMsg      = graphqlwsMessageType("connection_init")
    15  	graphqlwsConnectionTerminateMsg = graphqlwsMessageType("connection_terminate")
    16  	graphqlwsStartMsg               = graphqlwsMessageType("start")
    17  	graphqlwsStopMsg                = graphqlwsMessageType("stop")
    18  	graphqlwsConnectionAckMsg       = graphqlwsMessageType("connection_ack")
    19  	graphqlwsConnectionErrorMsg     = graphqlwsMessageType("connection_error")
    20  	graphqlwsDataMsg                = graphqlwsMessageType("data")
    21  	graphqlwsErrorMsg               = graphqlwsMessageType("error")
    22  	graphqlwsCompleteMsg            = graphqlwsMessageType("complete")
    23  	graphqlwsConnectionKeepAliveMsg = graphqlwsMessageType("ka")
    24  )
    25  
    26  var allGraphqlwsMessageTypes = []graphqlwsMessageType{
    27  	graphqlwsConnectionInitMsg,
    28  	graphqlwsConnectionTerminateMsg,
    29  	graphqlwsStartMsg,
    30  	graphqlwsStopMsg,
    31  	graphqlwsConnectionAckMsg,
    32  	graphqlwsConnectionErrorMsg,
    33  	graphqlwsDataMsg,
    34  	graphqlwsErrorMsg,
    35  	graphqlwsCompleteMsg,
    36  	graphqlwsConnectionKeepAliveMsg,
    37  }
    38  
    39  type (
    40  	graphqlwsMessageExchanger struct {
    41  		c *websocket.Conn
    42  	}
    43  
    44  	graphqlwsMessage struct {
    45  		Payload json.RawMessage      `json:"payload,omitempty"`
    46  		ID      string               `json:"id,omitempty"`
    47  		Type    graphqlwsMessageType `json:"type"`
    48  	}
    49  
    50  	graphqlwsMessageType string
    51  )
    52  
    53  func (me graphqlwsMessageExchanger) NextMessage() (message, error) {
    54  	_, r, err := me.c.NextReader()
    55  	if err != nil {
    56  		return message{}, handleNextReaderError(err)
    57  	}
    58  
    59  	var graphqlwsMessage graphqlwsMessage
    60  	if err := jsonDecode(r, &graphqlwsMessage); err != nil {
    61  		return message{}, errInvalidMsg
    62  	}
    63  
    64  	return graphqlwsMessage.toMessage()
    65  }
    66  
    67  func (me graphqlwsMessageExchanger) Send(m *message) error {
    68  	msg := &graphqlwsMessage{}
    69  	if err := msg.fromMessage(m); err != nil {
    70  		return err
    71  	}
    72  
    73  	return me.c.WriteJSON(msg)
    74  }
    75  
    76  func (t *graphqlwsMessageType) UnmarshalText(text []byte) (err error) {
    77  	var found bool
    78  	for _, candidate := range allGraphqlwsMessageTypes {
    79  		if string(candidate) == string(text) {
    80  			*t = candidate
    81  			found = true
    82  			break
    83  		}
    84  	}
    85  
    86  	if !found {
    87  		err = fmt.Errorf("invalid message type %s", string(text))
    88  	}
    89  
    90  	return err
    91  }
    92  
    93  func (t graphqlwsMessageType) MarshalText() ([]byte, error) {
    94  	return []byte(string(t)), nil
    95  }
    96  
    97  func (t graphqlwsMessageType) toMessageType() (mt messageType, err error) {
    98  	switch t {
    99  	default:
   100  		err = fmt.Errorf("unknown message type mapping for %s", t)
   101  	case graphqlwsConnectionInitMsg:
   102  		mt = initMessageType
   103  	case graphqlwsConnectionTerminateMsg:
   104  		mt = connectionCloseMessageType
   105  	case graphqlwsStartMsg:
   106  		mt = startMessageType
   107  	case graphqlwsStopMsg:
   108  		mt = stopMessageType
   109  	case graphqlwsConnectionAckMsg:
   110  		mt = connectionAckMessageType
   111  	case graphqlwsConnectionErrorMsg:
   112  		mt = connectionErrorMessageType
   113  	case graphqlwsDataMsg:
   114  		mt = dataMessageType
   115  	case graphqlwsErrorMsg:
   116  		mt = errorMessageType
   117  	case graphqlwsCompleteMsg:
   118  		mt = completeMessageType
   119  	case graphqlwsConnectionKeepAliveMsg:
   120  		mt = keepAliveMessageType
   121  	}
   122  
   123  	return mt, err
   124  }
   125  
   126  func (t *graphqlwsMessageType) fromMessageType(mt messageType) (err error) {
   127  	switch mt {
   128  	default:
   129  		err = fmt.Errorf("failed to convert message %s to %s subprotocol", mt, graphqlwsSubprotocol)
   130  	case initMessageType:
   131  		*t = graphqlwsConnectionInitMsg
   132  	case connectionAckMessageType:
   133  		*t = graphqlwsConnectionAckMsg
   134  	case keepAliveMessageType:
   135  		*t = graphqlwsConnectionKeepAliveMsg
   136  	case connectionErrorMessageType:
   137  		*t = graphqlwsConnectionErrorMsg
   138  	case connectionCloseMessageType:
   139  		*t = graphqlwsConnectionTerminateMsg
   140  	case startMessageType:
   141  		*t = graphqlwsStartMsg
   142  	case stopMessageType:
   143  		*t = graphqlwsStopMsg
   144  	case dataMessageType:
   145  		*t = graphqlwsDataMsg
   146  	case completeMessageType:
   147  		*t = graphqlwsCompleteMsg
   148  	case errorMessageType:
   149  		*t = graphqlwsErrorMsg
   150  	}
   151  
   152  	return err
   153  }
   154  
   155  func (m graphqlwsMessage) toMessage() (message, error) {
   156  	mt, err := m.Type.toMessageType()
   157  	return message{
   158  		payload: m.Payload,
   159  		id:      m.ID,
   160  		t:       mt,
   161  	}, err
   162  }
   163  
   164  func (m *graphqlwsMessage) fromMessage(msg *message) (err error) {
   165  	err = m.Type.fromMessageType(msg.t)
   166  	m.ID = msg.id
   167  	m.Payload = msg.payload
   168  	return err
   169  }