github.com/cloudwego/kitex@v0.9.0/pkg/generic/mapthrift_codec.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package generic 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "sync/atomic" 24 25 "github.com/cloudwego/kitex/pkg/generic/descriptor" 26 "github.com/cloudwego/kitex/pkg/generic/thrift" 27 "github.com/cloudwego/kitex/pkg/remote" 28 "github.com/cloudwego/kitex/pkg/remote/codec" 29 "github.com/cloudwego/kitex/pkg/serviceinfo" 30 ) 31 32 var ( 33 _ remote.PayloadCodec = &mapThriftCodec{} 34 _ Closer = &mapThriftCodec{} 35 ) 36 37 type mapThriftCodec struct { 38 svcDsc atomic.Value // *idl 39 provider DescriptorProvider 40 codec remote.PayloadCodec 41 forJSON bool 42 binaryWithBase64 bool 43 binaryWithByteSlice bool 44 } 45 46 func newMapThriftCodec(p DescriptorProvider, codec remote.PayloadCodec) (*mapThriftCodec, error) { 47 svc := <-p.Provide() 48 c := &mapThriftCodec{ 49 codec: codec, 50 provider: p, 51 binaryWithBase64: false, 52 binaryWithByteSlice: false, 53 } 54 c.svcDsc.Store(svc) 55 go c.update() 56 return c, nil 57 } 58 59 func newMapThriftCodecForJSON(p DescriptorProvider, codec remote.PayloadCodec) (*mapThriftCodec, error) { 60 c, err := newMapThriftCodec(p, codec) 61 if err != nil { 62 return nil, err 63 } 64 c.forJSON = true 65 return c, nil 66 } 67 68 func (c *mapThriftCodec) update() { 69 for { 70 svc, ok := <-c.provider.Provide() 71 if !ok { 72 return 73 } 74 c.svcDsc.Store(svc) 75 } 76 } 77 78 func (c *mapThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 79 method := msg.RPCInfo().Invocation().MethodName() 80 if method == "" { 81 return errors.New("empty methodName in thrift Marshal") 82 } 83 if msg.MessageType() == remote.Exception { 84 return c.codec.Marshal(ctx, msg, out) 85 } 86 svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) 87 if !ok { 88 return fmt.Errorf("get parser ServiceDescriptor failed") 89 } 90 wm, err := thrift.NewWriteStruct(svcDsc, method, msg.RPCRole() == remote.Client) 91 if err != nil { 92 return err 93 } 94 wm.SetBinaryWithBase64(c.binaryWithBase64) 95 msg.Data().(WithCodec).SetCodec(wm) 96 return c.codec.Marshal(ctx, msg, out) 97 } 98 99 func (c *mapThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 100 if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { 101 return err 102 } 103 svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) 104 if !ok { 105 return fmt.Errorf("get parser ServiceDescriptor failed") 106 } 107 var rm *thrift.ReadStruct 108 if c.forJSON { 109 rm = thrift.NewReadStructForJSON(svcDsc, msg.RPCRole() == remote.Client) 110 } else { 111 rm = thrift.NewReadStruct(svcDsc, msg.RPCRole() == remote.Client) 112 } 113 rm.SetBinaryOption(c.binaryWithBase64, c.binaryWithByteSlice) 114 msg.Data().(WithCodec).SetCodec(rm) 115 return c.codec.Unmarshal(ctx, msg, in) 116 } 117 118 func (c *mapThriftCodec) getMethod(req interface{}, method string) (*Method, error) { 119 fnSvc, err := c.svcDsc.Load().(*descriptor.ServiceDescriptor).LookupFunctionByMethod(method) 120 if err != nil { 121 return nil, err 122 } 123 return &Method{method, fnSvc.Oneway}, nil 124 } 125 126 func (c *mapThriftCodec) Name() string { 127 return "MapThrift" 128 } 129 130 func (c *mapThriftCodec) Close() error { 131 return c.provider.Close() 132 }