github.com/cloudwego/kitex@v0.9.0/pkg/protocol/bthrift/binary.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  // Package bthrift .
    18  package bthrift
    19  
    20  import (
    21  	"encoding/binary"
    22  	"errors"
    23  	"fmt"
    24  	"math"
    25  
    26  	"github.com/apache/thrift/lib/go/thrift"
    27  
    28  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    29  	"github.com/cloudwego/kitex/pkg/utils"
    30  )
    31  
    32  // Binary protocol for bthrift.
    33  var Binary binaryProtocol
    34  
    35  var _ BTProtocol = binaryProtocol{}
    36  
    37  type binaryProtocol struct{}
    38  
    39  func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID thrift.TMessageType, seqid int32) int {
    40  	offset := 0
    41  	version := uint32(thrift.VERSION_1) | uint32(typeID)
    42  	offset += Binary.WriteI32(buf, int32(version))
    43  	offset += Binary.WriteString(buf[offset:], name)
    44  	offset += Binary.WriteI32(buf[offset:], seqid)
    45  	return offset
    46  }
    47  
    48  func (binaryProtocol) WriteMessageEnd(buf []byte) int {
    49  	return 0
    50  }
    51  
    52  func (binaryProtocol) WriteStructBegin(buf []byte, name string) int {
    53  	return 0
    54  }
    55  
    56  func (binaryProtocol) WriteStructEnd(buf []byte) int {
    57  	return 0
    58  }
    59  
    60  func (binaryProtocol) WriteFieldBegin(buf []byte, name string, typeID thrift.TType, id int16) int {
    61  	return Binary.WriteByte(buf, int8(typeID)) + Binary.WriteI16(buf[1:], id)
    62  }
    63  
    64  func (binaryProtocol) WriteFieldEnd(buf []byte) int {
    65  	return 0
    66  }
    67  
    68  func (binaryProtocol) WriteFieldStop(buf []byte) int {
    69  	return Binary.WriteByte(buf, thrift.STOP)
    70  }
    71  
    72  func (binaryProtocol) WriteMapBegin(buf []byte, keyType, valueType thrift.TType, size int) int {
    73  	return Binary.WriteByte(buf, int8(keyType)) +
    74  		Binary.WriteByte(buf[1:], int8(valueType)) +
    75  		Binary.WriteI32(buf[2:], int32(size))
    76  }
    77  
    78  func (binaryProtocol) WriteMapEnd(buf []byte) int {
    79  	return 0
    80  }
    81  
    82  func (binaryProtocol) WriteListBegin(buf []byte, elemType thrift.TType, size int) int {
    83  	return Binary.WriteByte(buf, int8(elemType)) +
    84  		Binary.WriteI32(buf[1:], int32(size))
    85  }
    86  
    87  func (binaryProtocol) WriteListEnd(buf []byte) int {
    88  	return 0
    89  }
    90  
    91  func (binaryProtocol) WriteSetBegin(buf []byte, elemType thrift.TType, size int) int {
    92  	return Binary.WriteByte(buf, int8(elemType)) +
    93  		Binary.WriteI32(buf[1:], int32(size))
    94  }
    95  
    96  func (binaryProtocol) WriteSetEnd(buf []byte) int {
    97  	return 0
    98  }
    99  
   100  func (binaryProtocol) WriteBool(buf []byte, value bool) int {
   101  	if value {
   102  		return Binary.WriteByte(buf, 1)
   103  	}
   104  	return Binary.WriteByte(buf, 0)
   105  }
   106  
   107  func (binaryProtocol) WriteByte(buf []byte, value int8) int {
   108  	buf[0] = byte(value)
   109  	return 1
   110  }
   111  
   112  func (binaryProtocol) WriteI16(buf []byte, value int16) int {
   113  	binary.BigEndian.PutUint16(buf, uint16(value))
   114  	return 2
   115  }
   116  
   117  func (binaryProtocol) WriteI32(buf []byte, value int32) int {
   118  	binary.BigEndian.PutUint32(buf, uint32(value))
   119  	return 4
   120  }
   121  
   122  func (binaryProtocol) WriteI64(buf []byte, value int64) int {
   123  	binary.BigEndian.PutUint64(buf, uint64(value))
   124  	return 8
   125  }
   126  
   127  func (binaryProtocol) WriteDouble(buf []byte, value float64) int {
   128  	return Binary.WriteI64(buf, int64(math.Float64bits(value)))
   129  }
   130  
   131  func (binaryProtocol) WriteString(buf []byte, value string) int {
   132  	l := Binary.WriteI32(buf, int32(len(value)))
   133  	copy(buf[l:], value)
   134  	return l + len(value)
   135  }
   136  
   137  func (binaryProtocol) WriteBinary(buf, value []byte) int {
   138  	l := Binary.WriteI32(buf, int32(len(value)))
   139  	copy(buf[l:], value)
   140  	return l + len(value)
   141  }
   142  
   143  func (binaryProtocol) WriteStringNocopy(buf []byte, binaryWriter BinaryWriter, value string) int {
   144  	return Binary.WriteBinaryNocopy(buf, binaryWriter, utils.StringToSliceByte(value))
   145  }
   146  
   147  func (binaryProtocol) WriteBinaryNocopy(buf []byte, binaryWriter BinaryWriter, value []byte) int {
   148  	l := Binary.WriteI32(buf, int32(len(value)))
   149  	copy(buf[l:], value)
   150  	return l + len(value)
   151  }
   152  
   153  func (binaryProtocol) MessageBeginLength(name string, typeID thrift.TMessageType, seqid int32) int {
   154  	version := uint32(thrift.VERSION_1) | uint32(typeID)
   155  	return Binary.I32Length(int32(version)) + Binary.StringLength(name) + Binary.I32Length(seqid)
   156  }
   157  
   158  func (binaryProtocol) MessageEndLength() int {
   159  	return 0
   160  }
   161  
   162  func (binaryProtocol) StructBeginLength(name string) int {
   163  	return 0
   164  }
   165  
   166  func (binaryProtocol) StructEndLength() int {
   167  	return 0
   168  }
   169  
   170  func (binaryProtocol) FieldBeginLength(name string, typeID thrift.TType, id int16) int {
   171  	return Binary.ByteLength(int8(typeID)) + Binary.I16Length(id)
   172  }
   173  
   174  func (binaryProtocol) FieldEndLength() int {
   175  	return 0
   176  }
   177  
   178  func (binaryProtocol) FieldStopLength() int {
   179  	return Binary.ByteLength(thrift.STOP)
   180  }
   181  
   182  func (binaryProtocol) MapBeginLength(keyType, valueType thrift.TType, size int) int {
   183  	return Binary.ByteLength(int8(keyType)) +
   184  		Binary.ByteLength(int8(valueType)) +
   185  		Binary.I32Length(int32(size))
   186  }
   187  
   188  func (binaryProtocol) MapEndLength() int {
   189  	return 0
   190  }
   191  
   192  func (binaryProtocol) ListBeginLength(elemType thrift.TType, size int) int {
   193  	return Binary.ByteLength(int8(elemType)) +
   194  		Binary.I32Length(int32(size))
   195  }
   196  
   197  func (binaryProtocol) ListEndLength() int {
   198  	return 0
   199  }
   200  
   201  func (binaryProtocol) SetBeginLength(elemType thrift.TType, size int) int {
   202  	return Binary.ByteLength(int8(elemType)) +
   203  		Binary.I32Length(int32(size))
   204  }
   205  
   206  func (binaryProtocol) SetEndLength() int {
   207  	return 0
   208  }
   209  
   210  func (binaryProtocol) BoolLength(value bool) int {
   211  	if value {
   212  		return Binary.ByteLength(1)
   213  	}
   214  	return Binary.ByteLength(0)
   215  }
   216  
   217  func (binaryProtocol) ByteLength(value int8) int {
   218  	return 1
   219  }
   220  
   221  func (binaryProtocol) I16Length(value int16) int {
   222  	return 2
   223  }
   224  
   225  func (binaryProtocol) I32Length(value int32) int {
   226  	return 4
   227  }
   228  
   229  func (binaryProtocol) I64Length(value int64) int {
   230  	return 8
   231  }
   232  
   233  func (binaryProtocol) DoubleLength(value float64) int {
   234  	return Binary.I64Length(int64(math.Float64bits(value)))
   235  }
   236  
   237  func (binaryProtocol) StringLength(value string) int {
   238  	return Binary.I32Length(int32(len(value))) + len(value)
   239  }
   240  
   241  func (binaryProtocol) BinaryLength(value []byte) int {
   242  	return Binary.I32Length(int32(len(value))) + len(value)
   243  }
   244  
   245  func (binaryProtocol) StringLengthNocopy(value string) int {
   246  	return Binary.BinaryLengthNocopy(utils.StringToSliceByte(value))
   247  }
   248  
   249  func (binaryProtocol) BinaryLengthNocopy(value []byte) int {
   250  	l := Binary.I32Length(int32(len(value)))
   251  	return l + len(value)
   252  }
   253  
   254  func (binaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID thrift.TMessageType, seqid int32, length int, err error) {
   255  	size, l, e := Binary.ReadI32(buf)
   256  	length += l
   257  	if e != nil {
   258  		err = perrors.NewProtocolError(e)
   259  		return
   260  	}
   261  	if size > 0 {
   262  		err = perrors.NewProtocolErrorWithType(perrors.BadVersion, "Missing version in ReadMessageBegin")
   263  		return
   264  	}
   265  	typeID = thrift.TMessageType(size & 0x0ff)
   266  	version := int64(size) & thrift.VERSION_MASK
   267  	if version != thrift.VERSION_1 {
   268  		err = perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in ReadMessageBegin")
   269  		return
   270  	}
   271  	name, l, e = Binary.ReadString(buf[length:])
   272  	length += l
   273  	if e != nil {
   274  		err = perrors.NewProtocolError(e)
   275  		return
   276  	}
   277  	seqid, l, e = Binary.ReadI32(buf[length:])
   278  	length += l
   279  	if e != nil {
   280  		err = perrors.NewProtocolError(e)
   281  		return
   282  	}
   283  	return
   284  }
   285  
   286  func (binaryProtocol) ReadMessageEnd(buf []byte) (int, error) {
   287  	return 0, nil
   288  }
   289  
   290  func (binaryProtocol) ReadStructBegin(buf []byte) (name string, length int, err error) {
   291  	return
   292  }
   293  
   294  func (binaryProtocol) ReadStructEnd(buf []byte) (int, error) {
   295  	return 0, nil
   296  }
   297  
   298  func (binaryProtocol) ReadFieldBegin(buf []byte) (name string, typeID thrift.TType, id int16, length int, err error) {
   299  	t, l, e := Binary.ReadByte(buf)
   300  	length += l
   301  	typeID = thrift.TType(t)
   302  	if e != nil {
   303  		err = e
   304  		return
   305  	}
   306  	if t != thrift.STOP {
   307  		id, l, err = Binary.ReadI16(buf[length:])
   308  		length += l
   309  	}
   310  	return
   311  }
   312  
   313  func (binaryProtocol) ReadFieldEnd(buf []byte) (int, error) {
   314  	return 0, nil
   315  }
   316  
   317  func (binaryProtocol) ReadMapBegin(buf []byte) (keyType, valueType thrift.TType, size, length int, err error) {
   318  	k, l, e := Binary.ReadByte(buf)
   319  	length += l
   320  	if e != nil {
   321  		err = perrors.NewProtocolError(e)
   322  		return
   323  	}
   324  	keyType = thrift.TType(k)
   325  	v, l, e := Binary.ReadByte(buf[length:])
   326  	length += l
   327  	if e != nil {
   328  		err = perrors.NewProtocolError(e)
   329  		return
   330  	}
   331  	valueType = thrift.TType(v)
   332  	size32, l, e := Binary.ReadI32(buf[length:])
   333  	length += l
   334  	if e != nil {
   335  		err = perrors.NewProtocolError(e)
   336  		return
   337  	}
   338  	if size32 < 0 {
   339  		err = perrors.InvalidDataLength
   340  		return
   341  	}
   342  	size = int(size32)
   343  	return
   344  }
   345  
   346  func (binaryProtocol) ReadMapEnd(buf []byte) (int, error) {
   347  	return 0, nil
   348  }
   349  
   350  func (binaryProtocol) ReadListBegin(buf []byte) (elemType thrift.TType, size, length int, err error) {
   351  	b, l, e := Binary.ReadByte(buf)
   352  	length += l
   353  	if e != nil {
   354  		err = perrors.NewProtocolError(e)
   355  		return
   356  	}
   357  	elemType = thrift.TType(b)
   358  	size32, l, e := Binary.ReadI32(buf[length:])
   359  	length += l
   360  	if e != nil {
   361  		err = perrors.NewProtocolError(e)
   362  		return
   363  	}
   364  	if size32 < 0 {
   365  		err = perrors.InvalidDataLength
   366  		return
   367  	}
   368  	size = int(size32)
   369  
   370  	return
   371  }
   372  
   373  func (binaryProtocol) ReadListEnd(buf []byte) (int, error) {
   374  	return 0, nil
   375  }
   376  
   377  func (binaryProtocol) ReadSetBegin(buf []byte) (elemType thrift.TType, size, length int, err error) {
   378  	b, l, e := Binary.ReadByte(buf)
   379  	length += l
   380  	if e != nil {
   381  		err = perrors.NewProtocolError(e)
   382  		return
   383  	}
   384  	elemType = thrift.TType(b)
   385  	size32, l, e := Binary.ReadI32(buf[length:])
   386  	length += l
   387  	if e != nil {
   388  		err = perrors.NewProtocolError(e)
   389  		return
   390  	}
   391  	if size32 < 0 {
   392  		err = perrors.InvalidDataLength
   393  		return
   394  	}
   395  	size = int(size32)
   396  	return
   397  }
   398  
   399  func (binaryProtocol) ReadSetEnd(buf []byte) (int, error) {
   400  	return 0, nil
   401  }
   402  
   403  func (binaryProtocol) ReadBool(buf []byte) (value bool, length int, err error) {
   404  	b, l, e := Binary.ReadByte(buf)
   405  	v := true
   406  	if b != 1 {
   407  		v = false
   408  	}
   409  	return v, l, e
   410  }
   411  
   412  func (binaryProtocol) ReadByte(buf []byte) (value int8, length int, err error) {
   413  	if len(buf) < 1 {
   414  		return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadByte] buf length less than 1")
   415  	}
   416  	return int8(buf[0]), 1, err
   417  }
   418  
   419  func (binaryProtocol) ReadI16(buf []byte) (value int16, length int, err error) {
   420  	if len(buf) < 2 {
   421  		return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadI16] buf length less than 2")
   422  	}
   423  	value = int16(binary.BigEndian.Uint16(buf))
   424  	return value, 2, err
   425  }
   426  
   427  func (binaryProtocol) ReadI32(buf []byte) (value int32, length int, err error) {
   428  	if len(buf) < 4 {
   429  		return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadI32] buf length less than 4")
   430  	}
   431  	value = int32(binary.BigEndian.Uint32(buf))
   432  	return value, 4, err
   433  }
   434  
   435  func (binaryProtocol) ReadI64(buf []byte) (value int64, length int, err error) {
   436  	if len(buf) < 8 {
   437  		return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadI64] buf length less than 8")
   438  	}
   439  	value = int64(binary.BigEndian.Uint64(buf))
   440  	return value, 8, err
   441  }
   442  
   443  func (binaryProtocol) ReadDouble(buf []byte) (value float64, length int, err error) {
   444  	if len(buf) < 8 {
   445  		return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadDouble] buf length less than 8")
   446  	}
   447  	value = math.Float64frombits(binary.BigEndian.Uint64(buf))
   448  	return value, 8, err
   449  }
   450  
   451  func (binaryProtocol) ReadString(buf []byte) (value string, length int, err error) {
   452  	size, l, e := Binary.ReadI32(buf)
   453  	length += l
   454  	if e != nil {
   455  		err = e
   456  		return
   457  	}
   458  	if size < 0 || int(size) > len(buf) {
   459  		return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadString] the string size greater than buf length")
   460  	}
   461  	value = string(buf[length : length+int(size)])
   462  	length += int(size)
   463  	return
   464  }
   465  
   466  func (binaryProtocol) ReadBinary(buf []byte) (value []byte, length int, err error) {
   467  	size, l, e := Binary.ReadI32(buf)
   468  	length += l
   469  	if e != nil {
   470  		err = e
   471  		return
   472  	}
   473  	if size < 0 || int(size) > len(buf) {
   474  		return value, length, perrors.NewProtocolErrorWithType(thrift.INVALID_DATA, "[ReadBinary] the binary size greater than buf length")
   475  	}
   476  	value = make([]byte, size)
   477  	copy(value, buf[length:length+int(size)])
   478  	length += int(size)
   479  	return
   480  }
   481  
   482  // Skip .
   483  func (binaryProtocol) Skip(buf []byte, fieldType thrift.TType) (length int, err error) {
   484  	return SkipDefaultDepth(buf, Binary, fieldType)
   485  }
   486  
   487  // SkipDefaultDepth skips over the next data element from the provided input TProtocol object.
   488  func SkipDefaultDepth(buf []byte, prot BTProtocol, typeID thrift.TType) (int, error) {
   489  	return Skip(buf, prot, typeID, thrift.DEFAULT_RECURSION_DEPTH)
   490  }
   491  
   492  // Skip skips over the next data element from the provided input TProtocol object.
   493  func Skip(buf []byte, self BTProtocol, fieldType thrift.TType, maxDepth int) (length int, err error) {
   494  	if maxDepth <= 0 {
   495  		return 0, thrift.NewTProtocolExceptionWithType(thrift.DEPTH_LIMIT, errors.New("depth limit exceeded"))
   496  	}
   497  
   498  	var l int
   499  	switch fieldType {
   500  	case thrift.BOOL:
   501  		length += 1
   502  		return
   503  	case thrift.BYTE:
   504  		length += 1
   505  		return
   506  	case thrift.I16:
   507  		length += 2
   508  		return
   509  	case thrift.I32:
   510  		length += 4
   511  		return
   512  	case thrift.I64:
   513  		length += 8
   514  		return
   515  	case thrift.DOUBLE:
   516  		length += 8
   517  		return
   518  	case thrift.STRING:
   519  		var sl int32
   520  		sl, l, err = self.ReadI32(buf)
   521  		length += l + int(sl)
   522  		return
   523  	case thrift.STRUCT:
   524  		_, l, err = self.ReadStructBegin(buf)
   525  		length += l
   526  		if err != nil {
   527  			return
   528  		}
   529  		for {
   530  			_, typeID, _, l, e := self.ReadFieldBegin(buf[length:])
   531  			length += l
   532  			if e != nil {
   533  				err = e
   534  				return
   535  			}
   536  			if typeID == thrift.STOP {
   537  				break
   538  			}
   539  			l, e = Skip(buf[length:], self, typeID, maxDepth-1)
   540  			length += l
   541  			if e != nil {
   542  				err = e
   543  				return
   544  			}
   545  			l, e = self.ReadFieldEnd(buf[length:])
   546  			length += l
   547  			if e != nil {
   548  				err = e
   549  				return
   550  			}
   551  		}
   552  		l, e := self.ReadStructEnd(buf[length:])
   553  		length += l
   554  		if e != nil {
   555  			err = e
   556  		}
   557  		return
   558  	case thrift.MAP:
   559  		keyType, valueType, size, l, e := self.ReadMapBegin(buf)
   560  		length += l
   561  		if e != nil {
   562  			err = e
   563  			return
   564  		}
   565  		for i := 0; i < size; i++ {
   566  			l, e := Skip(buf[length:], self, keyType, maxDepth-1)
   567  			length += l
   568  			if e != nil {
   569  				err = e
   570  				return
   571  			}
   572  			l, e = Skip(buf[length:], self, valueType, maxDepth-1)
   573  			length += l
   574  			if e != nil {
   575  				err = e
   576  				return
   577  			}
   578  		}
   579  		l, e = self.ReadMapEnd(buf[length:])
   580  		length += l
   581  		if e != nil {
   582  			err = e
   583  		}
   584  		return
   585  	case thrift.SET:
   586  		elemType, size, l, e := self.ReadSetBegin(buf)
   587  		length += l
   588  		if e != nil {
   589  			err = e
   590  			return
   591  		}
   592  		for i := 0; i < size; i++ {
   593  			l, e = Skip(buf[length:], self, elemType, maxDepth-1)
   594  			length += l
   595  			if e != nil {
   596  				err = e
   597  				return
   598  			}
   599  		}
   600  		l, e = self.ReadSetEnd(buf[length:])
   601  		length += l
   602  		if e != nil {
   603  			err = e
   604  		}
   605  		return
   606  	case thrift.LIST:
   607  		elemType, size, l, e := self.ReadListBegin(buf)
   608  		length += l
   609  		if e != nil {
   610  			err = e
   611  			return
   612  		}
   613  		for i := 0; i < size; i++ {
   614  			l, e = Skip(buf[length:], self, elemType, maxDepth-1)
   615  			length += l
   616  			if e != nil {
   617  				err = e
   618  				return
   619  			}
   620  		}
   621  		l, e = self.ReadListEnd(buf[length:])
   622  		length += l
   623  		if e != nil {
   624  			err = e
   625  		}
   626  		return
   627  	default:
   628  		return 0, thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, fmt.Errorf("unknown data type %d", fieldType))
   629  	}
   630  }