github.com/cloudwego/kitex@v0.9.0/pkg/generic/httpthrift_codec.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package generic
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"io/ioutil"
    25  	"net/http"
    26  	"sync/atomic"
    27  
    28  	"github.com/cloudwego/dynamicgo/conv"
    29  
    30  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    31  	"github.com/cloudwego/kitex/pkg/generic/thrift"
    32  	"github.com/cloudwego/kitex/pkg/remote"
    33  	"github.com/cloudwego/kitex/pkg/remote/codec"
    34  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    35  )
    36  
    37  var (
    38  	_ remote.PayloadCodec = &httpThriftCodec{}
    39  	_ Closer              = &httpThriftCodec{}
    40  )
    41  
    42  // HTTPRequest alias of descriptor HTTPRequest
    43  type HTTPRequest = descriptor.HTTPRequest
    44  
    45  // HTTPResponse alias of descriptor HTTPResponse
    46  type HTTPResponse = descriptor.HTTPResponse
    47  
    48  type httpThriftCodec struct {
    49  	svcDsc                 atomic.Value // *idl
    50  	provider               DescriptorProvider
    51  	codec                  remote.PayloadCodec
    52  	binaryWithBase64       bool
    53  	convOpts               conv.Options // used for dynamicgo conversion
    54  	convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
    55  	dynamicgoEnabled       bool
    56  	useRawBodyForHTTPResp  bool
    57  }
    58  
    59  func newHTTPThriftCodec(p DescriptorProvider, codec remote.PayloadCodec, opts *Options) (*httpThriftCodec, error) {
    60  	svc := <-p.Provide()
    61  	c := &httpThriftCodec{codec: codec, provider: p, binaryWithBase64: false, dynamicgoEnabled: false, useRawBodyForHTTPResp: opts.useRawBodyForHTTPResp}
    62  	if dp, ok := p.(GetProviderOption); ok && dp.Option().DynamicGoEnabled {
    63  		c.dynamicgoEnabled = true
    64  
    65  		convOpts := opts.dynamicgoConvOpts
    66  		c.convOpts = convOpts
    67  
    68  		convOpts.EnableThriftBase = true
    69  		c.convOptsWithThriftBase = convOpts
    70  	}
    71  	c.svcDsc.Store(svc)
    72  	go c.update()
    73  	return c, nil
    74  }
    75  
    76  func (c *httpThriftCodec) update() {
    77  	for {
    78  		svc, ok := <-c.provider.Provide()
    79  		if !ok {
    80  			return
    81  		}
    82  		c.svcDsc.Store(svc)
    83  	}
    84  }
    85  
    86  func (c *httpThriftCodec) getMethod(req interface{}) (*Method, error) {
    87  	svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
    88  	if !ok {
    89  		return nil, errors.New("get method name failed, no ServiceDescriptor")
    90  	}
    91  	r, ok := req.(*HTTPRequest)
    92  	if !ok {
    93  		return nil, errors.New("req is invalid, need descriptor.HTTPRequest")
    94  	}
    95  	function, err := svcDsc.Router.Lookup(r)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	return &Method{function.Name, function.Oneway}, nil
   100  }
   101  
   102  func (c *httpThriftCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
   103  	svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
   104  	if !ok {
   105  		return fmt.Errorf("get parser ServiceDescriptor failed")
   106  	}
   107  
   108  	inner := thrift.NewWriteHTTPRequest(svcDsc)
   109  	inner.SetBinaryWithBase64(c.binaryWithBase64)
   110  	if c.dynamicgoEnabled {
   111  		if err := inner.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase, msg.RPCInfo().Invocation().MethodName()); err != nil {
   112  			return err
   113  		}
   114  	}
   115  
   116  	msg.Data().(WithCodec).SetCodec(inner)
   117  	return c.codec.Marshal(ctx, msg, out)
   118  }
   119  
   120  func (c *httpThriftCodec) Unmarshal(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
   121  	if err := codec.NewDataIfNeeded(serviceinfo.GenericMethod, msg); err != nil {
   122  		return err
   123  	}
   124  	svcDsc, ok := c.svcDsc.Load().(*descriptor.ServiceDescriptor)
   125  	if !ok {
   126  		return fmt.Errorf("get parser ServiceDescriptor failed")
   127  	}
   128  
   129  	inner := thrift.NewReadHTTPResponse(svcDsc)
   130  	inner.SetBase64Binary(c.binaryWithBase64)
   131  	inner.SetUseRawBodyForHTTPResp(c.useRawBodyForHTTPResp)
   132  	if c.dynamicgoEnabled && c.useRawBodyForHTTPResp {
   133  		inner.SetDynamicGo(&c.convOpts, msg)
   134  	}
   135  
   136  	msg.Data().(WithCodec).SetCodec(inner)
   137  	return c.codec.Unmarshal(ctx, msg, in)
   138  }
   139  
   140  func (c *httpThriftCodec) Name() string {
   141  	return "HttpThrift"
   142  }
   143  
   144  func (c *httpThriftCodec) Close() error {
   145  	return c.provider.Close()
   146  }
   147  
   148  // FromHTTPRequest parse HTTPRequest from http.Request
   149  func FromHTTPRequest(req *http.Request) (*HTTPRequest, error) {
   150  	customReq := &HTTPRequest{
   151  		Request:     req,
   152  		ContentType: descriptor.MIMEApplicationJson,
   153  	}
   154  	var b io.ReadCloser
   155  	var err error
   156  	if req.GetBody != nil {
   157  		// req from ServerHTTP or create by http.NewRequest
   158  		if b, err = req.GetBody(); err != nil {
   159  			return nil, err
   160  		}
   161  	} else {
   162  		b = req.Body
   163  	}
   164  	if b == nil {
   165  		// body == nil if from Get request
   166  		return customReq, nil
   167  	}
   168  	if customReq.RawBody, err = ioutil.ReadAll(b); err != nil {
   169  		return nil, err
   170  	}
   171  	return customReq, nil
   172  }