github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/protocol/compact/stream.go (about) 1 package compact 2 3 import ( 4 "fmt" 5 "github.com/batchcorp/thrift-iterator/protocol" 6 "github.com/batchcorp/thrift-iterator/spi" 7 "io" 8 "math" 9 ) 10 11 type Stream struct { 12 spi.ValEncoderProvider 13 writer io.Writer 14 buf []byte 15 err error 16 fieldIdStack []protocol.FieldId 17 lastFieldId protocol.FieldId 18 pendingBoolField protocol.FieldId 19 } 20 21 func NewStream(provider spi.ValEncoderProvider, writer io.Writer, buf []byte) *Stream { 22 return &Stream{ 23 ValEncoderProvider: provider, 24 writer: writer, 25 buf: buf, 26 pendingBoolField: -1, 27 } 28 } 29 30 func (stream *Stream) Spawn() spi.Stream { 31 return &Stream{ 32 ValEncoderProvider: stream.ValEncoderProvider, 33 } 34 } 35 36 func (stream *Stream) Error() error { 37 return stream.err 38 } 39 40 func (stream *Stream) ReportError(operation string, err string) { 41 if stream.err == nil { 42 stream.err = fmt.Errorf("%s: %s", operation, err) 43 } 44 } 45 46 func (stream *Stream) Buffer() []byte { 47 return stream.buf 48 } 49 50 func (stream *Stream) Reset(writer io.Writer) { 51 stream.writer = writer 52 stream.err = nil 53 stream.buf = stream.buf[:0] 54 } 55 56 func (stream *Stream) Flush() { 57 if stream.writer == nil { 58 return 59 } 60 _, err := stream.writer.Write(stream.buf) 61 if err != nil { 62 stream.ReportError("Flush", err.Error()) 63 return 64 } 65 if f, ok := stream.writer.(protocol.Flusher); ok { 66 if err = f.Flush(); err != nil { 67 stream.ReportError("Flush", err.Error()) 68 } 69 } 70 stream.buf = stream.buf[:0] 71 } 72 73 func (stream *Stream) Write(buf []byte) error { 74 stream.buf = append(stream.buf, buf...) 75 stream.Flush() 76 return stream.Error() 77 } 78 79 func (stream *Stream) WriteMessageHeader(header protocol.MessageHeader) { 80 stream.buf = append(stream.buf, protocol.COMPACT_PROTOCOL_ID) 81 stream.buf = append(stream.buf, (protocol.COMPACT_VERSION&protocol.COMPACT_VERSION_MASK)|((byte(header.MessageType)<<5)&0x0E0)) 82 stream.writeVarInt32(int32(header.SeqId)) 83 stream.WriteString(header.MessageName) 84 } 85 86 func (stream *Stream) WriteListHeader(elemType protocol.TType, length int) { 87 if length <= 14 { 88 stream.WriteUint8(uint8(int32(length<<4) | int32(compactTypes[elemType]))) 89 return 90 } 91 stream.WriteUint8(0xf0 | uint8(compactTypes[elemType])) 92 stream.writeVarInt32(int32(length)) 93 } 94 95 func (stream *Stream) WriteStructHeader() { 96 stream.fieldIdStack = append(stream.fieldIdStack, stream.lastFieldId) 97 stream.lastFieldId = 0 98 } 99 100 func (stream *Stream) WriteStructField(fieldType protocol.TType, fieldId protocol.FieldId) { 101 if fieldType == protocol.TypeBool { 102 stream.pendingBoolField = fieldId 103 return 104 } 105 compactType := uint8(compactTypes[fieldType]) 106 // check if we can use delta encoding for the field id 107 if fieldId > stream.lastFieldId && fieldId-stream.lastFieldId <= 15 { 108 stream.WriteUint8(uint8((fieldId-stream.lastFieldId)<<4) | compactType) 109 } else { 110 stream.WriteUint8(compactType) 111 stream.WriteInt16(int16(fieldId)) 112 } 113 stream.lastFieldId = fieldId 114 } 115 116 func (stream *Stream) WriteStructFieldStop() { 117 stream.buf = append(stream.buf, byte(TypeStop)) 118 stream.lastFieldId = stream.fieldIdStack[len(stream.fieldIdStack)-1] 119 stream.fieldIdStack = stream.fieldIdStack[:len(stream.fieldIdStack)-1] 120 stream.pendingBoolField = -1 121 } 122 123 func (stream *Stream) WriteMapHeader(keyType protocol.TType, elemType protocol.TType, length int) { 124 if length == 0 { 125 stream.WriteUint8(0) 126 return 127 } 128 stream.writeVarInt32(int32(length)) 129 stream.WriteUint8(uint8(compactTypes[keyType]<<4 | TCompactType(compactTypes[elemType]))) 130 } 131 132 func (stream *Stream) WriteBool(val bool) { 133 if stream.pendingBoolField == -1 { 134 if val { 135 stream.WriteUint8(1) 136 } else { 137 stream.WriteUint8(0) 138 } 139 return 140 } 141 var compactType TCompactType 142 if val { 143 compactType = TypeBooleanTrue 144 } else { 145 compactType = TypeBooleanFalse 146 } 147 fieldId := stream.pendingBoolField 148 // check if we can use delta encoding for the field id 149 if fieldId > stream.lastFieldId && fieldId-stream.lastFieldId <= 15 { 150 stream.WriteUint8(uint8((fieldId-stream.lastFieldId)<<4) | uint8(compactType)) 151 } else { 152 stream.WriteUint8(uint8(compactType)) 153 stream.WriteInt16(int16(fieldId)) 154 } 155 stream.lastFieldId = fieldId 156 stream.pendingBoolField = -1 157 } 158 159 func (stream *Stream) WriteInt8(val int8) { 160 stream.WriteUint8(uint8(val)) 161 } 162 163 func (stream *Stream) WriteUint8(val uint8) { 164 stream.buf = append(stream.buf, byte(val)) 165 } 166 167 func (stream *Stream) WriteInt16(val int16) { 168 stream.WriteInt32(int32(val)) 169 } 170 171 func (stream *Stream) WriteUint16(val uint16) { 172 stream.WriteInt32(int32(val)) 173 } 174 175 func (stream *Stream) WriteInt32(val int32) { 176 stream.writeVarInt32((val << 1) ^ (val >> 31)) 177 } 178 179 func (stream *Stream) WriteUint32(val uint32) { 180 stream.WriteInt32(int32(val)) 181 } 182 183 // Write an i32 as a varint. Results in 1-5 bytes on the wire. 184 func (stream *Stream) writeVarInt32(n int32) { 185 for { 186 if (n & ^0x7F) == 0 { 187 stream.buf = append(stream.buf, byte(n)) 188 break 189 } else { 190 stream.buf = append(stream.buf, byte((n&0x7F)|0x80)) 191 u := uint64(n) 192 n = int32(u >> 7) 193 } 194 } 195 } 196 197 func (stream *Stream) WriteInt64(val int64) { 198 stream.writeVarInt64((val << 1) ^ (val >> 63)) 199 } 200 201 // Write an i64 as a varint. Results in 1-10 bytes on the wire. 202 func (stream *Stream) writeVarInt64(n int64) { 203 for { 204 if (n & ^0x7F) == 0 { 205 stream.buf = append(stream.buf, byte(n)) 206 break 207 } else { 208 stream.buf = append(stream.buf, byte((n&0x7F)|0x80)) 209 u := uint64(n) 210 n = int64(u >> 7) 211 } 212 } 213 } 214 215 func (stream *Stream) WriteUint64(val uint64) { 216 stream.WriteInt64(int64(val)) 217 } 218 219 func (stream *Stream) WriteInt(val int) { 220 stream.WriteInt64(int64(val)) 221 } 222 223 func (stream *Stream) WriteUint(val uint) { 224 stream.WriteUint64(uint64(val)) 225 } 226 227 func (stream *Stream) WriteFloat64(val float64) { 228 bits := math.Float64bits(val) 229 stream.buf = append(stream.buf, 230 byte(bits), 231 byte(bits>>8), 232 byte(bits>>16), 233 byte(bits>>24), 234 byte(bits>>32), 235 byte(bits>>40), 236 byte(bits>>48), 237 byte(bits>>56), 238 ) 239 } 240 241 func (stream *Stream) WriteBinary(val []byte) { 242 stream.writeVarInt32(int32(len(val))) 243 stream.buf = append(stream.buf, val...) 244 } 245 246 func (stream *Stream) WriteString(val string) { 247 stream.writeVarInt32(int32(len(val))) 248 stream.buf = append(stream.buf, val...) 249 }