github.com/cloudwego/kitex@v0.9.0/pkg/generic/httppbthrift_codec.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package generic 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "io" 24 "io/ioutil" 25 "net/http" 26 "strings" 27 "sync/atomic" 28 29 "github.com/jhump/protoreflect/desc" 30 31 "github.com/cloudwego/kitex/pkg/generic/descriptor" 32 "github.com/cloudwego/kitex/pkg/generic/proto" 33 "github.com/cloudwego/kitex/pkg/generic/thrift" 34 "github.com/cloudwego/kitex/pkg/remote" 35 "github.com/cloudwego/kitex/pkg/remote/codec" 36 "github.com/cloudwego/kitex/pkg/serviceinfo" 37 ) 38 39 type httpPbThriftCodec struct { 40 svcDsc atomic.Value // *idl 41 pbSvcDsc atomic.Value // *pbIdl 42 provider DescriptorProvider 43 pbProvider PbDescriptorProvider 44 codec remote.PayloadCodec 45 } 46 47 func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider, codec remote.PayloadCodec) (*httpPbThriftCodec, error) { 48 svc := <-p.Provide() 49 pbSvc := <-pbp.Provide() 50 c := &httpPbThriftCodec{codec: codec, provider: p, pbProvider: pbp} 51 c.svcDsc.Store(svc) 52 c.pbSvcDsc.Store(pbSvc) 53 go c.update() 54 return c, nil 55 } 56 57 func (c *httpPbThriftCodec) update() { 58 for { 59 svc, ok := <-c.provider.Provide() 60 if !ok { 61 return 62 } 63 64 pbSvc, ok := <-c.pbProvider.Provide() 65 if !ok { 66 return 67 } 68 69 c.svcDsc.Store(svc) 70 c.pbSvcDsc.Store(pbSvc) 71 } 72 } 73 74 func (c *httpPbThriftCodec) getMethod(req interface{}) (*Method, error) { 75 svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) 76 if !ok { 77 return nil, errors.New("get method name failed, no ServiceDescriptor") 78 } 79 r, ok := req.(*HTTPRequest) 80 if !ok { 81 return nil, errors.New("req is invalid, need descriptor.HTTPRequest") 82 } 83 function, err := svcDsc.Router.Lookup(r) 84 if err != nil { 85 return nil, err 86 } 87 return &Method{function.Name, function.Oneway}, nil 88 } 89 90 func (c *httpPbThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 91 svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) 92 if !ok { 93 return fmt.Errorf("get parser ServiceDescriptor failed") 94 } 95 pbSvcDsc, ok := c.pbSvcDsc.Load().(*desc.ServiceDescriptor) 96 if !ok { 97 return fmt.Errorf("get parser PbServiceDescriptor failed") 98 } 99 100 inner := thrift.NewWriteHTTPPbRequest(svcDsc, pbSvcDsc) 101 msg.Data().(WithCodec).SetCodec(inner) 102 return c.codec.Marshal(ctx, msg, out) 103 } 104 105 func (c *httpPbThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 106 if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil { 107 return err 108 } 109 svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor) 110 if !ok { 111 return fmt.Errorf("get parser ServiceDescriptor failed") 112 } 113 pbSvcDsc, ok := c.pbSvcDsc.Load().(proto.ServiceDescriptor) 114 if !ok { 115 return fmt.Errorf("get parser PbServiceDescriptor failed") 116 } 117 118 inner := thrift.NewReadHTTPPbResponse(svcDsc, pbSvcDsc) 119 msg.Data().(WithCodec).SetCodec(inner) 120 return c.codec.Unmarshal(ctx, msg, in) 121 } 122 123 func (c *httpPbThriftCodec) Name() string { 124 return "HttpPbThrift" 125 } 126 127 func (c *httpPbThriftCodec) Close() error { 128 var errs []string 129 if err := c.provider.Close(); err != nil { 130 errs = append(errs, err.Error()) 131 } 132 if err := c.pbProvider.Close(); err != nil { 133 errs = append(errs, err.Error()) 134 } 135 136 if len(errs) == 0 { 137 return nil 138 } else { 139 return errors.New(strings.Join(errs, ";")) 140 } 141 } 142 143 // FromHTTPPbRequest parse HTTPRequest from http.Request 144 func FromHTTPPbRequest(req *http.Request) (*HTTPRequest, error) { 145 customReq := &HTTPRequest{ 146 Request: req, 147 ContentType: descriptor.MIMEApplicationProtobuf, 148 } 149 var b io.ReadCloser 150 var err error 151 if req.GetBody != nil { 152 // req from ServerHTTP or create by http.NewRequest 153 if b, err = req.GetBody(); err != nil { 154 return nil, err 155 } 156 } else { 157 b = req.Body 158 } 159 if b == nil { 160 // body == nil if from Get request 161 return customReq, nil 162 } 163 if customReq.RawBody, err = ioutil.ReadAll(b); err != nil { 164 return nil, err 165 } 166 if len(customReq.RawBody) == 0 { 167 return customReq, nil 168 } 169 170 return customReq, nil 171 }