github.com/segmentio/encoding@v0.4.0/proto/struct.go (about)

     1  package proto
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"reflect"
     7  	"unsafe"
     8  )
     9  
    10  const (
    11  	embedded = 1 << 0
    12  	repeated = 1 << 1
    13  	zigzag   = 1 << 2
    14  )
    15  
    16  type structField struct {
    17  	number  uint16
    18  	tagsize uint8
    19  	flags   uint8
    20  	offset  uint32
    21  	codec   *codec
    22  }
    23  
    24  func (f *structField) String() string {
    25  	return fmt.Sprintf("[%d,%s]", f.fieldNumber(), f.wireType())
    26  }
    27  
    28  func (f *structField) fieldNumber() fieldNumber {
    29  	return fieldNumber(f.number)
    30  }
    31  
    32  func (f *structField) wireType() wireType {
    33  	return f.codec.wire
    34  }
    35  
    36  func (f *structField) embedded() bool {
    37  	return (f.flags & embedded) != 0
    38  }
    39  
    40  func (f *structField) repeated() bool {
    41  	return (f.flags & repeated) != 0
    42  }
    43  
    44  func (f *structField) pointer(p unsafe.Pointer) unsafe.Pointer {
    45  	return unsafe.Pointer(uintptr(p) + uintptr(f.offset))
    46  }
    47  
    48  func (f *structField) makeFlags(base flags) flags {
    49  	return base | flags(f.flags&zigzag)
    50  }
    51  
    52  func structCodecOf(t reflect.Type, seen map[reflect.Type]*codec) *codec {
    53  	c := &codec{wire: varlen}
    54  	seen[t] = c
    55  
    56  	numField := t.NumField()
    57  	number := fieldNumber(1)
    58  	fields := make([]structField, 0, numField)
    59  
    60  	for i := 0; i < numField; i++ {
    61  		f := t.Field(i)
    62  
    63  		if f.PkgPath != "" {
    64  			continue // unexported
    65  		}
    66  
    67  		field := structField{
    68  			number: uint16(number),
    69  			offset: uint32(f.Offset),
    70  		}
    71  
    72  		if tag, ok := f.Tag.Lookup("protobuf"); ok {
    73  			t, err := parseStructTag(tag)
    74  			if err == nil {
    75  				field.number = uint16(t.fieldNumber)
    76  				if t.repeated {
    77  					field.flags |= repeated
    78  				}
    79  				if t.zigzag {
    80  					field.flags |= zigzag
    81  				}
    82  				switch t.wireType {
    83  				case Fixed32:
    84  					switch baseKindOf(f.Type) {
    85  					case reflect.Uint32:
    86  						field.codec = &fixed32Codec
    87  					case reflect.Float32:
    88  						field.codec = &float32Codec
    89  					}
    90  				case Fixed64:
    91  					switch baseKindOf(f.Type) {
    92  					case reflect.Uint64:
    93  						field.codec = &fixed64Codec
    94  					case reflect.Float64:
    95  						field.codec = &float64Codec
    96  					}
    97  				}
    98  			}
    99  		}
   100  
   101  		if field.codec == nil {
   102  			switch baseKindOf(f.Type) {
   103  			case reflect.Struct:
   104  				field.flags |= embedded
   105  				field.codec = codecOf(f.Type, seen)
   106  
   107  			case reflect.Slice:
   108  				elem := f.Type.Elem()
   109  
   110  				if elem.Kind() == reflect.Uint8 { // []byte
   111  					field.codec = codecOf(f.Type, seen)
   112  				} else {
   113  					if baseKindOf(elem) == reflect.Struct {
   114  						field.flags |= embedded
   115  					}
   116  					field.flags |= repeated
   117  					field.codec = codecOf(elem, seen)
   118  					field.codec = sliceCodecOf(f.Type, field, seen)
   119  				}
   120  
   121  			case reflect.Map:
   122  				key, val := f.Type.Key(), f.Type.Elem()
   123  				k := codecOf(key, seen)
   124  				v := codecOf(val, seen)
   125  				m := &mapField{
   126  					number:   field.number,
   127  					keyCodec: k,
   128  					valCodec: v,
   129  				}
   130  				if baseKindOf(key) == reflect.Struct {
   131  					m.keyFlags |= embedded
   132  				}
   133  				if baseKindOf(val) == reflect.Struct {
   134  					m.valFlags |= embedded
   135  				}
   136  				field.flags |= embedded | repeated
   137  				field.codec = mapCodecOf(f.Type, m, seen)
   138  
   139  			default:
   140  				field.codec = codecOf(f.Type, seen)
   141  			}
   142  		}
   143  
   144  		field.tagsize = uint8(sizeOfTag(fieldNumber(field.number), wireType(field.codec.wire)))
   145  		fields = append(fields, field)
   146  		number++
   147  	}
   148  
   149  	c.size = structSizeFuncOf(t, fields)
   150  	c.encode = structEncodeFuncOf(t, fields)
   151  	c.decode = structDecodeFuncOf(t, fields)
   152  	return c
   153  }
   154  
   155  func baseKindOf(t reflect.Type) reflect.Kind {
   156  	return baseTypeOf(t).Kind()
   157  }
   158  
   159  func baseTypeOf(t reflect.Type) reflect.Type {
   160  	for t.Kind() == reflect.Ptr {
   161  		t = t.Elem()
   162  	}
   163  	return t
   164  }
   165  
   166  func structSizeFuncOf(t reflect.Type, fields []structField) sizeFunc {
   167  	var inlined = inlined(t)
   168  	var unique, repeated []*structField
   169  
   170  	for i := range fields {
   171  		f := &fields[i]
   172  		if f.repeated() {
   173  			repeated = append(repeated, f)
   174  		} else {
   175  			unique = append(unique, f)
   176  		}
   177  	}
   178  
   179  	return func(p unsafe.Pointer, flags flags) int {
   180  		if p == nil {
   181  			return 0
   182  		}
   183  
   184  		if !inlined {
   185  			flags = flags.without(inline | toplevel)
   186  		} else {
   187  			flags = flags.without(toplevel)
   188  		}
   189  		n := 0
   190  
   191  		for _, f := range unique {
   192  			size := f.codec.size(f.pointer(p), f.makeFlags(flags))
   193  			if size > 0 {
   194  				n += int(f.tagsize) + size
   195  				if f.embedded() {
   196  					n += sizeOfVarint(uint64(size))
   197  				}
   198  				flags = flags.without(wantzero)
   199  			}
   200  		}
   201  
   202  		for _, f := range repeated {
   203  			size := f.codec.size(f.pointer(p), f.makeFlags(flags))
   204  			if size > 0 {
   205  				n += size
   206  				flags = flags.without(wantzero)
   207  			}
   208  		}
   209  
   210  		return n
   211  	}
   212  }
   213  
   214  func structEncodeFuncOf(t reflect.Type, fields []structField) encodeFunc {
   215  	var inlined = inlined(t)
   216  	var unique, repeated []*structField
   217  
   218  	for i := range fields {
   219  		f := &fields[i]
   220  		if f.repeated() {
   221  			repeated = append(repeated, f)
   222  		} else {
   223  			unique = append(unique, f)
   224  		}
   225  	}
   226  
   227  	return func(b []byte, p unsafe.Pointer, flags flags) (int, error) {
   228  		if p == nil {
   229  			return 0, nil
   230  		}
   231  
   232  		if !inlined {
   233  			flags = flags.without(inline | toplevel)
   234  		} else {
   235  			flags = flags.without(toplevel)
   236  		}
   237  		offset := 0
   238  
   239  		for _, f := range unique {
   240  			fieldFlags := f.makeFlags(flags)
   241  			elem := f.pointer(p)
   242  			size := f.codec.size(elem, fieldFlags)
   243  
   244  			if size > 0 {
   245  				n, err := encodeTag(b[offset:], f.fieldNumber(), f.wireType())
   246  				offset += n
   247  				if err != nil {
   248  					return offset, err
   249  				}
   250  
   251  				if f.embedded() {
   252  					n, err := encodeVarint(b[offset:], uint64(size))
   253  					offset += n
   254  					if err != nil {
   255  						return offset, err
   256  					}
   257  				}
   258  
   259  				if (len(b) - offset) < size {
   260  					return len(b), io.ErrShortBuffer
   261  				}
   262  
   263  				n, err = f.codec.encode(b[offset:offset+size], elem, fieldFlags)
   264  				offset += n
   265  				if err != nil {
   266  					return offset, err
   267  				}
   268  
   269  				flags = flags.without(wantzero)
   270  			}
   271  		}
   272  
   273  		for _, f := range repeated {
   274  			n, err := f.codec.encode(b[offset:], f.pointer(p), f.makeFlags(flags))
   275  			offset += n
   276  			if err != nil {
   277  				return offset, err
   278  			}
   279  			if n > 0 {
   280  				flags = flags.without(wantzero)
   281  			}
   282  		}
   283  
   284  		return offset, nil
   285  	}
   286  }
   287  
   288  func structDecodeFuncOf(t reflect.Type, fields []structField) decodeFunc {
   289  	maxFieldNumber := fieldNumber(0)
   290  
   291  	for _, f := range fields {
   292  		if n := f.fieldNumber(); n > maxFieldNumber {
   293  			maxFieldNumber = n
   294  		}
   295  	}
   296  
   297  	fieldIndex := make([]*structField, maxFieldNumber+1)
   298  
   299  	for i := range fields {
   300  		f := &fields[i]
   301  		fieldIndex[f.fieldNumber()] = f
   302  	}
   303  
   304  	return func(b []byte, p unsafe.Pointer, flags flags) (int, error) {
   305  		flags = flags.without(toplevel)
   306  		offset := 0
   307  
   308  		for offset < len(b) {
   309  			fieldNumber, wireType, n, err := decodeTag(b[offset:])
   310  			offset += n
   311  			if err != nil {
   312  				return offset, err
   313  			}
   314  
   315  			i := int(fieldNumber)
   316  			f := (*structField)(nil)
   317  
   318  			if i >= 0 && i < len(fieldIndex) {
   319  				f = fieldIndex[i]
   320  			}
   321  
   322  			if f == nil {
   323  				skip := 0
   324  				size := uint64(0)
   325  				switch wireType {
   326  				case varint:
   327  					_, skip, err = decodeVarint(b[offset:])
   328  				case varlen:
   329  					size, skip, err = decodeVarint(b[offset:])
   330  					if err == nil {
   331  						if size > uint64(len(b)-skip) {
   332  							err = io.ErrUnexpectedEOF
   333  						} else {
   334  							skip += int(size)
   335  						}
   336  					}
   337  				case fixed32:
   338  					_, skip, err = decodeLE32(b[offset:])
   339  				case fixed64:
   340  					_, skip, err = decodeLE64(b[offset:])
   341  				default:
   342  					err = ErrWireTypeUnknown
   343  				}
   344  				if (offset + skip) <= len(b) {
   345  					offset += skip
   346  				} else {
   347  					offset, err = len(b), io.ErrUnexpectedEOF
   348  				}
   349  				if err != nil {
   350  					return offset, fieldError(fieldNumber, wireType, err)
   351  				}
   352  				continue
   353  			}
   354  
   355  			if wireType != f.wireType() {
   356  				return offset, fieldError(fieldNumber, wireType, fmt.Errorf("expected wire type %d", f.wireType()))
   357  			}
   358  
   359  			// `data` will only contain the section of the input buffer where
   360  			// the data for the next field is available. This is necessary to
   361  			// limit how many bytes will be consumed by embedded messages.
   362  			var data []byte
   363  			switch wireType {
   364  			case varint:
   365  				_, n, err := decodeVarint(b[offset:])
   366  				if err != nil {
   367  					return offset, fieldError(fieldNumber, wireType, err)
   368  				}
   369  				data = b[offset : offset+n]
   370  
   371  			case varlen:
   372  				l, n, err := decodeVarint(b[offset:])
   373  				if err != nil {
   374  					return offset + n, fieldError(fieldNumber, wireType, err)
   375  				}
   376  				if l > uint64(len(b)-(offset+n)) {
   377  					return len(b), fieldError(fieldNumber, wireType, io.ErrUnexpectedEOF)
   378  				}
   379  				if f.embedded() {
   380  					offset += n
   381  					data = b[offset : offset+int(l)]
   382  				} else {
   383  					data = b[offset : offset+n+int(l)]
   384  				}
   385  
   386  			case fixed32:
   387  				if (offset + 4) > len(b) {
   388  					return len(b), fieldError(fieldNumber, wireType, io.ErrUnexpectedEOF)
   389  				}
   390  				data = b[offset : offset+4]
   391  
   392  			case fixed64:
   393  				if (offset + 8) > len(b) {
   394  					return len(b), fieldError(fieldNumber, wireType, io.ErrUnexpectedEOF)
   395  				}
   396  				data = b[offset : offset+8]
   397  
   398  			default:
   399  				return offset, fieldError(fieldNumber, wireType, ErrWireTypeUnknown)
   400  			}
   401  
   402  			n, err = f.codec.decode(data, f.pointer(p), f.makeFlags(flags))
   403  			offset += n
   404  			if err != nil {
   405  				return offset, fieldError(fieldNumber, wireType, err)
   406  			}
   407  		}
   408  
   409  		return offset, nil
   410  	}
   411  }