github.com/segmentio/encoding@v0.4.0/thrift/binary.go (about)

     1  package thrift
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"io"
     9  	"math"
    10  )
    11  
    12  // BinaryProtocol is a Protocol implementation for the binary thrift protocol.
    13  //
    14  // https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md
    15  type BinaryProtocol struct {
    16  	NonStrict bool
    17  }
    18  
    19  func (p *BinaryProtocol) NewReader(r io.Reader) Reader {
    20  	return &binaryReader{p: p, r: r}
    21  }
    22  
    23  func (p *BinaryProtocol) NewWriter(w io.Writer) Writer {
    24  	return &binaryWriter{p: p, w: w}
    25  }
    26  
    27  func (p *BinaryProtocol) Features() Features {
    28  	return 0
    29  }
    30  
    31  type binaryReader struct {
    32  	p *BinaryProtocol
    33  	r io.Reader
    34  	b [8]byte
    35  }
    36  
    37  func (r *binaryReader) Protocol() Protocol {
    38  	return r.p
    39  }
    40  
    41  func (r *binaryReader) Reader() io.Reader {
    42  	return r.r
    43  }
    44  
    45  func (r *binaryReader) ReadBool() (bool, error) {
    46  	v, err := r.ReadByte()
    47  	return v != 0, err
    48  }
    49  
    50  func (r *binaryReader) ReadInt8() (int8, error) {
    51  	b, err := r.ReadByte()
    52  	return int8(b), err
    53  }
    54  
    55  func (r *binaryReader) ReadInt16() (int16, error) {
    56  	b, err := r.read(2)
    57  	if len(b) < 2 {
    58  		return 0, err
    59  	}
    60  	return int16(binary.BigEndian.Uint16(b)), nil
    61  }
    62  
    63  func (r *binaryReader) ReadInt32() (int32, error) {
    64  	b, err := r.read(4)
    65  	if len(b) < 4 {
    66  		return 0, err
    67  	}
    68  	return int32(binary.BigEndian.Uint32(b)), nil
    69  }
    70  
    71  func (r *binaryReader) ReadInt64() (int64, error) {
    72  	b, err := r.read(8)
    73  	if len(b) < 8 {
    74  		return 0, err
    75  	}
    76  	return int64(binary.BigEndian.Uint64(b)), nil
    77  }
    78  
    79  func (r *binaryReader) ReadFloat64() (float64, error) {
    80  	b, err := r.read(8)
    81  	if len(b) < 8 {
    82  		return 0, err
    83  	}
    84  	return math.Float64frombits(binary.BigEndian.Uint64(b)), nil
    85  }
    86  
    87  func (r *binaryReader) ReadBytes() ([]byte, error) {
    88  	n, err := r.ReadLength()
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	b := make([]byte, n)
    93  	_, err = io.ReadFull(r.r, b)
    94  	return b, err
    95  }
    96  
    97  func (r *binaryReader) ReadString() (string, error) {
    98  	b, err := r.ReadBytes()
    99  	return unsafeBytesToString(b), err
   100  }
   101  
   102  func (r *binaryReader) ReadLength() (int, error) {
   103  	b, err := r.read(4)
   104  	if len(b) < 4 {
   105  		return 0, err
   106  	}
   107  	n := binary.BigEndian.Uint32(b)
   108  	if n > math.MaxInt32 {
   109  		return 0, fmt.Errorf("length out of range: %d", n)
   110  	}
   111  	return int(n), nil
   112  }
   113  
   114  func (r *binaryReader) ReadMessage() (Message, error) {
   115  	m := Message{}
   116  
   117  	b, err := r.read(4)
   118  	if len(b) < 4 {
   119  		return m, err
   120  	}
   121  
   122  	if (b[0] >> 7) == 0 { // non-strict
   123  		n := int(binary.BigEndian.Uint32(b))
   124  		s := make([]byte, n)
   125  		_, err := io.ReadFull(r.r, s)
   126  		if err != nil {
   127  			return m, dontExpectEOF(err)
   128  		}
   129  		m.Name = unsafeBytesToString(s)
   130  
   131  		t, err := r.ReadInt8()
   132  		if err != nil {
   133  			return m, dontExpectEOF(err)
   134  		}
   135  
   136  		m.Type = MessageType(t & 0x7)
   137  	} else {
   138  		m.Type = MessageType(b[3] & 0x7)
   139  
   140  		if m.Name, err = r.ReadString(); err != nil {
   141  			return m, dontExpectEOF(err)
   142  		}
   143  	}
   144  
   145  	m.SeqID, err = r.ReadInt32()
   146  	return m, err
   147  }
   148  
   149  func (r *binaryReader) ReadField() (Field, error) {
   150  	t, err := r.ReadInt8()
   151  	if err != nil {
   152  		return Field{}, err
   153  	}
   154  	i, err := r.ReadInt16()
   155  	if err != nil {
   156  		return Field{}, err
   157  	}
   158  	return Field{ID: i, Type: Type(t)}, nil
   159  }
   160  
   161  func (r *binaryReader) ReadList() (List, error) {
   162  	t, err := r.ReadInt8()
   163  	if err != nil {
   164  		return List{}, err
   165  	}
   166  	n, err := r.ReadInt32()
   167  	if err != nil {
   168  		return List{}, dontExpectEOF(err)
   169  	}
   170  	return List{Size: n, Type: Type(t)}, nil
   171  }
   172  
   173  func (r *binaryReader) ReadSet() (Set, error) {
   174  	l, err := r.ReadList()
   175  	return Set(l), err
   176  }
   177  
   178  func (r *binaryReader) ReadMap() (Map, error) {
   179  	k, err := r.ReadByte()
   180  	if err != nil {
   181  		return Map{}, err
   182  	}
   183  	v, err := r.ReadByte()
   184  	if err != nil {
   185  		return Map{}, dontExpectEOF(err)
   186  	}
   187  	n, err := r.ReadInt32()
   188  	if err != nil {
   189  		return Map{}, dontExpectEOF(err)
   190  	}
   191  	return Map{Size: n, Key: Type(k), Value: Type(v)}, nil
   192  }
   193  
   194  func (r *binaryReader) ReadByte() (byte, error) {
   195  	switch x := r.r.(type) {
   196  	case *bytes.Buffer:
   197  		return x.ReadByte()
   198  	case *bytes.Reader:
   199  		return x.ReadByte()
   200  	case *bufio.Reader:
   201  		return x.ReadByte()
   202  	case io.ByteReader:
   203  		return x.ReadByte()
   204  	default:
   205  		b, err := r.read(1)
   206  		if err != nil {
   207  			return 0, err
   208  		}
   209  		return b[0], nil
   210  	}
   211  }
   212  
   213  func (r *binaryReader) read(n int) ([]byte, error) {
   214  	_, err := io.ReadFull(r.r, r.b[:n])
   215  	return r.b[:n], err
   216  }
   217  
   218  type binaryWriter struct {
   219  	p *BinaryProtocol
   220  	b [8]byte
   221  	w io.Writer
   222  }
   223  
   224  func (w *binaryWriter) Protocol() Protocol {
   225  	return w.p
   226  }
   227  
   228  func (w *binaryWriter) Writer() io.Writer {
   229  	return w.w
   230  }
   231  
   232  func (w *binaryWriter) WriteBool(v bool) error {
   233  	var b byte
   234  	if v {
   235  		b = 1
   236  	}
   237  	return w.writeByte(b)
   238  }
   239  
   240  func (w *binaryWriter) WriteInt8(v int8) error {
   241  	return w.writeByte(byte(v))
   242  }
   243  
   244  func (w *binaryWriter) WriteInt16(v int16) error {
   245  	binary.BigEndian.PutUint16(w.b[:2], uint16(v))
   246  	return w.write(w.b[:2])
   247  }
   248  
   249  func (w *binaryWriter) WriteInt32(v int32) error {
   250  	binary.BigEndian.PutUint32(w.b[:4], uint32(v))
   251  	return w.write(w.b[:4])
   252  }
   253  
   254  func (w *binaryWriter) WriteInt64(v int64) error {
   255  	binary.BigEndian.PutUint64(w.b[:8], uint64(v))
   256  	return w.write(w.b[:8])
   257  }
   258  
   259  func (w *binaryWriter) WriteFloat64(v float64) error {
   260  	binary.BigEndian.PutUint64(w.b[:8], math.Float64bits(v))
   261  	return w.write(w.b[:8])
   262  }
   263  
   264  func (w *binaryWriter) WriteBytes(v []byte) error {
   265  	if err := w.WriteLength(len(v)); err != nil {
   266  		return err
   267  	}
   268  	return w.write(v)
   269  }
   270  
   271  func (w *binaryWriter) WriteString(v string) error {
   272  	if err := w.WriteLength(len(v)); err != nil {
   273  		return err
   274  	}
   275  	return w.writeString(v)
   276  }
   277  
   278  func (w *binaryWriter) WriteLength(n int) error {
   279  	if n < 0 {
   280  		return fmt.Errorf("negative length cannot be encoded in thrift: %d", n)
   281  	}
   282  	if n > math.MaxInt32 {
   283  		return fmt.Errorf("length is too large to be encoded in thrift: %d", n)
   284  	}
   285  	return w.WriteInt32(int32(n))
   286  }
   287  
   288  func (w *binaryWriter) WriteMessage(m Message) error {
   289  	if w.p.NonStrict {
   290  		if err := w.WriteString(m.Name); err != nil {
   291  			return err
   292  		}
   293  		if err := w.writeByte(byte(m.Type)); err != nil {
   294  			return err
   295  		}
   296  	} else {
   297  		w.b[0] = 1 << 7
   298  		w.b[1] = 0
   299  		w.b[2] = 0
   300  		w.b[3] = byte(m.Type) & 0x7
   301  		binary.BigEndian.PutUint32(w.b[4:], uint32(len(m.Name)))
   302  
   303  		if err := w.write(w.b[:8]); err != nil {
   304  			return err
   305  		}
   306  		if err := w.writeString(m.Name); err != nil {
   307  			return err
   308  		}
   309  	}
   310  	return w.WriteInt32(m.SeqID)
   311  }
   312  
   313  func (w *binaryWriter) WriteField(f Field) error {
   314  	if err := w.writeByte(byte(f.Type)); err != nil {
   315  		return err
   316  	}
   317  	return w.WriteInt16(f.ID)
   318  }
   319  
   320  func (w *binaryWriter) WriteList(l List) error {
   321  	if err := w.writeByte(byte(l.Type)); err != nil {
   322  		return err
   323  	}
   324  	return w.WriteInt32(l.Size)
   325  }
   326  
   327  func (w *binaryWriter) WriteSet(s Set) error {
   328  	return w.WriteList(List(s))
   329  }
   330  
   331  func (w *binaryWriter) WriteMap(m Map) error {
   332  	if err := w.writeByte(byte(m.Key)); err != nil {
   333  		return err
   334  	}
   335  	if err := w.writeByte(byte(m.Value)); err != nil {
   336  		return err
   337  	}
   338  	return w.WriteInt32(m.Size)
   339  }
   340  
   341  func (w *binaryWriter) write(b []byte) error {
   342  	_, err := w.w.Write(b)
   343  	return err
   344  }
   345  
   346  func (w *binaryWriter) writeString(s string) error {
   347  	_, err := io.WriteString(w.w, s)
   348  	return err
   349  }
   350  
   351  func (w *binaryWriter) writeByte(b byte) error {
   352  	// The special cases are intended to reduce the runtime overheadof testing
   353  	// for the io.ByteWriter interface for common types. Type assertions on a
   354  	// concrete type is just a pointer comparison, instead of requiring a
   355  	// complex lookup in the type metadata.
   356  	switch x := w.w.(type) {
   357  	case *bytes.Buffer:
   358  		return x.WriteByte(b)
   359  	case *bufio.Writer:
   360  		return x.WriteByte(b)
   361  	case io.ByteWriter:
   362  		return x.WriteByte(b)
   363  	default:
   364  		w.b[0] = b
   365  		return w.write(w.b[:1])
   366  	}
   367  }