github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_local/create_connection.go (about) 1 package db_local 2 3 import ( 4 "context" 5 "fmt" 6 "log" 7 "strings" 8 "time" 9 10 "github.com/jackc/pgx/v5" 11 "github.com/jackc/pgx/v5/pgxpool" 12 "github.com/spf13/viper" 13 "github.com/turbot/steampipe-plugin-sdk/v5/sperr" 14 "github.com/turbot/steampipe/pkg/constants" 15 "github.com/turbot/steampipe/pkg/constants/runtime" 16 "github.com/turbot/steampipe/pkg/db/db_common" 17 "github.com/turbot/steampipe/pkg/filepaths" 18 "github.com/turbot/steampipe/pkg/statushooks" 19 "github.com/turbot/steampipe/pkg/utils" 20 ) 21 22 func getLocalSteampipeConnectionString(opts *CreateDbOptions) (string, error) { 23 if opts == nil { 24 opts = &CreateDbOptions{} 25 } 26 utils.LogTime("db.createDbClient start") 27 defer utils.LogTime("db.createDbClient end") 28 29 // load the db status 30 info, err := GetState() 31 if err != nil { 32 return "", err 33 } 34 if info == nil { 35 return "", fmt.Errorf("steampipe service is not running") 36 } 37 if info.ResolvedListenAddresses == nil { 38 return "", fmt.Errorf("steampipe service is in unknown state") 39 } 40 41 // if no database name is passed, use constants.DatabaseUser 42 if len(opts.Username) == 0 { 43 opts.Username = constants.DatabaseUser 44 } 45 // if no username name is passed, deduce it from the db status 46 if len(opts.DatabaseName) == 0 { 47 opts.DatabaseName = info.Database 48 } 49 // if we still don't have it, fallback to default "postgres" 50 if len(opts.DatabaseName) == 0 { 51 opts.DatabaseName = "postgres" 52 } 53 54 psqlInfoMap := map[string]string{ 55 "host": utils.GetFirstListenAddress(info.ResolvedListenAddresses), 56 "port": fmt.Sprintf("%d", info.Port), 57 "user": opts.Username, 58 "dbname": opts.DatabaseName, 59 } 60 log.Println("[TRACE] SQLInfoMap >>>", psqlInfoMap) 61 psqlInfoMap = utils.MergeMaps(psqlInfoMap, dsnSSLParams()) 62 log.Println("[TRACE] SQLInfoMap >>>", psqlInfoMap) 63 64 psqlInfo := []string{} 65 for k, v := range psqlInfoMap { 66 psqlInfo = append(psqlInfo, fmt.Sprintf("%s=%s", k, v)) 67 } 68 log.Println("[TRACE] PSQLInfo >>>", psqlInfo) 69 70 return strings.Join(psqlInfo, " "), nil 71 } 72 73 type CreateDbOptions struct { 74 DatabaseName, Username string 75 } 76 77 // CreateLocalDbConnection connects and returns a connection to the given database using 78 // the provided username 79 // if the database is not provided (empty), it connects to the default database in the service 80 // that was created during installation. 81 // NOTE: this connection will use the ServiceConnectionAppName 82 func CreateLocalDbConnection(ctx context.Context, opts *CreateDbOptions) (*pgx.Conn, error) { 83 utils.LogTime("db.CreateLocalDbConnection start") 84 defer utils.LogTime("db.CreateLocalDbConnection end") 85 86 psqlInfo, err := getLocalSteampipeConnectionString(opts) 87 if err != nil { 88 return nil, err 89 } 90 91 connConfig, err := pgx.ParseConfig(psqlInfo) 92 if err != nil { 93 return nil, err 94 } 95 96 // set an app name so that we can track database connections from this Steampipe execution 97 // this is used to determine whether the database can safely be closed 98 // and also in pipes to allow accurate usage tracking (it excludes system calls) 99 connConfig.Config.RuntimeParams = map[string]string{ 100 constants.RuntimeParamsKeyApplicationName: runtime.ServiceConnectionAppName, 101 } 102 err = db_common.AddRootCertToConfig(&connConfig.Config, filepaths.GetRootCertLocation()) 103 if err != nil { 104 return nil, err 105 } 106 107 conn, err := pgx.ConnectConfig(ctx, connConfig) 108 if err != nil { 109 return nil, err 110 } 111 112 if err := db_common.WaitForConnectionPing(ctx, conn); err != nil { 113 return nil, err 114 } 115 return conn, nil 116 } 117 118 // CreateConnectionPool creates a connection pool using the provided options 119 // NOTE: this connection pool will use the ServiceConnectionAppName 120 func CreateConnectionPool(ctx context.Context, opts *CreateDbOptions, maxConnections int) (*pgxpool.Pool, error) { 121 utils.LogTime("db_client.establishConnectionPool start") 122 defer utils.LogTime("db_client.establishConnectionPool end") 123 124 psqlInfo, err := getLocalSteampipeConnectionString(opts) 125 if err != nil { 126 return nil, err 127 } 128 129 poolConfig, err := pgxpool.ParseConfig(psqlInfo) 130 if err != nil { 131 return nil, err 132 } 133 134 const ( 135 connMaxIdleTime = 1 * time.Minute 136 connMaxLifetime = 10 * time.Minute 137 ) 138 139 poolConfig.MinConns = 0 140 poolConfig.MaxConns = int32(maxConnections) 141 poolConfig.MaxConnLifetime = connMaxLifetime 142 poolConfig.MaxConnIdleTime = connMaxIdleTime 143 144 poolConfig.ConnConfig.Config.RuntimeParams = map[string]string{ 145 constants.RuntimeParamsKeyApplicationName: runtime.ServiceConnectionAppName, 146 } 147 148 // this returns connection pool 149 dbPool, err := pgxpool.NewWithConfig(context.Background(), poolConfig) 150 if err != nil { 151 return nil, err 152 } 153 154 err = db_common.WaitForPool( 155 ctx, 156 dbPool, 157 db_common.WithRetryInterval(constants.DBConnectionRetryBackoff), 158 db_common.WithTimeout(time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second), 159 ) 160 if err != nil { 161 return nil, err 162 } 163 return dbPool, nil 164 } 165 166 // createMaintenanceClient connects to the postgres server using the 167 // maintenance database (postgres) and superuser 168 // this is used in a couple of places 169 // 1. During installation to setup the DBMS with foreign_server, extension et.al. 170 // 2. During service start and stop to query the DBMS for parameters (connected clients, database name etc.) 171 // 172 // this is called immediately after the service process is started and hence 173 // all special handling related to service startup failures SHOULD be handled here 174 func createMaintenanceClient(ctx context.Context, port int) (*pgx.Conn, error) { 175 utils.LogTime("db_local.createMaintenanceClient start") 176 defer utils.LogTime("db_local.createMaintenanceClient end") 177 178 connStr := fmt.Sprintf("host=127.0.0.1 port=%d user=%s dbname=postgres sslmode=disable application_name=%s", port, constants.DatabaseSuperUser, runtime.ServiceConnectionAppName) 179 180 timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second) 181 defer cancel() 182 183 statushooks.SetStatus(ctx, "Waiting for connection") 184 conn, err := db_common.WaitForConnection( 185 timeoutCtx, 186 connStr, 187 db_common.WithRetryInterval(constants.DBConnectionRetryBackoff), 188 db_common.WithTimeout(time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second), 189 ) 190 if err != nil { 191 log.Println("[TRACE] could not connect to service") 192 return nil, sperr.Wrap(err, sperr.WithMessage("connection setup failed")) 193 } 194 195 // wait for db to start accepting queries on this connection 196 err = db_common.WaitForConnectionPing( 197 timeoutCtx, 198 conn, 199 db_common.WithRetryInterval(constants.DBConnectionRetryBackoff), 200 db_common.WithTimeout(viper.GetDuration(constants.ArgDatabaseStartTimeout)*time.Second), 201 ) 202 if err != nil { 203 conn.Close(ctx) 204 log.Println("[TRACE] Ping timed out") 205 return nil, sperr.Wrap(err, sperr.WithMessage("connection setup failed")) 206 } 207 208 // wait for recovery to complete 209 // the database may enter recovery mode if it detects that 210 // it wasn't shutdown gracefully. 211 // For large databases, this can take long 212 // We want to wait for a LONG time for this to complete 213 // Use the context that was given - since that is tied to os.Signal 214 // and can be interrupted 215 err = db_common.WaitForRecovery( 216 ctx, 217 conn, 218 db_common.WithRetryInterval(constants.DBRecoveryRetryBackoff), 219 ) 220 if err != nil { 221 conn.Close(ctx) 222 log.Println("[TRACE] WaitForRecovery timed out") 223 return nil, sperr.Wrap(err, sperr.WithMessage("could not complete recovery")) 224 } 225 226 return conn, nil 227 }