github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_client/db_client.go (about)

     1  package db_client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"strings"
     8  	"sync"
     9  
    10  	"github.com/jackc/pgx/v5"
    11  	"github.com/jackc/pgx/v5/pgconn"
    12  	"github.com/jackc/pgx/v5/pgxpool"
    13  	"github.com/spf13/viper"
    14  	"github.com/turbot/steampipe/pkg/constants"
    15  	"github.com/turbot/steampipe/pkg/db/db_common"
    16  	"github.com/turbot/steampipe/pkg/serversettings"
    17  	"github.com/turbot/steampipe/pkg/steampipeconfig"
    18  	"github.com/turbot/steampipe/pkg/utils"
    19  	"golang.org/x/exp/maps"
    20  	"golang.org/x/sync/semaphore"
    21  )
    22  
    23  // DbClient wraps over `sql.DB` and gives an interface to the database
    24  type DbClient struct {
    25  	connectionString string
    26  
    27  	// connection userPool for user initiated queries
    28  	userPool *pgxpool.Pool
    29  
    30  	// connection used to run system/plumbing queries (connection state, server settings)
    31  	managementPool *pgxpool.Pool
    32  
    33  	// the settings of the server that this client is connected to
    34  	serverSettings *db_common.ServerSettings
    35  
    36  	// this flag is set if the service that this client
    37  	// is connected to is running in the same physical system
    38  	isLocalService bool
    39  
    40  	// concurrency management for db session access
    41  	parallelSessionInitLock *semaphore.Weighted
    42  
    43  	// map of database sessions, keyed to the backend_pid in postgres
    44  	// used to update session search path where necessary
    45  	// TODO: there's no code which cleans up this map when connections get dropped by pgx
    46  	// https://github.com/turbot/steampipe/issues/3737
    47  	sessions map[uint32]*db_common.DatabaseSession
    48  
    49  	// allows locked access to the 'sessions' map
    50  	sessionsMutex *sync.Mutex
    51  
    52  	// if a custom search path or a prefix is used, store it here
    53  	customSearchPath []string
    54  	searchPathPrefix []string
    55  	// the default user search path
    56  	userSearchPath []string
    57  	// disable timing - set whilst in process of querying the timing
    58  	disableTiming        bool
    59  	onConnectionCallback DbConnectionCallback
    60  }
    61  
    62  func NewDbClient(ctx context.Context, connectionString string, onConnectionCallback DbConnectionCallback, opts ...ClientOption) (_ *DbClient, err error) {
    63  	utils.LogTime("db_client.NewDbClient start")
    64  	defer utils.LogTime("db_client.NewDbClient end")
    65  
    66  	wg := &sync.WaitGroup{}
    67  	// wrap onConnectionCallback to use wait group
    68  	var wrappedOnConnectionCallback DbConnectionCallback
    69  	if onConnectionCallback != nil {
    70  		wrappedOnConnectionCallback = func(ctx context.Context, conn *pgx.Conn) error {
    71  			wg.Add(1)
    72  			defer wg.Done()
    73  			return onConnectionCallback(ctx, conn)
    74  		}
    75  	}
    76  
    77  	client := &DbClient{
    78  		// a weighted semaphore to control the maximum number parallel
    79  		// initializations under way
    80  		parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits),
    81  		sessions:                make(map[uint32]*db_common.DatabaseSession),
    82  		sessionsMutex:           &sync.Mutex{},
    83  		// store the callback
    84  		onConnectionCallback: wrappedOnConnectionCallback,
    85  		connectionString:     connectionString,
    86  	}
    87  
    88  	defer func() {
    89  		if err != nil {
    90  			// try closing the client
    91  			client.Close(ctx)
    92  		}
    93  	}()
    94  
    95  	config := clientConfig{}
    96  	for _, o := range opts {
    97  		o(&config)
    98  	}
    99  
   100  	if err := client.establishConnectionPool(ctx, config); err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	// load up the server settings
   105  	if err := client.loadServerSettings(ctx); err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	// set user search path
   110  	if err := client.LoadUserSearchPath(ctx); err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	// populate customSearchPath
   115  	if err := client.SetRequiredSessionSearchPath(ctx); err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	return client, nil
   120  }
   121  
   122  func (c *DbClient) closePools() {
   123  	if c.userPool != nil {
   124  		c.userPool.Close()
   125  	}
   126  	if c.managementPool != nil {
   127  		c.managementPool.Close()
   128  	}
   129  }
   130  
   131  func (c *DbClient) loadServerSettings(ctx context.Context) error {
   132  	serverSettings, err := serversettings.Load(ctx, c.managementPool)
   133  	if err != nil {
   134  		if notFound := db_common.IsRelationNotFoundError(err); notFound {
   135  			// when connecting to pre-0.21.0 services, the steampipe_server_settings table will not be available.
   136  			// this is expected and not an error
   137  			// code which uses steampipe_server_settings should handle this
   138  			log.Printf("[TRACE] could not find %s.%s table. skipping\n", constants.InternalSchema, constants.ServerSettingsTable)
   139  			return nil
   140  		}
   141  		return err
   142  	}
   143  	c.serverSettings = serverSettings
   144  	log.Println("[TRACE] loaded server settings:", serverSettings)
   145  	return nil
   146  }
   147  
   148  func (c *DbClient) shouldFetchTiming() bool {
   149  	// check for override flag (this is to prevent timing being fetched when we read the timing metadata table)
   150  	if c.disableTiming {
   151  		return false
   152  	}
   153  	// only fetch timing if timing flag is set, or output is JSON
   154  	return (viper.GetString(constants.ArgTiming) != constants.ArgOff) ||
   155  		(viper.GetString(constants.ArgOutput) == constants.OutputFormatJSON)
   156  
   157  }
   158  func (c *DbClient) shouldFetchVerboseTiming() bool {
   159  	return (viper.GetString(constants.ArgTiming) == constants.ArgVerbose) ||
   160  		(viper.GetString(constants.ArgOutput) == constants.OutputFormatJSON)
   161  }
   162  
   163  // ServerSettings returns the settings of the steampipe service that this DbClient is connected to
   164  //
   165  // Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected.
   166  // Code which read server_settings should take this into account.
   167  func (c *DbClient) ServerSettings() *db_common.ServerSettings {
   168  	return c.serverSettings
   169  }
   170  
   171  // RegisterNotificationListener has an empty implementation
   172  // NOTE: we do not (currently) support notifications from remote connections
   173  func (c *DbClient) RegisterNotificationListener(func(notification *pgconn.Notification)) {}
   174  
   175  // Close implements Client
   176  
   177  // closes the connection to the database and shuts down the backend
   178  func (c *DbClient) Close(context.Context) error {
   179  	log.Printf("[TRACE] DbClient.Close %v", c.userPool)
   180  	c.closePools()
   181  	// nullify active sessions, since with the closing of the pools
   182  	// none of the sessions will be valid anymore
   183  	c.sessions = nil
   184  
   185  	return nil
   186  }
   187  
   188  // GetSchemaFromDB  retrieves schemas for all steampipe connections (EXCEPT DISABLED CONNECTIONS)
   189  // NOTE: it optimises the schema extraction by extracting schema information for
   190  // connections backed by distinct plugins and then fanning back out.
   191  func (c *DbClient) GetSchemaFromDB(ctx context.Context) (*db_common.SchemaMetadata, error) {
   192  	log.Printf("[INFO] DbClient GetSchemaFromDB")
   193  	mgmtConn, err := c.managementPool.Acquire(ctx)
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  	defer mgmtConn.Release()
   198  
   199  	// for optimisation purposes, try to load connection state and build a map of schemas to load
   200  	// (if we are connected to a remote server running an older CLI,
   201  	// this load may fail, in which case bypass the optimisation)
   202  	connectionStateMap, err := steampipeconfig.LoadConnectionState(ctx, mgmtConn.Conn(), steampipeconfig.WithWaitUntilLoading())
   203  	// NOTE: if we failed to load connection state, this may be because we are connected to an older version of the CLI
   204  	// use legacy (v0.19.x) schema loading code
   205  	if err != nil {
   206  		return c.GetSchemaFromDBLegacy(ctx, mgmtConn)
   207  	}
   208  
   209  	// build a ConnectionSchemaMap object to identify the schemas to load
   210  	connectionSchemaMap := steampipeconfig.NewConnectionSchemaMap(ctx, connectionStateMap, c.GetRequiredSessionSearchPath())
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	// get the unique schema - we use this to limit the schemas we load from the database
   216  	schemas := maps.Keys(connectionSchemaMap)
   217  
   218  	// build a query to retrieve these schemas
   219  	query := c.buildSchemasQuery(schemas...)
   220  
   221  	// build schema metadata from query result
   222  	metadata, err := db_common.LoadSchemaMetadata(ctx, mgmtConn.Conn(), query)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	// we now need to add in all other schemas which have the same schemas as those we have loaded
   228  	for loadedSchema, otherSchemas := range connectionSchemaMap {
   229  		// all 'otherSchema's have the same schema as loadedSchema
   230  		exemplarSchema, ok := metadata.Schemas[loadedSchema]
   231  		if !ok {
   232  			// should can happen in the case of a dynamic plugin with no tables - use empty schema
   233  			exemplarSchema = make(map[string]db_common.TableSchema)
   234  		}
   235  
   236  		for _, s := range otherSchemas {
   237  			metadata.Schemas[s] = exemplarSchema
   238  		}
   239  	}
   240  
   241  	return metadata, nil
   242  }
   243  
   244  func (c *DbClient) GetSchemaFromDBLegacy(ctx context.Context, conn *pgxpool.Conn) (*db_common.SchemaMetadata, error) {
   245  	// build a query to retrieve these schemas
   246  	query := c.buildSchemasQueryLegacy()
   247  
   248  	// build schema metadata from query result
   249  	return db_common.LoadSchemaMetadata(ctx, conn.Conn(), query)
   250  }
   251  
   252  // refreshDbClient terminates the current connection and opens up a new connection to the service.
   253  func (c *DbClient) ResetPools(ctx context.Context) {
   254  	log.Println("[TRACE] db_client.ResetPools start")
   255  	defer log.Println("[TRACE] db_client.ResetPools end")
   256  
   257  	c.userPool.Reset()
   258  	c.managementPool.Reset()
   259  }
   260  
   261  func (c *DbClient) buildSchemasQuery(schemas ...string) string {
   262  	for idx, s := range schemas {
   263  		schemas[idx] = fmt.Sprintf("'%s'", s)
   264  	}
   265  
   266  	// build the schemas filter clause
   267  	schemaClause := ""
   268  	if len(schemas) > 0 {
   269  		schemaClause = fmt.Sprintf(`
   270      cols.table_schema in (%s)
   271  	OR`, strings.Join(schemas, ","))
   272  	}
   273  
   274  	query := fmt.Sprintf(`
   275  SELECT
   276  		table_name,
   277  		column_name,
   278  		column_default,
   279  		is_nullable,
   280  		data_type,
   281  		udt_name,
   282  		table_schema,
   283  		(COALESCE(pg_catalog.col_description(c.oid, cols.ordinal_position :: int),'')) as column_comment,
   284  		(COALESCE(pg_catalog.obj_description(c.oid),'')) as table_comment
   285  FROM
   286      information_schema.columns cols
   287  LEFT JOIN
   288      pg_catalog.pg_namespace nsp ON nsp.nspname = cols.table_schema
   289  LEFT JOIN
   290      pg_catalog.pg_class c ON c.relname = cols.table_name AND c.relnamespace = nsp.oid
   291  WHERE %s
   292  	LEFT(cols.table_schema,8) = 'pg_temp_'
   293  `, schemaClause)
   294  	return query
   295  }
   296  func (c *DbClient) buildSchemasQueryLegacy() string {
   297  
   298  	query := `
   299  WITH distinct_schema AS (
   300  	SELECT DISTINCT(foreign_table_schema) 
   301  	FROM 
   302  		information_schema.foreign_tables 
   303  	WHERE 
   304  		foreign_table_schema <> 'steampipe_command'
   305  )
   306  SELECT
   307      table_name,
   308      column_name,
   309      column_default,
   310      is_nullable,
   311      data_type,
   312      udt_name,
   313      table_schema,
   314      (COALESCE(pg_catalog.col_description(c.oid, cols.ordinal_position :: int),'')) as column_comment,
   315      (COALESCE(pg_catalog.obj_description(c.oid),'')) as table_comment
   316  FROM
   317      information_schema.columns cols
   318  LEFT JOIN
   319      pg_catalog.pg_namespace nsp ON nsp.nspname = cols.table_schema
   320  LEFT JOIN
   321      pg_catalog.pg_class c ON c.relname = cols.table_name AND c.relnamespace = nsp.oid
   322  WHERE
   323  	cols.table_schema in (select * from distinct_schema)
   324  	OR
   325      LEFT(cols.table_schema,8) = 'pg_temp_'
   326  
   327  `
   328  	return query
   329  }