github.com/segmentio/encoding@v0.4.0/thrift/encode.go (about) 1 package thrift 2 3 import ( 4 "bytes" 5 "fmt" 6 "math" 7 "reflect" 8 "sort" 9 "sync/atomic" 10 ) 11 12 // Marshal serializes v into a thrift representation according to the the 13 // protocol p. 14 // 15 // The function panics if v cannot be converted to a thrift representation. 16 func Marshal(p Protocol, v interface{}) ([]byte, error) { 17 buf := new(bytes.Buffer) 18 enc := NewEncoder(p.NewWriter(buf)) 19 err := enc.Encode(v) 20 return buf.Bytes(), err 21 } 22 23 type Encoder struct { 24 w Writer 25 f flags 26 } 27 28 func NewEncoder(w Writer) *Encoder { 29 return &Encoder{w: w, f: encoderFlags(w)} 30 } 31 32 func (e *Encoder) Encode(v interface{}) error { 33 t := reflect.TypeOf(v) 34 cache, _ := encoderCache.Load().(map[typeID]encodeFunc) 35 encode, _ := cache[makeTypeID(t)] 36 37 if encode == nil { 38 encode = encodeFuncOf(t, make(encodeFuncCache)) 39 40 newCache := make(map[typeID]encodeFunc, len(cache)+1) 41 newCache[makeTypeID(t)] = encode 42 for k, v := range cache { 43 newCache[k] = v 44 } 45 46 encoderCache.Store(newCache) 47 } 48 49 return encode(e.w, reflect.ValueOf(v), e.f) 50 } 51 52 func (e *Encoder) Reset(w Writer) { 53 e.w = w 54 e.f = e.f.without(protocolFlags).with(encoderFlags(w)) 55 } 56 57 func encoderFlags(w Writer) flags { 58 return flags(w.Protocol().Features() << featuresBitOffset) 59 } 60 61 var encoderCache atomic.Value // map[typeID]encodeFunc 62 63 type encodeFunc func(Writer, reflect.Value, flags) error 64 65 type encodeFuncCache map[reflect.Type]encodeFunc 66 67 func encodeFuncOf(t reflect.Type, seen encodeFuncCache) encodeFunc { 68 f := seen[t] 69 if f != nil { 70 return f 71 } 72 switch t.Kind() { 73 case reflect.Bool: 74 f = encodeBool 75 case reflect.Int8: 76 f = encodeInt8 77 case reflect.Int16: 78 f = encodeInt16 79 case reflect.Int32: 80 f = encodeInt32 81 case reflect.Int64, reflect.Int: 82 f = encodeInt64 83 case reflect.Float32, reflect.Float64: 84 f = encodeFloat64 85 case reflect.String: 86 f = encodeString 87 case reflect.Slice: 88 if t.Elem().Kind() == reflect.Uint8 { 89 f = encodeBytes 90 } else { 91 f = encodeFuncSliceOf(t, seen) 92 } 93 case reflect.Map: 94 f = encodeFuncMapOf(t, seen) 95 case reflect.Struct: 96 f = encodeFuncStructOf(t, seen) 97 case reflect.Ptr: 98 f = encodeFuncPtrOf(t, seen) 99 default: 100 panic("type cannot be encoded in thrift: " + t.String()) 101 } 102 seen[t] = f 103 return f 104 } 105 106 func encodeBool(w Writer, v reflect.Value, _ flags) error { 107 return w.WriteBool(v.Bool()) 108 } 109 110 func encodeInt8(w Writer, v reflect.Value, _ flags) error { 111 return w.WriteInt8(int8(v.Int())) 112 } 113 114 func encodeInt16(w Writer, v reflect.Value, _ flags) error { 115 return w.WriteInt16(int16(v.Int())) 116 } 117 118 func encodeInt32(w Writer, v reflect.Value, _ flags) error { 119 return w.WriteInt32(int32(v.Int())) 120 } 121 122 func encodeInt64(w Writer, v reflect.Value, _ flags) error { 123 return w.WriteInt64(v.Int()) 124 } 125 126 func encodeFloat64(w Writer, v reflect.Value, _ flags) error { 127 return w.WriteFloat64(v.Float()) 128 } 129 130 func encodeString(w Writer, v reflect.Value, _ flags) error { 131 return w.WriteString(v.String()) 132 } 133 134 func encodeBytes(w Writer, v reflect.Value, _ flags) error { 135 return w.WriteBytes(v.Bytes()) 136 } 137 138 func encodeFuncSliceOf(t reflect.Type, seen encodeFuncCache) encodeFunc { 139 elem := t.Elem() 140 typ := TypeOf(elem) 141 enc := encodeFuncOf(elem, seen) 142 143 return func(w Writer, v reflect.Value, flags flags) error { 144 n := v.Len() 145 if n > math.MaxInt32 { 146 return fmt.Errorf("slice length is too large to be represented in thrift: %d > max(int32)", n) 147 } 148 149 err := w.WriteList(List{ 150 Size: int32(n), 151 Type: typ, 152 }) 153 if err != nil { 154 return err 155 } 156 157 for i := 0; i < n; i++ { 158 if err := enc(w, v.Index(i), flags); err != nil { 159 return err 160 } 161 } 162 163 return nil 164 } 165 } 166 167 func encodeFuncMapOf(t reflect.Type, seen encodeFuncCache) encodeFunc { 168 key, elem := t.Key(), t.Elem() 169 if elem.Size() == 0 { // map[?]struct{} 170 return encodeFuncMapAsSetOf(t, seen) 171 } 172 173 keyType := TypeOf(key) 174 elemType := TypeOf(elem) 175 encodeKey := encodeFuncOf(key, seen) 176 encodeElem := encodeFuncOf(elem, seen) 177 178 return func(w Writer, v reflect.Value, flags flags) error { 179 n := v.Len() 180 if n > math.MaxInt32 { 181 return fmt.Errorf("map length is too large to be represented in thrift: %d > max(int32)", n) 182 } 183 184 err := w.WriteMap(Map{ 185 Size: int32(n), 186 Key: keyType, 187 Value: elemType, 188 }) 189 if err != nil { 190 return err 191 } 192 if n == 0 { // empty map 193 return nil 194 } 195 196 for i, iter := 0, v.MapRange(); iter.Next(); i++ { 197 if err := encodeKey(w, iter.Key(), flags); err != nil { 198 return err 199 } 200 if err := encodeElem(w, iter.Value(), flags); err != nil { 201 return err 202 } 203 } 204 205 return nil 206 } 207 } 208 209 func encodeFuncMapAsSetOf(t reflect.Type, seen encodeFuncCache) encodeFunc { 210 key := t.Key() 211 typ := TypeOf(key) 212 enc := encodeFuncOf(key, seen) 213 214 return func(w Writer, v reflect.Value, flags flags) error { 215 n := v.Len() 216 if n > math.MaxInt32 { 217 return fmt.Errorf("map length is too large to be represented in thrift: %d > max(int32)", n) 218 } 219 220 err := w.WriteSet(Set{ 221 Size: int32(n), 222 Type: typ, 223 }) 224 if err != nil { 225 return err 226 } 227 if n == 0 { // empty map 228 return nil 229 } 230 231 for i, iter := 0, v.MapRange(); iter.Next(); i++ { 232 if err := enc(w, iter.Key(), flags); err != nil { 233 return err 234 } 235 } 236 237 return nil 238 } 239 } 240 241 type structEncoder struct { 242 fields []structEncoderField 243 union bool 244 } 245 246 func dereference(v reflect.Value) reflect.Value { 247 for v.Kind() == reflect.Ptr { 248 if v.IsNil() { 249 return v 250 } 251 v = v.Elem() 252 } 253 return v 254 } 255 256 func isTrue(v reflect.Value) bool { 257 v = dereference(v) 258 return v.IsValid() && v.Kind() == reflect.Bool && v.Bool() 259 } 260 261 func (enc *structEncoder) encode(w Writer, v reflect.Value, flags flags) error { 262 useDeltaEncoding := flags.have(useDeltaEncoding) 263 coalesceBoolFields := flags.have(coalesceBoolFields) 264 numFields := int16(0) 265 lastFieldID := int16(0) 266 267 encodeFields: 268 for _, f := range enc.fields { 269 x := v 270 for _, i := range f.index { 271 if x.Kind() == reflect.Ptr { 272 x = x.Elem() 273 } 274 if x = x.Field(i); x.Kind() == reflect.Ptr { 275 if x.IsNil() { 276 continue encodeFields 277 } 278 } 279 } 280 281 if !f.flags.have(required) && x.IsZero() { 282 continue encodeFields 283 } 284 285 field := Field{ 286 ID: f.id, 287 Type: f.typ, 288 } 289 290 if useDeltaEncoding { 291 if delta := field.ID - lastFieldID; delta <= 15 { 292 field.ID = delta 293 field.Delta = true 294 } 295 } 296 297 skipValue := coalesceBoolFields && field.Type == BOOL 298 if skipValue && isTrue(x) == true { 299 field.Type = TRUE 300 } 301 302 if err := w.WriteField(field); err != nil { 303 return err 304 } 305 306 if !skipValue { 307 if err := f.encode(w, x, flags); err != nil { 308 return err 309 } 310 } 311 312 numFields++ 313 lastFieldID = f.id 314 } 315 316 if err := w.WriteField(Field{Type: STOP}); err != nil { 317 return err 318 } 319 320 if numFields > 1 && enc.union { 321 return fmt.Errorf("thrift union had more than one field with a non-zero value (%d)", numFields) 322 } 323 324 return nil 325 } 326 327 func (enc *structEncoder) String() string { 328 if enc.union { 329 return "union" 330 } 331 return "struct" 332 } 333 334 type structEncoderField struct { 335 index []int 336 id int16 337 flags flags 338 typ Type 339 encode encodeFunc 340 } 341 342 func encodeFuncStructOf(t reflect.Type, seen encodeFuncCache) encodeFunc { 343 enc := &structEncoder{ 344 fields: make([]structEncoderField, 0, t.NumField()), 345 } 346 encode := enc.encode 347 seen[t] = encode 348 349 forEachStructField(t, nil, func(f structField) { 350 if f.flags.have(union) { 351 enc.union = true 352 } else { 353 enc.fields = append(enc.fields, structEncoderField{ 354 index: f.index, 355 id: f.id, 356 flags: f.flags, 357 typ: TypeOf(f.typ), 358 encode: encodeFuncStructFieldOf(f, seen), 359 }) 360 } 361 }) 362 363 sort.SliceStable(enc.fields, func(i, j int) bool { 364 return enc.fields[i].id < enc.fields[j].id 365 }) 366 367 for i := len(enc.fields) - 1; i > 0; i-- { 368 if enc.fields[i-1].id == enc.fields[i].id { 369 panic(fmt.Errorf("thrift struct field id %d is present multiple times", enc.fields[i].id)) 370 } 371 } 372 373 return encode 374 } 375 376 func encodeFuncStructFieldOf(f structField, seen encodeFuncCache) encodeFunc { 377 if f.flags.have(enum) { 378 switch f.typ.Kind() { 379 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 380 return encodeInt32 381 } 382 } 383 return encodeFuncOf(f.typ, seen) 384 } 385 386 func encodeFuncPtrOf(t reflect.Type, seen encodeFuncCache) encodeFunc { 387 typ := t.Elem() 388 enc := encodeFuncOf(typ, seen) 389 zero := reflect.Zero(typ) 390 391 return func(w Writer, v reflect.Value, f flags) error { 392 if v.IsNil() { 393 v = zero 394 } else { 395 v = v.Elem() 396 } 397 return enc(w, v, f) 398 } 399 }