trpc.group/trpc-go/trpc-go@v1.0.2/codec_stream.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package trpc 15 16 import ( 17 "errors" 18 "fmt" 19 "os" 20 "path" 21 "sync" 22 23 "trpc.group/trpc-go/trpc-go/codec" 24 "trpc.group/trpc-go/trpc-go/errs" 25 "trpc.group/trpc-go/trpc-go/internal/addrutil" 26 icodec "trpc.group/trpc-go/trpc-go/internal/codec" 27 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 28 29 "google.golang.org/protobuf/proto" 30 ) 31 32 var ( 33 // error for unknown stream frame type 34 errUnknownFrameType error = errors.New("unknown stream frame type") 35 // error for invalid total length of client decoding 36 errClientDecodeTotalLength error = errors.New("client decode total length invalid") 37 // error for failing to encode Close frame 38 errEncodeCloseFrame error = errors.New("encode close frame error") 39 // error for failing to encode Feedback frame 40 errEncodeFeedbackFrame error = errors.New("encode feedback error") 41 // error for init metadata not found 42 errUninitializedMeta error = errors.New("uninitialized meta") 43 // error for invalid trpc framehead 44 errFrameHeadTypeInvalid error = errors.New("framehead type invalid") 45 ) 46 47 // NewServerStreamCodec initializes and returns a ServerStreamCodec. 48 func NewServerStreamCodec() *ServerStreamCodec { 49 return &ServerStreamCodec{initMetas: make(map[string]map[uint32]*trpcpb.TrpcStreamInitMeta), m: &sync.RWMutex{}} 50 } 51 52 // NewClientStreamCodec initializes and returns a ClientStreamCodec. 53 func NewClientStreamCodec() *ClientStreamCodec { 54 return &ClientStreamCodec{} 55 } 56 57 // ServerStreamCodec is an implementation of codec.Codec. 58 // Used for trpc server streaming codec. 59 type ServerStreamCodec struct { 60 m *sync.RWMutex 61 initMetas map[string]map[uint32]*trpcpb.TrpcStreamInitMeta // addr->streamID->TrpcStreamInitMeta 62 } 63 64 // ClientStreamCodec is an implementation of codec.Codec. 65 // Used for trpc client streaming codec. 66 type ClientStreamCodec struct { 67 } 68 69 // Encode implements codec.Codec. 70 func (c *ClientStreamCodec) Encode(msg codec.Msg, reqBuf []byte) ([]byte, error) { 71 frameHead, ok := msg.FrameHead().(*FrameHead) 72 if !ok || !frameHead.isStream() { 73 return nil, errUnknownFrameType 74 } 75 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 76 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 77 return c.encodeInitFrame(frameHead, msg, reqBuf) 78 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 79 return c.encodeDataFrame(frameHead, msg, reqBuf) 80 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 81 return c.encodeCloseFrame(frameHead, msg, reqBuf) 82 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 83 return c.encodeFeedbackFrame(frameHead, msg, reqBuf) 84 default: 85 return nil, errUnknownFrameType 86 } 87 } 88 89 // Decode implements codec.Codec. 90 func (c *ClientStreamCodec) Decode(msg codec.Msg, rspBuf []byte) ([]byte, error) { 91 frameHead, ok := msg.FrameHead().(*FrameHead) 92 if !ok || !frameHead.isStream() { 93 return nil, errUnknownFrameType 94 } 95 96 msg.WithStreamID(frameHead.StreamID) 97 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 98 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 99 return c.decodeInitFrame(msg, rspBuf) 100 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 101 return c.decodeDataFrame(msg, rspBuf) 102 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 103 return c.decodeCloseFrame(msg, rspBuf) 104 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 105 return c.decodeFeedbackFrame(msg, rspBuf) 106 default: 107 return nil, errUnknownFrameType 108 } 109 } 110 111 // decodeCloseFrame decodes the Close frame. 112 func (c *ClientStreamCodec) decodeCloseFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 113 // unmarshal Close frame 114 close := &trpcpb.TrpcStreamCloseMeta{} 115 if err := proto.Unmarshal(rspBuf[frameHeadLen:], close); err != nil { 116 return nil, err 117 } 118 119 // It is considered an exception and an error should be returned to the client if: 120 // 1. the CloseType is Reset 121 // 2. ret code != 0 122 if close.GetCloseType() == int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET) || close.GetRet() != 0 { 123 e := &errs.Error{ 124 Type: errs.ErrorTypeCalleeFramework, 125 Code: trpcpb.TrpcRetCode(close.GetRet()), 126 Desc: "trpc", 127 Msg: string(close.GetMsg()), 128 } 129 msg.WithClientRspErr(e) 130 } 131 msg.WithStreamFrame(close) 132 return nil, nil 133 } 134 135 // decodeFeedbackFrame decodes the Feedback frame. 136 func (c *ClientStreamCodec) decodeFeedbackFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 137 feedback := &trpcpb.TrpcStreamFeedBackMeta{} 138 if err := proto.Unmarshal(rspBuf[frameHeadLen:], feedback); err != nil { 139 return nil, err 140 } 141 msg.WithStreamFrame(feedback) 142 return nil, nil 143 } 144 145 // decodeInitFrame decodes the Init frame. 146 func (c *ClientStreamCodec) decodeInitFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 147 initMeta := &trpcpb.TrpcStreamInitMeta{} 148 if err := proto.Unmarshal(rspBuf[frameHeadLen:], initMeta); err != nil { 149 return nil, err 150 } 151 152 msg.WithCompressType(int(initMeta.GetContentEncoding())) 153 msg.WithSerializationType(int(initMeta.GetContentType())) 154 155 // if ret code is not 0, an error should be set and returned 156 if initMeta.GetResponseMeta().GetRet() != 0 { 157 e := &errs.Error{ 158 Type: errs.ErrorTypeCalleeFramework, 159 Code: trpcpb.TrpcRetCode(initMeta.GetResponseMeta().GetRet()), 160 Desc: "trpc", 161 Msg: string(initMeta.GetResponseMeta().GetErrorMsg()), 162 } 163 msg.WithClientRspErr(e) 164 } 165 msg.WithStreamFrame(initMeta) 166 return nil, nil 167 168 } 169 170 // decodeDataFrame decodes the Data frame. 171 func (c *ClientStreamCodec) decodeDataFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 172 // decoding Data frame is straightforward, 173 // as it just returns all data following the frame head 174 return rspBuf[frameHeadLen:], nil 175 } 176 177 // encodeInitFrame encodes the Init frame. 178 func (c *ClientStreamCodec) encodeInitFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 179 initMeta, ok := msg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) 180 if !ok { 181 initMeta = &trpcpb.TrpcStreamInitMeta{} 182 initMeta.RequestMeta = &trpcpb.TrpcStreamInitRequestMeta{} 183 } 184 req := initMeta.RequestMeta 185 // set caller service name 186 // if nil, use the name of the process 187 if msg.CallerServiceName() == "" { 188 msg.WithCallerServiceName(fmt.Sprintf("trpc.app.%s.service", path.Base(os.Args[0]))) 189 } 190 req.Caller = []byte(msg.CallerServiceName()) 191 // set callee service name 192 req.Callee = []byte(msg.CalleeServiceName()) 193 // set backend rpc name, ClientRPCName already set by upper layer of client stub 194 req.Func = []byte(msg.ClientRPCName()) 195 // set backend serialization type 196 initMeta.ContentType = uint32(msg.SerializationType()) 197 // set backend compression type 198 initMeta.ContentEncoding = uint32(msg.CompressType()) 199 // set dyeing info 200 if msg.Dyeing() { 201 req.MessageType = req.MessageType | uint32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE) 202 } 203 // set client transinfo 204 req.TransInfo = setClientTransInfo(msg, req.TransInfo) 205 streamBuf, err := proto.Marshal(initMeta) 206 if err != nil { 207 return nil, err 208 } 209 return frameWrite(frameHead, streamBuf) 210 } 211 212 // encodeDataFrame encodes the Data frame. 213 func (c *ClientStreamCodec) encodeDataFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 214 return frameWrite(frameHead, reqBuf) 215 } 216 217 // encodeCloseFrame encodes the Close frame. 218 func (c *ClientStreamCodec) encodeCloseFrame(frameHead *FrameHead, msg codec.Msg, 219 reqBuf []byte) (rspbuf []byte, err error) { 220 closeFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamCloseMeta) 221 if !ok { 222 return nil, errEncodeCloseFrame 223 } 224 streamBuf, err := proto.Marshal(closeFrame) 225 if err != nil { 226 return nil, err 227 } 228 return frameWrite(frameHead, streamBuf) 229 } 230 231 // encodeFeedbackFrame encodes the Feedback frame. 232 func (c *ClientStreamCodec) encodeFeedbackFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 233 feedbackFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta) 234 if !ok { 235 return nil, errEncodeFeedbackFrame 236 } 237 streamBuf, err := proto.Marshal(feedbackFrame) 238 if err != nil { 239 return nil, err 240 } 241 return frameWrite(frameHead, streamBuf) 242 } 243 244 // frameWrite converts FrameHead to binary frame. 245 func frameWrite(frameHead *FrameHead, streamBuf []byte) ([]byte, error) { 246 // no pb header for streaming rpc 247 return frameHead.construct(nil, streamBuf, nil) 248 } 249 250 // encodeCloseFrame encodes the Close frame. 251 func (s *ServerStreamCodec) encodeCloseFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 252 defer s.deleteInitMeta(msg) 253 closeFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamCloseMeta) 254 if !ok { 255 return nil, errEncodeCloseFrame 256 } 257 msg.WithStreamID(frameHead.StreamID) 258 streamBuf, err := proto.Marshal(closeFrame) 259 if err != nil { 260 return nil, err 261 } 262 return frameWrite(frameHead, streamBuf) 263 } 264 265 // encodeDataFrame encodes the Data frame. 266 func (s *ServerStreamCodec) encodeDataFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 267 // If there is an error when processing the Data frame, 268 // then return the Close frame and close the current stream. 269 if err := msg.ServerRspErr(); err != nil { 270 s.buildResetFrame(msg, frameHead, err) 271 return s.encodeCloseFrame(frameHead, msg, reqBuf) 272 } 273 return frameWrite(frameHead, reqBuf) 274 } 275 276 // encodeInitFrame encodes the Init frame. 277 func (s *ServerStreamCodec) encodeInitFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 278 rsp := getStreamInitMeta(msg) 279 rsp.ContentType = uint32(msg.SerializationType()) 280 rsp.ContentEncoding = uint32(msg.CompressType()) 281 rspMeta := &trpcpb.TrpcStreamInitResponseMeta{} 282 if e := msg.ServerRspErr(); e != nil { 283 rspMeta.Ret = int32(e.Code) 284 rspMeta.ErrorMsg = []byte(e.Msg) 285 } 286 rsp.ResponseMeta = rspMeta 287 streamBuf, err := proto.Marshal(rsp) 288 if err != nil { 289 return nil, err 290 } 291 return frameWrite(frameHead, streamBuf) 292 } 293 294 // encodeFeedbackFrame encodes the Feedback frame. 295 func (s *ServerStreamCodec) encodeFeedbackFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 296 feedback, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta) 297 if !ok { 298 return nil, errEncodeFeedbackFrame 299 } 300 streamBuf, err := proto.Marshal(feedback) 301 if err != nil { 302 return nil, err 303 } 304 return frameWrite(frameHead, streamBuf) 305 } 306 307 // getStreamInitMeta returns TrpcStreamInitMeta from msg. 308 // If not found, a new TrpcStreamInitMeta will be created and returned. 309 func getStreamInitMeta(msg codec.Msg) *trpcpb.TrpcStreamInitMeta { 310 rsp, ok := msg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) 311 if !ok { 312 rsp = &trpcpb.TrpcStreamInitMeta{ResponseMeta: &trpcpb.TrpcStreamInitResponseMeta{}} 313 } 314 return rsp 315 } 316 317 // Encode implements codec.Codec. 318 func (s *ServerStreamCodec) Encode(msg codec.Msg, reqBuf []byte) (rspbuf []byte, err error) { 319 frameHead, ok := msg.FrameHead().(*FrameHead) 320 if !ok || !frameHead.isStream() { 321 return nil, errUnknownFrameType 322 } 323 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 324 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 325 return s.encodeInitFrame(frameHead, msg, reqBuf) 326 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 327 return s.encodeDataFrame(frameHead, msg, reqBuf) 328 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 329 return s.encodeCloseFrame(frameHead, msg, reqBuf) 330 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 331 return s.encodeFeedbackFrame(frameHead, msg, reqBuf) 332 default: 333 return nil, errUnknownFrameType 334 } 335 } 336 337 // Decode implements codec.Codec. 338 // It decodes the head and the stream frame data. 339 func (s *ServerStreamCodec) Decode(msg codec.Msg, reqBuf []byte) ([]byte, error) { 340 frameHead, ok := msg.FrameHead().(*FrameHead) 341 if !ok || !frameHead.isStream() { 342 return nil, errUnknownFrameType 343 } 344 msg.WithStreamID(frameHead.StreamID) 345 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 346 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 347 return s.decodeInitFrame(msg, reqBuf) 348 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 349 return s.decodeDataFrame(msg, reqBuf) 350 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 351 return s.decodeCloseFrame(msg, reqBuf) 352 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 353 return s.decodeFeedbackFrame(msg, reqBuf) 354 default: 355 return nil, errUnknownFrameType 356 } 357 } 358 359 // decodeFeedbackFrame decodes the Feedback frame. 360 func (s *ServerStreamCodec) decodeFeedbackFrame(msg codec.Msg, reqBuf []byte) ([]byte, error) { 361 if err := s.setInitMeta(msg); err != nil { 362 return nil, err 363 } 364 feedback := &trpcpb.TrpcStreamFeedBackMeta{} 365 if err := proto.Unmarshal(reqBuf[frameHeadLen:], feedback); err != nil { 366 return nil, err 367 } 368 msg.WithStreamFrame(feedback) 369 return nil, nil 370 } 371 372 // setInitMeta finds the InitMeta and sets the ServerRPCName by the server handler in the InitMeta. 373 func (s *ServerStreamCodec) setInitMeta(msg codec.Msg) error { 374 streamID := msg.StreamID() 375 addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) 376 s.m.RLock() 377 defer s.m.RUnlock() 378 if streamIDToInitMeta, ok := s.initMetas[addr]; ok { 379 if initMeta, ok := streamIDToInitMeta[streamID]; ok { 380 rpcName := string(initMeta.GetRequestMeta().GetFunc()) 381 msg.WithServerRPCName(rpcName) 382 msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) 383 return nil 384 } 385 } 386 return errUninitializedMeta 387 } 388 389 // deleteInitMeta deletes the cached info by msg. 390 func (s *ServerStreamCodec) deleteInitMeta(msg codec.Msg) { 391 addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) 392 streamID := msg.StreamID() 393 s.m.Lock() 394 defer s.m.Unlock() 395 delete(s.initMetas[addr], streamID) 396 if len(s.initMetas[addr]) == 0 { 397 delete(s.initMetas, addr) 398 } 399 } 400 401 // decodeCloseFrame decodes the Close frame. 402 func (s *ServerStreamCodec) decodeCloseFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 403 if err := s.setInitMeta(msg); err != nil { 404 return nil, err 405 } 406 close := &trpcpb.TrpcStreamCloseMeta{} 407 if err := proto.Unmarshal(rspBuf[frameHeadLen:], close); err != nil { 408 return nil, err 409 } 410 // It is considered an exception and an error should be returned to the client if: 411 // 1. the CloseType is Reset 412 // 2. ret code != 0 413 if close.GetCloseType() == int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET) || close.GetRet() != 0 { 414 e := &errs.Error{ 415 Type: errs.ErrorTypeCalleeFramework, 416 Code: trpcpb.TrpcRetCode(close.GetRet()), 417 Desc: "trpc", 418 Msg: string(close.GetMsg()), 419 } 420 msg.WithServerRspErr(e) 421 } 422 msg.WithStreamFrame(close) 423 return nil, nil 424 } 425 426 // decodeDataFrame decodes the Data frame. 427 func (s *ServerStreamCodec) decodeDataFrame(msg codec.Msg, reqBuf []byte) ([]byte, error) { 428 if err := s.setInitMeta(msg); err != nil { 429 return nil, err 430 } 431 reqBody := reqBuf[frameHeadLen:] 432 return reqBody, nil 433 } 434 435 // decodeInitFrame decodes the Init frame. 436 func (s *ServerStreamCodec) decodeInitFrame(msg codec.Msg, reqBuf []byte) ([]byte, error) { 437 initMeta := &trpcpb.TrpcStreamInitMeta{} 438 if err := proto.Unmarshal(reqBuf[frameHeadLen:], initMeta); err != nil { 439 return nil, err 440 } 441 s.updateMsg(msg, initMeta) 442 s.storeInitMeta(msg, initMeta) 443 msg.WithStreamFrame(initMeta) 444 return nil, nil 445 } 446 447 // storeInitMeta stores the InitMeta every time when a new frame is received. 448 func (s *ServerStreamCodec) storeInitMeta(msg codec.Msg, initMeta *trpcpb.TrpcStreamInitMeta) { 449 streamID := msg.StreamID() 450 addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) 451 s.m.Lock() 452 defer s.m.Unlock() 453 if _, ok := s.initMetas[addr]; ok { 454 s.initMetas[addr][streamID] = initMeta 455 } else { 456 t := make(map[uint32]*trpcpb.TrpcStreamInitMeta) 457 t[streamID] = initMeta 458 s.initMetas[addr] = t 459 } 460 } 461 462 // updateMsg updates the Msg by InitMeta. 463 func (s *ServerStreamCodec) updateMsg(msg codec.Msg, initMeta *trpcpb.TrpcStreamInitMeta) { 464 // get request meta 465 req := initMeta.GetRequestMeta() 466 467 // set caller service name 468 msg.WithCallerServiceName(string(req.GetCaller())) 469 msg.WithCalleeServiceName(string(req.GetCallee())) 470 // set server handler method name 471 rpcName := string(req.GetFunc()) 472 msg.WithServerRPCName(rpcName) 473 msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) 474 // set body serialization type 475 msg.WithSerializationType(int(initMeta.GetContentType())) 476 // set body compression type 477 msg.WithCompressType(int(initMeta.GetContentEncoding())) 478 msg.WithDyeing((req.GetMessageType() & uint32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)) != 0) 479 480 if len(req.TransInfo) > 0 { 481 msg.WithServerMetaData(req.GetTransInfo()) 482 // set dyeing key 483 if bs, ok := req.TransInfo[DyeingKey]; ok { 484 msg.WithDyeingKey(string(bs)) 485 } 486 // set environment message for transfer 487 if envs, ok := req.TransInfo[EnvTransfer]; ok { 488 msg.WithEnvTransfer(string(envs)) 489 } 490 } 491 } 492 493 func (s *ServerStreamCodec) buildResetFrame(msg codec.Msg, frameHead *FrameHead, err error) { 494 frameHead.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 495 closeMeta := &trpcpb.TrpcStreamCloseMeta{ 496 CloseType: int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), 497 Ret: int32(errs.Code(err)), 498 Msg: []byte(errs.Msg(err)), 499 } 500 msg.WithStreamFrame(closeMeta) 501 }