github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/default_codec.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package codec 18 19 import ( 20 "context" 21 "encoding/binary" 22 "fmt" 23 "sync/atomic" 24 25 "github.com/cloudwego/kitex/pkg/kerrors" 26 "github.com/cloudwego/kitex/pkg/remote" 27 "github.com/cloudwego/kitex/pkg/remote/codec/perrors" 28 "github.com/cloudwego/kitex/pkg/retry" 29 "github.com/cloudwego/kitex/pkg/rpcinfo" 30 "github.com/cloudwego/kitex/pkg/serviceinfo" 31 "github.com/cloudwego/kitex/transport" 32 ) 33 34 // The byte count of 32 and 16 integer values. 35 const ( 36 Size32 = 4 37 Size16 = 2 38 ) 39 40 const ( 41 // ThriftV1Magic is the magic code for thrift.VERSION_1 42 ThriftV1Magic = 0x80010000 43 // ProtobufV1Magic is the magic code for kitex protobuf 44 ProtobufV1Magic = 0x90010000 45 46 // MagicMask is bit mask for checking version. 47 MagicMask = 0xffff0000 48 ) 49 50 var ( 51 ttHeaderCodec = ttHeader{} 52 meshHeaderCodec = meshHeader{} 53 54 _ remote.Codec = (*defaultCodec)(nil) 55 _ remote.MetaDecoder = (*defaultCodec)(nil) 56 ) 57 58 // NewDefaultCodec creates the default protocol sniffing codec supporting thrift and protobuf. 59 func NewDefaultCodec() remote.Codec { 60 // No size limit by default 61 return &defaultCodec{ 62 maxSize: 0, 63 } 64 } 65 66 // NewDefaultCodecWithSizeLimit creates the default protocol sniffing codec supporting thrift and protobuf but with size limit. 67 // maxSize is in bytes 68 func NewDefaultCodecWithSizeLimit(maxSize int) remote.Codec { 69 return &defaultCodec{ 70 maxSize: maxSize, 71 } 72 } 73 74 type defaultCodec struct { 75 // maxSize limits the max size of the payload 76 maxSize int 77 } 78 79 // EncodePayload encode payload 80 func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message, out remote.ByteBuffer) error { 81 defer func() { 82 // notice: mallocLen() must exec before flush, or it will be reset 83 if ri := message.RPCInfo(); ri != nil { 84 if ms := rpcinfo.AsMutableRPCStats(ri.Stats()); ms != nil { 85 ms.SetSendSize(uint64(out.MallocLen())) 86 } 87 } 88 }() 89 var err error 90 var framedLenField []byte 91 headerLen := out.MallocLen() 92 tp := message.ProtocolInfo().TransProto 93 94 // 1. malloc framed field if needed 95 if tp&transport.Framed == transport.Framed { 96 if framedLenField, err = out.Malloc(Size32); err != nil { 97 return err 98 } 99 headerLen += Size32 100 } 101 102 // 2. encode payload 103 if err = c.encodePayload(ctx, message, out); err != nil { 104 return err 105 } 106 107 // 3. fill framed field if needed 108 var payloadLen int 109 if tp&transport.Framed == transport.Framed { 110 if framedLenField == nil { 111 return perrors.NewProtocolErrorWithMsg("no buffer allocated for the framed length field") 112 } 113 payloadLen = out.MallocLen() - headerLen 114 binary.BigEndian.PutUint32(framedLenField, uint32(payloadLen)) 115 } else if message.ProtocolInfo().CodecType == serviceinfo.Protobuf { 116 return perrors.NewProtocolErrorWithMsg("protobuf just support 'framed' trans proto") 117 } 118 if tp&transport.TTHeader == transport.TTHeader { 119 payloadLen = out.MallocLen() - Size32 120 } 121 err = checkPayloadSize(payloadLen, c.maxSize) 122 return err 123 } 124 125 // EncodeMetaAndPayload encode meta and payload 126 func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error { 127 var err error 128 var totalLenField []byte 129 tp := message.ProtocolInfo().TransProto 130 131 // 1. encode header and return totalLenField if needed 132 // totalLenField will be filled after payload encoded 133 if tp&transport.TTHeader == transport.TTHeader { 134 if totalLenField, err = ttHeaderCodec.encode(ctx, message, out); err != nil { 135 return err 136 } 137 } 138 // 2. encode payload 139 if err = me.EncodePayload(ctx, message, out); err != nil { 140 return err 141 } 142 // 3. fill totalLen field for header if needed 143 if tp&transport.TTHeader == transport.TTHeader { 144 if totalLenField == nil { 145 return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field") 146 } 147 payloadLen := out.MallocLen() - Size32 148 binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen)) 149 } 150 return nil 151 } 152 153 // Encode implements the remote.Codec interface, it does complete message encode include header and payload. 154 func (c *defaultCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) { 155 return c.EncodeMetaAndPayload(ctx, message, out, c) 156 } 157 158 // DecodeMeta decode header 159 func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) { 160 var flagBuf []byte 161 if flagBuf, err = in.Peek(2 * Size32); err != nil { 162 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("default codec read failed: %s", err.Error())) 163 } 164 165 if err = checkRPCState(ctx, message); err != nil { 166 // there is one call has finished in retry task, it doesn't need to do decode for this call 167 return err 168 } 169 isTTHeader := IsTTHeader(flagBuf) 170 // 1. decode header 171 if isTTHeader { 172 // TTHeader 173 if err = ttHeaderCodec.decode(ctx, message, in); err != nil { 174 return err 175 } 176 if flagBuf, err = in.Peek(2 * Size32); err != nil { 177 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error())) 178 } 179 } else if isMeshHeader(flagBuf) { 180 message.Tags()[remote.MeshHeader] = true 181 // MeshHeader 182 if err = meshHeaderCodec.decode(ctx, message, in); err != nil { 183 return err 184 } 185 if flagBuf, err = in.Peek(2 * Size32); err != nil { 186 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("meshHeader read payload first 8 byte failed: %s", err.Error())) 187 } 188 } 189 return checkPayload(flagBuf, message, in, isTTHeader, c.maxSize) 190 } 191 192 // DecodePayload decode payload 193 func (c *defaultCodec) DecodePayload(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { 194 defer func() { 195 if ri := message.RPCInfo(); ri != nil { 196 if ms := rpcinfo.AsMutableRPCStats(ri.Stats()); ms != nil { 197 ms.SetRecvSize(uint64(in.ReadLen())) 198 } 199 } 200 }() 201 202 hasRead := in.ReadLen() 203 pCodec, err := remote.GetPayloadCodec(message) 204 if err != nil { 205 return err 206 } 207 if err = pCodec.Unmarshal(ctx, message, in); err != nil { 208 return err 209 } 210 if message.PayloadLen() == 0 { 211 // if protocol is PurePayload, should set payload length after decoded 212 message.SetPayloadLen(in.ReadLen() - hasRead) 213 } 214 return nil 215 } 216 217 // Decode implements the remote.Codec interface, it does complete message decode include header and payload. 218 func (c *defaultCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) { 219 // 1. decode meta 220 if err = c.DecodeMeta(ctx, message, in); err != nil { 221 return err 222 } 223 224 // 2. decode payload 225 return c.DecodePayload(ctx, message, in) 226 } 227 228 func (c *defaultCodec) Name() string { 229 return "default" 230 } 231 232 // Select to use thrift or protobuf according to the protocol. 233 func (c *defaultCodec) encodePayload(ctx context.Context, message remote.Message, out remote.ByteBuffer) error { 234 pCodec, err := remote.GetPayloadCodec(message) 235 if err != nil { 236 return err 237 } 238 return pCodec.Marshal(ctx, message, out) 239 } 240 241 /** 242 * +------------------------------------------------------------+ 243 * | 4Byte | 2Byte | 244 * +------------------------------------------------------------+ 245 * | Length | HEADER MAGIC | 246 * +------------------------------------------------------------+ 247 */ 248 func IsTTHeader(flagBuf []byte) bool { 249 return binary.BigEndian.Uint32(flagBuf[Size32:])&MagicMask == TTHeaderMagic 250 } 251 252 /** 253 * +----------------------------------------+ 254 * | 2Byte | 2Byte | 255 * +----------------------------------------+ 256 * | HEADER MAGIC | HEADER SIZE | 257 * +----------------------------------------+ 258 */ 259 func isMeshHeader(flagBuf []byte) bool { 260 return binary.BigEndian.Uint32(flagBuf[:Size32])&MagicMask == MeshHeaderMagic 261 } 262 263 /** 264 * Kitex protobuf has framed field 265 * +------------------------------------------------------------+ 266 * | 4Byte | 2Byte | 267 * +------------------------------------------------------------+ 268 * | Length | HEADER MAGIC | 269 * +------------------------------------------------------------+ 270 */ 271 func isProtobufKitex(flagBuf []byte) bool { 272 return binary.BigEndian.Uint32(flagBuf[Size32:])&MagicMask == ProtobufV1Magic 273 } 274 275 /** 276 * +-------------------+ 277 * | 2Byte | 278 * +-------------------+ 279 * | HEADER MAGIC | 280 * +------------------- 281 */ 282 func isThriftBinary(flagBuf []byte) bool { 283 return binary.BigEndian.Uint32(flagBuf[:Size32])&MagicMask == ThriftV1Magic 284 } 285 286 /** 287 * +------------------------------------------------------------+ 288 * | 4Byte | 2Byte | 289 * +------------------------------------------------------------+ 290 * | Length | HEADER MAGIC | 291 * +------------------------------------------------------------+ 292 */ 293 func isThriftFramedBinary(flagBuf []byte) bool { 294 return binary.BigEndian.Uint32(flagBuf[Size32:])&MagicMask == ThriftV1Magic 295 } 296 297 func checkRPCState(ctx context.Context, message remote.Message) error { 298 if message.RPCRole() == remote.Server { 299 return nil 300 } 301 if ctx.Err() == context.DeadlineExceeded || ctx.Err() == context.Canceled { 302 return kerrors.ErrRPCFinish 303 } 304 if respOp, ok := ctx.Value(retry.CtxRespOp).(*int32); ok { 305 if !atomic.CompareAndSwapInt32(respOp, retry.OpNo, retry.OpDoing) { 306 // previous call is being handling or done 307 // this flag is used to check request status in retry(backup request) scene 308 return kerrors.ErrRPCFinish 309 } 310 } 311 return nil 312 } 313 314 func checkPayload(flagBuf []byte, message remote.Message, in remote.ByteBuffer, isTTHeader bool, maxPayloadSize int) error { 315 var transProto transport.Protocol 316 var codecType serviceinfo.PayloadCodec 317 if isThriftBinary(flagBuf) { 318 codecType = serviceinfo.Thrift 319 if isTTHeader { 320 transProto = transport.TTHeader 321 } else { 322 transProto = transport.PurePayload 323 } 324 } else if isThriftFramedBinary(flagBuf) { 325 codecType = serviceinfo.Thrift 326 if isTTHeader { 327 transProto = transport.TTHeaderFramed 328 } else { 329 transProto = transport.Framed 330 } 331 payloadLen := binary.BigEndian.Uint32(flagBuf[:Size32]) 332 message.SetPayloadLen(int(payloadLen)) 333 if err := in.Skip(Size32); err != nil { 334 return err 335 } 336 } else if isProtobufKitex(flagBuf) { 337 codecType = serviceinfo.Protobuf 338 if isTTHeader { 339 transProto = transport.TTHeaderFramed 340 } else { 341 transProto = transport.Framed 342 } 343 payloadLen := binary.BigEndian.Uint32(flagBuf[:Size32]) 344 message.SetPayloadLen(int(payloadLen)) 345 if err := in.Skip(Size32); err != nil { 346 return err 347 } 348 } else { 349 first4Bytes := binary.BigEndian.Uint32(flagBuf[:Size32]) 350 second4Bytes := binary.BigEndian.Uint32(flagBuf[Size32:]) 351 // 0xfff4fffd is the interrupt message of telnet 352 err := perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid payload (first4Bytes=%#x, second4Bytes=%#x)", first4Bytes, second4Bytes)) 353 return err 354 } 355 if err := checkPayloadSize(message.PayloadLen(), maxPayloadSize); err != nil { 356 return err 357 } 358 message.SetProtocolInfo(remote.NewProtocolInfo(transProto, codecType)) 359 cfg := rpcinfo.AsMutableRPCConfig(message.RPCInfo().Config()) 360 if cfg != nil { 361 tp := message.ProtocolInfo().TransProto 362 cfg.SetTransportProtocol(tp) 363 } 364 return nil 365 } 366 367 func checkPayloadSize(payloadLen, maxSize int) error { 368 if maxSize > 0 && payloadLen > 0 && payloadLen > maxSize { 369 return perrors.NewProtocolErrorWithType( 370 perrors.InvalidData, 371 fmt.Sprintf("invalid data: payload size(%d) larger than the limit(%d)", payloadLen, maxSize), 372 ) 373 } 374 return nil 375 }