github.com/apipluspower/gqlgen@v0.15.2/graphql/handler/transport/websocket_graphqlws.go (about) 1 package transport 2 3 import ( 4 "encoding/json" 5 "fmt" 6 7 "github.com/gorilla/websocket" 8 ) 9 10 // https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md 11 const ( 12 graphqlwsSubprotocol = "graphql-ws" 13 14 graphqlwsConnectionInitMsg = graphqlwsMessageType("connection_init") 15 graphqlwsConnectionTerminateMsg = graphqlwsMessageType("connection_terminate") 16 graphqlwsStartMsg = graphqlwsMessageType("start") 17 graphqlwsStopMsg = graphqlwsMessageType("stop") 18 graphqlwsConnectionAckMsg = graphqlwsMessageType("connection_ack") 19 graphqlwsConnectionErrorMsg = graphqlwsMessageType("connection_error") 20 graphqlwsDataMsg = graphqlwsMessageType("data") 21 graphqlwsErrorMsg = graphqlwsMessageType("error") 22 graphqlwsCompleteMsg = graphqlwsMessageType("complete") 23 graphqlwsConnectionKeepAliveMsg = graphqlwsMessageType("ka") 24 ) 25 26 var allGraphqlwsMessageTypes = []graphqlwsMessageType{ 27 graphqlwsConnectionInitMsg, 28 graphqlwsConnectionTerminateMsg, 29 graphqlwsStartMsg, 30 graphqlwsStopMsg, 31 graphqlwsConnectionAckMsg, 32 graphqlwsConnectionErrorMsg, 33 graphqlwsDataMsg, 34 graphqlwsErrorMsg, 35 graphqlwsCompleteMsg, 36 graphqlwsConnectionKeepAliveMsg, 37 } 38 39 type ( 40 graphqlwsMessageExchanger struct { 41 c *websocket.Conn 42 } 43 44 graphqlwsMessage struct { 45 Payload json.RawMessage `json:"payload,omitempty"` 46 ID string `json:"id,omitempty"` 47 Type graphqlwsMessageType `json:"type"` 48 } 49 50 graphqlwsMessageType string 51 ) 52 53 func (me graphqlwsMessageExchanger) NextMessage() (message, error) { 54 _, r, err := me.c.NextReader() 55 if err != nil { 56 return message{}, handleNextReaderError(err) 57 } 58 59 var graphqlwsMessage graphqlwsMessage 60 if err := jsonDecode(r, &graphqlwsMessage); err != nil { 61 return message{}, errInvalidMsg 62 } 63 64 return graphqlwsMessage.toMessage() 65 } 66 67 func (me graphqlwsMessageExchanger) Send(m *message) error { 68 msg := &graphqlwsMessage{} 69 if err := msg.fromMessage(m); err != nil { 70 return err 71 } 72 73 return me.c.WriteJSON(msg) 74 } 75 76 func (t *graphqlwsMessageType) UnmarshalText(text []byte) (err error) { 77 var found bool 78 for _, candidate := range allGraphqlwsMessageTypes { 79 if string(candidate) == string(text) { 80 *t = candidate 81 found = true 82 break 83 } 84 } 85 86 if !found { 87 err = fmt.Errorf("invalid message type %s", string(text)) 88 } 89 90 return err 91 } 92 93 func (t graphqlwsMessageType) MarshalText() ([]byte, error) { 94 return []byte(string(t)), nil 95 } 96 97 func (t graphqlwsMessageType) toMessageType() (mt messageType, err error) { 98 switch t { 99 default: 100 err = fmt.Errorf("unknown message type mapping for %s", t) 101 case graphqlwsConnectionInitMsg: 102 mt = initMessageType 103 case graphqlwsConnectionTerminateMsg: 104 mt = connectionCloseMessageType 105 case graphqlwsStartMsg: 106 mt = startMessageType 107 case graphqlwsStopMsg: 108 mt = stopMessageType 109 case graphqlwsConnectionAckMsg: 110 mt = connectionAckMessageType 111 case graphqlwsConnectionErrorMsg: 112 mt = connectionErrorMessageType 113 case graphqlwsDataMsg: 114 mt = dataMessageType 115 case graphqlwsErrorMsg: 116 mt = errorMessageType 117 case graphqlwsCompleteMsg: 118 mt = completeMessageType 119 case graphqlwsConnectionKeepAliveMsg: 120 mt = keepAliveMessageType 121 } 122 123 return mt, err 124 } 125 126 func (t *graphqlwsMessageType) fromMessageType(mt messageType) (err error) { 127 switch mt { 128 default: 129 err = fmt.Errorf("failed to convert message %s to %s subprotocol", mt, graphqlwsSubprotocol) 130 case initMessageType: 131 *t = graphqlwsConnectionInitMsg 132 case connectionAckMessageType: 133 *t = graphqlwsConnectionAckMsg 134 case keepAliveMessageType: 135 *t = graphqlwsConnectionKeepAliveMsg 136 case connectionErrorMessageType: 137 *t = graphqlwsConnectionErrorMsg 138 case connectionCloseMessageType: 139 *t = graphqlwsConnectionTerminateMsg 140 case startMessageType: 141 *t = graphqlwsStartMsg 142 case stopMessageType: 143 *t = graphqlwsStopMsg 144 case dataMessageType: 145 *t = graphqlwsDataMsg 146 case completeMessageType: 147 *t = graphqlwsCompleteMsg 148 case errorMessageType: 149 *t = graphqlwsErrorMsg 150 } 151 152 return err 153 } 154 155 func (m graphqlwsMessage) toMessage() (message, error) { 156 mt, err := m.Type.toMessageType() 157 return message{ 158 payload: m.Payload, 159 id: m.ID, 160 t: mt, 161 }, err 162 } 163 164 func (m *graphqlwsMessage) fromMessage(msg *message) (err error) { 165 err = m.Type.fromMessageType(msg.t) 166 m.ID = msg.id 167 m.Payload = msg.payload 168 return err 169 }