github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/protocol/binary/iterator.go (about)

     1  package binary
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/batchcorp/thrift-iterator/protocol"
     6  	"github.com/batchcorp/thrift-iterator/spi"
     7  	"io"
     8  	"math"
     9  )
    10  
    11  type Iterator struct {
    12  	spi.ValDecoderProvider
    13  	reader  io.Reader
    14  	tmp     []byte
    15  	preread []byte
    16  	skipped []byte
    17  	err     error
    18  }
    19  
    20  func NewIterator(provider spi.ValDecoderProvider, reader io.Reader, buf []byte) *Iterator {
    21  	return &Iterator{
    22  		ValDecoderProvider: provider,
    23  		reader:             reader,
    24  		tmp:                make([]byte, 8),
    25  		preread:            buf,
    26  	}
    27  }
    28  
    29  func (iter *Iterator) readByte() byte {
    30  	tmp := iter.tmp[:1]
    31  	if len(iter.preread) > 0 {
    32  		tmp[0] = iter.preread[0]
    33  		iter.preread = iter.preread[1:]
    34  	} else {
    35  		_, err := iter.reader.Read(tmp)
    36  		if err != nil {
    37  			iter.ReportError("read", err.Error())
    38  			return 0
    39  		}
    40  	}
    41  	if iter.skipped != nil {
    42  		iter.skipped = append(iter.skipped, tmp[0])
    43  	}
    44  	return tmp[0]
    45  }
    46  
    47  func (iter *Iterator) readSmall(nBytes int) []byte {
    48  	tmp := iter.tmp[:nBytes]
    49  	wantBytes := nBytes
    50  	if len(iter.preread) > 0 {
    51  		if len(iter.preread) > nBytes {
    52  			copy(tmp, iter.preread[:nBytes])
    53  			iter.preread = iter.preread[nBytes:]
    54  			wantBytes = 0
    55  		} else {
    56  			prelength := len(iter.preread)
    57  			copy(tmp[:prelength], iter.preread)
    58  			wantBytes -= prelength
    59  			iter.preread = nil
    60  		}
    61  	}
    62  	if wantBytes > 0 {
    63  		_, err := io.ReadFull(iter.reader, tmp[nBytes-wantBytes:nBytes])
    64  		if err != nil {
    65  			for i := 0; i < len(tmp); i++ {
    66  				tmp[i] = 0
    67  			}
    68  			iter.ReportError("read", err.Error())
    69  			return tmp
    70  		}
    71  	}
    72  	if iter.skipped != nil {
    73  		iter.skipped = append(iter.skipped, tmp...)
    74  	}
    75  	return tmp
    76  }
    77  
    78  func (iter *Iterator) readLarge(nBytes int) []byte {
    79  	// allocate new buffer if not enough
    80  	if len(iter.tmp) < nBytes {
    81  		iter.tmp = make([]byte, nBytes)
    82  	}
    83  	return iter.readSmall(nBytes)
    84  }
    85  
    86  func (iter *Iterator) Spawn() spi.Iterator {
    87  	return NewIterator(iter.ValDecoderProvider, nil, nil)
    88  }
    89  
    90  func (iter *Iterator) Error() error {
    91  	return iter.err
    92  }
    93  
    94  func (iter *Iterator) ReportError(operation string, err string) {
    95  	if iter.err == nil {
    96  		iter.err = fmt.Errorf("%s: %s", operation, err)
    97  	}
    98  }
    99  
   100  func (iter *Iterator) Reset(reader io.Reader, buf []byte) {
   101  	iter.reader = reader
   102  	iter.preread = buf
   103  	iter.err = nil
   104  }
   105  
   106  func (iter *Iterator) ReadMessageHeader() protocol.MessageHeader {
   107  	versionAndMessageType := iter.ReadInt32()
   108  	messageType := protocol.TMessageType(versionAndMessageType & 0x0ff)
   109  	version := int64(int64(versionAndMessageType) & 0xffff0000)
   110  	if version != protocol.BINARY_VERSION_1 {
   111  		iter.ReportError("ReadMessageHeader", "unexpected version")
   112  		return protocol.MessageHeader{}
   113  	}
   114  	messageName := iter.ReadString()
   115  	seqId := protocol.SeqId(iter.ReadInt32())
   116  	return protocol.MessageHeader{
   117  		MessageName: messageName,
   118  		MessageType: messageType,
   119  		SeqId:       seqId,
   120  	}
   121  }
   122  
   123  func (iter *Iterator) ReadStructHeader() {
   124  	// noop
   125  }
   126  
   127  func (iter *Iterator) ReadStructField() (fieldType protocol.TType, fieldId protocol.FieldId) {
   128  	firstByte := iter.readByte()
   129  	fieldType = protocol.TType(firstByte)
   130  	if fieldType == protocol.TypeStop {
   131  		return protocol.TypeStop, 0
   132  	}
   133  	fieldId = protocol.FieldId(iter.ReadUint16())
   134  	return fieldType, fieldId
   135  }
   136  
   137  func (iter *Iterator) ReadListHeader() (elemType protocol.TType, size int) {
   138  	b := iter.readSmall(5)
   139  	elemType = protocol.TType(b[0])
   140  	size = int(uint32(b[4]) | uint32(b[3])<<8 | uint32(b[2])<<16 | uint32(b[1])<<24)
   141  	return elemType, size
   142  }
   143  
   144  func (iter *Iterator) ReadMapHeader() (keyType protocol.TType, elemType protocol.TType, size int) {
   145  	b := iter.readSmall(6)
   146  	keyType = protocol.TType(b[0])
   147  	elemType = protocol.TType(b[1])
   148  	size = int(uint32(b[5]) | uint32(b[4])<<8 | uint32(b[3])<<16 | uint32(b[2])<<24)
   149  	return keyType, elemType, size
   150  }
   151  
   152  func (iter *Iterator) ReadBool() bool {
   153  	return iter.ReadUint8() == 1
   154  }
   155  
   156  func (iter *Iterator) ReadInt() int {
   157  	return int(iter.ReadInt64())
   158  }
   159  
   160  func (iter *Iterator) ReadUint() uint {
   161  	return uint(iter.ReadUint64())
   162  }
   163  
   164  func (iter *Iterator) ReadInt8() int8 {
   165  	return int8(iter.ReadUint8())
   166  }
   167  
   168  func (iter *Iterator) ReadUint8() uint8 {
   169  	return iter.readByte()
   170  }
   171  
   172  func (iter *Iterator) ReadInt16() int16 {
   173  	return int16(iter.ReadUint16())
   174  }
   175  
   176  func (iter *Iterator) ReadUint16() uint16 {
   177  	b := iter.readSmall(2)
   178  	return uint16(b[1]) | uint16(b[0])<<8
   179  }
   180  
   181  func (iter *Iterator) ReadInt32() int32 {
   182  	return int32(iter.ReadUint32())
   183  }
   184  
   185  func (iter *Iterator) ReadUint32() uint32 {
   186  	b := iter.readSmall(4)
   187  	return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
   188  }
   189  
   190  func (iter *Iterator) ReadInt64() int64 {
   191  	return int64(iter.ReadUint64())
   192  }
   193  
   194  func (iter *Iterator) ReadUint64() uint64 {
   195  	b := iter.readSmall(8)
   196  	return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
   197  		uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
   198  }
   199  
   200  func (iter *Iterator) ReadFloat64() float64 {
   201  	return math.Float64frombits(iter.ReadUint64())
   202  }
   203  
   204  func (iter *Iterator) ReadString() string {
   205  	length := iter.ReadUint32()
   206  	return string(iter.readLarge(int(length)))
   207  }
   208  
   209  func (iter *Iterator) ReadBinary() []byte {
   210  	length := iter.ReadUint32()
   211  	tmp := make([]byte, length)
   212  	copy(tmp, iter.readLarge(int(length)))
   213  	return tmp
   214  }