trpc.group/trpc-go/trpc-go@v1.0.3/restful/fasthttp.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package restful 15 16 import ( 17 "bytes" 18 "context" 19 "unsafe" 20 21 "github.com/valyala/fasthttp" 22 "google.golang.org/protobuf/proto" 23 "trpc.group/trpc-go/trpc-go/errs" 24 ) 25 26 // FastHTTPHeaderMatcher matches fasthttp request header to tRPC Stub Context. 27 type FastHTTPHeaderMatcher func( 28 ctx context.Context, 29 requestCtx *fasthttp.RequestCtx, 30 serviceName, methodName string, 31 ) (context.Context, error) 32 33 // DefaultFastHTTPHeaderMatcher is the default FastHTTPHeaderMatcher. 34 var DefaultFastHTTPHeaderMatcher = func( 35 ctx context.Context, 36 requestCtx *fasthttp.RequestCtx, 37 serviceName, methodName string, 38 ) (context.Context, error) { 39 return withNewMessage(ctx, serviceName, methodName), nil 40 } 41 42 // FastHTTPRespHandler is the custom response handler when fasthttp is used. 43 type FastHTTPRespHandler func( 44 ctx context.Context, 45 requestCtx *fasthttp.RequestCtx, 46 resp proto.Message, 47 body []byte, 48 ) error 49 50 // DefaultFastHTTPRespHandler is the default FastHTTPRespHandler. 51 func DefaultFastHTTPRespHandler(stubCtx context.Context, requestCtx *fasthttp.RequestCtx, 52 protoResp proto.Message, body []byte) error { 53 // compress 54 writer := requestCtx.Response.BodyWriter() 55 // fasthttp doesn't support getting multiple values of one key from http headers. 56 // ctx.Request.Header.Peek is equivalent to req.Header.Get from Go net/http. 57 _, c := compressorForTranscoding( 58 []string{bytes2str(requestCtx.Request.Header.Peek(headerContentEncoding))}, 59 []string{bytes2str(requestCtx.Request.Header.Peek(headerAcceptEncoding))}, 60 ) 61 if c != nil { 62 writeCloser, err := c.Compress(writer) 63 if err != nil { 64 return err 65 } 66 defer writeCloser.Close() 67 requestCtx.Response.Header.Set(headerContentEncoding, c.ContentEncoding()) 68 writer = writeCloser 69 } 70 71 // set response content-type 72 _, s := serializerForTranscoding( 73 []string{bytes2str(requestCtx.Request.Header.Peek(headerContentType))}, 74 []string{bytes2str(requestCtx.Request.Header.Peek(headerAccept))}, 75 ) 76 requestCtx.Response.Header.Set(headerContentType, s.ContentType()) 77 78 // set status code 79 statusCode := GetStatusCodeOnSucceed(stubCtx) 80 requestCtx.SetStatusCode(statusCode) 81 82 // write body 83 if statusCode != fasthttp.StatusNoContent && statusCode != fasthttp.StatusNotModified { 84 writer.Write(body) 85 } 86 87 return nil 88 } 89 90 // bytes2str is the high-performance way of converting []byte to string. 91 func bytes2str(b []byte) string { 92 return *(*string)(unsafe.Pointer(&b)) 93 } 94 95 // HandleRequestCtx fasthttp handler 96 func (r *Router) HandleRequestCtx(ctx *fasthttp.RequestCtx) { 97 newCtx := context.Background() 98 for _, tr := range r.transcoders[bytes2str(ctx.Method())] { 99 fieldValues, err := tr.pat.Match(bytes2str(ctx.Path())) 100 if err == nil { 101 // header matching 102 stubCtx, err := r.opts.FastHTTPHeaderMatcher(newCtx, ctx, 103 r.opts.ServiceName, tr.name) 104 if err != nil { 105 r.opts.FastHTTPErrHandler(stubCtx, ctx, errs.New(errs.RetServerDecodeFail, err.Error())) 106 return 107 } 108 109 // get inbound/outbound Compressor & Serializer 110 reqCompressor, respCompressor := compressorForTranscoding( 111 []string{bytes2str(ctx.Request.Header.Peek(headerContentEncoding))}, 112 []string{bytes2str(ctx.Request.Header.Peek(headerAcceptEncoding))}, 113 ) 114 reqSerializer, respSerializer := serializerForTranscoding( 115 []string{bytes2str(ctx.Request.Header.Peek(headerContentType))}, 116 []string{bytes2str(ctx.Request.Header.Peek(headerAccept))}, 117 ) 118 119 // get query params 120 form := make(map[string][]string) 121 ctx.QueryArgs().VisitAll(func(key []byte, value []byte) { 122 form[bytes2str(key)] = append(form[bytes2str(key)], bytes2str(value)) 123 }) 124 125 // set transcoding params 126 params := paramsPool.Get().(*transcodeParams) 127 params.reqCompressor = reqCompressor 128 params.respCompressor = respCompressor 129 params.reqSerializer = reqSerializer 130 params.respSerializer = respSerializer 131 params.body = bytes.NewBuffer(ctx.PostBody()) 132 params.fieldValues = fieldValues 133 params.form = form 134 135 // transcode 136 resp, body, err := tr.transcode(stubCtx, params) 137 if err != nil { 138 r.opts.FastHTTPErrHandler(stubCtx, ctx, err) 139 putBackCtxMessage(stubCtx) 140 putBackParams(params) 141 return 142 } 143 144 // response 145 if err := r.opts.FastHTTPRespHandler(stubCtx, ctx, resp, body); err != nil { 146 r.opts.FastHTTPErrHandler(stubCtx, ctx, errs.New(errs.RetServerEncodeFail, err.Error())) 147 } 148 putBackCtxMessage(stubCtx) 149 putBackParams(params) 150 return 151 } 152 } 153 r.opts.FastHTTPErrHandler(newCtx, ctx, errs.New(errs.RetServerNoFunc, "failed to match any pattern")) 154 }