github.com/jhump/protoreflect@v1.16.0/codec/decode_fields.go (about)

     1  package codec
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  
     9  	"github.com/golang/protobuf/proto"
    10  	"google.golang.org/protobuf/types/descriptorpb"
    11  
    12  	"github.com/jhump/protoreflect/desc"
    13  )
    14  
    15  var varintTypes = map[descriptorpb.FieldDescriptorProto_Type]bool{}
    16  var fixed32Types = map[descriptorpb.FieldDescriptorProto_Type]bool{}
    17  var fixed64Types = map[descriptorpb.FieldDescriptorProto_Type]bool{}
    18  
    19  func init() {
    20  	varintTypes[descriptorpb.FieldDescriptorProto_TYPE_BOOL] = true
    21  	varintTypes[descriptorpb.FieldDescriptorProto_TYPE_INT32] = true
    22  	varintTypes[descriptorpb.FieldDescriptorProto_TYPE_INT64] = true
    23  	varintTypes[descriptorpb.FieldDescriptorProto_TYPE_UINT32] = true
    24  	varintTypes[descriptorpb.FieldDescriptorProto_TYPE_UINT64] = true
    25  	varintTypes[descriptorpb.FieldDescriptorProto_TYPE_SINT32] = true
    26  	varintTypes[descriptorpb.FieldDescriptorProto_TYPE_SINT64] = true
    27  	varintTypes[descriptorpb.FieldDescriptorProto_TYPE_ENUM] = true
    28  
    29  	fixed32Types[descriptorpb.FieldDescriptorProto_TYPE_FIXED32] = true
    30  	fixed32Types[descriptorpb.FieldDescriptorProto_TYPE_SFIXED32] = true
    31  	fixed32Types[descriptorpb.FieldDescriptorProto_TYPE_FLOAT] = true
    32  
    33  	fixed64Types[descriptorpb.FieldDescriptorProto_TYPE_FIXED64] = true
    34  	fixed64Types[descriptorpb.FieldDescriptorProto_TYPE_SFIXED64] = true
    35  	fixed64Types[descriptorpb.FieldDescriptorProto_TYPE_DOUBLE] = true
    36  }
    37  
    38  // ErrWireTypeEndGroup is returned from DecodeFieldValue if the tag and wire-type
    39  // it reads indicates an end-group marker.
    40  var ErrWireTypeEndGroup = errors.New("unexpected wire type: end group")
    41  
    42  // MessageFactory is used to instantiate messages when DecodeFieldValue needs to
    43  // decode a message value.
    44  //
    45  // Also see MessageFactory in "github.com/jhump/protoreflect/dynamic", which
    46  // implements this interface.
    47  type MessageFactory interface {
    48  	NewMessage(md *desc.MessageDescriptor) proto.Message
    49  }
    50  
    51  // UnknownField represents a field that was parsed from the binary wire
    52  // format for a message, but was not a recognized field number. Enough
    53  // information is preserved so that re-serializing the message won't lose
    54  // any of the unrecognized data.
    55  type UnknownField struct {
    56  	// The tag number for the unrecognized field.
    57  	Tag int32
    58  
    59  	// Encoding indicates how the unknown field was encoded on the wire. If it
    60  	// is proto.WireBytes or proto.WireGroupStart then Contents will be set to
    61  	// the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least
    62  	// significant 32 bits of Value. Otherwise, the data is in all 64 bits of
    63  	// Value.
    64  	Encoding int8
    65  	Contents []byte
    66  	Value    uint64
    67  }
    68  
    69  // DecodeZigZag32 decodes a signed 32-bit integer from the given
    70  // zig-zag encoded value.
    71  func DecodeZigZag32(v uint64) int32 {
    72  	return int32((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31))
    73  }
    74  
    75  // DecodeZigZag64 decodes a signed 64-bit integer from the given
    76  // zig-zag encoded value.
    77  func DecodeZigZag64(v uint64) int64 {
    78  	return int64((v >> 1) ^ uint64((int64(v&1)<<63)>>63))
    79  }
    80  
    81  // DecodeFieldValue will read a field value from the buffer and return its
    82  // value and the corresponding field descriptor. The given function is used
    83  // to lookup a field descriptor by tag number. The given factory is used to
    84  // instantiate a message if the field value is (or contains) a message value.
    85  //
    86  // On error, the field descriptor and value are typically nil. However, if the
    87  // error returned is ErrWireTypeEndGroup, the returned value will indicate any
    88  // tag number encoded in the end-group marker.
    89  //
    90  // If the field descriptor returned is nil, that means that the given function
    91  // returned nil. This is expected to happen for unrecognized tag numbers. In
    92  // that case, no error is returned, and the value will be an UnknownField.
    93  func (cb *Buffer) DecodeFieldValue(fieldFinder func(int32) *desc.FieldDescriptor, fact MessageFactory) (*desc.FieldDescriptor, interface{}, error) {
    94  	if cb.EOF() {
    95  		return nil, nil, io.EOF
    96  	}
    97  	tagNumber, wireType, err := cb.DecodeTagAndWireType()
    98  	if err != nil {
    99  		return nil, nil, err
   100  	}
   101  	if wireType == proto.WireEndGroup {
   102  		return nil, tagNumber, ErrWireTypeEndGroup
   103  	}
   104  	fd := fieldFinder(tagNumber)
   105  	if fd == nil {
   106  		val, err := cb.decodeUnknownField(tagNumber, wireType)
   107  		return nil, val, err
   108  	}
   109  	val, err := cb.decodeKnownField(fd, wireType, fact)
   110  	return fd, val, err
   111  }
   112  
   113  // DecodeScalarField extracts a properly-typed value from v. The returned value's
   114  // type depends on the given field descriptor type. It will be the same type as
   115  // generated structs use for the field descriptor's type. Enum types will return
   116  // an int32. If the given field type uses length-delimited encoding (nested
   117  // messages, bytes, and strings), an error is returned.
   118  func DecodeScalarField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
   119  	switch fd.GetType() {
   120  	case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
   121  		return v != 0, nil
   122  	case descriptorpb.FieldDescriptorProto_TYPE_UINT32,
   123  		descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
   124  		if v > math.MaxUint32 {
   125  			return nil, ErrOverflow
   126  		}
   127  		return uint32(v), nil
   128  
   129  	case descriptorpb.FieldDescriptorProto_TYPE_INT32,
   130  		descriptorpb.FieldDescriptorProto_TYPE_ENUM:
   131  		s := int64(v)
   132  		if s > math.MaxInt32 || s < math.MinInt32 {
   133  			return nil, ErrOverflow
   134  		}
   135  		return int32(s), nil
   136  
   137  	case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
   138  		if v > math.MaxUint32 {
   139  			return nil, ErrOverflow
   140  		}
   141  		return int32(v), nil
   142  
   143  	case descriptorpb.FieldDescriptorProto_TYPE_SINT32:
   144  		if v > math.MaxUint32 {
   145  			return nil, ErrOverflow
   146  		}
   147  		return DecodeZigZag32(v), nil
   148  
   149  	case descriptorpb.FieldDescriptorProto_TYPE_UINT64,
   150  		descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
   151  		return v, nil
   152  
   153  	case descriptorpb.FieldDescriptorProto_TYPE_INT64,
   154  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
   155  		return int64(v), nil
   156  
   157  	case descriptorpb.FieldDescriptorProto_TYPE_SINT64:
   158  		return DecodeZigZag64(v), nil
   159  
   160  	case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
   161  		if v > math.MaxUint32 {
   162  			return nil, ErrOverflow
   163  		}
   164  		return math.Float32frombits(uint32(v)), nil
   165  
   166  	case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
   167  		return math.Float64frombits(v), nil
   168  
   169  	default:
   170  		// bytes, string, message, and group cannot be represented as a simple numeric value
   171  		return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
   172  	}
   173  }
   174  
   175  // DecodeLengthDelimitedField extracts a properly-typed value from bytes. The
   176  // returned value's type will usually be []byte, string, or, for nested messages,
   177  // the type returned from the given message factory. However, since repeated
   178  // scalar fields can be length-delimited, when they used packed encoding, it can
   179  // also return an []interface{}, where each element is a scalar value. Furthermore,
   180  // it could return a scalar type, not in a slice, if the given field descriptor is
   181  // not repeated. This is to support cases where a field is changed from optional
   182  // to repeated. New code may emit a packed repeated representation, but old code
   183  // still expects a single scalar value. In this case, if the actual data in bytes
   184  // contains multiple values, only the last value is returned.
   185  func DecodeLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf MessageFactory) (interface{}, error) {
   186  	switch {
   187  	case fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_BYTES:
   188  		return bytes, nil
   189  
   190  	case fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_STRING:
   191  		return string(bytes), nil
   192  
   193  	case fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE ||
   194  		fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP:
   195  		msg := mf.NewMessage(fd.GetMessageType())
   196  		err := proto.Unmarshal(bytes, msg)
   197  		if err != nil {
   198  			return nil, err
   199  		} else {
   200  			return msg, nil
   201  		}
   202  
   203  	default:
   204  		// even if the field is not repeated or not packed, we still parse it as such for
   205  		// backwards compatibility (e.g. message we are de-serializing could have been both
   206  		// repeated and packed at the time of serialization)
   207  		packedBuf := NewBuffer(bytes)
   208  		var slice []interface{}
   209  		var val interface{}
   210  		for !packedBuf.EOF() {
   211  			var v uint64
   212  			var err error
   213  			if varintTypes[fd.GetType()] {
   214  				v, err = packedBuf.DecodeVarint()
   215  			} else if fixed32Types[fd.GetType()] {
   216  				v, err = packedBuf.DecodeFixed32()
   217  			} else if fixed64Types[fd.GetType()] {
   218  				v, err = packedBuf.DecodeFixed64()
   219  			} else {
   220  				return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
   221  			}
   222  			if err != nil {
   223  				return nil, err
   224  			}
   225  			val, err = DecodeScalarField(fd, v)
   226  			if err != nil {
   227  				return nil, err
   228  			}
   229  			if fd.IsRepeated() {
   230  				slice = append(slice, val)
   231  			}
   232  		}
   233  		if fd.IsRepeated() {
   234  			return slice, nil
   235  		} else {
   236  			// if not a repeated field, last value wins
   237  			return val, nil
   238  		}
   239  	}
   240  }
   241  
   242  func (b *Buffer) decodeKnownField(fd *desc.FieldDescriptor, encoding int8, fact MessageFactory) (interface{}, error) {
   243  	var val interface{}
   244  	var err error
   245  	switch encoding {
   246  	case proto.WireFixed32:
   247  		var num uint64
   248  		num, err = b.DecodeFixed32()
   249  		if err == nil {
   250  			val, err = DecodeScalarField(fd, num)
   251  		}
   252  	case proto.WireFixed64:
   253  		var num uint64
   254  		num, err = b.DecodeFixed64()
   255  		if err == nil {
   256  			val, err = DecodeScalarField(fd, num)
   257  		}
   258  	case proto.WireVarint:
   259  		var num uint64
   260  		num, err = b.DecodeVarint()
   261  		if err == nil {
   262  			val, err = DecodeScalarField(fd, num)
   263  		}
   264  
   265  	case proto.WireBytes:
   266  		alloc := fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_BYTES
   267  		var raw []byte
   268  		raw, err = b.DecodeRawBytes(alloc)
   269  		if err == nil {
   270  			val, err = DecodeLengthDelimitedField(fd, raw, fact)
   271  		}
   272  
   273  	case proto.WireStartGroup:
   274  		if fd.GetMessageType() == nil {
   275  			return nil, fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
   276  		}
   277  		msg := fact.NewMessage(fd.GetMessageType())
   278  		var data []byte
   279  		data, err = b.ReadGroup(false)
   280  		if err == nil {
   281  			err = proto.Unmarshal(data, msg)
   282  			if err == nil {
   283  				val = msg
   284  			}
   285  		}
   286  
   287  	default:
   288  		return nil, ErrBadWireType
   289  	}
   290  	if err != nil {
   291  		return nil, err
   292  	}
   293  
   294  	return val, nil
   295  }
   296  
   297  func (b *Buffer) decodeUnknownField(tagNumber int32, encoding int8) (interface{}, error) {
   298  	u := UnknownField{Tag: tagNumber, Encoding: encoding}
   299  	var err error
   300  	switch encoding {
   301  	case proto.WireFixed32:
   302  		u.Value, err = b.DecodeFixed32()
   303  	case proto.WireFixed64:
   304  		u.Value, err = b.DecodeFixed64()
   305  	case proto.WireVarint:
   306  		u.Value, err = b.DecodeVarint()
   307  	case proto.WireBytes:
   308  		u.Contents, err = b.DecodeRawBytes(true)
   309  	case proto.WireStartGroup:
   310  		u.Contents, err = b.ReadGroup(true)
   311  	default:
   312  		err = ErrBadWireType
   313  	}
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  	return u, nil
   318  }