github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/test/level_2/message_test.go (about)

     1  package test
     2  
     3  import (
     4  	"github.com/apache/thrift/lib/go/thrift"
     5  	"github.com/batchcorp/thrift-iterator/general"
     6  	"github.com/batchcorp/thrift-iterator/protocol"
     7  	"github.com/batchcorp/thrift-iterator/test"
     8  	"github.com/stretchr/testify/require"
     9  	"testing"
    10  )
    11  
    12  func Test_skip_message(t *testing.T) {
    13  	should := require.New(t)
    14  	for _, c := range test.Combinations {
    15  		buf, proto := c.CreateProtocol()
    16  		proto.WriteMessageBegin("hello", thrift.CALL, 17)
    17  		proto.WriteStructBegin("args")
    18  		proto.WriteFieldBegin("field1", thrift.I64, 1)
    19  		proto.WriteI64(1)
    20  		proto.WriteFieldBegin("field2", thrift.I64, 2)
    21  		proto.WriteI64(2)
    22  		proto.WriteFieldEnd()
    23  		proto.WriteFieldStop()
    24  		proto.WriteStructEnd()
    25  		proto.WriteMessageEnd()
    26  		iter := c.CreateIterator(buf.Bytes())
    27  		should.Equal(buf.Bytes(), iter.SkipStruct(iter.SkipMessageHeader(nil)))
    28  	}
    29  }
    30  
    31  func Test_unmarshal_message(t *testing.T) {
    32  	should := require.New(t)
    33  	for _, c := range test.Combinations {
    34  		buf, proto := c.CreateProtocol()
    35  		proto.WriteMessageBegin("hello", thrift.CALL, 17)
    36  		proto.WriteStructBegin("args")
    37  		proto.WriteFieldBegin("field1", thrift.I64, 1)
    38  		proto.WriteI64(1)
    39  		proto.WriteFieldBegin("field2", thrift.I64, 2)
    40  		proto.WriteI64(2)
    41  		proto.WriteFieldEnd()
    42  		proto.WriteFieldStop()
    43  		proto.WriteStructEnd()
    44  		proto.WriteMessageEnd()
    45  		var msg general.Message
    46  		should.NoError(c.Unmarshal(buf.Bytes(), &msg))
    47  		should.Equal("hello", msg.MessageName)
    48  		should.Equal(protocol.MessageTypeCall, msg.MessageType)
    49  		should.Equal(protocol.SeqId(17), msg.SeqId)
    50  		should.Equal(int64(1), msg.Arguments[protocol.FieldId(1)])
    51  		should.Equal(int64(2), msg.Arguments[protocol.FieldId(2)])
    52  	}
    53  }
    54  
    55  func Test_marshal_message(t *testing.T) {
    56  	should := require.New(t)
    57  	for _, c := range test.Combinations {
    58  		output, err := c.Marshal(general.Message{
    59  			MessageHeader: protocol.MessageHeader{
    60  				MessageType: protocol.MessageTypeCall,
    61  				MessageName: "hello",
    62  				SeqId:       protocol.SeqId(17),
    63  			},
    64  			Arguments: general.Struct{
    65  				protocol.FieldId(1): int64(1),
    66  				protocol.FieldId(2): int64(2),
    67  			},
    68  		})
    69  		should.NoError(err)
    70  		var msg general.Message
    71  		should.NoError(c.Unmarshal(output, &msg))
    72  		should.Equal("hello", msg.MessageName)
    73  		should.Equal(protocol.MessageTypeCall, msg.MessageType)
    74  		should.Equal(protocol.SeqId(17), msg.SeqId)
    75  		should.Equal(int64(1), msg.Arguments[protocol.FieldId(1)])
    76  		should.Equal(int64(2), msg.Arguments[protocol.FieldId(2)])
    77  	}
    78  }