github.com/philippseith/signalr@v0.6.3/messagepackhubprotocol.go (about)

     1  package signalr
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/go-kit/log"
    11  	"github.com/vmihailenco/msgpack/v5"
    12  )
    13  
    14  type messagePackHubProtocol struct {
    15  	dbg log.Logger
    16  }
    17  
    18  func (m *messagePackHubProtocol) ParseMessages(reader io.Reader, remainBuf *bytes.Buffer) ([]interface{}, error) {
    19  	frames, err := m.readFrames(reader, remainBuf)
    20  	if err != nil {
    21  		return nil, err
    22  	}
    23  	messages := make([]interface{}, 0)
    24  	for _, frame := range frames {
    25  		message, err := m.parseMessage(bytes.NewBuffer(frame))
    26  		if err != nil {
    27  			return nil, err
    28  		}
    29  		messages = append(messages, message)
    30  	}
    31  	return messages, nil
    32  }
    33  
    34  func (m *messagePackHubProtocol) readFrames(reader io.Reader, remainBuf *bytes.Buffer) ([][]byte, error) {
    35  	frames := make([][]byte, 0)
    36  	for {
    37  		// Try to get the frame length
    38  		frameLenBuf := make([]byte, binary.MaxVarintLen32)
    39  		n1, err := remainBuf.Read(frameLenBuf)
    40  		if err != nil && !errors.Is(err, io.EOF) {
    41  			// Some weird other error
    42  			return nil, err
    43  		}
    44  		n2, err := reader.Read(frameLenBuf[n1:])
    45  		if err != nil && !errors.Is(err, io.EOF) {
    46  			// Some weird other error
    47  			return nil, err
    48  		}
    49  		frameLen, lenLen := binary.Uvarint(frameLenBuf[:n1+n2])
    50  		if lenLen == 0 {
    51  			// reader could not supply enough bytes to decode the Uvarint
    52  			// Store the already read bytes in the remainBuf for next iteration
    53  			_, _ = remainBuf.Write(frameLenBuf[:n1+n2])
    54  			return frames, nil
    55  		}
    56  		if lenLen < 0 {
    57  			return nil, fmt.Errorf("messagepack frame length to large")
    58  		}
    59  		// Still wondering why this happens, but it happens!
    60  		if frameLen == 0 {
    61  			// Store the overread bytes for the next iteration
    62  			_, _ = remainBuf.Write(frameLenBuf[lenLen:])
    63  			continue
    64  		}
    65  		// Try getting data until at least one frame is available
    66  		readBuf := make([]byte, frameLen)
    67  		frameBuf := &bytes.Buffer{}
    68  		// Did we read too many bytes when detecting the frameLen?
    69  		_, _ = frameBuf.Write(frameLenBuf[lenLen:])
    70  		// Read the rest of the bytes from the last iteration
    71  		_, _ = frameBuf.ReadFrom(remainBuf)
    72  		for {
    73  			n, err := reader.Read(readBuf)
    74  			if errors.Is(err, io.EOF) {
    75  				// Less than frameLen. Let the caller parse the already read frames and come here again later
    76  				_, _ = remainBuf.ReadFrom(frameBuf)
    77  				return frames, nil
    78  			}
    79  			if err != nil {
    80  				return nil, err
    81  			}
    82  			_, _ = frameBuf.Write(readBuf[:n])
    83  			if frameBuf.Len() == int(frameLen) {
    84  				// Frame completely read. Return it to the caller
    85  				frames = append(frames, frameBuf.Next(int(frameLen)))
    86  				return frames, nil
    87  			}
    88  			if frameBuf.Len() > int(frameLen) {
    89  				// More than frameLen. Append the current frame to the result and start reading the next frame
    90  				frames = append(frames, frameBuf.Next(int(frameLen)))
    91  				_, _ = remainBuf.ReadFrom(frameBuf)
    92  				break
    93  			}
    94  		}
    95  	}
    96  }
    97  
    98  func (m *messagePackHubProtocol) parseMessage(buf *bytes.Buffer) (interface{}, error) {
    99  	decoder := msgpack.NewDecoder(buf)
   100  	// Default map decoding expects all maps to have string keys
   101  	decoder.SetMapDecoder(func(decoder *msgpack.Decoder) (interface{}, error) {
   102  		return decoder.DecodeUntypedMap()
   103  	})
   104  	msgLen, err := decoder.DecodeArrayLen()
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	msgType, err := decoder.DecodeInt()
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	// Ignore Header for all messages, except ping message that has no header
   113  	// see message spec at https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#message-headers
   114  	if msgType != 6 {
   115  		_, err = decoder.DecodeMap()
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  	}
   120  	switch msgType {
   121  	case 1, 4:
   122  		if msgLen < 5 {
   123  			return nil, fmt.Errorf("invalid invocationMessage length %v", msgLen)
   124  		}
   125  		invocationID, err := m.decodeInvocationID(decoder)
   126  		if err != nil {
   127  			return nil, err
   128  		}
   129  		invocationMessage := invocationMessage{
   130  			Type:         msgType,
   131  			InvocationID: invocationID,
   132  		}
   133  		invocationMessage.Target, err = decoder.DecodeString()
   134  		if err != nil {
   135  			return nil, err
   136  		}
   137  		argLen, err := decoder.DecodeArrayLen()
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  		for i := 0; i < argLen; i++ {
   142  			argument, err := decoder.DecodeRaw()
   143  			if err != nil {
   144  				return nil, err
   145  			}
   146  			invocationMessage.Arguments = append(invocationMessage.Arguments, argument)
   147  		}
   148  		// StreamIds seem to be optional
   149  		if msgLen == 6 {
   150  			streamIDLen, err := decoder.DecodeArrayLen()
   151  			if err != nil {
   152  				return nil, err
   153  			}
   154  			for i := 0; i < streamIDLen; i++ {
   155  				streamID, err := decoder.DecodeString()
   156  				if err != nil {
   157  					return nil, err
   158  				}
   159  				invocationMessage.StreamIds = append(invocationMessage.StreamIds, streamID)
   160  			}
   161  		}
   162  		return invocationMessage, nil
   163  	case 2:
   164  		if msgLen != 4 {
   165  			return nil, fmt.Errorf("invalid streamItemMessage length %v", msgLen)
   166  		}
   167  		streamItemMessage := streamItemMessage{Type: 2}
   168  		streamItemMessage.InvocationID, err = decoder.DecodeString()
   169  		if err != nil {
   170  			return nil, err
   171  		}
   172  		streamItemMessage.Item, err = decoder.DecodeRaw()
   173  		if err != nil {
   174  			return nil, err
   175  		}
   176  		return streamItemMessage, nil
   177  	case 3:
   178  		if msgLen < 4 {
   179  			return nil, fmt.Errorf("invalid completionMessage length %v", msgLen)
   180  		}
   181  		completionMessage := completionMessage{Type: 3}
   182  		completionMessage.InvocationID, err = decoder.DecodeString()
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  		resultKind, err := decoder.DecodeInt8()
   187  		if err != nil {
   188  			return nil, err
   189  		}
   190  		switch resultKind {
   191  		case 1: // Error result
   192  			if msgLen < 5 {
   193  				return nil, fmt.Errorf("invalid completionMessage length %v", msgLen)
   194  			}
   195  			completionMessage.Error, err = decoder.DecodeString()
   196  			if err != nil {
   197  				return nil, err
   198  			}
   199  		case 2: // Void result
   200  		case 3: // Non-void result
   201  			if msgLen < 5 {
   202  				return nil, fmt.Errorf("invalid completionMessage length %v", msgLen)
   203  			}
   204  			completionMessage.Result, err = decoder.DecodeRaw()
   205  			if err != nil {
   206  				return nil, err
   207  			}
   208  		default:
   209  			return nil, fmt.Errorf("invalid resultKind %v", resultKind)
   210  		}
   211  		return completionMessage, nil
   212  	case 5:
   213  		if msgLen != 3 {
   214  			return nil, fmt.Errorf("invalid cancelInvocationMessage length %v", msgLen)
   215  		}
   216  		cancelInvocationMessage := cancelInvocationMessage{Type: 5}
   217  		cancelInvocationMessage.InvocationID, err = decoder.DecodeString()
   218  		if err != nil {
   219  			return nil, err
   220  		}
   221  		return cancelInvocationMessage, nil
   222  	case 6:
   223  		if msgLen != 1 {
   224  			return nil, fmt.Errorf("invalid pingMessage length %v", msgLen)
   225  		}
   226  		return hubMessage{Type: 6}, nil
   227  	case 7:
   228  		if msgLen < 2 {
   229  			return nil, fmt.Errorf("invalid closeMessage length %v", msgLen)
   230  		}
   231  		closeMessage := closeMessage{Type: 7}
   232  		closeMessage.Error, err = decoder.DecodeString()
   233  		if err != nil {
   234  			return nil, err
   235  		}
   236  		if msgLen > 2 {
   237  			closeMessage.AllowReconnect, err = decoder.DecodeBool()
   238  			if err != nil {
   239  				return nil, err
   240  			}
   241  		}
   242  		return closeMessage, nil
   243  	}
   244  	return msg, nil
   245  }
   246  
   247  func (m *messagePackHubProtocol) decodeInvocationID(decoder *msgpack.Decoder) (string, error) {
   248  	rawID, err := decoder.DecodeInterface()
   249  	if err != nil {
   250  		return "", err
   251  	}
   252  	// nil is ok
   253  	if rawID == nil {
   254  		return "", nil
   255  	}
   256  	// Otherwise, it must be string
   257  	invocationID, ok := rawID.(string)
   258  	if !ok {
   259  		return "", fmt.Errorf("invalid InvocationID %#v", rawID)
   260  	}
   261  	return invocationID, nil
   262  }
   263  
   264  func (m *messagePackHubProtocol) WriteMessage(message interface{}, writer io.Writer) error {
   265  	// Encode message body
   266  	buf := &bytes.Buffer{}
   267  	encoder := msgpack.NewEncoder(buf)
   268  	// Ensure uppercase/lowercase mapping for struct member names
   269  	encoder.SetCustomStructTag("json")
   270  	switch msg := message.(type) {
   271  	case invocationMessage:
   272  		if err := encodeMsgHeader(encoder, 6, msg.Type); err != nil {
   273  			return err
   274  		}
   275  		if msg.InvocationID == "" {
   276  			if err := encoder.EncodeNil(); err != nil {
   277  				return err
   278  			}
   279  		} else {
   280  			if err := encoder.EncodeString(msg.InvocationID); err != nil {
   281  				return err
   282  			}
   283  		}
   284  		if err := encoder.EncodeString(msg.Target); err != nil {
   285  			return err
   286  		}
   287  		if err := encoder.EncodeArrayLen(len(msg.Arguments)); err != nil {
   288  			return err
   289  		}
   290  		for _, arg := range msg.Arguments {
   291  			if err := encoder.Encode(arg); err != nil {
   292  				return err
   293  			}
   294  		}
   295  		if err := encoder.EncodeArrayLen(len(msg.StreamIds)); err != nil {
   296  			return err
   297  		}
   298  		for _, id := range msg.StreamIds {
   299  			if err := encoder.EncodeString(id); err != nil {
   300  				return err
   301  			}
   302  		}
   303  	case streamItemMessage:
   304  		if err := encodeMsgHeader(encoder, 4, msg.Type); err != nil {
   305  			return err
   306  		}
   307  		if err := encoder.EncodeString(msg.InvocationID); err != nil {
   308  			return err
   309  		}
   310  		if err := encoder.Encode(msg.Item); err != nil {
   311  			return err
   312  		}
   313  	case completionMessage:
   314  		msgLen := 4
   315  		if msg.Result != nil || msg.Error != "" {
   316  			msgLen = 5
   317  		}
   318  		if err := encodeMsgHeader(encoder, msgLen, msg.Type); err != nil {
   319  			return err
   320  		}
   321  		if err := encoder.EncodeString(msg.InvocationID); err != nil {
   322  			return err
   323  		}
   324  		var resultKind int8 = 2
   325  		if msg.Error != "" {
   326  			resultKind = 1
   327  		} else if msg.Result != nil {
   328  			resultKind = 3
   329  		}
   330  		if err := encoder.EncodeInt8(resultKind); err != nil {
   331  			return err
   332  		}
   333  		switch resultKind {
   334  		case 1:
   335  			if err := encoder.EncodeString(msg.Error); err != nil {
   336  				return err
   337  			}
   338  		case 3:
   339  			if err := encoder.Encode(msg.Result); err != nil {
   340  				return err
   341  			}
   342  		}
   343  	case cancelInvocationMessage:
   344  		if err := encodeMsgHeader(encoder, 3, msg.Type); err != nil {
   345  			return err
   346  		}
   347  		if err := encoder.EncodeString(msg.InvocationID); err != nil {
   348  			return err
   349  		}
   350  	case hubMessage:
   351  		if err := encoder.EncodeArrayLen(1); err != nil {
   352  			return err
   353  		}
   354  		if err := encoder.EncodeInt(6); err != nil {
   355  			return err
   356  		}
   357  	case closeMessage:
   358  		if err := encodeMsgHeader(encoder, 3, msg.Type); err != nil {
   359  			return err
   360  		}
   361  		if err := encoder.EncodeString(msg.Error); err != nil {
   362  			return err
   363  		}
   364  		if err := encoder.EncodeBool(msg.AllowReconnect); err != nil {
   365  			return err
   366  		}
   367  	}
   368  	// Build frame with length information
   369  	frameBuf := &bytes.Buffer{}
   370  	lenBuf := make([]byte, binary.MaxVarintLen32)
   371  	lenLen := binary.PutUvarint(lenBuf, uint64(buf.Len()))
   372  	if _, err := frameBuf.Write(lenBuf[:lenLen]); err != nil {
   373  		return err
   374  	}
   375  	_ = m.dbg.Log(evt, "Write", msg, fmt.Sprintf("%#v", message))
   376  	_, _ = frameBuf.ReadFrom(buf)
   377  	_, err := frameBuf.WriteTo(writer)
   378  	return err
   379  }
   380  
   381  func encodeMsgHeader(e *msgpack.Encoder, msgLen int, msgType int) (err error) {
   382  	if err = e.EncodeArrayLen(msgLen); err != nil {
   383  		return err
   384  	}
   385  	if err = e.EncodeInt(int64(msgType)); err != nil {
   386  		return err
   387  	}
   388  	headers := make(map[string]interface{})
   389  	if err = e.EncodeMap(headers); err != nil {
   390  		return err
   391  	}
   392  	return nil
   393  }
   394  
   395  func (m *messagePackHubProtocol) transferMode() TransferMode {
   396  	return BinaryTransferMode
   397  }
   398  
   399  func (m *messagePackHubProtocol) setDebugLogger(dbg StructuredLogger) {
   400  	m.dbg = log.WithPrefix(dbg, "ts", log.DefaultTimestampUTC, "protocol", "MSGP")
   401  }
   402  
   403  // UnmarshalArgument unmarshals raw bytes to a destination value. dst is the pointer to the destination value.
   404  func (m *messagePackHubProtocol) UnmarshalArgument(src interface{}, dst interface{}) error {
   405  	rawSrc, ok := src.(msgpack.RawMessage)
   406  	if !ok {
   407  		return fmt.Errorf("invalid source %#v for UnmarshalArgument", src)
   408  	}
   409  	buf := bytes.NewBuffer(rawSrc)
   410  	decoder := msgpack.GetDecoder()
   411  	defer msgpack.PutDecoder(decoder)
   412  	decoder.Reset(buf)
   413  	// Default map decoding expects all maps to have string keys
   414  	decoder.SetMapDecoder(func(decoder *msgpack.Decoder) (interface{}, error) {
   415  		return decoder.DecodeUntypedMap()
   416  	})
   417  	// Ensure uppercase/lowercase mapping for struct member names
   418  	decoder.SetCustomStructTag("json")
   419  	if err := decoder.Decode(dst); err != nil {
   420  		return err
   421  	}
   422  	return nil
   423  }