go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/supervisor/supervisor.go (about)

     1  /*
     2  
     3  Copyright (c) 2024 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package supervisor
     9  
    10  import (
    11  	"context"
    12  	"sync"
    13  	"sync/atomic"
    14  
    15  	"go.charczuk.com/sdk/errutil"
    16  	"go.charczuk.com/sdk/graceful"
    17  )
    18  
    19  var _ graceful.Service = (*Supervisor)(nil)
    20  
    21  // Supervisor is a collection of services that should be started / restarted.
    22  type Supervisor struct {
    23  	Services []*Service
    24  
    25  	status  int32
    26  	crashed chan error
    27  	waits   sync.WaitGroup
    28  }
    29  
    30  // StatusTypes
    31  const (
    32  	StatusStopped  int32 = iota
    33  	StatusStarting int32 = iota
    34  	StatusRunning  int32 = iota
    35  	StatusStopping int32 = iota
    36  )
    37  
    38  // Start starts the services and blocks.
    39  func (s *Supervisor) Start(ctx context.Context) error {
    40  	if err := s.StartAsync(ctx); err != nil {
    41  		return err
    42  	}
    43  	return s.Wait(ctx)
    44  }
    45  
    46  // Wait blocks until the services exit.
    47  func (s *Supervisor) Wait(ctx context.Context) error {
    48  	if atomic.LoadInt32(&s.status) != StatusRunning {
    49  		return nil
    50  	}
    51  	defer func() {
    52  		atomic.StoreInt32(&s.status, StatusStopped)
    53  	}()
    54  
    55  	done := make(chan struct{})
    56  	go func() {
    57  		s.waits.Wait()
    58  		close(done)
    59  	}()
    60  
    61  	select {
    62  	case <-ctx.Done():
    63  		return nil
    64  	case err := <-s.crashed:
    65  		s.status = StatusStopping
    66  		for _, service := range s.Services {
    67  			_ = service.Stop()
    68  		}
    69  		return err
    70  	case <-done:
    71  		return nil
    72  	}
    73  }
    74  
    75  // StartAsync starts the supervisor and does not block.
    76  func (s *Supervisor) StartAsync(ctx context.Context) (err error) {
    77  	if !atomic.CompareAndSwapInt32(&s.status, StatusStopped, StatusStarting) {
    78  		return
    79  	}
    80  	defer func() {
    81  		if err != nil {
    82  			atomic.StoreInt32(&s.status, StatusStopped)
    83  		} else {
    84  			atomic.StoreInt32(&s.status, StatusRunning)
    85  		}
    86  	}()
    87  
    88  	s.waits = sync.WaitGroup{}
    89  	s.crashed = make(chan error, len(s.Services))
    90  	for x := 0; x < len(s.Services); x++ {
    91  		s.Services[x].crashed = func(err error) {
    92  			s.crashed <- err
    93  		}
    94  		s.Services[x].finalizer = func() {
    95  			s.waits.Done()
    96  		}
    97  		if err = s.Services[x].Start(ctx); err != nil {
    98  			for y := 0; y < x; y++ {
    99  				_ = s.Services[y].Stop()
   100  			}
   101  			return
   102  		}
   103  		s.waits.Add(1)
   104  	}
   105  	return
   106  }
   107  
   108  // Restart restarts all the services.
   109  //
   110  // If there are errors on restart those errors are returned but
   111  // no recovery or coordinated shutdown attempt is made.
   112  func (s *Supervisor) Restart(_ context.Context) (err error) {
   113  	if !atomic.CompareAndSwapInt32(&s.status, StatusRunning, StatusStopping) {
   114  		return
   115  	}
   116  	defer func() {
   117  		atomic.StoreInt32(&s.status, StatusRunning)
   118  	}()
   119  	for _, service := range s.Services {
   120  		if serviceErr := service.Restart(); serviceErr != nil {
   121  			err = errutil.AppendFlat(err, serviceErr)
   122  		}
   123  	}
   124  	return
   125  }
   126  
   127  // Stop stops the supervisor, implementing graceful.
   128  func (s *Supervisor) Stop(_ context.Context) (err error) {
   129  	if !atomic.CompareAndSwapInt32(&s.status, StatusRunning, StatusStopping) {
   130  		return
   131  	}
   132  	for _, service := range s.Services {
   133  		if serviceErr := service.Stop(); serviceErr != nil {
   134  			err = errutil.AppendFlat(err, serviceErr)
   135  		}
   136  	}
   137  	return
   138  }