github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/grpc/grpc.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package grpc
    18  
    19  import (
    20  	"context"
    21  	"encoding/binary"
    22  	"errors"
    23  	"fmt"
    24  
    25  	"github.com/bytedance/gopkg/lang/mcache"
    26  	"github.com/cloudwego/fastpb"
    27  	"google.golang.org/protobuf/proto"
    28  
    29  	"github.com/cloudwego/kitex/pkg/remote"
    30  	"github.com/cloudwego/kitex/pkg/remote/codec/protobuf"
    31  	"github.com/cloudwego/kitex/pkg/remote/codec/thrift"
    32  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    33  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    34  )
    35  
    36  const dataFrameHeaderLen = 5
    37  
    38  var ErrInvalidPayload = errors.New("grpc invalid payload")
    39  
    40  // gogoproto generate
    41  type marshaler interface {
    42  	MarshalTo(data []byte) (n int, err error)
    43  	Size() int
    44  }
    45  
    46  type protobufV2MsgCodec interface {
    47  	XXX_Unmarshal(b []byte) error
    48  	XXX_Marshal(b []byte, deterministic bool) ([]byte, error)
    49  }
    50  
    51  type grpcCodec struct {
    52  	ThriftCodec remote.PayloadCodec
    53  }
    54  
    55  type CodecOption func(c *grpcCodec)
    56  
    57  func WithThriftCodec(t remote.PayloadCodec) CodecOption {
    58  	return func(c *grpcCodec) {
    59  		c.ThriftCodec = t
    60  	}
    61  }
    62  
    63  // NewGRPCCodec create grpc and protobuf codec
    64  func NewGRPCCodec(opts ...CodecOption) remote.Codec {
    65  	codec := &grpcCodec{}
    66  	for _, opt := range opts {
    67  		opt(codec)
    68  	}
    69  	if !thrift.IsThriftCodec(codec.ThriftCodec) {
    70  		codec.ThriftCodec = thrift.NewThriftCodec()
    71  	}
    72  	return codec
    73  }
    74  
    75  func mallocWithFirstByteZeroed(size int) []byte {
    76  	data := mcache.Malloc(size)
    77  	data[0] = 0 // compressed flag = false
    78  	return data
    79  }
    80  
    81  func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) {
    82  	var payload []byte
    83  	defer func() {
    84  		// record send size, even when err != nil (0 is recorded to the lastSendSize)
    85  		if rpcStats := rpcinfo.AsMutableRPCStats(message.RPCInfo().Stats()); rpcStats != nil {
    86  			rpcStats.IncrSendSize(uint64(len(payload)))
    87  		}
    88  	}()
    89  
    90  	writer, ok := out.(remote.FrameWrite)
    91  	if !ok {
    92  		return fmt.Errorf("output buffer must implement FrameWrite")
    93  	}
    94  	compressor, err := getSendCompressor(ctx)
    95  	if err != nil {
    96  		return err
    97  	}
    98  	isCompressed := compressor != nil
    99  
   100  	switch message.ProtocolInfo().CodecType {
   101  	case serviceinfo.Thrift:
   102  		payload, err = thrift.MarshalThriftData(ctx, c.ThriftCodec, message.Data())
   103  	case serviceinfo.Protobuf:
   104  		switch t := message.Data().(type) {
   105  		case fastpb.Writer:
   106  			size := t.Size()
   107  			if !isCompressed {
   108  				payload = mallocWithFirstByteZeroed(size + dataFrameHeaderLen)
   109  				t.FastWrite(payload[dataFrameHeaderLen:])
   110  				binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size))
   111  				return writer.WriteData(payload)
   112  			}
   113  			payload = mcache.Malloc(size)
   114  			t.FastWrite(payload)
   115  		case marshaler:
   116  			size := t.Size()
   117  			if !isCompressed {
   118  				payload = mallocWithFirstByteZeroed(size + dataFrameHeaderLen)
   119  				if _, err = t.MarshalTo(payload[dataFrameHeaderLen:]); err != nil {
   120  					return err
   121  				}
   122  				binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size))
   123  				return writer.WriteData(payload)
   124  			}
   125  			payload = mcache.Malloc(size)
   126  			if _, err = t.MarshalTo(payload); err != nil {
   127  				return err
   128  			}
   129  		case protobufV2MsgCodec:
   130  			payload, err = t.XXX_Marshal(nil, true)
   131  		case proto.Message:
   132  			payload, err = proto.Marshal(t)
   133  		case protobuf.ProtobufMsgCodec:
   134  			payload, err = t.Marshal(nil)
   135  		default:
   136  			return ErrInvalidPayload
   137  		}
   138  	default:
   139  		return ErrInvalidPayload
   140  	}
   141  
   142  	if err != nil {
   143  		return err
   144  	}
   145  	var header [dataFrameHeaderLen]byte
   146  	if isCompressed {
   147  		payload, err = compress(compressor, payload)
   148  		if err != nil {
   149  			return err
   150  		}
   151  		header[0] = 1
   152  	} else {
   153  		header[0] = 0
   154  	}
   155  	binary.BigEndian.PutUint32(header[1:dataFrameHeaderLen], uint32(len(payload)))
   156  	err = writer.WriteHeader(header[:])
   157  	if err != nil {
   158  		return err
   159  	}
   160  	return writer.WriteData(payload)
   161  	// TODO: recycle payload?
   162  }
   163  
   164  func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) {
   165  	d, err := decodeGRPCFrame(ctx, in)
   166  	if rpcStats := rpcinfo.AsMutableRPCStats(message.RPCInfo().Stats()); rpcStats != nil {
   167  		// record recv size, even when err != nil (0 is recorded to the lastRecvSize)
   168  		rpcStats.IncrRecvSize(uint64(len(d)))
   169  	}
   170  	if err != nil {
   171  		return err
   172  	}
   173  	message.SetPayloadLen(len(d))
   174  	data := message.Data()
   175  	switch message.ProtocolInfo().CodecType {
   176  	case serviceinfo.Thrift:
   177  		return thrift.UnmarshalThriftData(ctx, c.ThriftCodec, "", d, message.Data())
   178  	case serviceinfo.Protobuf:
   179  		if t, ok := data.(fastpb.Reader); ok {
   180  			if len(d) == 0 {
   181  				// if all fields of a struct is default value, data will be nil
   182  				// In the implementation of fastpb, if data is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected.
   183  				// So, when data is nil, use default protobuf unmarshal method to decode the struct.
   184  				// todo: fix fastpb
   185  			} else {
   186  				_, err = fastpb.ReadMessage(d, fastpb.SkipTypeCheck, t)
   187  				return err
   188  			}
   189  		}
   190  		switch t := data.(type) {
   191  		case protobufV2MsgCodec:
   192  			return t.XXX_Unmarshal(d)
   193  		case proto.Message:
   194  			return proto.Unmarshal(d, t)
   195  		case protobuf.ProtobufMsgCodec:
   196  			return t.Unmarshal(d)
   197  		default:
   198  			return ErrInvalidPayload
   199  		}
   200  	default:
   201  		return ErrInvalidPayload
   202  	}
   203  }
   204  
   205  func (c *grpcCodec) Name() string {
   206  	return "grpc"
   207  }