github.com/m3db/m3@v1.5.0/src/dbnode/encoding/proto/iterator.go (about)

     1  // Copyright (c) 2019 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package proto
    22  
    23  import (
    24  	"encoding/binary"
    25  	"fmt"
    26  	"io"
    27  	"math"
    28  
    29  	"github.com/m3db/m3/src/dbnode/encoding"
    30  	"github.com/m3db/m3/src/dbnode/encoding/m3tsz"
    31  	"github.com/m3db/m3/src/dbnode/namespace"
    32  	"github.com/m3db/m3/src/dbnode/ts"
    33  	"github.com/m3db/m3/src/dbnode/x/xio"
    34  	"github.com/m3db/m3/src/x/checked"
    35  	"github.com/m3db/m3/src/x/ident"
    36  	"github.com/m3db/m3/src/x/instrument"
    37  	xtime "github.com/m3db/m3/src/x/time"
    38  
    39  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    40  	"github.com/jhump/protoreflect/desc"
    41  )
    42  
    43  const (
    44  	// Maximum capacity of a checked.Bytes that will be retained between resets.
    45  	maxCapacityUnmarshalBufferRetain = 1024
    46  )
    47  
    48  var (
    49  	itErrPrefix                 = "proto iterator:"
    50  	errIteratorSchemaIsRequired = fmt.Errorf("%s schema is required", itErrPrefix)
    51  )
    52  
    53  type iterator struct {
    54  	nsID                 ident.ID
    55  	opts                 encoding.Options
    56  	err                  error
    57  	schema               *desc.MessageDescriptor
    58  	schemaDesc           namespace.SchemaDescr
    59  	stream               *encoding.IStream
    60  	marshaller           customFieldMarshaller
    61  	byteFieldDictLRUSize int
    62  	// TODO(rartoul): Update these as we traverse the stream if we encounter
    63  	// a mid-stream schema change: https://github.com/m3db/m3/issues/1471
    64  	customFields    []customFieldState
    65  	nonCustomFields []marshalledField
    66  
    67  	tsIterator m3tsz.TimestampIterator
    68  
    69  	// Fields that are reused between function calls to
    70  	// avoid allocations.
    71  	varIntBuf         [8]byte
    72  	bitsetValues      []int
    73  	unmarshalProtoBuf checked.Bytes
    74  	unmarshaller      customFieldUnmarshaller
    75  
    76  	consumedFirstMessage bool
    77  	done                 bool
    78  	closed               bool
    79  }
    80  
    81  // NewIterator creates a new iterator.
    82  func NewIterator(
    83  	reader xio.Reader64,
    84  	descr namespace.SchemaDescr,
    85  	opts encoding.Options,
    86  ) encoding.ReaderIterator {
    87  	stream := encoding.NewIStream(reader)
    88  
    89  	i := &iterator{
    90  		opts:       opts,
    91  		stream:     stream,
    92  		marshaller: newCustomMarshaller(),
    93  		tsIterator: m3tsz.NewTimestampIterator(opts, true),
    94  	}
    95  	i.resetSchema(descr)
    96  	return i
    97  }
    98  
    99  func (it *iterator) Next() bool {
   100  	if it.schema == nil {
   101  		// It is a programmatic error that schema is not set at all prior to iterating, panic to fix it asap.
   102  		it.err = instrument.InvariantErrorf(errIteratorSchemaIsRequired.Error())
   103  		return false
   104  	}
   105  
   106  	if !it.hasNext() {
   107  		return false
   108  	}
   109  
   110  	it.marshaller.reset()
   111  
   112  	if !it.consumedFirstMessage {
   113  		if err := it.readStreamHeader(); err != nil {
   114  			it.err = fmt.Errorf(
   115  				"%s error reading stream header: %v",
   116  				itErrPrefix, err)
   117  			return false
   118  		}
   119  	}
   120  
   121  	moreDataControlBit, err := it.stream.ReadBit()
   122  	if err == io.EOF {
   123  		it.done = true
   124  		return false
   125  	}
   126  	if err != nil {
   127  		it.err = fmt.Errorf(
   128  			"%s error reading more data control bit: %v",
   129  			itErrPrefix, err)
   130  		return false
   131  	}
   132  
   133  	if moreDataControlBit == opCodeNoMoreDataOrTimeUnitChangeAndOrSchemaChange {
   134  		// The next bit will tell us whether we've reached the end of the stream
   135  		// or that the time unit and/or schema has changed.
   136  		noMoreDataControlBit, err := it.stream.ReadBit()
   137  		if err == io.EOF {
   138  			it.done = true
   139  			return false
   140  		}
   141  		if err != nil {
   142  			it.err = fmt.Errorf(
   143  				"%s error reading no more data control bit: %v",
   144  				itErrPrefix, err)
   145  			return false
   146  		}
   147  
   148  		if noMoreDataControlBit == opCodeNoMoreData {
   149  			it.done = true
   150  			return false
   151  		}
   152  
   153  		// The next bit will tell us whether the time unit has changed.
   154  		timeUnitHasChangedControlBit, err := it.stream.ReadBit()
   155  		if err != nil {
   156  			it.err = fmt.Errorf(
   157  				"%s error reading time unit change has changed control bit: %v",
   158  				itErrPrefix, err)
   159  			return false
   160  		}
   161  
   162  		// The next bit will tell us whether the schema has changed.
   163  		schemaHasChangedControlBit, err := it.stream.ReadBit()
   164  		if err != nil {
   165  			it.err = fmt.Errorf(
   166  				"%s error reading schema has changed control bit: %v",
   167  				itErrPrefix, err)
   168  			return false
   169  		}
   170  
   171  		if timeUnitHasChangedControlBit == opCodeTimeUnitChange {
   172  			if err := it.tsIterator.ReadTimeUnit(it.stream); err != nil {
   173  				it.err = fmt.Errorf("%s error reading new time unit: %v", itErrPrefix, err)
   174  				return false
   175  			}
   176  		}
   177  
   178  		if schemaHasChangedControlBit == opCodeSchemaChange {
   179  			if err := it.readCustomFieldsSchema(); err != nil {
   180  				it.err = fmt.Errorf("%s error reading custom fields schema: %v", itErrPrefix, err)
   181  				return false
   182  			}
   183  
   184  			// When the encoder changes its schema it will reset all of its nonCustomFields state
   185  			// which means that the iterator needs to do the same to keep them synchronized at
   186  			// each point in the stream.
   187  			for i := range it.nonCustomFields {
   188  				// Reslice instead of setting to nil to reuse existing capacity if possible.
   189  				it.nonCustomFields[i].marshalled = it.nonCustomFields[i].marshalled[:0]
   190  			}
   191  		}
   192  	}
   193  
   194  	_, done, err := it.tsIterator.ReadTimestamp(it.stream)
   195  	if err != nil {
   196  		it.err = fmt.Errorf("%s error reading timestamp: %v", itErrPrefix, err)
   197  		return false
   198  	}
   199  	if done {
   200  		// This should never happen since we never encode the EndOfStream marker.
   201  		it.err = fmt.Errorf("%s unexpected end of timestamp stream", itErrPrefix)
   202  		return false
   203  	}
   204  
   205  	if err := it.readCustomValues(); err != nil {
   206  		it.err = err
   207  		return false
   208  	}
   209  
   210  	if err := it.readNonCustomValues(); err != nil {
   211  		it.err = err
   212  		return false
   213  	}
   214  
   215  	// Update the marshaller bytes (which will be returned by Current()) with the latest value
   216  	// for every non-custom field.
   217  	for _, marshalledField := range it.nonCustomFields {
   218  		it.marshaller.encPartialProto(marshalledField.marshalled)
   219  	}
   220  
   221  	it.consumedFirstMessage = true
   222  	return it.hasNext()
   223  }
   224  
   225  func (it *iterator) Current() (ts.Datapoint, xtime.Unit, ts.Annotation) {
   226  	var (
   227  		dp = ts.Datapoint{
   228  			TimestampNanos: it.tsIterator.PrevTime,
   229  		}
   230  		unit = it.tsIterator.TimeUnit
   231  	)
   232  
   233  	return dp, unit, it.marshaller.bytes()
   234  }
   235  
   236  func (it *iterator) Err() error {
   237  	return it.err
   238  }
   239  
   240  func (it *iterator) Reset(reader xio.Reader64, descr namespace.SchemaDescr) {
   241  	it.resetSchema(descr)
   242  	it.stream.Reset(reader)
   243  	it.tsIterator = m3tsz.NewTimestampIterator(it.opts, true)
   244  
   245  	it.err = nil
   246  	it.consumedFirstMessage = false
   247  	it.done = false
   248  	it.closed = false
   249  	it.byteFieldDictLRUSize = 0
   250  }
   251  
   252  // setSchema sets the schema for the iterator.
   253  func (it *iterator) resetSchema(schemaDesc namespace.SchemaDescr) {
   254  	if schemaDesc == nil {
   255  		it.schemaDesc = nil
   256  		it.schema = nil
   257  
   258  		// Clear but don't set to nil so they don't need to be reallocated
   259  		// next time.
   260  		customFields := it.customFields
   261  		for i := range customFields {
   262  			customFields[i] = customFieldState{}
   263  		}
   264  		it.customFields = customFields[:0]
   265  
   266  		nonCustomFields := it.nonCustomFields
   267  		for i := range nonCustomFields {
   268  			nonCustomFields[i] = marshalledField{}
   269  		}
   270  		it.nonCustomFields = nonCustomFields[:0]
   271  		return
   272  	}
   273  
   274  	it.schemaDesc = schemaDesc
   275  	it.schema = schemaDesc.Get().MessageDescriptor
   276  	it.customFields, it.nonCustomFields = customAndNonCustomFields(it.customFields, nil, it.schema)
   277  }
   278  
   279  func (it *iterator) Close() {
   280  	if it.closed {
   281  		return
   282  	}
   283  
   284  	it.closed = true
   285  	it.Reset(nil, nil)
   286  	it.stream.Reset(nil)
   287  
   288  	if it.unmarshalProtoBuf != nil && it.unmarshalProtoBuf.Cap() > maxCapacityUnmarshalBufferRetain {
   289  		// Only finalize the buffer if its grown too large to prevent pooled
   290  		// iterators from growing excessively large.
   291  		it.unmarshalProtoBuf.DecRef()
   292  		it.unmarshalProtoBuf.Finalize()
   293  		it.unmarshalProtoBuf = nil
   294  	}
   295  
   296  	if pool := it.opts.ReaderIteratorPool(); pool != nil {
   297  		pool.Put(it)
   298  	}
   299  }
   300  
   301  func (it *iterator) readStreamHeader() error {
   302  	// Can ignore the version number for now because we only have one.
   303  	_, err := it.readVarInt()
   304  	if err != nil {
   305  		return err
   306  	}
   307  
   308  	byteFieldDictLRUSize, err := it.readVarInt()
   309  	if err != nil {
   310  		return err
   311  	}
   312  
   313  	it.byteFieldDictLRUSize = int(byteFieldDictLRUSize)
   314  	return nil
   315  }
   316  
   317  func (it *iterator) readCustomFieldsSchema() error {
   318  	numCustomFields, err := it.readVarInt()
   319  	if err != nil {
   320  		return err
   321  	}
   322  
   323  	if numCustomFields > maxCustomFieldNum {
   324  		return fmt.Errorf(
   325  			"num custom fields in header is %d but maximum allowed is %d",
   326  			numCustomFields, maxCustomFieldNum)
   327  	}
   328  
   329  	if it.customFields != nil {
   330  		for i := range it.customFields {
   331  			it.customFields[i] = customFieldState{}
   332  		}
   333  		it.customFields = it.customFields[:0]
   334  	} else {
   335  		it.customFields = make([]customFieldState, 0, numCustomFields)
   336  	}
   337  
   338  	for i := 1; i <= int(numCustomFields); i++ {
   339  		fieldTypeBits, err := it.stream.ReadBits(uint8(numBitsToEncodeCustomType))
   340  		if err != nil {
   341  			return err
   342  		}
   343  
   344  		fieldType := customFieldType(fieldTypeBits)
   345  		if fieldType == notCustomEncodedField {
   346  			continue
   347  		}
   348  
   349  		var (
   350  			fieldDesc      = it.schema.FindFieldByNumber(int32(i))
   351  			protoFieldType = protoFieldTypeNotFound
   352  		)
   353  		if fieldDesc != nil {
   354  			protoFieldType = fieldDesc.GetType()
   355  		}
   356  
   357  		customFieldState := newCustomFieldState(i, protoFieldType, fieldType)
   358  		it.customFields = append(it.customFields, customFieldState)
   359  	}
   360  
   361  	return nil
   362  }
   363  
   364  func (it *iterator) readCustomValues() error {
   365  	for i, customField := range it.customFields {
   366  		switch {
   367  		case isCustomFloatEncodedField(customField.fieldType):
   368  			if err := it.readFloatValue(i); err != nil {
   369  				return err
   370  			}
   371  		case isCustomIntEncodedField(customField.fieldType):
   372  			if err := it.readIntValue(i); err != nil {
   373  				return err
   374  			}
   375  		case customField.fieldType == bytesField:
   376  			if err := it.readBytesValue(i, customField); err != nil {
   377  				return err
   378  			}
   379  		case customField.fieldType == boolField:
   380  			if err := it.readBoolValue(i); err != nil {
   381  				return err
   382  			}
   383  		default:
   384  			return fmt.Errorf(
   385  				"%s: unhandled custom field type: %v", itErrPrefix, customField.fieldType)
   386  		}
   387  	}
   388  
   389  	return nil
   390  }
   391  
   392  func (it *iterator) readNonCustomValues() error {
   393  	protoChangesControlBit, err := it.stream.ReadBit()
   394  	if err != nil {
   395  		return fmt.Errorf("%s err reading proto changes control bit: %v", itErrPrefix, err)
   396  	}
   397  
   398  	if protoChangesControlBit == opCodeNoChange {
   399  		// No changes since previous message.
   400  		return nil
   401  	}
   402  
   403  	fieldsSetToDefaultControlBit, err := it.stream.ReadBit()
   404  	if err != nil {
   405  		return fmt.Errorf("%s err reading field set to default control bit: %v", itErrPrefix, err)
   406  	}
   407  
   408  	if fieldsSetToDefaultControlBit == opCodeFieldsSetToDefaultProtoMarshal {
   409  		// Some fields set to default value, need to read bitset.
   410  		err = it.readBitset()
   411  		if err != nil {
   412  			return fmt.Errorf(
   413  				"error readining changed proto field numbers bitset: %v", err)
   414  		}
   415  	}
   416  
   417  	it.skipToNextByte()
   418  	marshalLen, err := it.readVarInt()
   419  	if err != nil {
   420  		return fmt.Errorf("%s err reading proto length varint: %v", itErrPrefix, err)
   421  	}
   422  
   423  	if marshalLen > maxMarshalledProtoMessageSize {
   424  		return fmt.Errorf(
   425  			"%s marshalled protobuf size was %d which is larger than the maximum of %d",
   426  			itErrPrefix, marshalLen, maxMarshalledProtoMessageSize)
   427  	}
   428  
   429  	it.resetUnmarshalProtoBuffer(int(marshalLen))
   430  	unmarshalBytes := it.unmarshalProtoBuf.Bytes()
   431  	n, err := it.stream.Read(unmarshalBytes)
   432  	if err != nil {
   433  		return fmt.Errorf("%s: error reading marshalled proto bytes: %v", itErrPrefix, err)
   434  	}
   435  	if n != int(marshalLen) {
   436  		return fmt.Errorf(
   437  			"%s tried to read %d marshalled proto bytes but only read %d",
   438  			itErrPrefix, int(marshalLen), n)
   439  	}
   440  
   441  	if it.unmarshaller == nil {
   442  		// Lazy init.
   443  		it.unmarshaller = newCustomFieldUnmarshaller(customUnmarshallerOptions{
   444  			// Skip over unknown fields when unmarshalling because its possible that the stream was
   445  			// encoded with a newer schema.
   446  			skipUnknownFields: true,
   447  		})
   448  	}
   449  
   450  	if err := it.unmarshaller.resetAndUnmarshal(it.schema, unmarshalBytes); err != nil {
   451  		return fmt.Errorf(
   452  			"%s error unmarshalling message: %v", itErrPrefix, err)
   453  	}
   454  	customFieldValues := it.unmarshaller.sortedCustomFieldValues()
   455  	if len(customFieldValues) > 0 {
   456  		// If the proto portion of the message has any fields that could  have been custom
   457  		// encoded then something went wrong on the encoding side.
   458  		return fmt.Errorf(
   459  			"%s encoded protobuf portion of message had custom fields", itErrPrefix)
   460  	}
   461  
   462  	// Update any non custom fields that have explicitly changed (they were explicitly included
   463  	// in the marshalled stream).
   464  	var (
   465  		unmarshalledNonCustomFields = it.unmarshaller.sortedNonCustomFieldValues()
   466  		// Matching entries in two sorted lists in which every element in each list is unique so keep
   467  		// track of the last index at which a match was found so that subsequent inner loops can start
   468  		// at the next index.
   469  		lastMatchIdx = -1
   470  	)
   471  	for _, nonCustomField := range unmarshalledNonCustomFields {
   472  		for i := lastMatchIdx + 1; i < len(it.nonCustomFields); i++ {
   473  			existingNonCustomField := it.nonCustomFields[i]
   474  			if nonCustomField.fieldNum != existingNonCustomField.fieldNum {
   475  				continue
   476  			}
   477  
   478  			// Copy because the underlying bytes get reused between reads. Also try and reuse the existing
   479  			// capacity to prevent an allocation if possible.
   480  			it.nonCustomFields[i].marshalled = append(
   481  				it.nonCustomFields[i].marshalled[:0],
   482  				nonCustomField.marshalled...)
   483  
   484  			lastMatchIdx = i
   485  			break
   486  		}
   487  	}
   488  
   489  	// Update any non custom fields that have been explicitly set to their default value as determined
   490  	// by the bitset.
   491  	if fieldsSetToDefaultControlBit == opCodeFieldsSetToDefaultProtoMarshal {
   492  		// Same comment as above about matching entries in two sorted lists.
   493  		lastMatchIdx := -1
   494  		for _, fieldNum := range it.bitsetValues {
   495  			for i := lastMatchIdx + 1; i < len(it.nonCustomFields); i++ {
   496  				nonCustomField := it.nonCustomFields[i]
   497  				if fieldNum != int(nonCustomField.fieldNum) {
   498  					continue
   499  				}
   500  
   501  				// Resize slice to zero so that the existing capacity can be reused later if required.
   502  				it.nonCustomFields[i].marshalled = it.nonCustomFields[i].marshalled[:0]
   503  				lastMatchIdx = i
   504  				break
   505  			}
   506  		}
   507  	}
   508  
   509  	return nil
   510  }
   511  
   512  func (it *iterator) readFloatValue(i int) error {
   513  	if err := it.customFields[i].floatEncAndIter.ReadFloat(it.stream); err != nil {
   514  		return err
   515  	}
   516  
   517  	updateArg := updateLastIterArg{i: i}
   518  	return it.updateMarshallerWithCustomValues(updateArg)
   519  }
   520  
   521  func (it *iterator) readBytesValue(i int, customField customFieldState) error {
   522  	bytesChangedControlBit, err := it.stream.ReadBit()
   523  	if err != nil {
   524  		return fmt.Errorf(
   525  			"%s: error trying to read bytes changed control bit: %v",
   526  			itErrPrefix, err)
   527  	}
   528  
   529  	if bytesChangedControlBit == opCodeNoChange {
   530  		// No changes to the bytes value.
   531  		lastValueBytesDict, err := it.lastValueBytesDict(i)
   532  		if err != nil {
   533  			return err
   534  		}
   535  		updateArg := updateLastIterArg{i: i, bytesFieldBuf: lastValueBytesDict}
   536  		return it.updateMarshallerWithCustomValues(updateArg)
   537  	}
   538  
   539  	// Bytes have changed since the previous value.
   540  	valueInDictControlBit, err := it.stream.ReadBit()
   541  	if err != nil {
   542  		return fmt.Errorf(
   543  			"%s error trying to read bytes changed control bit: %v",
   544  			itErrPrefix, err)
   545  	}
   546  
   547  	if valueInDictControlBit == opCodeInterpretSubsequentBitsAsLRUIndex {
   548  		dictIdxBits, err := it.stream.ReadBits(
   549  			uint8(numBitsRequiredForNumUpToN(it.byteFieldDictLRUSize)))
   550  		if err != nil {
   551  			return fmt.Errorf(
   552  				"%s error trying to read bytes dict idx: %v",
   553  				itErrPrefix, err)
   554  		}
   555  
   556  		dictIdx := int(dictIdxBits)
   557  		if dictIdx >= len(customField.iteratorBytesFieldDict) || dictIdx < 0 {
   558  			return fmt.Errorf(
   559  				"%s read bytes field dictionary index: %d, but dictionary is size: %d",
   560  				itErrPrefix, dictIdx, len(customField.iteratorBytesFieldDict))
   561  		}
   562  
   563  		bytesVal := customField.iteratorBytesFieldDict[dictIdx]
   564  		it.moveToEndOfBytesDict(i, dictIdx)
   565  
   566  		updateArg := updateLastIterArg{i: i, bytesFieldBuf: bytesVal}
   567  		return it.updateMarshallerWithCustomValues(updateArg)
   568  	}
   569  
   570  	// New value that was not in the dict already.
   571  	bytesLen, err := it.readVarInt()
   572  	if err != nil {
   573  		return fmt.Errorf(
   574  			"%s error trying to read bytes length: %v", itErrPrefix, err)
   575  	}
   576  
   577  	if err := it.skipToNextByte(); err != nil {
   578  		return fmt.Errorf(
   579  			"%s error trying to skip bytes value bit padding: %v",
   580  			itErrPrefix, err)
   581  	}
   582  
   583  	// Reuse the byte slice that is about to be evicted (if any) to read into instead of
   584  	// allocating if possible.
   585  	buf := it.nextToBeEvicted(i)
   586  	if cap(buf) < int(bytesLen) {
   587  		buf = make([]byte, bytesLen)
   588  	}
   589  	buf = buf[:bytesLen]
   590  
   591  	n, err := it.stream.Read(buf)
   592  	if err != nil {
   593  		return fmt.Errorf(
   594  			"%s error trying to read byte in readBytes: %v",
   595  			itErrPrefix, err)
   596  	}
   597  	if bytesLen != uint64(n) {
   598  		return fmt.Errorf(
   599  			"%s tried to read %d bytes but only read: %d", itErrPrefix, bytesLen, n)
   600  	}
   601  
   602  	it.addToBytesDict(i, buf)
   603  
   604  	updateArg := updateLastIterArg{i: i, bytesFieldBuf: buf}
   605  	return it.updateMarshallerWithCustomValues(updateArg)
   606  }
   607  
   608  func (it *iterator) readIntValue(i int) error {
   609  	if err := it.customFields[i].intEncAndIter.readIntValue(it.stream); err != nil {
   610  		return err
   611  	}
   612  
   613  	updateArg := updateLastIterArg{i: i}
   614  	return it.updateMarshallerWithCustomValues(updateArg)
   615  }
   616  
   617  func (it *iterator) readBoolValue(i int) error {
   618  	boolOpCode, err := it.stream.ReadBit()
   619  	if err != nil {
   620  		return fmt.Errorf(
   621  			"%s: error trying to read bool value: %v",
   622  			itErrPrefix, err)
   623  	}
   624  
   625  	boolVal := boolOpCode == opCodeBoolTrue
   626  	updateArg := updateLastIterArg{i: i, boolVal: boolVal}
   627  	return it.updateMarshallerWithCustomValues(updateArg)
   628  }
   629  
   630  type updateLastIterArg struct {
   631  	i             int
   632  	bytesFieldBuf []byte
   633  	boolVal       bool
   634  }
   635  
   636  // updateMarshallerWithCustomValues updates the marshalled stream with the current
   637  // value of the custom field at index i. This ensures that marshalled protobuf stream
   638  // returned by Current() contains the most recent value for all of the custom fields.
   639  func (it *iterator) updateMarshallerWithCustomValues(arg updateLastIterArg) error {
   640  	var (
   641  		fieldNum       = int32(it.customFields[arg.i].fieldNum)
   642  		fieldType      = it.customFields[arg.i].fieldType
   643  		protoFieldType = it.customFields[arg.i].protoFieldType
   644  	)
   645  
   646  	if protoFieldType == protoFieldTypeNotFound {
   647  		// This can happen when the field being decoded does not exist (or is reserved)
   648  		// in the current schema, but the message was encoded with a schema in which the
   649  		// field number did exist.
   650  		return nil
   651  	}
   652  
   653  	switch {
   654  	case isCustomFloatEncodedField(fieldType):
   655  		var (
   656  			val = math.Float64frombits(it.customFields[arg.i].floatEncAndIter.PrevFloatBits)
   657  			err error
   658  		)
   659  		if fieldType == float64Field {
   660  			it.marshaller.encFloat64(fieldNum, val)
   661  		} else {
   662  			it.marshaller.encFloat32(fieldNum, float32(val))
   663  		}
   664  		return err
   665  
   666  	case isCustomIntEncodedField(fieldType):
   667  		switch fieldType {
   668  		case signedInt64Field:
   669  			val := int64(it.customFields[arg.i].intEncAndIter.prevIntBits)
   670  			if protoFieldType == dpb.FieldDescriptorProto_TYPE_SINT64 {
   671  				// The encoding / compression schema in this package treats Protobuf int32 and sint32 the same,
   672  				// however, Protobuf unmarshallers assume that fields of type sint are zigzag encoded. As a result,
   673  				// the iterator needs to check the fields protobuf type so that it can perform the correct encoding.
   674  				it.marshaller.encSInt64(fieldNum, val)
   675  			} else if protoFieldType == dpb.FieldDescriptorProto_TYPE_SFIXED64 {
   676  				it.marshaller.encSFixedInt64(fieldNum, val)
   677  			} else {
   678  				it.marshaller.encInt64(fieldNum, val)
   679  			}
   680  			return nil
   681  
   682  		case unsignedInt64Field:
   683  			val := it.customFields[arg.i].intEncAndIter.prevIntBits
   684  			it.marshaller.encUInt64(fieldNum, val)
   685  			return nil
   686  
   687  		case signedInt32Field:
   688  			var (
   689  				val   = int32(it.customFields[arg.i].intEncAndIter.prevIntBits)
   690  				field = it.schema.FindFieldByNumber(fieldNum)
   691  			)
   692  			if field == nil {
   693  				return fmt.Errorf(
   694  					"updating last iterated with value, could not find field number %d in schema", fieldNum)
   695  			}
   696  
   697  			fieldType := field.GetType()
   698  			if fieldType == dpb.FieldDescriptorProto_TYPE_SINT32 {
   699  				// The encoding / compression schema in this package treats Protobuf int32 and sint32 the same,
   700  				// however, Protobuf unmarshallers assume that fields of type sint are zigzag encoded. As a result,
   701  				// the iterator needs to check the fields protobuf type so that it can perform the correct encoding.
   702  				it.marshaller.encSInt32(fieldNum, val)
   703  			} else if fieldType == dpb.FieldDescriptorProto_TYPE_SFIXED32 {
   704  				it.marshaller.encSFixedInt32(fieldNum, val)
   705  			} else {
   706  				it.marshaller.encInt32(fieldNum, val)
   707  			}
   708  			return nil
   709  
   710  		case unsignedInt32Field:
   711  			val := uint32(it.customFields[arg.i].intEncAndIter.prevIntBits)
   712  			it.marshaller.encUInt32(fieldNum, val)
   713  			return nil
   714  
   715  		default:
   716  			return fmt.Errorf(
   717  				"%s expected custom int encoded field but field type was: %v",
   718  				itErrPrefix, fieldType)
   719  		}
   720  
   721  	case fieldType == bytesField:
   722  		it.marshaller.encBytes(fieldNum, arg.bytesFieldBuf)
   723  		return nil
   724  
   725  	case fieldType == boolField:
   726  		it.marshaller.encBool(fieldNum, arg.boolVal)
   727  		return nil
   728  
   729  	default:
   730  		return fmt.Errorf(
   731  			"%s unhandled fieldType: %v", itErrPrefix, fieldType)
   732  	}
   733  }
   734  
   735  // readBitset does the inverse of encodeBitset on the encoder struct.
   736  func (it *iterator) readBitset() error {
   737  	it.bitsetValues = it.bitsetValues[:0]
   738  	bitsetLengthBits, err := it.readVarInt()
   739  	if err != nil {
   740  		return err
   741  	}
   742  
   743  	for i := uint64(0); i < bitsetLengthBits; i++ {
   744  		bit, err := it.stream.ReadBit()
   745  		if err != nil {
   746  			return fmt.Errorf("%s error reading bitset: %v", itErrPrefix, err)
   747  		}
   748  
   749  		if bit == opCodeBitsetValueIsSet {
   750  			// Add 1 because protobuf fields are 1-indexed not 0-indexed.
   751  			it.bitsetValues = append(it.bitsetValues, int(i)+1)
   752  		}
   753  	}
   754  
   755  	return nil
   756  }
   757  
   758  func (it *iterator) readVarInt() (uint64, error) {
   759  	var (
   760  		// Convert array to slice and reset size to zero so
   761  		// we can reuse the buffer.
   762  		buf      = it.varIntBuf[:0]
   763  		numBytes = 0
   764  	)
   765  	for {
   766  		b, err := it.stream.ReadByte()
   767  		if err != nil {
   768  			return 0, fmt.Errorf("%s error reading var int: %v", itErrPrefix, err)
   769  		}
   770  
   771  		buf = append(buf, b)
   772  		numBytes++
   773  
   774  		if b>>7 == 0 {
   775  			break
   776  		}
   777  	}
   778  
   779  	buf = buf[:numBytes]
   780  	varInt, _ := binary.Uvarint(buf)
   781  	return varInt, nil
   782  }
   783  
   784  // skipToNextByte will skip over any remaining bits in the current byte
   785  // to reach the next byte. This is used in situations where the stream
   786  // has padding bits to keep portions of data aligned at the byte boundary.
   787  func (it *iterator) skipToNextByte() error {
   788  	remainingBitsInByte := it.stream.RemainingBitsInCurrentByte()
   789  	for remainingBitsInByte > 0 {
   790  		_, err := it.stream.ReadBit()
   791  		if err != nil {
   792  			return err
   793  		}
   794  		remainingBitsInByte--
   795  	}
   796  
   797  	return nil
   798  }
   799  
   800  func (it *iterator) moveToEndOfBytesDict(fieldIdx, i int) {
   801  	existing := it.customFields[fieldIdx].iteratorBytesFieldDict
   802  	for j := i; j < len(existing); j++ {
   803  		nextIdx := j + 1
   804  		if nextIdx >= len(existing) {
   805  			break
   806  		}
   807  
   808  		currVal := existing[j]
   809  		nextVal := existing[nextIdx]
   810  		existing[j] = nextVal
   811  		existing[nextIdx] = currVal
   812  	}
   813  }
   814  
   815  func (it *iterator) addToBytesDict(fieldIdx int, b []byte) {
   816  	existing := it.customFields[fieldIdx].iteratorBytesFieldDict
   817  	if len(existing) < it.byteFieldDictLRUSize {
   818  		it.customFields[fieldIdx].iteratorBytesFieldDict = append(existing, b)
   819  		return
   820  	}
   821  
   822  	// Shift everything down 1 and replace the last value to evict the
   823  	// least recently used entry and add the newest one.
   824  	//     [1,2,3]
   825  	// becomes
   826  	//     [2,3,3]
   827  	// after shift, and then becomes
   828  	//     [2,3,4]
   829  	// after replacing the last value.
   830  	for i := range existing {
   831  		nextIdx := i + 1
   832  		if nextIdx >= len(existing) {
   833  			break
   834  		}
   835  
   836  		existing[i] = existing[nextIdx]
   837  	}
   838  
   839  	existing[len(existing)-1] = b
   840  }
   841  
   842  func (it *iterator) lastValueBytesDict(fieldIdx int) ([]byte, error) {
   843  	dict := it.customFields[fieldIdx].iteratorBytesFieldDict
   844  	if len(dict) == 0 {
   845  		return nil, fmt.Errorf("tried to read last value of bytes dictionary for empty dictionary")
   846  	}
   847  	return dict[len(dict)-1], nil
   848  }
   849  
   850  func (it *iterator) nextToBeEvicted(fieldIdx int) []byte {
   851  	dict := it.customFields[fieldIdx].iteratorBytesFieldDict
   852  	if len(dict) == 0 {
   853  		return nil
   854  	}
   855  
   856  	if len(dict) < it.byteFieldDictLRUSize {
   857  		// Next add won't trigger an eviction.
   858  		return nil
   859  	}
   860  
   861  	return dict[0]
   862  }
   863  
   864  func (it *iterator) resetUnmarshalProtoBuffer(n int) {
   865  	if it.unmarshalProtoBuf != nil && it.unmarshalProtoBuf.Cap() >= n {
   866  		// If the existing one is big enough, just resize it.
   867  		it.unmarshalProtoBuf.Resize(n)
   868  		return
   869  	}
   870  
   871  	if it.unmarshalProtoBuf != nil {
   872  		// If one exists, but its too small, return it to the pool.
   873  		it.unmarshalProtoBuf.DecRef()
   874  		it.unmarshalProtoBuf.Finalize()
   875  	}
   876  
   877  	// If none exists (or one existed but it was too small) get a new one
   878  	// and IncRef(). DecRef() will never be called unless this one is
   879  	// replaced by a new one later.
   880  	it.unmarshalProtoBuf = it.newBuffer(n)
   881  	it.unmarshalProtoBuf.IncRef()
   882  	it.unmarshalProtoBuf.Resize(n)
   883  }
   884  
   885  func (it *iterator) hasNext() bool {
   886  	return !it.hasError() && !it.isDone() && !it.isClosed()
   887  }
   888  
   889  func (it *iterator) hasError() bool {
   890  	return it.err != nil
   891  }
   892  
   893  func (it *iterator) isDone() bool {
   894  	return it.done
   895  }
   896  
   897  func (it *iterator) isClosed() bool {
   898  	return it.closed
   899  }
   900  
   901  func (it *iterator) newBuffer(capacity int) checked.Bytes {
   902  	if bytesPool := it.opts.BytesPool(); bytesPool != nil {
   903  		return bytesPool.Get(capacity)
   904  	}
   905  	return checked.NewBytes(make([]byte, 0, capacity), nil)
   906  }