trpc.group/trpc-go/trpc-go@v1.0.2/restful/router.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package restful 15 16 import ( 17 "context" 18 "fmt" 19 "io" 20 "net/http" 21 "strconv" 22 "strings" 23 "sync" 24 25 "google.golang.org/protobuf/proto" 26 "google.golang.org/protobuf/types/known/emptypb" 27 28 "trpc.group/trpc-go/trpc-go/codec" 29 "trpc.group/trpc-go/trpc-go/errs" 30 "trpc.group/trpc-go/trpc-go/filter" 31 "trpc.group/trpc-go/trpc-go/internal/dat" 32 ) 33 34 // Router is restful router. 35 type Router struct { 36 opts *Options 37 transcoders map[string][]*transcoder 38 } 39 40 // NewRouter creates a Router. 41 func NewRouter(opts ...Option) *Router { 42 o := Options{ 43 ErrorHandler: DefaultErrorHandler, 44 HeaderMatcher: DefaultHeaderMatcher, 45 ResponseHandler: DefaultResponseHandler, 46 FastHTTPErrHandler: DefaultFastHTTPErrorHandler, 47 FastHTTPHeaderMatcher: DefaultFastHTTPHeaderMatcher, 48 FastHTTPRespHandler: DefaultFastHTTPRespHandler, 49 } 50 for _, opt := range opts { 51 opt(&o) 52 } 53 o.rebuildHeaderMatcher() 54 55 return &Router{ 56 opts: &o, 57 transcoders: make(map[string][]*transcoder), 58 } 59 } 60 61 var ( 62 routers = make(map[string]http.Handler) // tRPC service name -> Router 63 routerLock sync.RWMutex 64 ) 65 66 // RegisterRouter registers a Router which corresponds to a tRPC Service. 67 func RegisterRouter(name string, router http.Handler) { 68 routerLock.Lock() 69 routers[name] = router 70 routerLock.Unlock() 71 } 72 73 // GetRouter returns a Router which corresponds to a tRPC Service. 74 func GetRouter(name string) http.Handler { 75 routerLock.RLock() 76 router := routers[name] 77 routerLock.RUnlock() 78 return router 79 } 80 81 // ProtoMessage is alias of proto.Message. 82 type ProtoMessage proto.Message 83 84 // Initializer initializes a ProtoMessage. 85 type Initializer func() ProtoMessage 86 87 // BodyLocator locates which fields of the proto message would be 88 // populated according to HttpRule body. 89 type BodyLocator interface { 90 Body() string 91 Locate(ProtoMessage) interface{} 92 } 93 94 // ResponseBodyLocator locates which fields of the proto message would be marshaled 95 // according to HttpRule response_body. 96 type ResponseBodyLocator interface { 97 ResponseBody() string 98 Locate(ProtoMessage) interface{} 99 } 100 101 // HandleFunc is tRPC method handle function. 102 type HandleFunc func(svc interface{}, ctx context.Context, reqBody interface{}) (interface{}, error) 103 104 // ExtractFilterFunc extracts tRPC service filter chain. 105 type ExtractFilterFunc func() filter.ServerChain 106 107 // Binding is the binding of tRPC method and HttpRule. 108 type Binding struct { 109 Name string 110 Input Initializer 111 Output Initializer 112 Filter HandleFunc 113 HTTPMethod string 114 Pattern *Pattern 115 Body BodyLocator 116 ResponseBody ResponseBodyLocator 117 } 118 119 // AddImplBinding creates a new binding with a specified service implementation. 120 func (r *Router) AddImplBinding(binding *Binding, serviceImpl interface{}) error { 121 tr, err := r.newTranscoder(binding, serviceImpl) 122 if err != nil { 123 return fmt.Errorf("new transcoder during add impl binding: %w", err) 124 } 125 // add transcoder 126 r.transcoders[binding.HTTPMethod] = append(r.transcoders[binding.HTTPMethod], tr) 127 return nil 128 } 129 130 func (r *Router) newTranscoder(binding *Binding, serviceImpl interface{}) (*transcoder, error) { 131 if binding.Output == nil { 132 binding.Output = func() ProtoMessage { return &emptypb.Empty{} } 133 } 134 135 // create a transcoder 136 tr := &transcoder{ 137 name: binding.Name, 138 input: binding.Input, 139 output: binding.Output, 140 handler: binding.Filter, 141 httpMethod: binding.HTTPMethod, 142 pat: binding.Pattern, 143 body: binding.Body, 144 respBody: binding.ResponseBody, 145 router: r, 146 discardUnknownParams: r.opts.DiscardUnknownParams, 147 serviceImpl: serviceImpl, 148 } 149 150 // create a dat, filter all fields specified in HttpRule 151 var fps [][]string 152 if fromPat := binding.Pattern.FieldPaths(); fromPat != nil { 153 fps = append(fps, fromPat...) 154 } 155 if binding.Body != nil { 156 if fromBody := binding.Body.Body(); fromBody != "" && fromBody != "*" { 157 fps = append(fps, strings.Split(fromBody, ".")) 158 } 159 } 160 if len(fps) > 0 { 161 doubleArrayTrie, err := dat.Build(fps) 162 if err != nil { 163 return nil, fmt.Errorf("failed to build dat: %w", err) 164 } 165 tr.dat = doubleArrayTrie 166 } 167 return tr, nil 168 } 169 170 // ctxForCompatibility is used only for compatibility with thttp. 171 var ctxForCompatibility func(context.Context, http.ResponseWriter, *http.Request) context.Context 172 173 // SetCtxForCompatibility is used only for compatibility with thttp. 174 func SetCtxForCompatibility(f func(context.Context, http.ResponseWriter, *http.Request) context.Context) { 175 ctxForCompatibility = f 176 } 177 178 // HeaderMatcher matches http request header to tRPC Stub Context. 179 type HeaderMatcher func( 180 ctx context.Context, 181 w http.ResponseWriter, 182 r *http.Request, 183 serviceName, methodName string, 184 ) (context.Context, error) 185 186 // DefaultHeaderMatcher is the default HeaderMatcher. 187 var DefaultHeaderMatcher = func( 188 ctx context.Context, 189 w http.ResponseWriter, 190 req *http.Request, 191 serviceName, methodName string, 192 ) (context.Context, error) { 193 // Noted: it's better to do the same thing as withNewMessage. 194 return withNewMessage(ctx, serviceName, methodName), nil 195 } 196 197 // withNewMessage create a new codec.Msg, put it into ctx, 198 // and set target service name and method name. 199 func withNewMessage(ctx context.Context, serviceName, methodName string) context.Context { 200 ctx, msg := codec.WithNewMessage(ctx) 201 msg.WithServerRPCName(methodName) 202 msg.WithCalleeMethod(methodName) 203 msg.WithCalleeServiceName(serviceName) 204 msg.WithSerializationType(codec.SerializationTypePB) 205 return ctx 206 } 207 208 // CustomResponseHandler is the custom response handler. 209 type CustomResponseHandler func( 210 ctx context.Context, 211 w http.ResponseWriter, 212 r *http.Request, 213 resp proto.Message, 214 body []byte, 215 ) error 216 217 var httpStatusKey = "t-http-status" 218 219 // SetStatusCodeOnSucceed sets status code on succeed, should be 2XX. 220 // It's not supposed to call this function but use WithStatusCode in restful/errors.go 221 // to set status code on error. 222 func SetStatusCodeOnSucceed(ctx context.Context, code int) { 223 msg := codec.Message(ctx) 224 metadata := msg.ServerMetaData() 225 if metadata == nil { 226 metadata = codec.MetaData{} 227 } 228 metadata[httpStatusKey] = []byte(strconv.Itoa(code)) 229 msg.WithServerMetaData(metadata) 230 } 231 232 // GetStatusCodeOnSucceed returns status code on succeed. 233 // SetStatusCodeOnSucceed must be called first in tRPC method. 234 func GetStatusCodeOnSucceed(ctx context.Context) int { 235 if metadata := codec.Message(ctx).ServerMetaData(); metadata != nil { 236 if buf, ok := metadata[httpStatusKey]; ok { 237 if code, err := strconv.Atoi(bytes2str(buf)); err == nil { 238 return code 239 } 240 } 241 } 242 return http.StatusOK 243 } 244 245 // DefaultResponseHandler is the default CustomResponseHandler. 246 var DefaultResponseHandler = func( 247 ctx context.Context, 248 w http.ResponseWriter, 249 r *http.Request, 250 resp proto.Message, 251 body []byte, 252 ) error { 253 // compress 254 var writer io.Writer = w 255 _, c := compressorForTranscoding(r.Header[headerContentEncoding], 256 r.Header[headerAcceptEncoding]) 257 if c != nil { 258 writeCloser, err := c.Compress(w) 259 if err != nil { 260 return fmt.Errorf("failed to compress resp body: %w", err) 261 } 262 defer writeCloser.Close() 263 w.Header().Set(headerContentEncoding, c.ContentEncoding()) 264 writer = writeCloser 265 } 266 267 // set response content-type 268 _, s := serializerForTranscoding(r.Header[headerContentType], 269 r.Header[headerAccept]) 270 w.Header().Set(headerContentType, s.ContentType()) 271 272 // set status code 273 statusCode := GetStatusCodeOnSucceed(ctx) 274 w.WriteHeader(statusCode) 275 276 // response body 277 if statusCode != http.StatusNoContent && statusCode != http.StatusNotModified { 278 writer.Write(body) 279 } 280 281 return nil 282 } 283 284 // putBackCtxMessage calls codec.PutBackMessage to put a codec.Msg back to pool, 285 // if the codec.Msg has been put into ctx. 286 func putBackCtxMessage(ctx context.Context) { 287 if msg, ok := ctx.Value(codec.ContextKeyMessage).(codec.Msg); ok { 288 codec.PutBackMessage(msg) 289 } 290 } 291 292 // ServeHTTP implements http.Handler. 293 // TODO: better routing handling. 294 func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { 295 ctx := ctxForCompatibility(req.Context(), w, req) 296 for _, tr := range r.transcoders[req.Method] { 297 fieldValues, err := tr.pat.Match(req.URL.Path) 298 if err == nil { 299 r.handle(ctx, w, req, tr, fieldValues) 300 return 301 } 302 } 303 r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerNoFunc, "failed to match any pattern")) 304 } 305 306 func (r *Router) handle( 307 ctx context.Context, 308 w http.ResponseWriter, 309 req *http.Request, 310 tr *transcoder, 311 fieldValues map[string]string, 312 ) { 313 modifiedCtx, err := r.opts.HeaderMatcher(ctx, w, req, r.opts.ServiceName, tr.name) 314 if err != nil { 315 r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerDecodeFail, err.Error())) 316 return 317 } 318 ctx = modifiedCtx 319 defer putBackCtxMessage(ctx) 320 321 timeout := r.opts.Timeout 322 requestTimeout := codec.Message(ctx).RequestTimeout() 323 if requestTimeout > 0 && (requestTimeout < timeout || timeout == 0) { 324 timeout = requestTimeout 325 } 326 if timeout > 0 { 327 var cancel context.CancelFunc 328 ctx, cancel = context.WithTimeout(ctx, timeout) 329 defer cancel() 330 } 331 332 // get inbound/outbound Compressor and Serializer 333 reqCompressor, respCompressor := compressorForTranscoding(req.Header[headerContentEncoding], 334 req.Header[headerAcceptEncoding]) 335 reqSerializer, respSerializer := serializerForTranscoding(req.Header[headerContentType], 336 req.Header[headerAccept]) 337 338 // set transcoder params 339 params, _ := paramsPool.Get().(*transcodeParams) 340 params.reqCompressor = reqCompressor 341 params.respCompressor = respCompressor 342 params.reqSerializer = reqSerializer 343 params.respSerializer = respSerializer 344 params.body = req.Body 345 params.fieldValues = fieldValues 346 params.form = req.URL.Query() 347 defer putBackParams(params) 348 349 // transcode 350 resp, body, err := tr.transcode(ctx, params) 351 if err != nil { 352 r.opts.ErrorHandler(ctx, w, req, err) 353 return 354 } 355 356 // custom response handling 357 if err := r.opts.ResponseHandler(ctx, w, req, resp, body); err != nil { 358 r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerEncodeFail, err.Error())) 359 } 360 }