github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_local/stop_services.go (about)

     1  package db_local
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"os"
     8  	"strings"
     9  	"syscall"
    10  	"time"
    11  
    12  	psutils "github.com/shirou/gopsutil/process"
    13  	"github.com/turbot/steampipe/pkg/constants"
    14  	"github.com/turbot/steampipe/pkg/constants/runtime"
    15  	"github.com/turbot/steampipe/pkg/db/db_common"
    16  	"github.com/turbot/steampipe/pkg/error_helpers"
    17  	"github.com/turbot/steampipe/pkg/filepaths"
    18  	"github.com/turbot/steampipe/pkg/pluginmanager"
    19  	"github.com/turbot/steampipe/pkg/statushooks"
    20  	"github.com/turbot/steampipe/pkg/utils"
    21  )
    22  
    23  // StopStatus is a pseudoEnum for service stop result
    24  type StopStatus int
    25  
    26  const (
    27  	// start from 1 to prevent confusion with int zero-value
    28  	ServiceStopped StopStatus = iota + 1
    29  	ServiceNotRunning
    30  	ServiceStopFailed
    31  	ServiceStopTimedOut
    32  )
    33  
    34  // ShutdownService stops the database instance if the given 'invoker' matches
    35  func ShutdownService(ctx context.Context, invoker constants.Invoker) {
    36  	utils.LogTime("db_local.ShutdownService start")
    37  	defer utils.LogTime("db_local.ShutdownService end")
    38  
    39  	if error_helpers.IsContextCanceled(ctx) {
    40  		ctx = context.Background()
    41  	}
    42  
    43  	status, _ := GetState()
    44  
    45  	// if the service is not running or it was invoked by 'steampipe service',
    46  	// then we don't shut it down
    47  	if status == nil || status.Invoker == constants.InvokerService {
    48  		return
    49  	}
    50  
    51  	// how many clients are connected
    52  	// under a fresh context
    53  	clientCounts, err := GetClientCount(context.Background())
    54  	// if there are other clients connected
    55  	// and if there's no error
    56  	if err == nil && clientCounts.SteampipeClients > 0 {
    57  		// there are other steampipe clients connected to the database
    58  		// we don't need to stop the service
    59  		// the last one to exit will shutdown the service
    60  		log.Printf("[INFO] ShutdownService not closing database service - %d steampipe %s connected", clientCounts.SteampipeClients, utils.Pluralize("client", clientCounts.SteampipeClients))
    61  		return
    62  	}
    63  
    64  	// we can shut down the database
    65  	stopStatus, err := StopServices(ctx, false, invoker)
    66  	if err != nil {
    67  		error_helpers.ShowError(ctx, err)
    68  	}
    69  	if stopStatus == ServiceStopped {
    70  		return
    71  	}
    72  
    73  	// shutdown failed - try to force stop
    74  	_, err = StopServices(ctx, true, invoker)
    75  	if err != nil {
    76  		error_helpers.ShowError(ctx, err)
    77  	}
    78  
    79  }
    80  
    81  type ClientCount struct {
    82  	SteampipeClients     int
    83  	PluginManagerClients int
    84  	TotalClients         int
    85  }
    86  
    87  // GetClientCount returns the number of connections to the service from anyone other than
    88  // _this_execution_ of steampipe
    89  //
    90  // We assume that any connections from this execution will eventually be closed
    91  // - if there are any other external connections, we cannot shut down the database
    92  //
    93  // this is to handle cases where either a third party tool is connected to the database,
    94  // or other Steampipe sessions are attached to an already running Steampipe service
    95  // - we do not want the db service being closed underneath them
    96  //
    97  // note: we need the PgClientAppName check to handle the case where there may be one or more open DB connections
    98  // from this instance at the time of shutdown - for example when a control run is cancelled
    99  // If we do not exclude connections from this execution, the DB will not be shut down after a cancellation
   100  func GetClientCount(ctx context.Context) (*ClientCount, error) {
   101  	utils.LogTime("db_local.GetClientCount start")
   102  	defer utils.LogTime(fmt.Sprintf("db_local.GetClientCount end"))
   103  
   104  	rootClient, err := CreateLocalDbConnection(ctx, &CreateDbOptions{Username: constants.DatabaseSuperUser})
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	defer rootClient.Close(ctx)
   109  
   110  	query := `
   111  SELECT 
   112    application_name,
   113    count(*)
   114  FROM 
   115    pg_stat_activity 
   116  WHERE
   117  	-- get only the network client processes
   118    client_port IS NOT NULL 
   119  	AND
   120  	-- which are client backends
   121    backend_type=$1 
   122  	AND
   123  	-- which are not connections from this application
   124    application_name!=$2
   125  GROUP BY application_name
   126  `
   127  
   128  	counts := &ClientCount{}
   129  
   130  	log.Println("[INFO] ClientConnectionAppName: ", runtime.ClientConnectionAppName)
   131  	rows, err := rootClient.Query(ctx, query, "client backend", runtime.ClientConnectionAppName)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	defer rows.Close()
   136  
   137  	for rows.Next() {
   138  		var appName string
   139  		var count int
   140  
   141  		if err := rows.Scan(&appName, &count); err != nil {
   142  			return nil, err
   143  		}
   144  		log.Printf("[INFO] appName: %s, count: %d", appName, count)
   145  
   146  		counts.TotalClients += count
   147  
   148  		if db_common.IsClientAppName(appName) {
   149  			counts.SteampipeClients += count
   150  		}
   151  
   152  		// plugin manager uses the service prefix
   153  		if db_common.IsServiceAppName(appName) {
   154  			counts.PluginManagerClients += count
   155  		}
   156  	}
   157  
   158  	return counts, nil
   159  }
   160  
   161  // StopServices searches for and stops the running instance. Does nothing if an instance was not found
   162  func StopServices(ctx context.Context, force bool, invoker constants.Invoker) (status StopStatus, e error) {
   163  	log.Printf("[TRACE] StopDB invoker %s, force %v", invoker, force)
   164  	utils.LogTime("db_local.StopDB start")
   165  
   166  	defer func() {
   167  		if e == nil {
   168  			os.Remove(filepaths.RunningInfoFilePath())
   169  		}
   170  		utils.LogTime("db_local.StopDB end")
   171  	}()
   172  
   173  	log.Println("[INFO] shutting down plugin manager")
   174  	// stop the plugin manager
   175  	// this means it may be stopped even if we fail to stop the service - that is ok - we will restart it if needed
   176  	pluginManagerStopError := pluginmanager.Stop()
   177  	log.Println("[INFO] shut down plugin manager")
   178  
   179  	// stop the DB Service
   180  	log.Println("[INFO] stopping DB Service")
   181  	stopResult, dbStopError := stopDBService(ctx, force)
   182  	log.Println("[INFO] stopped DB Service")
   183  
   184  	return stopResult, error_helpers.CombineErrors(dbStopError, pluginManagerStopError)
   185  }
   186  
   187  func stopDBService(ctx context.Context, force bool) (StopStatus, error) {
   188  	if force {
   189  		// check if we have a process from another install-dir
   190  		statushooks.SetStatus(ctx, "Checking for running instances…")
   191  		// do not use a context that can be cancelled
   192  		anyStopped := killInstanceIfAny(context.Background())
   193  		if anyStopped {
   194  			return ServiceStopped, nil
   195  		}
   196  		return ServiceNotRunning, nil
   197  	}
   198  
   199  	dbState, err := GetState()
   200  	if err != nil {
   201  		return ServiceStopFailed, err
   202  	}
   203  
   204  	if dbState == nil {
   205  		// we do not have a info file
   206  		// assume that the service is not running
   207  		return ServiceNotRunning, nil
   208  	}
   209  
   210  	// GetStatus has made sure that the process exists
   211  	process, err := psutils.NewProcess(int32(dbState.Pid))
   212  	if err != nil {
   213  		return ServiceStopFailed, err
   214  	}
   215  
   216  	err = doThreeStepPostgresExit(ctx, process)
   217  	if err != nil {
   218  		// we couldn't stop it still.
   219  		// timeout
   220  		return ServiceStopTimedOut, err
   221  	}
   222  
   223  	return ServiceStopped, nil
   224  }
   225  
   226  /*
   227  Postgres has three levels of shutdown:
   228  
   229    - SIGTERM   - Smart Shutdown	 :  Wait for children to end normally - exit self
   230    - SIGINT    - Fast Shutdown      :  SIGTERM children, causing them to abort current
   231      transations and exit - wait for children to exit -
   232      exit self
   233    - SIGQUIT   - Immediate Shutdown :  SIGQUIT children - wait at most 5 seconds,
   234      send SIGKILL to children - exit self immediately
   235  
   236  Postgres recommended shutdown is to send a SIGTERM - which initiates
   237  a Smart-Shutdown sequence.
   238  
   239  IMPORTANT:
   240  As per documentation, it is best not to use SIGKILL
   241  to shut down postgres. Doing so will prevent the server
   242  from releasing shared memory and semaphores.
   243  
   244  Reference:
   245  https://www.postgresql.org/docs/12/server-shutdown.html
   246  
   247  By the time we actually try to run this sequence, we will have
   248  checked that the service can indeed shutdown gracefully,
   249  the sequence is there only as a backup.
   250  */
   251  func doThreeStepPostgresExit(ctx context.Context, process *psutils.Process) error {
   252  	utils.LogTime("db_local.doThreeStepPostgresExit start")
   253  	defer utils.LogTime("db_local.doThreeStepPostgresExit end")
   254  
   255  	var err error
   256  	var exitSuccessful bool
   257  
   258  	// send a SIGTERM
   259  	err = process.SendSignal(syscall.SIGTERM)
   260  	if err != nil {
   261  		return err
   262  	}
   263  	exitSuccessful = waitForProcessExit(process, 2*time.Second)
   264  	if !exitSuccessful {
   265  		// process didn't quit
   266  
   267  		// set status, as this is taking time
   268  		statushooks.SetStatus(ctx, "Shutting down…")
   269  
   270  		// try a SIGINT
   271  		err = process.SendSignal(syscall.SIGINT)
   272  		if err != nil {
   273  			return err
   274  		}
   275  		exitSuccessful = waitForProcessExit(process, 2*time.Second)
   276  	}
   277  	if !exitSuccessful {
   278  		// process didn't quit
   279  		// desperation prevails
   280  		err = process.SendSignal(syscall.SIGQUIT)
   281  		if err != nil {
   282  			return err
   283  		}
   284  		exitSuccessful = waitForProcessExit(process, 5*time.Second)
   285  	}
   286  
   287  	if !exitSuccessful {
   288  		log.Println("[ERROR] Failed to stop service")
   289  		log.Printf("[ERROR] Service Details:\n%s\n", getPrintableProcessDetails(process, 0))
   290  		return fmt.Errorf("service shutdown timed out")
   291  	}
   292  
   293  	return nil
   294  }
   295  
   296  func waitForProcessExit(process *psutils.Process, waitFor time.Duration) bool {
   297  	utils.LogTime("db_local.waitForProcessExit start")
   298  	defer utils.LogTime("db_local.waitForProcessExit end")
   299  
   300  	checkTimer := time.NewTicker(50 * time.Millisecond)
   301  	timeoutAt := time.After(waitFor)
   302  
   303  	for {
   304  		select {
   305  		case <-checkTimer.C:
   306  			pEx, _ := utils.PidExists(int(process.Pid))
   307  			if pEx {
   308  				continue
   309  			}
   310  			return true
   311  		case <-timeoutAt:
   312  			checkTimer.Stop()
   313  			return false
   314  		}
   315  	}
   316  }
   317  
   318  func getPrintableProcessDetails(process *psutils.Process, indent int) string {
   319  	utils.LogTime("db_local.getPrintableProcessDetails start")
   320  	defer utils.LogTime("db_local.getPrintableProcessDetails end")
   321  
   322  	indentString := strings.Repeat("  ", indent)
   323  	appendTo := []string{}
   324  
   325  	if name, err := process.Name(); err == nil {
   326  		appendTo = append(appendTo, fmt.Sprintf("%s> Name: %s", indentString, name))
   327  	}
   328  	if cmdLine, err := process.Cmdline(); err == nil {
   329  		appendTo = append(appendTo, fmt.Sprintf("%s> CmdLine: %s", indentString, cmdLine))
   330  	}
   331  	if status, err := process.Status(); err == nil {
   332  		appendTo = append(appendTo, fmt.Sprintf("%s> Status: %s", indentString, status))
   333  	}
   334  	if cwd, err := process.Cwd(); err == nil {
   335  		appendTo = append(appendTo, fmt.Sprintf("%s> CWD: %s", indentString, cwd))
   336  	}
   337  	if executable, err := process.Exe(); err == nil {
   338  		appendTo = append(appendTo, fmt.Sprintf("%s> Executable: %s", indentString, executable))
   339  	}
   340  	if username, err := process.Username(); err == nil {
   341  		appendTo = append(appendTo, fmt.Sprintf("%s> Username: %s", indentString, username))
   342  	}
   343  	if indent == 0 {
   344  		// I do not care about the parent of my parent
   345  		if parent, err := process.Parent(); err == nil && parent != nil {
   346  			appendTo = append(appendTo, "", fmt.Sprintf("%s> Parent Details", indentString))
   347  			parentLog := getPrintableProcessDetails(parent, indent+1)
   348  			appendTo = append(appendTo, parentLog, "")
   349  		}
   350  
   351  		// I do not care about all the children of my parent
   352  		if children, err := process.Children(); err == nil && len(children) > 0 {
   353  			appendTo = append(appendTo, fmt.Sprintf("%s> Children Details", indentString))
   354  			for _, child := range children {
   355  				childLog := getPrintableProcessDetails(child, indent+1)
   356  				appendTo = append(appendTo, childLog, "")
   357  			}
   358  		}
   359  	}
   360  
   361  	return strings.Join(appendTo, "\n")
   362  }