github.com/segmentio/encoding@v0.4.0/thrift/decode.go (about) 1 package thrift 2 3 import ( 4 "bufio" 5 "bytes" 6 "fmt" 7 "io" 8 "reflect" 9 "sync/atomic" 10 ) 11 12 // Unmarshal deserializes the thrift data from b to v using to the protocol p. 13 // 14 // The function errors if the data in b does not match the type of v. 15 // 16 // The function panics if v cannot be converted to a thrift representation. 17 // 18 // As an optimization, the value passed in v may be reused across multiple calls 19 // to Unmarshal, allowing the function to reuse objects referenced by pointer 20 // fields of struct values. When reusing objects, the application is responsible 21 // for resetting the state of v before calling Unmarshal again. 22 func Unmarshal(p Protocol, b []byte, v interface{}) error { 23 br := bytes.NewReader(b) 24 pr := p.NewReader(br) 25 26 if err := NewDecoder(pr).Decode(v); err != nil { 27 return err 28 } 29 30 if n := br.Len(); n != 0 { 31 return fmt.Errorf("unexpected trailing bytes at the end of thrift input: %d", n) 32 } 33 34 return nil 35 } 36 37 type Decoder struct { 38 r Reader 39 f flags 40 } 41 42 func NewDecoder(r Reader) *Decoder { 43 return &Decoder{r: r, f: decoderFlags(r)} 44 } 45 46 func (d *Decoder) Decode(v interface{}) error { 47 t := reflect.TypeOf(v) 48 p := reflect.ValueOf(v) 49 50 if t.Kind() != reflect.Ptr { 51 panic("thrift.(*Decoder).Decode: expected pointer type but got " + t.String()) 52 } 53 54 t = t.Elem() 55 p = p.Elem() 56 57 cache, _ := decoderCache.Load().(map[typeID]decodeFunc) 58 decode, _ := cache[makeTypeID(t)] 59 60 if decode == nil { 61 decode = decodeFuncOf(t, make(decodeFuncCache)) 62 63 newCache := make(map[typeID]decodeFunc, len(cache)+1) 64 newCache[makeTypeID(t)] = decode 65 for k, v := range cache { 66 newCache[k] = v 67 } 68 69 decoderCache.Store(newCache) 70 } 71 72 return decode(d.r, p, d.f) 73 } 74 75 func (d *Decoder) Reset(r Reader) { 76 d.r = r 77 d.f = d.f.without(protocolFlags).with(decoderFlags(r)) 78 } 79 80 func (d *Decoder) SetStrict(enabled bool) { 81 if enabled { 82 d.f = d.f.with(strict) 83 } else { 84 d.f = d.f.without(strict) 85 } 86 } 87 88 func decoderFlags(r Reader) flags { 89 return flags(r.Protocol().Features() << featuresBitOffset) 90 } 91 92 var decoderCache atomic.Value // map[typeID]decodeFunc 93 94 type decodeFunc func(Reader, reflect.Value, flags) error 95 96 type decodeFuncCache map[reflect.Type]decodeFunc 97 98 func decodeFuncOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 99 f := seen[t] 100 if f != nil { 101 return f 102 } 103 switch t.Kind() { 104 case reflect.Bool: 105 f = decodeBool 106 case reflect.Int8: 107 f = decodeInt8 108 case reflect.Int16: 109 f = decodeInt16 110 case reflect.Int32: 111 f = decodeInt32 112 case reflect.Int64, reflect.Int: 113 f = decodeInt64 114 case reflect.Float32, reflect.Float64: 115 f = decodeFloat64 116 case reflect.String: 117 f = decodeString 118 case reflect.Slice: 119 if t.Elem().Kind() == reflect.Uint8 { // []byte 120 f = decodeBytes 121 } else { 122 f = decodeFuncSliceOf(t, seen) 123 } 124 case reflect.Map: 125 f = decodeFuncMapOf(t, seen) 126 case reflect.Struct: 127 f = decodeFuncStructOf(t, seen) 128 case reflect.Ptr: 129 f = decodeFuncPtrOf(t, seen) 130 default: 131 panic("type cannot be decoded in thrift: " + t.String()) 132 } 133 seen[t] = f 134 return f 135 } 136 137 func decodeBool(r Reader, v reflect.Value, _ flags) error { 138 b, err := r.ReadBool() 139 if err != nil { 140 return err 141 } 142 v.SetBool(b) 143 return nil 144 } 145 146 func decodeInt8(r Reader, v reflect.Value, _ flags) error { 147 i, err := r.ReadInt8() 148 if err != nil { 149 return err 150 } 151 v.SetInt(int64(i)) 152 return nil 153 } 154 155 func decodeInt16(r Reader, v reflect.Value, _ flags) error { 156 i, err := r.ReadInt16() 157 if err != nil { 158 return err 159 } 160 v.SetInt(int64(i)) 161 return nil 162 } 163 164 func decodeInt32(r Reader, v reflect.Value, _ flags) error { 165 i, err := r.ReadInt32() 166 if err != nil { 167 return err 168 } 169 v.SetInt(int64(i)) 170 return nil 171 } 172 173 func decodeInt64(r Reader, v reflect.Value, _ flags) error { 174 i, err := r.ReadInt64() 175 if err != nil { 176 return err 177 } 178 v.SetInt(int64(i)) 179 return nil 180 } 181 182 func decodeFloat64(r Reader, v reflect.Value, _ flags) error { 183 f, err := r.ReadFloat64() 184 if err != nil { 185 return err 186 } 187 v.SetFloat(f) 188 return nil 189 } 190 191 func decodeString(r Reader, v reflect.Value, _ flags) error { 192 s, err := r.ReadString() 193 if err != nil { 194 return err 195 } 196 v.SetString(s) 197 return nil 198 } 199 200 func decodeBytes(r Reader, v reflect.Value, _ flags) error { 201 b, err := r.ReadBytes() 202 if err != nil { 203 return err 204 } 205 v.SetBytes(b) 206 return nil 207 } 208 209 func decodeFuncSliceOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 210 elem := t.Elem() 211 typ := TypeOf(elem) 212 dec := decodeFuncOf(elem, seen) 213 214 return func(r Reader, v reflect.Value, flags flags) error { 215 l, err := r.ReadList() 216 if err != nil { 217 return err 218 } 219 220 // Sometimes the list type is set to TRUE when the list contains only 221 // TRUE values. Thrift does not seem to optimize the encoding by 222 // omitting the boolean values that are known to all be TRUE, we still 223 // need to decode them. 224 switch l.Type { 225 case TRUE: 226 l.Type = BOOL 227 } 228 229 // TODO: implement type conversions? 230 if typ != l.Type { 231 if flags.have(strict) { 232 return &TypeMismatch{item: "list item", Expect: typ, Found: l.Type} 233 } 234 return nil 235 } 236 237 v.Set(reflect.MakeSlice(t, int(l.Size), int(l.Size))) 238 flags = flags.only(decodeFlags) 239 240 for i := 0; i < int(l.Size); i++ { 241 if err := dec(r, v.Index(i), flags); err != nil { 242 return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i}) 243 } 244 } 245 246 return nil 247 } 248 } 249 250 func decodeFuncMapOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 251 key, elem := t.Key(), t.Elem() 252 if elem.Size() == 0 { // map[?]struct{} 253 return decodeFuncMapAsSetOf(t, seen) 254 } 255 256 mapType := reflect.MapOf(key, elem) 257 keyZero := reflect.Zero(key) 258 elemZero := reflect.Zero(elem) 259 keyType := TypeOf(key) 260 elemType := TypeOf(elem) 261 decodeKey := decodeFuncOf(key, seen) 262 decodeElem := decodeFuncOf(elem, seen) 263 264 return func(r Reader, v reflect.Value, flags flags) error { 265 m, err := r.ReadMap() 266 if err != nil { 267 return err 268 } 269 270 v.Set(reflect.MakeMapWithSize(mapType, int(m.Size))) 271 272 if m.Size == 0 { // empty map 273 return nil 274 } 275 276 // TODO: implement type conversions? 277 if keyType != m.Key { 278 if flags.have(strict) { 279 return &TypeMismatch{item: "map key", Expect: keyType, Found: m.Key} 280 } 281 return nil 282 } 283 284 if elemType != m.Value { 285 if flags.have(strict) { 286 return &TypeMismatch{item: "map value", Expect: elemType, Found: m.Value} 287 } 288 return nil 289 } 290 291 tmpKey := reflect.New(key).Elem() 292 tmpElem := reflect.New(elem).Elem() 293 flags = flags.only(decodeFlags) 294 295 for i := 0; i < int(m.Size); i++ { 296 if err := decodeKey(r, tmpKey, flags); err != nil { 297 return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) 298 } 299 if err := decodeElem(r, tmpElem, flags); err != nil { 300 return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) 301 } 302 v.SetMapIndex(tmpKey, tmpElem) 303 tmpKey.Set(keyZero) 304 tmpElem.Set(elemZero) 305 } 306 307 return nil 308 } 309 } 310 311 func decodeFuncMapAsSetOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 312 key, elem := t.Key(), t.Elem() 313 keyZero := reflect.Zero(key) 314 elemZero := reflect.Zero(elem) 315 typ := TypeOf(key) 316 dec := decodeFuncOf(key, seen) 317 318 return func(r Reader, v reflect.Value, flags flags) error { 319 s, err := r.ReadSet() 320 if err != nil { 321 return err 322 } 323 324 // See decodeFuncSliceOf for details about why this type conversion 325 // needs to be done. 326 switch s.Type { 327 case TRUE: 328 s.Type = BOOL 329 } 330 331 v.Set(reflect.MakeMapWithSize(t, int(s.Size))) 332 333 if s.Size == 0 { 334 return nil 335 } 336 337 // TODO: implement type conversions? 338 if typ != s.Type { 339 if flags.have(strict) { 340 return &TypeMismatch{item: "list item", Expect: typ, Found: s.Type} 341 } 342 return nil 343 } 344 345 tmp := reflect.New(key).Elem() 346 flags = flags.only(decodeFlags) 347 348 for i := 0; i < int(s.Size); i++ { 349 if err := dec(r, tmp, flags); err != nil { 350 return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i}) 351 } 352 v.SetMapIndex(tmp, elemZero) 353 tmp.Set(keyZero) 354 } 355 356 return nil 357 } 358 } 359 360 type structDecoder struct { 361 fields []structDecoderField 362 union []int 363 minID int16 364 zero reflect.Value 365 required []uint64 366 } 367 368 func (dec *structDecoder) decode(r Reader, v reflect.Value, flags flags) error { 369 flags = flags.only(decodeFlags) 370 coalesceBoolFields := flags.have(coalesceBoolFields) 371 372 lastField := reflect.Value{} 373 union := len(dec.union) > 0 374 seen := make([]uint64, 1) 375 if len(dec.required) > len(seen) { 376 seen = make([]uint64, len(dec.required)) 377 } 378 379 err := readStruct(r, func(r Reader, f Field) error { 380 i := int(f.ID) - int(dec.minID) 381 if i < 0 || i >= len(dec.fields) || dec.fields[i].decode == nil { 382 return skipField(r, f) 383 } 384 field := &dec.fields[i] 385 seen[i/64] |= 1 << (i % 64) 386 387 // TODO: implement type conversions? 388 if f.Type != field.typ && !(f.Type == TRUE && field.typ == BOOL) { 389 if flags.have(strict) { 390 return &TypeMismatch{item: "field value", Expect: field.typ, Found: f.Type} 391 } 392 return nil 393 } 394 395 x := v 396 for _, i := range field.index { 397 if x.Kind() == reflect.Ptr { 398 x = x.Elem() 399 } 400 if x = x.Field(i); x.Kind() == reflect.Ptr { 401 if x.IsNil() { 402 x.Set(reflect.New(x.Type().Elem())) 403 } 404 } 405 } 406 407 if union { 408 v.Set(dec.zero) 409 } 410 411 lastField = x 412 413 if coalesceBoolFields && (f.Type == TRUE || f.Type == FALSE) { 414 for x.Kind() == reflect.Ptr { 415 if x.IsNil() { 416 x.Set(reflect.New(x.Type().Elem())) 417 } 418 x = x.Elem() 419 } 420 x.SetBool(f.Type == TRUE) 421 return nil 422 } 423 424 return field.decode(r, x, flags.with(field.flags)) 425 }) 426 if err != nil { 427 return err 428 } 429 430 for i, required := range dec.required { 431 if mask := required & seen[i]; mask != required { 432 i *= 64 433 for (mask & 1) != 0 { 434 mask >>= 1 435 i++ 436 } 437 field := &dec.fields[i] 438 return &MissingField{Field: Field{ID: field.id, Type: field.typ}} 439 } 440 } 441 442 if union && lastField.IsValid() { 443 v.FieldByIndex(dec.union).Set(lastField.Addr()) 444 } 445 446 return nil 447 } 448 449 type structDecoderField struct { 450 index []int 451 id int16 452 flags flags 453 typ Type 454 decode decodeFunc 455 } 456 457 func decodeFuncStructOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 458 dec := &structDecoder{ 459 zero: reflect.Zero(t), 460 } 461 decode := dec.decode 462 seen[t] = decode 463 464 fields := make([]structDecoderField, 0, t.NumField()) 465 forEachStructField(t, nil, func(f structField) { 466 if f.flags.have(union) { 467 dec.union = f.index 468 } else { 469 fields = append(fields, structDecoderField{ 470 index: f.index, 471 id: f.id, 472 flags: f.flags, 473 typ: TypeOf(f.typ), 474 decode: decodeFuncStructFieldOf(f, seen), 475 }) 476 } 477 }) 478 479 minID := int16(0) 480 maxID := int16(0) 481 482 for _, f := range fields { 483 if f.id < minID || minID == 0 { 484 minID = f.id 485 } 486 if f.id > maxID { 487 maxID = f.id 488 } 489 } 490 491 dec.fields = make([]structDecoderField, (maxID-minID)+1) 492 dec.minID = minID 493 dec.required = make([]uint64, len(fields)/64+1) 494 495 for _, f := range fields { 496 i := f.id - minID 497 p := dec.fields[i] 498 if p.decode != nil { 499 panic(fmt.Errorf("thrift struct field id %d is present multiple times in %s with types %s and %s", f.id, t, p.typ, f.typ)) 500 } 501 dec.fields[i] = f 502 if f.flags.have(required) { 503 dec.required[i/64] |= 1 << (i % 64) 504 } 505 } 506 507 return decode 508 } 509 510 func decodeFuncStructFieldOf(f structField, seen decodeFuncCache) decodeFunc { 511 if f.flags.have(enum) { 512 switch f.typ.Kind() { 513 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 514 return decodeInt32 515 } 516 } 517 return decodeFuncOf(f.typ, seen) 518 } 519 520 func decodeFuncPtrOf(t reflect.Type, seen decodeFuncCache) decodeFunc { 521 elem := t.Elem() 522 decode := decodeFuncOf(t.Elem(), seen) 523 return func(r Reader, v reflect.Value, f flags) error { 524 if v.IsNil() { 525 v.Set(reflect.New(elem)) 526 } 527 return decode(r, v.Elem(), f) 528 } 529 } 530 531 func readBinary(r Reader, f func(io.Reader) error) error { 532 n, err := r.ReadLength() 533 if err != nil { 534 return err 535 } 536 return dontExpectEOF(f(io.LimitReader(r.Reader(), int64(n)))) 537 } 538 539 func readList(r Reader, f func(Reader, Type) error) error { 540 l, err := r.ReadList() 541 if err != nil { 542 return err 543 } 544 545 for i := 0; i < int(l.Size); i++ { 546 if err := f(r, l.Type); err != nil { 547 return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i}) 548 } 549 } 550 551 return nil 552 } 553 554 func readSet(r Reader, f func(Reader, Type) error) error { 555 s, err := r.ReadSet() 556 if err != nil { 557 return err 558 } 559 560 for i := 0; i < int(s.Size); i++ { 561 if err := f(r, s.Type); err != nil { 562 return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i}) 563 } 564 } 565 566 return nil 567 } 568 569 func readMap(r Reader, f func(Reader, Type, Type) error) error { 570 m, err := r.ReadMap() 571 if err != nil { 572 return err 573 } 574 575 for i := 0; i < int(m.Size); i++ { 576 if err := f(r, m.Key, m.Value); err != nil { 577 return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i}) 578 } 579 } 580 581 return nil 582 } 583 584 func readStruct(r Reader, f func(Reader, Field) error) error { 585 lastFieldID := int16(0) 586 numFields := 0 587 588 for { 589 x, err := r.ReadField() 590 if err != nil { 591 if numFields > 0 { 592 err = dontExpectEOF(err) 593 } 594 return err 595 } 596 597 if x.Type == STOP { 598 return nil 599 } 600 601 if x.Delta { 602 x.ID += lastFieldID 603 x.Delta = false 604 } 605 606 if err := f(r, x); err != nil { 607 return with(dontExpectEOF(err), &decodeErrorField{cause: x}) 608 } 609 610 lastFieldID = x.ID 611 numFields++ 612 } 613 } 614 615 func skip(r Reader, t Type) error { 616 var err error 617 switch t { 618 case TRUE, FALSE: 619 _, err = r.ReadBool() 620 case I8: 621 _, err = r.ReadInt8() 622 case I16: 623 _, err = r.ReadInt16() 624 case I32: 625 _, err = r.ReadInt32() 626 case I64: 627 _, err = r.ReadInt64() 628 case DOUBLE: 629 _, err = r.ReadFloat64() 630 case BINARY: 631 err = skipBinary(r) 632 case LIST: 633 err = skipList(r) 634 case SET: 635 err = skipSet(r) 636 case MAP: 637 err = skipMap(r) 638 case STRUCT: 639 err = skipStruct(r) 640 default: 641 return fmt.Errorf("skipping unsupported thrift type %d", t) 642 } 643 return err 644 } 645 646 func skipBinary(r Reader) error { 647 n, err := r.ReadLength() 648 if err != nil { 649 return err 650 } 651 if n == 0 { 652 return nil 653 } 654 switch x := r.Reader().(type) { 655 case *bufio.Reader: 656 _, err = x.Discard(int(n)) 657 default: 658 _, err = io.CopyN(io.Discard, x, int64(n)) 659 } 660 return dontExpectEOF(err) 661 } 662 663 func skipList(r Reader) error { 664 return readList(r, skip) 665 } 666 667 func skipSet(r Reader) error { 668 return readSet(r, skip) 669 } 670 671 func skipMap(r Reader) error { 672 return readMap(r, func(r Reader, k, v Type) error { 673 if err := skip(r, k); err != nil { 674 return dontExpectEOF(err) 675 } 676 if err := skip(r, v); err != nil { 677 return dontExpectEOF(err) 678 } 679 return nil 680 }) 681 } 682 683 func skipStruct(r Reader) error { 684 return readStruct(r, skipField) 685 } 686 687 func skipField(r Reader, f Field) error { 688 return skip(r, f.Type) 689 }