trpc.group/trpc-go/trpc-go@v1.0.3/restful/transcode.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package restful 15 16 import ( 17 "bytes" 18 "context" 19 "errors" 20 "fmt" 21 "io" 22 "net/url" 23 "strings" 24 "sync" 25 26 "google.golang.org/protobuf/proto" 27 28 "trpc.group/trpc-go/trpc-go/errs" 29 "trpc.group/trpc-go/trpc-go/internal/dat" 30 ) 31 32 const ( 33 // default size of http req body buffer 34 defaultBodyBufferSize = 4096 35 ) 36 37 // transcoder is for tRPC/httpjson transcoding. 38 type transcoder struct { 39 name string 40 input func() ProtoMessage 41 output func() ProtoMessage 42 handler HandleFunc 43 httpMethod string 44 pat *Pattern 45 body BodyLocator 46 respBody ResponseBodyLocator 47 router *Router 48 dat *dat.DoubleArrayTrie 49 discardUnknownParams bool 50 serviceImpl interface{} 51 } 52 53 // transcodeParams are params required for transcoding. 54 type transcodeParams struct { 55 reqCompressor Compressor 56 respCompressor Compressor 57 reqSerializer Serializer 58 respSerializer Serializer 59 body io.Reader 60 fieldValues map[string]string 61 form url.Values 62 } 63 64 // paramsPool is the transcodeParams pool. 65 var paramsPool = sync.Pool{ 66 New: func() interface{} { 67 return &transcodeParams{} 68 }, 69 } 70 71 // putBackParams puts transcodeParams back to pool. 72 func putBackParams(params *transcodeParams) { 73 params.reqCompressor = nil 74 params.respCompressor = nil 75 params.reqSerializer = nil 76 params.respSerializer = nil 77 params.body = nil 78 params.fieldValues = nil 79 params.form = nil 80 paramsPool.Put(params) 81 } 82 83 // transcode transcodes tRPC/httpjson. 84 func (tr *transcoder) transcode( 85 stubCtx context.Context, 86 params *transcodeParams, 87 ) (proto.Message, []byte, error) { 88 // init tRPC request 89 protoReq := tr.input() 90 91 // transcode body 92 if err := tr.transcodeBody(protoReq, params.body, params.reqCompressor, 93 params.reqSerializer); err != nil { 94 return nil, nil, errs.New(errs.RetServerDecodeFail, err.Error()) 95 } 96 97 // transcode fieldValues from url path matching 98 if err := tr.transcodeFieldValues(protoReq, params.fieldValues); err != nil { 99 return nil, nil, errs.New(errs.RetServerDecodeFail, err.Error()) 100 } 101 102 // transcode query params 103 if err := tr.transcodeQueryParams(protoReq, params.form); err != nil { 104 return nil, nil, errs.New(errs.RetServerDecodeFail, err.Error()) 105 } 106 107 // tRPC Stub handling 108 rsp, err := tr.handle(stubCtx, protoReq) 109 if err != nil { 110 return nil, nil, err 111 } 112 var protoResp proto.Message 113 if rsp == nil { 114 protoResp = tr.output() 115 } else { 116 protoResp = rsp.(proto.Message) 117 } 118 119 // response 120 // HttpRule.response_body only specifies serialization of fields. 121 // So compression would be custom. 122 buf, err := tr.transcodeResp(protoResp, params.respSerializer) 123 if err != nil { 124 return nil, nil, errs.New(errs.RetServerEncodeFail, err.Error()) 125 } 126 return protoResp, buf, nil 127 } 128 129 // bodyBufferPool is the pool of http request body buffer. 130 var bodyBufferPool = sync.Pool{ 131 New: func() interface{} { 132 return bytes.NewBuffer(make([]byte, defaultBodyBufferSize)) 133 }, 134 } 135 136 // transcodeBody transcodes tRPC/httpjson by http request body. 137 func (tr *transcoder) transcodeBody(protoReq proto.Message, body io.Reader, c Compressor, s Serializer) error { 138 // HttpRule body is not specified 139 if tr.body == nil { 140 return nil 141 } 142 143 // decompress 144 var reader io.Reader 145 var err error 146 if c != nil { 147 if reader, err = c.Decompress(body); err != nil { 148 return fmt.Errorf("failed to decompress request body: %w", err) 149 } 150 } else { 151 reader = body 152 } 153 154 // read body 155 buffer := bodyBufferPool.Get().(*bytes.Buffer) 156 buffer.Reset() 157 defer bodyBufferPool.Put(buffer) 158 if _, err := io.Copy(buffer, reader); err != nil { 159 return fmt.Errorf("failed to read request body: %w", err) 160 } 161 162 // unmarshal 163 if err := s.Unmarshal(buffer.Bytes(), tr.body.Locate(protoReq)); err != nil { 164 return fmt.Errorf("failed to unmarshal req body: %w", err) 165 } 166 167 // field mask will be set for PATCH method. 168 if tr.httpMethod == "PATCH" && tr.body.Body() != "*" { 169 return setFieldMask(protoReq.ProtoReflect(), tr.body.Body()) 170 } 171 172 return nil 173 } 174 175 // transcodeFieldValues transcodes tRPC/httpjson by fieldValues from url path matching. 176 func (tr *transcoder) transcodeFieldValues(msg proto.Message, fieldValues map[string]string) error { 177 for fieldPath, value := range fieldValues { 178 if err := PopulateMessage(msg, strings.Split(fieldPath, "."), []string{value}); err != nil { 179 return err 180 } 181 } 182 return nil 183 } 184 185 // transcodeQueryParams transcodes tRPC/httpjson by query params. 186 func (tr *transcoder) transcodeQueryParams(msg proto.Message, form url.Values) error { 187 // Query params will be ignored if HttpRule body is *. 188 if tr.body != nil && tr.body.Body() == "*" { 189 return nil 190 } 191 192 for key, values := range form { 193 // filter fields specified by HttpRule pattern and body 194 if tr.dat != nil && tr.dat.CommonPrefixSearch(strings.Split(key, ".")) { 195 continue 196 } 197 // populate proto message 198 if err := PopulateMessage(msg, strings.Split(key, "."), values); err != nil { 199 if !tr.discardUnknownParams || !errors.Is(err, ErrTraverseNotFound) { 200 return err 201 } 202 } 203 } 204 205 return nil 206 } 207 208 // handle does tRPC Stub handling. 209 func (tr *transcoder) handle(ctx context.Context, reqBody interface{}) (interface{}, error) { 210 filters := tr.router.opts.FilterFunc() 211 serviceImpl := tr.serviceImpl 212 handleFunc := func(ctx context.Context, reqBody interface{}) (interface{}, error) { 213 return tr.handler(serviceImpl, ctx, reqBody) 214 } 215 return filters.Filter(ctx, reqBody, handleFunc) 216 } 217 218 // transcodeResp transcodes tRPC/httpjson by response. 219 func (tr *transcoder) transcodeResp(protoResp proto.Message, s Serializer) ([]byte, error) { 220 // marshal 221 var obj interface{} 222 if tr.respBody == nil { 223 obj = protoResp 224 } else { 225 obj = tr.respBody.Locate(protoResp) 226 } 227 return s.Marshal(obj) 228 }