github.com/cloudwego/kitex@v0.9.0/pkg/generic/thrift/http_pb.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  
    24  	"github.com/apache/thrift/lib/go/thrift"
    25  	"github.com/jhump/protoreflect/desc"
    26  	"github.com/jhump/protoreflect/dynamic"
    27  
    28  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    29  	"github.com/cloudwego/kitex/pkg/generic/proto"
    30  )
    31  
    32  // WriteHTTPPbRequest implement of MessageWriter
    33  type WriteHTTPPbRequest struct {
    34  	svc   *descriptor.ServiceDescriptor
    35  	pbSvc *desc.ServiceDescriptor
    36  }
    37  
    38  var _ MessageWriter = (*WriteHTTPPbRequest)(nil)
    39  
    40  // NewWriteHTTPPbRequest ...
    41  // Base64 decoding for binary is enabled by default.
    42  func NewWriteHTTPPbRequest(svc *descriptor.ServiceDescriptor, pbSvc *desc.ServiceDescriptor) *WriteHTTPPbRequest {
    43  	return &WriteHTTPPbRequest{svc, pbSvc}
    44  }
    45  
    46  // Write ...
    47  func (w *WriteHTTPPbRequest) Write(ctx context.Context, out thrift.TProtocol, msg interface{}, requestBase *Base) error {
    48  	req := msg.(*descriptor.HTTPRequest)
    49  	fn, err := w.svc.Router.Lookup(req)
    50  	if err != nil {
    51  		return err
    52  	}
    53  	if !fn.HasRequestBase {
    54  		requestBase = nil
    55  	}
    56  
    57  	// unmarshal body bytes to pb message
    58  	mt := w.pbSvc.FindMethodByName(fn.Name)
    59  	if mt == nil {
    60  		return fmt.Errorf("method not found in pb descriptor: %v", fn.Name)
    61  	}
    62  	pbMsg := dynamic.NewMessage(mt.GetInputType())
    63  	err = pbMsg.Unmarshal(req.RawBody)
    64  	if err != nil {
    65  		return fmt.Errorf("unmarshal pb body error: %v", err)
    66  	}
    67  	req.GeneralBody = pbMsg
    68  
    69  	return wrapStructWriter(ctx, req, out, fn.Request, &writerOption{requestBase: requestBase})
    70  }
    71  
    72  // ReadHTTPResponse implement of MessageReaderWithMethod
    73  type ReadHTTPPbResponse struct {
    74  	svc   *descriptor.ServiceDescriptor
    75  	pbSvc proto.ServiceDescriptor
    76  }
    77  
    78  var _ MessageReader = (*ReadHTTPResponse)(nil)
    79  
    80  // NewReadHTTPResponse ...
    81  // Base64 encoding for binary is enabled by default.
    82  func NewReadHTTPPbResponse(svc *descriptor.ServiceDescriptor, pbSvc proto.ServiceDescriptor) *ReadHTTPPbResponse {
    83  	return &ReadHTTPPbResponse{svc, pbSvc}
    84  }
    85  
    86  // Read ...
    87  func (r *ReadHTTPPbResponse) Read(ctx context.Context, method string, in thrift.TProtocol) (interface{}, error) {
    88  	fnDsc, err := r.svc.LookupFunctionByMethod(method)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	fDsc := fnDsc.Response
    93  	mt := r.pbSvc.FindMethodByName(method)
    94  	if mt == nil {
    95  		return nil, errors.New("pb method not found")
    96  	}
    97  
    98  	return skipStructReader(ctx, in, fDsc, &readerOption{pbDsc: mt.GetOutputType(), http: true})
    99  }