github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_common/wait_connection.go (about)

     1  package db_common
     2  
     3  import (
     4  	"context"
     5  	"log"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/jackc/pgx/v5"
    10  	"github.com/jackc/pgx/v5/pgxpool"
    11  	"github.com/pkg/errors"
    12  	"github.com/sethvargo/go-retry"
    13  	"github.com/turbot/steampipe/pkg/constants"
    14  	"github.com/turbot/steampipe/pkg/error_helpers"
    15  	"github.com/turbot/steampipe/pkg/statushooks"
    16  	"github.com/turbot/steampipe/pkg/utils"
    17  )
    18  
    19  var ErrServiceInRecoveryMode = errors.New("service is in recovery mode")
    20  
    21  type waitConfig struct {
    22  	retryInterval time.Duration
    23  	timeout       time.Duration
    24  }
    25  
    26  type WaitOption func(w *waitConfig)
    27  
    28  func WithRetryInterval(d time.Duration) WaitOption {
    29  	return func(w *waitConfig) {
    30  		w.retryInterval = d
    31  	}
    32  }
    33  func WithTimeout(d time.Duration) WaitOption {
    34  	return func(w *waitConfig) {
    35  		w.timeout = d
    36  	}
    37  }
    38  
    39  func WaitForConnection(ctx context.Context, connStr string, options ...WaitOption) (conn *pgx.Conn, err error) {
    40  	utils.LogTime("db_common.waitForConnection start")
    41  	defer utils.LogTime("db.waitForConnection end")
    42  
    43  	config := &waitConfig{
    44  		retryInterval: constants.DBConnectionRetryBackoff,
    45  		timeout:       constants.DBStartTimeout,
    46  	}
    47  
    48  	for _, o := range options {
    49  		o(config)
    50  	}
    51  
    52  	backoff := retry.WithMaxDuration(
    53  		config.timeout,
    54  		retry.NewConstant(config.retryInterval),
    55  	)
    56  
    57  	// create a connection to the service.
    58  	// Retry after a backoff, but only upto a maximum duration.
    59  	err = retry.Do(ctx, backoff, func(rCtx context.Context) error {
    60  		log.Println("[TRACE] Trying to create client with: ", connStr)
    61  		dbConnection, err := pgx.Connect(rCtx, connStr)
    62  		if err != nil {
    63  			log.Println("[TRACE] could not connect:", err)
    64  			return retry.RetryableError(err)
    65  		}
    66  		log.Println("[TRACE] connected to database")
    67  		conn = dbConnection
    68  		return nil
    69  	})
    70  
    71  	return conn, err
    72  }
    73  
    74  // WaitForPool waits for the db to start accepting connections and returns true
    75  // returns false if the dbClient does not start within a stipulated time,
    76  func WaitForPool(ctx context.Context, db *pgxpool.Pool, waitOptions ...WaitOption) (err error) {
    77  	utils.LogTime("db.waitForConnection start")
    78  	defer utils.LogTime("db.waitForConnection end")
    79  
    80  	connection, err := db.Acquire(ctx)
    81  	if err != nil {
    82  		return err
    83  	}
    84  	defer connection.Release()
    85  	return WaitForConnectionPing(ctx, connection.Conn(), waitOptions...)
    86  }
    87  
    88  // WaitForConnectionPing PINGs the DB - retrying after a backoff of constants.ServicePingInterval - but only for constants.DBConnectionTimeout
    89  // returns the error from the database if the dbClient does not respond successfully after a timeout
    90  func WaitForConnectionPing(ctx context.Context, connection *pgx.Conn, waitOptions ...WaitOption) (err error) {
    91  	utils.LogTime("db_common.waitForConnection start")
    92  	defer utils.LogTime("db.waitForConnection end")
    93  
    94  	config := &waitConfig{
    95  		retryInterval: constants.ServicePingInterval,
    96  		timeout:       constants.DBStartTimeout,
    97  	}
    98  
    99  	for _, o := range waitOptions {
   100  		o(config)
   101  	}
   102  
   103  	retryBackoff := retry.WithMaxDuration(
   104  		config.timeout,
   105  		retry.NewConstant(config.retryInterval),
   106  	)
   107  
   108  	retryErr := retry.Do(ctx, retryBackoff, func(ctx context.Context) error {
   109  		log.Println("[TRACE] Pinging")
   110  		pingErr := connection.Ping(ctx)
   111  		if pingErr != nil {
   112  			log.Println("[TRACE] Pinging failed -> trying again")
   113  			return retry.RetryableError(pingErr)
   114  		}
   115  		return nil
   116  	})
   117  
   118  	return retryErr
   119  }
   120  
   121  // WaitForRecovery returns an error (ErrRecoveryMode) if the service stays in recovery
   122  // mode for more than constants.DBRecoveryWaitTimeout
   123  func WaitForRecovery(ctx context.Context, connection *pgx.Conn, waitOptions ...WaitOption) (err error) {
   124  	utils.LogTime("db_common.WaitForRecovery start")
   125  	defer utils.LogTime("db_common.WaitForRecovery end")
   126  
   127  	config := &waitConfig{
   128  		retryInterval: constants.ServicePingInterval,
   129  		timeout:       time.Duration(0),
   130  	}
   131  
   132  	for _, o := range waitOptions {
   133  		o(config)
   134  	}
   135  
   136  	var retryBackoff retry.Backoff
   137  	if config.timeout == 0 {
   138  		retryBackoff = retry.NewConstant(config.retryInterval)
   139  	} else {
   140  		retryBackoff = retry.WithMaxDuration(
   141  			config.timeout,
   142  			retry.NewConstant(config.retryInterval),
   143  		)
   144  	}
   145  
   146  	// this is to make sure that we set the
   147  	// "recovering" status only once, even if it's
   148  	// called from inside the retry loop
   149  	recoveryStatusUpdateOnce := &sync.Once{}
   150  
   151  	retryErr := retry.Do(ctx, retryBackoff, func(ctx context.Context) error {
   152  		log.Println("[TRACE] checking for recovery mode")
   153  		row := connection.QueryRow(ctx, "select pg_is_in_recovery();")
   154  		var isInRecovery bool
   155  		if scanErr := row.Scan(&isInRecovery); scanErr != nil {
   156  			if error_helpers.IsContextCancelledError(scanErr) {
   157  				return scanErr
   158  			}
   159  			log.Println("[ERROR] checking for recover mode", scanErr)
   160  			return retry.RetryableError(scanErr)
   161  		}
   162  		if isInRecovery {
   163  			log.Println("[TRACE] service is in recovery")
   164  
   165  			recoveryStatusUpdateOnce.Do(func() {
   166  				statushooks.SetStatus(ctx, "Database is recovering. This may take some time.")
   167  			})
   168  
   169  			return retry.RetryableError(ErrServiceInRecoveryMode)
   170  		}
   171  		return nil
   172  	})
   173  
   174  	return retryErr
   175  }