trpc.group/trpc-go/trpc-go@v1.0.3/restful/transcode.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"net/url"
    23  	"strings"
    24  	"sync"
    25  
    26  	"google.golang.org/protobuf/proto"
    27  
    28  	"trpc.group/trpc-go/trpc-go/errs"
    29  	"trpc.group/trpc-go/trpc-go/internal/dat"
    30  )
    31  
    32  const (
    33  	// default size of http req body buffer
    34  	defaultBodyBufferSize = 4096
    35  )
    36  
    37  // transcoder is for tRPC/httpjson transcoding.
    38  type transcoder struct {
    39  	name                 string
    40  	input                func() ProtoMessage
    41  	output               func() ProtoMessage
    42  	handler              HandleFunc
    43  	httpMethod           string
    44  	pat                  *Pattern
    45  	body                 BodyLocator
    46  	respBody             ResponseBodyLocator
    47  	router               *Router
    48  	dat                  *dat.DoubleArrayTrie
    49  	discardUnknownParams bool
    50  	serviceImpl          interface{}
    51  }
    52  
    53  // transcodeParams are params required for transcoding.
    54  type transcodeParams struct {
    55  	reqCompressor  Compressor
    56  	respCompressor Compressor
    57  	reqSerializer  Serializer
    58  	respSerializer Serializer
    59  	body           io.Reader
    60  	fieldValues    map[string]string
    61  	form           url.Values
    62  }
    63  
    64  // paramsPool is the transcodeParams pool.
    65  var paramsPool = sync.Pool{
    66  	New: func() interface{} {
    67  		return &transcodeParams{}
    68  	},
    69  }
    70  
    71  // putBackParams puts transcodeParams back to pool.
    72  func putBackParams(params *transcodeParams) {
    73  	params.reqCompressor = nil
    74  	params.respCompressor = nil
    75  	params.reqSerializer = nil
    76  	params.respSerializer = nil
    77  	params.body = nil
    78  	params.fieldValues = nil
    79  	params.form = nil
    80  	paramsPool.Put(params)
    81  }
    82  
    83  // transcode transcodes tRPC/httpjson.
    84  func (tr *transcoder) transcode(
    85  	stubCtx context.Context,
    86  	params *transcodeParams,
    87  ) (proto.Message, []byte, error) {
    88  	// init tRPC request
    89  	protoReq := tr.input()
    90  
    91  	// transcode body
    92  	if err := tr.transcodeBody(protoReq, params.body, params.reqCompressor,
    93  		params.reqSerializer); err != nil {
    94  		return nil, nil, errs.New(errs.RetServerDecodeFail, err.Error())
    95  	}
    96  
    97  	// transcode fieldValues from url path matching
    98  	if err := tr.transcodeFieldValues(protoReq, params.fieldValues); err != nil {
    99  		return nil, nil, errs.New(errs.RetServerDecodeFail, err.Error())
   100  	}
   101  
   102  	// transcode query params
   103  	if err := tr.transcodeQueryParams(protoReq, params.form); err != nil {
   104  		return nil, nil, errs.New(errs.RetServerDecodeFail, err.Error())
   105  	}
   106  
   107  	// tRPC Stub handling
   108  	rsp, err := tr.handle(stubCtx, protoReq)
   109  	if err != nil {
   110  		return nil, nil, err
   111  	}
   112  	var protoResp proto.Message
   113  	if rsp == nil {
   114  		protoResp = tr.output()
   115  	} else {
   116  		protoResp = rsp.(proto.Message)
   117  	}
   118  
   119  	// response
   120  	// HttpRule.response_body only specifies serialization of fields.
   121  	// So compression would be custom.
   122  	buf, err := tr.transcodeResp(protoResp, params.respSerializer)
   123  	if err != nil {
   124  		return nil, nil, errs.New(errs.RetServerEncodeFail, err.Error())
   125  	}
   126  	return protoResp, buf, nil
   127  }
   128  
   129  // bodyBufferPool is the pool of http request body buffer.
   130  var bodyBufferPool = sync.Pool{
   131  	New: func() interface{} {
   132  		return bytes.NewBuffer(make([]byte, defaultBodyBufferSize))
   133  	},
   134  }
   135  
   136  // transcodeBody transcodes tRPC/httpjson by http request body.
   137  func (tr *transcoder) transcodeBody(protoReq proto.Message, body io.Reader, c Compressor, s Serializer) error {
   138  	// HttpRule body is not specified
   139  	if tr.body == nil {
   140  		return nil
   141  	}
   142  
   143  	// decompress
   144  	var reader io.Reader
   145  	var err error
   146  	if c != nil {
   147  		if reader, err = c.Decompress(body); err != nil {
   148  			return fmt.Errorf("failed to decompress request body: %w", err)
   149  		}
   150  	} else {
   151  		reader = body
   152  	}
   153  
   154  	// read body
   155  	buffer := bodyBufferPool.Get().(*bytes.Buffer)
   156  	buffer.Reset()
   157  	defer bodyBufferPool.Put(buffer)
   158  	if _, err := io.Copy(buffer, reader); err != nil {
   159  		return fmt.Errorf("failed to read request body: %w", err)
   160  	}
   161  
   162  	// unmarshal
   163  	if err := s.Unmarshal(buffer.Bytes(), tr.body.Locate(protoReq)); err != nil {
   164  		return fmt.Errorf("failed to unmarshal req body: %w", err)
   165  	}
   166  
   167  	// field mask will be set for PATCH method.
   168  	if tr.httpMethod == "PATCH" && tr.body.Body() != "*" {
   169  		return setFieldMask(protoReq.ProtoReflect(), tr.body.Body())
   170  	}
   171  
   172  	return nil
   173  }
   174  
   175  // transcodeFieldValues transcodes tRPC/httpjson by fieldValues from url path matching.
   176  func (tr *transcoder) transcodeFieldValues(msg proto.Message, fieldValues map[string]string) error {
   177  	for fieldPath, value := range fieldValues {
   178  		if err := PopulateMessage(msg, strings.Split(fieldPath, "."), []string{value}); err != nil {
   179  			return err
   180  		}
   181  	}
   182  	return nil
   183  }
   184  
   185  // transcodeQueryParams transcodes tRPC/httpjson by query params.
   186  func (tr *transcoder) transcodeQueryParams(msg proto.Message, form url.Values) error {
   187  	// Query params will be ignored if HttpRule body is *.
   188  	if tr.body != nil && tr.body.Body() == "*" {
   189  		return nil
   190  	}
   191  
   192  	for key, values := range form {
   193  		// filter fields specified by HttpRule pattern and body
   194  		if tr.dat != nil && tr.dat.CommonPrefixSearch(strings.Split(key, ".")) {
   195  			continue
   196  		}
   197  		// populate proto message
   198  		if err := PopulateMessage(msg, strings.Split(key, "."), values); err != nil {
   199  			if !tr.discardUnknownParams || !errors.Is(err, ErrTraverseNotFound) {
   200  				return err
   201  			}
   202  		}
   203  	}
   204  
   205  	return nil
   206  }
   207  
   208  // handle does tRPC Stub handling.
   209  func (tr *transcoder) handle(ctx context.Context, reqBody interface{}) (interface{}, error) {
   210  	filters := tr.router.opts.FilterFunc()
   211  	serviceImpl := tr.serviceImpl
   212  	handleFunc := func(ctx context.Context, reqBody interface{}) (interface{}, error) {
   213  		return tr.handler(serviceImpl, ctx, reqBody)
   214  	}
   215  	return filters.Filter(ctx, reqBody, handleFunc)
   216  }
   217  
   218  // transcodeResp transcodes tRPC/httpjson by response.
   219  func (tr *transcoder) transcodeResp(protoResp proto.Message, s Serializer) ([]byte, error) {
   220  	// marshal
   221  	var obj interface{}
   222  	if tr.respBody == nil {
   223  		obj = protoResp
   224  	} else {
   225  		obj = tr.respBody.Locate(protoResp)
   226  	}
   227  	return s.Marshal(obj)
   228  }