github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/thrift/compact.go (about)

     1  package thrift
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"io"
     9  	"math"
    10  )
    11  
    12  // CompactProtocol is a Protocol implementation for the compact thrift protocol.
    13  //
    14  // https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md#integer-encoding
    15  type CompactProtocol struct{}
    16  
    17  func (p *CompactProtocol) NewReader(r io.Reader) Reader {
    18  	return &compactReader{protocol: p, binary: binaryReader{r: r}}
    19  }
    20  
    21  func (p *CompactProtocol) NewWriter(w io.Writer) Writer {
    22  	return &compactWriter{protocol: p, binary: binaryWriter{w: w}}
    23  }
    24  
    25  func (p *CompactProtocol) Features() Features {
    26  	return UseDeltaEncoding | CoalesceBoolFields
    27  }
    28  
    29  type compactReader struct {
    30  	protocol *CompactProtocol
    31  	binary   binaryReader
    32  }
    33  
    34  func (r *compactReader) Protocol() Protocol {
    35  	return r.protocol
    36  }
    37  
    38  func (r *compactReader) Reader() io.Reader {
    39  	return r.binary.Reader()
    40  }
    41  
    42  func (r *compactReader) ReadBool() (bool, error) {
    43  	return r.binary.ReadBool()
    44  }
    45  
    46  func (r *compactReader) ReadInt8() (int8, error) {
    47  	return r.binary.ReadInt8()
    48  }
    49  
    50  func (r *compactReader) ReadInt16() (int16, error) {
    51  	v, err := r.readVarint("int16", math.MinInt16, math.MaxInt16)
    52  	return int16(v), err
    53  }
    54  
    55  func (r *compactReader) ReadInt32() (int32, error) {
    56  	v, err := r.readVarint("int32", math.MinInt32, math.MaxInt32)
    57  	return int32(v), err
    58  }
    59  
    60  func (r *compactReader) ReadInt64() (int64, error) {
    61  	return r.readVarint("int64", math.MinInt64, math.MaxInt64)
    62  }
    63  
    64  func (r *compactReader) ReadFloat64() (float64, error) {
    65  	return r.binary.ReadFloat64()
    66  }
    67  
    68  func (r *compactReader) ReadBytes() ([]byte, error) {
    69  	n, err := r.ReadLength()
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	b := make([]byte, n)
    74  	_, err = io.ReadFull(r.Reader(), b)
    75  	return b, err
    76  }
    77  
    78  func (r *compactReader) ReadString() (string, error) {
    79  	b, err := r.ReadBytes()
    80  	return unsafeBytesToString(b), err
    81  }
    82  
    83  func (r *compactReader) ReadLength() (int, error) {
    84  	n, err := r.readUvarint("length", math.MaxInt32)
    85  	return int(n), err
    86  }
    87  
    88  func (r *compactReader) ReadMessage() (Message, error) {
    89  	m := Message{}
    90  
    91  	b0, err := r.ReadByte()
    92  	if err != nil {
    93  		return m, err
    94  	}
    95  	if b0 != 0x82 {
    96  		return m, fmt.Errorf("invalid protocol id found when reading thrift message: %#x", b0)
    97  	}
    98  
    99  	b1, err := r.ReadByte()
   100  	if err != nil {
   101  		return m, dontExpectEOF(err)
   102  	}
   103  
   104  	seqID, err := r.readUvarint("seq id", math.MaxInt32)
   105  	if err != nil {
   106  		return m, dontExpectEOF(err)
   107  	}
   108  
   109  	m.Type = MessageType(b1) & 0x7
   110  	m.SeqID = int32(seqID)
   111  	m.Name, err = r.ReadString()
   112  	return m, dontExpectEOF(err)
   113  }
   114  
   115  func (r *compactReader) ReadField() (Field, error) {
   116  	f := Field{}
   117  
   118  	b, err := r.ReadByte()
   119  	if err != nil {
   120  		return f, err
   121  	}
   122  
   123  	if Type(b) == STOP {
   124  		return f, nil
   125  	}
   126  
   127  	if (b >> 4) != 0 {
   128  		f = Field{ID: int16(b >> 4), Type: Type(b & 0xF), Delta: true}
   129  	} else {
   130  		i, err := r.ReadInt16()
   131  		if err != nil {
   132  			return f, dontExpectEOF(err)
   133  		}
   134  		f = Field{ID: i, Type: Type(b)}
   135  	}
   136  
   137  	return f, nil
   138  }
   139  
   140  func (r *compactReader) ReadList() (List, error) {
   141  	b, err := r.ReadByte()
   142  	if err != nil {
   143  		return List{}, err
   144  	}
   145  	if (b >> 4) != 0xF {
   146  		return List{Size: int32(b >> 4), Type: Type(b & 0xF)}, nil
   147  	}
   148  	n, err := r.readUvarint("list size", math.MaxInt32)
   149  	if err != nil {
   150  		return List{}, dontExpectEOF(err)
   151  	}
   152  	return List{Size: int32(n), Type: Type(b & 0xF)}, nil
   153  }
   154  
   155  func (r *compactReader) ReadSet() (Set, error) {
   156  	l, err := r.ReadList()
   157  	return Set(l), err
   158  }
   159  
   160  func (r *compactReader) ReadMap() (Map, error) {
   161  	n, err := r.readUvarint("map size", math.MaxInt32)
   162  	if err != nil {
   163  		return Map{}, err
   164  	}
   165  	if n == 0 { // empty map
   166  		return Map{}, nil
   167  	}
   168  	b, err := r.ReadByte()
   169  	if err != nil {
   170  		return Map{}, dontExpectEOF(err)
   171  	}
   172  	return Map{Size: int32(n), Key: Type(b >> 4), Value: Type(b & 0xF)}, nil
   173  }
   174  
   175  func (r *compactReader) ReadByte() (byte, error) {
   176  	return r.binary.ReadByte()
   177  }
   178  
   179  func (r *compactReader) readUvarint(typ string, max uint64) (uint64, error) {
   180  	var br io.ByteReader
   181  
   182  	switch x := r.Reader().(type) {
   183  	case *bytes.Buffer:
   184  		br = x
   185  	case *bytes.Reader:
   186  		br = x
   187  	case *bufio.Reader:
   188  		br = x
   189  	case io.ByteReader:
   190  		br = x
   191  	default:
   192  		br = &r.binary
   193  	}
   194  
   195  	u, err := binary.ReadUvarint(br)
   196  	if err == nil {
   197  		if u > max {
   198  			err = fmt.Errorf("%s varint out of range: %d > %d", typ, u, max)
   199  		}
   200  	}
   201  	return u, err
   202  }
   203  
   204  func (r *compactReader) readVarint(typ string, min, max int64) (int64, error) {
   205  	var br io.ByteReader
   206  
   207  	switch x := r.Reader().(type) {
   208  	case *bytes.Buffer:
   209  		br = x
   210  	case *bytes.Reader:
   211  		br = x
   212  	case *bufio.Reader:
   213  		br = x
   214  	case io.ByteReader:
   215  		br = x
   216  	default:
   217  		br = &r.binary
   218  	}
   219  
   220  	v, err := binary.ReadVarint(br)
   221  	if err == nil {
   222  		if v < min || v > max {
   223  			err = fmt.Errorf("%s varint out of range: %d not in [%d;%d]", typ, v, min, max)
   224  		}
   225  	}
   226  	return v, err
   227  }
   228  
   229  type compactWriter struct {
   230  	protocol *CompactProtocol
   231  	binary   binaryWriter
   232  	varint   [binary.MaxVarintLen64]byte
   233  }
   234  
   235  func (w *compactWriter) Protocol() Protocol {
   236  	return w.protocol
   237  }
   238  
   239  func (w *compactWriter) Writer() io.Writer {
   240  	return w.binary.Writer()
   241  }
   242  
   243  func (w *compactWriter) WriteBool(v bool) error {
   244  	return w.binary.WriteBool(v)
   245  }
   246  
   247  func (w *compactWriter) WriteInt8(v int8) error {
   248  	return w.binary.WriteInt8(v)
   249  }
   250  
   251  func (w *compactWriter) WriteInt16(v int16) error {
   252  	return w.writeVarint(int64(v))
   253  }
   254  
   255  func (w *compactWriter) WriteInt32(v int32) error {
   256  	return w.writeVarint(int64(v))
   257  }
   258  
   259  func (w *compactWriter) WriteInt64(v int64) error {
   260  	return w.writeVarint(v)
   261  }
   262  
   263  func (w *compactWriter) WriteFloat64(v float64) error {
   264  	return w.binary.WriteFloat64(v)
   265  }
   266  
   267  func (w *compactWriter) WriteBytes(v []byte) error {
   268  	if err := w.WriteLength(len(v)); err != nil {
   269  		return err
   270  	}
   271  	return w.binary.write(v)
   272  }
   273  
   274  func (w *compactWriter) WriteString(v string) error {
   275  	if err := w.WriteLength(len(v)); err != nil {
   276  		return err
   277  	}
   278  	return w.binary.writeString(v)
   279  }
   280  
   281  func (w *compactWriter) WriteLength(n int) error {
   282  	if n < 0 {
   283  		return fmt.Errorf("negative length cannot be encoded in thrift: %d", n)
   284  	}
   285  	if n > math.MaxInt32 {
   286  		return fmt.Errorf("length is too large to be encoded in thrift: %d", n)
   287  	}
   288  	return w.writeUvarint(uint64(n))
   289  }
   290  
   291  func (w *compactWriter) WriteMessage(m Message) error {
   292  	if err := w.binary.writeByte(0x82); err != nil {
   293  		return err
   294  	}
   295  	if err := w.binary.writeByte(byte(m.Type)); err != nil {
   296  		return err
   297  	}
   298  	if err := w.writeUvarint(uint64(m.SeqID)); err != nil {
   299  		return err
   300  	}
   301  	return w.WriteString(m.Name)
   302  }
   303  
   304  func (w *compactWriter) WriteField(f Field) error {
   305  	if f.Type == STOP {
   306  		return w.binary.writeByte(0)
   307  	}
   308  	if f.ID <= 15 {
   309  		return w.binary.writeByte(byte(f.ID<<4) | byte(f.Type))
   310  	}
   311  	if err := w.binary.writeByte(byte(f.Type)); err != nil {
   312  		return err
   313  	}
   314  	return w.WriteInt16(f.ID)
   315  }
   316  
   317  func (w *compactWriter) WriteList(l List) error {
   318  	if l.Size <= 14 {
   319  		return w.binary.writeByte(byte(l.Size<<4) | byte(l.Type))
   320  	}
   321  	if err := w.binary.writeByte(0xF0 | byte(l.Type)); err != nil {
   322  		return err
   323  	}
   324  	return w.writeUvarint(uint64(l.Size))
   325  }
   326  
   327  func (w *compactWriter) WriteSet(s Set) error {
   328  	return w.WriteList(List(s))
   329  }
   330  
   331  func (w *compactWriter) WriteMap(m Map) error {
   332  	if err := w.writeUvarint(uint64(m.Size)); err != nil || m.Size == 0 {
   333  		return err
   334  	}
   335  	return w.binary.writeByte((byte(m.Key) << 4) | byte(m.Value))
   336  }
   337  
   338  func (w *compactWriter) writeUvarint(v uint64) error {
   339  	n := binary.PutUvarint(w.varint[:], v)
   340  	return w.binary.write(w.varint[:n])
   341  }
   342  
   343  func (w *compactWriter) writeVarint(v int64) error {
   344  	n := binary.PutVarint(w.varint[:], v)
   345  	return w.binary.write(w.varint[:n])
   346  }