github.com/cloudwego/kitex@v0.9.0/pkg/generic/thrift/json.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strconv"
    23  
    24  	"github.com/apache/thrift/lib/go/thrift"
    25  	"github.com/cloudwego/dynamicgo/conv"
    26  	"github.com/cloudwego/dynamicgo/conv/t2j"
    27  	dthrift "github.com/cloudwego/dynamicgo/thrift"
    28  	jsoniter "github.com/json-iterator/go"
    29  	"github.com/tidwall/gjson"
    30  
    31  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    32  	"github.com/cloudwego/kitex/pkg/protocol/bthrift"
    33  	"github.com/cloudwego/kitex/pkg/remote"
    34  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    35  	cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift"
    36  	"github.com/cloudwego/kitex/pkg/utils"
    37  )
    38  
    39  // NewWriteJSON build WriteJSON according to ServiceDescriptor
    40  func NewWriteJSON(svc *descriptor.ServiceDescriptor, method string, isClient bool) (*WriteJSON, error) {
    41  	fnDsc, err := svc.LookupFunctionByMethod(method)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  	ty := fnDsc.Request
    46  	if !isClient {
    47  		ty = fnDsc.Response
    48  	}
    49  	ws := &WriteJSON{
    50  		typeDsc:          ty,
    51  		hasRequestBase:   fnDsc.HasRequestBase && isClient,
    52  		base64Binary:     true,
    53  		isClient:         isClient,
    54  		dynamicgoEnabled: false,
    55  	}
    56  	return ws, nil
    57  }
    58  
    59  const voidWholeLen = 5
    60  
    61  var _ = wrapJSONWriter
    62  
    63  // WriteJSON implement of MessageWriter
    64  type WriteJSON struct {
    65  	typeDsc                *descriptor.TypeDescriptor
    66  	dynamicgoTypeDsc       *dthrift.TypeDescriptor
    67  	hasRequestBase         bool
    68  	base64Binary           bool
    69  	isClient               bool
    70  	convOpts               conv.Options // used for dynamicgo conversion
    71  	convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
    72  	dynamicgoEnabled       bool
    73  }
    74  
    75  var _ MessageWriter = (*WriteJSON)(nil)
    76  
    77  // SetBase64Binary enable/disable Base64 decoding for binary.
    78  // Note that this method is not concurrent-safe.
    79  func (m *WriteJSON) SetBase64Binary(enable bool) {
    80  	m.base64Binary = enable
    81  }
    82  
    83  // SetDynamicGo ...
    84  func (m *WriteJSON) SetDynamicGo(svc *descriptor.ServiceDescriptor, method string, convOpts, convOptsWithThriftBase *conv.Options) error {
    85  	fnDsc := svc.DynamicGoDsc.Functions()[method]
    86  	if fnDsc == nil {
    87  		return fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, svc.DynamicGoDsc.Name())
    88  	}
    89  	if m.isClient {
    90  		m.dynamicgoTypeDsc = fnDsc.Request()
    91  	} else {
    92  		m.dynamicgoTypeDsc = fnDsc.Response()
    93  	}
    94  	m.convOpts = *convOpts
    95  	m.convOptsWithThriftBase = *convOptsWithThriftBase
    96  	m.dynamicgoEnabled = true
    97  	return nil
    98  }
    99  
   100  func (m *WriteJSON) originalWrite(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error {
   101  	if !m.hasRequestBase {
   102  		requestBase = nil
   103  	}
   104  
   105  	// msg is void or nil
   106  	if _, ok := msg.(descriptor.Void); ok || msg == nil {
   107  		return wrapStructWriter(ctx, msg, out, m.typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary})
   108  	}
   109  
   110  	// msg is string
   111  	s, ok := msg.(string)
   112  	if !ok {
   113  		return perrors.NewProtocolErrorWithType(perrors.InvalidData, "decode msg failed, is not string")
   114  	}
   115  
   116  	body := gjson.Parse(s)
   117  	if body.Type == gjson.Null {
   118  		body = gjson.Result{
   119  			Type:  gjson.String,
   120  			Raw:   s,
   121  			Str:   s,
   122  			Num:   0,
   123  			Index: 0,
   124  		}
   125  	}
   126  	return wrapJSONWriter(ctx, &body, out, m.typeDsc, &writerOption{requestBase: requestBase, binaryWithBase64: m.base64Binary})
   127  }
   128  
   129  // NewReadJSON build ReadJSON according to ServiceDescriptor
   130  func NewReadJSON(svc *descriptor.ServiceDescriptor, isClient bool) *ReadJSON {
   131  	return &ReadJSON{
   132  		svc:              svc,
   133  		isClient:         isClient,
   134  		binaryWithBase64: true,
   135  		dynamicgoEnabled: false,
   136  	}
   137  }
   138  
   139  // ReadJSON implement of MessageReaderWithMethod
   140  type ReadJSON struct {
   141  	svc              *descriptor.ServiceDescriptor
   142  	isClient         bool
   143  	binaryWithBase64 bool
   144  	msg              remote.Message
   145  	t2jBinaryConv    t2j.BinaryConv // used for dynamicgo thrift to json conversion
   146  	dynamicgoEnabled bool
   147  }
   148  
   149  var _ MessageReader = (*ReadJSON)(nil)
   150  
   151  // SetBinaryWithBase64 enable/disable Base64 encoding for binary.
   152  // Note that this method is not concurrent-safe.
   153  func (m *ReadJSON) SetBinaryWithBase64(enable bool) {
   154  	m.binaryWithBase64 = enable
   155  }
   156  
   157  // SetDynamicGo ...
   158  func (m *ReadJSON) SetDynamicGo(convOpts, convOptsWithException *conv.Options, msg remote.Message) {
   159  	m.msg = msg
   160  	m.dynamicgoEnabled = true
   161  	if m.isClient {
   162  		// set binary conv to handle an exception field
   163  		m.t2jBinaryConv = t2j.NewBinaryConv(*convOptsWithException)
   164  	} else {
   165  		m.t2jBinaryConv = t2j.NewBinaryConv(*convOpts)
   166  	}
   167  }
   168  
   169  // Read read data from in thrift.TProtocol and convert to json string
   170  func (m *ReadJSON) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) {
   171  	// fallback logic
   172  	if !m.dynamicgoEnabled {
   173  		return m.originalRead(ctx, method, in)
   174  	}
   175  
   176  	// dynamicgo logic
   177  	tProt, ok := in.(*cthrift.BinaryProtocol)
   178  	if !ok {
   179  		return nil, perrors.NewProtocolErrorWithMsg("TProtocol should be BinaryProtocol")
   180  	}
   181  
   182  	fnDsc := m.svc.DynamicGoDsc.Functions()[method]
   183  	if fnDsc == nil {
   184  		return nil, fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, m.svc.DynamicGoDsc.Name())
   185  	}
   186  	var tyDsc *dthrift.TypeDescriptor
   187  	if m.msg.MessageType() == remote.Reply {
   188  		tyDsc = fnDsc.Response()
   189  	} else {
   190  		tyDsc = fnDsc.Request()
   191  	}
   192  
   193  	var resp interface{}
   194  	if tyDsc.Struct().Fields()[0].Type().Type() == dthrift.VOID {
   195  		if _, err := tProt.ByteBuffer().ReadBinary(voidWholeLen); err != nil {
   196  			return nil, err
   197  		}
   198  		resp = descriptor.Void{}
   199  	} else {
   200  		msgBeginLen := bthrift.Binary.MessageBeginLength(method, thrift.TMessageType(m.msg.MessageType()), m.msg.RPCInfo().Invocation().SeqID())
   201  		transBuff, err := tProt.ByteBuffer().ReadBinary(m.msg.PayloadLen() - msgBeginLen - bthrift.Binary.MessageEndLength())
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  
   206  		// json size is usually 2 times larger than equivalent thrift data
   207  		buf := make([]byte, 0, len(transBuff)*2)
   208  		// thrift []byte to json []byte
   209  		if err := m.t2jBinaryConv.DoInto(ctx, tyDsc, transBuff, &buf); err != nil {
   210  			return nil, err
   211  		}
   212  		buf = removePrefixAndSuffix(buf)
   213  		resp = utils.SliceByteToString(buf)
   214  		if tyDsc.Struct().Fields()[0].Type().Type() == dthrift.STRING {
   215  			strresp := resp.(string)
   216  			resp, err = strconv.Unquote(strresp)
   217  			if err != nil {
   218  				return nil, err
   219  			}
   220  		}
   221  	}
   222  
   223  	return resp, nil
   224  }
   225  
   226  func (m *ReadJSON) originalRead(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) {
   227  	fnDsc, err := m.svc.LookupFunctionByMethod(method)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  	fDsc := fnDsc.Response
   232  	if !m.isClient {
   233  		fDsc = fnDsc.Request
   234  	}
   235  	resp, err := skipStructReader(ctx, in, fDsc, &readerOption{forJSON: true, throwException: true, binaryWithBase64: m.binaryWithBase64})
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  
   240  	// resp is void
   241  	if _, ok := resp.(descriptor.Void); ok {
   242  		return resp, nil
   243  	}
   244  
   245  	// resp is string
   246  	if _, ok := resp.(string); ok {
   247  		return resp, nil
   248  	}
   249  
   250  	// resp is map
   251  	respNode, err := jsoniter.Marshal(resp)
   252  	if err != nil {
   253  		return nil, perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("response marshal failed. err:%#v", err))
   254  	}
   255  
   256  	return string(respNode), nil
   257  }
   258  
   259  // removePrefixAndSuffix removes json []byte from prefix `{"":` and suffix `}`
   260  func removePrefixAndSuffix(buf []byte) []byte {
   261  	if len(buf) > structWrapLen {
   262  		return buf[structWrapLen : len(buf)-1]
   263  	}
   264  	return buf
   265  }