github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/thrift/thrift.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package thrift 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 24 "github.com/apache/thrift/lib/go/thrift" 25 26 "github.com/cloudwego/kitex/pkg/protocol/bthrift" 27 "github.com/cloudwego/kitex/pkg/remote" 28 "github.com/cloudwego/kitex/pkg/remote/codec" 29 "github.com/cloudwego/kitex/pkg/remote/codec/perrors" 30 "github.com/cloudwego/kitex/pkg/rpcinfo" 31 "github.com/cloudwego/kitex/pkg/stats" 32 ) 33 34 // CodecType is config of the thrift codec. Priority: Frugal > FastMode > Normal 35 type CodecType int 36 37 const ( 38 // Basic can be used for disabling fastCodec and frugal 39 Basic CodecType = 0b0000 40 FastWrite CodecType = 0b0001 41 FastRead CodecType = 0b0010 42 43 FastReadWrite = FastRead | FastWrite 44 ) 45 46 var ( 47 defaultCodec = NewThriftCodec().(*thriftCodec) 48 49 errEncodeMismatchMsgType = remote.NewTransErrorWithMsg(remote.InvalidProtocol, 50 "encode failed, codec msg type not match with thriftCodec") 51 errDecodeMismatchMsgType = remote.NewTransErrorWithMsg(remote.InvalidProtocol, 52 "decode failed, codec msg type not match with thriftCodec") 53 ) 54 55 // NewThriftCodec creates the thrift binary codec. 56 func NewThriftCodec() remote.PayloadCodec { 57 return &thriftCodec{FastWrite | FastRead} 58 } 59 60 // IsThriftCodec checks if the codec is thriftCodec 61 func IsThriftCodec(c remote.PayloadCodec) bool { 62 _, ok := c.(*thriftCodec) 63 return ok 64 } 65 66 // NewThriftFrugalCodec creates the thrift binary codec powered by frugal. 67 // Eg: xxxservice.NewServer(handler, server.WithPayloadCodec(thrift.NewThriftCodecWithConfig(thrift.FastWrite | thrift.FastRead))) 68 func NewThriftCodecWithConfig(c CodecType) remote.PayloadCodec { 69 return &thriftCodec{c} 70 } 71 72 // NewThriftCodecDisableFastMode creates the thrift binary codec which can control if do fast codec. 73 // Eg: xxxservice.NewServer(handler, server.WithPayloadCodec(thrift.NewThriftCodecDisableFastMode(true, true))) 74 func NewThriftCodecDisableFastMode(disableFastWrite, disableFastRead bool) remote.PayloadCodec { 75 var c CodecType 76 if !disableFastRead { 77 c |= FastRead 78 } 79 if !disableFastWrite { 80 c |= FastWrite 81 } 82 return &thriftCodec{c} 83 } 84 85 // thriftCodec implements PayloadCodec 86 type thriftCodec struct { 87 CodecType 88 } 89 90 // Marshal implements the remote.PayloadCodec interface. 91 func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out remote.ByteBuffer) error { 92 // prepare info 93 methodName := message.RPCInfo().Invocation().MethodName() 94 if methodName == "" { 95 return errors.New("empty methodName in thrift Marshal") 96 } 97 msgType := message.MessageType() 98 seqID := message.RPCInfo().Invocation().SeqID() 99 100 data, err := getValidData(methodName, message) 101 if err != nil { 102 return err 103 } 104 105 // encode with hyper codec 106 // NOTE: to ensure hyperMarshalEnabled is inlined so split the check logic, or it may cause performance loss 107 if c.hyperMarshalEnabled() && hyperMarshalAvailable(data) { 108 return c.hyperMarshal(out, methodName, msgType, seqID, data) 109 } 110 111 // encode with FastWrite 112 if c.CodecType&FastWrite != 0 { 113 if msg, ok := data.(ThriftMsgFastCodec); ok { 114 return encodeFastThrift(out, methodName, msgType, seqID, msg) 115 } 116 } 117 118 // fallback to old thrift way (slow) 119 if err = encodeBasicThrift(out, ctx, methodName, msgType, seqID, data); err == nil || err != errEncodeMismatchMsgType { 120 return err 121 } 122 123 // Basic can be used for disabling frugal, we need to check it 124 if c.CodecType != Basic && hyperMarshalAvailable(data) { 125 // fallback to frugal when the generated code is using slim template 126 return c.hyperMarshal(out, methodName, msgType, seqID, data) 127 } 128 129 return errEncodeMismatchMsgType 130 } 131 132 // encodeFastThrift encode with the FastCodec way 133 func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.MessageType, seqID int32, msg ThriftMsgFastCodec) error { 134 // nocopy write is a special implementation of linked buffer, only bytebuffer implement NocopyWrite do FastWrite 135 msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, thrift.TMessageType(msgType), seqID) 136 msgEndLen := bthrift.Binary.MessageEndLength() 137 buf, err := out.Malloc(msgBeginLen + msg.BLength() + msgEndLen) 138 if err != nil { 139 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error())) 140 } 141 offset := bthrift.Binary.WriteMessageBegin(buf, methodName, thrift.TMessageType(msgType), seqID) 142 offset += msg.FastWriteNocopy(buf[offset:], nil) 143 bthrift.Binary.WriteMessageEnd(buf[offset:]) 144 return nil 145 } 146 147 // encodeBasicThrift encode with the old thrift way (slow) 148 func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}) error { 149 if err := verifyMarshalBasicThriftDataType(data); err != nil { 150 return err 151 } 152 tProt := NewBinaryProtocol(out) 153 if err := tProt.WriteMessageBegin(method, thrift.TMessageType(msgType), seqID); err != nil { 154 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageBegin failed: %s", err.Error())) 155 } 156 if err := marshalBasicThriftData(ctx, tProt, data); err != nil { 157 return err 158 } 159 if err := tProt.WriteMessageEnd(); err != nil { 160 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageEnd failed: %s", err.Error())) 161 } 162 tProt.Recycle() 163 return nil 164 } 165 166 // Unmarshal implements the remote.PayloadCodec interface. 167 func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { 168 tProt := NewBinaryProtocol(in) 169 methodName, msgType, seqID, err := tProt.ReadMessageBegin() 170 if err != nil { 171 return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift unmarshal, ReadMessageBegin failed: %s", err.Error())) 172 } 173 if err = codec.UpdateMsgType(uint32(msgType), message); err != nil { 174 return err 175 } 176 177 // exception message 178 if message.MessageType() == remote.Exception { 179 return UnmarshalThriftException(tProt) 180 } 181 182 if err = validateMessageBeforeDecode(message, seqID, methodName); err != nil { 183 return err 184 } 185 186 // decode thrift data 187 data := message.Data() 188 msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, msgType, seqID) 189 dataLen := message.PayloadLen() - msgBeginLen - bthrift.Binary.MessageEndLength() 190 191 ri := message.RPCInfo() 192 rpcinfo.Record(ctx, ri, stats.WaitReadStart, nil) 193 err = c.unmarshalThriftData(ctx, tProt, methodName, data, dataLen) 194 rpcinfo.Record(ctx, ri, stats.WaitReadFinish, err) 195 if err != nil { 196 return err 197 } 198 199 if err = tProt.ReadMessageEnd(); err != nil { 200 return remote.NewTransError(remote.ProtocolError, err) 201 } 202 tProt.Recycle() 203 return err 204 } 205 206 // validateMessageBeforeDecode validate message before decode 207 func validateMessageBeforeDecode(message remote.Message, seqID int32, methodName string) (err error) { 208 // For server side, the following error can be sent back and 'SetSeqID' should be executed first to ensure the seqID 209 // is right when return Exception back. 210 if err = codec.SetOrCheckSeqID(seqID, message); err != nil { 211 return err 212 } 213 214 if err = codec.SetOrCheckMethodName(methodName, message); err != nil { 215 return err 216 } 217 218 if err = codec.NewDataIfNeeded(methodName, message); err != nil { 219 return err 220 } 221 return nil 222 } 223 224 // Name implements the remote.PayloadCodec interface. 225 func (c thriftCodec) Name() string { 226 return "thrift" 227 } 228 229 // MessageWriterWithContext write to thrift.TProtocol 230 type MessageWriterWithContext interface { 231 Write(ctx context.Context, oprot thrift.TProtocol) error 232 } 233 234 // MessageWriter write to thrift.TProtocol 235 type MessageWriter interface { 236 Write(oprot thrift.TProtocol) error 237 } 238 239 // MessageReader read from thrift.TProtocol 240 type MessageReader interface { 241 Read(oprot thrift.TProtocol) error 242 } 243 244 // MessageReaderWithMethodWithContext read from thrift.TProtocol with method 245 type MessageReaderWithMethodWithContext interface { 246 Read(ctx context.Context, method string, oprot thrift.TProtocol) error 247 } 248 249 type ThriftMsgFastCodec interface { 250 BLength() int 251 FastWriteNocopy(buf []byte, binaryWriter bthrift.BinaryWriter) int 252 FastRead(buf []byte) (int, error) 253 } 254 255 func getValidData(methodName string, message remote.Message) (interface{}, error) { 256 if err := codec.NewDataIfNeeded(methodName, message); err != nil { 257 return nil, err 258 } 259 data := message.Data() 260 if message.MessageType() != remote.Exception { 261 return data, nil 262 } 263 transErr, isTransErr := data.(*remote.TransError) 264 if !isTransErr { 265 if err, isError := data.(error); isError { 266 encodeErr := thrift.NewTApplicationException(remote.InternalError, err.Error()) 267 return encodeErr, nil 268 } 269 return nil, errors.New("exception relay need error type data") 270 } 271 encodeErr := thrift.NewTApplicationException(transErr.TypeID(), transErr.Error()) 272 return encodeErr, nil 273 }