github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/protocol/compact/iterator.go (about)

     1  package compact
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"github.com/batchcorp/thrift-iterator/protocol"
     7  	"github.com/batchcorp/thrift-iterator/spi"
     8  	"io"
     9  	"math"
    10  )
    11  
    12  type Iterator struct {
    13  	spi.ValDecoderProvider
    14  	reader  io.Reader
    15  	tmp     []byte
    16  	preread []byte
    17  	skipped []byte
    18  
    19  	err              error
    20  	fieldIdStack     []protocol.FieldId
    21  	lastFieldId      protocol.FieldId
    22  	pendingBoolField uint8
    23  }
    24  
    25  func NewIterator(provider spi.ValDecoderProvider, reader io.Reader, buf []byte) *Iterator {
    26  	return &Iterator{
    27  		ValDecoderProvider: provider,
    28  		reader:             reader,
    29  		tmp:                make([]byte, 8),
    30  		preread:            buf,
    31  	}
    32  }
    33  
    34  func (iter *Iterator) readByte() byte {
    35  	tmp := iter.tmp[:1]
    36  	if len(iter.preread) > 0 {
    37  		tmp[0] = iter.preread[0]
    38  		iter.preread = iter.preread[1:]
    39  	} else {
    40  		_, err := iter.reader.Read(tmp)
    41  		if err != nil {
    42  			iter.ReportError("read", err.Error())
    43  			return 0
    44  		}
    45  	}
    46  	if iter.skipped != nil {
    47  		iter.skipped = append(iter.skipped, tmp[0])
    48  	}
    49  	return tmp[0]
    50  }
    51  
    52  func (iter *Iterator) readSmall(nBytes int) []byte {
    53  	tmp := iter.tmp[:nBytes]
    54  	wantBytes := nBytes
    55  	if len(iter.preread) > 0 {
    56  		if len(iter.preread) > nBytes {
    57  			copy(tmp, iter.preread[:nBytes])
    58  			iter.preread = iter.preread[nBytes:]
    59  			wantBytes = 0
    60  		} else {
    61  			prelength := len(iter.preread)
    62  			copy(tmp[:prelength], iter.preread)
    63  			wantBytes -= prelength
    64  			iter.preread = nil
    65  		}
    66  	}
    67  	if wantBytes > 0 {
    68  		_, err := io.ReadFull(iter.reader, tmp[nBytes-wantBytes:nBytes])
    69  		if err != nil {
    70  			for i := 0; i < len(tmp); i++ {
    71  				tmp[i] = 0
    72  			}
    73  			iter.ReportError("read", err.Error())
    74  			return tmp
    75  		}
    76  	}
    77  	if iter.skipped != nil {
    78  		iter.skipped = append(iter.skipped, tmp...)
    79  	}
    80  	return tmp
    81  }
    82  
    83  func (iter *Iterator) readLarge(nBytes int) []byte {
    84  	// allocate new buffer if not enough
    85  	if len(iter.tmp) < nBytes {
    86  		iter.tmp = make([]byte, nBytes)
    87  	}
    88  	return iter.readSmall(nBytes)
    89  }
    90  
    91  func (iter *Iterator) readVarInt32() int32 {
    92  	return int32(iter.readVarInt64())
    93  }
    94  
    95  func (iter *Iterator) readVarInt64() int64 {
    96  	shift := uint(0)
    97  	result := int64(0)
    98  	for {
    99  		b := iter.readByte()
   100  		if iter.err != nil {
   101  			return 0
   102  		}
   103  		result |= int64(b&0x7f) << shift
   104  		if (b & 0x80) != 0x80 {
   105  			break
   106  		}
   107  		shift += 7
   108  	}
   109  	return result
   110  }
   111  
   112  func (iter *Iterator) Spawn() spi.Iterator {
   113  	return NewIterator(iter.ValDecoderProvider, nil, nil)
   114  }
   115  
   116  func (iter *Iterator) Error() error {
   117  	return iter.err
   118  }
   119  
   120  func (iter *Iterator) ReportError(operation string, err string) {
   121  	if iter.err == nil {
   122  		iter.err = fmt.Errorf("%s: %s", operation, err)
   123  	}
   124  }
   125  
   126  func (iter *Iterator) Reset(reader io.Reader, buf []byte) {
   127  	iter.reader = reader
   128  	iter.preread = buf
   129  	iter.err = nil
   130  }
   131  
   132  func (iter *Iterator) ReadMessageHeader() protocol.MessageHeader {
   133  	protocolId := iter.readByte()
   134  	if protocolId != protocol.COMPACT_PROTOCOL_ID {
   135  		iter.ReportError("ReadMessageHeader", "invalid protocol")
   136  		return protocol.MessageHeader{}
   137  	}
   138  	versionAndType := iter.readByte()
   139  	version := versionAndType & protocol.COMPACT_VERSION_MASK
   140  	messageType := protocol.TMessageType((versionAndType >> 5) & 0x07)
   141  	if version != protocol.COMPACT_VERSION {
   142  		iter.ReportError("ReadMessageHeader", fmt.Sprintf("expected version %02x but got %02x", protocol.COMPACT_VERSION, version))
   143  		return protocol.MessageHeader{}
   144  	}
   145  	seqId := protocol.SeqId(iter.readVarInt32())
   146  	messageName := iter.ReadString()
   147  	return protocol.MessageHeader{
   148  		MessageName: messageName,
   149  		MessageType: messageType,
   150  		SeqId:       seqId,
   151  	}
   152  }
   153  
   154  func (iter *Iterator) ReadStructHeader() {
   155  	iter.fieldIdStack = append(iter.fieldIdStack, iter.lastFieldId)
   156  	iter.lastFieldId = 0
   157  }
   158  
   159  func (iter *Iterator) ReadStructField() (fieldType protocol.TType, fieldId protocol.FieldId) {
   160  	firstByte := iter.readByte()
   161  	if firstByte == 0 {
   162  		if iter.Error() != nil {
   163  			return protocol.TypeStop, 0
   164  		}
   165  		iter.lastFieldId = iter.fieldIdStack[len(iter.fieldIdStack)-1]
   166  		iter.fieldIdStack = iter.fieldIdStack[:len(iter.fieldIdStack)-1]
   167  		iter.pendingBoolField = 0
   168  		return protocol.TType(firstByte), 0
   169  	}
   170  	// mask off the 4 MSB of the type header. it could contain a field id delta.
   171  	modifier := int16((firstByte & 0xf0) >> 4)
   172  	if modifier == 0 {
   173  		// not a delta, look ahead for the zigzag varint field id.
   174  		fieldId = protocol.FieldId(iter.ReadInt16())
   175  	} else {
   176  		// has a delta. add the delta to the last read field id.
   177  		fieldId = iter.lastFieldId + protocol.FieldId(modifier)
   178  	}
   179  	switch tType := TCompactType(firstByte & 0x0f); tType {
   180  	case TypeBooleanTrue:
   181  		fieldType = protocol.TypeBool
   182  		iter.pendingBoolField = 1
   183  	case TypeBooleanFalse:
   184  		fieldType = protocol.TypeBool
   185  		iter.pendingBoolField = 2
   186  	default:
   187  		fieldType = tType.ToTType()
   188  		iter.pendingBoolField = 0
   189  	}
   190  
   191  	// push the new field onto the field stack so we can keep the deltas going.
   192  	iter.lastFieldId = fieldId
   193  	return fieldType, fieldId
   194  }
   195  
   196  func (iter *Iterator) ReadListHeader() (elemType protocol.TType, size int) {
   197  	lenAndType := iter.readByte()
   198  	length := int((lenAndType >> 4) & 0x0f)
   199  	if length == 15 {
   200  		length2 := iter.readVarInt32()
   201  		if length2 < 0 {
   202  			iter.ReportError("ReadListHeader", "invalid length")
   203  			return protocol.TypeStop, 0
   204  		}
   205  		length = int(length2)
   206  	}
   207  	elemType = TCompactType(lenAndType).ToTType()
   208  	return elemType, length
   209  }
   210  
   211  func (iter *Iterator) ReadMapHeader() (keyType protocol.TType, elemType protocol.TType, size int) {
   212  	length := int(iter.readVarInt32())
   213  	if length == 0 {
   214  		return protocol.TypeStop, protocol.TypeStop, length
   215  	}
   216  	keyAndElemType := iter.readByte()
   217  	keyType = TCompactType(keyAndElemType >> 4).ToTType()
   218  	elemType = TCompactType(keyAndElemType & 0xf).ToTType()
   219  	return keyType, elemType, length
   220  }
   221  
   222  func (iter *Iterator) ReadBool() bool {
   223  	if iter.pendingBoolField == 0 {
   224  		return iter.ReadUint8() == 1
   225  	}
   226  	return iter.pendingBoolField == 1
   227  }
   228  
   229  func (iter *Iterator) ReadInt() int {
   230  	return int(iter.ReadInt64())
   231  }
   232  
   233  func (iter *Iterator) ReadUint() uint {
   234  	return uint(iter.ReadInt64())
   235  }
   236  
   237  func (iter *Iterator) ReadInt8() int8 {
   238  	return int8(iter.ReadUint8())
   239  }
   240  
   241  func (iter *Iterator) ReadUint8() uint8 {
   242  	return iter.readByte()
   243  }
   244  
   245  func (iter *Iterator) ReadInt16() int16 {
   246  	return int16(iter.ReadInt32())
   247  }
   248  
   249  func (iter *Iterator) ReadUint16() uint16 {
   250  	return uint16(iter.ReadUint32())
   251  }
   252  
   253  func (iter *Iterator) ReadInt32() int32 {
   254  	result := iter.readVarInt32()
   255  	u := uint32(result)
   256  	return int32(u>>1) ^ -(result & 1)
   257  }
   258  
   259  func (iter *Iterator) ReadUint32() uint32 {
   260  	return uint32(iter.ReadInt32())
   261  }
   262  
   263  func (iter *Iterator) ReadInt64() int64 {
   264  	result := iter.readVarInt64()
   265  	u := uint64(result)
   266  	return int64(u>>1) ^ -(result & 1)
   267  }
   268  
   269  func (iter *Iterator) ReadUint64() uint64 {
   270  	return uint64(iter.ReadInt64())
   271  }
   272  
   273  func (iter *Iterator) ReadFloat64() float64 {
   274  	tmp := iter.readSmall(8)
   275  	return math.Float64frombits(binary.LittleEndian.Uint64(tmp))
   276  }
   277  
   278  func (iter *Iterator) ReadString() string {
   279  	length := iter.readVarInt32()
   280  	return string(iter.readLarge(int(length)))
   281  }
   282  
   283  func (iter *Iterator) ReadBinary() []byte {
   284  	length := iter.readVarInt32()
   285  	tmp := make([]byte, length)
   286  	copy(tmp, iter.readLarge(int(length)))
   287  	return tmp
   288  }