github.com/cloudwego/kitex@v0.9.0/pkg/utils/thrift.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package utils
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  
    25  	"github.com/apache/thrift/lib/go/thrift"
    26  )
    27  
    28  // ThriftMessageCodec is used to codec thrift messages.
    29  type ThriftMessageCodec struct {
    30  	tb    *thrift.TMemoryBuffer
    31  	tProt thrift.TProtocol
    32  }
    33  
    34  // NewThriftMessageCodec creates a new ThriftMessageCodec.
    35  func NewThriftMessageCodec() *ThriftMessageCodec {
    36  	transport := thrift.NewTMemoryBufferLen(1024)
    37  	tProt := thrift.NewTBinaryProtocol(transport, true, true)
    38  
    39  	return &ThriftMessageCodec{
    40  		tb:    transport,
    41  		tProt: tProt,
    42  	}
    43  }
    44  
    45  // Encode do thrift message encode.
    46  // Notice! msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result
    47  // Notice! seqID will be reset in kitex if the buffer is used for generic call in client side, set seqID=0 is suggested
    48  // when you call this method as client.
    49  func (t *ThriftMessageCodec) Encode(method string, msgType thrift.TMessageType, seqID int32, msg thrift.TStruct) (b []byte, err error) {
    50  	if method == "" {
    51  		return nil, errors.New("empty methodName in thrift RPCEncode")
    52  	}
    53  	t.tb.Reset()
    54  	if err = t.tProt.WriteMessageBegin(method, msgType, seqID); err != nil {
    55  		return
    56  	}
    57  	if err = msg.Write(t.tProt); err != nil {
    58  		return
    59  	}
    60  	if err = t.tProt.WriteMessageEnd(); err != nil {
    61  		return
    62  	}
    63  	b = append(b, t.tb.Bytes()...)
    64  	return
    65  }
    66  
    67  // Decode do thrift message decode, notice: msg must be XXXArgs/XXXResult that the wrap struct for args and result, not the actual args or result
    68  func (t *ThriftMessageCodec) Decode(b []byte, msg thrift.TStruct) (method string, seqID int32, err error) {
    69  	t.tb.Reset()
    70  	if _, err = t.tb.Write(b); err != nil {
    71  		return
    72  	}
    73  	var msgType thrift.TMessageType
    74  	if method, msgType, seqID, err = t.tProt.ReadMessageBegin(); err != nil {
    75  		return
    76  	}
    77  	if msgType == thrift.EXCEPTION {
    78  		exception := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "")
    79  		if err = exception.Read(t.tProt); err != nil {
    80  			return
    81  		}
    82  		if err = t.tProt.ReadMessageEnd(); err != nil {
    83  			return
    84  		}
    85  		err = exception
    86  		return
    87  	}
    88  	if err = msg.Read(t.tProt); err != nil {
    89  		return
    90  	}
    91  	t.tProt.ReadMessageEnd()
    92  	return
    93  }
    94  
    95  // Serialize serialize message into bytes. This is normal thrift serialize func.
    96  // Notice: Binary generic use Encode instead of Serialize.
    97  func (t *ThriftMessageCodec) Serialize(msg thrift.TStruct) (b []byte, err error) {
    98  	t.tb.Reset()
    99  
   100  	if err = msg.Write(t.tProt); err != nil {
   101  		return
   102  	}
   103  	b = append(b, t.tb.Bytes()...)
   104  	return
   105  }
   106  
   107  // Deserialize deserialize bytes into message. This is normal thrift deserialize func.
   108  // Notice: Binary generic use Decode instead of Deserialize.
   109  func (t *ThriftMessageCodec) Deserialize(msg thrift.TStruct, b []byte) (err error) {
   110  	t.tb.Reset()
   111  	if _, err = t.tb.Write(b); err != nil {
   112  		return
   113  	}
   114  	if err = msg.Read(t.tProt); err != nil {
   115  		return
   116  	}
   117  	return nil
   118  }
   119  
   120  // MarshalError convert go error to thrift exception, and encode exception over buffered binary transport.
   121  func MarshalError(method string, err error) []byte {
   122  	e := thrift.NewTApplicationException(thrift.INTERNAL_ERROR, err.Error())
   123  	var buf bytes.Buffer
   124  	trans := thrift.NewStreamTransportRW(&buf)
   125  	proto := thrift.NewTBinaryProtocol(trans, true, true)
   126  	if err := proto.WriteMessageBegin(method, thrift.EXCEPTION, 0); err != nil {
   127  		return nil
   128  	}
   129  	if err := e.Write(proto); err != nil {
   130  		return nil
   131  	}
   132  	if err := proto.WriteMessageEnd(); err != nil {
   133  		return nil
   134  	}
   135  	if err := proto.Flush(context.Background()); err != nil {
   136  		return nil
   137  	}
   138  	return buf.Bytes()
   139  }
   140  
   141  // UnmarshalError decode binary and return error message
   142  func UnmarshalError(b []byte) error {
   143  	trans := thrift.NewStreamTransportR(bytes.NewReader(b))
   144  	proto := thrift.NewTBinaryProtocolTransport(trans)
   145  	if _, _, _, err := proto.ReadMessageBegin(); err != nil {
   146  		return fmt.Errorf("read message begin error: %w", err)
   147  	}
   148  	e := thrift.NewTApplicationException(0, "")
   149  	if err := e.Read(proto); err != nil {
   150  		return fmt.Errorf("read exception error: %w", err)
   151  	}
   152  	if err := proto.ReadMessageEnd(); err != nil {
   153  		return fmt.Errorf("read message end error: %w", err)
   154  	}
   155  	return e
   156  }