github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/server.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gain
    16  
    17  import (
    18  	"fmt"
    19  	"log"
    20  	"os"
    21  	"os/signal"
    22  	"sync"
    23  	"sync/atomic"
    24  	"syscall"
    25  
    26  	"github.com/pawelgaczynski/gain/logger"
    27  	gainErrors "github.com/pawelgaczynski/gain/pkg/errors"
    28  	gainNet "github.com/pawelgaczynski/gain/pkg/net"
    29  	"github.com/pawelgaczynski/giouring"
    30  	"github.com/rs/zerolog"
    31  )
    32  
    33  const (
    34  	inactive uint32 = iota
    35  	starting
    36  	running
    37  	closing
    38  )
    39  
    40  type Server interface {
    41  	// Start starts a server that will listen on the specified address and
    42  	// waits for SIGTERM or SIGINT signals to close the server.
    43  	StartAsMainProcess(address string) error
    44  	// Start starts a server that will listen on the specified address.
    45  	Start(address string) error
    46  	// Shutdown closes all connections and shuts down server. It's blocking until the server shuts down.
    47  	Shutdown()
    48  	// AsyncShutdown closes all connections and shuts down server in asynchronous manner.
    49  	// It does not wait for the server shutdown to complete.
    50  	AsyncShutdown()
    51  	// ActiveConnections returns the number of active connections.
    52  	ActiveConnections() int
    53  	// IsRunning returns true if server is running and handling requests.
    54  	IsRunning() bool
    55  }
    56  
    57  type engine struct {
    58  	config  Config
    59  	network string
    60  	address string
    61  
    62  	closer        func() error
    63  	closeChan     chan closeSignal
    64  	closeBackChan chan bool
    65  	state         atomic.Uint32
    66  	logger        zerolog.Logger
    67  
    68  	eventHandler EventHandler
    69  
    70  	readWriteWorkers sync.Map
    71  }
    72  
    73  type closeSignal int
    74  
    75  const (
    76  	system closeSignal = iota
    77  	user
    78  )
    79  
    80  func (e *engine) startConsumers(startedWg, doneWg *sync.WaitGroup) {
    81  	numberOfWorkers := int32(e.config.Workers)
    82  	e.readWriteWorkers.Range(func(key any, value any) bool {
    83  		var index int
    84  		var worker *consumerWorker
    85  		var ok bool
    86  		if index, ok = key.(int); !ok {
    87  			return false
    88  		}
    89  		if worker, ok = value.(*consumerWorker); !ok {
    90  			return false
    91  		}
    92  		startedWg.Add(1)
    93  		doneWg.Add(1)
    94  		go func(cWorker *consumerWorker) {
    95  			err := cWorker.loop(0)
    96  			if err != nil {
    97  				e.handleWorkerStop(index, &numberOfWorkers, err, startedWg, *cWorker.readWriteWorkerImpl)
    98  			}
    99  			doneWg.Done()
   100  		}(worker)
   101  		<-worker.startedChan
   102  
   103  		return true
   104  	})
   105  }
   106  
   107  func (e *engine) handleWorkerStop(
   108  	index int, numberOfWorkers *int32, err error, startedWg *sync.WaitGroup, worker readWriteWorkerImpl,
   109  ) {
   110  	atomic.AddInt32(numberOfWorkers, -1)
   111  	livingWorkers := atomic.LoadInt32(numberOfWorkers)
   112  	e.logger.Error().Err(err).Int32("living workers", livingWorkers).Int("worker index", index).Msg("Worker died...")
   113  
   114  	if !worker.started() {
   115  		startedWg.Done()
   116  	}
   117  
   118  	e.readWriteWorkers.Delete(index)
   119  }
   120  
   121  func (e *engine) startReactor(listener *listener, features supportedFeatures) error {
   122  	var doneWg, startedWg sync.WaitGroup
   123  
   124  	lb, err := createLoadBalancer(e.config.LoadBalancing)
   125  	if err != nil {
   126  		return err
   127  	}
   128  
   129  	acceptor, err := newAcceptorWorker(acceptorWorkerConfig{
   130  		workerConfig: workerConfig{
   131  			cpuAffinity:     e.config.CPUAffinity,
   132  			processPriority: e.config.ProcessPriority,
   133  			maxCQEvents:     int(e.config.MaxCQEvents),
   134  			loggerLevel:     e.config.LoggerLevel,
   135  			prettyLogger:    e.config.PrettyLogger,
   136  			maxSQEntries:    e.config.MaxSQEntries,
   137  		},
   138  		tcpKeepAlive: e.config.TCPKeepAlive,
   139  	}, lb, e.eventHandler, features)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	acceptor.startListener = func() {
   145  		startedWg.Done()
   146  	}
   147  
   148  	for i := 0; i < e.config.Workers; i++ {
   149  		var consumer *consumerWorker
   150  
   151  		consumer, err = newConsumerWorker(i+1, listener.addr, consumerConfig{
   152  			readWriteWorkerConfig: readWriteWorkerConfig{
   153  				workerConfig: workerConfig{
   154  					cpuAffinity:  e.config.CPUAffinity,
   155  					maxCQEvents:  int(e.config.MaxCQEvents),
   156  					loggerLevel:  e.config.LoggerLevel,
   157  					prettyLogger: e.config.PrettyLogger,
   158  					maxSQEntries: e.config.MaxSQEntries,
   159  				},
   160  				asyncHandler:  e.config.AsyncHandler,
   161  				goroutinePool: e.config.GoroutinePool,
   162  			},
   163  		}, e.eventHandler, features)
   164  		if err != nil {
   165  			return err
   166  		}
   167  		consumer.startListener = func() {
   168  			startedWg.Done()
   169  		}
   170  		acceptor.registerConsumer(consumer)
   171  		e.readWriteWorkers.Store(i, consumer)
   172  	}
   173  	e.closer = func() error {
   174  		acceptor.shutdown()
   175  
   176  		return nil
   177  	}
   178  
   179  	e.startConsumers(&startedWg, &doneWg)
   180  
   181  	startedWg.Add(1)
   182  	doneWg.Add(1)
   183  
   184  	go func(acceptor *acceptorWorker) {
   185  		err = acceptor.loop(listener.fd)
   186  		if err != nil {
   187  			log.Panic(err)
   188  		}
   189  
   190  		doneWg.Done()
   191  	}(acceptor)
   192  	startedWg.Wait()
   193  	e.state.Store(running)
   194  
   195  	e.eventHandler.OnStart(e)
   196  	doneWg.Wait()
   197  
   198  	return nil
   199  }
   200  
   201  func (e *engine) startSocketSharding(listeners []*listener, protocol string) error {
   202  	var (
   203  		doneWg    sync.WaitGroup
   204  		startedWg sync.WaitGroup
   205  	)
   206  
   207  	for i := 0; i < e.config.Workers; i++ {
   208  		shardWorker, err := newShardWorker(i, listeners[i].addr, shardWorkerConfig{
   209  			readWriteWorkerConfig: readWriteWorkerConfig{
   210  				workerConfig: workerConfig{
   211  					cpuAffinity:  e.config.CPUAffinity,
   212  					maxCQEvents:  int(e.config.MaxCQEvents),
   213  					loggerLevel:  e.config.LoggerLevel,
   214  					prettyLogger: e.config.PrettyLogger,
   215  					maxSQEntries: e.config.MaxSQEntries,
   216  				},
   217  				asyncHandler:  e.config.AsyncHandler,
   218  				goroutinePool: e.config.GoroutinePool,
   219  				sendRecvMsg:   protocol == gainNet.UDP,
   220  			},
   221  			tcpKeepAlive: e.config.TCPKeepAlive,
   222  		}, e.eventHandler)
   223  		if err != nil {
   224  			return err
   225  		}
   226  		shardWorker.startListener = func() {
   227  			startedWg.Done()
   228  		}
   229  		e.readWriteWorkers.Store(i, shardWorker)
   230  	}
   231  	e.closer = func() error {
   232  		e.readWriteWorkers.Range(func(key any, value any) bool {
   233  			var worker *shardWorker
   234  			var ok bool
   235  			if worker, ok = value.(*shardWorker); !ok {
   236  				return false
   237  			}
   238  			worker.shutdown()
   239  			e.logger.Warn().Msgf("Worker %d closed", worker.index())
   240  
   241  			return true
   242  		})
   243  
   244  		return nil
   245  	}
   246  	numberOfWorkers := int32(e.config.Workers)
   247  	e.readWriteWorkers.Range(func(key any, value any) bool {
   248  		var index int
   249  		var worker *shardWorker
   250  		var ok bool
   251  		if index, ok = key.(int); !ok {
   252  			return false
   253  		}
   254  		if worker, ok = value.(*shardWorker); !ok {
   255  			return false
   256  		}
   257  		startedWg.Add(1)
   258  		doneWg.Add(1)
   259  		go func(sWorker *shardWorker) {
   260  			err := sWorker.loop(listeners[index].fd)
   261  			if err != nil {
   262  				e.handleWorkerStop(index, &numberOfWorkers, err, &startedWg, *sWorker.readWriteWorkerImpl)
   263  			}
   264  			doneWg.Done()
   265  		}(worker)
   266  		<-worker.startedChan
   267  
   268  		return true
   269  	})
   270  	startedWg.Wait()
   271  	e.state.Store(running)
   272  
   273  	e.eventHandler.OnStart(e)
   274  	doneWg.Wait()
   275  
   276  	return nil
   277  }
   278  
   279  func (e *engine) start(mainProcess bool, address string) error {
   280  	if state := e.state.Load(); state != inactive {
   281  		return gainErrors.ErrInvalidState
   282  	}
   283  
   284  	e.state.Store(starting)
   285  
   286  	var (
   287  		err      error
   288  		features = supportedFeatures{}
   289  	)
   290  
   291  	probe, err := giouring.GetProbe()
   292  	if err != nil {
   293  		return fmt.Errorf("getProbe err: %w", err)
   294  	}
   295  
   296  	features.ringsMessaging = probe.IsSupported(giouring.OpMsgRing)
   297  
   298  	if mainProcess {
   299  		sigs := make(chan os.Signal, 1)
   300  		signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
   301  
   302  		go func() {
   303  			<-sigs
   304  			e.closeChan <- system
   305  		}()
   306  	}
   307  
   308  	go func() {
   309  		closeSig := <-e.closeChan
   310  
   311  		closeErr := e.closer()
   312  		if closeErr != nil {
   313  			e.logger.Error().Err(closeErr).Msg("Closing server error")
   314  		}
   315  
   316  		e.state.Store(inactive)
   317  
   318  		e.closeBackChan <- true
   319  
   320  		if closeSig == system {
   321  			os.Exit(0)
   322  		}
   323  	}()
   324  
   325  	e.network, e.address = parseProtoAddr(address)
   326  
   327  	if e.config.Architecture == SocketSharding || e.network == gainNet.UDP {
   328  		listeners := make([]*listener, e.config.Workers)
   329  		for i := 0; i < len(listeners); i++ {
   330  			var listener *listener
   331  
   332  			listener, err = initListener(e.network, e.address, e.config)
   333  			if err != nil {
   334  				return err
   335  			}
   336  			listeners[i] = listener
   337  		}
   338  
   339  		return e.startSocketSharding(listeners, e.network)
   340  	}
   341  
   342  	listener, err := initListener(e.network, e.address, e.config)
   343  	if err != nil {
   344  		return err
   345  	}
   346  
   347  	return e.startReactor(listener, features)
   348  }
   349  
   350  func (e *engine) StartAsMainProcess(address string) error {
   351  	if e.IsRunning() {
   352  		return gainErrors.ErrServerAlreadyRunning
   353  	}
   354  
   355  	return e.start(true, address)
   356  }
   357  
   358  func (e *engine) Start(address string) error {
   359  	if e.IsRunning() {
   360  		return gainErrors.ErrServerAlreadyRunning
   361  	}
   362  
   363  	return e.start(false, address)
   364  }
   365  
   366  func (e *engine) Shutdown() {
   367  	if state := e.state.Load(); state == running {
   368  		e.state.Store(closing)
   369  		e.closeChan <- user
   370  		<-e.closeBackChan
   371  	}
   372  }
   373  
   374  func (e *engine) AsyncShutdown() {
   375  	if state := e.state.Load(); state == running {
   376  		e.state.Store(closing)
   377  		e.closeChan <- user
   378  	}
   379  }
   380  
   381  func (e *engine) ActiveConnections() int {
   382  	connections := 0
   383  
   384  	e.readWriteWorkers.Range(func(key any, value any) bool {
   385  		if worker, ok := value.(readWriteWorker); ok {
   386  			connections += worker.activeConnections()
   387  
   388  			return true
   389  		}
   390  
   391  		return false
   392  	})
   393  
   394  	return connections
   395  }
   396  
   397  func (e *engine) IsRunning() bool {
   398  	return e.state.Load() == running
   399  }
   400  
   401  func NewServer(eventHandler EventHandler, config Config) Server {
   402  	return &engine{
   403  		config:        config,
   404  		logger:        logger.NewLogger("server", config.LoggerLevel, config.PrettyLogger),
   405  		eventHandler:  eventHandler,
   406  		closeChan:     make(chan closeSignal),
   407  		closeBackChan: make(chan bool),
   408  	}
   409  }