github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_local/create_connection.go (about)

     1  package db_local
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/jackc/pgx/v5"
    11  	"github.com/jackc/pgx/v5/pgxpool"
    12  	"github.com/spf13/viper"
    13  	"github.com/turbot/steampipe-plugin-sdk/v5/sperr"
    14  	"github.com/turbot/steampipe/pkg/constants"
    15  	"github.com/turbot/steampipe/pkg/constants/runtime"
    16  	"github.com/turbot/steampipe/pkg/db/db_common"
    17  	"github.com/turbot/steampipe/pkg/filepaths"
    18  	"github.com/turbot/steampipe/pkg/statushooks"
    19  	"github.com/turbot/steampipe/pkg/utils"
    20  )
    21  
    22  func getLocalSteampipeConnectionString(opts *CreateDbOptions) (string, error) {
    23  	if opts == nil {
    24  		opts = &CreateDbOptions{}
    25  	}
    26  	utils.LogTime("db.createDbClient start")
    27  	defer utils.LogTime("db.createDbClient end")
    28  
    29  	// load the db status
    30  	info, err := GetState()
    31  	if err != nil {
    32  		return "", err
    33  	}
    34  	if info == nil {
    35  		return "", fmt.Errorf("steampipe service is not running")
    36  	}
    37  	if info.ResolvedListenAddresses == nil {
    38  		return "", fmt.Errorf("steampipe service is in unknown state")
    39  	}
    40  
    41  	// if no database name is passed, use constants.DatabaseUser
    42  	if len(opts.Username) == 0 {
    43  		opts.Username = constants.DatabaseUser
    44  	}
    45  	// if no username name is passed, deduce it from the db status
    46  	if len(opts.DatabaseName) == 0 {
    47  		opts.DatabaseName = info.Database
    48  	}
    49  	// if we still don't have it, fallback to default "postgres"
    50  	if len(opts.DatabaseName) == 0 {
    51  		opts.DatabaseName = "postgres"
    52  	}
    53  
    54  	psqlInfoMap := map[string]string{
    55  		"host":   utils.GetFirstListenAddress(info.ResolvedListenAddresses),
    56  		"port":   fmt.Sprintf("%d", info.Port),
    57  		"user":   opts.Username,
    58  		"dbname": opts.DatabaseName,
    59  	}
    60  	log.Println("[TRACE] SQLInfoMap >>>", psqlInfoMap)
    61  	psqlInfoMap = utils.MergeMaps(psqlInfoMap, dsnSSLParams())
    62  	log.Println("[TRACE] SQLInfoMap >>>", psqlInfoMap)
    63  
    64  	psqlInfo := []string{}
    65  	for k, v := range psqlInfoMap {
    66  		psqlInfo = append(psqlInfo, fmt.Sprintf("%s=%s", k, v))
    67  	}
    68  	log.Println("[TRACE] PSQLInfo >>>", psqlInfo)
    69  
    70  	return strings.Join(psqlInfo, " "), nil
    71  }
    72  
    73  type CreateDbOptions struct {
    74  	DatabaseName, Username string
    75  }
    76  
    77  // CreateLocalDbConnection connects and returns a connection to the given database using
    78  // the provided username
    79  // if the database is not provided (empty), it connects to the default database in the service
    80  // that was created during installation.
    81  // NOTE: this connection will use the ServiceConnectionAppName
    82  func CreateLocalDbConnection(ctx context.Context, opts *CreateDbOptions) (*pgx.Conn, error) {
    83  	utils.LogTime("db.CreateLocalDbConnection start")
    84  	defer utils.LogTime("db.CreateLocalDbConnection end")
    85  
    86  	psqlInfo, err := getLocalSteampipeConnectionString(opts)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	connConfig, err := pgx.ParseConfig(psqlInfo)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	// set an app name so that we can track database connections from this Steampipe execution
    97  	// this is used to determine whether the database can safely be closed
    98  	// and also in pipes to allow accurate usage tracking (it excludes system calls)
    99  	connConfig.Config.RuntimeParams = map[string]string{
   100  		constants.RuntimeParamsKeyApplicationName: runtime.ServiceConnectionAppName,
   101  	}
   102  	err = db_common.AddRootCertToConfig(&connConfig.Config, filepaths.GetRootCertLocation())
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	conn, err := pgx.ConnectConfig(ctx, connConfig)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	if err := db_common.WaitForConnectionPing(ctx, conn); err != nil {
   113  		return nil, err
   114  	}
   115  	return conn, nil
   116  }
   117  
   118  // CreateConnectionPool creates a connection pool using the provided options
   119  // NOTE: this connection pool will use the ServiceConnectionAppName
   120  func CreateConnectionPool(ctx context.Context, opts *CreateDbOptions, maxConnections int) (*pgxpool.Pool, error) {
   121  	utils.LogTime("db_client.establishConnectionPool start")
   122  	defer utils.LogTime("db_client.establishConnectionPool end")
   123  
   124  	psqlInfo, err := getLocalSteampipeConnectionString(opts)
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  
   129  	poolConfig, err := pgxpool.ParseConfig(psqlInfo)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	const (
   135  		connMaxIdleTime = 1 * time.Minute
   136  		connMaxLifetime = 10 * time.Minute
   137  	)
   138  
   139  	poolConfig.MinConns = 0
   140  	poolConfig.MaxConns = int32(maxConnections)
   141  	poolConfig.MaxConnLifetime = connMaxLifetime
   142  	poolConfig.MaxConnIdleTime = connMaxIdleTime
   143  
   144  	poolConfig.ConnConfig.Config.RuntimeParams = map[string]string{
   145  		constants.RuntimeParamsKeyApplicationName: runtime.ServiceConnectionAppName,
   146  	}
   147  
   148  	// this returns connection pool
   149  	dbPool, err := pgxpool.NewWithConfig(context.Background(), poolConfig)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	err = db_common.WaitForPool(
   155  		ctx,
   156  		dbPool,
   157  		db_common.WithRetryInterval(constants.DBConnectionRetryBackoff),
   158  		db_common.WithTimeout(time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second),
   159  	)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	return dbPool, nil
   164  }
   165  
   166  // createMaintenanceClient connects to the postgres server using the
   167  // maintenance database (postgres) and superuser
   168  // this is used in a couple of places
   169  //  1. During installation to setup the DBMS with foreign_server, extension et.al.
   170  //  2. During service start and stop to query the DBMS for parameters (connected clients, database name etc.)
   171  //
   172  // this is called immediately after the service process is started and hence
   173  // all special handling related to service startup failures SHOULD be handled here
   174  func createMaintenanceClient(ctx context.Context, port int) (*pgx.Conn, error) {
   175  	utils.LogTime("db_local.createMaintenanceClient start")
   176  	defer utils.LogTime("db_local.createMaintenanceClient end")
   177  
   178  	connStr := fmt.Sprintf("host=127.0.0.1 port=%d user=%s dbname=postgres sslmode=disable application_name=%s", port, constants.DatabaseSuperUser, runtime.ServiceConnectionAppName)
   179  
   180  	timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second)
   181  	defer cancel()
   182  
   183  	statushooks.SetStatus(ctx, "Waiting for connection")
   184  	conn, err := db_common.WaitForConnection(
   185  		timeoutCtx,
   186  		connStr,
   187  		db_common.WithRetryInterval(constants.DBConnectionRetryBackoff),
   188  		db_common.WithTimeout(time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second),
   189  	)
   190  	if err != nil {
   191  		log.Println("[TRACE] could not connect to service")
   192  		return nil, sperr.Wrap(err, sperr.WithMessage("connection setup failed"))
   193  	}
   194  
   195  	// wait for db to start accepting queries on this connection
   196  	err = db_common.WaitForConnectionPing(
   197  		timeoutCtx,
   198  		conn,
   199  		db_common.WithRetryInterval(constants.DBConnectionRetryBackoff),
   200  		db_common.WithTimeout(viper.GetDuration(constants.ArgDatabaseStartTimeout)*time.Second),
   201  	)
   202  	if err != nil {
   203  		conn.Close(ctx)
   204  		log.Println("[TRACE] Ping timed out")
   205  		return nil, sperr.Wrap(err, sperr.WithMessage("connection setup failed"))
   206  	}
   207  
   208  	// wait for recovery to complete
   209  	// the database may enter recovery mode if it detects that
   210  	// it wasn't shutdown gracefully.
   211  	// For large databases, this can take long
   212  	// We want to wait for a LONG time for this to complete
   213  	// Use the context that was given - since that is tied to os.Signal
   214  	// and can be interrupted
   215  	err = db_common.WaitForRecovery(
   216  		ctx,
   217  		conn,
   218  		db_common.WithRetryInterval(constants.DBRecoveryRetryBackoff),
   219  	)
   220  	if err != nil {
   221  		conn.Close(ctx)
   222  		log.Println("[TRACE] WaitForRecovery timed out")
   223  		return nil, sperr.Wrap(err, sperr.WithMessage("could not complete recovery"))
   224  	}
   225  
   226  	return conn, nil
   227  }