trpc.group/trpc-go/trpc-go@v1.0.3/codec_stream.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package trpc 15 16 import ( 17 "errors" 18 "fmt" 19 "os" 20 "path" 21 "sync" 22 23 "trpc.group/trpc-go/trpc-go/codec" 24 "trpc.group/trpc-go/trpc-go/errs" 25 "trpc.group/trpc-go/trpc-go/internal/addrutil" 26 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 27 28 "google.golang.org/protobuf/proto" 29 ) 30 31 var ( 32 // error for unknown stream frame type 33 errUnknownFrameType error = errors.New("unknown stream frame type") 34 // error for invalid total length of client decoding 35 errClientDecodeTotalLength error = errors.New("client decode total length invalid") 36 // error for failing to encode Close frame 37 errEncodeCloseFrame error = errors.New("encode close frame error") 38 // error for failing to encode Feedback frame 39 errEncodeFeedbackFrame error = errors.New("encode feedback error") 40 // error for init metadata not found 41 errUninitializedMeta error = errors.New("uninitialized meta") 42 // error for invalid trpc framehead 43 errFrameHeadTypeInvalid error = errors.New("framehead type invalid") 44 ) 45 46 // NewServerStreamCodec initializes and returns a ServerStreamCodec. 47 func NewServerStreamCodec() *ServerStreamCodec { 48 return &ServerStreamCodec{initMetas: make(map[string]map[uint32]*trpcpb.TrpcStreamInitMeta), m: &sync.RWMutex{}} 49 } 50 51 // NewClientStreamCodec initializes and returns a ClientStreamCodec. 52 func NewClientStreamCodec() *ClientStreamCodec { 53 return &ClientStreamCodec{} 54 } 55 56 // ServerStreamCodec is an implementation of codec.Codec. 57 // Used for trpc server streaming codec. 58 type ServerStreamCodec struct { 59 m *sync.RWMutex 60 initMetas map[string]map[uint32]*trpcpb.TrpcStreamInitMeta // addr->streamID->TrpcStreamInitMeta 61 } 62 63 // ClientStreamCodec is an implementation of codec.Codec. 64 // Used for trpc client streaming codec. 65 type ClientStreamCodec struct { 66 } 67 68 // Encode implements codec.Codec. 69 func (c *ClientStreamCodec) Encode(msg codec.Msg, reqBuf []byte) ([]byte, error) { 70 frameHead, ok := msg.FrameHead().(*FrameHead) 71 if !ok || !frameHead.isStream() { 72 return nil, errUnknownFrameType 73 } 74 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 75 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 76 return c.encodeInitFrame(frameHead, msg, reqBuf) 77 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 78 return c.encodeDataFrame(frameHead, msg, reqBuf) 79 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 80 return c.encodeCloseFrame(frameHead, msg, reqBuf) 81 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 82 return c.encodeFeedbackFrame(frameHead, msg, reqBuf) 83 default: 84 return nil, errUnknownFrameType 85 } 86 } 87 88 // Decode implements codec.Codec. 89 func (c *ClientStreamCodec) Decode(msg codec.Msg, rspBuf []byte) ([]byte, error) { 90 frameHead, ok := msg.FrameHead().(*FrameHead) 91 if !ok || !frameHead.isStream() { 92 return nil, errUnknownFrameType 93 } 94 95 msg.WithStreamID(frameHead.StreamID) 96 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 97 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 98 return c.decodeInitFrame(msg, rspBuf) 99 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 100 return c.decodeDataFrame(msg, rspBuf) 101 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 102 return c.decodeCloseFrame(msg, rspBuf) 103 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 104 return c.decodeFeedbackFrame(msg, rspBuf) 105 default: 106 return nil, errUnknownFrameType 107 } 108 } 109 110 // decodeCloseFrame decodes the Close frame. 111 func (c *ClientStreamCodec) decodeCloseFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 112 // unmarshal Close frame 113 close := &trpcpb.TrpcStreamCloseMeta{} 114 if err := proto.Unmarshal(rspBuf[frameHeadLen:], close); err != nil { 115 return nil, err 116 } 117 118 // It is considered an exception and an error should be returned to the client if: 119 // 1. the CloseType is Reset 120 // 2. ret code != 0 121 if close.GetCloseType() == int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET) || close.GetRet() != 0 { 122 e := &errs.Error{ 123 Type: errs.ErrorTypeCalleeFramework, 124 Code: trpcpb.TrpcRetCode(close.GetRet()), 125 Desc: "trpc", 126 Msg: string(close.GetMsg()), 127 } 128 msg.WithClientRspErr(e) 129 } 130 msg.WithStreamFrame(close) 131 return nil, nil 132 } 133 134 // decodeFeedbackFrame decodes the Feedback frame. 135 func (c *ClientStreamCodec) decodeFeedbackFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 136 feedback := &trpcpb.TrpcStreamFeedBackMeta{} 137 if err := proto.Unmarshal(rspBuf[frameHeadLen:], feedback); err != nil { 138 return nil, err 139 } 140 msg.WithStreamFrame(feedback) 141 return nil, nil 142 } 143 144 // decodeInitFrame decodes the Init frame. 145 func (c *ClientStreamCodec) decodeInitFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 146 initMeta := &trpcpb.TrpcStreamInitMeta{} 147 if err := proto.Unmarshal(rspBuf[frameHeadLen:], initMeta); err != nil { 148 return nil, err 149 } 150 151 msg.WithCompressType(int(initMeta.GetContentEncoding())) 152 msg.WithSerializationType(int(initMeta.GetContentType())) 153 154 // if ret code is not 0, an error should be set and returned 155 if initMeta.GetResponseMeta().GetRet() != 0 { 156 e := &errs.Error{ 157 Type: errs.ErrorTypeCalleeFramework, 158 Code: trpcpb.TrpcRetCode(initMeta.GetResponseMeta().GetRet()), 159 Desc: "trpc", 160 Msg: string(initMeta.GetResponseMeta().GetErrorMsg()), 161 } 162 msg.WithClientRspErr(e) 163 } 164 msg.WithStreamFrame(initMeta) 165 return nil, nil 166 167 } 168 169 // decodeDataFrame decodes the Data frame. 170 func (c *ClientStreamCodec) decodeDataFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 171 // decoding Data frame is straightforward, 172 // as it just returns all data following the frame head 173 return rspBuf[frameHeadLen:], nil 174 } 175 176 // encodeInitFrame encodes the Init frame. 177 func (c *ClientStreamCodec) encodeInitFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 178 initMeta, ok := msg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) 179 if !ok { 180 initMeta = &trpcpb.TrpcStreamInitMeta{} 181 initMeta.RequestMeta = &trpcpb.TrpcStreamInitRequestMeta{} 182 } 183 req := initMeta.RequestMeta 184 // set caller service name 185 // if nil, use the name of the process 186 if msg.CallerServiceName() == "" { 187 msg.WithCallerServiceName(fmt.Sprintf("trpc.app.%s.service", path.Base(os.Args[0]))) 188 } 189 req.Caller = []byte(msg.CallerServiceName()) 190 // set callee service name 191 req.Callee = []byte(msg.CalleeServiceName()) 192 // set backend rpc name, ClientRPCName already set by upper layer of client stub 193 req.Func = []byte(msg.ClientRPCName()) 194 // set backend serialization type 195 initMeta.ContentType = uint32(msg.SerializationType()) 196 // set backend compression type 197 initMeta.ContentEncoding = uint32(msg.CompressType()) 198 // set dyeing info 199 if msg.Dyeing() { 200 req.MessageType = req.MessageType | uint32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE) 201 } 202 // set client transinfo 203 req.TransInfo = setClientTransInfo(msg, req.TransInfo) 204 streamBuf, err := proto.Marshal(initMeta) 205 if err != nil { 206 return nil, err 207 } 208 return frameWrite(frameHead, streamBuf) 209 } 210 211 // encodeDataFrame encodes the Data frame. 212 func (c *ClientStreamCodec) encodeDataFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 213 return frameWrite(frameHead, reqBuf) 214 } 215 216 // encodeCloseFrame encodes the Close frame. 217 func (c *ClientStreamCodec) encodeCloseFrame(frameHead *FrameHead, msg codec.Msg, 218 reqBuf []byte) (rspbuf []byte, err error) { 219 closeFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamCloseMeta) 220 if !ok { 221 return nil, errEncodeCloseFrame 222 } 223 streamBuf, err := proto.Marshal(closeFrame) 224 if err != nil { 225 return nil, err 226 } 227 return frameWrite(frameHead, streamBuf) 228 } 229 230 // encodeFeedbackFrame encodes the Feedback frame. 231 func (c *ClientStreamCodec) encodeFeedbackFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 232 feedbackFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta) 233 if !ok { 234 return nil, errEncodeFeedbackFrame 235 } 236 streamBuf, err := proto.Marshal(feedbackFrame) 237 if err != nil { 238 return nil, err 239 } 240 return frameWrite(frameHead, streamBuf) 241 } 242 243 // frameWrite converts FrameHead to binary frame. 244 func frameWrite(frameHead *FrameHead, streamBuf []byte) ([]byte, error) { 245 // no pb header for streaming rpc 246 return frameHead.construct(nil, streamBuf, nil) 247 } 248 249 // encodeCloseFrame encodes the Close frame. 250 func (s *ServerStreamCodec) encodeCloseFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 251 defer s.deleteInitMeta(msg) 252 closeFrame, ok := msg.StreamFrame().(*trpcpb.TrpcStreamCloseMeta) 253 if !ok { 254 return nil, errEncodeCloseFrame 255 } 256 msg.WithStreamID(frameHead.StreamID) 257 streamBuf, err := proto.Marshal(closeFrame) 258 if err != nil { 259 return nil, err 260 } 261 return frameWrite(frameHead, streamBuf) 262 } 263 264 // encodeDataFrame encodes the Data frame. 265 func (s *ServerStreamCodec) encodeDataFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 266 // If there is an error when processing the Data frame, 267 // then return the Close frame and close the current stream. 268 if err := msg.ServerRspErr(); err != nil { 269 s.buildResetFrame(msg, frameHead, err) 270 return s.encodeCloseFrame(frameHead, msg, reqBuf) 271 } 272 return frameWrite(frameHead, reqBuf) 273 } 274 275 // encodeInitFrame encodes the Init frame. 276 func (s *ServerStreamCodec) encodeInitFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 277 rsp := getStreamInitMeta(msg) 278 rsp.ContentType = uint32(msg.SerializationType()) 279 rsp.ContentEncoding = uint32(msg.CompressType()) 280 rspMeta := &trpcpb.TrpcStreamInitResponseMeta{} 281 if e := msg.ServerRspErr(); e != nil { 282 rspMeta.Ret = int32(e.Code) 283 rspMeta.ErrorMsg = []byte(e.Msg) 284 } 285 rsp.ResponseMeta = rspMeta 286 streamBuf, err := proto.Marshal(rsp) 287 if err != nil { 288 return nil, err 289 } 290 return frameWrite(frameHead, streamBuf) 291 } 292 293 // encodeFeedbackFrame encodes the Feedback frame. 294 func (s *ServerStreamCodec) encodeFeedbackFrame(frameHead *FrameHead, msg codec.Msg, reqBuf []byte) ([]byte, error) { 295 feedback, ok := msg.StreamFrame().(*trpcpb.TrpcStreamFeedBackMeta) 296 if !ok { 297 return nil, errEncodeFeedbackFrame 298 } 299 streamBuf, err := proto.Marshal(feedback) 300 if err != nil { 301 return nil, err 302 } 303 return frameWrite(frameHead, streamBuf) 304 } 305 306 // getStreamInitMeta returns TrpcStreamInitMeta from msg. 307 // If not found, a new TrpcStreamInitMeta will be created and returned. 308 func getStreamInitMeta(msg codec.Msg) *trpcpb.TrpcStreamInitMeta { 309 rsp, ok := msg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) 310 if !ok { 311 rsp = &trpcpb.TrpcStreamInitMeta{ResponseMeta: &trpcpb.TrpcStreamInitResponseMeta{}} 312 } 313 return rsp 314 } 315 316 // Encode implements codec.Codec. 317 func (s *ServerStreamCodec) Encode(msg codec.Msg, reqBuf []byte) (rspbuf []byte, err error) { 318 frameHead, ok := msg.FrameHead().(*FrameHead) 319 if !ok || !frameHead.isStream() { 320 return nil, errUnknownFrameType 321 } 322 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 323 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 324 return s.encodeInitFrame(frameHead, msg, reqBuf) 325 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 326 return s.encodeDataFrame(frameHead, msg, reqBuf) 327 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 328 return s.encodeCloseFrame(frameHead, msg, reqBuf) 329 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 330 return s.encodeFeedbackFrame(frameHead, msg, reqBuf) 331 default: 332 return nil, errUnknownFrameType 333 } 334 } 335 336 // Decode implements codec.Codec. 337 // It decodes the head and the stream frame data. 338 func (s *ServerStreamCodec) Decode(msg codec.Msg, reqBuf []byte) ([]byte, error) { 339 frameHead, ok := msg.FrameHead().(*FrameHead) 340 if !ok || !frameHead.isStream() { 341 return nil, errUnknownFrameType 342 } 343 msg.WithStreamID(frameHead.StreamID) 344 switch trpcpb.TrpcStreamFrameType(frameHead.StreamFrameType) { 345 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT: 346 return s.decodeInitFrame(msg, reqBuf) 347 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA: 348 return s.decodeDataFrame(msg, reqBuf) 349 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE: 350 return s.decodeCloseFrame(msg, reqBuf) 351 case trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK: 352 return s.decodeFeedbackFrame(msg, reqBuf) 353 default: 354 return nil, errUnknownFrameType 355 } 356 } 357 358 // decodeFeedbackFrame decodes the Feedback frame. 359 func (s *ServerStreamCodec) decodeFeedbackFrame(msg codec.Msg, reqBuf []byte) ([]byte, error) { 360 if err := s.setInitMeta(msg); err != nil { 361 return nil, err 362 } 363 feedback := &trpcpb.TrpcStreamFeedBackMeta{} 364 if err := proto.Unmarshal(reqBuf[frameHeadLen:], feedback); err != nil { 365 return nil, err 366 } 367 msg.WithStreamFrame(feedback) 368 return nil, nil 369 } 370 371 // setInitMeta finds the InitMeta and sets the ServerRPCName by the server handler in the InitMeta. 372 func (s *ServerStreamCodec) setInitMeta(msg codec.Msg) error { 373 streamID := msg.StreamID() 374 addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) 375 s.m.RLock() 376 defer s.m.RUnlock() 377 if streamIDToInitMeta, ok := s.initMetas[addr]; ok { 378 if initMeta, ok := streamIDToInitMeta[streamID]; ok { 379 msg.WithServerRPCName(string(initMeta.GetRequestMeta().GetFunc())) 380 return nil 381 } 382 } 383 return errUninitializedMeta 384 } 385 386 // deleteInitMeta deletes the cached info by msg. 387 func (s *ServerStreamCodec) deleteInitMeta(msg codec.Msg) { 388 addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) 389 streamID := msg.StreamID() 390 s.m.Lock() 391 defer s.m.Unlock() 392 delete(s.initMetas[addr], streamID) 393 if len(s.initMetas[addr]) == 0 { 394 delete(s.initMetas, addr) 395 } 396 } 397 398 // decodeCloseFrame decodes the Close frame. 399 func (s *ServerStreamCodec) decodeCloseFrame(msg codec.Msg, rspBuf []byte) ([]byte, error) { 400 if err := s.setInitMeta(msg); err != nil { 401 return nil, err 402 } 403 close := &trpcpb.TrpcStreamCloseMeta{} 404 if err := proto.Unmarshal(rspBuf[frameHeadLen:], close); err != nil { 405 return nil, err 406 } 407 // It is considered an exception and an error should be returned to the client if: 408 // 1. the CloseType is Reset 409 // 2. ret code != 0 410 if close.GetCloseType() == int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET) || close.GetRet() != 0 { 411 e := &errs.Error{ 412 Type: errs.ErrorTypeCalleeFramework, 413 Code: trpcpb.TrpcRetCode(close.GetRet()), 414 Desc: "trpc", 415 Msg: string(close.GetMsg()), 416 } 417 msg.WithServerRspErr(e) 418 } 419 msg.WithStreamFrame(close) 420 return nil, nil 421 } 422 423 // decodeDataFrame decodes the Data frame. 424 func (s *ServerStreamCodec) decodeDataFrame(msg codec.Msg, reqBuf []byte) ([]byte, error) { 425 if err := s.setInitMeta(msg); err != nil { 426 return nil, err 427 } 428 reqBody := reqBuf[frameHeadLen:] 429 return reqBody, nil 430 } 431 432 // decodeInitFrame decodes the Init frame. 433 func (s *ServerStreamCodec) decodeInitFrame(msg codec.Msg, reqBuf []byte) ([]byte, error) { 434 initMeta := &trpcpb.TrpcStreamInitMeta{} 435 if err := proto.Unmarshal(reqBuf[frameHeadLen:], initMeta); err != nil { 436 return nil, err 437 } 438 s.updateMsg(msg, initMeta) 439 s.storeInitMeta(msg, initMeta) 440 msg.WithStreamFrame(initMeta) 441 return nil, nil 442 } 443 444 // storeInitMeta stores the InitMeta every time when a new frame is received. 445 func (s *ServerStreamCodec) storeInitMeta(msg codec.Msg, initMeta *trpcpb.TrpcStreamInitMeta) { 446 streamID := msg.StreamID() 447 addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) 448 s.m.Lock() 449 defer s.m.Unlock() 450 if _, ok := s.initMetas[addr]; ok { 451 s.initMetas[addr][streamID] = initMeta 452 } else { 453 t := make(map[uint32]*trpcpb.TrpcStreamInitMeta) 454 t[streamID] = initMeta 455 s.initMetas[addr] = t 456 } 457 } 458 459 // updateMsg updates the Msg by InitMeta. 460 func (s *ServerStreamCodec) updateMsg(msg codec.Msg, initMeta *trpcpb.TrpcStreamInitMeta) { 461 // get request meta 462 req := initMeta.GetRequestMeta() 463 464 // set caller service name 465 msg.WithCallerServiceName(string(req.GetCaller())) 466 msg.WithCalleeServiceName(string(req.GetCallee())) 467 // set server handler method name 468 msg.WithServerRPCName(string(req.GetFunc())) 469 // set body serialization type 470 msg.WithSerializationType(int(initMeta.GetContentType())) 471 // set body compression type 472 msg.WithCompressType(int(initMeta.GetContentEncoding())) 473 msg.WithDyeing((req.GetMessageType() & uint32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)) != 0) 474 475 if len(req.TransInfo) > 0 { 476 msg.WithServerMetaData(req.GetTransInfo()) 477 // set dyeing key 478 if bs, ok := req.TransInfo[DyeingKey]; ok { 479 msg.WithDyeingKey(string(bs)) 480 } 481 // set environment message for transfer 482 if envs, ok := req.TransInfo[EnvTransfer]; ok { 483 msg.WithEnvTransfer(string(envs)) 484 } 485 } 486 } 487 488 func (s *ServerStreamCodec) buildResetFrame(msg codec.Msg, frameHead *FrameHead, err error) { 489 frameHead.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 490 closeMeta := &trpcpb.TrpcStreamCloseMeta{ 491 CloseType: int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_RESET), 492 Ret: int32(errs.Code(err)), 493 Msg: []byte(errs.Msg(err)), 494 } 495 msg.WithStreamFrame(closeMeta) 496 }