github.com/cloudwego/kitex@v0.9.0/pkg/generic/thrift/struct.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  
    22  	"github.com/apache/thrift/lib/go/thrift"
    23  
    24  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    25  )
    26  
    27  // NewWriteStruct ...
    28  func NewWriteStruct(svc *descriptor.ServiceDescriptor, method string, isClient bool) (*WriteStruct, error) {
    29  	fnDsc, err := svc.LookupFunctionByMethod(method)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  	ty := fnDsc.Request
    34  	if !isClient {
    35  		ty = fnDsc.Response
    36  	}
    37  	ws := &WriteStruct{
    38  		ty:             ty,
    39  		hasRequestBase: fnDsc.HasRequestBase && isClient,
    40  	}
    41  	return ws, nil
    42  }
    43  
    44  // WriteStruct implement of MessageWriter
    45  type WriteStruct struct {
    46  	ty               *descriptor.TypeDescriptor
    47  	hasRequestBase   bool
    48  	binaryWithBase64 bool
    49  }
    50  
    51  var _ MessageWriter = (*WriteStruct)(nil)
    52  
    53  // SetBinaryWithBase64 enable/disable Base64 decoding for binary.
    54  // Note that this method is not concurrent-safe.
    55  func (m *WriteStruct) SetBinaryWithBase64(enable bool) {
    56  	m.binaryWithBase64 = enable
    57  }
    58  
    59  // Write ...
    60  func (m *WriteStruct) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error {
    61  	if !m.hasRequestBase {
    62  		requestBase = nil
    63  	}
    64  	return wrapStructWriter(ctx, msg, out, m.ty, &writerOption{requestBase: requestBase, binaryWithBase64: m.binaryWithBase64})
    65  }
    66  
    67  // NewReadStruct ...
    68  func NewReadStruct(svc *descriptor.ServiceDescriptor, isClient bool) *ReadStruct {
    69  	return &ReadStruct{
    70  		svc:      svc,
    71  		isClient: isClient,
    72  	}
    73  }
    74  
    75  func NewReadStructForJSON(svc *descriptor.ServiceDescriptor, isClient bool) *ReadStruct {
    76  	return &ReadStruct{
    77  		svc:      svc,
    78  		isClient: isClient,
    79  		forJSON:  true,
    80  	}
    81  }
    82  
    83  // ReadStruct implement of MessageReaderWithMethod
    84  type ReadStruct struct {
    85  	svc                 *descriptor.ServiceDescriptor
    86  	isClient            bool
    87  	forJSON             bool
    88  	binaryWithBase64    bool
    89  	binaryWithByteSlice bool
    90  }
    91  
    92  var _ MessageReader = (*ReadStruct)(nil)
    93  
    94  // SetBinaryOption enable/disable Base64 encoding or returning []byte for binary.
    95  // Note that this method is not concurrent-safe.
    96  func (m *ReadStruct) SetBinaryOption(base64, byteSlice bool) {
    97  	m.binaryWithBase64 = base64
    98  	m.binaryWithByteSlice = byteSlice
    99  }
   100  
   101  // Read ...
   102  func (m *ReadStruct) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) {
   103  	fnDsc, err := m.svc.LookupFunctionByMethod(method)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	fDsc := fnDsc.Response
   108  	if !m.isClient {
   109  		fDsc = fnDsc.Request
   110  	}
   111  	return skipStructReader(ctx, in, fDsc, &readerOption{throwException: true, forJSON: m.forJSON, binaryWithBase64: m.binaryWithBase64, binaryWithByteSlice: m.binaryWithByteSlice})
   112  }