trpc.group/trpc-go/trpc-go@v1.0.3/server/service.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package server 15 16 import ( 17 "context" 18 "errors" 19 "fmt" 20 "os" 21 "reflect" 22 "strconv" 23 "sync/atomic" 24 "time" 25 26 "trpc.group/trpc-go/trpc-go/codec" 27 "trpc.group/trpc-go/trpc-go/errs" 28 "trpc.group/trpc-go/trpc-go/filter" 29 icodec "trpc.group/trpc-go/trpc-go/internal/codec" 30 "trpc.group/trpc-go/trpc-go/internal/report" 31 "trpc.group/trpc-go/trpc-go/log" 32 "trpc.group/trpc-go/trpc-go/naming/registry" 33 "trpc.group/trpc-go/trpc-go/restful" 34 "trpc.group/trpc-go/trpc-go/rpcz" 35 "trpc.group/trpc-go/trpc-go/transport" 36 ) 37 38 // MaxCloseWaitTime is the max waiting time for closing services. 39 const MaxCloseWaitTime = 10 * time.Second 40 41 // Service is the interface that provides services. 42 type Service interface { 43 // Register registers a proto service. 44 Register(serviceDesc interface{}, serviceImpl interface{}) error 45 // Serve starts serving. 46 Serve() error 47 // Close stops serving. 48 Close(chan struct{}) error 49 } 50 51 // FilterFunc reads reqBody, parses it, and returns a filter.Chain for server stub. 52 type FilterFunc func(reqBody interface{}) (filter.ServerChain, error) 53 54 // Method provides the information of an RPC Method. 55 type Method struct { 56 Name string 57 Func func(svr interface{}, ctx context.Context, f FilterFunc) (rspBody interface{}, err error) 58 Bindings []*restful.Binding 59 } 60 61 // ServiceDesc describes a proto service. 62 type ServiceDesc struct { 63 ServiceName string 64 HandlerType interface{} 65 Methods []Method 66 Streams []StreamDesc 67 StreamHandle StreamHandle 68 } 69 70 // StreamDesc describes a server stream. 71 type StreamDesc struct { 72 // StreamName is the name of stream. 73 StreamName string 74 // Handler is a stream handler. 75 Handler StreamHandlerWapper 76 // ServerStreams indicates whether it's server streaming. 77 ServerStreams bool 78 // ClientStreams indicates whether it's client streaming. 79 ClientStreams bool 80 } 81 82 // Handler is the default handler. 83 type Handler func(ctx context.Context, f FilterFunc) (rspBody interface{}, err error) 84 85 // StreamHandlerWapper is server stream handler wrapper. 86 // The input param srv should be an implementation of server stream proto service. 87 // The input param stream is used by srv. 88 type StreamHandlerWapper func(srv interface{}, stream Stream) error 89 90 // StreamHandler is server stream handler. 91 type StreamHandler func(stream Stream) error 92 93 // Stream is the interface that defines server stream api. 94 type Stream interface { 95 // Context is context of server stream. 96 Context() context.Context 97 // SendMsg sends streaming data. 98 SendMsg(m interface{}) error 99 // RecvMsg receives streaming data. 100 RecvMsg(m interface{}) error 101 } 102 103 // service is an implementation of Service 104 type service struct { 105 activeCount int64 // active requests count for graceful close if set MaxCloseWaitTime 106 ctx context.Context // context of this service 107 cancel context.CancelFunc // function that cancels this service 108 opts *Options // options of this service 109 handlers map[string]Handler // rpcname => handler 110 streamHandlers map[string]StreamHandler 111 streamInfo map[string]*StreamServerInfo 112 stopListening chan<- struct{} 113 } 114 115 // New creates a service. 116 // It will use transport.DefaultServerTransport unless Option WithTransport() 117 // is called to replace its transport.ServerTransport plugin. 118 var New = func(opts ...Option) Service { 119 o := defaultOptions() 120 s := &service{ 121 opts: o, 122 handlers: make(map[string]Handler), 123 streamHandlers: make(map[string]StreamHandler), 124 streamInfo: make(map[string]*StreamServerInfo), 125 } 126 for _, o := range opts { 127 o(s.opts) 128 } 129 o.Transport = attemptSwitchingTransport(o) 130 if !s.opts.handlerSet { 131 // if handler is not set, pass the service (which implements Handler interface) 132 // as handler of transport plugin. 133 s.opts.ServeOptions = append(s.opts.ServeOptions, transport.WithHandler(s)) 134 } 135 s.ctx, s.cancel = context.WithCancel(context.Background()) 136 return s 137 } 138 139 // Serve implements Service, starting serving. 140 func (s *service) Serve() error { 141 pid := os.Getpid() 142 143 // make sure ListenAndServe succeeds before Naming Service Registry. 144 if err := s.opts.Transport.ListenAndServe(s.ctx, s.opts.ServeOptions...); err != nil { 145 log.Errorf("process:%d service:%s ListenAndServe fail:%v", pid, s.opts.ServiceName, err) 146 return err 147 } 148 149 if s.opts.Registry != nil { 150 opts := []registry.Option{ 151 registry.WithAddress(s.opts.Address), 152 } 153 if isGraceful, isParental := checkProcessStatus(); isGraceful && !isParental { 154 // If current process is the child process forked for graceful restart, 155 // service should notify the registry plugin of graceful restart event. 156 // The registry plugin might handle registry results according to this event. 157 // For example, repeat registry cause error but the plugin would consider it's ok 158 // according to this event. 159 opts = append(opts, registry.WithEvent(registry.GracefulRestart)) 160 } 161 if err := s.opts.Registry.Register(s.opts.ServiceName, opts...); err != nil { 162 // if registry fails, service needs to be closed and error should be returned. 163 log.Errorf("process:%d, service:%s register fail:%v", pid, s.opts.ServiceName, err) 164 return err 165 } 166 } 167 168 log.Infof("process:%d, %s service:%s launch success, %s:%s, serving ...", 169 pid, s.opts.protocol, s.opts.ServiceName, s.opts.network, s.opts.Address) 170 171 report.ServiceStart.Incr() 172 <-s.ctx.Done() 173 return nil 174 } 175 176 // Handle implements transport.Handler. 177 // service itself is passed to its transport plugin as a transport handler. 178 // This is like a callback function that would be called by service's transport plugin. 179 func (s *service) Handle(ctx context.Context, reqBuf []byte) (rspBuf []byte, err error) { 180 if s.opts.MaxCloseWaitTime > s.opts.CloseWaitTime || s.opts.MaxCloseWaitTime > MaxCloseWaitTime { 181 atomic.AddInt64(&s.activeCount, 1) 182 defer atomic.AddInt64(&s.activeCount, -1) 183 } 184 185 // if server codec is empty, simply returns error. 186 if s.opts.Codec == nil { 187 log.ErrorContextf(ctx, "server codec empty") 188 report.ServerCodecEmpty.Incr() 189 return nil, errors.New("server codec empty") 190 } 191 192 msg := codec.Message(ctx) 193 span := rpcz.SpanFromContext(ctx) 194 span.SetAttribute(rpcz.TRPCAttributeFilterNames, s.opts.FilterNames) 195 196 _, end := span.NewChild("DecodeProtocolHead") 197 reqBodyBuf, err := s.decode(ctx, msg, reqBuf) 198 end.End() 199 200 if err != nil { 201 return s.encode(ctx, msg, nil, err) 202 } 203 // ServerRspErr is already set, 204 // since RequestID is acquired, just respond to client. 205 if err := msg.ServerRspErr(); err != nil { 206 return s.encode(ctx, msg, nil, err) 207 } 208 209 rspbody, err := s.handle(ctx, msg, reqBodyBuf) 210 if err != nil { 211 // no response 212 if err == errs.ErrServerNoResponse { 213 return nil, err 214 } 215 // failed to handle, should respond to client with error code, 216 // ignore rspBody. 217 report.ServiceHandleFail.Incr() 218 return s.encode(ctx, msg, nil, err) 219 } 220 return s.handleResponse(ctx, msg, rspbody) 221 } 222 223 // HandleClose is called when conn is closed. 224 // Currently, only used for server stream. 225 func (s *service) HandleClose(ctx context.Context) error { 226 if codec.Message(ctx).ServerRspErr() != nil && s.opts.StreamHandle != nil { 227 _, err := s.opts.StreamHandle.StreamHandleFunc(ctx, nil, nil, nil) 228 return err 229 } 230 return nil 231 } 232 233 func (s *service) encode(ctx context.Context, msg codec.Msg, rspBodyBuf []byte, e error) (rspBuf []byte, err error) { 234 if e != nil { 235 log.DebugContextf( 236 ctx, 237 "service: %s handle err (if caused by health checking, this error can be ignored): %+v", 238 s.opts.ServiceName, e) 239 msg.WithServerRspErr(e) 240 } 241 242 rspBuf, err = s.opts.Codec.Encode(msg, rspBodyBuf) 243 if err != nil { 244 report.ServiceCodecEncodeFail.Incr() 245 log.ErrorContextf(ctx, "service:%s encode fail:%v", s.opts.ServiceName, err) 246 return nil, err 247 } 248 return rspBuf, nil 249 } 250 251 // handleStream handles server stream. 252 func (s *service) handleStream(ctx context.Context, msg codec.Msg, reqBuf []byte, sh StreamHandler, 253 opts *Options) (resbody interface{}, err error) { 254 if s.opts.StreamHandle != nil { 255 si := s.streamInfo[msg.ServerRPCName()] 256 return s.opts.StreamHandle.StreamHandleFunc(ctx, sh, si, reqBuf) 257 } 258 return nil, errs.NewFrameError(errs.RetServerNoService, "Stream method no Handle") 259 } 260 261 func (s *service) decode(ctx context.Context, msg codec.Msg, reqBuf []byte) ([]byte, error) { 262 s.setOpt(msg) 263 reqBodyBuf, err := s.opts.Codec.Decode(msg, reqBuf) 264 if err != nil { 265 report.ServiceCodecDecodeFail.Incr() 266 return nil, errs.NewFrameError(errs.RetServerDecodeFail, "service codec Decode: "+err.Error()) 267 } 268 269 // call setOpt again to avoid some msg infos (namespace, env name, etc.) 270 // being modified by request decoding. 271 s.setOpt(msg) 272 return reqBodyBuf, nil 273 } 274 275 func (s *service) setOpt(msg codec.Msg) { 276 msg.WithNamespace(s.opts.Namespace) // service namespace 277 msg.WithEnvName(s.opts.EnvName) // service environment 278 msg.WithSetName(s.opts.SetName) // service "Set" 279 msg.WithCalleeServiceName(s.opts.ServiceName) // from perspective of the service, callee refers to itself 280 } 281 282 func (s *service) handle(ctx context.Context, msg codec.Msg, reqBodyBuf []byte) (interface{}, error) { 283 // whether is server streaming RPC 284 streamHandler, ok := s.streamHandlers[msg.ServerRPCName()] 285 if ok { 286 return s.handleStream(ctx, msg, reqBodyBuf, streamHandler, s.opts) 287 } 288 handler, ok := s.handlers[msg.ServerRPCName()] 289 if !ok { 290 handler, ok = s.handlers["*"] // wildcard 291 if !ok { 292 report.ServiceHandleRPCNameInvalid.Incr() 293 return nil, errs.NewFrameError(errs.RetServerNoFunc, 294 fmt.Sprintf("service handle: rpc name %s invalid, current service:%s", 295 msg.ServerRPCName(), msg.CalleeServiceName())) 296 } 297 } 298 299 var fixTimeout filter.ServerFilter 300 if s.opts.Timeout > 0 { 301 fixTimeout = mayConvert2NormalTimeout 302 } 303 timeout := s.opts.Timeout 304 if msg.RequestTimeout() > 0 && !s.opts.DisableRequestTimeout { // 可以配置禁用 305 if msg.RequestTimeout() < timeout || timeout == 0 { // 取最小值 306 fixTimeout = mayConvert2FullLinkTimeout 307 timeout = msg.RequestTimeout() 308 } 309 } 310 if timeout > 0 { 311 var cancel context.CancelFunc 312 ctx, cancel = context.WithTimeout(ctx, timeout) 313 defer cancel() 314 } 315 newFilterFunc := s.filterFunc(ctx, msg, reqBodyBuf, fixTimeout) 316 rspBody, err := handler(ctx, newFilterFunc) 317 if err != nil { 318 if e, ok := err.(*errs.Error); ok && 319 e.Type == errs.ErrorTypeFramework && 320 e.Code == errs.RetServerFullLinkTimeout { 321 err = errs.ErrServerNoResponse 322 } 323 return nil, err 324 } 325 if msg.CallType() == codec.SendOnly { 326 return nil, errs.ErrServerNoResponse 327 } 328 return rspBody, nil 329 } 330 331 // handleResponse handles response. 332 // serialization type is set to msg.SerializationType() by default, 333 // if serialization type Option is called, serialization type is set by the Option. 334 // compress type's setting is similar to it. 335 func (s *service) handleResponse(ctx context.Context, msg codec.Msg, rspBody interface{}) ([]byte, error) { 336 // marshal response body 337 338 serializationType := msg.SerializationType() 339 if icodec.IsValidSerializationType(s.opts.CurrentSerializationType) { 340 serializationType = s.opts.CurrentSerializationType 341 } 342 span := rpcz.SpanFromContext(ctx) 343 344 _, end := span.NewChild("Marshal") 345 rspBodyBuf, err := codec.Marshal(serializationType, rspBody) 346 end.End() 347 348 if err != nil { 349 report.ServiceCodecMarshalFail.Incr() 350 err = errs.NewFrameError(errs.RetServerEncodeFail, "service codec Marshal: "+err.Error()) 351 // rspBodyBuf will be nil if marshalling fails, respond only error code to client. 352 return s.encode(ctx, msg, rspBodyBuf, err) 353 } 354 355 // compress response body 356 compressType := msg.CompressType() 357 if icodec.IsValidCompressType(s.opts.CurrentCompressType) { 358 compressType = s.opts.CurrentCompressType 359 } 360 361 _, end = span.NewChild("Compress") 362 rspBodyBuf, err = codec.Compress(compressType, rspBodyBuf) 363 end.End() 364 365 if err != nil { 366 report.ServiceCodecCompressFail.Incr() 367 err = errs.NewFrameError(errs.RetServerEncodeFail, "service codec Compress: "+err.Error()) 368 // rspBodyBuf will be nil if compression fails, respond only error code to client. 369 return s.encode(ctx, msg, rspBodyBuf, err) 370 } 371 372 _, end = span.NewChild("EncodeProtocolHead") 373 rspBuf, err := s.encode(ctx, msg, rspBodyBuf, nil) 374 end.End() 375 376 return rspBuf, err 377 } 378 379 // filterFunc returns a FilterFunc, which would be passed to server stub to access pre/post filter handling. 380 func (s *service) filterFunc( 381 ctx context.Context, 382 msg codec.Msg, 383 reqBodyBuf []byte, 384 fixTimeout filter.ServerFilter, 385 ) FilterFunc { 386 // Decompression, serialization of request body are put into a closure. 387 // Both serialization type & compress type can be set. 388 // serialization type is set to msg.SerializationType() by default, 389 // if serialization type Option is called, serialization type is set by the Option. 390 // compress type's setting is similar to it. 391 return func(reqBody interface{}) (filter.ServerChain, error) { 392 // decompress request body 393 compressType := msg.CompressType() 394 if icodec.IsValidCompressType(s.opts.CurrentCompressType) { 395 compressType = s.opts.CurrentCompressType 396 } 397 span := rpcz.SpanFromContext(ctx) 398 _, end := span.NewChild("Decompress") 399 reqBodyBuf, err := codec.Decompress(compressType, reqBodyBuf) 400 end.End() 401 if err != nil { 402 report.ServiceCodecDecompressFail.Incr() 403 return nil, errs.NewFrameError(errs.RetServerDecodeFail, "service codec Decompress: "+err.Error()) 404 } 405 406 // unmarshal request body 407 serializationType := msg.SerializationType() 408 if icodec.IsValidSerializationType(s.opts.CurrentSerializationType) { 409 serializationType = s.opts.CurrentSerializationType 410 } 411 _, end = span.NewChild("Unmarshal") 412 err = codec.Unmarshal(serializationType, reqBodyBuf, reqBody) 413 end.End() 414 if err != nil { 415 report.ServiceCodecUnmarshalFail.Incr() 416 return nil, errs.NewFrameError(errs.RetServerDecodeFail, "service codec Unmarshal: "+err.Error()) 417 } 418 419 if fixTimeout != nil { 420 // this heap allocation cannot be avoided unless we change the generated xxx.trpc.go. 421 filters := make(filter.ServerChain, len(s.opts.Filters), len(s.opts.Filters)+1) 422 copy(filters, s.opts.Filters) 423 return append(filters, fixTimeout), nil 424 } 425 return s.opts.Filters, nil 426 } 427 } 428 429 // Register implements Service interface, registering a proto service impl for the service. 430 func (s *service) Register(serviceDesc interface{}, serviceImpl interface{}) error { 431 desc, ok := serviceDesc.(*ServiceDesc) 432 if !ok { 433 return errors.New("serviceDesc is not *ServiceDesc") 434 } 435 if desc.StreamHandle != nil { 436 s.opts.StreamHandle = desc.StreamHandle 437 if s.opts.StreamTransport != nil { 438 s.opts.Transport = s.opts.StreamTransport 439 } 440 // IdleTimeout is not used by server stream, set it to 0. 441 s.opts.ServeOptions = append(s.opts.ServeOptions, transport.WithServerIdleTimeout(0)) 442 err := s.opts.StreamHandle.Init(s.opts) 443 if err != nil { 444 return err 445 } 446 } 447 448 if serviceImpl != nil { 449 ht := reflect.TypeOf(desc.HandlerType).Elem() 450 hi := reflect.TypeOf(serviceImpl) 451 if !hi.Implements(ht) { 452 return fmt.Errorf("%s not implements interface %s", hi.String(), ht.String()) 453 } 454 } 455 456 var bindings []*restful.Binding 457 for _, method := range desc.Methods { 458 n := method.Name 459 if _, ok := s.handlers[n]; ok { 460 return fmt.Errorf("duplicate method name: %s", n) 461 } 462 h := method.Func 463 s.handlers[n] = func(ctx context.Context, f FilterFunc) (rsp interface{}, err error) { 464 return h(serviceImpl, ctx, f) 465 } 466 bindings = append(bindings, method.Bindings...) 467 } 468 469 for _, stream := range desc.Streams { 470 n := stream.StreamName 471 if _, ok := s.streamHandlers[n]; ok { 472 return fmt.Errorf("duplicate stream name: %s", n) 473 } 474 h := stream.Handler 475 s.streamInfo[stream.StreamName] = &StreamServerInfo{ 476 FullMethod: stream.StreamName, 477 IsClientStream: stream.ClientStreams, 478 IsServerStream: stream.ServerStreams, 479 } 480 s.streamHandlers[stream.StreamName] = func(stream Stream) error { 481 return h(serviceImpl, stream) 482 } 483 } 484 return s.createOrUpdateRouter(bindings, serviceImpl) 485 } 486 487 func (s *service) createOrUpdateRouter(bindings []*restful.Binding, serviceImpl interface{}) error { 488 // If pb option (trpc.api.http) is set,creates a RESTful Router. 489 if len(bindings) == 0 { 490 return nil 491 } 492 handler := restful.GetRouter(s.opts.ServiceName) 493 if handler != nil { 494 if router, ok := handler.(*restful.Router); ok { // A router has already been registered. 495 for _, binding := range bindings { // Add binding with a specified service implementation. 496 if err := router.AddImplBinding(binding, serviceImpl); err != nil { 497 return fmt.Errorf("add impl binding during service registration: %w", err) 498 } 499 } 500 return nil 501 } 502 } 503 // This is the first time of registering the service router, create a new one. 504 router := restful.NewRouter(append(s.opts.RESTOptions, 505 restful.WithNamespace(s.opts.Namespace), 506 restful.WithEnvironment(s.opts.EnvName), 507 restful.WithContainer(s.opts.container), 508 restful.WithSet(s.opts.SetName), 509 restful.WithServiceName(s.opts.ServiceName), 510 restful.WithTimeout(s.opts.Timeout), 511 restful.WithFilterFunc(func() filter.ServerChain { return s.opts.Filters }))...) 512 for _, binding := range bindings { 513 if err := router.AddImplBinding(binding, serviceImpl); err != nil { 514 return err 515 } 516 } 517 restful.RegisterRouter(s.opts.ServiceName, router) 518 return nil 519 } 520 521 // Close closes the service,registry.Deregister will be called. 522 func (s *service) Close(ch chan struct{}) error { 523 pid := os.Getpid() 524 if ch == nil { 525 ch = make(chan struct{}, 1) 526 } 527 log.Infof("process:%d, %s service:%s, closing ...", pid, s.opts.protocol, s.opts.ServiceName) 528 529 if s.opts.Registry != nil { 530 // When it comes to graceful restart, the parent process will not call registry Deregister(), 531 // while the child process would call registry Deregister(). 532 if isGraceful, isParental := checkProcessStatus(); !(isGraceful && isParental) { 533 if err := s.opts.Registry.Deregister(s.opts.ServiceName); err != nil { 534 log.Errorf("process:%d, deregister service:%s fail:%v", pid, s.opts.ServiceName, err) 535 } 536 } 537 } 538 if remains := s.waitBeforeClose(); remains > 0 { 539 log.Infof("process %d service %s remains %d requests before close", 540 os.Getpid(), s.opts.ServiceName, remains) 541 } 542 543 // this will cancel all children ctx. 544 s.cancel() 545 546 timeout := time.Millisecond * 300 547 if s.opts.Timeout > timeout { // use the larger one 548 timeout = s.opts.Timeout 549 } 550 if remains := s.waitInactive(timeout); remains > 0 { 551 log.Infof("process %d service %s remains %d requests after close", 552 os.Getpid(), s.opts.ServiceName, remains) 553 } 554 log.Infof("process:%d, %s service:%s, closed", pid, s.opts.protocol, s.opts.ServiceName) 555 ch <- struct{}{} 556 return nil 557 } 558 559 func (s *service) waitBeforeClose() int64 { 560 closeWaitTime := s.opts.CloseWaitTime 561 if closeWaitTime > MaxCloseWaitTime { 562 closeWaitTime = MaxCloseWaitTime 563 } 564 if closeWaitTime > 0 { 565 // After registry.Deregister() is called, sleep a while to let Naming Service (like Polaris) finish 566 // updating instance ip list. 567 // Otherwise, client request would still arrive while the service had already been closed (Typically, it occurs 568 // when k8s updates pods). 569 log.Infof("process %d service %s remain %d requests wait %v time when closing service", 570 os.Getpid(), s.opts.ServiceName, atomic.LoadInt64(&s.activeCount), closeWaitTime) 571 time.Sleep(closeWaitTime) 572 } 573 return s.waitInactive(s.opts.MaxCloseWaitTime - closeWaitTime) 574 } 575 576 func (s *service) waitInactive(maxWaitTime time.Duration) int64 { 577 const sleepTime = 100 * time.Millisecond 578 for start := time.Now(); time.Since(start) < maxWaitTime; time.Sleep(sleepTime) { 579 if atomic.LoadInt64(&s.activeCount) <= 0 { 580 return 0 581 } 582 } 583 return atomic.LoadInt64(&s.activeCount) 584 } 585 586 func checkProcessStatus() (isGracefulRestart, isParentalProcess bool) { 587 v := os.Getenv(transport.EnvGraceRestartPPID) 588 if v == "" { 589 return false, true 590 } 591 592 ppid, err := strconv.Atoi(v) 593 if err != nil { 594 return false, false 595 } 596 return true, ppid == os.Getpid() 597 } 598 599 func defaultOptions() *Options { 600 const ( 601 invalidSerializationType = -1 602 invalidCompressType = -1 603 ) 604 return &Options{ 605 protocol: "unknown-protocol", 606 ServiceName: "empty-name", 607 CurrentSerializationType: invalidSerializationType, 608 CurrentCompressType: invalidCompressType, 609 } 610 }