
     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    17  // Package server .
    18  package server
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"reflect"
    26  	"runtime/debug"
    27  	"sync"
    28  	"time"
    30  	""
    32  	internal_server ""
    33  	""
    34  	""
    35  	""
    36  	""
    37  	""
    38  	""
    39  	""
    40  	""
    41  	""
    42  	""
    43  	""
    44  	""
    45  	""
    46  	""
    47  	""
    48  )
    50  // Server is an abstraction of an RPC server. It accepts connections and dispatches them to the service
    51  // registered to it.
    52  type Server interface {
    53  	RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, opts ...RegisterOption) error
    54  	GetServiceInfos() map[string]*serviceinfo.ServiceInfo
    55  	Run() error
    56  	Stop() error
    57  }
    59  type server struct {
    60  	opt           *internal_server.Options
    61  	svcs          *services
    62  	targetSvcInfo *serviceinfo.ServiceInfo
    64  	// actual rpc service implement of biz
    65  	eps     endpoint.Endpoint
    66  	mws     []endpoint.Middleware
    67  	svr     remotesvr.Server
    68  	stopped sync.Once
    69  	isRun   bool
    71  	sync.Mutex
    72  }
    74  // NewServer creates a server with the given Options.
    75  func NewServer(ops ...Option) Server {
    76  	s := &server{
    77  		opt:  internal_server.NewOptions(ops),
    78  		svcs: newServices(),
    79  	}
    80  	s.init()
    81  	return s
    82  }
    84  func (s *server) init() {
    85  	ctx := fillContext(s.opt)
    86  	if s.opt.EnableContextTimeout {
    87  		// prepend for adding timeout to the context for all middlewares and the handler
    88  		s.opt.MWBs = append([]endpoint.MiddlewareBuilder{serverTimeoutMW}, s.opt.MWBs...)
    89  	}
    90  	s.mws = richMWsWithBuilder(ctx, s.opt.MWBs, s)
    91  	s.mws = append(s.mws, acl.NewACLMiddleware(s.opt.ACLRules))
    92  	s.initStreamMiddlewares(ctx)
    93  	if s.opt.ErrHandle != nil {
    94  		// errorHandleMW must be the last middleware,
    95  		// to ensure it only catches the server handler's error.
    96  		s.mws = append(s.mws, newErrorHandleMW(s.opt.ErrHandle))
    97  	}
    98  	if ds := s.opt.DebugService; ds != nil {
    99  		ds.RegisterProbeFunc(diagnosis.OptionsKey, diagnosis.WrapAsProbeFunc(s.opt.DebugInfo))
   100  		ds.RegisterProbeFunc(diagnosis.ChangeEventsKey, s.opt.Events.Dump)
   101  	}
   102  	backup.Init(s.opt.BackupOpt)
   103  	s.buildInvokeChain()
   104  	s.buildStreamInvokeChain()
   105  }
   107  func (s *server) Endpoints() endpoint.Endpoint {
   108  	return s.eps
   109  }
   111  func (s *server) SetEndpoints(e endpoint.Endpoint) {
   112  	s.eps = e
   113  }
   115  func (s *server) Option() *internal_server.Options {
   116  	return s.opt
   117  }
   119  func fillContext(opt *internal_server.Options) context.Context {
   120  	ctx := context.Background()
   121  	ctx = context.WithValue(ctx, endpoint.CtxEventBusKey, opt.Bus)
   122  	ctx = context.WithValue(ctx, endpoint.CtxEventQueueKey, opt.Events)
   123  	return ctx
   124  }
   126  func richMWsWithBuilder(ctx context.Context, mwBs []endpoint.MiddlewareBuilder, ks *server) []endpoint.Middleware {
   127  	for i := range mwBs {
   128  		ks.mws = append(ks.mws, mwBs[i](ctx))
   129  	}
   130  	return ks.mws
   131  }
   133  // newErrorHandleMW provides a hook point for server error handling.
   134  func newErrorHandleMW(errHandle func(context.Context, error) error) endpoint.Middleware {
   135  	return func(next endpoint.Endpoint) endpoint.Endpoint {
   136  		return func(ctx context.Context, request, response interface{}) error {
   137  			err := next(ctx, request, response)
   138  			if err == nil {
   139  				return nil
   140  			}
   141  			return errHandle(ctx, err)
   142  		}
   143  	}
   144  }
   146  func (s *server) initOrResetRPCInfoFunc() func(rpcinfo.RPCInfo, net.Addr) rpcinfo.RPCInfo {
   147  	return func(ri rpcinfo.RPCInfo, rAddr net.Addr) rpcinfo.RPCInfo {
   148  		// Reset existing rpcinfo to improve performance for long connections (PR #584).
   149  		if ri != nil && rpcinfo.PoolEnabled() {
   150  			fi := rpcinfo.AsMutableEndpointInfo(ri.From())
   151  			fi.Reset()
   152  			fi.SetAddress(rAddr)
   153  			rpcinfo.AsMutableEndpointInfo(ri.To()).ResetFromBasicInfo(s.opt.Svr)
   154  			if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok {
   155  				setter.Reset()
   156  			}
   157  			rpcinfo.AsMutableRPCConfig(ri.Config()).CopyFrom(s.opt.Configs)
   158  			rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats())
   159  			rpcStats.Reset()
   160  			if s.opt.StatsLevel != nil {
   161  				rpcStats.SetLevel(*s.opt.StatsLevel)
   162  			}
   163  			return ri
   164  		}
   166  		// allocate a new rpcinfo if it's the connection's first request or rpcInfoPool is disabled
   167  		rpcStats := rpcinfo.AsMutableRPCStats(rpcinfo.NewRPCStats())
   168  		if s.opt.StatsLevel != nil {
   169  			rpcStats.SetLevel(*s.opt.StatsLevel)
   170  		}
   172  		// Export read-only views to external users and keep a mapping for internal users.
   173  		ri = rpcinfo.NewRPCInfo(
   174  			rpcinfo.EmptyEndpointInfo(),
   175  			rpcinfo.FromBasicInfo(s.opt.Svr),
   176  			rpcinfo.NewServerInvocation(),
   177  			rpcinfo.AsMutableRPCConfig(s.opt.Configs).Clone().ImmutableView(),
   178  			rpcStats.ImmutableView(),
   179  		)
   180  		rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(rAddr)
   181  		return ri
   182  	}
   183  }
   185  func (s *server) buildInvokeChain() {
   186  	innerHandlerEp := s.invokeHandleEndpoint()
   187  	s.eps = endpoint.Chain(s.mws...)(innerHandlerEp)
   188  }
   190  // RegisterService should not be called by users directly.
   191  func (s *server) RegisterService(svcInfo *serviceinfo.ServiceInfo, handler interface{}, opts ...RegisterOption) error {
   192  	s.Lock()
   193  	defer s.Unlock()
   194  	if s.isRun {
   195  		panic("service cannot be registered while server is running")
   196  	}
   197  	if svcInfo == nil {
   198  		panic("svcInfo is nil. please specify non-nil svcInfo")
   199  	}
   200  	if handler == nil || reflect.ValueOf(handler).IsNil() {
   201  		panic("handler is nil. please specify non-nil handler")
   202  	}
   203  	if s.svcs.svcMap[svcInfo.ServiceName] != nil {
   204  		panic(fmt.Sprintf("Service[%s] is already defined", svcInfo.ServiceName))
   205  	}
   207  	registerOpts := internal_server.NewRegisterOptions(opts)
   208  	if err := s.svcs.addService(svcInfo, handler, registerOpts); err != nil {
   209  		panic(err.Error())
   210  	}
   211  	return nil
   212  }
   214  func (s *server) GetServiceInfos() map[string]*serviceinfo.ServiceInfo {
   215  	return s.svcs.getSvcInfoSearchMap()
   216  }
   218  // Run runs the server.
   219  func (s *server) Run() (err error) {
   220  	s.Lock()
   221  	s.isRun = true
   222  	s.Unlock()
   223  	if err = s.check(); err != nil {
   224  		return err
   225  	}
   226  	s.findAndSetDefaultService()
   227  	diagnosis.RegisterProbeFunc(s.opt.DebugService, diagnosis.ServiceInfosKey, diagnosis.WrapAsProbeFunc(s.svcs.getSvcInfoMap()))
   228  	if s.svcs.fallbackSvc != nil {
   229  		diagnosis.RegisterProbeFunc(s.opt.DebugService, diagnosis.FallbackServiceKey, diagnosis.WrapAsProbeFunc(s.svcs.fallbackSvc.svcInfo.ServiceName))
   230  	}
   231  	svrCfg := s.opt.RemoteOpt
   232  	addr := svrCfg.Address // should not be nil
   233  	if s.opt.Proxy != nil {
   234  		svrCfg.Address, err = s.opt.Proxy.Replace(addr)
   235  		if err != nil {
   236  			return
   237  		}
   238  	}
   240  	s.fillMoreServiceInfo(s.opt.RemoteOpt.Address)
   241  	s.richRemoteOption()
   242  	transHdlr, err := s.newSvrTransHandler()
   243  	if err != nil {
   244  		return err
   245  	}
   246  	s.Lock()
   247  	s.svr, err = remotesvr.NewServer(s.opt.RemoteOpt, s.eps, transHdlr)
   248  	s.Unlock()
   249  	if err != nil {
   250  		return err
   251  	}
   253  	// start profiler
   254  	if s.opt.RemoteOpt.Profiler != nil {
   255  		gofunc.GoFunc(context.Background(), func() {
   256  			klog.Info("KITEX: server starting profiler")
   257  			err := s.opt.RemoteOpt.Profiler.Run(context.Background())
   258  			if err != nil {
   259  				klog.Errorf("KITEX: server started profiler error: error=%s", err.Error())
   260  			}
   261  		})
   262  	}
   264  	errCh := s.svr.Start()
   265  	select {
   266  	case err = <-errCh:
   267  		klog.Errorf("KITEX: server start error: error=%s", err.Error())
   268  		return err
   269  	default:
   270  	}
   271  	muStartHooks.Lock()
   272  	for i := range onServerStart {
   273  		go onServerStart[i]()
   274  	}
   275  	muStartHooks.Unlock()
   276  	s.Lock()
   277  	s.buildRegistryInfo(s.svr.Address())
   278  	s.Unlock()
   280  	if err = s.waitExit(errCh); err != nil {
   281  		klog.Errorf("KITEX: received error and exit: error=%s", err.Error())
   282  	}
   283  	if e := s.Stop(); e != nil && err == nil {
   284  		err = e
   285  		klog.Errorf("KITEX: stop server error: error=%s", e.Error())
   286  	}
   287  	return
   288  }
   290  // Stop stops the server gracefully.
   291  func (s *server) Stop() (err error) {
   292  	s.stopped.Do(func() {
   293  		s.Lock()
   294  		defer s.Unlock()
   296  		muShutdownHooks.Lock()
   297  		for i := range onShutdown {
   298  			onShutdown[i]()
   299  		}
   300  		muShutdownHooks.Unlock()
   302  		if s.opt.RegistryInfo != nil {
   303  			err = s.opt.Registry.Deregister(s.opt.RegistryInfo)
   304  			s.opt.RegistryInfo = nil
   305  		}
   306  		if s.svr != nil {
   307  			if e := s.svr.Stop(); e != nil {
   308  				err = e
   309  			}
   310  			s.svr = nil
   311  		}
   312  	})
   313  	return
   314  }
   316  func (s *server) invokeHandleEndpoint() endpoint.Endpoint {
   317  	return func(ctx context.Context, args, resp interface{}) (err error) {
   318  		ri := rpcinfo.GetRPCInfo(ctx)
   319  		methodName := ri.Invocation().MethodName()
   320  		serviceName := ri.Invocation().ServiceName()
   321  		svc := s.svcs.svcMap[serviceName]
   322  		svcInfo := svc.svcInfo
   323  		if methodName == "" && svcInfo.ServiceName != serviceinfo.GenericService {
   324  			return errors.New("method name is empty in rpcinfo, should not happen")
   325  		}
   326  		defer func() {
   327  			if handlerErr := recover(); handlerErr != nil {
   328  				err = kerrors.ErrPanic.WithCauseAndStack(
   329  					fmt.Errorf(
   330  						"[happened in biz handler, method=%s.%s, please check the panic at the server side] %s",
   331  						svcInfo.ServiceName, methodName, handlerErr),
   332  					string(debug.Stack()))
   333  				rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats())
   334  				rpcStats.SetPanicked(err)
   335  			}
   336  			rpcinfo.Record(ctx, ri, stats.ServerHandleFinish, err)
   337  			// clear session
   338  			backup.ClearCtx()
   339  		}()
   340  		implHandlerFunc := svcInfo.MethodInfo(methodName).Handler()
   341  		rpcinfo.Record(ctx, ri, stats.ServerHandleStart, nil)
   342  		// set session
   343  		backup.BackupCtx(ctx)
   344  		err = implHandlerFunc(ctx, svc.handler, args, resp)
   345  		if err != nil {
   346  			if bizErr, ok := kerrors.FromBizStatusError(err); ok {
   347  				if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok {
   348  					setter.SetBizStatusErr(bizErr)
   349  					return nil
   350  				}
   351  			}
   352  			err = kerrors.ErrBiz.WithCause(err)
   353  		}
   354  		return err
   355  	}
   356  }
   358  func (s *server) initBasicRemoteOption() {
   359  	remoteOpt := s.opt.RemoteOpt
   360  	remoteOpt.TargetSvcInfo = s.targetSvcInfo
   361  	remoteOpt.SvcSearchMap = s.svcs.getSvcInfoSearchMap()
   362  	remoteOpt.RefuseTrafficWithoutServiceName = s.opt.RefuseTrafficWithoutServiceName
   363  	remoteOpt.InitOrResetRPCInfoFunc = s.initOrResetRPCInfoFunc()
   364  	remoteOpt.TracerCtl = s.opt.TracerCtl
   365  	remoteOpt.ReadWriteTimeout = s.opt.Configs.ReadWriteTimeout()
   366  }
   368  func (s *server) richRemoteOption() {
   369  	s.initBasicRemoteOption()
   371  	s.addBoundHandlers(s.opt.RemoteOpt)
   372  }
   374  func (s *server) addBoundHandlers(opt *remote.ServerOption) {
   375  	// add profiler meta handler, which should be exec after other MetaHandlers
   376  	if opt.Profiler != nil && opt.ProfilerMessageTagging != nil {
   377  		s.opt.MetaHandlers = append(s.opt.MetaHandlers,
   378  			remote.NewProfilerMetaHandler(opt.Profiler, opt.ProfilerMessageTagging),
   379  		)
   380  	}
   381  	// for server trans info handler
   382  	if len(s.opt.MetaHandlers) > 0 {
   383  		transInfoHdlr := bound.NewTransMetaHandler(s.opt.MetaHandlers)
   384  		// meta handler exec before boundHandlers which add with option
   385  		doAddBoundHandlerToHead(transInfoHdlr, opt)
   386  		for _, h := range s.opt.MetaHandlers {
   387  			if shdlr, ok := h.(remote.StreamingMetaHandler); ok {
   388  				opt.StreamingMetaHandlers = append(opt.StreamingMetaHandlers, shdlr)
   389  			}
   390  		}
   391  	}
   393  	limitHdlr := s.buildLimiterWithOpt()
   394  	if limitHdlr != nil {
   395  		doAddBoundHandler(limitHdlr, opt)
   396  	}
   397  }
   399  /*
   400   * There are two times when the rate limiter can take effect for a non-multiplexed server,
   401   * which are the OnRead and OnMessage callback. OnRead is called before request decoded
   402   * and OnMessage is called after.
   403   * Therefore, the optimization point is that we can make rate limiter take effect in OnRead as
   404   * possible to save computational cost of decoding.
   405   * The implementation is that when using the default rate limiter to launching a non-multiplexed
   406   * service, use the `serverLimiterOnReadHandler` whose rate limiting takes effect in the OnRead
   407   * callback.
   408   */
   409  func (s *server) buildLimiterWithOpt() (handler remote.InboundHandler) {
   410  	limits := s.opt.Limit.Limits
   411  	connLimit := s.opt.Limit.ConLimit
   412  	qpsLimit := s.opt.Limit.QPSLimit
   413  	if limits == nil && connLimit == nil && qpsLimit == nil {
   414  		return
   415  	}
   417  	if connLimit == nil {
   418  		if limits != nil {
   419  			connLimit = limiter.NewConnectionLimiter(limits.MaxConnections)
   420  		} else {
   421  			connLimit = &limiter.DummyConcurrencyLimiter{}
   422  		}
   423  	}
   425  	if qpsLimit == nil {
   426  		if limits != nil {
   427  			interval := time.Millisecond * 100 // FIXME: should not care this implementation-specific parameter
   428  			qpsLimit = limiter.NewQPSLimiter(interval, limits.MaxQPS)
   429  		} else {
   430  			qpsLimit = &limiter.DummyRateLimiter{}
   431  		}
   432  	} else {
   433  		s.opt.Limit.QPSLimitPostDecode = true
   434  	}
   436  	if limits != nil && limits.UpdateControl != nil {
   437  		updater := limiter.NewLimiterWrapper(connLimit, qpsLimit)
   438  		limits.UpdateControl(updater)
   439  	}
   441  	handler = bound.NewServerLimiterHandler(connLimit, qpsLimit, s.opt.Limit.LimitReporter, s.opt.Limit.QPSLimitPostDecode)
   442  	// TODO: gRPC limiter
   443  	return
   444  }
   446  func (s *server) check() error {
   447  	if len(s.svcs.svcMap) == 0 {
   448  		return errors.New("run: no service. Use RegisterService to set one")
   449  	}
   450  	return checkFallbackServiceForConflictingMethods(s.svcs.conflictingMethodHasFallbackSvcMap, s.opt.RefuseTrafficWithoutServiceName)
   451  }
   453  func doAddBoundHandlerToHead(h remote.BoundHandler, opt *remote.ServerOption) {
   454  	add := false
   455  	if ih, ok := h.(remote.InboundHandler); ok {
   456  		handlers := []remote.InboundHandler{ih}
   457  		opt.Inbounds = append(handlers, opt.Inbounds...)
   458  		add = true
   459  	}
   460  	if oh, ok := h.(remote.OutboundHandler); ok {
   461  		handlers := []remote.OutboundHandler{oh}
   462  		opt.Outbounds = append(handlers, opt.Outbounds...)
   463  		add = true
   464  	}
   465  	if !add {
   466  		panic("invalid BoundHandler: must implement InboundHandler or OutboundHandler")
   467  	}
   468  }
   470  func doAddBoundHandler(h remote.BoundHandler, opt *remote.ServerOption) {
   471  	add := false
   472  	if ih, ok := h.(remote.InboundHandler); ok {
   473  		opt.Inbounds = append(opt.Inbounds, ih)
   474  		add = true
   475  	}
   476  	if oh, ok := h.(remote.OutboundHandler); ok {
   477  		opt.Outbounds = append(opt.Outbounds, oh)
   478  		add = true
   479  	}
   480  	if !add {
   481  		panic("invalid BoundHandler: must implement InboundHandler or OutboundHandler")
   482  	}
   483  }
   485  func (s *server) newSvrTransHandler() (handler remote.ServerTransHandler, err error) {
   486  	transHdlrFactory := s.opt.RemoteOpt.SvrHandlerFactory
   487  	transHdlr, err := transHdlrFactory.NewTransHandler(s.opt.RemoteOpt)
   488  	if err != nil {
   489  		return nil, err
   490  	}
   491  	if setter, ok := transHdlr.(remote.InvokeHandleFuncSetter); ok {
   492  		setter.SetInvokeHandleFunc(s.eps)
   493  	}
   494  	transPl := remote.NewTransPipeline(transHdlr)
   496  	for _, ib := range s.opt.RemoteOpt.Inbounds {
   497  		transPl.AddInboundHandler(ib)
   498  	}
   499  	for _, ob := range s.opt.RemoteOpt.Outbounds {
   500  		transPl.AddOutboundHandler(ob)
   501  	}
   502  	return transPl, nil
   503  }
   505  func (s *server) buildRegistryInfo(lAddr net.Addr) {
   506  	if s.opt.RegistryInfo == nil {
   507  		s.opt.RegistryInfo = &registry.Info{}
   508  	}
   509  	info := s.opt.RegistryInfo
   510  	// notice: lAddr may be nil when listen failed
   511  	info.Addr = lAddr
   512  	if info.ServiceName == "" {
   513  		info.ServiceName = s.opt.Svr.ServiceName
   514  	}
   515  	if info.PayloadCodec == "" {
   516  		info.PayloadCodec = getDefaultSvcInfo(s.svcs).PayloadCodec.String()
   517  	}
   518  	if info.Weight == 0 {
   519  		info.Weight = discovery.DefaultWeight
   520  	}
   521  	if info.Tags == nil {
   522  		info.Tags = s.opt.Svr.Tags
   523  	}
   524  }
   526  func (s *server) fillMoreServiceInfo(lAddr net.Addr) {
   527  	for _, svc := range s.svcs.svcMap {
   528  		ni := *svc.svcInfo
   529  		si := &ni
   530  		extra := make(map[string]interface{}, len(si.Extra)+2)
   531  		for k, v := range si.Extra {
   532  			extra[k] = v
   533  		}
   534  		extra["address"] = lAddr
   535  		extra["transports"] = s.opt.SupportedTransportsFunc(*s.opt.RemoteOpt)
   536  		si.Extra = extra
   537  		svc.svcInfo = si
   538  	}
   539  }
   541  func (s *server) waitExit(errCh chan error) error {
   542  	exitSignal := s.opt.ExitSignal()
   544  	// service may not be available as soon as startup.
   545  	delayRegister := time.After(1 * time.Second)
   546  	for {
   547  		select {
   548  		case err := <-exitSignal:
   549  			return err
   550  		case err := <-errCh:
   551  			return err
   552  		case <-delayRegister:
   553  			s.Lock()
   554  			if err := s.opt.Registry.Register(s.opt.RegistryInfo); err != nil {
   555  				s.Unlock()
   556  				return err
   557  			}
   558  			s.Unlock()
   559  		}
   560  	}
   561  }
   563  func (s *server) findAndSetDefaultService() {
   564  	if len(s.svcs.svcMap) == 1 {
   565  		s.targetSvcInfo = getDefaultSvcInfo(s.svcs)
   566  	}
   567  }
   569  // getDefaultSvc is used to get one ServiceInfo from map
   570  func getDefaultSvcInfo(svcs *services) *serviceinfo.ServiceInfo {
   571  	if len(svcs.svcMap) > 1 && svcs.fallbackSvc != nil {
   572  		return svcs.fallbackSvc.svcInfo
   573  	}
   574  	for _, svc := range svcs.svcMap {
   575  		return svc.svcInfo
   576  	}
   577  	return nil
   578  }
   580  func checkFallbackServiceForConflictingMethods(conflictingMethodHasFallbackSvcMap map[string]bool, refuseTrafficWithoutServiceName bool) error {
   581  	if refuseTrafficWithoutServiceName {
   582  		return nil
   583  	}
   584  	for name, hasFallbackSvc := range conflictingMethodHasFallbackSvcMap {
   585  		if !hasFallbackSvc {
   586  			return fmt.Errorf("method name [%s] is conflicted between services but no fallback service is specified", name)
   587  		}
   588  	}
   589  	return nil
   590  }