github.com/batchcorp/thrift-iterator@v0.0.0-20220918180557-4c4a158fc6e9/protocol/compact/stream.go (about)

     1  package compact
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/batchcorp/thrift-iterator/protocol"
     6  	"github.com/batchcorp/thrift-iterator/spi"
     7  	"io"
     8  	"math"
     9  )
    10  
    11  type Stream struct {
    12  	spi.ValEncoderProvider
    13  	writer           io.Writer
    14  	buf              []byte
    15  	err              error
    16  	fieldIdStack     []protocol.FieldId
    17  	lastFieldId      protocol.FieldId
    18  	pendingBoolField protocol.FieldId
    19  }
    20  
    21  func NewStream(provider spi.ValEncoderProvider, writer io.Writer, buf []byte) *Stream {
    22  	return &Stream{
    23  		ValEncoderProvider: provider,
    24  		writer:             writer,
    25  		buf:                buf,
    26  		pendingBoolField:   -1,
    27  	}
    28  }
    29  
    30  func (stream *Stream) Spawn() spi.Stream {
    31  	return &Stream{
    32  		ValEncoderProvider: stream.ValEncoderProvider,
    33  	}
    34  }
    35  
    36  func (stream *Stream) Error() error {
    37  	return stream.err
    38  }
    39  
    40  func (stream *Stream) ReportError(operation string, err string) {
    41  	if stream.err == nil {
    42  		stream.err = fmt.Errorf("%s: %s", operation, err)
    43  	}
    44  }
    45  
    46  func (stream *Stream) Buffer() []byte {
    47  	return stream.buf
    48  }
    49  
    50  func (stream *Stream) Reset(writer io.Writer) {
    51  	stream.writer = writer
    52  	stream.err = nil
    53  	stream.buf = stream.buf[:0]
    54  }
    55  
    56  func (stream *Stream) Flush() {
    57  	if stream.writer == nil {
    58  		return
    59  	}
    60  	_, err := stream.writer.Write(stream.buf)
    61  	if err != nil {
    62  		stream.ReportError("Flush", err.Error())
    63  		return
    64  	}
    65  	if f, ok := stream.writer.(protocol.Flusher); ok {
    66  		if err = f.Flush(); err != nil {
    67  			stream.ReportError("Flush", err.Error())
    68  		}
    69  	}
    70  	stream.buf = stream.buf[:0]
    71  }
    72  
    73  func (stream *Stream) Write(buf []byte) error {
    74  	stream.buf = append(stream.buf, buf...)
    75  	stream.Flush()
    76  	return stream.Error()
    77  }
    78  
    79  func (stream *Stream) WriteMessageHeader(header protocol.MessageHeader) {
    80  	stream.buf = append(stream.buf, protocol.COMPACT_PROTOCOL_ID)
    81  	stream.buf = append(stream.buf, (protocol.COMPACT_VERSION&protocol.COMPACT_VERSION_MASK)|((byte(header.MessageType)<<5)&0x0E0))
    82  	stream.writeVarInt32(int32(header.SeqId))
    83  	stream.WriteString(header.MessageName)
    84  }
    85  
    86  func (stream *Stream) WriteListHeader(elemType protocol.TType, length int) {
    87  	if length <= 14 {
    88  		stream.WriteUint8(uint8(int32(length<<4) | int32(compactTypes[elemType])))
    89  		return
    90  	}
    91  	stream.WriteUint8(0xf0 | uint8(compactTypes[elemType]))
    92  	stream.writeVarInt32(int32(length))
    93  }
    94  
    95  func (stream *Stream) WriteStructHeader() {
    96  	stream.fieldIdStack = append(stream.fieldIdStack, stream.lastFieldId)
    97  	stream.lastFieldId = 0
    98  }
    99  
   100  func (stream *Stream) WriteStructField(fieldType protocol.TType, fieldId protocol.FieldId) {
   101  	if fieldType == protocol.TypeBool {
   102  		stream.pendingBoolField = fieldId
   103  		return
   104  	}
   105  	compactType := uint8(compactTypes[fieldType])
   106  	// check if we can use delta encoding for the field id
   107  	if fieldId > stream.lastFieldId && fieldId-stream.lastFieldId <= 15 {
   108  		stream.WriteUint8(uint8((fieldId-stream.lastFieldId)<<4) | compactType)
   109  	} else {
   110  		stream.WriteUint8(compactType)
   111  		stream.WriteInt16(int16(fieldId))
   112  	}
   113  	stream.lastFieldId = fieldId
   114  }
   115  
   116  func (stream *Stream) WriteStructFieldStop() {
   117  	stream.buf = append(stream.buf, byte(TypeStop))
   118  	stream.lastFieldId = stream.fieldIdStack[len(stream.fieldIdStack)-1]
   119  	stream.fieldIdStack = stream.fieldIdStack[:len(stream.fieldIdStack)-1]
   120  	stream.pendingBoolField = -1
   121  }
   122  
   123  func (stream *Stream) WriteMapHeader(keyType protocol.TType, elemType protocol.TType, length int) {
   124  	if length == 0 {
   125  		stream.WriteUint8(0)
   126  		return
   127  	}
   128  	stream.writeVarInt32(int32(length))
   129  	stream.WriteUint8(uint8(compactTypes[keyType]<<4 | TCompactType(compactTypes[elemType])))
   130  }
   131  
   132  func (stream *Stream) WriteBool(val bool) {
   133  	if stream.pendingBoolField == -1 {
   134  		if val {
   135  			stream.WriteUint8(1)
   136  		} else {
   137  			stream.WriteUint8(0)
   138  		}
   139  		return
   140  	}
   141  	var compactType TCompactType
   142  	if val {
   143  		compactType = TypeBooleanTrue
   144  	} else {
   145  		compactType = TypeBooleanFalse
   146  	}
   147  	fieldId := stream.pendingBoolField
   148  	// check if we can use delta encoding for the field id
   149  	if fieldId > stream.lastFieldId && fieldId-stream.lastFieldId <= 15 {
   150  		stream.WriteUint8(uint8((fieldId-stream.lastFieldId)<<4) | uint8(compactType))
   151  	} else {
   152  		stream.WriteUint8(uint8(compactType))
   153  		stream.WriteInt16(int16(fieldId))
   154  	}
   155  	stream.lastFieldId = fieldId
   156  	stream.pendingBoolField = -1
   157  }
   158  
   159  func (stream *Stream) WriteInt8(val int8) {
   160  	stream.WriteUint8(uint8(val))
   161  }
   162  
   163  func (stream *Stream) WriteUint8(val uint8) {
   164  	stream.buf = append(stream.buf, byte(val))
   165  }
   166  
   167  func (stream *Stream) WriteInt16(val int16) {
   168  	stream.WriteInt32(int32(val))
   169  }
   170  
   171  func (stream *Stream) WriteUint16(val uint16) {
   172  	stream.WriteInt32(int32(val))
   173  }
   174  
   175  func (stream *Stream) WriteInt32(val int32) {
   176  	stream.writeVarInt32((val << 1) ^ (val >> 31))
   177  }
   178  
   179  func (stream *Stream) WriteUint32(val uint32) {
   180  	stream.WriteInt32(int32(val))
   181  }
   182  
   183  // Write an i32 as a varint. Results in 1-5 bytes on the wire.
   184  func (stream *Stream) writeVarInt32(n int32) {
   185  	for {
   186  		if (n & ^0x7F) == 0 {
   187  			stream.buf = append(stream.buf, byte(n))
   188  			break
   189  		} else {
   190  			stream.buf = append(stream.buf, byte((n&0x7F)|0x80))
   191  			u := uint64(n)
   192  			n = int32(u >> 7)
   193  		}
   194  	}
   195  }
   196  
   197  func (stream *Stream) WriteInt64(val int64) {
   198  	stream.writeVarInt64((val << 1) ^ (val >> 63))
   199  }
   200  
   201  // Write an i64 as a varint. Results in 1-10 bytes on the wire.
   202  func (stream *Stream) writeVarInt64(n int64) {
   203  	for {
   204  		if (n & ^0x7F) == 0 {
   205  			stream.buf = append(stream.buf, byte(n))
   206  			break
   207  		} else {
   208  			stream.buf = append(stream.buf, byte((n&0x7F)|0x80))
   209  			u := uint64(n)
   210  			n = int64(u >> 7)
   211  		}
   212  	}
   213  }
   214  
   215  func (stream *Stream) WriteUint64(val uint64) {
   216  	stream.WriteInt64(int64(val))
   217  }
   218  
   219  func (stream *Stream) WriteInt(val int) {
   220  	stream.WriteInt64(int64(val))
   221  }
   222  
   223  func (stream *Stream) WriteUint(val uint) {
   224  	stream.WriteUint64(uint64(val))
   225  }
   226  
   227  func (stream *Stream) WriteFloat64(val float64) {
   228  	bits := math.Float64bits(val)
   229  	stream.buf = append(stream.buf,
   230  		byte(bits),
   231  		byte(bits>>8),
   232  		byte(bits>>16),
   233  		byte(bits>>24),
   234  		byte(bits>>32),
   235  		byte(bits>>40),
   236  		byte(bits>>48),
   237  		byte(bits>>56),
   238  	)
   239  }
   240  
   241  func (stream *Stream) WriteBinary(val []byte) {
   242  	stream.writeVarInt32(int32(len(val)))
   243  	stream.buf = append(stream.buf, val...)
   244  }
   245  
   246  func (stream *Stream) WriteString(val string) {
   247  	stream.writeVarInt32(int32(len(val)))
   248  	stream.buf = append(stream.buf, val...)
   249  }