github.com/cloudwego/kitex@v0.9.0/pkg/generic/thrift/read.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package thrift 18 19 import ( 20 "context" 21 "encoding/base64" 22 "fmt" 23 24 "github.com/apache/thrift/lib/go/thrift" 25 "github.com/jhump/protoreflect/desc" 26 27 "github.com/cloudwego/kitex/pkg/generic/descriptor" 28 "github.com/cloudwego/kitex/pkg/generic/proto" 29 ) 30 31 var emptyPbDsc = &desc.MessageDescriptor{} 32 33 type readerOption struct { 34 // result will be encode to json, so map[interface{}]interface{} will not be valid 35 // need use map[string]interface{} instead 36 forJSON bool 37 // return exception as error 38 throwException bool 39 // read http response 40 http bool 41 binaryWithBase64 bool 42 binaryWithByteSlice bool 43 // describe struct of current level 44 pbDsc proto.MessageDescriptor 45 } 46 47 type reader func(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) 48 49 type fieldSetter func(field *descriptor.FieldDescriptor, val interface{}) error 50 51 func getMapFieldSetter(st map[string]interface{}) fieldSetter { 52 return func(field *descriptor.FieldDescriptor, val interface{}) error { 53 st[field.FieldName()] = val 54 return nil 55 } 56 } 57 58 func getPbFieldSetter(st proto.Message) fieldSetter { 59 return func(field *descriptor.FieldDescriptor, val interface{}) error { 60 return st.TrySetFieldByNumber(int(field.ID), val) 61 } 62 } 63 64 func nextReader(tt descriptor.Type, t *descriptor.TypeDescriptor, opt *readerOption) (reader, error) { 65 if err := assertType(tt, t.Type); err != nil { 66 return nil, err 67 } 68 switch tt { 69 case descriptor.BOOL: 70 return readBool, nil 71 case descriptor.BYTE: 72 return readByte, nil 73 case descriptor.I16: 74 return readInt16, nil 75 case descriptor.I32: 76 return readInt32, nil 77 case descriptor.I64: 78 return readInt64, nil 79 case descriptor.STRING: 80 if t.Name == "binary" { 81 if opt.binaryWithByteSlice { 82 return readBinary, nil 83 } else if opt.binaryWithBase64 { 84 return readBase64Binary, nil 85 } 86 } 87 return readString, nil 88 case descriptor.DOUBLE: 89 return readDouble, nil 90 case descriptor.LIST: 91 return readList, nil 92 case descriptor.SET: 93 return readList, nil 94 case descriptor.MAP: 95 return readMap, nil 96 case descriptor.STRUCT: 97 return readStruct, nil 98 case descriptor.VOID: 99 return readVoid, nil 100 case descriptor.JSON: 101 return readStruct, nil 102 default: 103 return nil, fmt.Errorf("unsupported type: %d", tt) 104 } 105 } 106 107 func skipStructReader(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 108 structName, err := in.ReadStructBegin() 109 if err != nil { 110 return nil, err 111 } 112 var v interface{} 113 for { 114 fieldName, fieldType, fieldID, err := in.ReadFieldBegin() 115 if err != nil { 116 return nil, err 117 } 118 if fieldType == thrift.STOP { 119 break 120 } 121 field, ok := t.Struct.FieldsByID[int32(fieldID)] 122 if !ok { 123 // just ignore the missing field, maybe server update its idls 124 if err := in.Skip(fieldType); err != nil { 125 return nil, err 126 } 127 } else { 128 _fieldType := descriptor.FromThriftTType(fieldType) 129 reader, err := nextReader(_fieldType, field.Type, opt) 130 if err != nil { 131 return nil, fmt.Errorf("nextReader of %s/%s/%d error %w", structName, fieldName, fieldID, err) 132 } 133 if field.IsException && opt != nil && opt.throwException { 134 if v, err = reader(ctx, in, field.Type, opt); err != nil { 135 return nil, err 136 } 137 // return exception as error 138 return nil, fmt.Errorf("%#v", v) 139 } 140 if opt != nil && opt.http { 141 // use http response reader when http generic call 142 // only support struct response method, return error when use base type response 143 reader = readHTTPResponse 144 } 145 if v, err = reader(ctx, in, field.Type, opt); err != nil { 146 return nil, fmt.Errorf("reader of %s/%s/%d error %w", structName, fieldName, fieldID, err) 147 } 148 } 149 if err := in.ReadFieldEnd(); err != nil { 150 return nil, err 151 } 152 } 153 154 return v, in.ReadStructEnd() 155 } 156 157 func readVoid(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 158 _, err := readStruct(ctx, in, t, opt) 159 return descriptor.Void{}, err 160 } 161 162 func readDouble(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 163 return in.ReadDouble() 164 } 165 166 func readBool(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 167 return in.ReadBool() 168 } 169 170 func readByte(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 171 res, err := in.ReadByte() 172 if err != nil { 173 return nil, err 174 } 175 if opt.pbDsc != nil { 176 return int32(res), nil 177 } 178 return res, nil 179 } 180 181 func readInt16(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 182 res, err := in.ReadI16() 183 if err != nil { 184 return nil, err 185 } 186 if opt.pbDsc != nil { 187 return int32(res), nil 188 } 189 return res, nil 190 } 191 192 func readInt32(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 193 return in.ReadI32() 194 } 195 196 func readInt64(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 197 return in.ReadI64() 198 } 199 200 func readString(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 201 return in.ReadString() 202 } 203 204 func readBinary(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 205 bytes, err := in.ReadBinary() 206 if err != nil { 207 return "", err 208 } 209 return bytes, nil 210 } 211 212 func readBase64Binary(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 213 bytes, err := in.ReadBinary() 214 if err != nil { 215 return "", err 216 } 217 return base64.StdEncoding.EncodeToString(bytes), nil 218 } 219 220 func readList(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 221 elemType, length, err := in.ReadListBegin() 222 if err != nil { 223 return nil, err 224 } 225 _elemType := descriptor.FromThriftTType(elemType) 226 reader, err := nextReader(_elemType, t.Elem, opt) 227 if err != nil { 228 return nil, err 229 } 230 l := make([]interface{}, 0, length) 231 for i := 0; i < length; i++ { 232 item, err := reader(ctx, in, t.Elem, opt) 233 if err != nil { 234 return nil, err 235 } 236 l = append(l, item) 237 } 238 return l, in.ReadListEnd() 239 } 240 241 func readMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 242 if opt != nil && opt.forJSON { 243 return readStringMap(ctx, in, t, opt) 244 } 245 return readInterfaceMap(ctx, in, t, opt) 246 } 247 248 func readInterfaceMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 249 keyType, elemType, length, err := in.ReadMapBegin() 250 if err != nil { 251 return nil, err 252 } 253 m := make(map[interface{}]interface{}, length) 254 if length == 0 { 255 return m, nil 256 } 257 _keyType := descriptor.FromThriftTType(keyType) 258 keyReader, err := nextReader(_keyType, t.Key, opt) 259 if err != nil { 260 return nil, err 261 } 262 _elemType := descriptor.FromThriftTType(elemType) 263 elemReader, err := nextReader(_elemType, t.Elem, opt) 264 if err != nil { 265 return nil, err 266 } 267 for i := 0; i < length; i++ { 268 nest := unnestPb(opt, 1) 269 key, err := keyReader(ctx, in, t.Key, opt) 270 if err != nil { 271 return nil, err 272 } 273 nest() 274 nest = unnestPb(opt, 2) 275 elem, err := elemReader(ctx, in, t.Elem, opt) 276 if err != nil { 277 return nil, err 278 } 279 nest() 280 m[key] = elem 281 } 282 return m, in.ReadMapEnd() 283 } 284 285 func readStringMap(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 286 keyType, elemType, length, err := in.ReadMapBegin() 287 if err != nil { 288 return nil, err 289 } 290 m := make(map[string]interface{}, length) 291 if length == 0 { 292 return m, nil 293 } 294 _keyType := descriptor.FromThriftTType(keyType) 295 keyReader, err := nextReader(_keyType, t.Key, opt) 296 if err != nil { 297 return nil, err 298 } 299 _elemType := descriptor.FromThriftTType(elemType) 300 elemReader, err := nextReader(_elemType, t.Elem, opt) 301 if err != nil { 302 return nil, err 303 } 304 for i := 0; i < length; i++ { 305 key, err := keyReader(ctx, in, t.Key, opt) 306 if err != nil { 307 return nil, err 308 } 309 elem, err := elemReader(ctx, in, t.Elem, opt) 310 if err != nil { 311 return nil, err 312 } 313 m[buildinTypeIntoString(key)] = elem 314 } 315 return m, in.ReadMapEnd() 316 } 317 318 func readStruct(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 319 var fs fieldSetter 320 var st interface{} 321 if opt == nil || opt.pbDsc == nil { 322 if opt == nil { 323 opt = &readerOption{} 324 } 325 holder := map[string]interface{}{} 326 fs = getMapFieldSetter(holder) 327 st = holder 328 } else { 329 holder := proto.NewMessage(opt.pbDsc) 330 fs = getPbFieldSetter(holder) 331 st = holder 332 } 333 334 var err error 335 // set default value 336 // void is nil struct 337 // default value with struct NOT SUPPORT pb. 338 if t.Struct != nil { 339 for _, field := range t.Struct.DefaultFields { 340 val := field.DefaultValue 341 if field.ValueMapping != nil { 342 if val, err = field.ValueMapping.Response(ctx, val, field); err != nil { 343 return nil, err 344 } 345 } 346 if err := fs(field, val); err != nil { 347 return nil, err 348 } 349 } 350 } 351 _, err = in.ReadStructBegin() 352 if err != nil { 353 return nil, err 354 } 355 readFields := map[int32]struct{}{} 356 for { 357 _, fieldType, fieldID, err := in.ReadFieldBegin() 358 if err != nil { 359 return nil, err 360 } 361 if fieldType == thrift.STOP { 362 if err := in.ReadFieldEnd(); err != nil { 363 return nil, err 364 } 365 // check required 366 // void is nil struct 367 if t.Struct != nil { 368 if err := t.Struct.CheckRequired(readFields); err != nil { 369 return nil, err 370 } 371 } 372 return st, in.ReadStructEnd() 373 } 374 field, ok := t.Struct.FieldsByID[int32(fieldID)] 375 if !ok { 376 // just ignore the missing field, maybe server update its idls 377 if err := in.Skip(fieldType); err != nil { 378 return nil, err 379 } 380 } else { 381 nest := unnestPb(opt, field.ID) 382 _fieldType := descriptor.FromThriftTType(fieldType) 383 reader, err := nextReader(_fieldType, field.Type, opt) 384 if err != nil { 385 return nil, fmt.Errorf("nextReader of %s/%s/%d error %w", t.Name, field.Name, fieldID, err) 386 } 387 val, err := reader(ctx, in, field.Type, opt) 388 if err != nil { 389 return nil, fmt.Errorf("reader of %s/%s/%d error %w", t.Name, field.Name, fieldID, err) 390 } 391 if field.ValueMapping != nil { 392 if val, err = field.ValueMapping.Response(ctx, val, field); err != nil { 393 return nil, err 394 } 395 } 396 nest() 397 398 if err := fs(field, val); err != nil { 399 return nil, err 400 } 401 } 402 if err := in.ReadFieldEnd(); err != nil { 403 return nil, err 404 } 405 readFields[int32(fieldID)] = struct{}{} 406 } 407 } 408 409 func readHTTPResponse(ctx context.Context, in thrift.TProtocol, t *descriptor.TypeDescriptor, opt *readerOption) (interface{}, error) { 410 var resp *descriptor.HTTPResponse 411 if opt == nil || opt.pbDsc == nil { 412 if opt == nil { 413 opt = &readerOption{} 414 } 415 resp = descriptor.NewHTTPResponse() 416 } else { 417 resp = descriptor.NewHTTPPbResponse(proto.NewMessage(opt.pbDsc)) 418 } 419 420 var err error 421 // set default value 422 // default value with struct NOT SUPPORT pb. 423 for _, field := range t.Struct.DefaultFields { 424 val := field.DefaultValue 425 if field.ValueMapping != nil { 426 if val, err = field.ValueMapping.Response(ctx, val, field); err != nil { 427 return nil, err 428 } 429 } 430 if err = field.HTTPMapping.Response(ctx, resp, field, val); err != nil { 431 return nil, err 432 } 433 } 434 _, err = in.ReadStructBegin() 435 if err != nil { 436 return nil, err 437 } 438 readFields := map[int32]struct{}{} 439 for { 440 _, fieldType, fieldID, err := in.ReadFieldBegin() 441 if err != nil { 442 return nil, err 443 } 444 if fieldType == thrift.STOP { 445 if err := in.ReadFieldEnd(); err != nil { 446 return nil, err 447 } 448 // check required 449 if err := t.Struct.CheckRequired(readFields); err != nil { 450 return nil, err 451 } 452 return resp, in.ReadStructEnd() 453 } 454 field, ok := t.Struct.FieldsByID[int32(fieldID)] 455 if !ok { 456 // just ignore the missing field, maybe server update its idls 457 if err := in.Skip(fieldType); err != nil { 458 return nil, err 459 } 460 } else { 461 // Replace pb descriptor with field type 462 nest := unnestPb(opt, field.ID) 463 464 // check required 465 _fieldType := descriptor.FromThriftTType(fieldType) 466 reader, err := nextReader(_fieldType, field.Type, opt) 467 if err != nil { 468 return nil, fmt.Errorf("nextReader of %s/%s/%d error %w", t.Name, field.Name, fieldID, err) 469 } 470 val, err := reader(ctx, in, field.Type, opt) 471 if err != nil { 472 return nil, fmt.Errorf("reader of %s/%s/%d error %w", t.Name, field.Name, fieldID, err) 473 } 474 if field.ValueMapping != nil { 475 if val, err = field.ValueMapping.Response(ctx, val, field); err != nil { 476 return nil, err 477 } 478 } 479 nest() 480 if err = field.HTTPMapping.Response(ctx, resp, field, val); err != nil { 481 return nil, err 482 } 483 } 484 if err := in.ReadFieldEnd(); err != nil { 485 return nil, err 486 } 487 readFields[int32(fieldID)] = struct{}{} 488 } 489 } 490 491 func unnestPb(opt *readerOption, fieldId int32) func() { 492 pbDsc := opt.pbDsc 493 if pbDsc != nil { 494 fd := opt.pbDsc.FindFieldByNumber(fieldId) 495 if fd != nil && fd.GetMessageType() != nil { 496 opt.pbDsc = fd.GetMessageType() 497 } else { 498 opt.pbDsc = emptyPbDsc 499 } 500 } 501 return func() { 502 opt.pbDsc = pbDsc 503 } 504 }