github.com/segmentio/encoding@v0.4.0/thrift/decode.go (about)

     1  package thrift
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"sync/atomic"
    10  )
    11  
    12  // Unmarshal deserializes the thrift data from b to v using to the protocol p.
    13  //
    14  // The function errors if the data in b does not match the type of v.
    15  //
    16  // The function panics if v cannot be converted to a thrift representation.
    17  //
    18  // As an optimization, the value passed in v may be reused across multiple calls
    19  // to Unmarshal, allowing the function to reuse objects referenced by pointer
    20  // fields of struct values. When reusing objects, the application is responsible
    21  // for resetting the state of v before calling Unmarshal again.
    22  func Unmarshal(p Protocol, b []byte, v interface{}) error {
    23  	br := bytes.NewReader(b)
    24  	pr := p.NewReader(br)
    25  
    26  	if err := NewDecoder(pr).Decode(v); err != nil {
    27  		return err
    28  	}
    29  
    30  	if n := br.Len(); n != 0 {
    31  		return fmt.Errorf("unexpected trailing bytes at the end of thrift input: %d", n)
    32  	}
    33  
    34  	return nil
    35  }
    36  
    37  type Decoder struct {
    38  	r Reader
    39  	f flags
    40  }
    41  
    42  func NewDecoder(r Reader) *Decoder {
    43  	return &Decoder{r: r, f: decoderFlags(r)}
    44  }
    45  
    46  func (d *Decoder) Decode(v interface{}) error {
    47  	t := reflect.TypeOf(v)
    48  	p := reflect.ValueOf(v)
    49  
    50  	if t.Kind() != reflect.Ptr {
    51  		panic("thrift.(*Decoder).Decode: expected pointer type but got " + t.String())
    52  	}
    53  
    54  	t = t.Elem()
    55  	p = p.Elem()
    56  
    57  	cache, _ := decoderCache.Load().(map[typeID]decodeFunc)
    58  	decode, _ := cache[makeTypeID(t)]
    59  
    60  	if decode == nil {
    61  		decode = decodeFuncOf(t, make(decodeFuncCache))
    62  
    63  		newCache := make(map[typeID]decodeFunc, len(cache)+1)
    64  		newCache[makeTypeID(t)] = decode
    65  		for k, v := range cache {
    66  			newCache[k] = v
    67  		}
    68  
    69  		decoderCache.Store(newCache)
    70  	}
    71  
    72  	return decode(d.r, p, d.f)
    73  }
    74  
    75  func (d *Decoder) Reset(r Reader) {
    76  	d.r = r
    77  	d.f = d.f.without(protocolFlags).with(decoderFlags(r))
    78  }
    79  
    80  func (d *Decoder) SetStrict(enabled bool) {
    81  	if enabled {
    82  		d.f = d.f.with(strict)
    83  	} else {
    84  		d.f = d.f.without(strict)
    85  	}
    86  }
    87  
    88  func decoderFlags(r Reader) flags {
    89  	return flags(r.Protocol().Features() << featuresBitOffset)
    90  }
    91  
    92  var decoderCache atomic.Value // map[typeID]decodeFunc
    93  
    94  type decodeFunc func(Reader, reflect.Value, flags) error
    95  
    96  type decodeFuncCache map[reflect.Type]decodeFunc
    97  
    98  func decodeFuncOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
    99  	f := seen[t]
   100  	if f != nil {
   101  		return f
   102  	}
   103  	switch t.Kind() {
   104  	case reflect.Bool:
   105  		f = decodeBool
   106  	case reflect.Int8:
   107  		f = decodeInt8
   108  	case reflect.Int16:
   109  		f = decodeInt16
   110  	case reflect.Int32:
   111  		f = decodeInt32
   112  	case reflect.Int64, reflect.Int:
   113  		f = decodeInt64
   114  	case reflect.Float32, reflect.Float64:
   115  		f = decodeFloat64
   116  	case reflect.String:
   117  		f = decodeString
   118  	case reflect.Slice:
   119  		if t.Elem().Kind() == reflect.Uint8 { // []byte
   120  			f = decodeBytes
   121  		} else {
   122  			f = decodeFuncSliceOf(t, seen)
   123  		}
   124  	case reflect.Map:
   125  		f = decodeFuncMapOf(t, seen)
   126  	case reflect.Struct:
   127  		f = decodeFuncStructOf(t, seen)
   128  	case reflect.Ptr:
   129  		f = decodeFuncPtrOf(t, seen)
   130  	default:
   131  		panic("type cannot be decoded in thrift: " + t.String())
   132  	}
   133  	seen[t] = f
   134  	return f
   135  }
   136  
   137  func decodeBool(r Reader, v reflect.Value, _ flags) error {
   138  	b, err := r.ReadBool()
   139  	if err != nil {
   140  		return err
   141  	}
   142  	v.SetBool(b)
   143  	return nil
   144  }
   145  
   146  func decodeInt8(r Reader, v reflect.Value, _ flags) error {
   147  	i, err := r.ReadInt8()
   148  	if err != nil {
   149  		return err
   150  	}
   151  	v.SetInt(int64(i))
   152  	return nil
   153  }
   154  
   155  func decodeInt16(r Reader, v reflect.Value, _ flags) error {
   156  	i, err := r.ReadInt16()
   157  	if err != nil {
   158  		return err
   159  	}
   160  	v.SetInt(int64(i))
   161  	return nil
   162  }
   163  
   164  func decodeInt32(r Reader, v reflect.Value, _ flags) error {
   165  	i, err := r.ReadInt32()
   166  	if err != nil {
   167  		return err
   168  	}
   169  	v.SetInt(int64(i))
   170  	return nil
   171  }
   172  
   173  func decodeInt64(r Reader, v reflect.Value, _ flags) error {
   174  	i, err := r.ReadInt64()
   175  	if err != nil {
   176  		return err
   177  	}
   178  	v.SetInt(int64(i))
   179  	return nil
   180  }
   181  
   182  func decodeFloat64(r Reader, v reflect.Value, _ flags) error {
   183  	f, err := r.ReadFloat64()
   184  	if err != nil {
   185  		return err
   186  	}
   187  	v.SetFloat(f)
   188  	return nil
   189  }
   190  
   191  func decodeString(r Reader, v reflect.Value, _ flags) error {
   192  	s, err := r.ReadString()
   193  	if err != nil {
   194  		return err
   195  	}
   196  	v.SetString(s)
   197  	return nil
   198  }
   199  
   200  func decodeBytes(r Reader, v reflect.Value, _ flags) error {
   201  	b, err := r.ReadBytes()
   202  	if err != nil {
   203  		return err
   204  	}
   205  	v.SetBytes(b)
   206  	return nil
   207  }
   208  
   209  func decodeFuncSliceOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   210  	elem := t.Elem()
   211  	typ := TypeOf(elem)
   212  	dec := decodeFuncOf(elem, seen)
   213  
   214  	return func(r Reader, v reflect.Value, flags flags) error {
   215  		l, err := r.ReadList()
   216  		if err != nil {
   217  			return err
   218  		}
   219  
   220  		// Sometimes the list type is set to TRUE when the list contains only
   221  		// TRUE values. Thrift does not seem to optimize the encoding by
   222  		// omitting the boolean values that are known to all be TRUE, we still
   223  		// need to decode them.
   224  		switch l.Type {
   225  		case TRUE:
   226  			l.Type = BOOL
   227  		}
   228  
   229  		// TODO: implement type conversions?
   230  		if typ != l.Type {
   231  			if flags.have(strict) {
   232  				return &TypeMismatch{item: "list item", Expect: typ, Found: l.Type}
   233  			}
   234  			return nil
   235  		}
   236  
   237  		v.Set(reflect.MakeSlice(t, int(l.Size), int(l.Size)))
   238  		flags = flags.only(decodeFlags)
   239  
   240  		for i := 0; i < int(l.Size); i++ {
   241  			if err := dec(r, v.Index(i), flags); err != nil {
   242  				return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i})
   243  			}
   244  		}
   245  
   246  		return nil
   247  	}
   248  }
   249  
   250  func decodeFuncMapOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   251  	key, elem := t.Key(), t.Elem()
   252  	if elem.Size() == 0 { // map[?]struct{}
   253  		return decodeFuncMapAsSetOf(t, seen)
   254  	}
   255  
   256  	mapType := reflect.MapOf(key, elem)
   257  	keyZero := reflect.Zero(key)
   258  	elemZero := reflect.Zero(elem)
   259  	keyType := TypeOf(key)
   260  	elemType := TypeOf(elem)
   261  	decodeKey := decodeFuncOf(key, seen)
   262  	decodeElem := decodeFuncOf(elem, seen)
   263  
   264  	return func(r Reader, v reflect.Value, flags flags) error {
   265  		m, err := r.ReadMap()
   266  		if err != nil {
   267  			return err
   268  		}
   269  
   270  		v.Set(reflect.MakeMapWithSize(mapType, int(m.Size)))
   271  
   272  		if m.Size == 0 { // empty map
   273  			return nil
   274  		}
   275  
   276  		// TODO: implement type conversions?
   277  		if keyType != m.Key {
   278  			if flags.have(strict) {
   279  				return &TypeMismatch{item: "map key", Expect: keyType, Found: m.Key}
   280  			}
   281  			return nil
   282  		}
   283  
   284  		if elemType != m.Value {
   285  			if flags.have(strict) {
   286  				return &TypeMismatch{item: "map value", Expect: elemType, Found: m.Value}
   287  			}
   288  			return nil
   289  		}
   290  
   291  		tmpKey := reflect.New(key).Elem()
   292  		tmpElem := reflect.New(elem).Elem()
   293  		flags = flags.only(decodeFlags)
   294  
   295  		for i := 0; i < int(m.Size); i++ {
   296  			if err := decodeKey(r, tmpKey, flags); err != nil {
   297  				return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i})
   298  			}
   299  			if err := decodeElem(r, tmpElem, flags); err != nil {
   300  				return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i})
   301  			}
   302  			v.SetMapIndex(tmpKey, tmpElem)
   303  			tmpKey.Set(keyZero)
   304  			tmpElem.Set(elemZero)
   305  		}
   306  
   307  		return nil
   308  	}
   309  }
   310  
   311  func decodeFuncMapAsSetOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   312  	key, elem := t.Key(), t.Elem()
   313  	keyZero := reflect.Zero(key)
   314  	elemZero := reflect.Zero(elem)
   315  	typ := TypeOf(key)
   316  	dec := decodeFuncOf(key, seen)
   317  
   318  	return func(r Reader, v reflect.Value, flags flags) error {
   319  		s, err := r.ReadSet()
   320  		if err != nil {
   321  			return err
   322  		}
   323  
   324  		// See decodeFuncSliceOf for details about why this type conversion
   325  		// needs to be done.
   326  		switch s.Type {
   327  		case TRUE:
   328  			s.Type = BOOL
   329  		}
   330  
   331  		v.Set(reflect.MakeMapWithSize(t, int(s.Size)))
   332  
   333  		if s.Size == 0 {
   334  			return nil
   335  		}
   336  
   337  		// TODO: implement type conversions?
   338  		if typ != s.Type {
   339  			if flags.have(strict) {
   340  				return &TypeMismatch{item: "list item", Expect: typ, Found: s.Type}
   341  			}
   342  			return nil
   343  		}
   344  
   345  		tmp := reflect.New(key).Elem()
   346  		flags = flags.only(decodeFlags)
   347  
   348  		for i := 0; i < int(s.Size); i++ {
   349  			if err := dec(r, tmp, flags); err != nil {
   350  				return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i})
   351  			}
   352  			v.SetMapIndex(tmp, elemZero)
   353  			tmp.Set(keyZero)
   354  		}
   355  
   356  		return nil
   357  	}
   358  }
   359  
   360  type structDecoder struct {
   361  	fields   []structDecoderField
   362  	union    []int
   363  	minID    int16
   364  	zero     reflect.Value
   365  	required []uint64
   366  }
   367  
   368  func (dec *structDecoder) decode(r Reader, v reflect.Value, flags flags) error {
   369  	flags = flags.only(decodeFlags)
   370  	coalesceBoolFields := flags.have(coalesceBoolFields)
   371  
   372  	lastField := reflect.Value{}
   373  	union := len(dec.union) > 0
   374  	seen := make([]uint64, 1)
   375  	if len(dec.required) > len(seen) {
   376  		seen = make([]uint64, len(dec.required))
   377  	}
   378  
   379  	err := readStruct(r, func(r Reader, f Field) error {
   380  		i := int(f.ID) - int(dec.minID)
   381  		if i < 0 || i >= len(dec.fields) || dec.fields[i].decode == nil {
   382  			return skipField(r, f)
   383  		}
   384  		field := &dec.fields[i]
   385  		seen[i/64] |= 1 << (i % 64)
   386  
   387  		// TODO: implement type conversions?
   388  		if f.Type != field.typ && !(f.Type == TRUE && field.typ == BOOL) {
   389  			if flags.have(strict) {
   390  				return &TypeMismatch{item: "field value", Expect: field.typ, Found: f.Type}
   391  			}
   392  			return nil
   393  		}
   394  
   395  		x := v
   396  		for _, i := range field.index {
   397  			if x.Kind() == reflect.Ptr {
   398  				x = x.Elem()
   399  			}
   400  			if x = x.Field(i); x.Kind() == reflect.Ptr {
   401  				if x.IsNil() {
   402  					x.Set(reflect.New(x.Type().Elem()))
   403  				}
   404  			}
   405  		}
   406  
   407  		if union {
   408  			v.Set(dec.zero)
   409  		}
   410  
   411  		lastField = x
   412  
   413  		if coalesceBoolFields && (f.Type == TRUE || f.Type == FALSE) {
   414  			for x.Kind() == reflect.Ptr {
   415  				if x.IsNil() {
   416  					x.Set(reflect.New(x.Type().Elem()))
   417  				}
   418  				x = x.Elem()
   419  			}
   420  			x.SetBool(f.Type == TRUE)
   421  			return nil
   422  		}
   423  
   424  		return field.decode(r, x, flags.with(field.flags))
   425  	})
   426  	if err != nil {
   427  		return err
   428  	}
   429  
   430  	for i, required := range dec.required {
   431  		if mask := required & seen[i]; mask != required {
   432  			i *= 64
   433  			for (mask & 1) != 0 {
   434  				mask >>= 1
   435  				i++
   436  			}
   437  			field := &dec.fields[i]
   438  			return &MissingField{Field: Field{ID: field.id, Type: field.typ}}
   439  		}
   440  	}
   441  
   442  	if union && lastField.IsValid() {
   443  		v.FieldByIndex(dec.union).Set(lastField.Addr())
   444  	}
   445  
   446  	return nil
   447  }
   448  
   449  type structDecoderField struct {
   450  	index  []int
   451  	id     int16
   452  	flags  flags
   453  	typ    Type
   454  	decode decodeFunc
   455  }
   456  
   457  func decodeFuncStructOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   458  	dec := &structDecoder{
   459  		zero: reflect.Zero(t),
   460  	}
   461  	decode := dec.decode
   462  	seen[t] = decode
   463  
   464  	fields := make([]structDecoderField, 0, t.NumField())
   465  	forEachStructField(t, nil, func(f structField) {
   466  		if f.flags.have(union) {
   467  			dec.union = f.index
   468  		} else {
   469  			fields = append(fields, structDecoderField{
   470  				index:  f.index,
   471  				id:     f.id,
   472  				flags:  f.flags,
   473  				typ:    TypeOf(f.typ),
   474  				decode: decodeFuncStructFieldOf(f, seen),
   475  			})
   476  		}
   477  	})
   478  
   479  	minID := int16(0)
   480  	maxID := int16(0)
   481  
   482  	for _, f := range fields {
   483  		if f.id < minID || minID == 0 {
   484  			minID = f.id
   485  		}
   486  		if f.id > maxID {
   487  			maxID = f.id
   488  		}
   489  	}
   490  
   491  	dec.fields = make([]structDecoderField, (maxID-minID)+1)
   492  	dec.minID = minID
   493  	dec.required = make([]uint64, len(fields)/64+1)
   494  
   495  	for _, f := range fields {
   496  		i := f.id - minID
   497  		p := dec.fields[i]
   498  		if p.decode != nil {
   499  			panic(fmt.Errorf("thrift struct field id %d is present multiple times in %s with types %s and %s", f.id, t, p.typ, f.typ))
   500  		}
   501  		dec.fields[i] = f
   502  		if f.flags.have(required) {
   503  			dec.required[i/64] |= 1 << (i % 64)
   504  		}
   505  	}
   506  
   507  	return decode
   508  }
   509  
   510  func decodeFuncStructFieldOf(f structField, seen decodeFuncCache) decodeFunc {
   511  	if f.flags.have(enum) {
   512  		switch f.typ.Kind() {
   513  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   514  			return decodeInt32
   515  		}
   516  	}
   517  	return decodeFuncOf(f.typ, seen)
   518  }
   519  
   520  func decodeFuncPtrOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   521  	elem := t.Elem()
   522  	decode := decodeFuncOf(t.Elem(), seen)
   523  	return func(r Reader, v reflect.Value, f flags) error {
   524  		if v.IsNil() {
   525  			v.Set(reflect.New(elem))
   526  		}
   527  		return decode(r, v.Elem(), f)
   528  	}
   529  }
   530  
   531  func readBinary(r Reader, f func(io.Reader) error) error {
   532  	n, err := r.ReadLength()
   533  	if err != nil {
   534  		return err
   535  	}
   536  	return dontExpectEOF(f(io.LimitReader(r.Reader(), int64(n))))
   537  }
   538  
   539  func readList(r Reader, f func(Reader, Type) error) error {
   540  	l, err := r.ReadList()
   541  	if err != nil {
   542  		return err
   543  	}
   544  
   545  	for i := 0; i < int(l.Size); i++ {
   546  		if err := f(r, l.Type); err != nil {
   547  			return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i})
   548  		}
   549  	}
   550  
   551  	return nil
   552  }
   553  
   554  func readSet(r Reader, f func(Reader, Type) error) error {
   555  	s, err := r.ReadSet()
   556  	if err != nil {
   557  		return err
   558  	}
   559  
   560  	for i := 0; i < int(s.Size); i++ {
   561  		if err := f(r, s.Type); err != nil {
   562  			return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i})
   563  		}
   564  	}
   565  
   566  	return nil
   567  }
   568  
   569  func readMap(r Reader, f func(Reader, Type, Type) error) error {
   570  	m, err := r.ReadMap()
   571  	if err != nil {
   572  		return err
   573  	}
   574  
   575  	for i := 0; i < int(m.Size); i++ {
   576  		if err := f(r, m.Key, m.Value); err != nil {
   577  			return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i})
   578  		}
   579  	}
   580  
   581  	return nil
   582  }
   583  
   584  func readStruct(r Reader, f func(Reader, Field) error) error {
   585  	lastFieldID := int16(0)
   586  	numFields := 0
   587  
   588  	for {
   589  		x, err := r.ReadField()
   590  		if err != nil {
   591  			if numFields > 0 {
   592  				err = dontExpectEOF(err)
   593  			}
   594  			return err
   595  		}
   596  
   597  		if x.Type == STOP {
   598  			return nil
   599  		}
   600  
   601  		if x.Delta {
   602  			x.ID += lastFieldID
   603  			x.Delta = false
   604  		}
   605  
   606  		if err := f(r, x); err != nil {
   607  			return with(dontExpectEOF(err), &decodeErrorField{cause: x})
   608  		}
   609  
   610  		lastFieldID = x.ID
   611  		numFields++
   612  	}
   613  }
   614  
   615  func skip(r Reader, t Type) error {
   616  	var err error
   617  	switch t {
   618  	case TRUE, FALSE:
   619  		_, err = r.ReadBool()
   620  	case I8:
   621  		_, err = r.ReadInt8()
   622  	case I16:
   623  		_, err = r.ReadInt16()
   624  	case I32:
   625  		_, err = r.ReadInt32()
   626  	case I64:
   627  		_, err = r.ReadInt64()
   628  	case DOUBLE:
   629  		_, err = r.ReadFloat64()
   630  	case BINARY:
   631  		err = skipBinary(r)
   632  	case LIST:
   633  		err = skipList(r)
   634  	case SET:
   635  		err = skipSet(r)
   636  	case MAP:
   637  		err = skipMap(r)
   638  	case STRUCT:
   639  		err = skipStruct(r)
   640  	default:
   641  		return fmt.Errorf("skipping unsupported thrift type %d", t)
   642  	}
   643  	return err
   644  }
   645  
   646  func skipBinary(r Reader) error {
   647  	n, err := r.ReadLength()
   648  	if err != nil {
   649  		return err
   650  	}
   651  	if n == 0 {
   652  		return nil
   653  	}
   654  	switch x := r.Reader().(type) {
   655  	case *bufio.Reader:
   656  		_, err = x.Discard(int(n))
   657  	default:
   658  		_, err = io.CopyN(io.Discard, x, int64(n))
   659  	}
   660  	return dontExpectEOF(err)
   661  }
   662  
   663  func skipList(r Reader) error {
   664  	return readList(r, skip)
   665  }
   666  
   667  func skipSet(r Reader) error {
   668  	return readSet(r, skip)
   669  }
   670  
   671  func skipMap(r Reader) error {
   672  	return readMap(r, func(r Reader, k, v Type) error {
   673  		if err := skip(r, k); err != nil {
   674  			return dontExpectEOF(err)
   675  		}
   676  		if err := skip(r, v); err != nil {
   677  			return dontExpectEOF(err)
   678  		}
   679  		return nil
   680  	})
   681  }
   682  
   683  func skipStruct(r Reader) error {
   684  	return readStruct(r, skipField)
   685  }
   686  
   687  func skipField(r Reader, f Field) error {
   688  	return skip(r, f.Type)
   689  }