github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/dbnode/encoding/proto/custom_unmarshal.go (about)

     1  // Copyright (c) 2019 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package proto
    22  
    23  import (
    24  	"errors"
    25  	"fmt"
    26  	"math"
    27  	"sort"
    28  
    29  	"github.com/golang/protobuf/proto"
    30  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    31  	"github.com/jhump/protoreflect/desc"
    32  )
    33  
    34  var (
    35  	// Groups in the Protobuf wire format are deprecated, so simplify the code significantly by
    36  	// not supporting them.
    37  	errGroupsAreNotSupported = errors.New("use of groups in proto wire format is not supported")
    38  	zeroValue                unmarshalValue
    39  )
    40  
    41  type customFieldUnmarshaller interface {
    42  	sortedCustomFieldValues() sortedCustomFieldValues
    43  	sortedNonCustomFieldValues() sortedMarshalledFields
    44  	numNonCustomValues() int
    45  	resetAndUnmarshal(schema *desc.MessageDescriptor, buf []byte) error
    46  }
    47  
    48  type customUnmarshallerOptions struct {
    49  	skipUnknownFields bool
    50  }
    51  
    52  type customUnmarshaller struct {
    53  	schema       *desc.MessageDescriptor
    54  	decodeBuf    *buffer
    55  	customValues sortedCustomFieldValues
    56  
    57  	nonCustomValues sortedMarshalledFields
    58  	numNonCustom    int
    59  
    60  	opts customUnmarshallerOptions
    61  }
    62  
    63  func newCustomFieldUnmarshaller(opts customUnmarshallerOptions) customFieldUnmarshaller {
    64  	return &customUnmarshaller{
    65  		decodeBuf: newCodedBuffer(nil),
    66  		opts:      opts,
    67  	}
    68  }
    69  
    70  func (u *customUnmarshaller) sortedCustomFieldValues() sortedCustomFieldValues {
    71  	return u.customValues
    72  }
    73  
    74  func (u *customUnmarshaller) numNonCustomValues() int {
    75  	return u.numNonCustom
    76  }
    77  
    78  func (u *customUnmarshaller) sortedNonCustomFieldValues() sortedMarshalledFields {
    79  	return u.nonCustomValues
    80  }
    81  
    82  func (u *customUnmarshaller) unmarshal() error {
    83  	u.resetCustomAndNonCustomValues()
    84  
    85  	var (
    86  		areCustomValuesSorted    = true
    87  		areNonCustomValuesSorted = true
    88  	)
    89  	for !u.decodeBuf.eof() {
    90  		tagAndWireTypeStartOffset := u.decodeBuf.index
    91  		fieldNum, wireType, err := u.decodeBuf.decodeTagAndWireType()
    92  		if err != nil {
    93  			return err
    94  		}
    95  
    96  		fd := u.schema.FindFieldByNumber(fieldNum)
    97  		if fd == nil {
    98  			if !u.opts.skipUnknownFields {
    99  				return fmt.Errorf("encountered unknown field with field number: %d", fieldNum)
   100  			}
   101  
   102  			if _, err := u.skip(wireType); err != nil {
   103  				return err
   104  			}
   105  			continue
   106  		}
   107  
   108  		if !u.isCustomField(fd) {
   109  			_, err = u.skip(wireType)
   110  			if err != nil {
   111  				return err
   112  			}
   113  
   114  			var (
   115  				startIdx   = tagAndWireTypeStartOffset
   116  				endIdx     = u.decodeBuf.index
   117  				marshalled = u.decodeBuf.buf[startIdx:endIdx]
   118  			)
   119  			// A marshalled Protobuf message consists of a stream of <fieldNumber, wireType, value>
   120  			// tuples, all of which are optional, with no additional header or footer information.
   121  			// This means that each tuple within the stream can be thought of as its own complete
   122  			// marshalled message and as a result we can build up the []marshalledField one field at
   123  			// a time.
   124  			updatedExisting := false
   125  			if fd.IsRepeated() {
   126  				// If the fd is a repeated type and not using `packed` encoding then their could be multiple
   127  				// entries in the stream with the same field number so their marshalled bytes needs to be all
   128  				// concatenated together.
   129  				//
   130  				// NB(rartoul): This will have an adverse impact on the compression of map types because the
   131  				// key/val pairs can be encoded in any order. This means that its possible for two equivalent
   132  				// maps to have different byte streams which will force the encoder to re-encode the field into
   133  				// the stream even though it hasn't changed. This naive solution should be good enough for now,
   134  				// but if it proves problematic in the future the issue could be resolved by accumulating the
   135  				// marshalled tuples into a slice and then sorting by field number to produce a deterministic
   136  				// result such that equivalent maps always result in equivalent marshalled bytes slices.
   137  				for i, val := range u.nonCustomValues {
   138  					if fieldNum == val.fieldNum {
   139  						u.nonCustomValues[i].marshalled = append(u.nonCustomValues[i].marshalled, marshalled...)
   140  						updatedExisting = true
   141  						break
   142  					}
   143  				}
   144  			}
   145  			if !updatedExisting {
   146  				u.nonCustomValues = append(u.nonCustomValues, marshalledField{
   147  					fieldNum:   fieldNum,
   148  					marshalled: marshalled,
   149  				})
   150  			}
   151  
   152  			if areNonCustomValuesSorted && len(u.nonCustomValues) > 1 {
   153  				// Check if the slice is sorted as it's built to avoid resorting
   154  				// unnecessarily at the end.
   155  				lastFieldNum := u.nonCustomValues[len(u.nonCustomValues)-1].fieldNum
   156  				if fieldNum < lastFieldNum {
   157  					areNonCustomValuesSorted = false
   158  				}
   159  			}
   160  
   161  			u.numNonCustom++
   162  			continue
   163  		}
   164  
   165  		value, err := u.unmarshalCustomField(fd, wireType)
   166  		if err != nil {
   167  			return err
   168  		}
   169  
   170  		if areCustomValuesSorted && len(u.customValues) > 1 {
   171  			// Check if the slice is sorted as it's built to avoid resorting
   172  			// unnecessarily at the end.
   173  			lastFieldNum := u.customValues[len(u.customValues)-1].fieldNumber
   174  			if fieldNum < lastFieldNum {
   175  				areCustomValuesSorted = false
   176  			}
   177  		}
   178  
   179  		u.customValues = append(u.customValues, value)
   180  	}
   181  
   182  	u.decodeBuf.reset(u.decodeBuf.buf)
   183  
   184  	// Avoid resorting if possible.
   185  	if !areCustomValuesSorted {
   186  		sort.Sort(u.customValues)
   187  	}
   188  	if !areNonCustomValuesSorted {
   189  		sort.Sort(u.nonCustomValues)
   190  	}
   191  
   192  	return nil
   193  }
   194  
   195  // isCustomField checks whether the encoder would have custom encoded this field or left
   196  // it up to the `jhump/dynamic` package to handle the encoding. This is important because
   197  // it allows us to use the efficient unmarshal path only for fields that the encoder can
   198  // actually take advantage of.
   199  func (u *customUnmarshaller) isCustomField(fd *desc.FieldDescriptor) bool {
   200  	if fd.IsRepeated() || fd.IsMap() {
   201  		// Map should always be repeated but include the guard just in case.
   202  		return false
   203  	}
   204  
   205  	if fd.GetMessageType() != nil {
   206  		// Skip nested messages.
   207  		return false
   208  	}
   209  
   210  	return true
   211  }
   212  
   213  // skip will skip over the next value in the encoded stream (given that the tag and
   214  // wiretype have already been decoded).
   215  func (u *customUnmarshaller) skip(wireType int8) (int, error) {
   216  	switch wireType {
   217  	case proto.WireFixed32:
   218  		bytesSkipped := 4
   219  		u.decodeBuf.index += bytesSkipped
   220  		return bytesSkipped, nil
   221  
   222  	case proto.WireFixed64:
   223  		bytesSkipped := 8
   224  		u.decodeBuf.index += bytesSkipped
   225  		return bytesSkipped, nil
   226  
   227  	case proto.WireVarint:
   228  		var (
   229  			bytesSkipped             = 0
   230  			offsetBeforeDecodeVarInt = u.decodeBuf.index
   231  		)
   232  		_, err := u.decodeBuf.decodeVarint()
   233  		if err != nil {
   234  			return 0, err
   235  		}
   236  		bytesSkipped += u.decodeBuf.index - offsetBeforeDecodeVarInt
   237  		return bytesSkipped, nil
   238  
   239  	case proto.WireBytes:
   240  		var (
   241  			bytesSkipped               = 0
   242  			offsetBeforeDecodeRawBytes = u.decodeBuf.index
   243  		)
   244  		// Bytes aren't copied because they're just being skipped over so
   245  		// copying would be wasteful.
   246  		_, err := u.decodeBuf.decodeRawBytes(false)
   247  		if err != nil {
   248  			return 0, err
   249  		}
   250  		bytesSkipped += u.decodeBuf.index - offsetBeforeDecodeRawBytes
   251  		return bytesSkipped, nil
   252  
   253  	case proto.WireStartGroup:
   254  		return 0, errGroupsAreNotSupported
   255  
   256  	case proto.WireEndGroup:
   257  		return 0, errGroupsAreNotSupported
   258  
   259  	default:
   260  		return 0, proto.ErrInternalBadWireType
   261  	}
   262  }
   263  
   264  func (u *customUnmarshaller) unmarshalCustomField(fd *desc.FieldDescriptor, wireType int8) (unmarshalValue, error) {
   265  	switch wireType {
   266  	case proto.WireFixed32:
   267  		num, err := u.decodeBuf.decodeFixed32()
   268  		if err != nil {
   269  			return zeroValue, err
   270  		}
   271  		return unmarshalSimpleField(fd, num)
   272  
   273  	case proto.WireFixed64:
   274  		num, err := u.decodeBuf.decodeFixed64()
   275  		if err != nil {
   276  			return zeroValue, err
   277  		}
   278  		return unmarshalSimpleField(fd, num)
   279  
   280  	case proto.WireVarint:
   281  		num, err := u.decodeBuf.decodeVarint()
   282  		if err != nil {
   283  			return zeroValue, err
   284  		}
   285  		return unmarshalSimpleField(fd, num)
   286  
   287  	case proto.WireBytes:
   288  		if t := fd.GetType(); t != dpb.FieldDescriptorProto_TYPE_BYTES &&
   289  			t != dpb.FieldDescriptorProto_TYPE_STRING {
   290  			// This should never happen since it means the skipping logic is not working
   291  			// correctly or the message is malformed since proto.WireBytes should only be
   292  			// used for fields of type bytes, string, group, or message. Groups/messages
   293  			// should be handled by the skipping logic (for now).
   294  			return zeroValue, fmt.Errorf(
   295  				"tried to unmarshal field with wire type: bytes and proto field type: %s",
   296  				fd.GetType().String())
   297  		}
   298  
   299  		// Don't bother copying the bytes now because the encoder has exclusive ownership
   300  		// of them until the call to Encode() completes and they will get "copied" anyways
   301  		// once they're written into the OStream.
   302  		raw, err := u.decodeBuf.decodeRawBytes(false)
   303  		if err != nil {
   304  			return zeroValue, err
   305  		}
   306  
   307  		val := unmarshalValue{fieldNumber: fd.GetNumber(), bytes: raw}
   308  		return val, nil
   309  
   310  	case proto.WireStartGroup:
   311  		return zeroValue, errGroupsAreNotSupported
   312  
   313  	default:
   314  		return zeroValue, proto.ErrInternalBadWireType
   315  	}
   316  }
   317  
   318  func unmarshalSimpleField(fd *desc.FieldDescriptor, v uint64) (unmarshalValue, error) {
   319  	fieldNum := fd.GetNumber()
   320  	val := unmarshalValue{fieldNumber: fieldNum, v: v}
   321  	switch fd.GetType() {
   322  	case dpb.FieldDescriptorProto_TYPE_BOOL,
   323  		dpb.FieldDescriptorProto_TYPE_UINT64,
   324  		dpb.FieldDescriptorProto_TYPE_FIXED64,
   325  		dpb.FieldDescriptorProto_TYPE_INT64,
   326  		dpb.FieldDescriptorProto_TYPE_SFIXED64,
   327  		dpb.FieldDescriptorProto_TYPE_DOUBLE:
   328  		return val, nil
   329  
   330  	case dpb.FieldDescriptorProto_TYPE_UINT32,
   331  		dpb.FieldDescriptorProto_TYPE_FIXED32:
   332  		if v > math.MaxUint32 {
   333  			return zeroValue, fmt.Errorf("%d (field num %d) overflows uint32", v, fieldNum)
   334  		}
   335  		return val, nil
   336  
   337  	case dpb.FieldDescriptorProto_TYPE_INT32,
   338  		dpb.FieldDescriptorProto_TYPE_ENUM:
   339  		s := int64(v)
   340  		if s > math.MaxInt32 {
   341  			return zeroValue, fmt.Errorf("%d (field num %d) overflows int32", v, fieldNum)
   342  		}
   343  		if s < math.MinInt32 {
   344  			return zeroValue, fmt.Errorf("%d (field num %d) underflows int32", v, fieldNum)
   345  		}
   346  		return val, nil
   347  
   348  	case dpb.FieldDescriptorProto_TYPE_SFIXED32:
   349  		if v > math.MaxUint32 {
   350  			return zeroValue, fmt.Errorf("%d (field num %d) overflows int32", v, fieldNum)
   351  		}
   352  		return val, nil
   353  
   354  	case dpb.FieldDescriptorProto_TYPE_SINT32:
   355  		if v > math.MaxUint32 {
   356  			return zeroValue, fmt.Errorf("%d (field num %d) overflows int32", v, fieldNum)
   357  		}
   358  		val.v = uint64(decodeZigZag32(v))
   359  		return val, nil
   360  
   361  	case dpb.FieldDescriptorProto_TYPE_SINT64:
   362  		val.v = uint64(decodeZigZag64(v))
   363  		return val, nil
   364  
   365  	case dpb.FieldDescriptorProto_TYPE_FLOAT:
   366  		if v > math.MaxUint32 {
   367  			return zeroValue, fmt.Errorf("%d (field num %d) overflows uint32", v, fieldNum)
   368  		}
   369  		float32Val := math.Float32frombits(uint32(v))
   370  		float64Bits := math.Float64bits(float64(float32Val))
   371  		val.v = float64Bits
   372  		return val, nil
   373  
   374  	default:
   375  		// bytes, string, message, and group cannot be represented as a simple numeric value.
   376  		return zeroValue, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
   377  	}
   378  }
   379  
   380  func (u *customUnmarshaller) resetAndUnmarshal(schema *desc.MessageDescriptor, buf []byte) error {
   381  	u.schema = schema
   382  	u.numNonCustom = 0
   383  	u.resetCustomAndNonCustomValues()
   384  	u.decodeBuf.reset(buf)
   385  	return u.unmarshal()
   386  }
   387  
   388  func (u *customUnmarshaller) resetCustomAndNonCustomValues() {
   389  	for i := range u.customValues {
   390  		u.customValues[i] = unmarshalValue{}
   391  	}
   392  	u.customValues = u.customValues[:0]
   393  
   394  	for i := range u.nonCustomValues {
   395  		u.nonCustomValues[i] = marshalledField{}
   396  	}
   397  	u.nonCustomValues = u.nonCustomValues[:0]
   398  }
   399  
   400  type sortedCustomFieldValues []unmarshalValue
   401  
   402  func (s sortedCustomFieldValues) Len() int {
   403  	return len(s)
   404  }
   405  
   406  func (s sortedCustomFieldValues) Less(i, j int) bool {
   407  	return s[i].fieldNumber < s[j].fieldNumber
   408  }
   409  
   410  func (s sortedCustomFieldValues) Swap(i, j int) {
   411  	s[i], s[j] = s[j], s[i]
   412  }
   413  
   414  type unmarshalValue struct {
   415  	fieldNumber int32
   416  	v           uint64
   417  	bytes       []byte
   418  }
   419  
   420  func (v *unmarshalValue) asBool() bool {
   421  	return v.v != 0
   422  }
   423  
   424  func (v *unmarshalValue) asUint64() uint64 {
   425  	return v.v
   426  }
   427  
   428  func (v *unmarshalValue) asInt64() int64 {
   429  	return int64(v.v)
   430  }
   431  
   432  func (v *unmarshalValue) asFloat64() float64 {
   433  	return math.Float64frombits(v.v)
   434  }
   435  
   436  func (v *unmarshalValue) asBytes() []byte {
   437  	return v.bytes
   438  }