github.com/cloudwego/kitex@v0.9.0/pkg/protocol/bthrift/test/unknown_test.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package test
    18  
    19  import (
    20  	"bytes"
    21  	"reflect"
    22  	"testing"
    23  
    24  	tt "github.com/cloudwego/kitex/internal/test"
    25  	"github.com/cloudwego/kitex/pkg/protocol/bthrift"
    26  	"github.com/cloudwego/kitex/pkg/protocol/bthrift/test/kitex_gen/test"
    27  	"github.com/cloudwego/kitex/pkg/remote"
    28  	codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift"
    29  )
    30  
    31  var fullReq *test.FullStruct
    32  
    33  func init() {
    34  	desc := "aa"
    35  	status := test.HTTPStatus_NOT_FOUND
    36  	byte1 := int8(1)
    37  	double1 := 1.3
    38  	fullReq = &test.FullStruct{
    39  		Left:  32,
    40  		Right: 45,
    41  		Dummy: []byte("test"),
    42  		InnerReq: &test.Inner{
    43  			Num:          6,
    44  			Desc:         &desc,
    45  			MapOfList:    map[int64][]int64{42: {1, 2}},
    46  			MapOfEnumKey: map[test.AEnum]int64{test.AEnum_A: 1, test.AEnum_B: 2},
    47  			Byte1:        &byte1,
    48  			Double1:      &double1,
    49  		},
    50  		Status:   test.HTTPStatus_OK,
    51  		Str:      "str",
    52  		EnumList: []test.HTTPStatus{test.HTTPStatus_NOT_FOUND, test.HTTPStatus_OK},
    53  		Strmap: map[int32]string{
    54  			10: "aa",
    55  			11: "bb",
    56  		},
    57  		Int64:     5,
    58  		IntList:   []int32{11, 22, 33},
    59  		LocalList: []*test.Local{{L: 33}, nil},
    60  		StrLocalMap: map[string]*test.Local{
    61  			"bbb": {
    62  				L: 22,
    63  			},
    64  			"ccc": {
    65  				L: 11,
    66  			},
    67  			"ddd": nil,
    68  		},
    69  		NestList: [][]int32{{3, 4}, {5, 6}},
    70  		RequiredIns: &test.Local{
    71  			L: 55,
    72  		},
    73  		NestMap:  map[string][]string{"aa": {"cc", "bb"}, "bb": {"xx", "yy"}},
    74  		NestMap2: []map[string]test.HTTPStatus{{"ok": test.HTTPStatus_OK}},
    75  		EnumMap: map[int32]test.HTTPStatus{
    76  			0: test.HTTPStatus_NOT_FOUND,
    77  			1: test.HTTPStatus_OK,
    78  		},
    79  		Strlist:   []string{"mm", "nn"},
    80  		OptStatus: &status,
    81  		Complex: map[test.HTTPStatus][]map[string]*test.Local{
    82  			test.HTTPStatus_OK: {
    83  				{"": &test.Local{L: 3}},
    84  				{"c": nil, "d": &test.Local{L: 42}},
    85  				nil,
    86  			},
    87  			test.HTTPStatus_NOT_FOUND: nil,
    88  		},
    89  		I64Set: []int64{1, 2, 3},
    90  		Int16:  98,
    91  		IsSet:  true,
    92  	}
    93  }
    94  
    95  func TestOnlyUnknownField(t *testing.T) {
    96  	l := fullReq.BLength()
    97  	buf := make([]byte, l)
    98  	ll := fullReq.FastWriteNocopy(buf, nil)
    99  	tt.Assert(t, ll == l)
   100  
   101  	unknown := &test.EmptyStruct{}
   102  	ll, err := unknown.FastRead(buf)
   103  	tt.Assert(t, err == nil)
   104  	tt.Assert(t, ll == l)
   105  	unknownL := unknown.BLength()
   106  	tt.Assert(t, unknownL == l)
   107  	unknownBuf := make([]byte, unknownL)
   108  	writeL := unknown.FastWriteNocopy(unknownBuf, nil)
   109  	tt.Assert(t, writeL == l)
   110  	tt.Assert(t, bytes.Equal(buf, unknownBuf))
   111  
   112  	// thrift read/write without fast api
   113  	trans := remote.NewReaderWriterBuffer(-1)
   114  	prot := codecThrift.NewBinaryProtocol(trans)
   115  	err = fullReq.Write(prot)
   116  	tt.Assert(t, err == nil)
   117  	unknown1 := &test.EmptyStruct{}
   118  	err = unknown1.Read(prot)
   119  	tt.Assert(t, err == nil)
   120  	tt.Assert(t, unknown.BLength() == unknown1.BLength())
   121  	trans = remote.NewReaderWriterBuffer(-1)
   122  	prot = codecThrift.NewBinaryProtocol(trans)
   123  	err = unknown1.Write(prot)
   124  	tt.Assert(t, err == nil)
   125  	unknown1 = &test.EmptyStruct{}
   126  	err = unknown1.Read(prot)
   127  	tt.Assert(t, err == nil)
   128  	tt.Assert(t, unknown.BLength() == unknown1.BLength())
   129  
   130  	// test get unknown fields
   131  	fields, err := bthrift.GetUnknownFields(unknown)
   132  	tt.Assert(t, err == nil)
   133  	l, err = bthrift.UnknownFieldsLength(fields)
   134  	tt.Assert(t, err == nil)
   135  	buf = make([]byte, l)
   136  	_, err = bthrift.WriteUnknownFields(buf, fields)
   137  	tt.Assert(t, err == nil)
   138  	tt.Assert(t, bytes.Equal(buf, reflect.ValueOf(unknown).Elem().FieldByName("_unknownFields").Bytes()))
   139  }
   140  
   141  func TestPartialUnknownField(t *testing.T) {
   142  	l := fullReq.BLength()
   143  	buf := make([]byte, l)
   144  	ll := fullReq.FastWriteNocopy(buf, nil)
   145  	tt.Assert(t, ll == l)
   146  	compare := &test.FullStruct{}
   147  	ll, err := compare.FastRead(buf)
   148  	tt.Assert(t, err == nil)
   149  	tt.Assert(t, ll == l)
   150  
   151  	unknown := &test.MixedStruct{}
   152  	ll, err = unknown.FastRead(buf)
   153  	tt.Assert(t, err == nil)
   154  	tt.Assert(t, ll == l)
   155  	unknownL := unknown.BLength()
   156  	unknownBuf := make([]byte, unknownL)
   157  	writeL := unknown.FastWriteNocopy(unknownBuf, nil)
   158  	tt.Assert(t, writeL == unknownL)
   159  	compare1 := &test.FullStruct{}
   160  	ll, err = compare1.FastRead(unknownBuf)
   161  	tt.Assert(t, err == nil)
   162  	tt.Assert(t, ll == unknownL)
   163  	tt.Assert(t, compare1.DeepEqual(compare))
   164  
   165  	// thrift read/write without fast api
   166  	trans := remote.NewReaderWriterBuffer(-1)
   167  	prot := codecThrift.NewBinaryProtocol(trans)
   168  	err = fullReq.Write(prot)
   169  	tt.Assert(t, err == nil)
   170  	unknown1 := &test.MixedStruct{}
   171  	err = unknown1.Read(prot)
   172  	tt.Assert(t, err == nil)
   173  	tt.Assert(t, unknown.BLength() == unknown1.BLength())
   174  	trans = remote.NewReaderWriterBuffer(-1)
   175  	prot = codecThrift.NewBinaryProtocol(trans)
   176  	err = unknown1.Write(prot)
   177  	tt.Assert(t, err == nil)
   178  	unknown1 = &test.MixedStruct{}
   179  	err = unknown1.Read(prot)
   180  	tt.Assert(t, err == nil)
   181  	tt.Assert(t, unknown.BLength() == unknown1.BLength())
   182  }
   183  
   184  func TestNoUnknownField(t *testing.T) {
   185  	l := fullReq.BLength()
   186  	buf := make([]byte, l)
   187  	ll := fullReq.FastWriteNocopy(buf, nil)
   188  	tt.Assert(t, ll == l)
   189  
   190  	ori := &test.FullStruct{}
   191  	ll, err := ori.FastRead(buf)
   192  	tt.Assert(t, err == nil)
   193  	tt.Assert(t, ll == l)
   194  
   195  	// required fields
   196  	tt.Assert(t, ori.Field11DeepEqual([]*test.Local{{L: 33}, test.NewLocal()}))
   197  	tt.Assert(t, ori.Field12DeepEqual(map[string]*test.Local{
   198  		"bbb": {L: 22}, "ccc": {L: 11}, "ddd": {},
   199  	}))
   200  	tt.Assert(t, ori.Field21DeepEqual(test.NewInner()))
   201  	tt.Assert(t, ori.Field28DeepEqual(map[test.HTTPStatus][]map[string]*test.Local{
   202  		test.HTTPStatus_OK: {
   203  			{"": &test.Local{L: 3}},
   204  			{"c": {}, "d": &test.Local{L: 42}},
   205  			nil,
   206  		},
   207  		test.HTTPStatus_NOT_FOUND: nil,
   208  	}))
   209  	ori.LocalList[1] = nil
   210  	ori.StrLocalMap["ddd"] = nil
   211  	ori.AnotherInner = nil
   212  	ori.Complex[test.HTTPStatus_OK][1]["c"] = nil
   213  
   214  	tt.Assert(t, ori.Field1DeepEqual(fullReq.Left))
   215  	tt.Assert(t, ori.Field2DeepEqual(fullReq.Right))
   216  	tt.Assert(t, ori.Field3DeepEqual(fullReq.Dummy))
   217  	tt.Assert(t, ori.Field4DeepEqual(fullReq.InnerReq))
   218  	tt.Assert(t, ori.Field5DeepEqual(fullReq.Status))
   219  	tt.Assert(t, ori.Field6DeepEqual(fullReq.Str))
   220  	tt.Assert(t, ori.Field7DeepEqual(fullReq.EnumList))
   221  	tt.Assert(t, ori.Field8DeepEqual(fullReq.Strmap))
   222  	tt.Assert(t, ori.Field9DeepEqual(fullReq.Int64))
   223  	tt.Assert(t, ori.Field10DeepEqual(fullReq.IntList))
   224  	tt.Assert(t, ori.Field11DeepEqual(fullReq.LocalList))
   225  	tt.Assert(t, ori.Field12DeepEqual(fullReq.StrLocalMap))
   226  	tt.Assert(t, ori.Field13DeepEqual(fullReq.NestList))
   227  	tt.Assert(t, ori.Field14DeepEqual(fullReq.RequiredIns))
   228  	tt.Assert(t, ori.Field16DeepEqual(fullReq.NestMap))
   229  	tt.Assert(t, ori.Field17DeepEqual(fullReq.NestMap2))
   230  	tt.Assert(t, ori.Field18DeepEqual(fullReq.EnumMap))
   231  	tt.Assert(t, ori.Field19DeepEqual(fullReq.Strlist))
   232  	tt.Assert(t, ori.Field20DeepEqual(fullReq.OptionalIns))
   233  	tt.Assert(t, ori.Field21DeepEqual(fullReq.AnotherInner))
   234  	tt.Assert(t, ori.Field22DeepEqual(fullReq.OptNilList))
   235  	tt.Assert(t, ori.Field23DeepEqual(fullReq.NilList))
   236  	tt.Assert(t, ori.Field24DeepEqual(fullReq.OptNilInsList))
   237  	tt.Assert(t, ori.Field25DeepEqual(fullReq.NilInsList))
   238  	tt.Assert(t, ori.Field26DeepEqual(fullReq.OptStatus))
   239  	tt.Assert(t, ori.Field27DeepEqual(fullReq.EnumKeyMap))
   240  	tt.Assert(t, ori.Field28DeepEqual(fullReq.Complex))
   241  }
   242  
   243  func BenchmarkOnlyUnknownField(b *testing.B) {
   244  	l := fullReq.BLength()
   245  	buf := make([]byte, l)
   246  	ll := fullReq.FastWriteNocopy(buf, nil)
   247  	tt.Assert(b, ll == l)
   248  
   249  	unknownBuf := make([]byte, l)
   250  	for i := 0; i < b.N; i++ {
   251  		unknown := &test.EmptyStruct{}
   252  		_, _ = unknown.FastRead(buf)
   253  		unknown.FastWriteNocopy(unknownBuf, nil)
   254  	}
   255  }
   256  
   257  //func TestCorruptWrite(t *testing.T) {
   258  //	local := &test.Local{L: 3}
   259  //	ufs := unknown.Fields{&unknown.Field{Type: 1000}}
   260  //	local.SetUnknown(ufs)
   261  //
   262  //	defer func() {
   263  //		e := recover()
   264  //		if strings.Contains(e.(error).Error(), "unknown data type 1000") {
   265  //			return
   266  //		}
   267  //		tt.Assert(t, false, e)
   268  //	}()
   269  //	_ = local.BLength()
   270  //	tt.Assert(t, false)
   271  //}
   272  //
   273  //func TestCorruptRead(t *testing.T) {
   274  //	local := &test.Local{L: 3}
   275  //	ufs := unknown.Fields{&unknown.Field{Name: "test", Type: unknown.TString, Value: "str"}}
   276  //	local.SetUnknown(ufs)
   277  //	l := local.BLength()
   278  //	buf := make([]byte, l)
   279  //	ll := local.FastWriteNocopy(buf, nil)
   280  //	tt.Assert(t, ll == l)
   281  //	buf[7] = 200
   282  //
   283  //	var local2 test.Local
   284  //	_, err := local2.FastRead(buf)
   285  //	tt.Assert(t, err != nil)
   286  //	tt.Assert(t, strings.Contains(err.Error(), "unknown data type 200"))
   287  //}