github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/codec/decode_fields.go (about)

     1  package codec
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  
     9  	"github.com/golang/protobuf/proto"
    10  	"github.com/golang/protobuf/protoc-gen-go/descriptor"
    11  
    12  	"github.com/hoveychen/protoreflect/desc"
    13  )
    14  
    15  // ErrWireTypeEndGroup is returned from DecodeFieldValue if the tag and wire-type
    16  // it reads indicates an end-group marker.
    17  var ErrWireTypeEndGroup = errors.New("unexpected wire type: end group")
    18  
    19  // MessageFactory is used to instantiate messages when DecodeFieldValue needs to
    20  // decode a message value.
    21  //
    22  // Also see MessageFactory in "github.com/hoveychen/protoreflect/dynamic", which
    23  // implements this interface.
    24  type MessageFactory interface {
    25  	NewMessage(md *desc.MessageDescriptor) proto.Message
    26  }
    27  
    28  // UnknownField represents a field that was parsed from the binary wire
    29  // format for a message, but was not a recognized field number. Enough
    30  // information is preserved so that re-serializing the message won't lose
    31  // any of the unrecognized data.
    32  type UnknownField struct {
    33  	// The tag number for the unrecognized field.
    34  	Tag int32
    35  
    36  	// Encoding indicates how the unknown field was encoded on the wire. If it
    37  	// is proto.WireBytes or proto.WireGroupStart then Contents will be set to
    38  	// the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least
    39  	// significant 32 bits of Value. Otherwise, the data is in all 64 bits of
    40  	// Value.
    41  	Encoding int8
    42  	Contents []byte
    43  	Value    uint64
    44  }
    45  
    46  // DecodeFieldValue will read a field value from the buffer and return its
    47  // value and the corresponding field descriptor. The given function is used
    48  // to lookup a field descriptor by tag number. The given factory is used to
    49  // instantiate a message if the field value is (or contains) a message value.
    50  //
    51  // On error, the field descriptor and value are typically nil. However, if the
    52  // error returned is ErrWireTypeEndGroup, the returned value will indicate any
    53  // tag number encoded in the end-group marker.
    54  //
    55  // If the field descriptor returned is nil, that means that the given function
    56  // returned nil. This is expected to happen for unrecognized tag numbers. In
    57  // that case, no error is returned, and the value will be an UnknownField.
    58  func (cb *Buffer) DecodeFieldValue(fieldFinder func(int32) *desc.FieldDescriptor, fact MessageFactory) (*desc.FieldDescriptor, interface{}, error) {
    59  	if cb.EOF() {
    60  		return nil, nil, io.EOF
    61  	}
    62  	tagNumber, wireType, err := cb.DecodeTagAndWireType()
    63  	if err != nil {
    64  		return nil, nil, err
    65  	}
    66  	if wireType == proto.WireEndGroup {
    67  		return nil, tagNumber, ErrWireTypeEndGroup
    68  	}
    69  	fd := fieldFinder(tagNumber)
    70  	if fd == nil {
    71  		val, err := cb.decodeUnknownField(tagNumber, wireType)
    72  		return nil, val, err
    73  	}
    74  	val, err := cb.decodeKnownField(fd, wireType, fact)
    75  	return fd, val, err
    76  }
    77  
    78  // DecodeScalarField extracts a properly-typed value from v. The returned value's
    79  // type depends on the given field descriptor type. It will be the same type as
    80  // generated structs use for the field descriptor's type. Enum types will return
    81  // an int32. If the given field type uses length-delimited encoding (nested
    82  // messages, bytes, and strings), an error is returned.
    83  func DecodeScalarField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
    84  	switch fd.GetType() {
    85  	case descriptor.FieldDescriptorProto_TYPE_BOOL:
    86  		return v != 0, nil
    87  	case descriptor.FieldDescriptorProto_TYPE_UINT32,
    88  		descriptor.FieldDescriptorProto_TYPE_FIXED32:
    89  		if v > math.MaxUint32 {
    90  			return nil, ErrOverflow
    91  		}
    92  		return uint32(v), nil
    93  
    94  	case descriptor.FieldDescriptorProto_TYPE_INT32,
    95  		descriptor.FieldDescriptorProto_TYPE_ENUM:
    96  		s := int64(v)
    97  		if s > math.MaxInt32 || s < math.MinInt32 {
    98  			return nil, ErrOverflow
    99  		}
   100  		return int32(s), nil
   101  
   102  	case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
   103  		if v > math.MaxUint32 {
   104  			return nil, ErrOverflow
   105  		}
   106  		return int32(v), nil
   107  
   108  	case descriptor.FieldDescriptorProto_TYPE_SINT32:
   109  		if v > math.MaxUint32 {
   110  			return nil, ErrOverflow
   111  		}
   112  		return DecodeZigZag32(v), nil
   113  
   114  	case descriptor.FieldDescriptorProto_TYPE_UINT64,
   115  		descriptor.FieldDescriptorProto_TYPE_FIXED64:
   116  		return v, nil
   117  
   118  	case descriptor.FieldDescriptorProto_TYPE_INT64,
   119  		descriptor.FieldDescriptorProto_TYPE_SFIXED64:
   120  		return int64(v), nil
   121  
   122  	case descriptor.FieldDescriptorProto_TYPE_SINT64:
   123  		return DecodeZigZag64(v), nil
   124  
   125  	case descriptor.FieldDescriptorProto_TYPE_FLOAT:
   126  		if v > math.MaxUint32 {
   127  			return nil, ErrOverflow
   128  		}
   129  		return math.Float32frombits(uint32(v)), nil
   130  
   131  	case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
   132  		return math.Float64frombits(v), nil
   133  
   134  	default:
   135  		// bytes, string, message, and group cannot be represented as a simple numeric value
   136  		return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
   137  	}
   138  }
   139  
   140  // DecodeLengthDelimitedField extracts a properly-typed value from bytes. The
   141  // returned value's type will usually be []byte, string, or, for nested messages,
   142  // the type returned from the given message factory. However, since repeated
   143  // scalar fields can be length-delimited, when they used packed encoding, it can
   144  // also return an []interface{}, where each element is a scalar value. Furthermore,
   145  // it could return a scalar type, not in a slice, if the given field descriptor is
   146  // not repeated. This is to support cases where a field is changed from optional
   147  // to repeated. New code may emit a packed repeated representation, but old code
   148  // still expects a single scalar value. In this case, if the actual data in bytes
   149  // contains multiple values, only the last value is returned.
   150  func DecodeLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf MessageFactory) (interface{}, error) {
   151  	switch {
   152  	case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES:
   153  		return bytes, nil
   154  
   155  	case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING:
   156  		return string(bytes), nil
   157  
   158  	case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
   159  		fd.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP:
   160  		msg := mf.NewMessage(fd.GetMessageType())
   161  		err := proto.Unmarshal(bytes, msg)
   162  		if err != nil {
   163  			return nil, err
   164  		} else {
   165  			return msg, nil
   166  		}
   167  
   168  	default:
   169  		// even if the field is not repeated or not packed, we still parse it as such for
   170  		// backwards compatibility (e.g. message we are de-serializing could have been both
   171  		// repeated and packed at the time of serialization)
   172  		packedBuf := NewBuffer(bytes)
   173  		var slice []interface{}
   174  		var val interface{}
   175  		for !packedBuf.EOF() {
   176  			var v uint64
   177  			var err error
   178  			if varintTypes[fd.GetType()] {
   179  				v, err = packedBuf.DecodeVarint()
   180  			} else if fixed32Types[fd.GetType()] {
   181  				v, err = packedBuf.DecodeFixed32()
   182  			} else if fixed64Types[fd.GetType()] {
   183  				v, err = packedBuf.DecodeFixed64()
   184  			} else {
   185  				return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
   186  			}
   187  			if err != nil {
   188  				return nil, err
   189  			}
   190  			val, err = DecodeScalarField(fd, v)
   191  			if err != nil {
   192  				return nil, err
   193  			}
   194  			if fd.IsRepeated() {
   195  				slice = append(slice, val)
   196  			}
   197  		}
   198  		if fd.IsRepeated() {
   199  			return slice, nil
   200  		} else {
   201  			// if not a repeated field, last value wins
   202  			return val, nil
   203  		}
   204  	}
   205  }
   206  
   207  func (b *Buffer) decodeKnownField(fd *desc.FieldDescriptor, encoding int8, fact MessageFactory) (interface{}, error) {
   208  	var val interface{}
   209  	var err error
   210  	switch encoding {
   211  	case proto.WireFixed32:
   212  		var num uint64
   213  		num, err = b.DecodeFixed32()
   214  		if err == nil {
   215  			val, err = DecodeScalarField(fd, num)
   216  		}
   217  	case proto.WireFixed64:
   218  		var num uint64
   219  		num, err = b.DecodeFixed64()
   220  		if err == nil {
   221  			val, err = DecodeScalarField(fd, num)
   222  		}
   223  	case proto.WireVarint:
   224  		var num uint64
   225  		num, err = b.DecodeVarint()
   226  		if err == nil {
   227  			val, err = DecodeScalarField(fd, num)
   228  		}
   229  
   230  	case proto.WireBytes:
   231  		alloc := fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES
   232  		var raw []byte
   233  		raw, err = b.DecodeRawBytes(alloc)
   234  		if err == nil {
   235  			val, err = DecodeLengthDelimitedField(fd, raw, fact)
   236  		}
   237  
   238  	case proto.WireStartGroup:
   239  		if fd.GetMessageType() == nil {
   240  			return nil, fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
   241  		}
   242  		msg := fact.NewMessage(fd.GetMessageType())
   243  		var data []byte
   244  		data, err = b.ReadGroup(false)
   245  		if err == nil {
   246  			err = proto.Unmarshal(data, msg)
   247  			if err == nil {
   248  				val = msg
   249  			}
   250  		}
   251  
   252  	default:
   253  		return nil, ErrBadWireType
   254  	}
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  
   259  	return val, nil
   260  }
   261  
   262  func (b *Buffer) decodeUnknownField(tagNumber int32, encoding int8) (interface{}, error) {
   263  	u := UnknownField{Tag: tagNumber, Encoding: encoding}
   264  	var err error
   265  	switch encoding {
   266  	case proto.WireFixed32:
   267  		u.Value, err = b.DecodeFixed32()
   268  	case proto.WireFixed64:
   269  		u.Value, err = b.DecodeFixed64()
   270  	case proto.WireVarint:
   271  		u.Value, err = b.DecodeVarint()
   272  	case proto.WireBytes:
   273  		u.Contents, err = b.DecodeRawBytes(true)
   274  	case proto.WireStartGroup:
   275  		u.Contents, err = b.ReadGroup(true)
   276  	default:
   277  		err = ErrBadWireType
   278  	}
   279  	if err != nil {
   280  		return nil, err
   281  	}
   282  	return u, nil
   283  }