github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/thrift/binary_protocol.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  	"encoding/binary"
    22  	"math"
    23  	"sync"
    24  
    25  	"github.com/apache/thrift/lib/go/thrift"
    26  
    27  	"github.com/cloudwego/kitex/pkg/remote"
    28  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    29  )
    30  
    31  // must be strict read & strict write
    32  var (
    33  	bpPool sync.Pool
    34  	_      thrift.TProtocol = (*BinaryProtocol)(nil)
    35  )
    36  
    37  func init() {
    38  	bpPool.New = newBP
    39  }
    40  
    41  func newBP() interface{} {
    42  	return &BinaryProtocol{}
    43  }
    44  
    45  // NewBinaryProtocol ...
    46  func NewBinaryProtocol(t remote.ByteBuffer) *BinaryProtocol {
    47  	bp := bpPool.Get().(*BinaryProtocol)
    48  	bp.trans = t
    49  	return bp
    50  }
    51  
    52  // BinaryProtocol ...
    53  type BinaryProtocol struct {
    54  	trans remote.ByteBuffer
    55  }
    56  
    57  // Recycle ...
    58  func (p *BinaryProtocol) Recycle() {
    59  	p.trans = nil
    60  	bpPool.Put(p)
    61  }
    62  
    63  /**
    64   * Writing Methods
    65   */
    66  
    67  // WriteMessageBegin ...
    68  func (p *BinaryProtocol) WriteMessageBegin(name string, typeID thrift.TMessageType, seqID int32) error {
    69  	version := uint32(thrift.VERSION_1) | uint32(typeID)
    70  	e := p.WriteI32(int32(version))
    71  	if e != nil {
    72  		return e
    73  	}
    74  	e = p.WriteString(name)
    75  	if e != nil {
    76  		return e
    77  	}
    78  	e = p.WriteI32(seqID)
    79  	return e
    80  }
    81  
    82  // WriteMessageEnd ...
    83  func (p *BinaryProtocol) WriteMessageEnd() error {
    84  	return nil
    85  }
    86  
    87  // WriteStructBegin ...
    88  func (p *BinaryProtocol) WriteStructBegin(name string) error {
    89  	return nil
    90  }
    91  
    92  // WriteStructEnd ...
    93  func (p *BinaryProtocol) WriteStructEnd() error {
    94  	return nil
    95  }
    96  
    97  // WriteFieldBegin ...
    98  func (p *BinaryProtocol) WriteFieldBegin(name string, typeID thrift.TType, id int16) error {
    99  	e := p.WriteByte(int8(typeID))
   100  	if e != nil {
   101  		return e
   102  	}
   103  	e = p.WriteI16(id)
   104  	return e
   105  }
   106  
   107  // WriteFieldEnd ...
   108  func (p *BinaryProtocol) WriteFieldEnd() error {
   109  	return nil
   110  }
   111  
   112  // WriteFieldStop ...
   113  func (p *BinaryProtocol) WriteFieldStop() error {
   114  	e := p.WriteByte(thrift.STOP)
   115  	return e
   116  }
   117  
   118  // WriteMapBegin ...
   119  func (p *BinaryProtocol) WriteMapBegin(keyType, valueType thrift.TType, size int) error {
   120  	e := p.WriteByte(int8(keyType))
   121  	if e != nil {
   122  		return e
   123  	}
   124  	e = p.WriteByte(int8(valueType))
   125  	if e != nil {
   126  		return e
   127  	}
   128  	e = p.WriteI32(int32(size))
   129  	return e
   130  }
   131  
   132  // WriteMapEnd ...
   133  func (p *BinaryProtocol) WriteMapEnd() error {
   134  	return nil
   135  }
   136  
   137  // WriteListBegin ...
   138  func (p *BinaryProtocol) WriteListBegin(elemType thrift.TType, size int) error {
   139  	e := p.WriteByte(int8(elemType))
   140  	if e != nil {
   141  		return e
   142  	}
   143  	e = p.WriteI32(int32(size))
   144  	return e
   145  }
   146  
   147  // WriteListEnd ...
   148  func (p *BinaryProtocol) WriteListEnd() error {
   149  	return nil
   150  }
   151  
   152  // WriteSetBegin ...
   153  func (p *BinaryProtocol) WriteSetBegin(elemType thrift.TType, size int) error {
   154  	e := p.WriteByte(int8(elemType))
   155  	if e != nil {
   156  		return e
   157  	}
   158  	e = p.WriteI32(int32(size))
   159  	return e
   160  }
   161  
   162  // WriteSetEnd ...
   163  func (p *BinaryProtocol) WriteSetEnd() error {
   164  	return nil
   165  }
   166  
   167  // WriteBool ...
   168  func (p *BinaryProtocol) WriteBool(value bool) error {
   169  	if value {
   170  		return p.WriteByte(1)
   171  	}
   172  	return p.WriteByte(0)
   173  }
   174  
   175  // WriteByte ...
   176  func (p *BinaryProtocol) WriteByte(value int8) error {
   177  	v, err := p.malloc(1)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	v[0] = byte(value)
   182  	return err
   183  }
   184  
   185  // WriteI16 ...
   186  func (p *BinaryProtocol) WriteI16(value int16) error {
   187  	v, err := p.malloc(2)
   188  	if err != nil {
   189  		return err
   190  	}
   191  	binary.BigEndian.PutUint16(v, uint16(value))
   192  	return err
   193  }
   194  
   195  // WriteI32 ...
   196  func (p *BinaryProtocol) WriteI32(value int32) error {
   197  	v, err := p.malloc(4)
   198  	if err != nil {
   199  		return err
   200  	}
   201  	binary.BigEndian.PutUint32(v, uint32(value))
   202  	return err
   203  }
   204  
   205  // WriteI64 ...
   206  func (p *BinaryProtocol) WriteI64(value int64) error {
   207  	v, err := p.malloc(8)
   208  	if err != nil {
   209  		return err
   210  	}
   211  	binary.BigEndian.PutUint64(v, uint64(value))
   212  	return err
   213  }
   214  
   215  // WriteDouble ...
   216  func (p *BinaryProtocol) WriteDouble(value float64) error {
   217  	return p.WriteI64(int64(math.Float64bits(value)))
   218  }
   219  
   220  // WriteString ...
   221  func (p *BinaryProtocol) WriteString(value string) error {
   222  	len := len(value)
   223  	e := p.WriteI32(int32(len))
   224  	if e != nil {
   225  		return e
   226  	}
   227  	_, e = p.trans.WriteString(value)
   228  	return e
   229  }
   230  
   231  // WriteBinary ...
   232  func (p *BinaryProtocol) WriteBinary(value []byte) error {
   233  	e := p.WriteI32(int32(len(value)))
   234  	if e != nil {
   235  		return e
   236  	}
   237  	_, e = p.trans.WriteBinary(value)
   238  	return e
   239  }
   240  
   241  // malloc ...
   242  func (p *BinaryProtocol) malloc(size int) ([]byte, error) {
   243  	buf, err := p.trans.Malloc(size)
   244  	if err != nil {
   245  		return buf, perrors.NewProtocolError(err)
   246  	}
   247  	return buf, nil
   248  }
   249  
   250  /**
   251   * Reading methods
   252   */
   253  
   254  // ReadMessageBegin ...
   255  func (p *BinaryProtocol) ReadMessageBegin() (name string, typeID thrift.TMessageType, seqID int32, err error) {
   256  	size, e := p.ReadI32()
   257  	if e != nil {
   258  		return "", typeID, 0, perrors.NewProtocolError(e)
   259  	}
   260  	if size > 0 {
   261  		return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Missing version in ReadMessageBegin")
   262  	}
   263  	typeID = thrift.TMessageType(size & 0x0ff)
   264  	version := int64(int64(size) & thrift.VERSION_MASK)
   265  	if version != thrift.VERSION_1 {
   266  		return name, typeID, seqID, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in ReadMessageBegin")
   267  	}
   268  	name, e = p.ReadString()
   269  	if e != nil {
   270  		return name, typeID, seqID, perrors.NewProtocolError(e)
   271  	}
   272  	seqID, e = p.ReadI32()
   273  	if e != nil {
   274  		return name, typeID, seqID, perrors.NewProtocolError(e)
   275  	}
   276  	return name, typeID, seqID, nil
   277  }
   278  
   279  // ReadMessageEnd ...
   280  func (p *BinaryProtocol) ReadMessageEnd() error {
   281  	return nil
   282  }
   283  
   284  // ReadStructBegin ...
   285  func (p *BinaryProtocol) ReadStructBegin() (name string, err error) {
   286  	return
   287  }
   288  
   289  // ReadStructEnd ...
   290  func (p *BinaryProtocol) ReadStructEnd() error {
   291  	return nil
   292  }
   293  
   294  // ReadFieldBegin ...
   295  func (p *BinaryProtocol) ReadFieldBegin() (name string, typeID thrift.TType, id int16, err error) {
   296  	t, err := p.ReadByte()
   297  	typeID = thrift.TType(t)
   298  	if err != nil {
   299  		return name, typeID, id, err
   300  	}
   301  	if t != thrift.STOP {
   302  		id, err = p.ReadI16()
   303  	}
   304  	return name, typeID, id, err
   305  }
   306  
   307  // ReadFieldEnd ...
   308  func (p *BinaryProtocol) ReadFieldEnd() error {
   309  	return nil
   310  }
   311  
   312  // ReadMapBegin ...
   313  func (p *BinaryProtocol) ReadMapBegin() (kType, vType thrift.TType, size int, err error) {
   314  	k, e := p.ReadByte()
   315  	if e != nil {
   316  		err = perrors.NewProtocolError(e)
   317  		return
   318  	}
   319  	kType = thrift.TType(k)
   320  	v, e := p.ReadByte()
   321  	if e != nil {
   322  		err = perrors.NewProtocolError(e)
   323  		return
   324  	}
   325  	vType = thrift.TType(v)
   326  	size32, e := p.ReadI32()
   327  	if e != nil {
   328  		err = perrors.NewProtocolError(e)
   329  		return
   330  	}
   331  	if size32 < 0 {
   332  		err = perrors.InvalidDataLength
   333  		return
   334  	}
   335  	size = int(size32)
   336  	return kType, vType, size, nil
   337  }
   338  
   339  // ReadMapEnd ...
   340  func (p *BinaryProtocol) ReadMapEnd() error {
   341  	return nil
   342  }
   343  
   344  // ReadListBegin ...
   345  func (p *BinaryProtocol) ReadListBegin() (elemType thrift.TType, size int, err error) {
   346  	b, e := p.ReadByte()
   347  	if e != nil {
   348  		err = perrors.NewProtocolError(e)
   349  		return
   350  	}
   351  	elemType = thrift.TType(b)
   352  	size32, e := p.ReadI32()
   353  	if e != nil {
   354  		err = perrors.NewProtocolError(e)
   355  		return
   356  	}
   357  	if size32 < 0 {
   358  		err = perrors.InvalidDataLength
   359  		return
   360  	}
   361  	size = int(size32)
   362  
   363  	return
   364  }
   365  
   366  // ReadListEnd ...
   367  func (p *BinaryProtocol) ReadListEnd() error {
   368  	return nil
   369  }
   370  
   371  // ReadSetBegin ...
   372  func (p *BinaryProtocol) ReadSetBegin() (elemType thrift.TType, size int, err error) {
   373  	b, e := p.ReadByte()
   374  	if e != nil {
   375  		err = perrors.NewProtocolError(e)
   376  		return
   377  	}
   378  	elemType = thrift.TType(b)
   379  	size32, e := p.ReadI32()
   380  	if e != nil {
   381  		err = perrors.NewProtocolError(e)
   382  		return
   383  	}
   384  	if size32 < 0 {
   385  		err = perrors.InvalidDataLength
   386  		return
   387  	}
   388  	size = int(size32)
   389  	return elemType, size, nil
   390  }
   391  
   392  // ReadSetEnd ...
   393  func (p *BinaryProtocol) ReadSetEnd() error {
   394  	return nil
   395  }
   396  
   397  // ReadBool ...
   398  func (p *BinaryProtocol) ReadBool() (bool, error) {
   399  	b, e := p.ReadByte()
   400  	v := true
   401  	if b != 1 {
   402  		v = false
   403  	}
   404  	return v, e
   405  }
   406  
   407  // ReadByte ...
   408  func (p *BinaryProtocol) ReadByte() (value int8, err error) {
   409  	buf, err := p.next(1)
   410  	if err != nil {
   411  		return value, err
   412  	}
   413  	return int8(buf[0]), err
   414  }
   415  
   416  // ReadI16 ...
   417  func (p *BinaryProtocol) ReadI16() (value int16, err error) {
   418  	buf, err := p.next(2)
   419  	if err != nil {
   420  		return value, err
   421  	}
   422  	value = int16(binary.BigEndian.Uint16(buf))
   423  	return value, err
   424  }
   425  
   426  // ReadI32 ...
   427  func (p *BinaryProtocol) ReadI32() (value int32, err error) {
   428  	buf, err := p.next(4)
   429  	if err != nil {
   430  		return value, err
   431  	}
   432  	value = int32(binary.BigEndian.Uint32(buf))
   433  	return value, err
   434  }
   435  
   436  // ReadI64 ...
   437  func (p *BinaryProtocol) ReadI64() (value int64, err error) {
   438  	buf, err := p.next(8)
   439  	if err != nil {
   440  		return value, err
   441  	}
   442  	value = int64(binary.BigEndian.Uint64(buf))
   443  	return value, err
   444  }
   445  
   446  // ReadDouble ...
   447  func (p *BinaryProtocol) ReadDouble() (value float64, err error) {
   448  	buf, err := p.next(8)
   449  	if err != nil {
   450  		return value, err
   451  	}
   452  	value = math.Float64frombits(binary.BigEndian.Uint64(buf))
   453  	return value, err
   454  }
   455  
   456  // ReadString ...
   457  func (p *BinaryProtocol) ReadString() (value string, err error) {
   458  	size, e := p.ReadI32()
   459  	if e != nil {
   460  		return "", e
   461  	}
   462  	if size < 0 {
   463  		err = perrors.InvalidDataLength
   464  		return
   465  	}
   466  	value, err = p.trans.ReadString(int(size))
   467  	if err != nil {
   468  		return value, perrors.NewProtocolError(err)
   469  	}
   470  	return value, nil
   471  }
   472  
   473  // ReadBinary ...
   474  func (p *BinaryProtocol) ReadBinary() ([]byte, error) {
   475  	size, e := p.ReadI32()
   476  	if e != nil {
   477  		return nil, e
   478  	}
   479  	if size < 0 {
   480  		return nil, perrors.InvalidDataLength
   481  	}
   482  	return p.trans.ReadBinary(int(size))
   483  }
   484  
   485  // Flush ...
   486  func (p *BinaryProtocol) Flush(ctx context.Context) (err error) {
   487  	err = p.trans.Flush()
   488  	if err != nil {
   489  		return perrors.NewProtocolError(err)
   490  	}
   491  	return nil
   492  }
   493  
   494  // Skip ...
   495  func (p *BinaryProtocol) Skip(fieldType thrift.TType) (err error) {
   496  	return thrift.SkipDefaultDepth(p, fieldType)
   497  }
   498  
   499  // Transport ...
   500  func (p *BinaryProtocol) Transport() thrift.TTransport {
   501  	// not support
   502  	return nil
   503  }
   504  
   505  // ByteBuffer ...
   506  func (p *BinaryProtocol) ByteBuffer() remote.ByteBuffer {
   507  	return p.trans
   508  }
   509  
   510  // next ...
   511  func (p *BinaryProtocol) next(size int) ([]byte, error) {
   512  	buf, err := p.trans.Next(size)
   513  	if err != nil {
   514  		return buf, perrors.NewProtocolError(err)
   515  	}
   516  	return buf, nil
   517  }