github.com/99designs/gqlgen@v0.17.45/graphql/handler/transport/websocket_graphql_transport_ws.go (about) 1 package transport 2 3 import ( 4 "encoding/json" 5 "fmt" 6 7 "github.com/gorilla/websocket" 8 ) 9 10 // https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md 11 const ( 12 graphqltransportwsSubprotocol = "graphql-transport-ws" 13 14 graphqltransportwsConnectionInitMsg = graphqltransportwsMessageType("connection_init") 15 graphqltransportwsConnectionAckMsg = graphqltransportwsMessageType("connection_ack") 16 graphqltransportwsSubscribeMsg = graphqltransportwsMessageType("subscribe") 17 graphqltransportwsNextMsg = graphqltransportwsMessageType("next") 18 graphqltransportwsErrorMsg = graphqltransportwsMessageType("error") 19 graphqltransportwsCompleteMsg = graphqltransportwsMessageType("complete") 20 graphqltransportwsPingMsg = graphqltransportwsMessageType("ping") 21 graphqltransportwsPongMsg = graphqltransportwsMessageType("pong") 22 ) 23 24 var allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{ 25 graphqltransportwsConnectionInitMsg, 26 graphqltransportwsConnectionAckMsg, 27 graphqltransportwsSubscribeMsg, 28 graphqltransportwsNextMsg, 29 graphqltransportwsErrorMsg, 30 graphqltransportwsCompleteMsg, 31 graphqltransportwsPingMsg, 32 graphqltransportwsPongMsg, 33 } 34 35 type ( 36 graphqltransportwsMessageExchanger struct { 37 c *websocket.Conn 38 } 39 40 graphqltransportwsMessage struct { 41 Payload json.RawMessage `json:"payload,omitempty"` 42 ID string `json:"id,omitempty"` 43 Type graphqltransportwsMessageType `json:"type"` 44 noOp bool 45 } 46 47 graphqltransportwsMessageType string 48 ) 49 50 func (me graphqltransportwsMessageExchanger) NextMessage() (message, error) { 51 _, r, err := me.c.NextReader() 52 if err != nil { 53 return message{}, handleNextReaderError(err) 54 } 55 56 var graphqltransportwsMessage graphqltransportwsMessage 57 if err := jsonDecode(r, &graphqltransportwsMessage); err != nil { 58 return message{}, errInvalidMsg 59 } 60 61 return graphqltransportwsMessage.toMessage() 62 } 63 64 func (me graphqltransportwsMessageExchanger) Send(m *message) error { 65 msg := &graphqltransportwsMessage{} 66 if err := msg.fromMessage(m); err != nil { 67 return err 68 } 69 70 if msg.noOp { 71 return nil 72 } 73 74 return me.c.WriteJSON(msg) 75 } 76 77 func (t *graphqltransportwsMessageType) UnmarshalText(text []byte) (err error) { 78 var found bool 79 for _, candidate := range allGraphqltransportwsMessageTypes { 80 if string(candidate) == string(text) { 81 *t = candidate 82 found = true 83 break 84 } 85 } 86 87 if !found { 88 err = fmt.Errorf("invalid message type %s", string(text)) 89 } 90 91 return err 92 } 93 94 func (t graphqltransportwsMessageType) MarshalText() ([]byte, error) { 95 return []byte(string(t)), nil 96 } 97 98 func (m graphqltransportwsMessage) toMessage() (message, error) { 99 var t messageType 100 var err error 101 switch m.Type { 102 default: 103 err = fmt.Errorf("invalid client->server message type %s", m.Type) 104 case graphqltransportwsConnectionInitMsg: 105 t = initMessageType 106 case graphqltransportwsSubscribeMsg: 107 t = startMessageType 108 case graphqltransportwsCompleteMsg: 109 t = stopMessageType 110 case graphqltransportwsPingMsg: 111 t = pingMessageType 112 case graphqltransportwsPongMsg: 113 t = pongMessageType 114 } 115 116 return message{ 117 payload: m.Payload, 118 id: m.ID, 119 t: t, 120 }, err 121 } 122 123 func (m *graphqltransportwsMessage) fromMessage(msg *message) (err error) { 124 m.ID = msg.id 125 m.Payload = msg.payload 126 127 switch msg.t { 128 default: 129 err = fmt.Errorf("invalid server->client message type %s", msg.t) 130 case connectionAckMessageType: 131 m.Type = graphqltransportwsConnectionAckMsg 132 case keepAliveMessageType: 133 m.noOp = true 134 case connectionErrorMessageType: 135 m.noOp = true 136 case dataMessageType: 137 m.Type = graphqltransportwsNextMsg 138 case completeMessageType: 139 m.Type = graphqltransportwsCompleteMsg 140 case errorMessageType: 141 m.Type = graphqltransportwsErrorMsg 142 case pingMessageType: 143 m.Type = graphqltransportwsPingMsg 144 case pongMessageType: 145 m.Type = graphqltransportwsPongMsg 146 } 147 148 return err 149 }