github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/protocol/compact/iterator.go (about) 1 package compact 2 3 import ( 4 "encoding/binary" 5 "fmt" 6 "github.com/batchcorp/thrift-iterator/protocol" 7 "github.com/batchcorp/thrift-iterator/spi" 8 "io" 9 "math" 10 ) 11 12 type Iterator struct { 13 spi.ValDecoderProvider 14 reader io.Reader 15 tmp []byte 16 preread []byte 17 skipped []byte 18 19 err error 20 fieldIdStack []protocol.FieldId 21 lastFieldId protocol.FieldId 22 pendingBoolField uint8 23 } 24 25 func NewIterator(provider spi.ValDecoderProvider, reader io.Reader, buf []byte) *Iterator { 26 return &Iterator{ 27 ValDecoderProvider: provider, 28 reader: reader, 29 tmp: make([]byte, 8), 30 preread: buf, 31 } 32 } 33 34 func (iter *Iterator) readByte() byte { 35 tmp := iter.tmp[:1] 36 if len(iter.preread) > 0 { 37 tmp[0] = iter.preread[0] 38 iter.preread = iter.preread[1:] 39 } else { 40 _, err := iter.reader.Read(tmp) 41 if err != nil { 42 iter.ReportError("read", err.Error()) 43 return 0 44 } 45 } 46 if iter.skipped != nil { 47 iter.skipped = append(iter.skipped, tmp[0]) 48 } 49 return tmp[0] 50 } 51 52 func (iter *Iterator) readSmall(nBytes int) []byte { 53 tmp := iter.tmp[:nBytes] 54 wantBytes := nBytes 55 if len(iter.preread) > 0 { 56 if len(iter.preread) > nBytes { 57 copy(tmp, iter.preread[:nBytes]) 58 iter.preread = iter.preread[nBytes:] 59 wantBytes = 0 60 } else { 61 prelength := len(iter.preread) 62 copy(tmp[:prelength], iter.preread) 63 wantBytes -= prelength 64 iter.preread = nil 65 } 66 } 67 if wantBytes > 0 { 68 _, err := io.ReadFull(iter.reader, tmp[nBytes-wantBytes:nBytes]) 69 if err != nil { 70 for i := 0; i < len(tmp); i++ { 71 tmp[i] = 0 72 } 73 iter.ReportError("read", err.Error()) 74 return tmp 75 } 76 } 77 if iter.skipped != nil { 78 iter.skipped = append(iter.skipped, tmp...) 79 } 80 return tmp 81 } 82 83 func (iter *Iterator) readLarge(nBytes int) []byte { 84 // allocate new buffer if not enough 85 if len(iter.tmp) < nBytes { 86 iter.tmp = make([]byte, nBytes) 87 } 88 return iter.readSmall(nBytes) 89 } 90 91 func (iter *Iterator) readVarInt32() int32 { 92 return int32(iter.readVarInt64()) 93 } 94 95 func (iter *Iterator) readVarInt64() int64 { 96 shift := uint(0) 97 result := int64(0) 98 for { 99 b := iter.readByte() 100 if iter.err != nil { 101 return 0 102 } 103 result |= int64(b&0x7f) << shift 104 if (b & 0x80) != 0x80 { 105 break 106 } 107 shift += 7 108 } 109 return result 110 } 111 112 func (iter *Iterator) Spawn() spi.Iterator { 113 return NewIterator(iter.ValDecoderProvider, nil, nil) 114 } 115 116 func (iter *Iterator) Error() error { 117 return iter.err 118 } 119 120 func (iter *Iterator) ReportError(operation string, err string) { 121 if iter.err == nil { 122 iter.err = fmt.Errorf("%s: %s", operation, err) 123 } 124 } 125 126 func (iter *Iterator) Reset(reader io.Reader, buf []byte) { 127 iter.reader = reader 128 iter.preread = buf 129 iter.err = nil 130 } 131 132 func (iter *Iterator) ReadMessageHeader() protocol.MessageHeader { 133 protocolId := iter.readByte() 134 if protocolId != protocol.COMPACT_PROTOCOL_ID { 135 iter.ReportError("ReadMessageHeader", "invalid protocol") 136 return protocol.MessageHeader{} 137 } 138 versionAndType := iter.readByte() 139 version := versionAndType & protocol.COMPACT_VERSION_MASK 140 messageType := protocol.TMessageType((versionAndType >> 5) & 0x07) 141 if version != protocol.COMPACT_VERSION { 142 iter.ReportError("ReadMessageHeader", fmt.Sprintf("expected version %02x but got %02x", protocol.COMPACT_VERSION, version)) 143 return protocol.MessageHeader{} 144 } 145 seqId := protocol.SeqId(iter.readVarInt32()) 146 messageName := iter.ReadString() 147 return protocol.MessageHeader{ 148 MessageName: messageName, 149 MessageType: messageType, 150 SeqId: seqId, 151 } 152 } 153 154 func (iter *Iterator) ReadStructHeader() { 155 iter.fieldIdStack = append(iter.fieldIdStack, iter.lastFieldId) 156 iter.lastFieldId = 0 157 } 158 159 func (iter *Iterator) ReadStructField() (fieldType protocol.TType, fieldId protocol.FieldId) { 160 firstByte := iter.readByte() 161 if firstByte == 0 { 162 if iter.Error() != nil { 163 return protocol.TypeStop, 0 164 } 165 iter.lastFieldId = iter.fieldIdStack[len(iter.fieldIdStack)-1] 166 iter.fieldIdStack = iter.fieldIdStack[:len(iter.fieldIdStack)-1] 167 iter.pendingBoolField = 0 168 return protocol.TType(firstByte), 0 169 } 170 // mask off the 4 MSB of the type header. it could contain a field id delta. 171 modifier := int16((firstByte & 0xf0) >> 4) 172 if modifier == 0 { 173 // not a delta, look ahead for the zigzag varint field id. 174 fieldId = protocol.FieldId(iter.ReadInt16()) 175 } else { 176 // has a delta. add the delta to the last read field id. 177 fieldId = iter.lastFieldId + protocol.FieldId(modifier) 178 } 179 switch tType := TCompactType(firstByte & 0x0f); tType { 180 case TypeBooleanTrue: 181 fieldType = protocol.TypeBool 182 iter.pendingBoolField = 1 183 case TypeBooleanFalse: 184 fieldType = protocol.TypeBool 185 iter.pendingBoolField = 2 186 default: 187 fieldType = tType.ToTType() 188 iter.pendingBoolField = 0 189 } 190 191 // push the new field onto the field stack so we can keep the deltas going. 192 iter.lastFieldId = fieldId 193 return fieldType, fieldId 194 } 195 196 func (iter *Iterator) ReadListHeader() (elemType protocol.TType, size int) { 197 lenAndType := iter.readByte() 198 length := int((lenAndType >> 4) & 0x0f) 199 if length == 15 { 200 length2 := iter.readVarInt32() 201 if length2 < 0 { 202 iter.ReportError("ReadListHeader", "invalid length") 203 return protocol.TypeStop, 0 204 } 205 length = int(length2) 206 } 207 elemType = TCompactType(lenAndType).ToTType() 208 return elemType, length 209 } 210 211 func (iter *Iterator) ReadMapHeader() (keyType protocol.TType, elemType protocol.TType, size int) { 212 length := int(iter.readVarInt32()) 213 if length == 0 { 214 return protocol.TypeStop, protocol.TypeStop, length 215 } 216 keyAndElemType := iter.readByte() 217 keyType = TCompactType(keyAndElemType >> 4).ToTType() 218 elemType = TCompactType(keyAndElemType & 0xf).ToTType() 219 return keyType, elemType, length 220 } 221 222 func (iter *Iterator) ReadBool() bool { 223 if iter.pendingBoolField == 0 { 224 return iter.ReadUint8() == 1 225 } 226 return iter.pendingBoolField == 1 227 } 228 229 func (iter *Iterator) ReadInt() int { 230 return int(iter.ReadInt64()) 231 } 232 233 func (iter *Iterator) ReadUint() uint { 234 return uint(iter.ReadInt64()) 235 } 236 237 func (iter *Iterator) ReadInt8() int8 { 238 return int8(iter.ReadUint8()) 239 } 240 241 func (iter *Iterator) ReadUint8() uint8 { 242 return iter.readByte() 243 } 244 245 func (iter *Iterator) ReadInt16() int16 { 246 return int16(iter.ReadInt32()) 247 } 248 249 func (iter *Iterator) ReadUint16() uint16 { 250 return uint16(iter.ReadUint32()) 251 } 252 253 func (iter *Iterator) ReadInt32() int32 { 254 result := iter.readVarInt32() 255 u := uint32(result) 256 return int32(u>>1) ^ -(result & 1) 257 } 258 259 func (iter *Iterator) ReadUint32() uint32 { 260 return uint32(iter.ReadInt32()) 261 } 262 263 func (iter *Iterator) ReadInt64() int64 { 264 result := iter.readVarInt64() 265 u := uint64(result) 266 return int64(u>>1) ^ -(result & 1) 267 } 268 269 func (iter *Iterator) ReadUint64() uint64 { 270 return uint64(iter.ReadInt64()) 271 } 272 273 func (iter *Iterator) ReadFloat64() float64 { 274 tmp := iter.readSmall(8) 275 return math.Float64frombits(binary.LittleEndian.Uint64(tmp)) 276 } 277 278 func (iter *Iterator) ReadString() string { 279 length := iter.readVarInt32() 280 return string(iter.readLarge(int(length))) 281 } 282 283 func (iter *Iterator) ReadBinary() []byte { 284 length := iter.readVarInt32() 285 tmp := make([]byte, length) 286 copy(tmp, iter.readLarge(int(length))) 287 return tmp 288 }