github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/thrift/thrift_data.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  
    23  	"github.com/apache/thrift/lib/go/thrift"
    24  	"github.com/bytedance/gopkg/lang/mcache"
    25  
    26  	"github.com/cloudwego/kitex/pkg/protocol/bthrift"
    27  	"github.com/cloudwego/kitex/pkg/remote"
    28  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    29  )
    30  
    31  const marshalThriftBufferSize = 1024
    32  
    33  // MarshalThriftData only encodes the data (without the prepending methodName, msgType, seqId)
    34  // It will allocate a new buffer and encode to it
    35  func MarshalThriftData(ctx context.Context, codec remote.PayloadCodec, data interface{}) ([]byte, error) {
    36  	c, ok := codec.(*thriftCodec)
    37  	if !ok {
    38  		c = defaultCodec
    39  	}
    40  	return c.marshalThriftData(ctx, data)
    41  }
    42  
    43  // marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId)
    44  // It will allocate a new buffer and encode to it
    45  func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([]byte, error) {
    46  	// encode with hyper codec
    47  	// NOTE: to ensure hyperMarshalEnabled is inlined so split the check logic, or it may cause performance loss
    48  	if c.hyperMarshalEnabled() && hyperMarshalAvailable(data) {
    49  		return c.hyperMarshalBody(data)
    50  	}
    51  
    52  	// encode with FastWrite
    53  	if c.CodecType&FastWrite != 0 {
    54  		if msg, ok := data.(ThriftMsgFastCodec); ok {
    55  			payloadSize := msg.BLength()
    56  			payload := mcache.Malloc(payloadSize)
    57  			msg.FastWriteNocopy(payload, nil)
    58  			return payload, nil
    59  		}
    60  	}
    61  
    62  	if err := verifyMarshalBasicThriftDataType(data); err != nil {
    63  		// Basic can be used for disabling frugal, we need to check it
    64  		if c.CodecType != Basic && hyperMarshalAvailable(data) {
    65  			// fallback to frugal when the generated code is using slim template
    66  			return c.hyperMarshalBody(data)
    67  		}
    68  		return nil, err
    69  	}
    70  
    71  	// fallback to old thrift way (slow)
    72  	transport := thrift.NewTMemoryBufferLen(marshalThriftBufferSize)
    73  	tProt := thrift.NewTBinaryProtocol(transport, true, true)
    74  	if err := marshalBasicThriftData(ctx, tProt, data); err != nil {
    75  		return nil, err
    76  	}
    77  	return transport.Bytes(), nil
    78  }
    79  
    80  // verifyMarshalBasicThriftDataType verifies whether data could be marshaled by old thrift way
    81  func verifyMarshalBasicThriftDataType(data interface{}) error {
    82  	switch data.(type) {
    83  	case MessageWriter:
    84  	case MessageWriterWithContext:
    85  	default:
    86  		return errEncodeMismatchMsgType
    87  	}
    88  	return nil
    89  }
    90  
    91  // marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId)
    92  // It uses the old thrift way which is much slower than FastCodec and Frugal
    93  func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data interface{}) error {
    94  	switch msg := data.(type) {
    95  	case MessageWriter:
    96  		if err := msg.Write(tProt); err != nil {
    97  			return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error()))
    98  		}
    99  	case MessageWriterWithContext:
   100  		if err := msg.Write(ctx, tProt); err != nil {
   101  			return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error()))
   102  		}
   103  	default:
   104  		return errEncodeMismatchMsgType
   105  	}
   106  	return nil
   107  }
   108  
   109  // UnmarshalThriftException decode thrift exception from tProt
   110  // If your input is []byte, you can wrap it with `NewBinaryProtocol(remote.NewReaderBuffer(buf))`
   111  func UnmarshalThriftException(tProt thrift.TProtocol) error {
   112  	exception := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "")
   113  	if err := exception.Read(tProt); err != nil {
   114  		return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal Exception failed: %s", err.Error()))
   115  	}
   116  	if err := tProt.ReadMessageEnd(); err != nil {
   117  		return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal, ReadMessageEnd failed: %s", err.Error()))
   118  	}
   119  	return remote.NewTransError(exception.TypeId(), exception)
   120  }
   121  
   122  // UnmarshalThriftData only decodes the data (after methodName, msgType and seqId)
   123  // It will decode from the given buffer.
   124  // Note:
   125  // 1. `method` is only used for generic calls
   126  // 2. if the buf contains an exception, you should call UnmarshalThriftException instead.
   127  func UnmarshalThriftData(ctx context.Context, codec remote.PayloadCodec, method string, buf []byte, data interface{}) error {
   128  	c, ok := codec.(*thriftCodec)
   129  	if !ok {
   130  		c = defaultCodec
   131  	}
   132  	tProt := NewBinaryProtocol(remote.NewReaderBuffer(buf))
   133  	err := c.unmarshalThriftData(ctx, tProt, method, data, len(buf))
   134  	if err == nil {
   135  		tProt.Recycle()
   136  	}
   137  	return err
   138  }
   139  
   140  // unmarshalThriftData only decodes the data (after methodName, msgType and seqId)
   141  // method is only used for generic calls
   142  func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProtocol, method string, data interface{}, dataLen int) error {
   143  	// decode with hyper unmarshal
   144  	if c.hyperMessageUnmarshalEnabled() && hyperMessageUnmarshalAvailable(data, dataLen) {
   145  		return c.hyperUnmarshal(tProt, data, dataLen)
   146  	}
   147  
   148  	// decode with FastRead
   149  	if c.CodecType&FastRead != 0 {
   150  		if msg, ok := data.(ThriftMsgFastCodec); ok && dataLen > 0 {
   151  			buf, err := tProt.next(dataLen)
   152  			if err != nil {
   153  				return remote.NewTransError(remote.ProtocolError, err)
   154  			}
   155  			_, err = msg.FastRead(buf)
   156  			return err
   157  		}
   158  	}
   159  
   160  	if err := verifyUnmarshalBasicThriftDataType(data); err != nil {
   161  		// Basic can be used for disabling frugal, we need to check it
   162  		if c.CodecType != Basic && hyperMessageUnmarshalAvailable(data, dataLen) {
   163  			// fallback to frugal when the generated code is using slim template
   164  			return c.hyperUnmarshal(tProt, data, dataLen)
   165  		}
   166  		return err
   167  	}
   168  
   169  	// fallback to old thrift way (slow)
   170  	return decodeBasicThriftData(ctx, tProt, method, data)
   171  }
   172  
   173  func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error {
   174  	buf, err := tProt.next(dataLen - bthrift.Binary.MessageEndLength())
   175  	if err != nil {
   176  		return remote.NewTransError(remote.ProtocolError, err)
   177  	}
   178  	return c.hyperMessageUnmarshal(buf, data)
   179  }
   180  
   181  // verifyUnmarshalBasicThriftDataType verifies whether data could be unmarshal by old thrift way
   182  func verifyUnmarshalBasicThriftDataType(data interface{}) error {
   183  	switch data.(type) {
   184  	case MessageReader:
   185  	case MessageReaderWithMethodWithContext:
   186  	default:
   187  		return errDecodeMismatchMsgType
   188  	}
   189  	return nil
   190  }
   191  
   192  // decodeBasicThriftData decode thrift body the old way (slow)
   193  func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method string, data interface{}) error {
   194  	var err error
   195  	switch t := data.(type) {
   196  	case MessageReader:
   197  		if err = t.Read(tProt); err != nil {
   198  			return remote.NewTransError(remote.ProtocolError, err)
   199  		}
   200  	case MessageReaderWithMethodWithContext:
   201  		// methodName is necessary for generic calls to methodInfo from serviceInfo
   202  		if err = t.Read(ctx, method, tProt); err != nil {
   203  			return remote.NewTransError(remote.ProtocolError, err)
   204  		}
   205  	default:
   206  		return errDecodeMismatchMsgType
   207  	}
   208  	return nil
   209  }