github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/grpc/grpc.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package grpc 18 19 import ( 20 "context" 21 "encoding/binary" 22 "errors" 23 "fmt" 24 25 "github.com/bytedance/gopkg/lang/mcache" 26 "github.com/cloudwego/fastpb" 27 "google.golang.org/protobuf/proto" 28 29 "github.com/cloudwego/kitex/pkg/remote" 30 "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" 31 "github.com/cloudwego/kitex/pkg/remote/codec/thrift" 32 "github.com/cloudwego/kitex/pkg/rpcinfo" 33 "github.com/cloudwego/kitex/pkg/serviceinfo" 34 ) 35 36 const dataFrameHeaderLen = 5 37 38 var ErrInvalidPayload = errors.New("grpc invalid payload") 39 40 // gogoproto generate 41 type marshaler interface { 42 MarshalTo(data []byte) (n int, err error) 43 Size() int 44 } 45 46 type protobufV2MsgCodec interface { 47 XXX_Unmarshal(b []byte) error 48 XXX_Marshal(b []byte, deterministic bool) ([]byte, error) 49 } 50 51 type grpcCodec struct { 52 ThriftCodec remote.PayloadCodec 53 } 54 55 type CodecOption func(c *grpcCodec) 56 57 func WithThriftCodec(t remote.PayloadCodec) CodecOption { 58 return func(c *grpcCodec) { 59 c.ThriftCodec = t 60 } 61 } 62 63 // NewGRPCCodec create grpc and protobuf codec 64 func NewGRPCCodec(opts ...CodecOption) remote.Codec { 65 codec := &grpcCodec{} 66 for _, opt := range opts { 67 opt(codec) 68 } 69 if !thrift.IsThriftCodec(codec.ThriftCodec) { 70 codec.ThriftCodec = thrift.NewThriftCodec() 71 } 72 return codec 73 } 74 75 func mallocWithFirstByteZeroed(size int) []byte { 76 data := mcache.Malloc(size) 77 data[0] = 0 // compressed flag = false 78 return data 79 } 80 81 func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) { 82 var payload []byte 83 defer func() { 84 // record send size, even when err != nil (0 is recorded to the lastSendSize) 85 if rpcStats := rpcinfo.AsMutableRPCStats(message.RPCInfo().Stats()); rpcStats != nil { 86 rpcStats.IncrSendSize(uint64(len(payload))) 87 } 88 }() 89 90 writer, ok := out.(remote.FrameWrite) 91 if !ok { 92 return fmt.Errorf("output buffer must implement FrameWrite") 93 } 94 compressor, err := getSendCompressor(ctx) 95 if err != nil { 96 return err 97 } 98 isCompressed := compressor != nil 99 100 switch message.ProtocolInfo().CodecType { 101 case serviceinfo.Thrift: 102 payload, err = thrift.MarshalThriftData(ctx, c.ThriftCodec, message.Data()) 103 case serviceinfo.Protobuf: 104 switch t := message.Data().(type) { 105 case fastpb.Writer: 106 size := t.Size() 107 if !isCompressed { 108 payload = mallocWithFirstByteZeroed(size + dataFrameHeaderLen) 109 t.FastWrite(payload[dataFrameHeaderLen:]) 110 binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size)) 111 return writer.WriteData(payload) 112 } 113 payload = mcache.Malloc(size) 114 t.FastWrite(payload) 115 case marshaler: 116 size := t.Size() 117 if !isCompressed { 118 payload = mallocWithFirstByteZeroed(size + dataFrameHeaderLen) 119 if _, err = t.MarshalTo(payload[dataFrameHeaderLen:]); err != nil { 120 return err 121 } 122 binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size)) 123 return writer.WriteData(payload) 124 } 125 payload = mcache.Malloc(size) 126 if _, err = t.MarshalTo(payload); err != nil { 127 return err 128 } 129 case protobufV2MsgCodec: 130 payload, err = t.XXX_Marshal(nil, true) 131 case proto.Message: 132 payload, err = proto.Marshal(t) 133 case protobuf.ProtobufMsgCodec: 134 payload, err = t.Marshal(nil) 135 default: 136 return ErrInvalidPayload 137 } 138 default: 139 return ErrInvalidPayload 140 } 141 142 if err != nil { 143 return err 144 } 145 var header [dataFrameHeaderLen]byte 146 if isCompressed { 147 payload, err = compress(compressor, payload) 148 if err != nil { 149 return err 150 } 151 header[0] = 1 152 } else { 153 header[0] = 0 154 } 155 binary.BigEndian.PutUint32(header[1:dataFrameHeaderLen], uint32(len(payload))) 156 err = writer.WriteHeader(header[:]) 157 if err != nil { 158 return err 159 } 160 return writer.WriteData(payload) 161 // TODO: recycle payload? 162 } 163 164 func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) { 165 d, err := decodeGRPCFrame(ctx, in) 166 if rpcStats := rpcinfo.AsMutableRPCStats(message.RPCInfo().Stats()); rpcStats != nil { 167 // record recv size, even when err != nil (0 is recorded to the lastRecvSize) 168 rpcStats.IncrRecvSize(uint64(len(d))) 169 } 170 if err != nil { 171 return err 172 } 173 message.SetPayloadLen(len(d)) 174 data := message.Data() 175 switch message.ProtocolInfo().CodecType { 176 case serviceinfo.Thrift: 177 return thrift.UnmarshalThriftData(ctx, c.ThriftCodec, "", d, message.Data()) 178 case serviceinfo.Protobuf: 179 if t, ok := data.(fastpb.Reader); ok { 180 if len(d) == 0 { 181 // if all fields of a struct is default value, data will be nil 182 // In the implementation of fastpb, if data is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected. 183 // So, when data is nil, use default protobuf unmarshal method to decode the struct. 184 // todo: fix fastpb 185 } else { 186 _, err = fastpb.ReadMessage(d, fastpb.SkipTypeCheck, t) 187 return err 188 } 189 } 190 switch t := data.(type) { 191 case protobufV2MsgCodec: 192 return t.XXX_Unmarshal(d) 193 case proto.Message: 194 return proto.Unmarshal(d, t) 195 case protobuf.ProtobufMsgCodec: 196 return t.Unmarshal(d) 197 default: 198 return ErrInvalidPayload 199 } 200 default: 201 return ErrInvalidPayload 202 } 203 } 204 205 func (c *grpcCodec) Name() string { 206 return "grpc" 207 }