github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/test/level_2/message_test.go (about) 1 package test 2 3 import ( 4 "github.com/apache/thrift/lib/go/thrift" 5 "github.com/batchcorp/thrift-iterator/general" 6 "github.com/batchcorp/thrift-iterator/protocol" 7 "github.com/batchcorp/thrift-iterator/test" 8 "github.com/stretchr/testify/require" 9 "testing" 10 ) 11 12 func Test_skip_message(t *testing.T) { 13 should := require.New(t) 14 for _, c := range test.Combinations { 15 buf, proto := c.CreateProtocol() 16 proto.WriteMessageBegin("hello", thrift.CALL, 17) 17 proto.WriteStructBegin("args") 18 proto.WriteFieldBegin("field1", thrift.I64, 1) 19 proto.WriteI64(1) 20 proto.WriteFieldBegin("field2", thrift.I64, 2) 21 proto.WriteI64(2) 22 proto.WriteFieldEnd() 23 proto.WriteFieldStop() 24 proto.WriteStructEnd() 25 proto.WriteMessageEnd() 26 iter := c.CreateIterator(buf.Bytes()) 27 should.Equal(buf.Bytes(), iter.SkipStruct(iter.SkipMessageHeader(nil))) 28 } 29 } 30 31 func Test_unmarshal_message(t *testing.T) { 32 should := require.New(t) 33 for _, c := range test.Combinations { 34 buf, proto := c.CreateProtocol() 35 proto.WriteMessageBegin("hello", thrift.CALL, 17) 36 proto.WriteStructBegin("args") 37 proto.WriteFieldBegin("field1", thrift.I64, 1) 38 proto.WriteI64(1) 39 proto.WriteFieldBegin("field2", thrift.I64, 2) 40 proto.WriteI64(2) 41 proto.WriteFieldEnd() 42 proto.WriteFieldStop() 43 proto.WriteStructEnd() 44 proto.WriteMessageEnd() 45 var msg general.Message 46 should.NoError(c.Unmarshal(buf.Bytes(), &msg)) 47 should.Equal("hello", msg.MessageName) 48 should.Equal(protocol.MessageTypeCall, msg.MessageType) 49 should.Equal(protocol.SeqId(17), msg.SeqId) 50 should.Equal(int64(1), msg.Arguments[protocol.FieldId(1)]) 51 should.Equal(int64(2), msg.Arguments[protocol.FieldId(2)]) 52 } 53 } 54 55 func Test_marshal_message(t *testing.T) { 56 should := require.New(t) 57 for _, c := range test.Combinations { 58 output, err := c.Marshal(general.Message{ 59 MessageHeader: protocol.MessageHeader{ 60 MessageType: protocol.MessageTypeCall, 61 MessageName: "hello", 62 SeqId: protocol.SeqId(17), 63 }, 64 Arguments: general.Struct{ 65 protocol.FieldId(1): int64(1), 66 protocol.FieldId(2): int64(2), 67 }, 68 }) 69 should.NoError(err) 70 var msg general.Message 71 should.NoError(c.Unmarshal(output, &msg)) 72 should.Equal("hello", msg.MessageName) 73 should.Equal(protocol.MessageTypeCall, msg.MessageType) 74 should.Equal(protocol.SeqId(17), msg.SeqId) 75 should.Equal(int64(1), msg.Arguments[protocol.FieldId(1)]) 76 should.Equal(int64(2), msg.Arguments[protocol.FieldId(2)]) 77 } 78 }