github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/thrift/decode.go (about)

     1  package thrift
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"reflect"
    10  	"sync/atomic"
    11  )
    12  
    13  // Unmarshal deserializes the thrift data from b to v using to the protocol p.
    14  //
    15  // The function errors if the data in b does not match the type of v.
    16  //
    17  // The function panics if v cannot be converted to a thrift representation.
    18  //
    19  // As an optimization, the value passed in v may be reused across multiple calls
    20  // to Unmarshal, allowing the function to reuse objects referenced by pointer
    21  // fields of struct values. When reusing objects, the application is responsible
    22  // for resetting the state of v before calling Unmarshal again.
    23  func Unmarshal(p Protocol, b []byte, v interface{}) error {
    24  	br := bytes.NewReader(b)
    25  	pr := p.NewReader(br)
    26  
    27  	if err := NewDecoder(pr).Decode(v); err != nil {
    28  		return err
    29  	}
    30  
    31  	if n := br.Len(); n != 0 {
    32  		return fmt.Errorf("unexpected trailing bytes at the end of thrift input: %d", n)
    33  	}
    34  
    35  	return nil
    36  }
    37  
    38  type Decoder struct {
    39  	r Reader
    40  	f flags
    41  }
    42  
    43  func NewDecoder(r Reader) *Decoder {
    44  	return &Decoder{r: r, f: decoderFlags(r)}
    45  }
    46  
    47  func (d *Decoder) Decode(v interface{}) error {
    48  	t := reflect.TypeOf(v)
    49  	p := reflect.ValueOf(v)
    50  
    51  	if t.Kind() != reflect.Ptr {
    52  		panic("thrift.(*Decoder).Decode: expected pointer type but got " + t.String())
    53  	}
    54  
    55  	t = t.Elem()
    56  	p = p.Elem()
    57  
    58  	cache, _ := decoderCache.Load().(map[typeID]decodeFunc)
    59  	decode, _ := cache[makeTypeID(t)]
    60  
    61  	if decode == nil {
    62  		decode = decodeFuncOf(t, make(decodeFuncCache))
    63  
    64  		newCache := make(map[typeID]decodeFunc, len(cache)+1)
    65  		newCache[makeTypeID(t)] = decode
    66  		for k, v := range cache {
    67  			newCache[k] = v
    68  		}
    69  
    70  		decoderCache.Store(newCache)
    71  	}
    72  
    73  	return decode(d.r, p, d.f)
    74  }
    75  
    76  func (d *Decoder) Reset(r Reader) {
    77  	d.r = r
    78  	d.f = d.f.without(protocolFlags).with(decoderFlags(r))
    79  }
    80  
    81  func (d *Decoder) SetStrict(enabled bool) {
    82  	if enabled {
    83  		d.f = d.f.with(strict)
    84  	} else {
    85  		d.f = d.f.without(strict)
    86  	}
    87  }
    88  
    89  func decoderFlags(r Reader) flags {
    90  	return flags(r.Protocol().Features() << featuresBitOffset)
    91  }
    92  
    93  var decoderCache atomic.Value // map[typeID]decodeFunc
    94  
    95  type decodeFunc func(Reader, reflect.Value, flags) error
    96  
    97  type decodeFuncCache map[reflect.Type]decodeFunc
    98  
    99  func decodeFuncOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   100  	f := seen[t]
   101  	if f != nil {
   102  		return f
   103  	}
   104  	switch t.Kind() {
   105  	case reflect.Bool:
   106  		f = decodeBool
   107  	case reflect.Int8:
   108  		f = decodeInt8
   109  	case reflect.Int16:
   110  		f = decodeInt16
   111  	case reflect.Int32:
   112  		f = decodeInt32
   113  	case reflect.Int64, reflect.Int:
   114  		f = decodeInt64
   115  	case reflect.Float32, reflect.Float64:
   116  		f = decodeFloat64
   117  	case reflect.String:
   118  		f = decodeString
   119  	case reflect.Slice:
   120  		if t.Elem().Kind() == reflect.Uint8 { // []byte
   121  			f = decodeBytes
   122  		} else {
   123  			f = decodeFuncSliceOf(t, seen)
   124  		}
   125  	case reflect.Map:
   126  		f = decodeFuncMapOf(t, seen)
   127  	case reflect.Struct:
   128  		f = decodeFuncStructOf(t, seen)
   129  	case reflect.Ptr:
   130  		f = decodeFuncPtrOf(t, seen)
   131  	default:
   132  		panic("type cannot be decoded in thrift: " + t.String())
   133  	}
   134  	seen[t] = f
   135  	return f
   136  }
   137  
   138  func decodeBool(r Reader, v reflect.Value, _ flags) error {
   139  	b, err := r.ReadBool()
   140  	if err != nil {
   141  		return err
   142  	}
   143  	v.SetBool(b)
   144  	return nil
   145  }
   146  
   147  func decodeInt8(r Reader, v reflect.Value, _ flags) error {
   148  	i, err := r.ReadInt8()
   149  	if err != nil {
   150  		return err
   151  	}
   152  	v.SetInt(int64(i))
   153  	return nil
   154  }
   155  
   156  func decodeInt16(r Reader, v reflect.Value, _ flags) error {
   157  	i, err := r.ReadInt16()
   158  	if err != nil {
   159  		return err
   160  	}
   161  	v.SetInt(int64(i))
   162  	return nil
   163  }
   164  
   165  func decodeInt32(r Reader, v reflect.Value, _ flags) error {
   166  	i, err := r.ReadInt32()
   167  	if err != nil {
   168  		return err
   169  	}
   170  	v.SetInt(int64(i))
   171  	return nil
   172  }
   173  
   174  func decodeInt64(r Reader, v reflect.Value, _ flags) error {
   175  	i, err := r.ReadInt64()
   176  	if err != nil {
   177  		return err
   178  	}
   179  	v.SetInt(int64(i))
   180  	return nil
   181  }
   182  
   183  func decodeFloat64(r Reader, v reflect.Value, _ flags) error {
   184  	f, err := r.ReadFloat64()
   185  	if err != nil {
   186  		return err
   187  	}
   188  	v.SetFloat(f)
   189  	return nil
   190  }
   191  
   192  func decodeString(r Reader, v reflect.Value, _ flags) error {
   193  	s, err := r.ReadString()
   194  	if err != nil {
   195  		return err
   196  	}
   197  	v.SetString(s)
   198  	return nil
   199  }
   200  
   201  func decodeBytes(r Reader, v reflect.Value, _ flags) error {
   202  	b, err := r.ReadBytes()
   203  	if err != nil {
   204  		return err
   205  	}
   206  	v.SetBytes(b)
   207  	return nil
   208  }
   209  
   210  func decodeFuncSliceOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   211  	elem := t.Elem()
   212  	typ := TypeOf(elem)
   213  	dec := decodeFuncOf(elem, seen)
   214  
   215  	return func(r Reader, v reflect.Value, flags flags) error {
   216  		l, err := r.ReadList()
   217  		if err != nil {
   218  			return err
   219  		}
   220  
   221  		// Sometimes the list type is set to TRUE when the list contains only
   222  		// TRUE values. Thrift does not seem to optimize the encoding by
   223  		// omitting the boolean values that are known to all be TRUE, we still
   224  		// need to decode them.
   225  		switch l.Type {
   226  		case TRUE:
   227  			l.Type = BOOL
   228  		}
   229  
   230  		// TODO: implement type conversions?
   231  		if typ != l.Type {
   232  			if flags.have(strict) {
   233  				return &TypeMismatch{item: "list item", Expect: typ, Found: l.Type}
   234  			}
   235  			return nil
   236  		}
   237  
   238  		v.Set(reflect.MakeSlice(t, int(l.Size), int(l.Size)))
   239  		flags = flags.only(decodeFlags)
   240  
   241  		for i := 0; i < int(l.Size); i++ {
   242  			if err := dec(r, v.Index(i), flags); err != nil {
   243  				return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i})
   244  			}
   245  		}
   246  
   247  		return nil
   248  	}
   249  }
   250  
   251  func decodeFuncMapOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   252  	key, elem := t.Key(), t.Elem()
   253  	if elem.Size() == 0 { // map[?]struct{}
   254  		return decodeFuncMapAsSetOf(t, seen)
   255  	}
   256  
   257  	mapType := reflect.MapOf(key, elem)
   258  	keyZero := reflect.Zero(key)
   259  	elemZero := reflect.Zero(elem)
   260  	keyType := TypeOf(key)
   261  	elemType := TypeOf(elem)
   262  	decodeKey := decodeFuncOf(key, seen)
   263  	decodeElem := decodeFuncOf(elem, seen)
   264  
   265  	return func(r Reader, v reflect.Value, flags flags) error {
   266  		m, err := r.ReadMap()
   267  		if err != nil {
   268  			return err
   269  		}
   270  
   271  		v.Set(reflect.MakeMapWithSize(mapType, int(m.Size)))
   272  
   273  		if m.Size == 0 { // empty map
   274  			return nil
   275  		}
   276  
   277  		// TODO: implement type conversions?
   278  		if keyType != m.Key {
   279  			if flags.have(strict) {
   280  				return &TypeMismatch{item: "map key", Expect: keyType, Found: m.Key}
   281  			}
   282  			return nil
   283  		}
   284  
   285  		if elemType != m.Value {
   286  			if flags.have(strict) {
   287  				return &TypeMismatch{item: "map value", Expect: elemType, Found: m.Value}
   288  			}
   289  			return nil
   290  		}
   291  
   292  		tmpKey := reflect.New(key).Elem()
   293  		tmpElem := reflect.New(elem).Elem()
   294  		flags = flags.only(decodeFlags)
   295  
   296  		for i := 0; i < int(m.Size); i++ {
   297  			if err := decodeKey(r, tmpKey, flags); err != nil {
   298  				return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i})
   299  			}
   300  			if err := decodeElem(r, tmpElem, flags); err != nil {
   301  				return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i})
   302  			}
   303  			v.SetMapIndex(tmpKey, tmpElem)
   304  			tmpKey.Set(keyZero)
   305  			tmpElem.Set(elemZero)
   306  		}
   307  
   308  		return nil
   309  	}
   310  }
   311  
   312  func decodeFuncMapAsSetOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   313  	key, elem := t.Key(), t.Elem()
   314  	keyZero := reflect.Zero(key)
   315  	elemZero := reflect.Zero(elem)
   316  	typ := TypeOf(key)
   317  	dec := decodeFuncOf(key, seen)
   318  
   319  	return func(r Reader, v reflect.Value, flags flags) error {
   320  		s, err := r.ReadSet()
   321  		if err != nil {
   322  			return err
   323  		}
   324  
   325  		// See decodeFuncSliceOf for details about why this type conversion
   326  		// needs to be done.
   327  		switch s.Type {
   328  		case TRUE:
   329  			s.Type = BOOL
   330  		}
   331  
   332  		v.Set(reflect.MakeMapWithSize(t, int(s.Size)))
   333  
   334  		if s.Size == 0 {
   335  			return nil
   336  		}
   337  
   338  		// TODO: implement type conversions?
   339  		if typ != s.Type {
   340  			if flags.have(strict) {
   341  				return &TypeMismatch{item: "list item", Expect: typ, Found: s.Type}
   342  			}
   343  			return nil
   344  		}
   345  
   346  		tmp := reflect.New(key).Elem()
   347  		flags = flags.only(decodeFlags)
   348  
   349  		for i := 0; i < int(s.Size); i++ {
   350  			if err := dec(r, tmp, flags); err != nil {
   351  				return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i})
   352  			}
   353  			v.SetMapIndex(tmp, elemZero)
   354  			tmp.Set(keyZero)
   355  		}
   356  
   357  		return nil
   358  	}
   359  }
   360  
   361  type structDecoder struct {
   362  	fields   []structDecoderField
   363  	union    []int
   364  	minID    int16
   365  	zero     reflect.Value
   366  	required []uint64
   367  }
   368  
   369  func (dec *structDecoder) decode(r Reader, v reflect.Value, flags flags) error {
   370  	flags = flags.only(decodeFlags)
   371  	coalesceBoolFields := flags.have(coalesceBoolFields)
   372  
   373  	lastField := reflect.Value{}
   374  	union := len(dec.union) > 0
   375  	seen := make([]uint64, 1)
   376  	if len(dec.required) > len(seen) {
   377  		seen = make([]uint64, len(dec.required))
   378  	}
   379  
   380  	err := readStruct(r, func(r Reader, f Field) error {
   381  		i := int(f.ID) - int(dec.minID)
   382  		if i < 0 || i >= len(dec.fields) || dec.fields[i].decode == nil {
   383  			return skipField(r, f)
   384  		}
   385  		field := &dec.fields[i]
   386  		seen[i/64] |= 1 << (i % 64)
   387  
   388  		// TODO: implement type conversions?
   389  		if f.Type != field.typ && !(f.Type == TRUE && field.typ == BOOL) {
   390  			if flags.have(strict) {
   391  				return &TypeMismatch{item: "field value", Expect: field.typ, Found: f.Type}
   392  			}
   393  			return nil
   394  		}
   395  
   396  		x := v
   397  		for _, i := range field.index {
   398  			if x.Kind() == reflect.Ptr {
   399  				x = x.Elem()
   400  			}
   401  			if x = x.Field(i); x.Kind() == reflect.Ptr {
   402  				if x.IsNil() {
   403  					x.Set(reflect.New(x.Type().Elem()))
   404  				}
   405  			}
   406  		}
   407  
   408  		if union {
   409  			v.Set(dec.zero)
   410  		}
   411  
   412  		lastField = x
   413  
   414  		if coalesceBoolFields && (f.Type == TRUE || f.Type == FALSE) {
   415  			for x.Kind() == reflect.Ptr {
   416  				if x.IsNil() {
   417  					x.Set(reflect.New(x.Type().Elem()))
   418  				}
   419  				x = x.Elem()
   420  			}
   421  			x.SetBool(f.Type == TRUE)
   422  			return nil
   423  		}
   424  
   425  		return field.decode(r, x, flags.with(field.flags))
   426  	})
   427  	if err != nil {
   428  		return err
   429  	}
   430  
   431  	for i, required := range dec.required {
   432  		if mask := required & seen[i]; mask != required {
   433  			i *= 64
   434  			for (mask & 1) != 0 {
   435  				mask >>= 1
   436  				i++
   437  			}
   438  			field := &dec.fields[i]
   439  			return &MissingField{Field: Field{ID: field.id, Type: field.typ}}
   440  		}
   441  	}
   442  
   443  	if union && lastField.IsValid() {
   444  		v.FieldByIndex(dec.union).Set(lastField.Addr())
   445  	}
   446  
   447  	return nil
   448  }
   449  
   450  type structDecoderField struct {
   451  	index  []int
   452  	id     int16
   453  	flags  flags
   454  	typ    Type
   455  	decode decodeFunc
   456  }
   457  
   458  func decodeFuncStructOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   459  	dec := &structDecoder{
   460  		zero: reflect.Zero(t),
   461  	}
   462  	decode := dec.decode
   463  	seen[t] = decode
   464  
   465  	fields := make([]structDecoderField, 0, t.NumField())
   466  	forEachStructField(t, nil, func(f structField) {
   467  		if f.flags.have(union) {
   468  			dec.union = f.index
   469  		} else {
   470  			fields = append(fields, structDecoderField{
   471  				index:  f.index,
   472  				id:     f.id,
   473  				flags:  f.flags,
   474  				typ:    TypeOf(f.typ),
   475  				decode: decodeFuncStructFieldOf(f, seen),
   476  			})
   477  		}
   478  	})
   479  
   480  	minID := int16(0)
   481  	maxID := int16(0)
   482  
   483  	for _, f := range fields {
   484  		if f.id < minID || minID == 0 {
   485  			minID = f.id
   486  		}
   487  		if f.id > maxID {
   488  			maxID = f.id
   489  		}
   490  	}
   491  
   492  	dec.fields = make([]structDecoderField, (maxID-minID)+1)
   493  	dec.minID = minID
   494  	dec.required = make([]uint64, len(fields)/64+1)
   495  
   496  	for _, f := range fields {
   497  		i := f.id - minID
   498  		p := dec.fields[i]
   499  		if p.decode != nil {
   500  			panic(fmt.Errorf("thrift struct field id %d is present multiple times in %s with types %s and %s", f.id, t, p.typ, f.typ))
   501  		}
   502  		dec.fields[i] = f
   503  		if f.flags.have(required) {
   504  			dec.required[i/64] |= 1 << (i % 64)
   505  		}
   506  	}
   507  
   508  	return decode
   509  }
   510  
   511  func decodeFuncStructFieldOf(f structField, seen decodeFuncCache) decodeFunc {
   512  	if f.flags.have(enum) {
   513  		switch f.typ.Kind() {
   514  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   515  			return decodeInt32
   516  		}
   517  	}
   518  	return decodeFuncOf(f.typ, seen)
   519  }
   520  
   521  func decodeFuncPtrOf(t reflect.Type, seen decodeFuncCache) decodeFunc {
   522  	elem := t.Elem()
   523  	decode := decodeFuncOf(t.Elem(), seen)
   524  	return func(r Reader, v reflect.Value, f flags) error {
   525  		if v.IsNil() {
   526  			v.Set(reflect.New(elem))
   527  		}
   528  		return decode(r, v.Elem(), f)
   529  	}
   530  }
   531  
   532  func readBinary(r Reader, f func(io.Reader) error) error {
   533  	n, err := r.ReadLength()
   534  	if err != nil {
   535  		return err
   536  	}
   537  	return dontExpectEOF(f(io.LimitReader(r.Reader(), int64(n))))
   538  }
   539  
   540  func readList(r Reader, f func(Reader, Type) error) error {
   541  	l, err := r.ReadList()
   542  	if err != nil {
   543  		return err
   544  	}
   545  
   546  	for i := 0; i < int(l.Size); i++ {
   547  		if err := f(r, l.Type); err != nil {
   548  			return with(dontExpectEOF(err), &decodeErrorList{cause: l, index: i})
   549  		}
   550  	}
   551  
   552  	return nil
   553  }
   554  
   555  func readSet(r Reader, f func(Reader, Type) error) error {
   556  	s, err := r.ReadSet()
   557  	if err != nil {
   558  		return err
   559  	}
   560  
   561  	for i := 0; i < int(s.Size); i++ {
   562  		if err := f(r, s.Type); err != nil {
   563  			return with(dontExpectEOF(err), &decodeErrorSet{cause: s, index: i})
   564  		}
   565  	}
   566  
   567  	return nil
   568  }
   569  
   570  func readMap(r Reader, f func(Reader, Type, Type) error) error {
   571  	m, err := r.ReadMap()
   572  	if err != nil {
   573  		return err
   574  	}
   575  
   576  	for i := 0; i < int(m.Size); i++ {
   577  		if err := f(r, m.Key, m.Value); err != nil {
   578  			return with(dontExpectEOF(err), &decodeErrorMap{cause: m, index: i})
   579  		}
   580  	}
   581  
   582  	return nil
   583  }
   584  
   585  func readStruct(r Reader, f func(Reader, Field) error) error {
   586  	lastFieldID := int16(0)
   587  	numFields := 0
   588  
   589  	for {
   590  		x, err := r.ReadField()
   591  		if err != nil {
   592  			if numFields > 0 {
   593  				err = dontExpectEOF(err)
   594  			}
   595  			return err
   596  		}
   597  
   598  		if x.Type == STOP {
   599  			return nil
   600  		}
   601  
   602  		if x.Delta {
   603  			x.ID += lastFieldID
   604  			x.Delta = false
   605  		}
   606  
   607  		if err := f(r, x); err != nil {
   608  			return with(dontExpectEOF(err), &decodeErrorField{cause: x})
   609  		}
   610  
   611  		lastFieldID = x.ID
   612  		numFields++
   613  	}
   614  }
   615  
   616  func skip(r Reader, t Type) error {
   617  	var err error
   618  	switch t {
   619  	case TRUE, FALSE:
   620  		_, err = r.ReadBool()
   621  	case I8:
   622  		_, err = r.ReadInt8()
   623  	case I16:
   624  		_, err = r.ReadInt16()
   625  	case I32:
   626  		_, err = r.ReadInt32()
   627  	case I64:
   628  		_, err = r.ReadInt64()
   629  	case DOUBLE:
   630  		_, err = r.ReadFloat64()
   631  	case BINARY:
   632  		err = skipBinary(r)
   633  	case LIST:
   634  		err = skipList(r)
   635  	case SET:
   636  		err = skipSet(r)
   637  	case MAP:
   638  		err = skipMap(r)
   639  	case STRUCT:
   640  		err = skipStruct(r)
   641  	default:
   642  		return fmt.Errorf("skipping unsupported thrift type %d", t)
   643  	}
   644  	return err
   645  }
   646  
   647  func skipBinary(r Reader) error {
   648  	n, err := r.ReadLength()
   649  	if err != nil {
   650  		return err
   651  	}
   652  	if n == 0 {
   653  		return nil
   654  	}
   655  	switch x := r.Reader().(type) {
   656  	case *bufio.Reader:
   657  		_, err = x.Discard(int(n))
   658  	default:
   659  		_, err = io.CopyN(ioutil.Discard, x, int64(n))
   660  	}
   661  	return dontExpectEOF(err)
   662  }
   663  
   664  func skipList(r Reader) error {
   665  	return readList(r, skip)
   666  }
   667  
   668  func skipSet(r Reader) error {
   669  	return readSet(r, skip)
   670  }
   671  
   672  func skipMap(r Reader) error {
   673  	return readMap(r, func(r Reader, k, v Type) error {
   674  		if err := skip(r, k); err != nil {
   675  			return dontExpectEOF(err)
   676  		}
   677  		if err := skip(r, v); err != nil {
   678  			return dontExpectEOF(err)
   679  		}
   680  		return nil
   681  	})
   682  }
   683  
   684  func skipStruct(r Reader) error {
   685  	return readStruct(r, skipField)
   686  }
   687  
   688  func skipField(r Reader, f Field) error {
   689  	return skip(r, f.Type)
   690  }