trpc.group/trpc-go/trpc-go@v1.0.2/server/service.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package server 15 16 import ( 17 "context" 18 "errors" 19 "fmt" 20 "os" 21 "reflect" 22 "strconv" 23 "sync/atomic" 24 "time" 25 26 "trpc.group/trpc-go/trpc-go/codec" 27 "trpc.group/trpc-go/trpc-go/errs" 28 "trpc.group/trpc-go/trpc-go/filter" 29 icodec "trpc.group/trpc-go/trpc-go/internal/codec" 30 "trpc.group/trpc-go/trpc-go/internal/report" 31 "trpc.group/trpc-go/trpc-go/log" 32 "trpc.group/trpc-go/trpc-go/naming/registry" 33 "trpc.group/trpc-go/trpc-go/restful" 34 "trpc.group/trpc-go/trpc-go/rpcz" 35 "trpc.group/trpc-go/trpc-go/transport" 36 ) 37 38 // MaxCloseWaitTime is the max waiting time for closing services. 39 const MaxCloseWaitTime = 10 * time.Second 40 41 // Service is the interface that provides services. 42 type Service interface { 43 // Register registers a proto service. 44 Register(serviceDesc interface{}, serviceImpl interface{}) error 45 // Serve starts serving. 46 Serve() error 47 // Close stops serving. 48 Close(chan struct{}) error 49 } 50 51 // FilterFunc reads reqBody, parses it, and returns a filter.Chain for server stub. 52 type FilterFunc func(reqBody interface{}) (filter.ServerChain, error) 53 54 // Method provides the information of an RPC Method. 55 type Method struct { 56 Name string 57 Func func(svr interface{}, ctx context.Context, f FilterFunc) (rspBody interface{}, err error) 58 Bindings []*restful.Binding 59 } 60 61 // ServiceDesc describes a proto service. 62 type ServiceDesc struct { 63 ServiceName string 64 HandlerType interface{} 65 Methods []Method 66 Streams []StreamDesc 67 StreamHandle StreamHandle 68 } 69 70 // StreamDesc describes a server stream. 71 type StreamDesc struct { 72 // StreamName is the name of stream. 73 StreamName string 74 // Handler is a stream handler. 75 Handler StreamHandlerWapper 76 // ServerStreams indicates whether it's server streaming. 77 ServerStreams bool 78 // ClientStreams indicates whether it's client streaming. 79 ClientStreams bool 80 } 81 82 // Handler is the default handler. 83 type Handler func(ctx context.Context, f FilterFunc) (rspBody interface{}, err error) 84 85 // StreamHandlerWapper is server stream handler wrapper. 86 // The input param srv should be an implementation of server stream proto service. 87 // The input param stream is used by srv. 88 type StreamHandlerWapper func(srv interface{}, stream Stream) error 89 90 // StreamHandler is server stream handler. 91 type StreamHandler func(stream Stream) error 92 93 // Stream is the interface that defines server stream api. 94 type Stream interface { 95 // Context is context of server stream. 96 Context() context.Context 97 // SendMsg sends streaming data. 98 SendMsg(m interface{}) error 99 // RecvMsg receives streaming data. 100 RecvMsg(m interface{}) error 101 } 102 103 // service is an implementation of Service 104 type service struct { 105 ctx context.Context // context of this service 106 cancel context.CancelFunc // function that cancels this service 107 opts *Options // options of this service 108 handlers map[string]Handler // rpcname => handler 109 streamHandlers map[string]StreamHandler 110 streamInfo map[string]*StreamServerInfo 111 activeCount int64 // active requests count for graceful close if set MaxCloseWaitTime 112 } 113 114 // New creates a service. 115 // It will use transport.DefaultServerTransport unless Option WithTransport() 116 // is called to replace its transport.ServerTransport plugin. 117 var New = func(opts ...Option) Service { 118 o := defaultOptions() 119 s := &service{ 120 opts: o, 121 handlers: make(map[string]Handler), 122 streamHandlers: make(map[string]StreamHandler), 123 streamInfo: make(map[string]*StreamServerInfo), 124 } 125 for _, o := range opts { 126 o(s.opts) 127 } 128 o.Transport = attemptSwitchingTransport(o) 129 if !s.opts.handlerSet { 130 // if handler is not set, pass the service (which implements Handler interface) 131 // as handler of transport plugin. 132 s.opts.ServeOptions = append(s.opts.ServeOptions, transport.WithHandler(s)) 133 } 134 s.ctx, s.cancel = context.WithCancel(context.Background()) 135 return s 136 } 137 138 // Serve implements Service, starting serving. 139 func (s *service) Serve() error { 140 pid := os.Getpid() 141 142 // make sure ListenAndServe succeeds before Naming Service Registry. 143 if err := s.opts.Transport.ListenAndServe(s.ctx, s.opts.ServeOptions...); err != nil { 144 log.Errorf("process:%d service:%s ListenAndServe fail:%v", pid, s.opts.ServiceName, err) 145 return err 146 } 147 148 if s.opts.Registry != nil { 149 opts := []registry.Option{ 150 registry.WithAddress(s.opts.Address), 151 } 152 if isGraceful, isParental := checkProcessStatus(); isGraceful && !isParental { 153 // If current process is the child process forked for graceful restart, 154 // service should notify the registry plugin of graceful restart event. 155 // The registry plugin might handle registry results according to this event. 156 // For example, repeat registry cause error but the plugin would consider it's ok 157 // according to this event. 158 opts = append(opts, registry.WithEvent(registry.GracefulRestart)) 159 } 160 if err := s.opts.Registry.Register(s.opts.ServiceName, opts...); err != nil { 161 // if registry fails, service needs to be closed and error should be returned. 162 log.Errorf("process:%d, service:%s register fail:%v", pid, s.opts.ServiceName, err) 163 return err 164 } 165 } 166 167 log.Infof("process:%d, %s service:%s launch success, %s:%s, serving ...", 168 pid, s.opts.protocol, s.opts.ServiceName, s.opts.network, s.opts.Address) 169 170 report.ServiceStart.Incr() 171 <-s.ctx.Done() 172 return nil 173 } 174 175 // Handle implements transport.Handler. 176 // service itself is passed to its transport plugin as a transport handler. 177 // This is like a callback function that would be called by service's transport plugin. 178 func (s *service) Handle(ctx context.Context, reqBuf []byte) (rspBuf []byte, err error) { 179 if s.opts.MaxCloseWaitTime > s.opts.CloseWaitTime || s.opts.MaxCloseWaitTime > MaxCloseWaitTime { 180 atomic.AddInt64(&s.activeCount, 1) 181 defer atomic.AddInt64(&s.activeCount, -1) 182 } 183 184 // if server codec is empty, simply returns error. 185 if s.opts.Codec == nil { 186 log.ErrorContextf(ctx, "server codec empty") 187 report.ServerCodecEmpty.Incr() 188 return nil, errors.New("server codec empty") 189 } 190 191 msg := codec.Message(ctx) 192 span := rpcz.SpanFromContext(ctx) 193 span.SetAttribute(rpcz.TRPCAttributeFilterNames, s.opts.FilterNames) 194 195 _, end := span.NewChild("DecodeProtocolHead") 196 reqBodyBuf, err := s.decode(ctx, msg, reqBuf) 197 end.End() 198 199 if err != nil { 200 return s.encode(ctx, msg, nil, err) 201 } 202 // ServerRspErr is already set, 203 // since RequestID is acquired, just respond to client. 204 if err := msg.ServerRspErr(); err != nil { 205 return s.encode(ctx, msg, nil, err) 206 } 207 208 rspbody, err := s.handle(ctx, msg, reqBodyBuf) 209 if err != nil { 210 // no response 211 if err == errs.ErrServerNoResponse { 212 return nil, err 213 } 214 // failed to handle, should respond to client with error code, 215 // ignore rspBody. 216 report.ServiceHandleFail.Incr() 217 return s.encode(ctx, msg, nil, err) 218 } 219 return s.handleResponse(ctx, msg, rspbody) 220 } 221 222 // HandleClose is called when conn is closed. 223 // Currently, only used for server stream. 224 func (s *service) HandleClose(ctx context.Context) error { 225 if codec.Message(ctx).ServerRspErr() != nil && s.opts.StreamHandle != nil { 226 _, err := s.opts.StreamHandle.StreamHandleFunc(ctx, nil, nil, nil) 227 return err 228 } 229 return nil 230 } 231 232 func (s *service) encode(ctx context.Context, msg codec.Msg, rspBodyBuf []byte, e error) (rspBuf []byte, err error) { 233 if e != nil { 234 log.DebugContextf( 235 ctx, 236 "service: %s handle err (if caused by health checking, this error can be ignored): %+v", 237 s.opts.ServiceName, e) 238 msg.WithServerRspErr(e) 239 } 240 241 rspBuf, err = s.opts.Codec.Encode(msg, rspBodyBuf) 242 if err != nil { 243 report.ServiceCodecEncodeFail.Incr() 244 log.ErrorContextf(ctx, "service:%s encode fail:%v", s.opts.ServiceName, err) 245 return nil, err 246 } 247 return rspBuf, nil 248 } 249 250 // handleStream handles server stream. 251 func (s *service) handleStream(ctx context.Context, msg codec.Msg, reqBuf []byte, sh StreamHandler, 252 opts *Options) (resbody interface{}, err error) { 253 if s.opts.StreamHandle != nil { 254 si := s.streamInfo[msg.ServerRPCName()] 255 return s.opts.StreamHandle.StreamHandleFunc(ctx, sh, si, reqBuf) 256 } 257 return nil, errs.NewFrameError(errs.RetServerNoService, "Stream method no Handle") 258 } 259 260 func (s *service) decode(ctx context.Context, msg codec.Msg, reqBuf []byte) ([]byte, error) { 261 s.setOpt(msg) 262 reqBodyBuf, err := s.opts.Codec.Decode(msg, reqBuf) 263 if err != nil { 264 report.ServiceCodecDecodeFail.Incr() 265 return nil, errs.NewFrameError(errs.RetServerDecodeFail, "service codec Decode: "+err.Error()) 266 } 267 268 // call setOpt again to avoid some msg infos (namespace, env name, etc.) 269 // being modified by request decoding. 270 s.setOpt(msg) 271 return reqBodyBuf, nil 272 } 273 274 func (s *service) setOpt(msg codec.Msg) { 275 msg.WithNamespace(s.opts.Namespace) // service namespace 276 msg.WithEnvName(s.opts.EnvName) // service environment 277 msg.WithSetName(s.opts.SetName) // service "Set" 278 msg.WithCalleeServiceName(s.opts.ServiceName) // from perspective of the service, callee refers to itself 279 } 280 281 func (s *service) handle(ctx context.Context, msg codec.Msg, reqBodyBuf []byte) (interface{}, error) { 282 // whether is server streaming RPC 283 streamHandler, ok := s.streamHandlers[msg.ServerRPCName()] 284 if ok { 285 return s.handleStream(ctx, msg, reqBodyBuf, streamHandler, s.opts) 286 } 287 handler, ok := s.handlers[msg.ServerRPCName()] 288 if !ok { 289 handler, ok = s.handlers["*"] // wildcard 290 if !ok { 291 report.ServiceHandleRPCNameInvalid.Incr() 292 return nil, errs.NewFrameError(errs.RetServerNoFunc, 293 fmt.Sprintf("service handle: rpc name %s invalid, current service:%s", 294 msg.ServerRPCName(), msg.CalleeServiceName())) 295 } 296 } 297 298 var fixTimeout filter.ServerFilter 299 if s.opts.Timeout > 0 { 300 fixTimeout = mayConvert2NormalTimeout 301 } 302 timeout := s.opts.Timeout 303 if msg.RequestTimeout() > 0 && !s.opts.DisableRequestTimeout { // 可以配置禁用 304 if msg.RequestTimeout() < timeout || timeout == 0 { // 取最小值 305 fixTimeout = mayConvert2FullLinkTimeout 306 timeout = msg.RequestTimeout() 307 } 308 } 309 if timeout > 0 { 310 var cancel context.CancelFunc 311 ctx, cancel = context.WithTimeout(ctx, timeout) 312 defer cancel() 313 } 314 newFilterFunc := s.filterFunc(ctx, msg, reqBodyBuf, fixTimeout) 315 rspBody, err := handler(ctx, newFilterFunc) 316 if err != nil { 317 if e, ok := err.(*errs.Error); ok && 318 e.Type == errs.ErrorTypeFramework && 319 e.Code == errs.RetServerFullLinkTimeout { 320 err = errs.ErrServerNoResponse 321 } 322 return nil, err 323 } 324 if msg.CallType() == codec.SendOnly { 325 return nil, errs.ErrServerNoResponse 326 } 327 return rspBody, nil 328 } 329 330 // handleResponse handles response. 331 // serialization type is set to msg.SerializationType() by default, 332 // if serialization type Option is called, serialization type is set by the Option. 333 // compress type's setting is similar to it. 334 func (s *service) handleResponse(ctx context.Context, msg codec.Msg, rspBody interface{}) ([]byte, error) { 335 // marshal response body 336 337 serializationType := msg.SerializationType() 338 if icodec.IsValidSerializationType(s.opts.CurrentSerializationType) { 339 serializationType = s.opts.CurrentSerializationType 340 } 341 span := rpcz.SpanFromContext(ctx) 342 343 _, end := span.NewChild("Marshal") 344 rspBodyBuf, err := codec.Marshal(serializationType, rspBody) 345 end.End() 346 347 if err != nil { 348 report.ServiceCodecMarshalFail.Incr() 349 err = errs.NewFrameError(errs.RetServerEncodeFail, "service codec Marshal: "+err.Error()) 350 // rspBodyBuf will be nil if marshalling fails, respond only error code to client. 351 return s.encode(ctx, msg, rspBodyBuf, err) 352 } 353 354 // compress response body 355 compressType := msg.CompressType() 356 if icodec.IsValidCompressType(s.opts.CurrentCompressType) { 357 compressType = s.opts.CurrentCompressType 358 } 359 360 _, end = span.NewChild("Compress") 361 rspBodyBuf, err = codec.Compress(compressType, rspBodyBuf) 362 end.End() 363 364 if err != nil { 365 report.ServiceCodecCompressFail.Incr() 366 err = errs.NewFrameError(errs.RetServerEncodeFail, "service codec Compress: "+err.Error()) 367 // rspBodyBuf will be nil if compression fails, respond only error code to client. 368 return s.encode(ctx, msg, rspBodyBuf, err) 369 } 370 371 _, end = span.NewChild("EncodeProtocolHead") 372 rspBuf, err := s.encode(ctx, msg, rspBodyBuf, nil) 373 end.End() 374 375 return rspBuf, err 376 } 377 378 // filterFunc returns a FilterFunc, which would be passed to server stub to access pre/post filter handling. 379 func (s *service) filterFunc( 380 ctx context.Context, 381 msg codec.Msg, 382 reqBodyBuf []byte, 383 fixTimeout filter.ServerFilter, 384 ) FilterFunc { 385 // Decompression, serialization of request body are put into a closure. 386 // Both serialization type & compress type can be set. 387 // serialization type is set to msg.SerializationType() by default, 388 // if serialization type Option is called, serialization type is set by the Option. 389 // compress type's setting is similar to it. 390 return func(reqBody interface{}) (filter.ServerChain, error) { 391 // decompress request body 392 compressType := msg.CompressType() 393 if icodec.IsValidCompressType(s.opts.CurrentCompressType) { 394 compressType = s.opts.CurrentCompressType 395 } 396 span := rpcz.SpanFromContext(ctx) 397 _, end := span.NewChild("Decompress") 398 reqBodyBuf, err := codec.Decompress(compressType, reqBodyBuf) 399 end.End() 400 if err != nil { 401 report.ServiceCodecDecompressFail.Incr() 402 return nil, errs.NewFrameError(errs.RetServerDecodeFail, "service codec Decompress: "+err.Error()) 403 } 404 405 // unmarshal request body 406 serializationType := msg.SerializationType() 407 if icodec.IsValidSerializationType(s.opts.CurrentSerializationType) { 408 serializationType = s.opts.CurrentSerializationType 409 } 410 _, end = span.NewChild("Unmarshal") 411 err = codec.Unmarshal(serializationType, reqBodyBuf, reqBody) 412 end.End() 413 if err != nil { 414 report.ServiceCodecUnmarshalFail.Incr() 415 return nil, errs.NewFrameError(errs.RetServerDecodeFail, "service codec Unmarshal: "+err.Error()) 416 } 417 418 if fixTimeout != nil { 419 // this heap allocation cannot be avoided unless we change the generated xxx.trpc.go. 420 filters := make(filter.ServerChain, len(s.opts.Filters), len(s.opts.Filters)+1) 421 copy(filters, s.opts.Filters) 422 return append(filters, fixTimeout), nil 423 } 424 return s.opts.Filters, nil 425 } 426 } 427 428 // Register implements Service interface, registering a proto service impl for the service. 429 func (s *service) Register(serviceDesc interface{}, serviceImpl interface{}) error { 430 desc, ok := serviceDesc.(*ServiceDesc) 431 if !ok { 432 return errors.New("serviceDesc is not *ServiceDesc") 433 } 434 if desc.StreamHandle != nil { 435 s.opts.StreamHandle = desc.StreamHandle 436 if s.opts.StreamTransport != nil { 437 s.opts.Transport = s.opts.StreamTransport 438 } 439 // IdleTimeout is not used by server stream, set it to 0. 440 s.opts.ServeOptions = append(s.opts.ServeOptions, transport.WithServerIdleTimeout(0)) 441 err := s.opts.StreamHandle.Init(s.opts) 442 if err != nil { 443 return err 444 } 445 } 446 447 if serviceImpl != nil { 448 ht := reflect.TypeOf(desc.HandlerType).Elem() 449 hi := reflect.TypeOf(serviceImpl) 450 if !hi.Implements(ht) { 451 return fmt.Errorf("%s not implements interface %s", hi.String(), ht.String()) 452 } 453 } 454 455 var bindings []*restful.Binding 456 for _, method := range desc.Methods { 457 n := method.Name 458 if _, ok := s.handlers[n]; ok { 459 return fmt.Errorf("duplicate method name: %s", n) 460 } 461 h := method.Func 462 s.handlers[n] = func(ctx context.Context, f FilterFunc) (rsp interface{}, err error) { 463 return h(serviceImpl, ctx, f) 464 } 465 bindings = append(bindings, method.Bindings...) 466 } 467 468 for _, stream := range desc.Streams { 469 n := stream.StreamName 470 if _, ok := s.streamHandlers[n]; ok { 471 return fmt.Errorf("duplicate stream name: %s", n) 472 } 473 h := stream.Handler 474 s.streamInfo[stream.StreamName] = &StreamServerInfo{ 475 FullMethod: stream.StreamName, 476 IsClientStream: stream.ClientStreams, 477 IsServerStream: stream.ServerStreams, 478 } 479 s.streamHandlers[stream.StreamName] = func(stream Stream) error { 480 return h(serviceImpl, stream) 481 } 482 } 483 return s.createOrUpdateRouter(bindings, serviceImpl) 484 } 485 486 func (s *service) createOrUpdateRouter(bindings []*restful.Binding, serviceImpl interface{}) error { 487 // If pb option (trpc.api.http) is set,creates a RESTful Router. 488 if len(bindings) == 0 { 489 return nil 490 } 491 handler := restful.GetRouter(s.opts.ServiceName) 492 if handler != nil { 493 if router, ok := handler.(*restful.Router); ok { // A router has already been registered. 494 for _, binding := range bindings { // Add binding with a specified service implementation. 495 if err := router.AddImplBinding(binding, serviceImpl); err != nil { 496 return fmt.Errorf("add impl binding during service registration: %w", err) 497 } 498 } 499 return nil 500 } 501 } 502 // This is the first time of registering the service router, create a new one. 503 router := restful.NewRouter(append(s.opts.RESTOptions, 504 restful.WithNamespace(s.opts.Namespace), 505 restful.WithEnvironment(s.opts.EnvName), 506 restful.WithContainer(s.opts.container), 507 restful.WithSet(s.opts.SetName), 508 restful.WithServiceName(s.opts.ServiceName), 509 restful.WithTimeout(s.opts.Timeout), 510 restful.WithFilterFunc(func() filter.ServerChain { return s.opts.Filters }))...) 511 for _, binding := range bindings { 512 if err := router.AddImplBinding(binding, serviceImpl); err != nil { 513 return err 514 } 515 } 516 restful.RegisterRouter(s.opts.ServiceName, router) 517 return nil 518 } 519 520 // Close closes the service,registry.Deregister will be called. 521 func (s *service) Close(ch chan struct{}) error { 522 pid := os.Getpid() 523 if ch == nil { 524 ch = make(chan struct{}, 1) 525 } 526 log.Infof("process:%d, %s service:%s, closing ...", pid, s.opts.protocol, s.opts.ServiceName) 527 528 if s.opts.Registry != nil { 529 // When it comes to graceful restart, the parent process will not call registry Deregister(), 530 // while the child process would call registry Deregister(). 531 if isGraceful, isParental := checkProcessStatus(); !(isGraceful && isParental) { 532 if err := s.opts.Registry.Deregister(s.opts.ServiceName); err != nil { 533 log.Errorf("process:%d, deregister service:%s fail:%v", pid, s.opts.ServiceName, err) 534 } 535 } 536 } 537 s.waitBeforeClose() 538 539 // this will cancel all children ctx. 540 s.cancel() 541 timeout := time.Millisecond * 300 542 if s.opts.Timeout > timeout { // use the larger one 543 timeout = s.opts.Timeout 544 } 545 time.Sleep(timeout) 546 log.Infof("process:%d, %s service:%s, closed", pid, s.opts.protocol, s.opts.ServiceName) 547 ch <- struct{}{} 548 return nil 549 } 550 551 func (s *service) waitBeforeClose() { 552 closeWaitTime := s.opts.CloseWaitTime 553 if closeWaitTime > MaxCloseWaitTime { 554 closeWaitTime = MaxCloseWaitTime 555 } 556 if closeWaitTime > 0 { 557 // After registry.Deregister() is called, sleep a while to let Naming Service (like Polaris) finish 558 // updating instance ip list. 559 // Otherwise, client request would still arrive while the service had already been closed (Typically, it occurs 560 // when k8s updates pods). 561 log.Infof("process %d service %s remain %d requests wait %v time when closing service", 562 os.Getpid(), s.opts.ServiceName, atomic.LoadInt64(&s.activeCount), closeWaitTime) 563 time.Sleep(closeWaitTime) 564 } 565 const sleepTime = 100 * time.Millisecond 566 if s.opts.MaxCloseWaitTime > closeWaitTime { 567 spinCount := int((s.opts.MaxCloseWaitTime - closeWaitTime) / sleepTime) 568 for i := 0; i < spinCount; i++ { 569 if atomic.LoadInt64(&s.activeCount) <= 0 { 570 break 571 } 572 time.Sleep(sleepTime) 573 } 574 log.Infof("process %d service %s remain %d requests when closing service", 575 os.Getpid(), s.opts.ServiceName, atomic.LoadInt64(&s.activeCount)) 576 } 577 } 578 579 func checkProcessStatus() (isGracefulRestart, isParentalProcess bool) { 580 v := os.Getenv(transport.EnvGraceRestartPPID) 581 if v == "" { 582 return false, true 583 } 584 585 ppid, err := strconv.Atoi(v) 586 if err != nil { 587 return false, false 588 } 589 return true, ppid == os.Getpid() 590 } 591 592 func defaultOptions() *Options { 593 const ( 594 invalidSerializationType = -1 595 invalidCompressType = -1 596 ) 597 return &Options{ 598 protocol: "unknown-protocol", 599 ServiceName: "empty-name", 600 CurrentSerializationType: invalidSerializationType, 601 CurrentCompressType: invalidCompressType, 602 } 603 }