github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/test/level_1/struct_test.go (about) 1 package test 2 3 import ( 4 "github.com/apache/thrift/lib/go/thrift" 5 "github.com/batchcorp/thrift-iterator/general" 6 "github.com/batchcorp/thrift-iterator/protocol" 7 "github.com/batchcorp/thrift-iterator/raw" 8 "github.com/batchcorp/thrift-iterator/test" 9 "github.com/batchcorp/thrift-iterator/test/level_1/struct_test" 10 "github.com/stretchr/testify/require" 11 "testing" 12 ) 13 14 func Test_decode_struct_by_iterator(t *testing.T) { 15 should := require.New(t) 16 for _, c := range test.Combinations { 17 buf, proto := c.CreateProtocol() 18 proto.WriteStructBegin("hello") 19 proto.WriteFieldBegin("field1", thrift.I64, 1) 20 proto.WriteI64(1024) 21 proto.WriteFieldEnd() 22 proto.WriteFieldStop() 23 proto.WriteStructEnd() 24 iter := c.CreateIterator(buf.Bytes()) 25 called := false 26 iter.ReadStructHeader() 27 for { 28 fieldType, fieldId := iter.ReadStructField() 29 if fieldType == protocol.TypeStop { 30 break 31 } 32 should.False(called) 33 called = true 34 should.Equal(protocol.TypeI64, fieldType) 35 should.Equal(protocol.FieldId(1), fieldId) 36 should.Equal(int64(1024), iter.ReadInt64()) 37 } 38 should.NoError(iter.Error()) 39 should.True(called) 40 } 41 } 42 43 func Test_decode_struct_with_bool_by_iterator(t *testing.T) { 44 should := require.New(t) 45 for _, c := range test.Combinations { 46 buf, proto := c.CreateProtocol() 47 proto.WriteStructBegin("hello") 48 proto.WriteFieldBegin("field1", thrift.BOOL, 1) 49 proto.WriteBool(true) 50 proto.WriteFieldEnd() 51 proto.WriteFieldStop() 52 proto.WriteStructEnd() 53 iter := c.CreateIterator(buf.Bytes()) 54 called := false 55 iter.ReadStructHeader() 56 for { 57 fieldType, fieldId := iter.ReadStructField() 58 if fieldType == protocol.TypeStop { 59 break 60 } 61 should.False(called) 62 called = true 63 should.Equal(protocol.TypeBool, fieldType) 64 should.Equal(protocol.FieldId(1), fieldId) 65 should.Equal(true, iter.ReadBool()) 66 } 67 should.True(called) 68 } 69 } 70 71 func Test_encode_struct_by_stream(t *testing.T) { 72 should := require.New(t) 73 for _, c := range test.Combinations { 74 stream := c.CreateStream() 75 stream.WriteStructHeader() 76 stream.WriteStructField(protocol.TypeI64, protocol.FieldId(1)) 77 stream.WriteInt64(1024) 78 stream.WriteStructFieldStop() 79 iter := c.CreateIterator(stream.Buffer()) 80 called := false 81 iter.ReadStructHeader() 82 for { 83 fieldType, fieldId := iter.ReadStructField() 84 if fieldType == protocol.TypeStop { 85 break 86 } 87 should.False(called) 88 called = true 89 should.Equal(protocol.TypeI64, fieldType) 90 should.Equal(protocol.FieldId(1), fieldId) 91 should.Equal(int64(1024), iter.ReadInt64()) 92 } 93 } 94 } 95 96 func Test_encode_struct_with_bool_by_stream(t *testing.T) { 97 should := require.New(t) 98 for _, c := range test.Combinations { 99 stream := c.CreateStream() 100 stream.WriteStructHeader() 101 stream.WriteStructField(protocol.TypeBool, protocol.FieldId(1)) 102 stream.WriteBool(true) 103 stream.WriteStructFieldStop() 104 iter := c.CreateIterator(stream.Buffer()) 105 called := false 106 iter.ReadStructHeader() 107 for { 108 fieldType, fieldId := iter.ReadStructField() 109 if fieldType == protocol.TypeStop { 110 break 111 } 112 should.False(called) 113 called = true 114 should.Equal(protocol.TypeBool, fieldType) 115 should.Equal(protocol.FieldId(1), fieldId) 116 should.Equal(true, iter.ReadBool()) 117 } 118 should.True(called) 119 } 120 } 121 122 func Test_skip_struct(t *testing.T) { 123 should := require.New(t) 124 for _, c := range test.Combinations { 125 buf, proto := c.CreateProtocol() 126 proto.WriteStructBegin("hello") 127 proto.WriteFieldBegin("field1", thrift.I64, 1) 128 proto.WriteI64(1024) 129 proto.WriteFieldEnd() 130 proto.WriteFieldStop() 131 proto.WriteStructEnd() 132 iter := c.CreateIterator(buf.Bytes()) 133 should.Equal(buf.Bytes(), iter.SkipStruct(nil)) 134 } 135 } 136 137 func Test_unmarshal_general_struct(t *testing.T) { 138 should := require.New(t) 139 for _, c := range test.Combinations { 140 buf, proto := c.CreateProtocol() 141 proto.WriteStructBegin("hello") 142 proto.WriteFieldBegin("field1", thrift.I64, 1) 143 proto.WriteI64(1024) 144 proto.WriteFieldEnd() 145 proto.WriteFieldStop() 146 proto.WriteStructEnd() 147 var val general.Struct 148 should.NoError(c.Unmarshal(buf.Bytes(), &val)) 149 should.Equal(general.Struct{ 150 protocol.FieldId(1): int64(1024), 151 }, val) 152 } 153 } 154 155 func Test_unmarshal_raw_struct(t *testing.T) { 156 should := require.New(t) 157 for _, c := range test.Combinations { 158 buf, proto := c.CreateProtocol() 159 proto.WriteStructBegin("hello") 160 proto.WriteFieldBegin("field1", thrift.I64, 1) 161 proto.WriteI64(1024) 162 proto.WriteFieldEnd() 163 proto.WriteFieldStop() 164 proto.WriteStructEnd() 165 var val raw.Struct 166 should.NoError(c.Unmarshal(buf.Bytes(), &val)) 167 should.Equal(1, len(val)) 168 should.Equal(protocol.TypeI64, val[protocol.FieldId(1)].Type) 169 } 170 } 171 172 func Test_unmarshal_struct(t *testing.T) { 173 should := require.New(t) 174 for _, c := range test.UnmarshalCombinations { 175 buf, proto := c.CreateProtocol() 176 proto.WriteStructBegin("hello") 177 proto.WriteFieldBegin("field1", thrift.I64, 1) 178 proto.WriteI64(1024) 179 proto.WriteFieldEnd() 180 proto.WriteFieldStop() 181 proto.WriteStructEnd() 182 var val struct_test.TestObject 183 should.NoError(c.Unmarshal(buf.Bytes(), &val)) 184 should.Equal(struct_test.TestObject{1024}, val) 185 } 186 } 187 188 func Test_marshal_general_struct(t *testing.T) { 189 should := require.New(t) 190 for _, c := range test.Combinations { 191 output, err := c.Marshal(general.Struct{ 192 protocol.FieldId(1): int64(1024), 193 }) 194 should.NoError(err) 195 var val general.Struct 196 should.NoError(c.Unmarshal(output, &val)) 197 should.Equal(general.Struct{ 198 protocol.FieldId(1): int64(1024), 199 }, val) 200 } 201 } 202 203 func Test_marshal_raw_struct(t *testing.T) { 204 should := require.New(t) 205 for _, c := range test.Combinations { 206 buf, proto := c.CreateProtocol() 207 proto.WriteStructBegin("hello") 208 proto.WriteFieldBegin("field1", thrift.I64, 1) 209 proto.WriteI64(1024) 210 proto.WriteFieldEnd() 211 proto.WriteFieldStop() 212 proto.WriteStructEnd() 213 var val raw.Struct 214 should.NoError(c.Unmarshal(buf.Bytes(), &val)) 215 output, err := c.Marshal(val) 216 should.NoError(err) 217 var generalVal general.Struct 218 should.NoError(c.Unmarshal(output, &generalVal)) 219 should.Equal(general.Struct{ 220 protocol.FieldId(1): int64(1024), 221 }, generalVal) 222 } 223 } 224 225 func Test_marshal_struct(t *testing.T) { 226 should := require.New(t) 227 for _, c := range test.MarshalCombinations { 228 output, err := c.Marshal(struct_test.TestObject{1024}) 229 should.NoError(err) 230 iter := c.CreateIterator(output) 231 called := false 232 iter.ReadStructHeader() 233 for { 234 fieldType, fieldId := iter.ReadStructField() 235 if fieldType == protocol.TypeStop { 236 break 237 } 238 should.False(called) 239 called = true 240 should.Equal(protocol.TypeI64, fieldType) 241 should.Equal(protocol.FieldId(1), fieldId) 242 should.Equal(int64(1024), iter.ReadInt64()) 243 } 244 should.True(called) 245 } 246 }