github.com/cloudwego/kitex@v0.9.0/pkg/generic/thrift/http.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  
    23  	"github.com/apache/thrift/lib/go/thrift"
    24  	"github.com/cloudwego/dynamicgo/conv"
    25  	"github.com/cloudwego/dynamicgo/conv/t2j"
    26  	dthrift "github.com/cloudwego/dynamicgo/thrift"
    27  	jsoniter "github.com/json-iterator/go"
    28  
    29  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    30  	"github.com/cloudwego/kitex/pkg/protocol/bthrift"
    31  	"github.com/cloudwego/kitex/pkg/remote"
    32  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    33  	cthrift "github.com/cloudwego/kitex/pkg/remote/codec/thrift"
    34  )
    35  
    36  // WriteHTTPRequest implement of MessageWriter
    37  type WriteHTTPRequest struct {
    38  	svc                    *descriptor.ServiceDescriptor
    39  	dynamicgoTypeDsc       *dthrift.TypeDescriptor
    40  	binaryWithBase64       bool
    41  	convOpts               conv.Options // used for dynamicgo conversion
    42  	convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
    43  	hasRequestBase         bool
    44  	dynamicgoEnabled       bool
    45  }
    46  
    47  var (
    48  	_          MessageWriter = (*WriteHTTPRequest)(nil)
    49  	customJson               = jsoniter.Config{
    50  		EscapeHTML: true,
    51  		UseNumber:  true,
    52  	}.Froze()
    53  )
    54  
    55  // NewWriteHTTPRequest ...
    56  // Base64 decoding for binary is enabled by default.
    57  func NewWriteHTTPRequest(svc *descriptor.ServiceDescriptor) *WriteHTTPRequest {
    58  	return &WriteHTTPRequest{svc: svc, binaryWithBase64: true, dynamicgoEnabled: false}
    59  }
    60  
    61  // SetBinaryWithBase64 enable/disable Base64 decoding for binary.
    62  // Note that this method is not concurrent-safe.
    63  func (w *WriteHTTPRequest) SetBinaryWithBase64(enable bool) {
    64  	w.binaryWithBase64 = enable
    65  }
    66  
    67  // SetDynamicGo ...
    68  func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options, method string) error {
    69  	w.convOpts = *convOpts
    70  	w.convOptsWithThriftBase = *convOptsWithThriftBase
    71  	w.dynamicgoEnabled = true
    72  	fnDsc := w.svc.DynamicGoDsc.Functions()[method]
    73  	if fnDsc == nil {
    74  		return fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, w.svc.DynamicGoDsc.Name())
    75  	}
    76  	w.hasRequestBase = fnDsc.HasRequestBase()
    77  	w.dynamicgoTypeDsc = fnDsc.Request()
    78  	return nil
    79  }
    80  
    81  // originalWrite ...
    82  func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error {
    83  	req := msg.(*descriptor.HTTPRequest)
    84  	if req.Body == nil && len(req.RawBody) != 0 {
    85  		if err := customJson.Unmarshal(req.RawBody, &req.Body); err != nil {
    86  			return err
    87  		}
    88  	}
    89  	fn, err := w.svc.Router.Lookup(req)
    90  	if err != nil {
    91  		return err
    92  	}
    93  	if !fn.HasRequestBase {
    94  		requestBase = nil
    95  	}
    96  	return wrapStructWriter(ctx, req, out, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64})
    97  }
    98  
    99  // ReadHTTPResponse implement of MessageReaderWithMethod
   100  type ReadHTTPResponse struct {
   101  	svc                   *descriptor.ServiceDescriptor
   102  	base64Binary          bool
   103  	msg                   remote.Message
   104  	dynamicgoEnabled      bool
   105  	useRawBodyForHTTPResp bool
   106  	t2jBinaryConv         t2j.BinaryConv // used for dynamicgo thrift to json conversion
   107  }
   108  
   109  var _ MessageReader = (*ReadHTTPResponse)(nil)
   110  
   111  // NewReadHTTPResponse ...
   112  // Base64 encoding for binary is enabled by default.
   113  func NewReadHTTPResponse(svc *descriptor.ServiceDescriptor) *ReadHTTPResponse {
   114  	return &ReadHTTPResponse{svc: svc, base64Binary: true, dynamicgoEnabled: false, useRawBodyForHTTPResp: false}
   115  }
   116  
   117  // SetBase64Binary enable/disable Base64 encoding for binary.
   118  // Note that this method is not concurrent-safe.
   119  func (r *ReadHTTPResponse) SetBase64Binary(enable bool) {
   120  	r.base64Binary = enable
   121  }
   122  
   123  // SetUseRawBodyForHTTPResp ...
   124  func (r *ReadHTTPResponse) SetUseRawBodyForHTTPResp(useRawBodyForHTTPResp bool) {
   125  	r.useRawBodyForHTTPResp = useRawBodyForHTTPResp
   126  }
   127  
   128  // SetDynamicGo ...
   129  func (r *ReadHTTPResponse) SetDynamicGo(convOpts *conv.Options, msg remote.Message) {
   130  	r.t2jBinaryConv = t2j.NewBinaryConv(*convOpts)
   131  	r.msg = msg
   132  	r.dynamicgoEnabled = true
   133  }
   134  
   135  // Read ...
   136  func (r *ReadHTTPResponse) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) {
   137  	// fallback logic
   138  	if !r.dynamicgoEnabled {
   139  		return r.originalRead(ctx, method, in)
   140  	}
   141  
   142  	// dynamicgo logic
   143  	if r.msg.PayloadLen() == 0 {
   144  		return nil, perrors.NewProtocolErrorWithMsg("msg.PayloadLen should always be greater than zero")
   145  	}
   146  
   147  	tProt, ok := in.(*cthrift.BinaryProtocol)
   148  	if !ok {
   149  		return nil, perrors.NewProtocolErrorWithMsg("TProtocol should be BinaryProtocol")
   150  	}
   151  	mBeginLen := bthrift.Binary.MessageBeginLength(method, thrift.TMessageType(r.msg.MessageType()), r.msg.RPCInfo().Invocation().SeqID())
   152  	sName, err := in.ReadStructBegin()
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	sBeginLen := bthrift.Binary.StructBeginLength(sName)
   157  	// TODO: support exception field
   158  	fName, typeId, id, err := in.ReadFieldBegin()
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	fBeginLen := bthrift.Binary.FieldBeginLength(fName, typeId, id)
   163  	transBuf, err := tProt.ByteBuffer().ReadBinary(r.msg.PayloadLen() - mBeginLen - sBeginLen - fBeginLen - bthrift.Binary.MessageEndLength())
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  	fid := dthrift.FieldID(id)
   168  
   169  	resp := descriptor.NewHTTPResponse()
   170  	ctx = context.WithValue(ctx, conv.CtxKeyHTTPResponse, resp)
   171  	fnDsc := r.svc.DynamicGoDsc.Functions()[method]
   172  	if fnDsc == nil {
   173  		return nil, fmt.Errorf("missing method: %s in service: %s in dynamicgo", method, r.svc.DynamicGoDsc.Name())
   174  	}
   175  	tyDsc := fnDsc.Response()
   176  	// json size is usually 2 times larger than equivalent thrift data
   177  	buf := make([]byte, 0, len(transBuf)*2)
   178  
   179  	for _, field := range tyDsc.Struct().Fields() {
   180  		if fid == field.ID() {
   181  			// decode with dynamicgo
   182  			// thrift []byte to json []byte
   183  			if err = r.t2jBinaryConv.DoInto(ctx, field.Type(), transBuf, &buf); err != nil {
   184  				return nil, err
   185  			}
   186  			break
   187  		}
   188  	}
   189  	resp.RawBody = buf
   190  	return resp, nil
   191  }
   192  
   193  func (r *ReadHTTPResponse) originalRead(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) {
   194  	fnDsc, err := r.svc.LookupFunctionByMethod(method)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  	fDsc := fnDsc.Response
   199  	resp, err := skipStructReader(ctx, in, fDsc, &readerOption{forJSON: true, http: true, binaryWithBase64: r.base64Binary})
   200  	if r.useRawBodyForHTTPResp {
   201  		if httpResp, ok := resp.(*descriptor.HTTPResponse); ok && httpResp.Body != nil {
   202  			rawBody, err := customJson.Marshal(httpResp.Body)
   203  			if err != nil {
   204  				return nil, err
   205  			}
   206  			httpResp.RawBody = rawBody
   207  		}
   208  	}
   209  	return resp, err
   210  }