github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/thrift/thrift.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  
    24  	"github.com/apache/thrift/lib/go/thrift"
    25  
    26  	"github.com/cloudwego/kitex/pkg/protocol/bthrift"
    27  	"github.com/cloudwego/kitex/pkg/remote"
    28  	"github.com/cloudwego/kitex/pkg/remote/codec"
    29  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    30  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    31  	"github.com/cloudwego/kitex/pkg/stats"
    32  )
    33  
    34  // CodecType is config of the thrift codec. Priority: Frugal > FastMode > Normal
    35  type CodecType int
    36  
    37  const (
    38  	// Basic can be used for disabling fastCodec and frugal
    39  	Basic     CodecType = 0b0000
    40  	FastWrite CodecType = 0b0001
    41  	FastRead  CodecType = 0b0010
    42  
    43  	FastReadWrite = FastRead | FastWrite
    44  )
    45  
    46  var (
    47  	defaultCodec = NewThriftCodec().(*thriftCodec)
    48  
    49  	errEncodeMismatchMsgType = remote.NewTransErrorWithMsg(remote.InvalidProtocol,
    50  		"encode failed, codec msg type not match with thriftCodec")
    51  	errDecodeMismatchMsgType = remote.NewTransErrorWithMsg(remote.InvalidProtocol,
    52  		"decode failed, codec msg type not match with thriftCodec")
    53  )
    54  
    55  // NewThriftCodec creates the thrift binary codec.
    56  func NewThriftCodec() remote.PayloadCodec {
    57  	return &thriftCodec{FastWrite | FastRead}
    58  }
    59  
    60  // IsThriftCodec checks if the codec is thriftCodec
    61  func IsThriftCodec(c remote.PayloadCodec) bool {
    62  	_, ok := c.(*thriftCodec)
    63  	return ok
    64  }
    65  
    66  // NewThriftFrugalCodec creates the thrift binary codec powered by frugal.
    67  // Eg: xxxservice.NewServer(handler, server.WithPayloadCodec(thrift.NewThriftCodecWithConfig(thrift.FastWrite | thrift.FastRead)))
    68  func NewThriftCodecWithConfig(c CodecType) remote.PayloadCodec {
    69  	return &thriftCodec{c}
    70  }
    71  
    72  // NewThriftCodecDisableFastMode creates the thrift binary codec which can control if do fast codec.
    73  // Eg: xxxservice.NewServer(handler, server.WithPayloadCodec(thrift.NewThriftCodecDisableFastMode(true, true)))
    74  func NewThriftCodecDisableFastMode(disableFastWrite, disableFastRead bool) remote.PayloadCodec {
    75  	var c CodecType
    76  	if !disableFastRead {
    77  		c |= FastRead
    78  	}
    79  	if !disableFastWrite {
    80  		c |= FastWrite
    81  	}
    82  	return &thriftCodec{c}
    83  }
    84  
    85  // thriftCodec implements PayloadCodec
    86  type thriftCodec struct {
    87  	CodecType
    88  }
    89  
    90  // Marshal implements the remote.PayloadCodec interface.
    91  func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out remote.ByteBuffer) error {
    92  	// prepare info
    93  	methodName := message.RPCInfo().Invocation().MethodName()
    94  	if methodName == "" {
    95  		return errors.New("empty methodName in thrift Marshal")
    96  	}
    97  	msgType := message.MessageType()
    98  	seqID := message.RPCInfo().Invocation().SeqID()
    99  
   100  	data, err := getValidData(methodName, message)
   101  	if err != nil {
   102  		return err
   103  	}
   104  
   105  	// encode with hyper codec
   106  	// NOTE: to ensure hyperMarshalEnabled is inlined so split the check logic, or it may cause performance loss
   107  	if c.hyperMarshalEnabled() && hyperMarshalAvailable(data) {
   108  		return c.hyperMarshal(out, methodName, msgType, seqID, data)
   109  	}
   110  
   111  	// encode with FastWrite
   112  	if c.CodecType&FastWrite != 0 {
   113  		if msg, ok := data.(ThriftMsgFastCodec); ok {
   114  			return encodeFastThrift(out, methodName, msgType, seqID, msg)
   115  		}
   116  	}
   117  
   118  	// fallback to old thrift way (slow)
   119  	if err = encodeBasicThrift(out, ctx, methodName, msgType, seqID, data); err == nil || err != errEncodeMismatchMsgType {
   120  		return err
   121  	}
   122  
   123  	// Basic can be used for disabling frugal, we need to check it
   124  	if c.CodecType != Basic && hyperMarshalAvailable(data) {
   125  		// fallback to frugal when the generated code is using slim template
   126  		return c.hyperMarshal(out, methodName, msgType, seqID, data)
   127  	}
   128  
   129  	return errEncodeMismatchMsgType
   130  }
   131  
   132  // encodeFastThrift encode with the FastCodec way
   133  func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, msg ThriftMsgFastCodec) error {
   134  	// nocopy write is a special implementation of linked buffer, only bytebuffer implement NocopyWrite do FastWrite
   135  	msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID)
   136  	msgEndLen := bthrift.Binary.MessageEndLength()
   137  	buf, err := out.Malloc(msgBeginLen + msg.BLength() + msgEndLen)
   138  	if err != nil {
   139  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error()))
   140  	}
   141  	offset := bthrift.Binary.WriteMessageBegin(buf, methodName, thrift.TMessageType(msgType), seqID)
   142  	offset += msg.FastWriteNocopy(buf[offset:], nil)
   143  	bthrift.Binary.WriteMessageEnd(buf[offset:])
   144  	return nil
   145  }
   146  
   147  // encodeBasicThrift encode with the old thrift way (slow)
   148  func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}) error {
   149  	if err := verifyMarshalBasicThriftDataType(data); err != nil {
   150  		return err
   151  	}
   152  	tProt := NewBinaryProtocol(out)
   153  	if err := tProt.WriteMessageBegin(method, thrift.TMessageType(msgType), seqID); err != nil {
   154  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageBegin failed: %s", err.Error()))
   155  	}
   156  	if err := marshalBasicThriftData(ctx, tProt, data); err != nil {
   157  		return err
   158  	}
   159  	if err := tProt.WriteMessageEnd(); err != nil {
   160  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageEnd failed: %s", err.Error()))
   161  	}
   162  	tProt.Recycle()
   163  	return nil
   164  }
   165  
   166  // Unmarshal implements the remote.PayloadCodec interface.
   167  func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
   168  	tProt := NewBinaryProtocol(in)
   169  	methodName, msgType, seqID, err := tProt.ReadMessageBegin()
   170  	if err != nil {
   171  		return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal, ReadMessageBegin failed: %s", err.Error()))
   172  	}
   173  	if err = codec.UpdateMsgType(uint32(msgType), message); err != nil {
   174  		return err
   175  	}
   176  
   177  	// exception message
   178  	if message.MessageType() == remote.Exception {
   179  		return UnmarshalThriftException(tProt)
   180  	}
   181  
   182  	if err = validateMessageBeforeDecode(message, seqID, methodName); err != nil {
   183  		return err
   184  	}
   185  
   186  	// decode thrift data
   187  	data := message.Data()
   188  	msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, msgType, seqID)
   189  	dataLen := message.PayloadLen() - msgBeginLen - bthrift.Binary.MessageEndLength()
   190  
   191  	ri := message.RPCInfo()
   192  	rpcinfo.Record(ctx, ri, stats.WaitReadStart, nil)
   193  	err = c.unmarshalThriftData(ctx, tProt, methodName, data, dataLen)
   194  	rpcinfo.Record(ctx, ri, stats.WaitReadFinish, err)
   195  	if err != nil {
   196  		return err
   197  	}
   198  
   199  	if err = tProt.ReadMessageEnd(); err != nil {
   200  		return remote.NewTransError(remote.ProtocolError, err)
   201  	}
   202  	tProt.Recycle()
   203  	return err
   204  }
   205  
   206  // validateMessageBeforeDecode validate message before decode
   207  func validateMessageBeforeDecode(message remote.Message, seqID int32, methodName string) (err error) {
   208  	// For server side, the following error can be sent back and 'SetSeqID' should be executed first to ensure the seqID
   209  	// is right when return Exception back.
   210  	if err = codec.SetOrCheckSeqID(seqID, message); err != nil {
   211  		return err
   212  	}
   213  
   214  	if err = codec.SetOrCheckMethodName(methodName, message); err != nil {
   215  		return err
   216  	}
   217  
   218  	if err = codec.NewDataIfNeeded(methodName, message); err != nil {
   219  		return err
   220  	}
   221  	return nil
   222  }
   223  
   224  // Name implements the remote.PayloadCodec interface.
   225  func (c thriftCodec) Name() string {
   226  	return "thrift"
   227  }
   228  
   229  // MessageWriterWithContext write to thrift.TProtocol
   230  type MessageWriterWithContext interface {
   231  	Write(ctx context.Context, oprot thrift.TProtocol) error
   232  }
   233  
   234  // MessageWriter write to thrift.TProtocol
   235  type MessageWriter interface {
   236  	Write(oprot thrift.TProtocol) error
   237  }
   238  
   239  // MessageReader read from thrift.TProtocol
   240  type MessageReader interface {
   241  	Read(oprot thrift.TProtocol) error
   242  }
   243  
   244  // MessageReaderWithMethodWithContext read from thrift.TProtocol with method
   245  type MessageReaderWithMethodWithContext interface {
   246  	Read(ctx context.Context, method string, oprot thrift.TProtocol) error
   247  }
   248  
   249  type ThriftMsgFastCodec interface {
   250  	BLength() int
   251  	FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int
   252  	FastRead(buf []byte) (int, error)
   253  }
   254  
   255  func getValidData(methodName string, message remote.Message) (interface{}, error) {
   256  	if err := codec.NewDataIfNeeded(methodName, message); err != nil {
   257  		return nil, err
   258  	}
   259  	data := message.Data()
   260  	if message.MessageType() != remote.Exception {
   261  		return data, nil
   262  	}
   263  	transErr, isTransErr := data.(*remote.TransError)
   264  	if !isTransErr {
   265  		if err, isError := data.(error); isError {
   266  			encodeErr := thrift.NewTApplicationException(remote.InternalError, err.Error())
   267  			return encodeErr, nil
   268  		}
   269  		return nil, errors.New("exception relay need error type data")
   270  	}
   271  	encodeErr := thrift.NewTApplicationException(transErr.TypeID(), transErr.Error())
   272  	return encodeErr, nil
   273  }