github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/protobuf/protobuf.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package protobuf 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 24 "github.com/cloudwego/fastpb" 25 26 "github.com/cloudwego/kitex/pkg/remote" 27 "github.com/cloudwego/kitex/pkg/remote/codec" 28 "github.com/cloudwego/kitex/pkg/remote/codec/perrors" 29 ) 30 31 /** 32 * Kitex Protobuf Protocol 33 * |----------Len--------|--------------------------------MetaInfo--------------------------------| 34 * |---------4Byte-------|----2Byte----|----2Byte----|---------String-------|---------4Byte-------| 35 * +----------------------------------------------------------------------------------------------+ 36 * | PayloadLen | Magic | MsgType | MethodName | SeqID | 37 * +----------------------------------------------------------------------------------------------+ 38 * | | 39 * | Protobuf Argument/Result/Error | 40 * | | 41 * +----------------------------------------------------------------------------------------------+ 42 */ 43 44 const ( 45 metaInfoFixLen = 8 46 ) 47 48 // NewProtobufCodec ... 49 func NewProtobufCodec() remote.PayloadCodec { 50 return &protobufCodec{} 51 } 52 53 // protobufCodec implements PayloadMarshaler 54 type protobufCodec struct{} 55 56 // Len encode outside not here 57 func (c protobufCodec) Marshal(ctx context.Context, message remote.Message, out remote.ByteBuffer) error { 58 // 1. prepare info 59 methodName := message.RPCInfo().Invocation().MethodName() 60 if methodName == "" { 61 return errors.New("empty methodName in protobuf Marshal") 62 } 63 data, err := getValidData(methodName, message) 64 if err != nil { 65 return err 66 } 67 68 // 3. encode metainfo 69 // 3.1 magic && msgType 70 if err := codec.WriteUint32(codec.ProtobufV1Magic+uint32(message.MessageType()), out); err != nil { 71 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write meta info failed: %s", err.Error())) 72 } 73 // 3.2 methodName 74 if _, err := codec.WriteString(methodName, out); err != nil { 75 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write method name failed: %s", err.Error())) 76 } 77 // 3.3 seqID 78 if err := codec.WriteUint32(uint32(message.RPCInfo().Invocation().SeqID()), out); err != nil { 79 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write seqID failed: %s", err.Error())) 80 } 81 82 // 4. write actual message buf 83 msg, ok := data.(ProtobufMsgCodec) 84 if !ok { 85 // If Using Generics 86 // if data is a MessageWriterWithContext 87 // Do msg.WritePb(ctx context.Context, out remote.ByteBuffer) 88 genmsg, isgen := data.(MessageWriterWithContext) 89 if isgen { 90 actualMsg, err := genmsg.WritePb(ctx) 91 if err != nil { 92 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf marshal message failed: %s", err.Error())) 93 } 94 actualMsgBuf, ok := actualMsg.([]byte) 95 if !ok { 96 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf marshal message failed: %s", err.Error())) 97 } 98 _, err = out.WriteBinary(actualMsgBuf) 99 if err != nil { 100 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write message buffer failed: %s", err.Error())) 101 } 102 return nil 103 } 104 // return error otherwise 105 return remote.NewTransErrorWithMsg(remote.InvalidProtocol, "encode failed, codec msg type not match with protobufCodec") 106 } 107 108 // 2. encode pb struct 109 // fast write 110 if msg, ok := data.(fastpb.Writer); ok { 111 msgsize := msg.Size() 112 actualMsgBuf, err := out.Malloc(msgsize) 113 if err != nil { 114 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf malloc size %d failed: %s", msgsize, err.Error())) 115 } 116 msg.FastWrite(actualMsgBuf) 117 return nil 118 } 119 120 var actualMsgBuf []byte 121 if actualMsgBuf, err = msg.Marshal(actualMsgBuf); err != nil { 122 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf marshal message failed: %s", err.Error())) 123 } 124 if _, err = out.WriteBinary(actualMsgBuf); err != nil { 125 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write message buffer failed: %s", err.Error())) 126 } 127 return nil 128 } 129 130 func (c protobufCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { 131 payloadLen := message.PayloadLen() 132 magicAndMsgType, err := codec.ReadUint32(in) 133 if err != nil { 134 return err 135 } 136 if magicAndMsgType&codec.MagicMask != codec.ProtobufV1Magic { 137 return perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in protobuf Unmarshal") 138 } 139 msgType := magicAndMsgType & codec.FrontMask 140 if err := codec.UpdateMsgType(msgType, message); err != nil { 141 return err 142 } 143 144 methodName, methodFieldLen, err := codec.ReadString(in) 145 if err != nil { 146 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf unmarshal, read method name failed: %s", err.Error())) 147 } 148 if err = codec.SetOrCheckMethodName(methodName, message); err != nil && msgType != uint32(remote.Exception) { 149 return err 150 } 151 seqID, err := codec.ReadUint32(in) 152 if err != nil { 153 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf unmarshal, read seqID failed: %s", err.Error())) 154 } 155 if err = codec.SetOrCheckSeqID(int32(seqID), message); err != nil && msgType != uint32(remote.Exception) { 156 return err 157 } 158 actualMsgLen := payloadLen - metaInfoFixLen - methodFieldLen 159 actualMsgBuf, err := in.Next(actualMsgLen) 160 if err != nil { 161 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf unmarshal, read message buffer failed: %s", err.Error())) 162 } 163 // exception message 164 if message.MessageType() == remote.Exception { 165 var exception pbError 166 if err := exception.Unmarshal(actualMsgBuf); err != nil { 167 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf unmarshal Exception failed: %s", err.Error())) 168 } 169 return remote.NewTransError(exception.TypeID(), &exception) 170 } 171 172 if err = codec.NewDataIfNeeded(methodName, message); err != nil { 173 return err 174 } 175 data := message.Data() 176 177 // fast read 178 if msg, ok := data.(fastpb.Reader); ok { 179 if len(actualMsgBuf) == 0 { 180 // if all fields of a struct is default value, actualMsgLen will be zero and actualMsgBuf will be nil 181 // In the implementation of fastpb, if actualMsgBuf is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected. 182 // So, when actualMsgBuf is nil, use default protobuf unmarshal method to decode the struct. 183 // todo: fix fastpb 184 } else { 185 _, err := fastpb.ReadMessage(actualMsgBuf, fastpb.SkipTypeCheck, msg) 186 if err != nil { 187 return remote.NewTransErrorWithMsg(remote.ProtocolError, err.Error()) 188 } 189 return nil 190 } 191 } 192 193 // JSONPB Generic Case 194 if msg, ok := data.(MessageReaderWithMethodWithContext); ok { 195 err := msg.ReadPb(ctx, methodName, actualMsgBuf) 196 if err != nil { 197 return err 198 } 199 return nil 200 } 201 202 msg, ok := data.(ProtobufMsgCodec) 203 if !ok { 204 return remote.NewTransErrorWithMsg(remote.InvalidProtocol, "decode failed, codec msg type not match with protobufCodec") 205 } 206 if err = msg.Unmarshal(actualMsgBuf); err != nil { 207 return remote.NewTransErrorWithMsg(remote.ProtocolError, err.Error()) 208 } 209 return err 210 } 211 212 func (c protobufCodec) Name() string { 213 return "protobuf" 214 } 215 216 // MessageWriterWithContext writes to output bytebuffer 217 type MessageWriterWithContext interface { 218 WritePb(ctx context.Context) (interface{}, error) 219 } 220 221 // MessageReaderWithMethodWithContext read from ActualMsgBuf with method 222 type MessageReaderWithMethodWithContext interface { 223 ReadPb(ctx context.Context, method string, in []byte) error 224 } 225 226 type ProtobufMsgCodec interface { 227 Marshal(out []byte) ([]byte, error) 228 Unmarshal(in []byte) error 229 } 230 231 func getValidData(methodName string, message remote.Message) (interface{}, error) { 232 if err := codec.NewDataIfNeeded(methodName, message); err != nil { 233 return nil, err 234 } 235 data := message.Data() 236 if message.MessageType() != remote.Exception { 237 return data, nil 238 } 239 transErr, isTransErr := data.(*remote.TransError) 240 if !isTransErr { 241 if err, isError := data.(error); isError { 242 encodeErr := NewPbError(remote.InternalError, err.Error()) 243 return encodeErr, nil 244 } 245 return nil, errors.New("exception relay need error type data") 246 } 247 encodeErr := NewPbError(transErr.TypeID(), transErr.Error()) 248 return encodeErr, nil 249 }