github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/thrift/encode.go (about)

     1  package thrift
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"math"
     7  	"reflect"
     8  	"sort"
     9  	"sync/atomic"
    10  )
    11  
    12  // Marshal serializes v into a thrift representation according to the the
    13  // protocol p.
    14  //
    15  // The function panics if v cannot be converted to a thrift representation.
    16  func Marshal(p Protocol, v interface{}) ([]byte, error) {
    17  	buf := new(bytes.Buffer)
    18  	enc := NewEncoder(p.NewWriter(buf))
    19  	err := enc.Encode(v)
    20  	return buf.Bytes(), err
    21  }
    22  
    23  type Encoder struct {
    24  	w Writer
    25  	f flags
    26  }
    27  
    28  func NewEncoder(w Writer) *Encoder {
    29  	return &Encoder{w: w, f: encoderFlags(w)}
    30  }
    31  
    32  func (e *Encoder) Encode(v interface{}) error {
    33  	t := reflect.TypeOf(v)
    34  	cache, _ := encoderCache.Load().(map[typeID]encodeFunc)
    35  	encode, _ := cache[makeTypeID(t)]
    36  
    37  	if encode == nil {
    38  		encode = encodeFuncOf(t, make(encodeFuncCache))
    39  
    40  		newCache := make(map[typeID]encodeFunc, len(cache)+1)
    41  		newCache[makeTypeID(t)] = encode
    42  		for k, v := range cache {
    43  			newCache[k] = v
    44  		}
    45  
    46  		encoderCache.Store(newCache)
    47  	}
    48  
    49  	return encode(e.w, reflect.ValueOf(v), e.f)
    50  }
    51  
    52  func (e *Encoder) Reset(w Writer) {
    53  	e.w = w
    54  	e.f = e.f.without(protocolFlags).with(encoderFlags(w))
    55  }
    56  
    57  func encoderFlags(w Writer) flags {
    58  	return flags(w.Protocol().Features() << featuresBitOffset)
    59  }
    60  
    61  var encoderCache atomic.Value // map[typeID]encodeFunc
    62  
    63  type encodeFunc func(Writer, reflect.Value, flags) error
    64  
    65  type encodeFuncCache map[reflect.Type]encodeFunc
    66  
    67  func encodeFuncOf(t reflect.Type, seen encodeFuncCache) encodeFunc {
    68  	f := seen[t]
    69  	if f != nil {
    70  		return f
    71  	}
    72  	switch t.Kind() {
    73  	case reflect.Bool:
    74  		f = encodeBool
    75  	case reflect.Int8:
    76  		f = encodeInt8
    77  	case reflect.Int16:
    78  		f = encodeInt16
    79  	case reflect.Int32:
    80  		f = encodeInt32
    81  	case reflect.Int64, reflect.Int:
    82  		f = encodeInt64
    83  	case reflect.Float32, reflect.Float64:
    84  		f = encodeFloat64
    85  	case reflect.String:
    86  		f = encodeString
    87  	case reflect.Slice:
    88  		if t.Elem().Kind() == reflect.Uint8 {
    89  			f = encodeBytes
    90  		} else {
    91  			f = encodeFuncSliceOf(t, seen)
    92  		}
    93  	case reflect.Map:
    94  		f = encodeFuncMapOf(t, seen)
    95  	case reflect.Struct:
    96  		f = encodeFuncStructOf(t, seen)
    97  	case reflect.Ptr:
    98  		f = encodeFuncPtrOf(t, seen)
    99  	default:
   100  		panic("type cannot be encoded in thrift: " + t.String())
   101  	}
   102  	seen[t] = f
   103  	return f
   104  }
   105  
   106  func encodeBool(w Writer, v reflect.Value, _ flags) error {
   107  	return w.WriteBool(v.Bool())
   108  }
   109  
   110  func encodeInt8(w Writer, v reflect.Value, _ flags) error {
   111  	return w.WriteInt8(int8(v.Int()))
   112  }
   113  
   114  func encodeInt16(w Writer, v reflect.Value, _ flags) error {
   115  	return w.WriteInt16(int16(v.Int()))
   116  }
   117  
   118  func encodeInt32(w Writer, v reflect.Value, _ flags) error {
   119  	return w.WriteInt32(int32(v.Int()))
   120  }
   121  
   122  func encodeInt64(w Writer, v reflect.Value, _ flags) error {
   123  	return w.WriteInt64(v.Int())
   124  }
   125  
   126  func encodeFloat64(w Writer, v reflect.Value, _ flags) error {
   127  	return w.WriteFloat64(v.Float())
   128  }
   129  
   130  func encodeString(w Writer, v reflect.Value, _ flags) error {
   131  	return w.WriteString(v.String())
   132  }
   133  
   134  func encodeBytes(w Writer, v reflect.Value, _ flags) error {
   135  	return w.WriteBytes(v.Bytes())
   136  }
   137  
   138  func encodeFuncSliceOf(t reflect.Type, seen encodeFuncCache) encodeFunc {
   139  	elem := t.Elem()
   140  	typ := TypeOf(elem)
   141  	enc := encodeFuncOf(elem, seen)
   142  
   143  	return func(w Writer, v reflect.Value, flags flags) error {
   144  		n := v.Len()
   145  		if n > math.MaxInt32 {
   146  			return fmt.Errorf("slice length is too large to be represented in thrift: %d > max(int32)", n)
   147  		}
   148  
   149  		err := w.WriteList(List{
   150  			Size: int32(n),
   151  			Type: typ,
   152  		})
   153  		if err != nil {
   154  			return err
   155  		}
   156  
   157  		for i := 0; i < n; i++ {
   158  			if err := enc(w, v.Index(i), flags); err != nil {
   159  				return err
   160  			}
   161  		}
   162  
   163  		return nil
   164  	}
   165  }
   166  
   167  func encodeFuncMapOf(t reflect.Type, seen encodeFuncCache) encodeFunc {
   168  	key, elem := t.Key(), t.Elem()
   169  	if elem.Size() == 0 { // map[?]struct{}
   170  		return encodeFuncMapAsSetOf(t, seen)
   171  	}
   172  
   173  	keyType := TypeOf(key)
   174  	elemType := TypeOf(elem)
   175  	encodeKey := encodeFuncOf(key, seen)
   176  	encodeElem := encodeFuncOf(elem, seen)
   177  
   178  	return func(w Writer, v reflect.Value, flags flags) error {
   179  		n := v.Len()
   180  		if n > math.MaxInt32 {
   181  			return fmt.Errorf("map length is too large to be represented in thrift: %d > max(int32)", n)
   182  		}
   183  
   184  		err := w.WriteMap(Map{
   185  			Size:  int32(n),
   186  			Key:   keyType,
   187  			Value: elemType,
   188  		})
   189  		if err != nil {
   190  			return err
   191  		}
   192  		if n == 0 { // empty map
   193  			return nil
   194  		}
   195  
   196  		for i, iter := 0, v.MapRange(); iter.Next(); i++ {
   197  			if err := encodeKey(w, iter.Key(), flags); err != nil {
   198  				return err
   199  			}
   200  			if err := encodeElem(w, iter.Value(), flags); err != nil {
   201  				return err
   202  			}
   203  		}
   204  
   205  		return nil
   206  	}
   207  }
   208  
   209  func encodeFuncMapAsSetOf(t reflect.Type, seen encodeFuncCache) encodeFunc {
   210  	key := t.Key()
   211  	typ := TypeOf(key)
   212  	enc := encodeFuncOf(key, seen)
   213  
   214  	return func(w Writer, v reflect.Value, flags flags) error {
   215  		n := v.Len()
   216  		if n > math.MaxInt32 {
   217  			return fmt.Errorf("map length is too large to be represented in thrift: %d > max(int32)", n)
   218  		}
   219  
   220  		err := w.WriteSet(Set{
   221  			Size: int32(n),
   222  			Type: typ,
   223  		})
   224  		if err != nil {
   225  			return err
   226  		}
   227  		if n == 0 { // empty map
   228  			return nil
   229  		}
   230  
   231  		for i, iter := 0, v.MapRange(); iter.Next(); i++ {
   232  			if err := enc(w, iter.Key(), flags); err != nil {
   233  				return err
   234  			}
   235  		}
   236  
   237  		return nil
   238  	}
   239  }
   240  
   241  type structEncoder struct {
   242  	fields []structEncoderField
   243  	union  bool
   244  }
   245  
   246  func dereference(v reflect.Value) reflect.Value {
   247  	for v.Kind() == reflect.Ptr {
   248  		if v.IsNil() {
   249  			return v
   250  		}
   251  		v = v.Elem()
   252  	}
   253  	return v
   254  }
   255  
   256  func isTrue(v reflect.Value) bool {
   257  	v = dereference(v)
   258  	return v.IsValid() && v.Kind() == reflect.Bool && v.Bool()
   259  }
   260  
   261  func (enc *structEncoder) encode(w Writer, v reflect.Value, flags flags) error {
   262  	useDeltaEncoding := flags.have(useDeltaEncoding)
   263  	coalesceBoolFields := flags.have(coalesceBoolFields)
   264  	numFields := int16(0)
   265  	lastFieldID := int16(0)
   266  
   267  encodeFields:
   268  	for _, f := range enc.fields {
   269  		x := v
   270  		for _, i := range f.index {
   271  			if x.Kind() == reflect.Ptr {
   272  				x = x.Elem()
   273  			}
   274  			if x = x.Field(i); x.Kind() == reflect.Ptr {
   275  				if x.IsNil() {
   276  					continue encodeFields
   277  				}
   278  			}
   279  		}
   280  
   281  		if !f.flags.have(required) && x.IsZero() {
   282  			continue encodeFields
   283  		}
   284  
   285  		field := Field{
   286  			ID:   f.id,
   287  			Type: f.typ,
   288  		}
   289  
   290  		if useDeltaEncoding {
   291  			if delta := field.ID - lastFieldID; delta <= 15 {
   292  				field.ID = delta
   293  				field.Delta = true
   294  			}
   295  		}
   296  
   297  		skipValue := coalesceBoolFields && field.Type == BOOL
   298  		if skipValue && isTrue(x) == true {
   299  			field.Type = TRUE
   300  		}
   301  
   302  		if err := w.WriteField(field); err != nil {
   303  			return err
   304  		}
   305  
   306  		if !skipValue {
   307  			if err := f.encode(w, x, flags); err != nil {
   308  				return err
   309  			}
   310  		}
   311  
   312  		numFields++
   313  		lastFieldID = f.id
   314  	}
   315  
   316  	if err := w.WriteField(Field{Type: STOP}); err != nil {
   317  		return err
   318  	}
   319  
   320  	if numFields > 1 && enc.union {
   321  		return fmt.Errorf("thrift union had more than one field with a non-zero value (%d)", numFields)
   322  	}
   323  
   324  	return nil
   325  }
   326  
   327  func (enc *structEncoder) String() string {
   328  	if enc.union {
   329  		return "union"
   330  	}
   331  	return "struct"
   332  }
   333  
   334  type structEncoderField struct {
   335  	index  []int
   336  	id     int16
   337  	flags  flags
   338  	typ    Type
   339  	encode encodeFunc
   340  }
   341  
   342  func encodeFuncStructOf(t reflect.Type, seen encodeFuncCache) encodeFunc {
   343  	enc := &structEncoder{
   344  		fields: make([]structEncoderField, 0, t.NumField()),
   345  	}
   346  	encode := enc.encode
   347  	seen[t] = encode
   348  
   349  	forEachStructField(t, nil, func(f structField) {
   350  		if f.flags.have(union) {
   351  			enc.union = true
   352  		} else {
   353  			enc.fields = append(enc.fields, structEncoderField{
   354  				index:  f.index,
   355  				id:     f.id,
   356  				flags:  f.flags,
   357  				typ:    TypeOf(f.typ),
   358  				encode: encodeFuncStructFieldOf(f, seen),
   359  			})
   360  		}
   361  	})
   362  
   363  	sort.SliceStable(enc.fields, func(i, j int) bool {
   364  		return enc.fields[i].id < enc.fields[j].id
   365  	})
   366  
   367  	for i := len(enc.fields) - 1; i > 0; i-- {
   368  		if enc.fields[i-1].id == enc.fields[i].id {
   369  			panic(fmt.Errorf("thrift struct field id %d is present multiple times", enc.fields[i].id))
   370  		}
   371  	}
   372  
   373  	return encode
   374  }
   375  
   376  func encodeFuncStructFieldOf(f structField, seen encodeFuncCache) encodeFunc {
   377  	if f.flags.have(enum) {
   378  		switch f.typ.Kind() {
   379  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   380  			return encodeInt32
   381  		}
   382  	}
   383  	return encodeFuncOf(f.typ, seen)
   384  }
   385  
   386  func encodeFuncPtrOf(t reflect.Type, seen encodeFuncCache) encodeFunc {
   387  	typ := t.Elem()
   388  	enc := encodeFuncOf(typ, seen)
   389  	zero := reflect.Zero(typ)
   390  
   391  	return func(w Writer, v reflect.Value, f flags) error {
   392  		if v.IsNil() {
   393  			v = zero
   394  		} else {
   395  			v = v.Elem()
   396  		}
   397  		return enc(w, v, f)
   398  	}
   399  }