github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/thrift/decode.go (about) 1 package thrift 2 3 import ( 4 "bufio" 5 "bytes" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "reflect" 10 "sync/atomic" 11 ) 12 13 // Unmarshal deserializes the thrift data from b to v using to the protocol p. 14 // 15 // The function errors if the data in b does not match the type of v. 16 // 17 // The function panics if v cannot be converted to a thrift representation. 18 // 19 // As an optimization, the value passed in v may be reused across multiple calls 20 // to Unmarshal, allowing the function to reuse objects referenced by pointer 21 // fields of struct values. When reusing objects, the application is responsible 22 // for resetting the state of v before calling Unmarshal again. 23 func Unmarshal(p Protocol, b []byte, v interface{}) error { 24 br := bytes.NewReader(b) 25 pr := p.NewReader(br) 26 27 if err := NewDecoder(pr).Decode(v); err != nil { 28 return err 29 } 30 31 if n := br.Len(); n != 0 { 32 return fmt.Errorf("unexpected trailing bytes at the end of thrift input: %d", n) 33 } 34 35 return nil 36 } 37 38 type Decoder struct { 39 r Reader 40 f flags 41 } 42 43 func NewDecoder(r Reader) *Decoder { 44 return &Decoder{r: r, f: decoderFlags(r)} 45 } 46 47 func (d *Decoder) Decode(v interface{}) error { 48 t := reflect.TypeOf(v) 49 p := reflect.ValueOf(v) 50 51 if t.Kind() != reflect.Ptr { 52 panic("thrift.(*Decoder).Decode: expected pointer type but got " + t.String()) 53 } 54 55 t = t.Elem() 56 p = p.Elem() 57 58 cache, _ := decoderCache.Load().(map[typeID]decodeFunc) 59 decode, _ := cache[makeTypeID(t)] 60 61 if decode == nil { 62 decode = decodeFuncOf(t, make(decodeFuncCache)) 63 64 newCache := make(map[typeID]decodeFunc, len(cache)+1) 65 newCache[makeTypeID(t)] = decode 66 for k, v := range cache { 67 newCache[k] = v 68 } 69 70 decoderCache.Store(newCache) 71 } 72 73 return decode(d.r, p, d.f) 74 } 75 76 func (d *Decoder) Reset(r Reader) { 77 d.r = r 78 d.f = d.f.without(protocolFlags).with(decoderFlags(r)) 79 } 80 81 func (d *Decoder) SetStrict(enabled bool) { 82 if enabled { 83 d.f = d.f.with(strict) 84 } else { 85 d.f = d.f.without(strict) 86 } 87 } 88 89 func decoderFlags(r Reader) flags { 90 return flags(r.Protocol().Features() << featuresBitOffset) 91 } 92 93 var decoderCache atomic.Value // map[typeID]decodeFunc 94 95 type decodeFunc func(Reader, reflect.Value, flags) error 96 97 type decodeFuncCache map[reflect.Type]decodeFunc 98 99 func decodeFuncOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 100 f := seen[t] 101 if f != nil { 102 return f 103 } 104 switch t.Kind() { 105 case reflect.Bool: 106 f = decodeBool 107 case reflect.Int8: 108 f = decodeInt8 109 case reflect.Int16: 110 f = decodeInt16 111 case reflect.Int32: 112 f = decodeInt32 113 case reflect.Int64, reflect.Int: 114 f = decodeInt64 115 case reflect.Float32, reflect.Float64: 116 f = decodeFloat64 117 case reflect.String: 118 f = decodeString 119 case reflect.Slice: 120 if t.Elem().Kind() == reflect.Uint8 { // []byte 121 f = decodeBytes 122 } else { 123 f = decodeFuncSliceOf(t, seen) 124 } 125 case reflect.Map: 126 f = decodeFuncMapOf(t, seen) 127 case reflect.Struct: 128 f = decodeFuncStructOf(t, seen) 129 case reflect.Ptr: 130 f = decodeFuncPtrOf(t, seen) 131 default: 132 panic("type cannot be decoded in thrift: " + t.String()) 133 } 134 seen[t] = f 135 return f 136 } 137 138 func decodeBool(r Reader, v reflect.Value, _ flags) error { 139 b, err := r.ReadBool() 140 if err != nil { 141 return err 142 } 143 v.SetBool(b) 144 return nil 145 } 146 147 func decodeInt8(r Reader, v reflect.Value, _ flags) error { 148 i, err := r.ReadInt8() 149 if err != nil { 150 return err 151 } 152 v.SetInt(int64(i)) 153 return nil 154 } 155 156 func decodeInt16(r Reader, v reflect.Value, _ flags) error { 157 i, err := r.ReadInt16() 158 if err != nil { 159 return err 160 } 161 v.SetInt(int64(i)) 162 return nil 163 } 164 165 func decodeInt32(r Reader, v reflect.Value, _ flags) error { 166 i, err := r.ReadInt32() 167 if err != nil { 168 return err 169 } 170 v.SetInt(int64(i)) 171 return nil 172 } 173 174 func decodeInt64(r Reader, v reflect.Value, _ flags) error { 175 i, err := r.ReadInt64() 176 if err != nil { 177 return err 178 } 179 v.SetInt(int64(i)) 180 return nil 181 } 182 183 func decodeFloat64(r Reader, v reflect.Value, _ flags) error { 184 f, err := r.ReadFloat64() 185 if err != nil { 186 return err 187 } 188 v.SetFloat(f) 189 return nil 190 } 191 192 func decodeString(r Reader, v reflect.Value, _ flags) error { 193 s, err := r.ReadString() 194 if err != nil { 195 return err 196 } 197 v.SetString(s) 198 return nil 199 } 200 201 func decodeBytes(r Reader, v reflect.Value, _ flags) error { 202 b, err := r.ReadBytes() 203 if err != nil { 204 return err 205 } 206 v.SetBytes(b) 207 return nil 208 } 209 210 func decodeFuncSliceOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 211 elem := t.Elem() 212 typ := TypeOf(elem) 213 dec := decodeFuncOf(elem, seen) 214 215 return func(r Reader, v reflect.Value, flags flags) error { 216 l, err := r.ReadList() 217 if err != nil { 218 return err 219 } 220 221 // Sometimes the list type is set to TRUE when the list contains only 222 // TRUE values. Thrift does not seem to optimize the encoding by 223 // omitting the boolean values that are known to all be TRUE, we still 224 // need to decode them. 225 switch l.Type { 226 case TRUE: 227 l.Type = BOOL 228 } 229 230 // TODO: implement type conversions? 231 if typ != l.Type { 232 if flags.have(strict) { 233 return &TypeMismatch{item: "list item", Expect: typ, Found: l.Type} 234 } 235 return nil 236 } 237 238 v.Set(reflect.MakeSlice(t, int(l.Size), int(l.Size))) 239 flags = flags.only(decodeFlags) 240 241 for i := 0; i < int(l.Size); i++ { 242 if err := dec(r, v.Index(i), flags); err != nil { 243 return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i}) 244 } 245 } 246 247 return nil 248 } 249 } 250 251 func decodeFuncMapOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 252 key, elem := t.Key(), t.Elem() 253 if elem.Size() == 0 { // map[?]struct{} 254 return decodeFuncMapAsSetOf(t, seen) 255 } 256 257 mapType := reflect.MapOf(key, elem) 258 keyZero := reflect.Zero(key) 259 elemZero := reflect.Zero(elem) 260 keyType := TypeOf(key) 261 elemType := TypeOf(elem) 262 decodeKey := decodeFuncOf(key, seen) 263 decodeElem := decodeFuncOf(elem, seen) 264 265 return func(r Reader, v reflect.Value, flags flags) error { 266 m, err := r.ReadMap() 267 if err != nil { 268 return err 269 } 270 271 v.Set(reflect.MakeMapWithSize(mapType, int(m.Size))) 272 273 if m.Size == 0 { // empty map 274 return nil 275 } 276 277 // TODO: implement type conversions? 278 if keyType != m.Key { 279 if flags.have(strict) { 280 return &TypeMismatch{item: "map key", Expect: keyType, Found: m.Key} 281 } 282 return nil 283 } 284 285 if elemType != m.Value { 286 if flags.have(strict) { 287 return &TypeMismatch{item: "map value", Expect: elemType, Found: m.Value} 288 } 289 return nil 290 } 291 292 tmpKey := reflect.New(key).Elem() 293 tmpElem := reflect.New(elem).Elem() 294 flags = flags.only(decodeFlags) 295 296 for i := 0; i < int(m.Size); i++ { 297 if err := decodeKey(r, tmpKey, flags); err != nil { 298 return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) 299 } 300 if err := decodeElem(r, tmpElem, flags); err != nil { 301 return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) 302 } 303 v.SetMapIndex(tmpKey, tmpElem) 304 tmpKey.Set(keyZero) 305 tmpElem.Set(elemZero) 306 } 307 308 return nil 309 } 310 } 311 312 func decodeFuncMapAsSetOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 313 key, elem := t.Key(), t.Elem() 314 keyZero := reflect.Zero(key) 315 elemZero := reflect.Zero(elem) 316 typ := TypeOf(key) 317 dec := decodeFuncOf(key, seen) 318 319 return func(r Reader, v reflect.Value, flags flags) error { 320 s, err := r.ReadSet() 321 if err != nil { 322 return err 323 } 324 325 // See decodeFuncSliceOf for details about why this type conversion 326 // needs to be done. 327 switch s.Type { 328 case TRUE: 329 s.Type = BOOL 330 } 331 332 v.Set(reflect.MakeMapWithSize(t, int(s.Size))) 333 334 if s.Size == 0 { 335 return nil 336 } 337 338 // TODO: implement type conversions? 339 if typ != s.Type { 340 if flags.have(strict) { 341 return &TypeMismatch{item: "list item", Expect: typ, Found: s.Type} 342 } 343 return nil 344 } 345 346 tmp := reflect.New(key).Elem() 347 flags = flags.only(decodeFlags) 348 349 for i := 0; i < int(s.Size); i++ { 350 if err := dec(r, tmp, flags); err != nil { 351 return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i}) 352 } 353 v.SetMapIndex(tmp, elemZero) 354 tmp.Set(keyZero) 355 } 356 357 return nil 358 } 359 } 360 361 type structDecoder struct { 362 fields []structDecoderField 363 union []int 364 minID int16 365 zero reflect.Value 366 required []uint64 367 } 368 369 func (dec *structDecoder) decode(r Reader, v reflect.Value, flags flags) error { 370 flags = flags.only(decodeFlags) 371 coalesceBoolFields := flags.have(coalesceBoolFields) 372 373 lastField := reflect.Value{} 374 union := len(dec.union) > 0 375 seen := make([]uint64, 1) 376 if len(dec.required) > len(seen) { 377 seen = make([]uint64, len(dec.required)) 378 } 379 380 err := readStruct(r, func(r Reader, f Field) error { 381 i := int(f.ID) - int(dec.minID) 382 if i < 0 || i >= len(dec.fields) || dec.fields[i].decode == nil { 383 return skipField(r, f) 384 } 385 field := &dec.fields[i] 386 seen[i/64] |= 1 << (i % 64) 387 388 // TODO: implement type conversions? 389 if f.Type != field.typ && !(f.Type == TRUE && field.typ == BOOL) { 390 if flags.have(strict) { 391 return &TypeMismatch{item: "field value", Expect: field.typ, Found: f.Type} 392 } 393 return nil 394 } 395 396 x := v 397 for _, i := range field.index { 398 if x.Kind() == reflect.Ptr { 399 x = x.Elem() 400 } 401 if x = x.Field(i); x.Kind() == reflect.Ptr { 402 if x.IsNil() { 403 x.Set(reflect.New(x.Type().Elem())) 404 } 405 } 406 } 407 408 if union { 409 v.Set(dec.zero) 410 } 411 412 lastField = x 413 414 if coalesceBoolFields && (f.Type == TRUE || f.Type == FALSE) { 415 for x.Kind() == reflect.Ptr { 416 if x.IsNil() { 417 x.Set(reflect.New(x.Type().Elem())) 418 } 419 x = x.Elem() 420 } 421 x.SetBool(f.Type == TRUE) 422 return nil 423 } 424 425 return field.decode(r, x, flags.with(field.flags)) 426 }) 427 if err != nil { 428 return err 429 } 430 431 for i, required := range dec.required { 432 if mask := required & seen[i]; mask != required { 433 i *= 64 434 for (mask & 1) != 0 { 435 mask >>= 1 436 i++ 437 } 438 field := &dec.fields[i] 439 return &MissingField{Field: Field{ID: field.id, Type: field.typ}} 440 } 441 } 442 443 if union && lastField.IsValid() { 444 v.FieldByIndex(dec.union).Set(lastField.Addr()) 445 } 446 447 return nil 448 } 449 450 type structDecoderField struct { 451 index []int 452 id int16 453 flags flags 454 typ Type 455 decode decodeFunc 456 } 457 458 func decodeFuncStructOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 459 dec := &structDecoder{ 460 zero: reflect.Zero(t), 461 } 462 decode := dec.decode 463 seen[t] = decode 464 465 fields := make([]structDecoderField, 0, t.NumField()) 466 forEachStructField(t, nil, func(f structField) { 467 if f.flags.have(union) { 468 dec.union = f.index 469 } else { 470 fields = append(fields, structDecoderField{ 471 index: f.index, 472 id: f.id, 473 flags: f.flags, 474 typ: TypeOf(f.typ), 475 decode: decodeFuncStructFieldOf(f, seen), 476 }) 477 } 478 }) 479 480 minID := int16(0) 481 maxID := int16(0) 482 483 for _, f := range fields { 484 if f.id < minID || minID == 0 { 485 minID = f.id 486 } 487 if f.id > maxID { 488 maxID = f.id 489 } 490 } 491 492 dec.fields = make([]structDecoderField, (maxID-minID)+1) 493 dec.minID = minID 494 dec.required = make([]uint64, len(fields)/64+1) 495 496 for _, f := range fields { 497 i := f.id - minID 498 p := dec.fields[i] 499 if p.decode != nil { 500 panic(fmt.Errorf("thrift struct field id %d is present multiple times in %s with types %s and %s", f.id, t, p.typ, f.typ)) 501 } 502 dec.fields[i] = f 503 if f.flags.have(required) { 504 dec.required[i/64] |= 1 << (i % 64) 505 } 506 } 507 508 return decode 509 } 510 511 func decodeFuncStructFieldOf(f structField, seen decodeFuncCache) decodeFunc { 512 if f.flags.have(enum) { 513 switch f.typ.Kind() { 514 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 515 return decodeInt32 516 } 517 } 518 return decodeFuncOf(f.typ, seen) 519 } 520 521 func decodeFuncPtrOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 522 elem := t.Elem() 523 decode := decodeFuncOf(t.Elem(), seen) 524 return func(r Reader, v reflect.Value, f flags) error { 525 if v.IsNil() { 526 v.Set(reflect.New(elem)) 527 } 528 return decode(r, v.Elem(), f) 529 } 530 } 531 532 func readBinary(r Reader, f func(io.Reader) error) error { 533 n, err := r.ReadLength() 534 if err != nil { 535 return err 536 } 537 return dontExpectEOF(f(io.LimitReader(r.Reader(), int64(n)))) 538 } 539 540 func readList(r Reader, f func(Reader, Type) error) error { 541 l, err := r.ReadList() 542 if err != nil { 543 return err 544 } 545 546 for i := 0; i < int(l.Size); i++ { 547 if err := f(r, l.Type); err != nil { 548 return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i}) 549 } 550 } 551 552 return nil 553 } 554 555 func readSet(r Reader, f func(Reader, Type) error) error { 556 s, err := r.ReadSet() 557 if err != nil { 558 return err 559 } 560 561 for i := 0; i < int(s.Size); i++ { 562 if err := f(r, s.Type); err != nil { 563 return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i}) 564 } 565 } 566 567 return nil 568 } 569 570 func readMap(r Reader, f func(Reader, Type, Type) error) error { 571 m, err := r.ReadMap() 572 if err != nil { 573 return err 574 } 575 576 for i := 0; i < int(m.Size); i++ { 577 if err := f(r, m.Key, m.Value); err != nil { 578 return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) 579 } 580 } 581 582 return nil 583 } 584 585 func readStruct(r Reader, f func(Reader, Field) error) error { 586 lastFieldID := int16(0) 587 numFields := 0 588 589 for { 590 x, err := r.ReadField() 591 if err != nil { 592 if numFields > 0 { 593 err = dontExpectEOF(err) 594 } 595 return err 596 } 597 598 if x.Type == STOP { 599 return nil 600 } 601 602 if x.Delta { 603 x.ID += lastFieldID 604 x.Delta = false 605 } 606 607 if err := f(r, x); err != nil { 608 return with(dontExpectEOF(err), &decodeErrorField{cause: x}) 609 } 610 611 lastFieldID = x.ID 612 numFields++ 613 } 614 } 615 616 func skip(r Reader, t Type) error { 617 var err error 618 switch t { 619 case TRUE, FALSE: 620 _, err = r.ReadBool() 621 case I8: 622 _, err = r.ReadInt8() 623 case I16: 624 _, err = r.ReadInt16() 625 case I32: 626 _, err = r.ReadInt32() 627 case I64: 628 _, err = r.ReadInt64() 629 case DOUBLE: 630 _, err = r.ReadFloat64() 631 case BINARY: 632 err = skipBinary(r) 633 case LIST: 634 err = skipList(r) 635 case SET: 636 err = skipSet(r) 637 case MAP: 638 err = skipMap(r) 639 case STRUCT: 640 err = skipStruct(r) 641 default: 642 return fmt.Errorf("skipping unsupported thrift type %d", t) 643 } 644 return err 645 } 646 647 func skipBinary(r Reader) error { 648 n, err := r.ReadLength() 649 if err != nil { 650 return err 651 } 652 if n == 0 { 653 return nil 654 } 655 switch x := r.Reader().(type) { 656 case *bufio.Reader: 657 _, err = x.Discard(int(n)) 658 default: 659 _, err = io.CopyN(ioutil.Discard, x, int64(n)) 660 } 661 return dontExpectEOF(err) 662 } 663 664 func skipList(r Reader) error { 665 return readList(r, skip) 666 } 667 668 func skipSet(r Reader) error { 669 return readSet(r, skip) 670 } 671 672 func skipMap(r Reader) error { 673 return readMap(r, func(r Reader, k, v Type) error { 674 if err := skip(r, k); err != nil { 675 return dontExpectEOF(err) 676 } 677 if err := skip(r, v); err != nil { 678 return dontExpectEOF(err) 679 } 680 return nil 681 }) 682 } 683 684 func skipStruct(r Reader) error { 685 return readStruct(r, skipField) 686 } 687 688 func skipField(r Reader, f Field) error { 689 return skip(r, f.Type) 690 }