github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/protocol/binary/iterator.go (about) 1 package binary 2 3 import ( 4 "fmt" 5 "github.com/batchcorp/thrift-iterator/protocol" 6 "github.com/batchcorp/thrift-iterator/spi" 7 "io" 8 "math" 9 ) 10 11 type Iterator struct { 12 spi.ValDecoderProvider 13 reader io.Reader 14 tmp []byte 15 preread []byte 16 skipped []byte 17 err error 18 } 19 20 func NewIterator(provider spi.ValDecoderProvider, reader io.Reader, buf []byte) *Iterator { 21 return &Iterator{ 22 ValDecoderProvider: provider, 23 reader: reader, 24 tmp: make([]byte, 8), 25 preread: buf, 26 } 27 } 28 29 func (iter *Iterator) readByte() byte { 30 tmp := iter.tmp[:1] 31 if len(iter.preread) > 0 { 32 tmp[0] = iter.preread[0] 33 iter.preread = iter.preread[1:] 34 } else { 35 _, err := iter.reader.Read(tmp) 36 if err != nil { 37 iter.ReportError("read", err.Error()) 38 return 0 39 } 40 } 41 if iter.skipped != nil { 42 iter.skipped = append(iter.skipped, tmp[0]) 43 } 44 return tmp[0] 45 } 46 47 func (iter *Iterator) readSmall(nBytes int) []byte { 48 tmp := iter.tmp[:nBytes] 49 wantBytes := nBytes 50 if len(iter.preread) > 0 { 51 if len(iter.preread) > nBytes { 52 copy(tmp, iter.preread[:nBytes]) 53 iter.preread = iter.preread[nBytes:] 54 wantBytes = 0 55 } else { 56 prelength := len(iter.preread) 57 copy(tmp[:prelength], iter.preread) 58 wantBytes -= prelength 59 iter.preread = nil 60 } 61 } 62 if wantBytes > 0 { 63 _, err := io.ReadFull(iter.reader, tmp[nBytes-wantBytes:nBytes]) 64 if err != nil { 65 for i := 0; i < len(tmp); i++ { 66 tmp[i] = 0 67 } 68 iter.ReportError("read", err.Error()) 69 return tmp 70 } 71 } 72 if iter.skipped != nil { 73 iter.skipped = append(iter.skipped, tmp...) 74 } 75 return tmp 76 } 77 78 func (iter *Iterator) readLarge(nBytes int) []byte { 79 // allocate new buffer if not enough 80 if len(iter.tmp) < nBytes { 81 iter.tmp = make([]byte, nBytes) 82 } 83 return iter.readSmall(nBytes) 84 } 85 86 func (iter *Iterator) Spawn() spi.Iterator { 87 return NewIterator(iter.ValDecoderProvider, nil, nil) 88 } 89 90 func (iter *Iterator) Error() error { 91 return iter.err 92 } 93 94 func (iter *Iterator) ReportError(operation string, err string) { 95 if iter.err == nil { 96 iter.err = fmt.Errorf("%s: %s", operation, err) 97 } 98 } 99 100 func (iter *Iterator) Reset(reader io.Reader, buf []byte) { 101 iter.reader = reader 102 iter.preread = buf 103 iter.err = nil 104 } 105 106 func (iter *Iterator) ReadMessageHeader() protocol.MessageHeader { 107 versionAndMessageType := iter.ReadInt32() 108 messageType := protocol.TMessageType(versionAndMessageType & 0x0ff) 109 version := int64(int64(versionAndMessageType) & 0xffff0000) 110 if version != protocol.BINARY_VERSION_1 { 111 iter.ReportError("ReadMessageHeader", "unexpected version") 112 return protocol.MessageHeader{} 113 } 114 messageName := iter.ReadString() 115 seqId := protocol.SeqId(iter.ReadInt32()) 116 return protocol.MessageHeader{ 117 MessageName: messageName, 118 MessageType: messageType, 119 SeqId: seqId, 120 } 121 } 122 123 func (iter *Iterator) ReadStructHeader() { 124 // noop 125 } 126 127 func (iter *Iterator) ReadStructField() (fieldType protocol.TType, fieldId protocol.FieldId) { 128 firstByte := iter.readByte() 129 fieldType = protocol.TType(firstByte) 130 if fieldType == protocol.TypeStop { 131 return protocol.TypeStop, 0 132 } 133 fieldId = protocol.FieldId(iter.ReadUint16()) 134 return fieldType, fieldId 135 } 136 137 func (iter *Iterator) ReadListHeader() (elemType protocol.TType, size int) { 138 b := iter.readSmall(5) 139 elemType = protocol.TType(b[0]) 140 size = int(uint32(b[4]) | uint32(b[3])<<8 | uint32(b[2])<<16 | uint32(b[1])<<24) 141 return elemType, size 142 } 143 144 func (iter *Iterator) ReadMapHeader() (keyType protocol.TType, elemType protocol.TType, size int) { 145 b := iter.readSmall(6) 146 keyType = protocol.TType(b[0]) 147 elemType = protocol.TType(b[1]) 148 size = int(uint32(b[5]) | uint32(b[4])<<8 | uint32(b[3])<<16 | uint32(b[2])<<24) 149 return keyType, elemType, size 150 } 151 152 func (iter *Iterator) ReadBool() bool { 153 return iter.ReadUint8() == 1 154 } 155 156 func (iter *Iterator) ReadInt() int { 157 return int(iter.ReadInt64()) 158 } 159 160 func (iter *Iterator) ReadUint() uint { 161 return uint(iter.ReadUint64()) 162 } 163 164 func (iter *Iterator) ReadInt8() int8 { 165 return int8(iter.ReadUint8()) 166 } 167 168 func (iter *Iterator) ReadUint8() uint8 { 169 return iter.readByte() 170 } 171 172 func (iter *Iterator) ReadInt16() int16 { 173 return int16(iter.ReadUint16()) 174 } 175 176 func (iter *Iterator) ReadUint16() uint16 { 177 b := iter.readSmall(2) 178 return uint16(b[1]) | uint16(b[0])<<8 179 } 180 181 func (iter *Iterator) ReadInt32() int32 { 182 return int32(iter.ReadUint32()) 183 } 184 185 func (iter *Iterator) ReadUint32() uint32 { 186 b := iter.readSmall(4) 187 return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 188 } 189 190 func (iter *Iterator) ReadInt64() int64 { 191 return int64(iter.ReadUint64()) 192 } 193 194 func (iter *Iterator) ReadUint64() uint64 { 195 b := iter.readSmall(8) 196 return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | 197 uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56 198 } 199 200 func (iter *Iterator) ReadFloat64() float64 { 201 return math.Float64frombits(iter.ReadUint64()) 202 } 203 204 func (iter *Iterator) ReadString() string { 205 length := iter.ReadUint32() 206 return string(iter.readLarge(int(length))) 207 } 208 209 func (iter *Iterator) ReadBinary() []byte { 210 length := iter.ReadUint32() 211 tmp := make([]byte, length) 212 copy(tmp, iter.readLarge(int(length))) 213 return tmp 214 }