trpc.group/trpc-go/trpc-go@v1.0.2/codec.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package trpc 15 16 import ( 17 "encoding/binary" 18 "errors" 19 "fmt" 20 "io" 21 "math" 22 "os" 23 "path" 24 "sync/atomic" 25 "time" 26 27 "trpc.group/trpc-go/trpc-go/codec" 28 "trpc.group/trpc-go/trpc-go/errs" 29 "trpc.group/trpc-go/trpc-go/internal/attachment" 30 icodec "trpc.group/trpc-go/trpc-go/internal/codec" 31 "trpc.group/trpc-go/trpc-go/transport" 32 33 "google.golang.org/protobuf/proto" 34 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 35 ) 36 37 func init() { 38 codec.Register(ProtocolName, DefaultServerCodec, DefaultClientCodec) 39 transport.RegisterFramerBuilder(ProtocolName, DefaultFramerBuilder) 40 } 41 42 // default codec 43 var ( 44 DefaultServerCodec = &ServerCodec{ 45 streamCodec: NewServerStreamCodec(), 46 } 47 DefaultClientCodec = &ClientCodec{ 48 streamCodec: NewClientStreamCodec(), 49 defaultCaller: fmt.Sprintf("trpc.client.%s.service", path.Base(os.Args[0])), 50 } 51 DefaultFramerBuilder = &FramerBuilder{} 52 53 // DefaultMaxFrameSize is the default max size of frame including attachment, 54 // which can be modified if size of the packet is bigger than this. 55 DefaultMaxFrameSize = 10 * 1024 * 1024 56 ) 57 58 var ( 59 errHeadOverflowsUint16 = errors.New("head len overflows uint16") 60 errHeadOverflowsUint32 = errors.New("total len overflows uint32") 61 errAttachmentOverflowsUint32 = errors.New("attachment len overflows uint32") 62 ) 63 64 type errFrameTooLarge struct { 65 maxFrameSize int 66 } 67 68 // Error implements the error interface and returns the description of the errFrameTooLarge. 69 func (e *errFrameTooLarge) Error() string { 70 return fmt.Sprintf("frame len is larger than MaxFrameSize(%d)", e.maxFrameSize) 71 } 72 73 // frequently used const variables 74 const ( 75 DyeingKey = "trpc-dyeing-key" // dyeing key 76 UserIP = "trpc-user-ip" // user ip 77 EnvTransfer = "trpc-env" // env info 78 79 ProtocolName = "trpc" // protocol name 80 ) 81 82 // trpc protocol codec 83 const ( 84 // frame head format: 85 // v0: 86 // 2 bytes magic + 1 byte frame type + 1 byte stream frame type + 4 bytes total len 87 // + 2 bytes pb header len + 4 bytes stream id + 2 bytes reserved 88 // v1: 89 // 2 bytes magic + 1 byte frame type + 1 byte stream frame type + 4 bytes total len 90 // + 2 bytes pb header len + 4 bytes stream id + 1 byte protocol version + 1 byte reserved 91 frameHeadLen = uint16(16) // total length of frame head: 16 bytes 92 protocolVersion0 = uint8(0) // v0 93 protocolVersion1 = uint8(1) // v1 94 curProtocolVersion = protocolVersion1 // current protocol version 95 ) 96 97 // FrameHead is head of the trpc frame. 98 type FrameHead struct { 99 FrameType uint8 // type of the frame 100 StreamFrameType uint8 // type of the stream frame 101 TotalLen uint32 // total length 102 HeaderLen uint16 // header's length 103 StreamID uint32 // stream id for streaming rpc, request id for unary rpc 104 ProtocolVersion uint8 // version of protocol 105 FrameReserved uint8 // reserved bits for further development 106 } 107 108 func newDefaultUnaryFrameHead() *FrameHead { 109 return &FrameHead{ 110 FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_UNARY_FRAME), // default unary 111 ProtocolVersion: curProtocolVersion, 112 } 113 } 114 115 // extract extracts field values of the FrameHead from the buffer. 116 func (h *FrameHead) extract(buf []byte) { 117 h.FrameType = buf[2] 118 h.StreamFrameType = buf[3] 119 h.TotalLen = binary.BigEndian.Uint32(buf[4:8]) 120 h.HeaderLen = binary.BigEndian.Uint16(buf[8:10]) 121 h.StreamID = binary.BigEndian.Uint32(buf[10:14]) 122 h.ProtocolVersion = buf[14] 123 h.FrameReserved = buf[15] 124 } 125 126 // construct constructs bytes data for the whole frame. 127 func (h *FrameHead) construct(header, body, attachment []byte) ([]byte, error) { 128 headerLen := len(header) 129 if headerLen > math.MaxUint16 { 130 return nil, errHeadOverflowsUint16 131 } 132 attachmentLen := int64(len(attachment)) 133 if attachmentLen > math.MaxUint32 { 134 return nil, errAttachmentOverflowsUint32 135 } 136 totalLen := int64(frameHeadLen) + int64(headerLen) + int64(len(body)) + attachmentLen 137 if totalLen > int64(DefaultMaxFrameSize) { 138 return nil, &errFrameTooLarge{maxFrameSize: DefaultMaxFrameSize} 139 } 140 if totalLen > math.MaxUint32 { 141 return nil, errHeadOverflowsUint32 142 } 143 144 // construct the buffer 145 buf := make([]byte, totalLen) 146 binary.BigEndian.PutUint16(buf[:2], uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE)) 147 buf[2] = h.FrameType 148 buf[3] = h.StreamFrameType 149 binary.BigEndian.PutUint32(buf[4:8], uint32(totalLen)) 150 binary.BigEndian.PutUint16(buf[8:10], uint16(headerLen)) 151 binary.BigEndian.PutUint32(buf[10:14], h.StreamID) 152 buf[14] = h.ProtocolVersion 153 buf[15] = h.FrameReserved 154 155 frameHeadLen := int(frameHeadLen) 156 copy(buf[frameHeadLen:frameHeadLen+headerLen], header) 157 copy(buf[frameHeadLen+headerLen:frameHeadLen+headerLen+len(body)], body) 158 copy(buf[frameHeadLen+headerLen+len(body):], attachment) 159 return buf, nil 160 } 161 162 func (h *FrameHead) isStream() bool { 163 return trpcpb.TrpcDataFrameType(h.FrameType) == trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME 164 } 165 166 func (h *FrameHead) isUnary() bool { 167 return trpcpb.TrpcDataFrameType(h.FrameType) == trpcpb.TrpcDataFrameType_TRPC_UNARY_FRAME 168 } 169 170 // upgradeProtocol upgrades protocol and sets stream id and request id. 171 // For compatibility, server should respond the same protocol version as that of the request. 172 // and client should always send request with the latest protocol version. 173 func (h *FrameHead) upgradeProtocol(protocolVersion uint8, requestID uint32) { 174 h.ProtocolVersion = protocolVersion 175 h.StreamID = requestID 176 } 177 178 // FramerBuilder is an implementation of codec.FramerBuilder. 179 // Used for trpc protocol. 180 type FramerBuilder struct{} 181 182 // New implements codec.FramerBuilder. 183 func (fb *FramerBuilder) New(reader io.Reader) codec.Framer { 184 return &framer{ 185 reader: reader, 186 } 187 } 188 189 // Parse implement multiplexed.FrameParser interface. 190 func (fb *FramerBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) { 191 buf, err = fb.New(rc).ReadFrame() 192 if err != nil { 193 return 0, nil, err 194 } 195 return binary.BigEndian.Uint32(buf[10:14]), buf, nil 196 } 197 198 // framer is an implementation of codec.Framer. 199 // Used for trpc protocol. 200 type framer struct { 201 reader io.Reader 202 header [frameHeadLen]byte 203 } 204 205 // ReadFrame implements codec.Framer. 206 func (f *framer) ReadFrame() ([]byte, error) { 207 num, err := io.ReadFull(f.reader, f.header[:]) 208 if err != nil { 209 return nil, err 210 } 211 if num != int(frameHeadLen) { 212 return nil, fmt.Errorf("trpc framer: read frame header num %d != %d, invalid", num, int(frameHeadLen)) 213 } 214 magic := binary.BigEndian.Uint16(f.header[:2]) 215 if magic != uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE) { 216 return nil, fmt.Errorf( 217 "trpc framer: read framer head magic %d != %d, not match", magic, uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE)) 218 } 219 totalLen := binary.BigEndian.Uint32(f.header[4:8]) 220 if totalLen < uint32(frameHeadLen) { 221 return nil, fmt.Errorf( 222 "trpc framer: read frame header total len %d < %d, invalid", totalLen, uint32(frameHeadLen)) 223 } 224 225 if totalLen > uint32(DefaultMaxFrameSize) { 226 return nil, fmt.Errorf( 227 "trpc framer: read frame header total len %d > %d, too large", totalLen, uint32(DefaultMaxFrameSize)) 228 } 229 230 msg := make([]byte, totalLen) 231 num, err = io.ReadFull(f.reader, msg[frameHeadLen:totalLen]) 232 if err != nil { 233 return nil, err 234 } 235 if num != int(totalLen-uint32(frameHeadLen)) { 236 return nil, fmt.Errorf( 237 "trpc framer: read frame total num %d != %d, invalid", num, int(totalLen-uint32(frameHeadLen))) 238 } 239 copy(msg, f.header[:]) 240 return msg, nil 241 } 242 243 // IsSafe implements codec.SafeFramer. 244 // Used for compatibility. 245 func (f *framer) IsSafe() bool { 246 return true 247 } 248 249 // ServerCodec is an implementation of codec.Codec. 250 // Used for trpc serverside codec. 251 type ServerCodec struct { 252 streamCodec *ServerStreamCodec 253 } 254 255 // Decode implements codec.Codec. 256 // It decodes the reqBuf and updates the msg that already initialized by service handler. 257 func (s *ServerCodec) Decode(msg codec.Msg, reqBuf []byte) ([]byte, error) { 258 if len(reqBuf) < int(frameHeadLen) { 259 return nil, errors.New("server decode req buf len invalid") 260 } 261 frameHead := newDefaultUnaryFrameHead() 262 frameHead.extract(reqBuf) 263 msg.WithFrameHead(frameHead) 264 if frameHead.TotalLen != uint32(len(reqBuf)) { 265 return nil, fmt.Errorf("total len %d is not actual buf len %d", frameHead.TotalLen, len(reqBuf)) 266 } 267 if frameHead.FrameType != uint8(trpcpb.TrpcDataFrameType_TRPC_UNARY_FRAME) { // streaming rpc has its own decoding 268 rspBody, err := s.streamCodec.Decode(msg, reqBuf) 269 if err != nil { 270 // if decoding fails, the Close frame with Reset type will be returned to the client 271 err := errs.NewFrameError(errs.RetServerDecodeFail, err.Error()) 272 s.streamCodec.buildResetFrame(msg, frameHead, err) 273 return nil, err 274 } 275 return rspBody, nil 276 } 277 if frameHead.HeaderLen == 0 { // header not allowed to be empty for unary rpc 278 return nil, errors.New("server decode pb head len empty") 279 } 280 281 requestProtocolBegin := uint32(frameHeadLen) 282 requestProtocolEnd := requestProtocolBegin + uint32(frameHead.HeaderLen) 283 if requestProtocolEnd > uint32(len(reqBuf)) { 284 return nil, errors.New("server decode pb head len invalid") 285 } 286 req := &trpcpb.RequestProtocol{} 287 if err := proto.Unmarshal(reqBuf[requestProtocolBegin:requestProtocolEnd], req); err != nil { 288 return nil, err 289 } 290 291 attachmentBegin := frameHead.TotalLen - req.AttachmentSize 292 if s := uint32(len(reqBuf)) - attachmentBegin; s != req.AttachmentSize { 293 return nil, fmt.Errorf("decoding attachment: len of attachment(%d) "+ 294 "isn't equal to expected AttachmentSize(%d) ", s, req.AttachmentSize) 295 } 296 297 msgWithRequestProtocol(msg, req, reqBuf[attachmentBegin:]) 298 299 requestBodyBegin, requestBodyEnd := requestProtocolEnd, attachmentBegin 300 return reqBuf[requestBodyBegin:requestBodyEnd], nil 301 } 302 303 func msgWithRequestProtocol(msg codec.Msg, req *trpcpb.RequestProtocol, attm []byte) { 304 // set server request head 305 msg.WithServerReqHead(req) 306 // construct response protocol head in advance 307 rsp := newResponseProtocol(req) 308 msg.WithServerRspHead(rsp) 309 310 // ---------the following code is to set the essential info-----------// 311 // set upstream timeout 312 msg.WithRequestTimeout(time.Millisecond * time.Duration(req.GetTimeout())) 313 // set upstream service name 314 msg.WithCallerServiceName(string(req.GetCaller())) 315 msg.WithCalleeServiceName(string(req.GetCallee())) 316 // set server handler method name 317 rpcName := string(req.GetFunc()) 318 msg.WithServerRPCName(rpcName) 319 msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) 320 // set body serialization type 321 msg.WithSerializationType(int(req.GetContentType())) 322 // set body compression type 323 msg.WithCompressType(int(req.GetContentEncoding())) 324 // set dyeing mark 325 msg.WithDyeing((req.GetMessageType() & uint32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)) != 0) 326 // parse tracing MetaData, set MetaData into msg 327 if len(req.TransInfo) > 0 { 328 msg.WithServerMetaData(req.GetTransInfo()) 329 // mark with dyeing key 330 if bs, ok := req.TransInfo[DyeingKey]; ok { 331 msg.WithDyeingKey(string(bs)) 332 } 333 // transmit env info 334 if envs, ok := req.TransInfo[EnvTransfer]; ok { 335 msg.WithEnvTransfer(string(envs)) 336 } 337 } 338 // set call type 339 msg.WithCallType(codec.RequestType(req.GetCallType())) 340 if len(attm) != 0 { 341 attachment.SetServerRequestAttachment(msg, attm) 342 } 343 } 344 345 // Encode implements codec.Codec. 346 // It encodes the rspBody to binary data and returns it to client. 347 func (s *ServerCodec) Encode(msg codec.Msg, rspBody []byte) ([]byte, error) { 348 frameHead := loadOrStoreDefaultUnaryFrameHead(msg) 349 if frameHead.isStream() { 350 return s.streamCodec.Encode(msg, rspBody) 351 } 352 if !frameHead.isUnary() { 353 return nil, errUnknownFrameType 354 } 355 356 rspProtocol := getAndInitResponseProtocol(msg) 357 358 var attm []byte 359 if a, ok := attachment.ServerResponseAttachment(msg); ok { 360 var err error 361 if attm, err = io.ReadAll(a); err != nil { 362 return nil, fmt.Errorf("encoding attachment: %w", err) 363 } 364 } 365 rspProtocol.AttachmentSize = uint32(len(attm)) 366 367 rspHead, err := proto.Marshal(rspProtocol) 368 if err != nil { 369 return nil, err 370 } 371 372 rspBuf, err := frameHead.construct(rspHead, rspBody, attm) 373 if errors.Is(err, errHeadOverflowsUint16) { 374 return handleEncodeErr(rspProtocol, frameHead, rspBody, err) 375 } 376 var frameTooLargeErr *errFrameTooLarge 377 if errors.As(err, &frameTooLargeErr) || errors.Is(err, errHeadOverflowsUint32) { 378 // If frame len is larger than DefaultMaxFrameSize or overflows uint32, set rspBody nil. 379 return handleEncodeErr(rspProtocol, frameHead, nil, err) 380 } 381 return rspBuf, err 382 } 383 384 // getAndInitResponseProtocol returns rsp head from msg and initialize the rsp with msg. 385 // If rsp head is not found from msg, a new rsp head will be created and initialized. 386 func getAndInitResponseProtocol(msg codec.Msg) *trpcpb.ResponseProtocol { 387 rsp, ok := msg.ServerRspHead().(*trpcpb.ResponseProtocol) 388 if !ok { 389 if req, ok := msg.ServerReqHead().(*trpcpb.RequestProtocol); ok { 390 rsp = newResponseProtocol(req) 391 } else { 392 rsp = &trpcpb.ResponseProtocol{} 393 } 394 } 395 396 // update serialization and compression type 397 rsp.ContentType = uint32(msg.SerializationType()) 398 rsp.ContentEncoding = uint32(msg.CompressType()) 399 400 // convert error returned by server handler to ret code in response protocol head 401 if err := msg.ServerRspErr(); err != nil { 402 rsp.ErrorMsg = []byte(err.Msg) 403 if err.Type == errs.ErrorTypeFramework { 404 rsp.Ret = int32(err.Code) 405 } else { 406 rsp.FuncRet = int32(err.Code) 407 } 408 } 409 410 if len(msg.ServerMetaData()) > 0 { 411 if rsp.TransInfo == nil { 412 rsp.TransInfo = make(map[string][]byte) 413 } 414 for k, v := range msg.ServerMetaData() { 415 rsp.TransInfo[k] = v 416 } 417 } 418 419 return rsp 420 } 421 422 func newResponseProtocol(req *trpcpb.RequestProtocol) *trpcpb.ResponseProtocol { 423 return &trpcpb.ResponseProtocol{ 424 Version: uint32(trpcpb.TrpcProtoVersion_TRPC_PROTO_V1), 425 CallType: req.CallType, 426 RequestId: req.RequestId, 427 MessageType: req.MessageType, 428 ContentType: req.ContentType, 429 ContentEncoding: req.ContentEncoding, 430 } 431 } 432 433 // handleEncodeErr handles encode err and returns RetServerEncodeFail. 434 func handleEncodeErr( 435 rsp *trpcpb.ResponseProtocol, 436 frameHead *FrameHead, 437 rspBody []byte, 438 encodeErr error, 439 ) ([]byte, error) { 440 // discard all TransInfo and return RetServerEncodeFail 441 // cover the original no matter what 442 rsp.TransInfo = nil 443 rsp.Ret = int32(errs.RetServerEncodeFail) 444 rsp.ErrorMsg = []byte(encodeErr.Error()) 445 rspHead, err := proto.Marshal(rsp) 446 if err != nil { 447 return nil, err 448 } 449 // if error still occurs, response will be discarded. 450 // client will be notified as conn closed 451 return frameHead.construct(rspHead, rspBody, nil) 452 } 453 454 // ClientCodec is an implementation of codec.Codec. 455 // Used for trpc clientside codec. 456 type ClientCodec struct { 457 streamCodec *ClientStreamCodec 458 defaultCaller string // trpc.app.server.service 459 requestID uint32 // global unique request id 460 } 461 462 // Encode implements codec.Codec. 463 // It encodes reqBody into binary data. New msg will be cloned by client stub. 464 func (c *ClientCodec) Encode(msg codec.Msg, reqBody []byte) (reqBuf []byte, err error) { 465 frameHead := loadOrStoreDefaultUnaryFrameHead(msg) 466 if frameHead.isStream() { 467 return c.streamCodec.Encode(msg, reqBody) 468 } 469 if !frameHead.isUnary() { 470 return nil, errUnknownFrameType 471 } 472 473 // create a new framehead without modifying the original one 474 // to avoid overwriting the requestID of the original framehead. 475 frameHead = newDefaultUnaryFrameHead() 476 req, err := loadOrStoreDefaultRequestProtocol(msg) 477 if err != nil { 478 return nil, err 479 } 480 481 // request id atomically increases by 1, ensuring that each request id is unique. 482 requestID := atomic.AddUint32(&c.requestID, 1) 483 frameHead.upgradeProtocol(curProtocolVersion, requestID) 484 msg.WithRequestID(requestID) 485 486 var attm []byte 487 if a, ok := attachment.ClientRequestAttachment(msg); ok { 488 if attm, err = io.ReadAll(a); err != nil { 489 return nil, fmt.Errorf("encoding attachment: %w", err) 490 } 491 } 492 req.AttachmentSize = uint32(len(attm)) 493 494 updateRequestProtocol(req, updateCallerServiceName(msg, c.defaultCaller)) 495 496 reqHead, err := proto.Marshal(req) 497 if err != nil { 498 return nil, err 499 } 500 return frameHead.construct(reqHead, reqBody, attm) 501 } 502 503 // loadOrStoreDefaultRequestProtocol loads the existing RequestProtocol from msg if present. 504 // Otherwise, it stores default UnaryRequestProtocol created to msg and returns the default RequestProtocol. 505 func loadOrStoreDefaultRequestProtocol(msg codec.Msg) (*trpcpb.RequestProtocol, error) { 506 if req := msg.ClientReqHead(); req != nil { 507 // client req head not being nil means it's created on purpose and set to 508 // record request protocol head 509 req, ok := req.(*trpcpb.RequestProtocol) 510 if !ok { 511 return nil, errors.New("client encode req head type invalid, must be trpc request protocol head") 512 } 513 return req, nil 514 } 515 516 req := newDefaultUnaryRequestProtocol() 517 msg.WithClientReqHead(req) 518 return req, nil 519 } 520 521 func newDefaultUnaryRequestProtocol() *trpcpb.RequestProtocol { 522 return &trpcpb.RequestProtocol{ 523 Version: uint32(trpcpb.TrpcProtoVersion_TRPC_PROTO_V1), 524 CallType: uint32(trpcpb.TrpcCallType_TRPC_UNARY_CALL), 525 } 526 } 527 528 // update updates CallerServiceName of msg with name 529 func updateCallerServiceName(msg codec.Msg, name string) codec.Msg { 530 if msg.CallerServiceName() == "" { 531 msg.WithCallerServiceName(name) 532 } 533 return msg 534 } 535 536 // update updates req with requestID and msg. 537 func updateRequestProtocol(req *trpcpb.RequestProtocol, msg codec.Msg) { 538 req.RequestId = msg.RequestID() 539 req.Caller = []byte(msg.CallerServiceName()) 540 // set callee service name 541 req.Callee = []byte(msg.CalleeServiceName()) 542 // set backend rpc name 543 req.Func = []byte(msg.ClientRPCName()) 544 // set backend serialization type 545 req.ContentType = uint32(msg.SerializationType()) 546 // set backend compression type 547 req.ContentEncoding = uint32(msg.CompressType()) 548 // set rest timeout for downstream 549 req.Timeout = uint32(msg.RequestTimeout() / time.Millisecond) 550 // set dyeing info 551 if msg.Dyeing() { 552 req.MessageType = req.MessageType | uint32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE) 553 } 554 // set client transinfo 555 req.TransInfo = setClientTransInfo(msg, req.TransInfo) 556 // set call type 557 req.CallType = uint32(msg.CallType()) 558 } 559 560 // setClientTransInfo sets client TransInfo. 561 func setClientTransInfo(msg codec.Msg, trans map[string][]byte) map[string][]byte { 562 // set MetaData 563 if len(msg.ClientMetaData()) > 0 { 564 if trans == nil { 565 trans = make(map[string][]byte) 566 } 567 for k, v := range msg.ClientMetaData() { 568 trans[k] = v 569 } 570 } 571 if len(msg.DyeingKey()) > 0 { 572 if trans == nil { 573 trans = make(map[string][]byte) 574 } 575 trans[DyeingKey] = []byte(msg.DyeingKey()) 576 } 577 if len(msg.EnvTransfer()) > 0 { 578 if trans == nil { 579 trans = make(map[string][]byte) 580 } 581 trans[EnvTransfer] = []byte(msg.EnvTransfer()) 582 } else { 583 // if msg.EnvTransfer() empty, transmitted env info in req.TransInfo should be cleared 584 if _, ok := trans[EnvTransfer]; ok { 585 trans[EnvTransfer] = nil 586 } 587 } 588 return trans 589 } 590 591 // Decode implements codec.Codec. 592 // It decodes rspBuf into rspBody. 593 func (c *ClientCodec) Decode(msg codec.Msg, rspBuf []byte) (rspBody []byte, err error) { 594 if len(rspBuf) < int(frameHeadLen) { 595 return nil, errors.New("client decode rsp buf len invalid") 596 } 597 frameHead := newDefaultUnaryFrameHead() 598 frameHead.extract(rspBuf) 599 msg.WithFrameHead(frameHead) 600 if frameHead.TotalLen != uint32(len(rspBuf)) { 601 return nil, fmt.Errorf("total len %d is not actual buf len %d", frameHead.TotalLen, len(rspBuf)) 602 } 603 if trpcpb.TrpcDataFrameType(frameHead.FrameType) != trpcpb.TrpcDataFrameType_TRPC_UNARY_FRAME { 604 return c.streamCodec.Decode(msg, rspBuf) 605 } 606 if frameHead.HeaderLen == 0 { 607 return nil, errors.New("client decode pb head len empty") 608 } 609 610 responseProtocolBegin := uint32(frameHeadLen) 611 responseProtocolEnd := responseProtocolBegin + uint32(frameHead.HeaderLen) 612 if responseProtocolEnd > uint32(len(rspBuf)) { 613 return nil, errors.New("client decode pb head len invalid") 614 } 615 rsp, err := loadOrStoreResponseHead(msg) 616 if err != nil { 617 return nil, err 618 } 619 if err := proto.Unmarshal(rspBuf[responseProtocolBegin:responseProtocolEnd], rsp); err != nil { 620 return nil, err 621 } 622 623 attachmentBegin := frameHead.TotalLen - rsp.AttachmentSize 624 if s := uint32(len(rspBuf)) - attachmentBegin; rsp.AttachmentSize != s { 625 return nil, fmt.Errorf("decoding attachment:(%d) len of attachment"+ 626 "isn't equal to expected AttachmentSize(%d)", s, rsp.AttachmentSize) 627 } 628 if err := updateMsg(msg, frameHead, rsp, rspBuf[attachmentBegin:]); err != nil { 629 return nil, err 630 } 631 632 bodyBegin, bodyEnd := responseProtocolEnd, attachmentBegin 633 return rspBuf[bodyBegin:bodyEnd], nil 634 } 635 636 func loadOrStoreResponseHead(msg codec.Msg) (*trpcpb.ResponseProtocol, error) { 637 // client rsp head being nil means no need to record backend response protocol head 638 // most of the time, response head is not set and should be created here. 639 rsp := msg.ClientRspHead() 640 if rsp == nil { 641 rsp := &trpcpb.ResponseProtocol{} 642 msg.WithClientRspHead(rsp) 643 return rsp, nil 644 } 645 646 // client rsp head not being nil means it's created on purpose and set to 647 // record response protocol head 648 { 649 rsp, ok := rsp.(*trpcpb.ResponseProtocol) 650 if !ok { 651 return nil, errors.New("client decode rsp head type invalid, must be trpc response protocol head") 652 } 653 return rsp, nil 654 } 655 } 656 657 // loadOrStoreDefaultUnaryFrameHead loads the existing frameHead from msg if present. 658 // Otherwise, it stores default Unary FrameHead to msg, and returns the default Unary FrameHead. 659 func loadOrStoreDefaultUnaryFrameHead(msg codec.Msg) *FrameHead { 660 frameHead, ok := msg.FrameHead().(*FrameHead) 661 if !ok { 662 frameHead = newDefaultUnaryFrameHead() 663 msg.WithFrameHead(frameHead) 664 } 665 return frameHead 666 } 667 668 func copyRspHead(dst, src *trpcpb.ResponseProtocol) { 669 dst.Version = src.Version 670 dst.CallType = src.CallType 671 dst.RequestId = src.RequestId 672 dst.Ret = src.Ret 673 dst.FuncRet = src.FuncRet 674 dst.ErrorMsg = src.ErrorMsg 675 dst.MessageType = src.MessageType 676 dst.TransInfo = src.TransInfo 677 dst.ContentType = src.ContentType 678 dst.ContentEncoding = src.ContentEncoding 679 } 680 681 func updateMsg(msg codec.Msg, frameHead *FrameHead, rsp *trpcpb.ResponseProtocol, attm []byte) error { 682 msg.WithFrameHead(frameHead) 683 msg.WithCompressType(int(rsp.GetContentEncoding())) 684 msg.WithSerializationType(int(rsp.GetContentType())) 685 686 // reset client metadata if new transinfo is returned with response 687 if len(rsp.TransInfo) > 0 { 688 md := msg.ClientMetaData() 689 if len(md) == 0 { 690 md = codec.MetaData{} 691 } 692 for k, v := range rsp.TransInfo { 693 md[k] = v 694 } 695 msg.WithClientMetaData(md) 696 } 697 698 // if retcode is not 0, a converted error should be returned 699 if rsp.GetRet() != 0 { 700 err := &errs.Error{ 701 Type: errs.ErrorTypeCalleeFramework, 702 Code: trpcpb.TrpcRetCode(rsp.GetRet()), 703 Desc: ProtocolName, 704 Msg: string(rsp.GetErrorMsg()), 705 } 706 msg.WithClientRspErr(err) 707 } else if rsp.GetFuncRet() != 0 { 708 msg.WithClientRspErr(errs.New(int(rsp.GetFuncRet()), string(rsp.GetErrorMsg()))) 709 } 710 711 // error should be returned immediately for request id mismatch 712 req, err := loadOrStoreDefaultRequestProtocol(msg) 713 if err == nil && rsp.RequestId != req.RequestId { 714 return fmt.Errorf("rsp request_id %d different from req request_id %d", rsp.RequestId, req.RequestId) 715 } 716 717 // handle protocol upgrading 718 frameHead.upgradeProtocol(curProtocolVersion, rsp.RequestId) 719 msg.WithRequestID(rsp.RequestId) 720 721 if len(attm) != 0 { 722 attachment.SetClientResponseAttachment(msg, attm) 723 } 724 return nil 725 }