github.com/philippseith/signalr@v0.6.3/messagepackhubprotocol.go (about) 1 package signalr 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "io" 9 10 "github.com/go-kit/log" 11 "github.com/vmihailenco/msgpack/v5" 12 ) 13 14 type messagePackHubProtocol struct { 15 dbg log.Logger 16 } 17 18 func (m *messagePackHubProtocol) ParseMessages(reader io.Reader, remainBuf *bytes.Buffer) ([]interface{}, error) { 19 frames, err := m.readFrames(reader, remainBuf) 20 if err != nil { 21 return nil, err 22 } 23 messages := make([]interface{}, 0) 24 for _, frame := range frames { 25 message, err := m.parseMessage(bytes.NewBuffer(frame)) 26 if err != nil { 27 return nil, err 28 } 29 messages = append(messages, message) 30 } 31 return messages, nil 32 } 33 34 func (m *messagePackHubProtocol) readFrames(reader io.Reader, remainBuf *bytes.Buffer) ([][]byte, error) { 35 frames := make([][]byte, 0) 36 for { 37 // Try to get the frame length 38 frameLenBuf := make([]byte, binary.MaxVarintLen32) 39 n1, err := remainBuf.Read(frameLenBuf) 40 if err != nil && !errors.Is(err, io.EOF) { 41 // Some weird other error 42 return nil, err 43 } 44 n2, err := reader.Read(frameLenBuf[n1:]) 45 if err != nil && !errors.Is(err, io.EOF) { 46 // Some weird other error 47 return nil, err 48 } 49 frameLen, lenLen := binary.Uvarint(frameLenBuf[:n1+n2]) 50 if lenLen == 0 { 51 // reader could not supply enough bytes to decode the Uvarint 52 // Store the already read bytes in the remainBuf for next iteration 53 _, _ = remainBuf.Write(frameLenBuf[:n1+n2]) 54 return frames, nil 55 } 56 if lenLen < 0 { 57 return nil, fmt.Errorf("messagepack frame length to large") 58 } 59 // Still wondering why this happens, but it happens! 60 if frameLen == 0 { 61 // Store the overread bytes for the next iteration 62 _, _ = remainBuf.Write(frameLenBuf[lenLen:]) 63 continue 64 } 65 // Try getting data until at least one frame is available 66 readBuf := make([]byte, frameLen) 67 frameBuf := &bytes.Buffer{} 68 // Did we read too many bytes when detecting the frameLen? 69 _, _ = frameBuf.Write(frameLenBuf[lenLen:]) 70 // Read the rest of the bytes from the last iteration 71 _, _ = frameBuf.ReadFrom(remainBuf) 72 for { 73 n, err := reader.Read(readBuf) 74 if errors.Is(err, io.EOF) { 75 // Less than frameLen. Let the caller parse the already read frames and come here again later 76 _, _ = remainBuf.ReadFrom(frameBuf) 77 return frames, nil 78 } 79 if err != nil { 80 return nil, err 81 } 82 _, _ = frameBuf.Write(readBuf[:n]) 83 if frameBuf.Len() == int(frameLen) { 84 // Frame completely read. Return it to the caller 85 frames = append(frames, frameBuf.Next(int(frameLen))) 86 return frames, nil 87 } 88 if frameBuf.Len() > int(frameLen) { 89 // More than frameLen. Append the current frame to the result and start reading the next frame 90 frames = append(frames, frameBuf.Next(int(frameLen))) 91 _, _ = remainBuf.ReadFrom(frameBuf) 92 break 93 } 94 } 95 } 96 } 97 98 func (m *messagePackHubProtocol) parseMessage(buf *bytes.Buffer) (interface{}, error) { 99 decoder := msgpack.NewDecoder(buf) 100 // Default map decoding expects all maps to have string keys 101 decoder.SetMapDecoder(func(decoder *msgpack.Decoder) (interface{}, error) { 102 return decoder.DecodeUntypedMap() 103 }) 104 msgLen, err := decoder.DecodeArrayLen() 105 if err != nil { 106 return nil, err 107 } 108 msgType, err := decoder.DecodeInt() 109 if err != nil { 110 return nil, err 111 } 112 // Ignore Header for all messages, except ping message that has no header 113 // see message spec at https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#message-headers 114 if msgType != 6 { 115 _, err = decoder.DecodeMap() 116 if err != nil { 117 return nil, err 118 } 119 } 120 switch msgType { 121 case 1, 4: 122 if msgLen < 5 { 123 return nil, fmt.Errorf("invalid invocationMessage length %v", msgLen) 124 } 125 invocationID, err := m.decodeInvocationID(decoder) 126 if err != nil { 127 return nil, err 128 } 129 invocationMessage := invocationMessage{ 130 Type: msgType, 131 InvocationID: invocationID, 132 } 133 invocationMessage.Target, err = decoder.DecodeString() 134 if err != nil { 135 return nil, err 136 } 137 argLen, err := decoder.DecodeArrayLen() 138 if err != nil { 139 return nil, err 140 } 141 for i := 0; i < argLen; i++ { 142 argument, err := decoder.DecodeRaw() 143 if err != nil { 144 return nil, err 145 } 146 invocationMessage.Arguments = append(invocationMessage.Arguments, argument) 147 } 148 // StreamIds seem to be optional 149 if msgLen == 6 { 150 streamIDLen, err := decoder.DecodeArrayLen() 151 if err != nil { 152 return nil, err 153 } 154 for i := 0; i < streamIDLen; i++ { 155 streamID, err := decoder.DecodeString() 156 if err != nil { 157 return nil, err 158 } 159 invocationMessage.StreamIds = append(invocationMessage.StreamIds, streamID) 160 } 161 } 162 return invocationMessage, nil 163 case 2: 164 if msgLen != 4 { 165 return nil, fmt.Errorf("invalid streamItemMessage length %v", msgLen) 166 } 167 streamItemMessage := streamItemMessage{Type: 2} 168 streamItemMessage.InvocationID, err = decoder.DecodeString() 169 if err != nil { 170 return nil, err 171 } 172 streamItemMessage.Item, err = decoder.DecodeRaw() 173 if err != nil { 174 return nil, err 175 } 176 return streamItemMessage, nil 177 case 3: 178 if msgLen < 4 { 179 return nil, fmt.Errorf("invalid completionMessage length %v", msgLen) 180 } 181 completionMessage := completionMessage{Type: 3} 182 completionMessage.InvocationID, err = decoder.DecodeString() 183 if err != nil { 184 return nil, err 185 } 186 resultKind, err := decoder.DecodeInt8() 187 if err != nil { 188 return nil, err 189 } 190 switch resultKind { 191 case 1: // Error result 192 if msgLen < 5 { 193 return nil, fmt.Errorf("invalid completionMessage length %v", msgLen) 194 } 195 completionMessage.Error, err = decoder.DecodeString() 196 if err != nil { 197 return nil, err 198 } 199 case 2: // Void result 200 case 3: // Non-void result 201 if msgLen < 5 { 202 return nil, fmt.Errorf("invalid completionMessage length %v", msgLen) 203 } 204 completionMessage.Result, err = decoder.DecodeRaw() 205 if err != nil { 206 return nil, err 207 } 208 default: 209 return nil, fmt.Errorf("invalid resultKind %v", resultKind) 210 } 211 return completionMessage, nil 212 case 5: 213 if msgLen != 3 { 214 return nil, fmt.Errorf("invalid cancelInvocationMessage length %v", msgLen) 215 } 216 cancelInvocationMessage := cancelInvocationMessage{Type: 5} 217 cancelInvocationMessage.InvocationID, err = decoder.DecodeString() 218 if err != nil { 219 return nil, err 220 } 221 return cancelInvocationMessage, nil 222 case 6: 223 if msgLen != 1 { 224 return nil, fmt.Errorf("invalid pingMessage length %v", msgLen) 225 } 226 return hubMessage{Type: 6}, nil 227 case 7: 228 if msgLen < 2 { 229 return nil, fmt.Errorf("invalid closeMessage length %v", msgLen) 230 } 231 closeMessage := closeMessage{Type: 7} 232 closeMessage.Error, err = decoder.DecodeString() 233 if err != nil { 234 return nil, err 235 } 236 if msgLen > 2 { 237 closeMessage.AllowReconnect, err = decoder.DecodeBool() 238 if err != nil { 239 return nil, err 240 } 241 } 242 return closeMessage, nil 243 } 244 return msg, nil 245 } 246 247 func (m *messagePackHubProtocol) decodeInvocationID(decoder *msgpack.Decoder) (string, error) { 248 rawID, err := decoder.DecodeInterface() 249 if err != nil { 250 return "", err 251 } 252 // nil is ok 253 if rawID == nil { 254 return "", nil 255 } 256 // Otherwise, it must be string 257 invocationID, ok := rawID.(string) 258 if !ok { 259 return "", fmt.Errorf("invalid InvocationID %#v", rawID) 260 } 261 return invocationID, nil 262 } 263 264 func (m *messagePackHubProtocol) WriteMessage(message interface{}, writer io.Writer) error { 265 // Encode message body 266 buf := &bytes.Buffer{} 267 encoder := msgpack.NewEncoder(buf) 268 // Ensure uppercase/lowercase mapping for struct member names 269 encoder.SetCustomStructTag("json") 270 switch msg := message.(type) { 271 case invocationMessage: 272 if err := encodeMsgHeader(encoder, 6, msg.Type); err != nil { 273 return err 274 } 275 if msg.InvocationID == "" { 276 if err := encoder.EncodeNil(); err != nil { 277 return err 278 } 279 } else { 280 if err := encoder.EncodeString(msg.InvocationID); err != nil { 281 return err 282 } 283 } 284 if err := encoder.EncodeString(msg.Target); err != nil { 285 return err 286 } 287 if err := encoder.EncodeArrayLen(len(msg.Arguments)); err != nil { 288 return err 289 } 290 for _, arg := range msg.Arguments { 291 if err := encoder.Encode(arg); err != nil { 292 return err 293 } 294 } 295 if err := encoder.EncodeArrayLen(len(msg.StreamIds)); err != nil { 296 return err 297 } 298 for _, id := range msg.StreamIds { 299 if err := encoder.EncodeString(id); err != nil { 300 return err 301 } 302 } 303 case streamItemMessage: 304 if err := encodeMsgHeader(encoder, 4, msg.Type); err != nil { 305 return err 306 } 307 if err := encoder.EncodeString(msg.InvocationID); err != nil { 308 return err 309 } 310 if err := encoder.Encode(msg.Item); err != nil { 311 return err 312 } 313 case completionMessage: 314 msgLen := 4 315 if msg.Result != nil || msg.Error != "" { 316 msgLen = 5 317 } 318 if err := encodeMsgHeader(encoder, msgLen, msg.Type); err != nil { 319 return err 320 } 321 if err := encoder.EncodeString(msg.InvocationID); err != nil { 322 return err 323 } 324 var resultKind int8 = 2 325 if msg.Error != "" { 326 resultKind = 1 327 } else if msg.Result != nil { 328 resultKind = 3 329 } 330 if err := encoder.EncodeInt8(resultKind); err != nil { 331 return err 332 } 333 switch resultKind { 334 case 1: 335 if err := encoder.EncodeString(msg.Error); err != nil { 336 return err 337 } 338 case 3: 339 if err := encoder.Encode(msg.Result); err != nil { 340 return err 341 } 342 } 343 case cancelInvocationMessage: 344 if err := encodeMsgHeader(encoder, 3, msg.Type); err != nil { 345 return err 346 } 347 if err := encoder.EncodeString(msg.InvocationID); err != nil { 348 return err 349 } 350 case hubMessage: 351 if err := encoder.EncodeArrayLen(1); err != nil { 352 return err 353 } 354 if err := encoder.EncodeInt(6); err != nil { 355 return err 356 } 357 case closeMessage: 358 if err := encodeMsgHeader(encoder, 3, msg.Type); err != nil { 359 return err 360 } 361 if err := encoder.EncodeString(msg.Error); err != nil { 362 return err 363 } 364 if err := encoder.EncodeBool(msg.AllowReconnect); err != nil { 365 return err 366 } 367 } 368 // Build frame with length information 369 frameBuf := &bytes.Buffer{} 370 lenBuf := make([]byte, binary.MaxVarintLen32) 371 lenLen := binary.PutUvarint(lenBuf, uint64(buf.Len())) 372 if _, err := frameBuf.Write(lenBuf[:lenLen]); err != nil { 373 return err 374 } 375 _ = m.dbg.Log(evt, "Write", msg, fmt.Sprintf("%#v", message)) 376 _, _ = frameBuf.ReadFrom(buf) 377 _, err := frameBuf.WriteTo(writer) 378 return err 379 } 380 381 func encodeMsgHeader(e *msgpack.Encoder, msgLen int, msgType int) (err error) { 382 if err = e.EncodeArrayLen(msgLen); err != nil { 383 return err 384 } 385 if err = e.EncodeInt(int64(msgType)); err != nil { 386 return err 387 } 388 headers := make(map[string]interface{}) 389 if err = e.EncodeMap(headers); err != nil { 390 return err 391 } 392 return nil 393 } 394 395 func (m *messagePackHubProtocol) transferMode() TransferMode { 396 return BinaryTransferMode 397 } 398 399 func (m *messagePackHubProtocol) setDebugLogger(dbg StructuredLogger) { 400 m.dbg = log.WithPrefix(dbg, "ts", log.DefaultTimestampUTC, "protocol", "MSGP") 401 } 402 403 // UnmarshalArgument unmarshals raw bytes to a destination value. dst is the pointer to the destination value. 404 func (m *messagePackHubProtocol) UnmarshalArgument(src interface{}, dst interface{}) error { 405 rawSrc, ok := src.(msgpack.RawMessage) 406 if !ok { 407 return fmt.Errorf("invalid source %#v for UnmarshalArgument", src) 408 } 409 buf := bytes.NewBuffer(rawSrc) 410 decoder := msgpack.GetDecoder() 411 defer msgpack.PutDecoder(decoder) 412 decoder.Reset(buf) 413 // Default map decoding expects all maps to have string keys 414 decoder.SetMapDecoder(func(decoder *msgpack.Decoder) (interface{}, error) { 415 return decoder.DecodeUntypedMap() 416 }) 417 // Ensure uppercase/lowercase mapping for struct member names 418 decoder.SetCustomStructTag("json") 419 if err := decoder.Decode(dst); err != nil { 420 return err 421 } 422 return nil 423 }