github.com/cloudwego/kitex@v0.9.0/pkg/generic/mapthrift_codec.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package generic
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"sync/atomic"
    24  
    25  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    26  	"github.com/cloudwego/kitex/pkg/generic/thrift"
    27  	"github.com/cloudwego/kitex/pkg/remote"
    28  	"github.com/cloudwego/kitex/pkg/remote/codec"
    29  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    30  )
    31  
    32  var (
    33  	_ remote.PayloadCodec = &mapThriftCodec{}
    34  	_ Closer              = &mapThriftCodec{}
    35  )
    36  
    37  type mapThriftCodec struct {
    38  	svcDsc              atomic.Value // *idl
    39  	provider            DescriptorProvider
    40  	codec               remote.PayloadCodec
    41  	forJSON             bool
    42  	binaryWithBase64    bool
    43  	binaryWithByteSlice bool
    44  }
    45  
    46  func newMapThriftCodec(p DescriptorProvider, codec remote.PayloadCodec) (*mapThriftCodec, error) {
    47  	svc := <-p.Provide()
    48  	c := &mapThriftCodec{
    49  		codec:               codec,
    50  		provider:            p,
    51  		binaryWithBase64:    false,
    52  		binaryWithByteSlice: false,
    53  	}
    54  	c.svcDsc.Store(svc)
    55  	go c.update()
    56  	return c, nil
    57  }
    58  
    59  func newMapThriftCodecForJSON(p DescriptorProvider, codec remote.PayloadCodec) (*mapThriftCodec, error) {
    60  	c, err := newMapThriftCodec(p, codec)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	c.forJSON = true
    65  	return c, nil
    66  }
    67  
    68  func (c *mapThriftCodec) update() {
    69  	for {
    70  		svc, ok := <-c.provider.Provide()
    71  		if !ok {
    72  			return
    73  		}
    74  		c.svcDsc.Store(svc)
    75  	}
    76  }
    77  
    78  func (c *mapThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
    79  	method := msg.RPCInfo().Invocation().MethodName()
    80  	if method == "" {
    81  		return errors.New("empty methodName in thrift Marshal")
    82  	}
    83  	if msg.MessageType() == remote.Exception {
    84  		return c.codec.Marshal(ctx, msg, out)
    85  	}
    86  	svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
    87  	if !ok {
    88  		return fmt.Errorf("get parser ServiceDescriptor failed")
    89  	}
    90  	wm, err := thrift.NewWriteStruct(svcDsc, method, msg.RPCRole() == remote.Client)
    91  	if err != nil {
    92  		return err
    93  	}
    94  	wm.SetBinaryWithBase64(c.binaryWithBase64)
    95  	msg.Data().(WithCodec).SetCodec(wm)
    96  	return c.codec.Marshal(ctx, msg, out)
    97  }
    98  
    99  func (c *mapThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
   100  	if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil {
   101  		return err
   102  	}
   103  	svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
   104  	if !ok {
   105  		return fmt.Errorf("get parser ServiceDescriptor failed")
   106  	}
   107  	var rm *thrift.ReadStruct
   108  	if c.forJSON {
   109  		rm = thrift.NewReadStructForJSON(svcDsc, msg.RPCRole() == remote.Client)
   110  	} else {
   111  		rm = thrift.NewReadStruct(svcDsc, msg.RPCRole() == remote.Client)
   112  	}
   113  	rm.SetBinaryOption(c.binaryWithBase64, c.binaryWithByteSlice)
   114  	msg.Data().(WithCodec).SetCodec(rm)
   115  	return c.codec.Unmarshal(ctx, msg, in)
   116  }
   117  
   118  func (c *mapThriftCodec) getMethod(req interface{}, method string) (*Method, error) {
   119  	fnSvc, err := c.svcDsc.Load().(*descriptor.ServiceDescriptor).LookupFunctionByMethod(method)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	return &Method{method, fnSvc.Oneway}, nil
   124  }
   125  
   126  func (c *mapThriftCodec) Name() string {
   127  	return "MapThrift"
   128  }
   129  
   130  func (c *mapThriftCodec) Close() error {
   131  	return c.provider.Close()
   132  }