github.com/clubpay/ronykit/kit@v0.14.4-0.20240515065620-d0dace45cbc7/edge.go (about)

     1  package kit
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"hash/crc32"
     7  	"io"
     8  	"os"
     9  	"os/exec"
    10  	"os/signal"
    11  	"reflect"
    12  	"runtime"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/clubpay/ronykit/kit/errors"
    18  	"github.com/clubpay/ronykit/kit/utils"
    19  	"github.com/jedib0t/go-pretty/v6/table"
    20  	"github.com/jedib0t/go-pretty/v6/text"
    21  )
    22  
    23  // EdgeServer is the main component of the kit. It glues all other components of the
    24  // app to each other.
    25  type EdgeServer struct {
    26  	sb        *southBridge
    27  	nb        []*northBridge
    28  	gh        []HandlerFunc
    29  	svc       []Service
    30  	cd        ConnDelegate
    31  	contracts map[string]Contract
    32  	eh        ErrHandlerFunc
    33  	l         Logger
    34  	wg        sync.WaitGroup
    35  
    36  	// trace tools
    37  	t Tracer
    38  
    39  	// configs
    40  	prefork         bool
    41  	shutdownTimeout time.Duration
    42  
    43  	// local store
    44  	ls localStore
    45  }
    46  
    47  func NewServer(opts ...Option) *EdgeServer {
    48  	s := &EdgeServer{
    49  		contracts: map[string]Contract{},
    50  		ls: localStore{
    51  			kv: map[string]any{},
    52  		},
    53  	}
    54  	cfg := &edgeConfig{
    55  		logger:     NOPLogger{},
    56  		errHandler: func(ctx *Context, err error) {},
    57  	}
    58  	for _, opt := range opts {
    59  		opt(cfg)
    60  	}
    61  
    62  	s.l = cfg.logger
    63  	s.prefork = cfg.prefork
    64  	s.eh = cfg.errHandler
    65  	s.gh = cfg.globalHandlers
    66  	s.cd = cfg.connDelegate
    67  	if cfg.tracer != nil {
    68  		s.t = cfg.tracer
    69  	}
    70  
    71  	if cfg.cluster != nil {
    72  		s.registerCluster(utils.RandomID(32), cfg.cluster)
    73  	}
    74  	for _, gw := range cfg.gateways {
    75  		s.registerGateway(gw)
    76  	}
    77  	for _, svc := range cfg.services {
    78  		s.registerService(svc)
    79  	}
    80  
    81  	return s
    82  }
    83  
    84  // RegisterGateway registers a Gateway to our server.
    85  func (s *EdgeServer) registerGateway(gw Gateway) *EdgeServer {
    86  	var th HandlerFunc
    87  
    88  	// if tracer is set we inject it to our context pool as the first handler
    89  	if s.t != nil {
    90  		th = s.t.Handler()
    91  	}
    92  
    93  	nb := &northBridge{
    94  		ctxPool: ctxPool{
    95  			ls: &s.ls,
    96  			th: th,
    97  		},
    98  		cd: s.cd,
    99  		wg: &s.wg,
   100  		eh: s.eh,
   101  		c:  s.contracts,
   102  		gw: gw,
   103  		sb: s.sb,
   104  	}
   105  	s.nb = append(s.nb, nb)
   106  
   107  	// Subscribe the northBridge, which is a GatewayDelegate, to connect northBridge with the Gateway
   108  	gw.Subscribe(nb)
   109  
   110  	return s
   111  }
   112  
   113  // RegisterCluster registers a Cluster to our server.
   114  func (s *EdgeServer) registerCluster(id string, cb Cluster) *EdgeServer {
   115  	var th HandlerFunc
   116  	if s.t != nil {
   117  		th = s.t.Handler()
   118  	}
   119  
   120  	s.sb = &southBridge{
   121  		ctxPool: ctxPool{
   122  			ls: &s.ls,
   123  			th: th,
   124  		},
   125  		id:            id,
   126  		wg:            &s.wg,
   127  		eh:            s.eh,
   128  		c:             s.contracts,
   129  		cb:            cb,
   130  		tp:            s.t,
   131  		inProgressMtx: utils.SpinLock{},
   132  		inProgress:    map[string]chan *envelopeCarrier{},
   133  		msgFactories:  map[string]MessageFactoryFunc{},
   134  		l:             s.l,
   135  	}
   136  
   137  	// Subscribe the southBridge, which is a ClusterDelegate, to connect southBridge with the Cluster
   138  	cb.Subscribe(id, s.sb)
   139  
   140  	return s
   141  }
   142  
   143  // RegisterService registers a Service to our server. We need to define the appropriate
   144  // RouteSelector in each desc.Contract.
   145  func (s *EdgeServer) registerService(svc Service) *EdgeServer {
   146  	if _, ok := s.contracts[svc.Name()]; ok {
   147  		panic(errors.New("service already registered: %s", svc.Name()))
   148  	}
   149  
   150  	s.svc = append(s.svc, svc)
   151  	for _, c := range svc.Contracts() {
   152  		s.contracts[c.ID()] = WrapContract(
   153  			c,
   154  			ContractWrapperFunc(s.wrapWithGlobalHandlers),
   155  			ContractWrapperFunc(s.sb.wrapWithCoordinator),
   156  		)
   157  	}
   158  
   159  	return s
   160  }
   161  
   162  func (s *EdgeServer) wrapWithGlobalHandlers(c Contract) Contract {
   163  	if len(s.gh) == 0 {
   164  		return c
   165  	}
   166  
   167  	cw := &contractWrap{
   168  		Contract: c,
   169  		h:        s.gh,
   170  	}
   171  
   172  	return cw
   173  }
   174  
   175  // Start registers services in the registered bundles and start the bundles.
   176  func (s *EdgeServer) Start(ctx context.Context) *EdgeServer {
   177  	if ctx == nil {
   178  		ctx = context.Background()
   179  	}
   180  
   181  	s.l.Debugf("server started.")
   182  
   183  	if s.prefork {
   184  		if childID() > 0 {
   185  			s.startChild(ctx)
   186  		} else {
   187  			s.startParent(ctx)
   188  		}
   189  
   190  		return s
   191  	}
   192  
   193  	s.startup(ctx)
   194  
   195  	return s
   196  }
   197  
   198  func (s *EdgeServer) startChild(ctx context.Context) {
   199  	s.l.Debugf("child process [%d] with parent [%d] started. ", os.Getpid(), os.Getppid())
   200  
   201  	// we are in child process
   202  	// use 1 cpu core per child process
   203  	runtime.GOMAXPROCS(1)
   204  
   205  	// kill current child proc when master exits
   206  	go s.watchParent()
   207  
   208  	s.startup(ctx)
   209  }
   210  
   211  // watchParent watches the parent process
   212  func (s *EdgeServer) watchParent() {
   213  	if runtime.GOOS == "windows" {
   214  		// finds parent process,
   215  		// and waits for it to exit
   216  		p, err := os.FindProcess(os.Getppid())
   217  		if err == nil {
   218  			_, _ = p.Wait()
   219  		}
   220  
   221  		s.shutdown(context.Background())
   222  
   223  		return
   224  	}
   225  	// if it is equal to 1 (init process ID),
   226  	// it indicates that the master process has exited
   227  	for range time.NewTicker(time.Millisecond * 500).C {
   228  		if os.Getppid() == 1 {
   229  			s.shutdown(context.Background())
   230  
   231  			return
   232  		}
   233  	}
   234  }
   235  
   236  func (s *EdgeServer) startParent(_ context.Context) {
   237  	// create variables
   238  	maxProc := runtime.GOMAXPROCS(0)
   239  
   240  	children := make(map[int]*exec.Cmd)
   241  	childChan := make(chan child, maxProc)
   242  
   243  	// launch child processes
   244  	for i := 0; i < maxProc; i++ {
   245  		/* #nosec G204 */
   246  		cmd := exec.Command(os.Args[0], os.Args[1:]...)
   247  		cmd.Stdout = os.Stdout
   248  		cmd.Stderr = os.Stderr
   249  
   250  		// add child flag into child proc env
   251  		cmd.Env = append(
   252  			os.Environ(),
   253  			fmt.Sprintf("%s=%d", envForkChildKey, i+1),
   254  		)
   255  		if err := cmd.Start(); err != nil {
   256  			panic(fmt.Errorf("failed to start a child prefork process, error: %w", err))
   257  		}
   258  
   259  		// store child process
   260  		pid := cmd.Process.Pid
   261  		children[pid] = cmd
   262  
   263  		// notify master if child crashes
   264  		go func() {
   265  			childChan <- child{pid, cmd.Wait()}
   266  		}()
   267  	}
   268  
   269  	ch := <-childChan
   270  	s.l.Debugf("detect child's exit. pid=%d, err=%v", ch.pid, ch.err)
   271  
   272  	// if any child exited then we terminate all children and exit program.
   273  	for _, proc := range children {
   274  		_ = proc.Process.Kill()
   275  	}
   276  
   277  	os.Exit(0)
   278  }
   279  
   280  func (s *EdgeServer) startup(ctx context.Context) {
   281  	for idx := range s.nb {
   282  		for _, svc := range s.svc {
   283  			for _, c := range svc.Contracts() {
   284  				s.nb[idx].gw.Register(svc.Name(), c.ID(), c.Encoding(), c.RouteSelector(), c.Input())
   285  			}
   286  		}
   287  
   288  		err := s.nb[idx].gw.Start(
   289  			ctx,
   290  			GatewayStartConfig{
   291  				ReusePort: s.prefork,
   292  			},
   293  		)
   294  		if err != nil {
   295  			s.l.Errorf("got error on starting gateway: %v", err)
   296  			panic(err)
   297  		}
   298  	}
   299  
   300  	if s.sb != nil {
   301  		for _, svc := range s.svc {
   302  			for _, c := range svc.Contracts() {
   303  				s.sb.registerContract(c.Input(), c.Output())
   304  			}
   305  		}
   306  		err := s.sb.Start(ctx)
   307  		if err != nil {
   308  			s.l.Errorf("got error on starting cluster: %v", err)
   309  			panic(err)
   310  		}
   311  	}
   312  }
   313  
   314  // Shutdown stops the server. If there is no signal input, then it shut down the server immediately.
   315  // However, if there is one or more signals added in the input argument, then it waits for any of them to
   316  // trigger the shutdown process.
   317  // Since this is a graceful shutdown, it waits for all flying requests to complete. However, you can set
   318  // the maximum time that it waits before forcefully shutting down the server, by WithShutdownTimeout
   319  // option. The Default value is 1 minute.
   320  func (s *EdgeServer) Shutdown(ctx context.Context, signals ...os.Signal) {
   321  	if ctx == nil {
   322  		ctx = context.Background()
   323  	}
   324  
   325  	if len(signals) > 0 {
   326  		// Create a signal channel and bind it to all the os signals in the arg
   327  		shutdownChan := make(chan os.Signal, 1)
   328  		signal.Notify(shutdownChan, signals...)
   329  
   330  		// Wait for the shutdown signal
   331  		<-shutdownChan
   332  	}
   333  
   334  	if s.prefork && childID() == 0 {
   335  		return
   336  	}
   337  
   338  	s.shutdown(ctx)
   339  }
   340  
   341  func (s *EdgeServer) shutdown(ctx context.Context) {
   342  	// Shutdown all the registered gateways
   343  	for idx := range s.nb {
   344  		err := s.nb[idx].gw.Shutdown(ctx)
   345  		if err != nil {
   346  			s.l.Errorf("got error on shutdown gateway: %v", err)
   347  		}
   348  	}
   349  
   350  	if s.sb != nil {
   351  		err := s.sb.Shutdown(ctx)
   352  		if err != nil {
   353  			s.l.Errorf("got error on shutdown cluster: %v", err)
   354  		}
   355  	}
   356  
   357  	if s.shutdownTimeout == 0 {
   358  		s.shutdownTimeout = time.Minute
   359  	}
   360  
   361  	waitCh := make(chan struct{}, 1)
   362  	go func() {
   363  		s.wg.Wait()
   364  		waitCh <- struct{}{}
   365  	}()
   366  
   367  	select {
   368  	case <-waitCh:
   369  	case <-time.After(s.shutdownTimeout):
   370  	}
   371  }
   372  
   373  func (s *EdgeServer) PrintRoutes(w io.Writer) *EdgeServer {
   374  	if s.prefork && childID() > 1 {
   375  		return s
   376  	}
   377  
   378  	tw := table.NewWriter()
   379  	tw.SuppressEmptyColumns()
   380  	tw.SetStyle(table.StyleRounded)
   381  	style := tw.Style()
   382  	style.Title = table.TitleOptions{
   383  		Align:  text.AlignLeft,
   384  		Colors: text.Colors{text.FgBlack, text.BgWhite},
   385  		Format: text.FormatTitle,
   386  	}
   387  	style.Options.SeparateRows = true
   388  	style.Color.Header = text.Colors{text.Bold, text.FgWhite}
   389  
   390  	tw.SetColumnConfigs([]table.ColumnConfig{
   391  		{
   392  			Number:    1,
   393  			AutoMerge: true,
   394  			VAlign:    text.VAlignTop,
   395  			Align:     text.AlignLeft,
   396  			WidthMax:  24,
   397  		},
   398  		{
   399  			Number:    2,
   400  			AutoMerge: true,
   401  			VAlign:    text.VAlignTop,
   402  			Align:     text.AlignLeft,
   403  			WidthMax:  36,
   404  		},
   405  		{
   406  			Number:   3,
   407  			Align:    text.AlignLeft,
   408  			WidthMax: 12,
   409  		},
   410  		{
   411  			Number:           4,
   412  			Align:            text.AlignLeft,
   413  			WidthMax:         52,
   414  			WidthMaxEnforcer: text.WrapSoft,
   415  		},
   416  		{
   417  			Number:           5,
   418  			AutoMerge:        true,
   419  			VAlign:           text.VAlignTop,
   420  			Align:            text.AlignLeft,
   421  			WidthMax:         84,
   422  			WidthMaxEnforcer: text.WrapText,
   423  		},
   424  	})
   425  
   426  	tw.AppendHeader(
   427  		table.Row{
   428  			"ContractID",
   429  			"Bundle",
   430  			"API",
   431  			"Route | Predicate",
   432  			"Handlers",
   433  		},
   434  	)
   435  
   436  	for _, svc := range s.svc {
   437  		for _, c := range svc.Contracts() {
   438  			if route := rpcRoute(c.RouteSelector()); route != "" {
   439  				tw.AppendRow(
   440  					table.Row{
   441  						c.ID(),
   442  						reflect.TypeOf(c.RouteSelector()).String(),
   443  						text.FgBlue.Sprint("RPC"),
   444  						route,
   445  						getHandlers(c.Handlers()...),
   446  					},
   447  				)
   448  			}
   449  			if route := restRoute(c.RouteSelector()); route != "" {
   450  				tw.AppendRow(
   451  					table.Row{
   452  						c.ID(),
   453  						reflect.TypeOf(c.RouteSelector()).String(),
   454  						text.FgGreen.Sprint("REST"),
   455  						route,
   456  						getHandlers(c.Handlers()...),
   457  					},
   458  				)
   459  			}
   460  		}
   461  
   462  		tw.AppendSeparator()
   463  	}
   464  	_, _ = w.Write(utils.S2B(tw.Render()))
   465  	_, _ = w.Write(utils.S2B("\n"))
   466  
   467  	if x, ok := w.(interface{ Sync() error }); ok {
   468  		_ = x.Sync()
   469  	} else if x, ok := w.(interface{ Flush() error }); ok {
   470  		_ = x.Flush()
   471  	}
   472  
   473  	return s
   474  }
   475  
   476  func (s *EdgeServer) PrintRoutesCompact(w io.Writer) *EdgeServer {
   477  	if s.prefork && childID() > 1 {
   478  		return s
   479  	}
   480  
   481  	tw := table.NewWriter()
   482  	tw.SuppressEmptyColumns()
   483  	tw.SetStyle(table.StyleRounded)
   484  	style := tw.Style()
   485  	style.Title = table.TitleOptions{
   486  		Align:  text.AlignLeft,
   487  		Colors: text.Colors{text.FgBlack, text.BgWhite},
   488  		Format: text.FormatTitle,
   489  	}
   490  	style.Color.Header = text.Colors{text.Bold, text.FgWhite}
   491  
   492  	tw.SetColumnConfigs([]table.ColumnConfig{
   493  		{
   494  			Number:    1,
   495  			AutoMerge: true,
   496  			VAlign:    text.VAlignTop,
   497  			Align:     text.AlignLeft,
   498  			WidthMax:  32,
   499  		},
   500  		{
   501  			Number:   2,
   502  			Align:    text.AlignLeft,
   503  			WidthMax: 12,
   504  		},
   505  		{
   506  			Number:           3,
   507  			Align:            text.AlignLeft,
   508  			WidthMax:         120,
   509  			WidthMaxEnforcer: text.WrapSoft,
   510  		},
   511  	})
   512  
   513  	tw.AppendHeader(
   514  		table.Row{
   515  			"Service",
   516  			"API",
   517  			"Route | Predicate",
   518  		},
   519  	)
   520  
   521  	for _, svc := range s.svc {
   522  		for _, c := range svc.Contracts() {
   523  			if route := rpcRoute(c.RouteSelector()); route != "" {
   524  				tw.AppendRow(
   525  					table.Row{
   526  						svc.Name(),
   527  						text.FgBlue.Sprint("RPC"),
   528  						route,
   529  					},
   530  				)
   531  			}
   532  			if route := restRoute(c.RouteSelector()); route != "" {
   533  				tw.AppendRow(
   534  					table.Row{
   535  						svc.Name(),
   536  						text.FgGreen.Sprint("REST"),
   537  						route,
   538  					},
   539  				)
   540  			}
   541  		}
   542  
   543  		tw.AppendSeparator()
   544  	}
   545  	_, _ = w.Write(utils.S2B(tw.Render()))
   546  	_, _ = w.Write(utils.S2B("\n"))
   547  
   548  	if x, ok := w.(interface{ Sync() error }); ok {
   549  		_ = x.Sync()
   550  	} else if x, ok := w.(interface{ Flush() error }); ok {
   551  		_ = x.Flush()
   552  	}
   553  
   554  	return s
   555  }
   556  
   557  func getFuncName(f HandlerFunc) string {
   558  	name := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
   559  	parts := strings.Split(name, "/")
   560  
   561  	return getColor(parts[len(parts)-1]).Sprint(parts[len(parts)-1])
   562  }
   563  
   564  func getColor(s string) text.Color {
   565  	c := text.Color(crc32.ChecksumIEEE(utils.S2B(s)) % 7)
   566  	c += text.FgBlack + 1
   567  
   568  	return c
   569  }
   570  
   571  func getHandlers(handlers ...HandlerFunc) string {
   572  	sb := strings.Builder{}
   573  	for idx, h := range handlers {
   574  		if idx != 0 {
   575  			sb.WriteString(", ")
   576  		}
   577  		sb.WriteString(getFuncName(h))
   578  	}
   579  
   580  	return text.WrapSoft(sb.String(), 32)
   581  }
   582  
   583  func rpcRoute(rs RouteSelector) string {
   584  	rpc, ok := rs.(RPCRouteSelector)
   585  	if !ok || rpc.GetPredicate() == "" {
   586  		return ""
   587  	}
   588  
   589  	return text.Colors{
   590  		text.Bold, text.FgBlue,
   591  	}.Sprint(rpc.GetPredicate())
   592  }
   593  
   594  func restRoute(rs RouteSelector) string {
   595  	rest, ok := rs.(RESTRouteSelector)
   596  	if !ok || rest.GetMethod() == "" || rest.GetPath() == "" {
   597  		return ""
   598  	}
   599  
   600  	return fmt.Sprintf("%s %s",
   601  		text.Colors{
   602  			getColor(rest.GetMethod()),
   603  			text.Bold,
   604  		}.Sprint(rest.GetMethod()),
   605  		text.Colors{
   606  			text.BgWhite, text.FgBlack,
   607  		}.Sprint(rest.GetPath()),
   608  	)
   609  }
   610  
   611  type localStore struct {
   612  	kvl sync.RWMutex
   613  	kv  map[string]any
   614  }
   615  
   616  var _ Store = (*localStore)(nil)
   617  
   618  func (ls *localStore) Get(key string) any {
   619  	ls.kvl.RLock()
   620  	v := ls.kv[key]
   621  	ls.kvl.RUnlock()
   622  
   623  	return v
   624  }
   625  
   626  func (ls *localStore) Set(key string, value any) {
   627  	ls.kvl.Lock()
   628  	ls.kv[key] = value
   629  	ls.kvl.Unlock()
   630  }
   631  
   632  func (ls *localStore) Delete(key string) {
   633  	ls.kvl.Lock()
   634  	delete(ls.kv, key)
   635  	ls.kvl.Unlock()
   636  }
   637  
   638  func (ls *localStore) Exists(key string) bool {
   639  	ls.kvl.RLock()
   640  	_, v := ls.kv[key]
   641  	ls.kvl.RUnlock()
   642  
   643  	return v
   644  }
   645  
   646  func (ls *localStore) Scan(prefix string, cb func(key string) bool) {
   647  	ls.kvl.RLock()
   648  	defer ls.kvl.RUnlock()
   649  
   650  	for k := range ls.kv {
   651  		if strings.HasPrefix(k, prefix) {
   652  			if cb(k) {
   653  				return
   654  			}
   655  		}
   656  	}
   657  }