github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/test/level_1/struct_test.go (about)

     1  package test
     2  
     3  import (
     4  	"github.com/apache/thrift/lib/go/thrift"
     5  	"github.com/batchcorp/thrift-iterator/general"
     6  	"github.com/batchcorp/thrift-iterator/protocol"
     7  	"github.com/batchcorp/thrift-iterator/raw"
     8  	"github.com/batchcorp/thrift-iterator/test"
     9  	"github.com/batchcorp/thrift-iterator/test/level_1/struct_test"
    10  	"github.com/stretchr/testify/require"
    11  	"testing"
    12  )
    13  
    14  func Test_decode_struct_by_iterator(t *testing.T) {
    15  	should := require.New(t)
    16  	for _, c := range test.Combinations {
    17  		buf, proto := c.CreateProtocol()
    18  		proto.WriteStructBegin("hello")
    19  		proto.WriteFieldBegin("field1", thrift.I64, 1)
    20  		proto.WriteI64(1024)
    21  		proto.WriteFieldEnd()
    22  		proto.WriteFieldStop()
    23  		proto.WriteStructEnd()
    24  		iter := c.CreateIterator(buf.Bytes())
    25  		called := false
    26  		iter.ReadStructHeader()
    27  		for {
    28  			fieldType, fieldId := iter.ReadStructField()
    29  			if fieldType == protocol.TypeStop {
    30  				break
    31  			}
    32  			should.False(called)
    33  			called = true
    34  			should.Equal(protocol.TypeI64, fieldType)
    35  			should.Equal(protocol.FieldId(1), fieldId)
    36  			should.Equal(int64(1024), iter.ReadInt64())
    37  		}
    38  		should.NoError(iter.Error())
    39  		should.True(called)
    40  	}
    41  }
    42  
    43  func Test_decode_struct_with_bool_by_iterator(t *testing.T) {
    44  	should := require.New(t)
    45  	for _, c := range test.Combinations {
    46  		buf, proto := c.CreateProtocol()
    47  		proto.WriteStructBegin("hello")
    48  		proto.WriteFieldBegin("field1", thrift.BOOL, 1)
    49  		proto.WriteBool(true)
    50  		proto.WriteFieldEnd()
    51  		proto.WriteFieldStop()
    52  		proto.WriteStructEnd()
    53  		iter := c.CreateIterator(buf.Bytes())
    54  		called := false
    55  		iter.ReadStructHeader()
    56  		for {
    57  			fieldType, fieldId := iter.ReadStructField()
    58  			if fieldType == protocol.TypeStop {
    59  				break
    60  			}
    61  			should.False(called)
    62  			called = true
    63  			should.Equal(protocol.TypeBool, fieldType)
    64  			should.Equal(protocol.FieldId(1), fieldId)
    65  			should.Equal(true, iter.ReadBool())
    66  		}
    67  		should.True(called)
    68  	}
    69  }
    70  
    71  func Test_encode_struct_by_stream(t *testing.T) {
    72  	should := require.New(t)
    73  	for _, c := range test.Combinations {
    74  		stream := c.CreateStream()
    75  		stream.WriteStructHeader()
    76  		stream.WriteStructField(protocol.TypeI64, protocol.FieldId(1))
    77  		stream.WriteInt64(1024)
    78  		stream.WriteStructFieldStop()
    79  		iter := c.CreateIterator(stream.Buffer())
    80  		called := false
    81  		iter.ReadStructHeader()
    82  		for {
    83  			fieldType, fieldId := iter.ReadStructField()
    84  			if fieldType == protocol.TypeStop {
    85  				break
    86  			}
    87  			should.False(called)
    88  			called = true
    89  			should.Equal(protocol.TypeI64, fieldType)
    90  			should.Equal(protocol.FieldId(1), fieldId)
    91  			should.Equal(int64(1024), iter.ReadInt64())
    92  		}
    93  	}
    94  }
    95  
    96  func Test_encode_struct_with_bool_by_stream(t *testing.T) {
    97  	should := require.New(t)
    98  	for _, c := range test.Combinations {
    99  		stream := c.CreateStream()
   100  		stream.WriteStructHeader()
   101  		stream.WriteStructField(protocol.TypeBool, protocol.FieldId(1))
   102  		stream.WriteBool(true)
   103  		stream.WriteStructFieldStop()
   104  		iter := c.CreateIterator(stream.Buffer())
   105  		called := false
   106  		iter.ReadStructHeader()
   107  		for {
   108  			fieldType, fieldId := iter.ReadStructField()
   109  			if fieldType == protocol.TypeStop {
   110  				break
   111  			}
   112  			should.False(called)
   113  			called = true
   114  			should.Equal(protocol.TypeBool, fieldType)
   115  			should.Equal(protocol.FieldId(1), fieldId)
   116  			should.Equal(true, iter.ReadBool())
   117  		}
   118  		should.True(called)
   119  	}
   120  }
   121  
   122  func Test_skip_struct(t *testing.T) {
   123  	should := require.New(t)
   124  	for _, c := range test.Combinations {
   125  		buf, proto := c.CreateProtocol()
   126  		proto.WriteStructBegin("hello")
   127  		proto.WriteFieldBegin("field1", thrift.I64, 1)
   128  		proto.WriteI64(1024)
   129  		proto.WriteFieldEnd()
   130  		proto.WriteFieldStop()
   131  		proto.WriteStructEnd()
   132  		iter := c.CreateIterator(buf.Bytes())
   133  		should.Equal(buf.Bytes(), iter.SkipStruct(nil))
   134  	}
   135  }
   136  
   137  func Test_unmarshal_general_struct(t *testing.T) {
   138  	should := require.New(t)
   139  	for _, c := range test.Combinations {
   140  		buf, proto := c.CreateProtocol()
   141  		proto.WriteStructBegin("hello")
   142  		proto.WriteFieldBegin("field1", thrift.I64, 1)
   143  		proto.WriteI64(1024)
   144  		proto.WriteFieldEnd()
   145  		proto.WriteFieldStop()
   146  		proto.WriteStructEnd()
   147  		var val general.Struct
   148  		should.NoError(c.Unmarshal(buf.Bytes(), &val))
   149  		should.Equal(general.Struct{
   150  			protocol.FieldId(1): int64(1024),
   151  		}, val)
   152  	}
   153  }
   154  
   155  func Test_unmarshal_raw_struct(t *testing.T) {
   156  	should := require.New(t)
   157  	for _, c := range test.Combinations {
   158  		buf, proto := c.CreateProtocol()
   159  		proto.WriteStructBegin("hello")
   160  		proto.WriteFieldBegin("field1", thrift.I64, 1)
   161  		proto.WriteI64(1024)
   162  		proto.WriteFieldEnd()
   163  		proto.WriteFieldStop()
   164  		proto.WriteStructEnd()
   165  		var val raw.Struct
   166  		should.NoError(c.Unmarshal(buf.Bytes(), &val))
   167  		should.Equal(1, len(val))
   168  		should.Equal(protocol.TypeI64, val[protocol.FieldId(1)].Type)
   169  	}
   170  }
   171  
   172  func Test_unmarshal_struct(t *testing.T) {
   173  	should := require.New(t)
   174  	for _, c := range test.UnmarshalCombinations {
   175  		buf, proto := c.CreateProtocol()
   176  		proto.WriteStructBegin("hello")
   177  		proto.WriteFieldBegin("field1", thrift.I64, 1)
   178  		proto.WriteI64(1024)
   179  		proto.WriteFieldEnd()
   180  		proto.WriteFieldStop()
   181  		proto.WriteStructEnd()
   182  		var val struct_test.TestObject
   183  		should.NoError(c.Unmarshal(buf.Bytes(), &val))
   184  		should.Equal(struct_test.TestObject{1024}, val)
   185  	}
   186  }
   187  
   188  func Test_marshal_general_struct(t *testing.T) {
   189  	should := require.New(t)
   190  	for _, c := range test.Combinations {
   191  		output, err := c.Marshal(general.Struct{
   192  			protocol.FieldId(1): int64(1024),
   193  		})
   194  		should.NoError(err)
   195  		var val general.Struct
   196  		should.NoError(c.Unmarshal(output, &val))
   197  		should.Equal(general.Struct{
   198  			protocol.FieldId(1): int64(1024),
   199  		}, val)
   200  	}
   201  }
   202  
   203  func Test_marshal_raw_struct(t *testing.T) {
   204  	should := require.New(t)
   205  	for _, c := range test.Combinations {
   206  		buf, proto := c.CreateProtocol()
   207  		proto.WriteStructBegin("hello")
   208  		proto.WriteFieldBegin("field1", thrift.I64, 1)
   209  		proto.WriteI64(1024)
   210  		proto.WriteFieldEnd()
   211  		proto.WriteFieldStop()
   212  		proto.WriteStructEnd()
   213  		var val raw.Struct
   214  		should.NoError(c.Unmarshal(buf.Bytes(), &val))
   215  		output, err := c.Marshal(val)
   216  		should.NoError(err)
   217  		var generalVal general.Struct
   218  		should.NoError(c.Unmarshal(output, &generalVal))
   219  		should.Equal(general.Struct{
   220  			protocol.FieldId(1): int64(1024),
   221  		}, generalVal)
   222  	}
   223  }
   224  
   225  func Test_marshal_struct(t *testing.T) {
   226  	should := require.New(t)
   227  	for _, c := range test.MarshalCombinations {
   228  		output, err := c.Marshal(struct_test.TestObject{1024})
   229  		should.NoError(err)
   230  		iter := c.CreateIterator(output)
   231  		called := false
   232  		iter.ReadStructHeader()
   233  		for {
   234  			fieldType, fieldId := iter.ReadStructField()
   235  			if fieldType == protocol.TypeStop {
   236  				break
   237  			}
   238  			should.False(called)
   239  			called = true
   240  			should.Equal(protocol.TypeI64, fieldType)
   241  			should.Equal(protocol.FieldId(1), fieldId)
   242  			should.Equal(int64(1024), iter.ReadInt64())
   243  		}
   244  		should.True(called)
   245  	}
   246  }