github.com/cloudwego/kitex@v0.9.0/pkg/generic/httppbthrift_codec.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package generic
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"io/ioutil"
    25  	"net/http"
    26  	"strings"
    27  	"sync/atomic"
    28  
    29  	"github.com/jhump/protoreflect/desc"
    30  
    31  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    32  	"github.com/cloudwego/kitex/pkg/generic/proto"
    33  	"github.com/cloudwego/kitex/pkg/generic/thrift"
    34  	"github.com/cloudwego/kitex/pkg/remote"
    35  	"github.com/cloudwego/kitex/pkg/remote/codec"
    36  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    37  )
    38  
    39  type httpPbThriftCodec struct {
    40  	svcDsc     atomic.Value // *idl
    41  	pbSvcDsc   atomic.Value // *pbIdl
    42  	provider   DescriptorProvider
    43  	pbProvider PbDescriptorProvider
    44  	codec      remote.PayloadCodec
    45  }
    46  
    47  func newHTTPPbThriftCodec(p DescriptorProvider, pbp PbDescriptorProvider, codec remote.PayloadCodec) (*httpPbThriftCodec, error) {
    48  	svc := <-p.Provide()
    49  	pbSvc := <-pbp.Provide()
    50  	c := &httpPbThriftCodec{codec: codec, provider: p, pbProvider: pbp}
    51  	c.svcDsc.Store(svc)
    52  	c.pbSvcDsc.Store(pbSvc)
    53  	go c.update()
    54  	return c, nil
    55  }
    56  
    57  func (c *httpPbThriftCodec) update() {
    58  	for {
    59  		svc, ok := <-c.provider.Provide()
    60  		if !ok {
    61  			return
    62  		}
    63  
    64  		pbSvc, ok := <-c.pbProvider.Provide()
    65  		if !ok {
    66  			return
    67  		}
    68  
    69  		c.svcDsc.Store(svc)
    70  		c.pbSvcDsc.Store(pbSvc)
    71  	}
    72  }
    73  
    74  func (c *httpPbThriftCodec) getMethod(req interface{}) (*Method, error) {
    75  	svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
    76  	if !ok {
    77  		return nil, errors.New("get method name failed, no ServiceDescriptor")
    78  	}
    79  	r, ok := req.(*HTTPRequest)
    80  	if !ok {
    81  		return nil, errors.New("req is invalid, need descriptor.HTTPRequest")
    82  	}
    83  	function, err := svcDsc.Router.Lookup(r)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	return &Method{function.Name, function.Oneway}, nil
    88  }
    89  
    90  func (c *httpPbThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
    91  	svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
    92  	if !ok {
    93  		return fmt.Errorf("get parser ServiceDescriptor failed")
    94  	}
    95  	pbSvcDsc, ok := c.pbSvcDsc.Load().(*desc.ServiceDescriptor)
    96  	if !ok {
    97  		return fmt.Errorf("get parser PbServiceDescriptor failed")
    98  	}
    99  
   100  	inner := thrift.NewWriteHTTPPbRequest(svcDsc, pbSvcDsc)
   101  	msg.Data().(WithCodec).SetCodec(inner)
   102  	return c.codec.Marshal(ctx, msg, out)
   103  }
   104  
   105  func (c *httpPbThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
   106  	if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil {
   107  		return err
   108  	}
   109  	svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
   110  	if !ok {
   111  		return fmt.Errorf("get parser ServiceDescriptor failed")
   112  	}
   113  	pbSvcDsc, ok := c.pbSvcDsc.Load().(proto.ServiceDescriptor)
   114  	if !ok {
   115  		return fmt.Errorf("get parser PbServiceDescriptor failed")
   116  	}
   117  
   118  	inner := thrift.NewReadHTTPPbResponse(svcDsc, pbSvcDsc)
   119  	msg.Data().(WithCodec).SetCodec(inner)
   120  	return c.codec.Unmarshal(ctx, msg, in)
   121  }
   122  
   123  func (c *httpPbThriftCodec) Name() string {
   124  	return "HttpPbThrift"
   125  }
   126  
   127  func (c *httpPbThriftCodec) Close() error {
   128  	var errs []string
   129  	if err := c.provider.Close(); err != nil {
   130  		errs = append(errs, err.Error())
   131  	}
   132  	if err := c.pbProvider.Close(); err != nil {
   133  		errs = append(errs, err.Error())
   134  	}
   135  
   136  	if len(errs) == 0 {
   137  		return nil
   138  	} else {
   139  		return errors.New(strings.Join(errs, ";"))
   140  	}
   141  }
   142  
   143  // FromHTTPPbRequest parse  HTTPRequest from http.Request
   144  func FromHTTPPbRequest(req *http.Request) (*HTTPRequest, error) {
   145  	customReq := &HTTPRequest{
   146  		Request:     req,
   147  		ContentType: descriptor.MIMEApplicationProtobuf,
   148  	}
   149  	var b io.ReadCloser
   150  	var err error
   151  	if req.GetBody != nil {
   152  		// req from ServerHTTP or create by http.NewRequest
   153  		if b, err = req.GetBody(); err != nil {
   154  			return nil, err
   155  		}
   156  	} else {
   157  		b = req.Body
   158  	}
   159  	if b == nil {
   160  		// body == nil if from Get request
   161  		return customReq, nil
   162  	}
   163  	if customReq.RawBody, err = ioutil.ReadAll(b); err != nil {
   164  		return nil, err
   165  	}
   166  	if len(customReq.RawBody) == 0 {
   167  		return customReq, nil
   168  	}
   169  
   170  	return customReq, nil
   171  }