trpc.group/trpc-go/trpc-go@v1.0.3/restful/fasthttp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"unsafe"
    20  
    21  	"github.com/valyala/fasthttp"
    22  	"google.golang.org/protobuf/proto"
    23  	"trpc.group/trpc-go/trpc-go/errs"
    24  )
    25  
    26  // FastHTTPHeaderMatcher matches fasthttp request header to tRPC Stub Context.
    27  type FastHTTPHeaderMatcher func(
    28  	ctx context.Context,
    29  	requestCtx *fasthttp.RequestCtx,
    30  	serviceName, methodName string,
    31  ) (context.Context, error)
    32  
    33  // DefaultFastHTTPHeaderMatcher is the default FastHTTPHeaderMatcher.
    34  var DefaultFastHTTPHeaderMatcher = func(
    35  	ctx context.Context,
    36  	requestCtx *fasthttp.RequestCtx,
    37  	serviceName, methodName string,
    38  ) (context.Context, error) {
    39  	return withNewMessage(ctx, serviceName, methodName), nil
    40  }
    41  
    42  // FastHTTPRespHandler is the custom response handler when fasthttp is used.
    43  type FastHTTPRespHandler func(
    44  	ctx context.Context,
    45  	requestCtx *fasthttp.RequestCtx,
    46  	resp proto.Message,
    47  	body []byte,
    48  ) error
    49  
    50  // DefaultFastHTTPRespHandler is the default FastHTTPRespHandler.
    51  func DefaultFastHTTPRespHandler(stubCtx context.Context, requestCtx *fasthttp.RequestCtx,
    52  	protoResp proto.Message, body []byte) error {
    53  	// compress
    54  	writer := requestCtx.Response.BodyWriter()
    55  	// fasthttp doesn't support getting multiple values of one key from http headers.
    56  	// ctx.Request.Header.Peek is equivalent to req.Header.Get from Go net/http.
    57  	_, c := compressorForTranscoding(
    58  		[]string{bytes2str(requestCtx.Request.Header.Peek(headerContentEncoding))},
    59  		[]string{bytes2str(requestCtx.Request.Header.Peek(headerAcceptEncoding))},
    60  	)
    61  	if c != nil {
    62  		writeCloser, err := c.Compress(writer)
    63  		if err != nil {
    64  			return err
    65  		}
    66  		defer writeCloser.Close()
    67  		requestCtx.Response.Header.Set(headerContentEncoding, c.ContentEncoding())
    68  		writer = writeCloser
    69  	}
    70  
    71  	// set response content-type
    72  	_, s := serializerForTranscoding(
    73  		[]string{bytes2str(requestCtx.Request.Header.Peek(headerContentType))},
    74  		[]string{bytes2str(requestCtx.Request.Header.Peek(headerAccept))},
    75  	)
    76  	requestCtx.Response.Header.Set(headerContentType, s.ContentType())
    77  
    78  	// set status code
    79  	statusCode := GetStatusCodeOnSucceed(stubCtx)
    80  	requestCtx.SetStatusCode(statusCode)
    81  
    82  	// write body
    83  	if statusCode != fasthttp.StatusNoContent && statusCode != fasthttp.StatusNotModified {
    84  		writer.Write(body)
    85  	}
    86  
    87  	return nil
    88  }
    89  
    90  // bytes2str is the high-performance way of converting []byte to string.
    91  func bytes2str(b []byte) string {
    92  	return *(*string)(unsafe.Pointer(&b))
    93  }
    94  
    95  // HandleRequestCtx fasthttp handler
    96  func (r *Router) HandleRequestCtx(ctx *fasthttp.RequestCtx) {
    97  	newCtx := context.Background()
    98  	for _, tr := range r.transcoders[bytes2str(ctx.Method())] {
    99  		fieldValues, err := tr.pat.Match(bytes2str(ctx.Path()))
   100  		if err == nil {
   101  			// header matching
   102  			stubCtx, err := r.opts.FastHTTPHeaderMatcher(newCtx, ctx,
   103  				r.opts.ServiceName, tr.name)
   104  			if err != nil {
   105  				r.opts.FastHTTPErrHandler(stubCtx, ctx, errs.New(errs.RetServerDecodeFail, err.Error()))
   106  				return
   107  			}
   108  
   109  			// get inbound/outbound Compressor & Serializer
   110  			reqCompressor, respCompressor := compressorForTranscoding(
   111  				[]string{bytes2str(ctx.Request.Header.Peek(headerContentEncoding))},
   112  				[]string{bytes2str(ctx.Request.Header.Peek(headerAcceptEncoding))},
   113  			)
   114  			reqSerializer, respSerializer := serializerForTranscoding(
   115  				[]string{bytes2str(ctx.Request.Header.Peek(headerContentType))},
   116  				[]string{bytes2str(ctx.Request.Header.Peek(headerAccept))},
   117  			)
   118  
   119  			// get query params
   120  			form := make(map[string][]string)
   121  			ctx.QueryArgs().VisitAll(func(key []byte, value []byte) {
   122  				form[bytes2str(key)] = append(form[bytes2str(key)], bytes2str(value))
   123  			})
   124  
   125  			// set transcoding params
   126  			params := paramsPool.Get().(*transcodeParams)
   127  			params.reqCompressor = reqCompressor
   128  			params.respCompressor = respCompressor
   129  			params.reqSerializer = reqSerializer
   130  			params.respSerializer = respSerializer
   131  			params.body = bytes.NewBuffer(ctx.PostBody())
   132  			params.fieldValues = fieldValues
   133  			params.form = form
   134  
   135  			// transcode
   136  			resp, body, err := tr.transcode(stubCtx, params)
   137  			if err != nil {
   138  				r.opts.FastHTTPErrHandler(stubCtx, ctx, err)
   139  				putBackCtxMessage(stubCtx)
   140  				putBackParams(params)
   141  				return
   142  			}
   143  
   144  			// response
   145  			if err := r.opts.FastHTTPRespHandler(stubCtx, ctx, resp, body); err != nil {
   146  				r.opts.FastHTTPErrHandler(stubCtx, ctx, errs.New(errs.RetServerEncodeFail, err.Error()))
   147  			}
   148  			putBackCtxMessage(stubCtx)
   149  			putBackParams(params)
   150  			return
   151  		}
   152  	}
   153  	r.opts.FastHTTPErrHandler(newCtx, ctx, errs.New(errs.RetServerNoFunc, "failed to match any pattern"))
   154  }