github.com/cloudwego/kitex@v0.9.0/pkg/utils/thrift.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package utils 18 19 import ( 20 "bytes" 21 "context" 22 "errors" 23 "fmt" 24 25 "github.com/apache/thrift/lib/go/thrift" 26 ) 27 28 // ThriftMessageCodec is used to codec thrift messages. 29 type ThriftMessageCodec struct { 30 tb *thrift.TMemoryBuffer 31 tProt thrift.TProtocol 32 } 33 34 // NewThriftMessageCodec creates a new ThriftMessageCodec. 35 func NewThriftMessageCodec() *ThriftMessageCodec { 36 transport := thrift.NewTMemoryBufferLen(1024) 37 tProt := thrift.NewTBinaryProtocol(transport, true, true) 38 39 return &ThriftMessageCodec{ 40 tb: transport, 41 tProt: tProt, 42 } 43 } 44 45 // Encode do thrift message encode. 46 // Notice! msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result 47 // Notice! seqID will be reset in kitex if the buffer is used for generic call in client side, set seqID=0 is suggested 48 // when you call this method as client. 49 func (t *ThriftMessageCodec) Encode(method string, msgType thrift.TMessageType, seqID int32, msg thrift.TStruct) (b []byte, err error) { 50 if method == "" { 51 return nil, errors.New("empty methodName in thrift RPCEncode") 52 } 53 t.tb.Reset() 54 if err = t.tProt.WriteMessageBegin(method, msgType, seqID); err != nil { 55 return 56 } 57 if err = msg.Write(t.tProt); err != nil { 58 return 59 } 60 if err = t.tProt.WriteMessageEnd(); err != nil { 61 return 62 } 63 b = append(b, t.tb.Bytes()...) 64 return 65 } 66 67 // Decode do thrift message decode, notice: msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result 68 func (t *ThriftMessageCodec) Decode(b []byte, msg thrift.TStruct) (method string, seqID int32, err error) { 69 t.tb.Reset() 70 if _, err = t.tb.Write(b); err != nil { 71 return 72 } 73 var msgType thrift.TMessageType 74 if method, msgType, seqID, err = t.tProt.ReadMessageBegin(); err != nil { 75 return 76 } 77 if msgType == thrift.EXCEPTION { 78 exception := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "") 79 if err = exception.Read(t.tProt); err != nil { 80 return 81 } 82 if err = t.tProt.ReadMessageEnd(); err != nil { 83 return 84 } 85 err = exception 86 return 87 } 88 if err = msg.Read(t.tProt); err != nil { 89 return 90 } 91 t.tProt.ReadMessageEnd() 92 return 93 } 94 95 // Serialize serialize message into bytes. This is normal thrift serialize func. 96 // Notice: Binary generic use Encode instead of Serialize. 97 func (t *ThriftMessageCodec) Serialize(msg thrift.TStruct) (b []byte, err error) { 98 t.tb.Reset() 99 100 if err = msg.Write(t.tProt); err != nil { 101 return 102 } 103 b = append(b, t.tb.Bytes()...) 104 return 105 } 106 107 // Deserialize deserialize bytes into message. This is normal thrift deserialize func. 108 // Notice: Binary generic use Decode instead of Deserialize. 109 func (t *ThriftMessageCodec) Deserialize(msg thrift.TStruct, b []byte) (err error) { 110 t.tb.Reset() 111 if _, err = t.tb.Write(b); err != nil { 112 return 113 } 114 if err = msg.Read(t.tProt); err != nil { 115 return 116 } 117 return nil 118 } 119 120 // MarshalError convert go error to thrift exception, and encode exception over buffered binary transport. 121 func MarshalError(method string, err error) []byte { 122 e := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, err.Error()) 123 var buf bytes.Buffer 124 trans := thrift.NewStreamTransportRW(&buf) 125 proto := thrift.NewTBinaryProtocol(trans, true, true) 126 if err := proto.WriteMessageBegin(method, thrift.EXCEPTION, 0); err != nil { 127 return nil 128 } 129 if err := e.Write(proto); err != nil { 130 return nil 131 } 132 if err := proto.WriteMessageEnd(); err != nil { 133 return nil 134 } 135 if err := proto.Flush(context.Background()); err != nil { 136 return nil 137 } 138 return buf.Bytes() 139 } 140 141 // UnmarshalError decode binary and return error message 142 func UnmarshalError(b []byte) error { 143 trans := thrift.NewStreamTransportR(bytes.NewReader(b)) 144 proto := thrift.NewTBinaryProtocolTransport(trans) 145 if _, _, _, err := proto.ReadMessageBegin(); err != nil { 146 return fmt.Errorf("read message begin error: %w", err) 147 } 148 e := thrift.NewTApplicationException(0, "") 149 if err := e.Read(proto); err != nil { 150 return fmt.Errorf("read exception error: %w", err) 151 } 152 if err := proto.ReadMessageEnd(); err != nil { 153 return fmt.Errorf("read message end error: %w", err) 154 } 155 return e 156 }