trpc.group/trpc-go/trpc-go@v1.0.3/restful/router.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package restful 15 16 import ( 17 "context" 18 "fmt" 19 "io" 20 "net/http" 21 "strconv" 22 "strings" 23 "sync" 24 25 "google.golang.org/protobuf/proto" 26 "google.golang.org/protobuf/types/known/emptypb" 27 28 "trpc.group/trpc-go/trpc-go/codec" 29 "trpc.group/trpc-go/trpc-go/errs" 30 "trpc.group/trpc-go/trpc-go/filter" 31 "trpc.group/trpc-go/trpc-go/internal/dat" 32 ) 33 34 // Router is restful router. 35 type Router struct { 36 opts *Options 37 transcoders map[string][]*transcoder 38 } 39 40 // NewRouter creates a Router. 41 func NewRouter(opts ...Option) *Router { 42 o := Options{ 43 ErrorHandler: DefaultErrorHandler, 44 HeaderMatcher: DefaultHeaderMatcher, 45 ResponseHandler: DefaultResponseHandler, 46 FastHTTPErrHandler: DefaultFastHTTPErrorHandler, 47 FastHTTPHeaderMatcher: DefaultFastHTTPHeaderMatcher, 48 FastHTTPRespHandler: DefaultFastHTTPRespHandler, 49 } 50 for _, opt := range opts { 51 opt(&o) 52 } 53 o.rebuildHeaderMatcher() 54 55 return &Router{ 56 opts: &o, 57 transcoders: make(map[string][]*transcoder), 58 } 59 } 60 61 var ( 62 routers = make(map[string]http.Handler) // tRPC service name -> Router 63 routerLock sync.RWMutex 64 ) 65 66 // RegisterRouter registers a Router which corresponds to a tRPC Service. 67 func RegisterRouter(name string, router http.Handler) { 68 routerLock.Lock() 69 routers[name] = router 70 routerLock.Unlock() 71 } 72 73 // GetRouter returns a Router which corresponds to a tRPC Service. 74 func GetRouter(name string) http.Handler { 75 routerLock.RLock() 76 router := routers[name] 77 routerLock.RUnlock() 78 return router 79 } 80 81 // ProtoMessage is alias of proto.Message. 82 type ProtoMessage proto.Message 83 84 // Initializer initializes a ProtoMessage. 85 type Initializer func() ProtoMessage 86 87 // BodyLocator locates which fields of the proto message would be 88 // populated according to HttpRule body. 89 type BodyLocator interface { 90 Body() string 91 Locate(ProtoMessage) interface{} 92 } 93 94 // ResponseBodyLocator locates which fields of the proto message would be marshaled 95 // according to HttpRule response_body. 96 type ResponseBodyLocator interface { 97 ResponseBody() string 98 Locate(ProtoMessage) interface{} 99 } 100 101 // HandleFunc is tRPC method handle function. 102 type HandleFunc func(svc interface{}, ctx context.Context, reqBody interface{}) (interface{}, error) 103 104 // ExtractFilterFunc extracts tRPC service filter chain. 105 type ExtractFilterFunc func() filter.ServerChain 106 107 // Binding is the binding of tRPC method and HttpRule. 108 type Binding struct { 109 Name string 110 Input Initializer 111 Output Initializer 112 Filter HandleFunc 113 HTTPMethod string 114 Pattern *Pattern 115 Body BodyLocator 116 ResponseBody ResponseBodyLocator 117 } 118 119 // AddImplBinding creates a new binding with a specified service implementation. 120 func (r *Router) AddImplBinding(binding *Binding, serviceImpl interface{}) error { 121 tr, err := r.newTranscoder(binding, serviceImpl) 122 if err != nil { 123 return fmt.Errorf("new transcoder during add impl binding: %w", err) 124 } 125 // add transcoder 126 r.transcoders[binding.HTTPMethod] = append(r.transcoders[binding.HTTPMethod], tr) 127 return nil 128 } 129 130 func (r *Router) newTranscoder(binding *Binding, serviceImpl interface{}) (*transcoder, error) { 131 if binding.Output == nil { 132 binding.Output = func() ProtoMessage { return &emptypb.Empty{} } 133 } 134 135 // create a transcoder 136 tr := &transcoder{ 137 name: binding.Name, 138 input: binding.Input, 139 output: binding.Output, 140 handler: binding.Filter, 141 httpMethod: binding.HTTPMethod, 142 pat: binding.Pattern, 143 body: binding.Body, 144 respBody: binding.ResponseBody, 145 router: r, 146 discardUnknownParams: r.opts.DiscardUnknownParams, 147 serviceImpl: serviceImpl, 148 } 149 150 // create a dat, filter all fields specified in HttpRule 151 var fps [][]string 152 if fromPat := binding.Pattern.FieldPaths(); fromPat != nil { 153 fps = append(fps, fromPat...) 154 } 155 if binding.Body != nil { 156 if fromBody := binding.Body.Body(); fromBody != "" && fromBody != "*" { 157 fps = append(fps, strings.Split(fromBody, ".")) 158 } 159 } 160 if len(fps) > 0 { 161 doubleArrayTrie, err := dat.Build(fps) 162 if err != nil { 163 return nil, fmt.Errorf("failed to build dat: %w", err) 164 } 165 tr.dat = doubleArrayTrie 166 } 167 return tr, nil 168 } 169 170 // ctxForCompatibility is used only for compatibility with thttp. 171 var ctxForCompatibility func(context.Context, http.ResponseWriter, *http.Request) context.Context 172 173 // SetCtxForCompatibility is used only for compatibility with thttp. 174 func SetCtxForCompatibility(f func(context.Context, http.ResponseWriter, *http.Request) context.Context) { 175 ctxForCompatibility = f 176 } 177 178 // HeaderMatcher matches http request header to tRPC Stub Context. 179 type HeaderMatcher func( 180 ctx context.Context, 181 w http.ResponseWriter, 182 r *http.Request, 183 serviceName, methodName string, 184 ) (context.Context, error) 185 186 // DefaultHeaderMatcher is the default HeaderMatcher. 187 var DefaultHeaderMatcher = func( 188 ctx context.Context, 189 w http.ResponseWriter, 190 req *http.Request, 191 serviceName, methodName string, 192 ) (context.Context, error) { 193 // Noted: it's better to do the same thing as withNewMessage. 194 return withNewMessage(ctx, serviceName, methodName), nil 195 } 196 197 // withNewMessage create a new codec.Msg, put it into ctx, 198 // and set target service name and method name. 199 func withNewMessage(ctx context.Context, serviceName, methodName string) context.Context { 200 ctx, msg := codec.WithNewMessage(ctx) 201 msg.WithServerRPCName(methodName) 202 msg.WithCalleeServiceName(serviceName) 203 msg.WithSerializationType(codec.SerializationTypePB) 204 return ctx 205 } 206 207 // CustomResponseHandler is the custom response handler. 208 type CustomResponseHandler func( 209 ctx context.Context, 210 w http.ResponseWriter, 211 r *http.Request, 212 resp proto.Message, 213 body []byte, 214 ) error 215 216 var httpStatusKey = "t-http-status" 217 218 // SetStatusCodeOnSucceed sets status code on succeed, should be 2XX. 219 // It's not supposed to call this function but use WithStatusCode in restful/errors.go 220 // to set status code on error. 221 func SetStatusCodeOnSucceed(ctx context.Context, code int) { 222 msg := codec.Message(ctx) 223 metadata := msg.ServerMetaData() 224 if metadata == nil { 225 metadata = codec.MetaData{} 226 } 227 metadata[httpStatusKey] = []byte(strconv.Itoa(code)) 228 msg.WithServerMetaData(metadata) 229 } 230 231 // GetStatusCodeOnSucceed returns status code on succeed. 232 // SetStatusCodeOnSucceed must be called first in tRPC method. 233 func GetStatusCodeOnSucceed(ctx context.Context) int { 234 if metadata := codec.Message(ctx).ServerMetaData(); metadata != nil { 235 if buf, ok := metadata[httpStatusKey]; ok { 236 if code, err := strconv.Atoi(bytes2str(buf)); err == nil { 237 return code 238 } 239 } 240 } 241 return http.StatusOK 242 } 243 244 // DefaultResponseHandler is the default CustomResponseHandler. 245 var DefaultResponseHandler = func( 246 ctx context.Context, 247 w http.ResponseWriter, 248 r *http.Request, 249 resp proto.Message, 250 body []byte, 251 ) error { 252 // compress 253 var writer io.Writer = w 254 _, c := compressorForTranscoding(r.Header[headerContentEncoding], 255 r.Header[headerAcceptEncoding]) 256 if c != nil { 257 writeCloser, err := c.Compress(w) 258 if err != nil { 259 return fmt.Errorf("failed to compress resp body: %w", err) 260 } 261 defer writeCloser.Close() 262 w.Header().Set(headerContentEncoding, c.ContentEncoding()) 263 writer = writeCloser 264 } 265 266 // set response content-type 267 _, s := serializerForTranscoding(r.Header[headerContentType], 268 r.Header[headerAccept]) 269 w.Header().Set(headerContentType, s.ContentType()) 270 271 // set status code 272 statusCode := GetStatusCodeOnSucceed(ctx) 273 w.WriteHeader(statusCode) 274 275 // response body 276 if statusCode != http.StatusNoContent && statusCode != http.StatusNotModified { 277 writer.Write(body) 278 } 279 280 return nil 281 } 282 283 // putBackCtxMessage calls codec.PutBackMessage to put a codec.Msg back to pool, 284 // if the codec.Msg has been put into ctx. 285 func putBackCtxMessage(ctx context.Context) { 286 if msg, ok := ctx.Value(codec.ContextKeyMessage).(codec.Msg); ok { 287 codec.PutBackMessage(msg) 288 } 289 } 290 291 // ServeHTTP implements http.Handler. 292 // TODO: better routing handling. 293 func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { 294 ctx := ctxForCompatibility(req.Context(), w, req) 295 for _, tr := range r.transcoders[req.Method] { 296 fieldValues, err := tr.pat.Match(req.URL.Path) 297 if err == nil { 298 r.handle(ctx, w, req, tr, fieldValues) 299 return 300 } 301 } 302 r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerNoFunc, "failed to match any pattern")) 303 } 304 305 func (r *Router) handle( 306 ctx context.Context, 307 w http.ResponseWriter, 308 req *http.Request, 309 tr *transcoder, 310 fieldValues map[string]string, 311 ) { 312 modifiedCtx, err := r.opts.HeaderMatcher(ctx, w, req, r.opts.ServiceName, tr.name) 313 if err != nil { 314 r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerDecodeFail, err.Error())) 315 return 316 } 317 ctx = modifiedCtx 318 defer putBackCtxMessage(ctx) 319 320 timeout := r.opts.Timeout 321 requestTimeout := codec.Message(ctx).RequestTimeout() 322 if requestTimeout > 0 && (requestTimeout < timeout || timeout == 0) { 323 timeout = requestTimeout 324 } 325 if timeout > 0 { 326 var cancel context.CancelFunc 327 ctx, cancel = context.WithTimeout(ctx, timeout) 328 defer cancel() 329 } 330 331 // get inbound/outbound Compressor and Serializer 332 reqCompressor, respCompressor := compressorForTranscoding(req.Header[headerContentEncoding], 333 req.Header[headerAcceptEncoding]) 334 reqSerializer, respSerializer := serializerForTranscoding(req.Header[headerContentType], 335 req.Header[headerAccept]) 336 337 // set transcoder params 338 params, _ := paramsPool.Get().(*transcodeParams) 339 params.reqCompressor = reqCompressor 340 params.respCompressor = respCompressor 341 params.reqSerializer = reqSerializer 342 params.respSerializer = respSerializer 343 params.body = req.Body 344 params.fieldValues = fieldValues 345 params.form = req.URL.Query() 346 defer putBackParams(params) 347 348 // transcode 349 resp, body, err := tr.transcode(ctx, params) 350 if err != nil { 351 r.opts.ErrorHandler(ctx, w, req, err) 352 return 353 } 354 355 // custom response handling 356 if err := r.opts.ResponseHandler(ctx, w, req, resp, body); err != nil { 357 r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerEncodeFail, err.Error())) 358 } 359 }