github.com/cloudwego/kitex@v0.9.0/pkg/protocol/bthrift/test/unknown_test.go (about) 1 /* 2 * Copyright 2023 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package test 18 19 import ( 20 "bytes" 21 "reflect" 22 "testing" 23 24 tt "github.com/cloudwego/kitex/internal/test" 25 "github.com/cloudwego/kitex/pkg/protocol/bthrift" 26 "github.com/cloudwego/kitex/pkg/protocol/bthrift/test/kitex_gen/test" 27 "github.com/cloudwego/kitex/pkg/remote" 28 codecThrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift" 29 ) 30 31 var fullReq *test.FullStruct 32 33 func init() { 34 desc := "aa" 35 status := test.HTTPStatus_NOT_FOUND 36 byte1 := int8(1) 37 double1 := 1.3 38 fullReq = &test.FullStruct{ 39 Left: 32, 40 Right: 45, 41 Dummy: []byte("test"), 42 InnerReq: &test.Inner{ 43 Num: 6, 44 Desc: &desc, 45 MapOfList: map[int64][]int64{42: {1, 2}}, 46 MapOfEnumKey: map[test.AEnum]int64{test.AEnum_A: 1, test.AEnum_B: 2}, 47 Byte1: &byte1, 48 Double1: &double1, 49 }, 50 Status: test.HTTPStatus_OK, 51 Str: "str", 52 EnumList: []test.HTTPStatus{test.HTTPStatus_NOT_FOUND, test.HTTPStatus_OK}, 53 Strmap: map[int32]string{ 54 10: "aa", 55 11: "bb", 56 }, 57 Int64: 5, 58 IntList: []int32{11, 22, 33}, 59 LocalList: []*test.Local{{L: 33}, nil}, 60 StrLocalMap: map[string]*test.Local{ 61 "bbb": { 62 L: 22, 63 }, 64 "ccc": { 65 L: 11, 66 }, 67 "ddd": nil, 68 }, 69 NestList: [][]int32{{3, 4}, {5, 6}}, 70 RequiredIns: &test.Local{ 71 L: 55, 72 }, 73 NestMap: map[string][]string{"aa": {"cc", "bb"}, "bb": {"xx", "yy"}}, 74 NestMap2: []map[string]test.HTTPStatus{{"ok": test.HTTPStatus_OK}}, 75 EnumMap: map[int32]test.HTTPStatus{ 76 0: test.HTTPStatus_NOT_FOUND, 77 1: test.HTTPStatus_OK, 78 }, 79 Strlist: []string{"mm", "nn"}, 80 OptStatus: &status, 81 Complex: map[test.HTTPStatus][]map[string]*test.Local{ 82 test.HTTPStatus_OK: { 83 {"": &test.Local{L: 3}}, 84 {"c": nil, "d": &test.Local{L: 42}}, 85 nil, 86 }, 87 test.HTTPStatus_NOT_FOUND: nil, 88 }, 89 I64Set: []int64{1, 2, 3}, 90 Int16: 98, 91 IsSet: true, 92 } 93 } 94 95 func TestOnlyUnknownField(t *testing.T) { 96 l := fullReq.BLength() 97 buf := make([]byte, l) 98 ll := fullReq.FastWriteNocopy(buf, nil) 99 tt.Assert(t, ll == l) 100 101 unknown := &test.EmptyStruct{} 102 ll, err := unknown.FastRead(buf) 103 tt.Assert(t, err == nil) 104 tt.Assert(t, ll == l) 105 unknownL := unknown.BLength() 106 tt.Assert(t, unknownL == l) 107 unknownBuf := make([]byte, unknownL) 108 writeL := unknown.FastWriteNocopy(unknownBuf, nil) 109 tt.Assert(t, writeL == l) 110 tt.Assert(t, bytes.Equal(buf, unknownBuf)) 111 112 // thrift read/write without fast api 113 trans := remote.NewReaderWriterBuffer(-1) 114 prot := codecThrift.NewBinaryProtocol(trans) 115 err = fullReq.Write(prot) 116 tt.Assert(t, err == nil) 117 unknown1 := &test.EmptyStruct{} 118 err = unknown1.Read(prot) 119 tt.Assert(t, err == nil) 120 tt.Assert(t, unknown.BLength() == unknown1.BLength()) 121 trans = remote.NewReaderWriterBuffer(-1) 122 prot = codecThrift.NewBinaryProtocol(trans) 123 err = unknown1.Write(prot) 124 tt.Assert(t, err == nil) 125 unknown1 = &test.EmptyStruct{} 126 err = unknown1.Read(prot) 127 tt.Assert(t, err == nil) 128 tt.Assert(t, unknown.BLength() == unknown1.BLength()) 129 130 // test get unknown fields 131 fields, err := bthrift.GetUnknownFields(unknown) 132 tt.Assert(t, err == nil) 133 l, err = bthrift.UnknownFieldsLength(fields) 134 tt.Assert(t, err == nil) 135 buf = make([]byte, l) 136 _, err = bthrift.WriteUnknownFields(buf, fields) 137 tt.Assert(t, err == nil) 138 tt.Assert(t, bytes.Equal(buf, reflect.ValueOf(unknown).Elem().FieldByName("_unknownFields").Bytes())) 139 } 140 141 func TestPartialUnknownField(t *testing.T) { 142 l := fullReq.BLength() 143 buf := make([]byte, l) 144 ll := fullReq.FastWriteNocopy(buf, nil) 145 tt.Assert(t, ll == l) 146 compare := &test.FullStruct{} 147 ll, err := compare.FastRead(buf) 148 tt.Assert(t, err == nil) 149 tt.Assert(t, ll == l) 150 151 unknown := &test.MixedStruct{} 152 ll, err = unknown.FastRead(buf) 153 tt.Assert(t, err == nil) 154 tt.Assert(t, ll == l) 155 unknownL := unknown.BLength() 156 unknownBuf := make([]byte, unknownL) 157 writeL := unknown.FastWriteNocopy(unknownBuf, nil) 158 tt.Assert(t, writeL == unknownL) 159 compare1 := &test.FullStruct{} 160 ll, err = compare1.FastRead(unknownBuf) 161 tt.Assert(t, err == nil) 162 tt.Assert(t, ll == unknownL) 163 tt.Assert(t, compare1.DeepEqual(compare)) 164 165 // thrift read/write without fast api 166 trans := remote.NewReaderWriterBuffer(-1) 167 prot := codecThrift.NewBinaryProtocol(trans) 168 err = fullReq.Write(prot) 169 tt.Assert(t, err == nil) 170 unknown1 := &test.MixedStruct{} 171 err = unknown1.Read(prot) 172 tt.Assert(t, err == nil) 173 tt.Assert(t, unknown.BLength() == unknown1.BLength()) 174 trans = remote.NewReaderWriterBuffer(-1) 175 prot = codecThrift.NewBinaryProtocol(trans) 176 err = unknown1.Write(prot) 177 tt.Assert(t, err == nil) 178 unknown1 = &test.MixedStruct{} 179 err = unknown1.Read(prot) 180 tt.Assert(t, err == nil) 181 tt.Assert(t, unknown.BLength() == unknown1.BLength()) 182 } 183 184 func TestNoUnknownField(t *testing.T) { 185 l := fullReq.BLength() 186 buf := make([]byte, l) 187 ll := fullReq.FastWriteNocopy(buf, nil) 188 tt.Assert(t, ll == l) 189 190 ori := &test.FullStruct{} 191 ll, err := ori.FastRead(buf) 192 tt.Assert(t, err == nil) 193 tt.Assert(t, ll == l) 194 195 // required fields 196 tt.Assert(t, ori.Field11DeepEqual([]*test.Local{{L: 33}, test.NewLocal()})) 197 tt.Assert(t, ori.Field12DeepEqual(map[string]*test.Local{ 198 "bbb": {L: 22}, "ccc": {L: 11}, "ddd": {}, 199 })) 200 tt.Assert(t, ori.Field21DeepEqual(test.NewInner())) 201 tt.Assert(t, ori.Field28DeepEqual(map[test.HTTPStatus][]map[string]*test.Local{ 202 test.HTTPStatus_OK: { 203 {"": &test.Local{L: 3}}, 204 {"c": {}, "d": &test.Local{L: 42}}, 205 nil, 206 }, 207 test.HTTPStatus_NOT_FOUND: nil, 208 })) 209 ori.LocalList[1] = nil 210 ori.StrLocalMap["ddd"] = nil 211 ori.AnotherInner = nil 212 ori.Complex[test.HTTPStatus_OK][1]["c"] = nil 213 214 tt.Assert(t, ori.Field1DeepEqual(fullReq.Left)) 215 tt.Assert(t, ori.Field2DeepEqual(fullReq.Right)) 216 tt.Assert(t, ori.Field3DeepEqual(fullReq.Dummy)) 217 tt.Assert(t, ori.Field4DeepEqual(fullReq.InnerReq)) 218 tt.Assert(t, ori.Field5DeepEqual(fullReq.Status)) 219 tt.Assert(t, ori.Field6DeepEqual(fullReq.Str)) 220 tt.Assert(t, ori.Field7DeepEqual(fullReq.EnumList)) 221 tt.Assert(t, ori.Field8DeepEqual(fullReq.Strmap)) 222 tt.Assert(t, ori.Field9DeepEqual(fullReq.Int64)) 223 tt.Assert(t, ori.Field10DeepEqual(fullReq.IntList)) 224 tt.Assert(t, ori.Field11DeepEqual(fullReq.LocalList)) 225 tt.Assert(t, ori.Field12DeepEqual(fullReq.StrLocalMap)) 226 tt.Assert(t, ori.Field13DeepEqual(fullReq.NestList)) 227 tt.Assert(t, ori.Field14DeepEqual(fullReq.RequiredIns)) 228 tt.Assert(t, ori.Field16DeepEqual(fullReq.NestMap)) 229 tt.Assert(t, ori.Field17DeepEqual(fullReq.NestMap2)) 230 tt.Assert(t, ori.Field18DeepEqual(fullReq.EnumMap)) 231 tt.Assert(t, ori.Field19DeepEqual(fullReq.Strlist)) 232 tt.Assert(t, ori.Field20DeepEqual(fullReq.OptionalIns)) 233 tt.Assert(t, ori.Field21DeepEqual(fullReq.AnotherInner)) 234 tt.Assert(t, ori.Field22DeepEqual(fullReq.OptNilList)) 235 tt.Assert(t, ori.Field23DeepEqual(fullReq.NilList)) 236 tt.Assert(t, ori.Field24DeepEqual(fullReq.OptNilInsList)) 237 tt.Assert(t, ori.Field25DeepEqual(fullReq.NilInsList)) 238 tt.Assert(t, ori.Field26DeepEqual(fullReq.OptStatus)) 239 tt.Assert(t, ori.Field27DeepEqual(fullReq.EnumKeyMap)) 240 tt.Assert(t, ori.Field28DeepEqual(fullReq.Complex)) 241 } 242 243 func BenchmarkOnlyUnknownField(b *testing.B) { 244 l := fullReq.BLength() 245 buf := make([]byte, l) 246 ll := fullReq.FastWriteNocopy(buf, nil) 247 tt.Assert(b, ll == l) 248 249 unknownBuf := make([]byte, l) 250 for i := 0; i < b.N; i++ { 251 unknown := &test.EmptyStruct{} 252 _, _ = unknown.FastRead(buf) 253 unknown.FastWriteNocopy(unknownBuf, nil) 254 } 255 } 256 257 //func TestCorruptWrite(t *testing.T) { 258 // local := &test.Local{L: 3} 259 // ufs := unknown.Fields{&unknown.Field{Type: 1000}} 260 // local.SetUnknown(ufs) 261 // 262 // defer func() { 263 // e := recover() 264 // if strings.Contains(e.(error).Error(), "unknown data type 1000") { 265 // return 266 // } 267 // tt.Assert(t, false, e) 268 // }() 269 // _ = local.BLength() 270 // tt.Assert(t, false) 271 //} 272 // 273 //func TestCorruptRead(t *testing.T) { 274 // local := &test.Local{L: 3} 275 // ufs := unknown.Fields{&unknown.Field{Name: "test", Type: unknown.TString, Value: "str"}} 276 // local.SetUnknown(ufs) 277 // l := local.BLength() 278 // buf := make([]byte, l) 279 // ll := local.FastWriteNocopy(buf, nil) 280 // tt.Assert(t, ll == l) 281 // buf[7] = 200 282 // 283 // var local2 test.Local 284 // _, err := local2.FastRead(buf) 285 // tt.Assert(t, err != nil) 286 // tt.Assert(t, strings.Contains(err.Error(), "unknown data type 200")) 287 //}