github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_client/db_client.go (about) 1 package db_client 2 3 import ( 4 "context" 5 "fmt" 6 "log" 7 "strings" 8 "sync" 9 10 "github.com/jackc/pgx/v5" 11 "github.com/jackc/pgx/v5/pgconn" 12 "github.com/jackc/pgx/v5/pgxpool" 13 "github.com/spf13/viper" 14 "github.com/turbot/steampipe/pkg/constants" 15 "github.com/turbot/steampipe/pkg/db/db_common" 16 "github.com/turbot/steampipe/pkg/serversettings" 17 "github.com/turbot/steampipe/pkg/steampipeconfig" 18 "github.com/turbot/steampipe/pkg/utils" 19 "golang.org/x/exp/maps" 20 "golang.org/x/sync/semaphore" 21 ) 22 23 // DbClient wraps over `sql.DB` and gives an interface to the database 24 type DbClient struct { 25 connectionString string 26 27 // connection userPool for user initiated queries 28 userPool *pgxpool.Pool 29 30 // connection used to run system/plumbing queries (connection state, server settings) 31 managementPool *pgxpool.Pool 32 33 // the settings of the server that this client is connected to 34 serverSettings *db_common.ServerSettings 35 36 // this flag is set if the service that this client 37 // is connected to is running in the same physical system 38 isLocalService bool 39 40 // concurrency management for db session access 41 parallelSessionInitLock *semaphore.Weighted 42 43 // map of database sessions, keyed to the backend_pid in postgres 44 // used to update session search path where necessary 45 // TODO: there's no code which cleans up this map when connections get dropped by pgx 46 // https://github.com/turbot/steampipe/issues/3737 47 sessions map[uint32]*db_common.DatabaseSession 48 49 // allows locked access to the 'sessions' map 50 sessionsMutex *sync.Mutex 51 52 // if a custom search path or a prefix is used, store it here 53 customSearchPath []string 54 searchPathPrefix []string 55 // the default user search path 56 userSearchPath []string 57 // disable timing - set whilst in process of querying the timing 58 disableTiming bool 59 onConnectionCallback DbConnectionCallback 60 } 61 62 func NewDbClient(ctx context.Context, connectionString string, onConnectionCallback DbConnectionCallback, opts ...ClientOption) (_ *DbClient, err error) { 63 utils.LogTime("db_client.NewDbClient start") 64 defer utils.LogTime("db_client.NewDbClient end") 65 66 wg := &sync.WaitGroup{} 67 // wrap onConnectionCallback to use wait group 68 var wrappedOnConnectionCallback DbConnectionCallback 69 if onConnectionCallback != nil { 70 wrappedOnConnectionCallback = func(ctx context.Context, conn *pgx.Conn) error { 71 wg.Add(1) 72 defer wg.Done() 73 return onConnectionCallback(ctx, conn) 74 } 75 } 76 77 client := &DbClient{ 78 // a weighted semaphore to control the maximum number parallel 79 // initializations under way 80 parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits), 81 sessions: make(map[uint32]*db_common.DatabaseSession), 82 sessionsMutex: &sync.Mutex{}, 83 // store the callback 84 onConnectionCallback: wrappedOnConnectionCallback, 85 connectionString: connectionString, 86 } 87 88 defer func() { 89 if err != nil { 90 // try closing the client 91 client.Close(ctx) 92 } 93 }() 94 95 config := clientConfig{} 96 for _, o := range opts { 97 o(&config) 98 } 99 100 if err := client.establishConnectionPool(ctx, config); err != nil { 101 return nil, err 102 } 103 104 // load up the server settings 105 if err := client.loadServerSettings(ctx); err != nil { 106 return nil, err 107 } 108 109 // set user search path 110 if err := client.LoadUserSearchPath(ctx); err != nil { 111 return nil, err 112 } 113 114 // populate customSearchPath 115 if err := client.SetRequiredSessionSearchPath(ctx); err != nil { 116 return nil, err 117 } 118 119 return client, nil 120 } 121 122 func (c *DbClient) closePools() { 123 if c.userPool != nil { 124 c.userPool.Close() 125 } 126 if c.managementPool != nil { 127 c.managementPool.Close() 128 } 129 } 130 131 func (c *DbClient) loadServerSettings(ctx context.Context) error { 132 serverSettings, err := serversettings.Load(ctx, c.managementPool) 133 if err != nil { 134 if notFound := db_common.IsRelationNotFoundError(err); notFound { 135 // when connecting to pre-0.21.0 services, the steampipe_server_settings table will not be available. 136 // this is expected and not an error 137 // code which uses steampipe_server_settings should handle this 138 log.Printf("[TRACE] could not find %s.%s table. skipping\n", constants.InternalSchema, constants.ServerSettingsTable) 139 return nil 140 } 141 return err 142 } 143 c.serverSettings = serverSettings 144 log.Println("[TRACE] loaded server settings:", serverSettings) 145 return nil 146 } 147 148 func (c *DbClient) shouldFetchTiming() bool { 149 // check for override flag (this is to prevent timing being fetched when we read the timing metadata table) 150 if c.disableTiming { 151 return false 152 } 153 // only fetch timing if timing flag is set, or output is JSON 154 return (viper.GetString(constants.ArgTiming) != constants.ArgOff) || 155 (viper.GetString(constants.ArgOutput) == constants.OutputFormatJSON) 156 157 } 158 func (c *DbClient) shouldFetchVerboseTiming() bool { 159 return (viper.GetString(constants.ArgTiming) == constants.ArgVerbose) || 160 (viper.GetString(constants.ArgOutput) == constants.OutputFormatJSON) 161 } 162 163 // ServerSettings returns the settings of the steampipe service that this DbClient is connected to 164 // 165 // Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected. 166 // Code which read server_settings should take this into account. 167 func (c *DbClient) ServerSettings() *db_common.ServerSettings { 168 return c.serverSettings 169 } 170 171 // RegisterNotificationListener has an empty implementation 172 // NOTE: we do not (currently) support notifications from remote connections 173 func (c *DbClient) RegisterNotificationListener(func(notification *pgconn.Notification)) {} 174 175 // Close implements Client 176 177 // closes the connection to the database and shuts down the backend 178 func (c *DbClient) Close(context.Context) error { 179 log.Printf("[TRACE] DbClient.Close %v", c.userPool) 180 c.closePools() 181 // nullify active sessions, since with the closing of the pools 182 // none of the sessions will be valid anymore 183 c.sessions = nil 184 185 return nil 186 } 187 188 // GetSchemaFromDB retrieves schemas for all steampipe connections (EXCEPT DISABLED CONNECTIONS) 189 // NOTE: it optimises the schema extraction by extracting schema information for 190 // connections backed by distinct plugins and then fanning back out. 191 func (c *DbClient) GetSchemaFromDB(ctx context.Context) (*db_common.SchemaMetadata, error) { 192 log.Printf("[INFO] DbClient GetSchemaFromDB") 193 mgmtConn, err := c.managementPool.Acquire(ctx) 194 if err != nil { 195 return nil, err 196 } 197 defer mgmtConn.Release() 198 199 // for optimisation purposes, try to load connection state and build a map of schemas to load 200 // (if we are connected to a remote server running an older CLI, 201 // this load may fail, in which case bypass the optimisation) 202 connectionStateMap, err := steampipeconfig.LoadConnectionState(ctx, mgmtConn.Conn(), steampipeconfig.WithWaitUntilLoading()) 203 // NOTE: if we failed to load connection state, this may be because we are connected to an older version of the CLI 204 // use legacy (v0.19.x) schema loading code 205 if err != nil { 206 return c.GetSchemaFromDBLegacy(ctx, mgmtConn) 207 } 208 209 // build a ConnectionSchemaMap object to identify the schemas to load 210 connectionSchemaMap := steampipeconfig.NewConnectionSchemaMap(ctx, connectionStateMap, c.GetRequiredSessionSearchPath()) 211 if err != nil { 212 return nil, err 213 } 214 215 // get the unique schema - we use this to limit the schemas we load from the database 216 schemas := maps.Keys(connectionSchemaMap) 217 218 // build a query to retrieve these schemas 219 query := c.buildSchemasQuery(schemas...) 220 221 // build schema metadata from query result 222 metadata, err := db_common.LoadSchemaMetadata(ctx, mgmtConn.Conn(), query) 223 if err != nil { 224 return nil, err 225 } 226 227 // we now need to add in all other schemas which have the same schemas as those we have loaded 228 for loadedSchema, otherSchemas := range connectionSchemaMap { 229 // all 'otherSchema's have the same schema as loadedSchema 230 exemplarSchema, ok := metadata.Schemas[loadedSchema] 231 if !ok { 232 // should can happen in the case of a dynamic plugin with no tables - use empty schema 233 exemplarSchema = make(map[string]db_common.TableSchema) 234 } 235 236 for _, s := range otherSchemas { 237 metadata.Schemas[s] = exemplarSchema 238 } 239 } 240 241 return metadata, nil 242 } 243 244 func (c *DbClient) GetSchemaFromDBLegacy(ctx context.Context, conn *pgxpool.Conn) (*db_common.SchemaMetadata, error) { 245 // build a query to retrieve these schemas 246 query := c.buildSchemasQueryLegacy() 247 248 // build schema metadata from query result 249 return db_common.LoadSchemaMetadata(ctx, conn.Conn(), query) 250 } 251 252 // refreshDbClient terminates the current connection and opens up a new connection to the service. 253 func (c *DbClient) ResetPools(ctx context.Context) { 254 log.Println("[TRACE] db_client.ResetPools start") 255 defer log.Println("[TRACE] db_client.ResetPools end") 256 257 c.userPool.Reset() 258 c.managementPool.Reset() 259 } 260 261 func (c *DbClient) buildSchemasQuery(schemas ...string) string { 262 for idx, s := range schemas { 263 schemas[idx] = fmt.Sprintf("'%s'", s) 264 } 265 266 // build the schemas filter clause 267 schemaClause := "" 268 if len(schemas) > 0 { 269 schemaClause = fmt.Sprintf(` 270 cols.table_schema in (%s) 271 OR`, strings.Join(schemas, ",")) 272 } 273 274 query := fmt.Sprintf(` 275 SELECT 276 table_name, 277 column_name, 278 column_default, 279 is_nullable, 280 data_type, 281 udt_name, 282 table_schema, 283 (COALESCE(pg_catalog.col_description(c.oid, cols.ordinal_position :: int),'')) as column_comment, 284 (COALESCE(pg_catalog.obj_description(c.oid),'')) as table_comment 285 FROM 286 information_schema.columns cols 287 LEFT JOIN 288 pg_catalog.pg_namespace nsp ON nsp.nspname = cols.table_schema 289 LEFT JOIN 290 pg_catalog.pg_class c ON c.relname = cols.table_name AND c.relnamespace = nsp.oid 291 WHERE %s 292 LEFT(cols.table_schema,8) = 'pg_temp_' 293 `, schemaClause) 294 return query 295 } 296 func (c *DbClient) buildSchemasQueryLegacy() string { 297 298 query := ` 299 WITH distinct_schema AS ( 300 SELECT DISTINCT(foreign_table_schema) 301 FROM 302 information_schema.foreign_tables 303 WHERE 304 foreign_table_schema <> 'steampipe_command' 305 ) 306 SELECT 307 table_name, 308 column_name, 309 column_default, 310 is_nullable, 311 data_type, 312 udt_name, 313 table_schema, 314 (COALESCE(pg_catalog.col_description(c.oid, cols.ordinal_position :: int),'')) as column_comment, 315 (COALESCE(pg_catalog.obj_description(c.oid),'')) as table_comment 316 FROM 317 information_schema.columns cols 318 LEFT JOIN 319 pg_catalog.pg_namespace nsp ON nsp.nspname = cols.table_schema 320 LEFT JOIN 321 pg_catalog.pg_class c ON c.relname = cols.table_name AND c.relnamespace = nsp.oid 322 WHERE 323 cols.table_schema in (select * from distinct_schema) 324 OR 325 LEFT(cols.table_schema,8) = 'pg_temp_' 326 327 ` 328 return query 329 }