github.com/cloudwego/kitex@v0.9.0/server/server.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Package server . 18 package server 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 "net" 25 "reflect" 26 "runtime/debug" 27 "sync" 28 "time" 29 30 "github.com/cloudwego/localsession/backup" 31 32 internal_server "github.com/cloudwego/kitex/internal/server" 33 "github.com/cloudwego/kitex/pkg/acl" 34 "github.com/cloudwego/kitex/pkg/diagnosis" 35 "github.com/cloudwego/kitex/pkg/discovery" 36 "github.com/cloudwego/kitex/pkg/endpoint" 37 "github.com/cloudwego/kitex/pkg/gofunc" 38 "github.com/cloudwego/kitex/pkg/kerrors" 39 "github.com/cloudwego/kitex/pkg/klog" 40 "github.com/cloudwego/kitex/pkg/limiter" 41 "github.com/cloudwego/kitex/pkg/registry" 42 "github.com/cloudwego/kitex/pkg/remote" 43 "github.com/cloudwego/kitex/pkg/remote/bound" 44 "github.com/cloudwego/kitex/pkg/remote/remotesvr" 45 "github.com/cloudwego/kitex/pkg/rpcinfo" 46 "github.com/cloudwego/kitex/pkg/serviceinfo" 47 "github.com/cloudwego/kitex/pkg/stats" 48 ) 49 50 // Server is an abstraction of an RPC server. It accepts connections and dispatches them to the service 51 // registered to it. 52 type Server interface { 53 RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, opts ...RegisterOption) error 54 GetServiceInfos() map[string]*serviceinfo.ServiceInfo 55 Run() error 56 Stop() error 57 } 58 59 type server struct { 60 opt *internal_server.Options 61 svcs *services 62 targetSvcInfo *serviceinfo.ServiceInfo 63 64 // actual rpc service implement of biz 65 eps endpoint.Endpoint 66 mws []endpoint.Middleware 67 svr remotesvr.Server 68 stopped sync.Once 69 isRun bool 70 71 sync.Mutex 72 } 73 74 // NewServer creates a server with the given Options. 75 func NewServer(ops ...Option) Server { 76 s := &server{ 77 opt: internal_server.NewOptions(ops), 78 svcs: newServices(), 79 } 80 s.init() 81 return s 82 } 83 84 func (s *server) init() { 85 ctx := fillContext(s.opt) 86 if s.opt.EnableContextTimeout { 87 // prepend for adding timeout to the context for all middlewares and the handler 88 s.opt.MWBs = append([]endpoint.MiddlewareBuilder{serverTimeoutMW}, s.opt.MWBs...) 89 } 90 s.mws = richMWsWithBuilder(ctx, s.opt.MWBs, s) 91 s.mws = append(s.mws, acl.NewACLMiddleware(s.opt.ACLRules)) 92 s.initStreamMiddlewares(ctx) 93 if s.opt.ErrHandle != nil { 94 // errorHandleMW must be the last middleware, 95 // to ensure it only catches the server handler's error. 96 s.mws = append(s.mws, newErrorHandleMW(s.opt.ErrHandle)) 97 } 98 if ds := s.opt.DebugService; ds != nil { 99 ds.RegisterProbeFunc(diagnosis.OptionsKey, diagnosis.WrapAsProbeFunc(s.opt.DebugInfo)) 100 ds.RegisterProbeFunc(diagnosis.ChangeEventsKey, s.opt.Events.Dump) 101 } 102 backup.Init(s.opt.BackupOpt) 103 s.buildInvokeChain() 104 s.buildStreamInvokeChain() 105 } 106 107 func (s *server) Endpoints() endpoint.Endpoint { 108 return s.eps 109 } 110 111 func (s *server) SetEndpoints(e endpoint.Endpoint) { 112 s.eps = e 113 } 114 115 func (s *server) Option() *internal_server.Options { 116 return s.opt 117 } 118 119 func fillContext(opt *internal_server.Options) context.Context { 120 ctx := context.Background() 121 ctx = context.WithValue(ctx, endpoint.CtxEventBusKey, opt.Bus) 122 ctx = context.WithValue(ctx, endpoint.CtxEventQueueKey, opt.Events) 123 return ctx 124 } 125 126 func richMWsWithBuilder(ctx context.Context, mwBs []endpoint.MiddlewareBuilder, ks *server) []endpoint.Middleware { 127 for i := range mwBs { 128 ks.mws = append(ks.mws, mwBs[i](ctx)) 129 } 130 return ks.mws 131 } 132 133 // newErrorHandleMW provides a hook point for server error handling. 134 func newErrorHandleMW(errHandle func(context.Context, error) error) endpoint.Middleware { 135 return func(next endpoint.Endpoint) endpoint.Endpoint { 136 return func(ctx context.Context, request, response interface{}) error { 137 err := next(ctx, request, response) 138 if err == nil { 139 return nil 140 } 141 return errHandle(ctx, err) 142 } 143 } 144 } 145 146 func (s *server) initOrResetRPCInfoFunc() func(rpcinfo.RPCInfo, net.Addr) rpcinfo.RPCInfo { 147 return func(ri rpcinfo.RPCInfo, rAddr net.Addr) rpcinfo.RPCInfo { 148 // Reset existing rpcinfo to improve performance for long connections (PR #584). 149 if ri != nil && rpcinfo.PoolEnabled() { 150 fi := rpcinfo.AsMutableEndpointInfo(ri.From()) 151 fi.Reset() 152 fi.SetAddress(rAddr) 153 rpcinfo.AsMutableEndpointInfo(ri.To()).ResetFromBasicInfo(s.opt.Svr) 154 if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { 155 setter.Reset() 156 } 157 rpcinfo.AsMutableRPCConfig(ri.Config()).CopyFrom(s.opt.Configs) 158 rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) 159 rpcStats.Reset() 160 if s.opt.StatsLevel != nil { 161 rpcStats.SetLevel(*s.opt.StatsLevel) 162 } 163 return ri 164 } 165 166 // allocate a new rpcinfo if it's the connection's first request or rpcInfoPool is disabled 167 rpcStats := rpcinfo.AsMutableRPCStats(rpcinfo.NewRPCStats()) 168 if s.opt.StatsLevel != nil { 169 rpcStats.SetLevel(*s.opt.StatsLevel) 170 } 171 172 // Export read-only views to external users and keep a mapping for internal users. 173 ri = rpcinfo.NewRPCInfo( 174 rpcinfo.EmptyEndpointInfo(), 175 rpcinfo.FromBasicInfo(s.opt.Svr), 176 rpcinfo.NewServerInvocation(), 177 rpcinfo.AsMutableRPCConfig(s.opt.Configs).Clone().ImmutableView(), 178 rpcStats.ImmutableView(), 179 ) 180 rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(rAddr) 181 return ri 182 } 183 } 184 185 func (s *server) buildInvokeChain() { 186 innerHandlerEp := s.invokeHandleEndpoint() 187 s.eps = endpoint.Chain(s.mws...)(innerHandlerEp) 188 } 189 190 // RegisterService should not be called by users directly. 191 func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, opts ...RegisterOption) error { 192 s.Lock() 193 defer s.Unlock() 194 if s.isRun { 195 panic("service cannot be registered while server is running") 196 } 197 if svcInfo == nil { 198 panic("svcInfo is nil. please specify non-nil svcInfo") 199 } 200 if handler == nil || reflect.ValueOf(handler).IsNil() { 201 panic("handler is nil. please specify non-nil handler") 202 } 203 if s.svcs.svcMap[svcInfo.ServiceName] != nil { 204 panic(fmt.Sprintf("Service[%s] is already defined", svcInfo.ServiceName)) 205 } 206 207 registerOpts := internal_server.NewRegisterOptions(opts) 208 if err := s.svcs.addService(svcInfo, handler, registerOpts); err != nil { 209 panic(err.Error()) 210 } 211 return nil 212 } 213 214 func (s *server) GetServiceInfos() map[string]*serviceinfo.ServiceInfo { 215 return s.svcs.getSvcInfoSearchMap() 216 } 217 218 // Run runs the server. 219 func (s *server) Run() (err error) { 220 s.Lock() 221 s.isRun = true 222 s.Unlock() 223 if err = s.check(); err != nil { 224 return err 225 } 226 s.findAndSetDefaultService() 227 diagnosis.RegisterProbeFunc(s.opt.DebugService, diagnosis.ServiceInfosKey, diagnosis.WrapAsProbeFunc(s.svcs.getSvcInfoMap())) 228 if s.svcs.fallbackSvc != nil { 229 diagnosis.RegisterProbeFunc(s.opt.DebugService, diagnosis.FallbackServiceKey, diagnosis.WrapAsProbeFunc(s.svcs.fallbackSvc.svcInfo.ServiceName)) 230 } 231 svrCfg := s.opt.RemoteOpt 232 addr := svrCfg.Address // should not be nil 233 if s.opt.Proxy != nil { 234 svrCfg.Address, err = s.opt.Proxy.Replace(addr) 235 if err != nil { 236 return 237 } 238 } 239 240 s.fillMoreServiceInfo(s.opt.RemoteOpt.Address) 241 s.richRemoteOption() 242 transHdlr, err := s.newSvrTransHandler() 243 if err != nil { 244 return err 245 } 246 s.Lock() 247 s.svr, err = remotesvr.NewServer(s.opt.RemoteOpt, s.eps, transHdlr) 248 s.Unlock() 249 if err != nil { 250 return err 251 } 252 253 // start profiler 254 if s.opt.RemoteOpt.Profiler != nil { 255 gofunc.GoFunc(context.Background(), func() { 256 klog.Info("KITEX: server starting profiler") 257 err := s.opt.RemoteOpt.Profiler.Run(context.Background()) 258 if err != nil { 259 klog.Errorf("KITEX: server started profiler error: error=%s", err.Error()) 260 } 261 }) 262 } 263 264 errCh := s.svr.Start() 265 select { 266 case err = <-errCh: 267 klog.Errorf("KITEX: server start error: error=%s", err.Error()) 268 return err 269 default: 270 } 271 muStartHooks.Lock() 272 for i := range onServerStart { 273 go onServerStart[i]() 274 } 275 muStartHooks.Unlock() 276 s.Lock() 277 s.buildRegistryInfo(s.svr.Address()) 278 s.Unlock() 279 280 if err = s.waitExit(errCh); err != nil { 281 klog.Errorf("KITEX: received error and exit: error=%s", err.Error()) 282 } 283 if e := s.Stop(); e != nil && err == nil { 284 err = e 285 klog.Errorf("KITEX: stop server error: error=%s", e.Error()) 286 } 287 return 288 } 289 290 // Stop stops the server gracefully. 291 func (s *server) Stop() (err error) { 292 s.stopped.Do(func() { 293 s.Lock() 294 defer s.Unlock() 295 296 muShutdownHooks.Lock() 297 for i := range onShutdown { 298 onShutdown[i]() 299 } 300 muShutdownHooks.Unlock() 301 302 if s.opt.RegistryInfo != nil { 303 err = s.opt.Registry.Deregister(s.opt.RegistryInfo) 304 s.opt.RegistryInfo = nil 305 } 306 if s.svr != nil { 307 if e := s.svr.Stop(); e != nil { 308 err = e 309 } 310 s.svr = nil 311 } 312 }) 313 return 314 } 315 316 func (s *server) invokeHandleEndpoint() endpoint.Endpoint { 317 return func(ctx context.Context, args, resp interface{}) (err error) { 318 ri := rpcinfo.GetRPCInfo(ctx) 319 methodName := ri.Invocation().MethodName() 320 serviceName := ri.Invocation().ServiceName() 321 svc := s.svcs.svcMap[serviceName] 322 svcInfo := svc.svcInfo 323 if methodName == "" && svcInfo.ServiceName != serviceinfo.GenericService { 324 return errors.New("method name is empty in rpcinfo, should not happen") 325 } 326 defer func() { 327 if handlerErr := recover(); handlerErr != nil { 328 err = kerrors.ErrPanic.WithCauseAndStack( 329 fmt.Errorf( 330 "[happened in biz handler, method=%s.%s, please check the panic at the server side] %s", 331 svcInfo.ServiceName, methodName, handlerErr), 332 string(debug.Stack())) 333 rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) 334 rpcStats.SetPanicked(err) 335 } 336 rpcinfo.Record(ctx, ri, stats.ServerHandleFinish, err) 337 // clear session 338 backup.ClearCtx() 339 }() 340 implHandlerFunc := svcInfo.MethodInfo(methodName).Handler() 341 rpcinfo.Record(ctx, ri, stats.ServerHandleStart, nil) 342 // set session 343 backup.BackupCtx(ctx) 344 err = implHandlerFunc(ctx, svc.handler, args, resp) 345 if err != nil { 346 if bizErr, ok := kerrors.FromBizStatusError(err); ok { 347 if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { 348 setter.SetBizStatusErr(bizErr) 349 return nil 350 } 351 } 352 err = kerrors.ErrBiz.WithCause(err) 353 } 354 return err 355 } 356 } 357 358 func (s *server) initBasicRemoteOption() { 359 remoteOpt := s.opt.RemoteOpt 360 remoteOpt.TargetSvcInfo = s.targetSvcInfo 361 remoteOpt.SvcSearchMap = s.svcs.getSvcInfoSearchMap() 362 remoteOpt.RefuseTrafficWithoutServiceName = s.opt.RefuseTrafficWithoutServiceName 363 remoteOpt.InitOrResetRPCInfoFunc = s.initOrResetRPCInfoFunc() 364 remoteOpt.TracerCtl = s.opt.TracerCtl 365 remoteOpt.ReadWriteTimeout = s.opt.Configs.ReadWriteTimeout() 366 } 367 368 func (s *server) richRemoteOption() { 369 s.initBasicRemoteOption() 370 371 s.addBoundHandlers(s.opt.RemoteOpt) 372 } 373 374 func (s *server) addBoundHandlers(opt *remote.ServerOption) { 375 // add profiler meta handler, which should be exec after other MetaHandlers 376 if opt.Profiler != nil && opt.ProfilerMessageTagging != nil { 377 s.opt.MetaHandlers = append(s.opt.MetaHandlers, 378 remote.NewProfilerMetaHandler(opt.Profiler, opt.ProfilerMessageTagging), 379 ) 380 } 381 // for server trans info handler 382 if len(s.opt.MetaHandlers) > 0 { 383 transInfoHdlr := bound.NewTransMetaHandler(s.opt.MetaHandlers) 384 // meta handler exec before boundHandlers which add with option 385 doAddBoundHandlerToHead(transInfoHdlr, opt) 386 for _, h := range s.opt.MetaHandlers { 387 if shdlr, ok := h.(remote.StreamingMetaHandler); ok { 388 opt.StreamingMetaHandlers = append(opt.StreamingMetaHandlers, shdlr) 389 } 390 } 391 } 392 393 limitHdlr := s.buildLimiterWithOpt() 394 if limitHdlr != nil { 395 doAddBoundHandler(limitHdlr, opt) 396 } 397 } 398 399 /* 400 * There are two times when the rate limiter can take effect for a non-multiplexed server, 401 * which are the OnRead and OnMessage callback. OnRead is called before request decoded 402 * and OnMessage is called after. 403 * Therefore, the optimization point is that we can make rate limiter take effect in OnRead as 404 * possible to save computational cost of decoding. 405 * The implementation is that when using the default rate limiter to launching a non-multiplexed 406 * service, use the `serverLimiterOnReadHandler` whose rate limiting takes effect in the OnRead 407 * callback. 408 */ 409 func (s *server) buildLimiterWithOpt() (handler remote.InboundHandler) { 410 limits := s.opt.Limit.Limits 411 connLimit := s.opt.Limit.ConLimit 412 qpsLimit := s.opt.Limit.QPSLimit 413 if limits == nil && connLimit == nil && qpsLimit == nil { 414 return 415 } 416 417 if connLimit == nil { 418 if limits != nil { 419 connLimit = limiter.NewConnectionLimiter(limits.MaxConnections) 420 } else { 421 connLimit = &limiter.DummyConcurrencyLimiter{} 422 } 423 } 424 425 if qpsLimit == nil { 426 if limits != nil { 427 interval := time.Millisecond * 100 // FIXME: should not care this implementation-specific parameter 428 qpsLimit = limiter.NewQPSLimiter(interval, limits.MaxQPS) 429 } else { 430 qpsLimit = &limiter.DummyRateLimiter{} 431 } 432 } else { 433 s.opt.Limit.QPSLimitPostDecode = true 434 } 435 436 if limits != nil && limits.UpdateControl != nil { 437 updater := limiter.NewLimiterWrapper(connLimit, qpsLimit) 438 limits.UpdateControl(updater) 439 } 440 441 handler = bound.NewServerLimiterHandler(connLimit, qpsLimit, s.opt.Limit.LimitReporter, s.opt.Limit.QPSLimitPostDecode) 442 // TODO: gRPC limiter 443 return 444 } 445 446 func (s *server) check() error { 447 if len(s.svcs.svcMap) == 0 { 448 return errors.New("run: no service. Use RegisterService to set one") 449 } 450 return checkFallbackServiceForConflictingMethods(s.svcs.conflictingMethodHasFallbackSvcMap, s.opt.RefuseTrafficWithoutServiceName) 451 } 452 453 func doAddBoundHandlerToHead(h remote.BoundHandler, opt *remote.ServerOption) { 454 add := false 455 if ih, ok := h.(remote.InboundHandler); ok { 456 handlers := []remote.InboundHandler{ih} 457 opt.Inbounds = append(handlers, opt.Inbounds...) 458 add = true 459 } 460 if oh, ok := h.(remote.OutboundHandler); ok { 461 handlers := []remote.OutboundHandler{oh} 462 opt.Outbounds = append(handlers, opt.Outbounds...) 463 add = true 464 } 465 if !add { 466 panic("invalid BoundHandler: must implement InboundHandler or OutboundHandler") 467 } 468 } 469 470 func doAddBoundHandler(h remote.BoundHandler, opt *remote.ServerOption) { 471 add := false 472 if ih, ok := h.(remote.InboundHandler); ok { 473 opt.Inbounds = append(opt.Inbounds, ih) 474 add = true 475 } 476 if oh, ok := h.(remote.OutboundHandler); ok { 477 opt.Outbounds = append(opt.Outbounds, oh) 478 add = true 479 } 480 if !add { 481 panic("invalid BoundHandler: must implement InboundHandler or OutboundHandler") 482 } 483 } 484 485 func (s *server) newSvrTransHandler() (handler remote.ServerTransHandler, err error) { 486 transHdlrFactory := s.opt.RemoteOpt.SvrHandlerFactory 487 transHdlr, err := transHdlrFactory.NewTransHandler(s.opt.RemoteOpt) 488 if err != nil { 489 return nil, err 490 } 491 if setter, ok := transHdlr.(remote.InvokeHandleFuncSetter); ok { 492 setter.SetInvokeHandleFunc(s.eps) 493 } 494 transPl := remote.NewTransPipeline(transHdlr) 495 496 for _, ib := range s.opt.RemoteOpt.Inbounds { 497 transPl.AddInboundHandler(ib) 498 } 499 for _, ob := range s.opt.RemoteOpt.Outbounds { 500 transPl.AddOutboundHandler(ob) 501 } 502 return transPl, nil 503 } 504 505 func (s *server) buildRegistryInfo(lAddr net.Addr) { 506 if s.opt.RegistryInfo == nil { 507 s.opt.RegistryInfo = ®istry.Info{} 508 } 509 info := s.opt.RegistryInfo 510 // notice: lAddr may be nil when listen failed 511 info.Addr = lAddr 512 if info.ServiceName == "" { 513 info.ServiceName = s.opt.Svr.ServiceName 514 } 515 if info.PayloadCodec == "" { 516 info.PayloadCodec = getDefaultSvcInfo(s.svcs).PayloadCodec.String() 517 } 518 if info.Weight == 0 { 519 info.Weight = discovery.DefaultWeight 520 } 521 if info.Tags == nil { 522 info.Tags = s.opt.Svr.Tags 523 } 524 } 525 526 func (s *server) fillMoreServiceInfo(lAddr net.Addr) { 527 for _, svc := range s.svcs.svcMap { 528 ni := *svc.svcInfo 529 si := &ni 530 extra := make(map[string]interface{}, len(si.Extra)+2) 531 for k, v := range si.Extra { 532 extra[k] = v 533 } 534 extra["address"] = lAddr 535 extra["transports"] = s.opt.SupportedTransportsFunc(*s.opt.RemoteOpt) 536 si.Extra = extra 537 svc.svcInfo = si 538 } 539 } 540 541 func (s *server) waitExit(errCh chan error) error { 542 exitSignal := s.opt.ExitSignal() 543 544 // service may not be available as soon as startup. 545 delayRegister := time.After(1 * time.Second) 546 for { 547 select { 548 case err := <-exitSignal: 549 return err 550 case err := <-errCh: 551 return err 552 case <-delayRegister: 553 s.Lock() 554 if err := s.opt.Registry.Register(s.opt.RegistryInfo); err != nil { 555 s.Unlock() 556 return err 557 } 558 s.Unlock() 559 } 560 } 561 } 562 563 func (s *server) findAndSetDefaultService() { 564 if len(s.svcs.svcMap) == 1 { 565 s.targetSvcInfo = getDefaultSvcInfo(s.svcs) 566 } 567 } 568 569 // getDefaultSvc is used to get one ServiceInfo from map 570 func getDefaultSvcInfo(svcs *services) *serviceinfo.ServiceInfo { 571 if len(svcs.svcMap) > 1 && svcs.fallbackSvc != nil { 572 return svcs.fallbackSvc.svcInfo 573 } 574 for _, svc := range svcs.svcMap { 575 return svc.svcInfo 576 } 577 return nil 578 } 579 580 func checkFallbackServiceForConflictingMethods(conflictingMethodHasFallbackSvcMap map[string]bool, refuseTrafficWithoutServiceName bool) error { 581 if refuseTrafficWithoutServiceName { 582 return nil 583 } 584 for name, hasFallbackSvc := range conflictingMethodHasFallbackSvcMap { 585 if !hasFallbackSvc { 586 return fmt.Errorf("method name [%s] is conflicted between services but no fallback service is specified", name) 587 } 588 } 589 return nil 590 }