github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_client/db_client_search_path.go (about)

     1  package db_client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"strings"
     8  
     9  	"github.com/jackc/pgx/v5"
    10  	"github.com/spf13/viper"
    11  	"github.com/turbot/go-kit/helpers"
    12  	"github.com/turbot/steampipe/pkg/constants"
    13  	"github.com/turbot/steampipe/pkg/db/db_common"
    14  )
    15  
    16  // SetRequiredSessionSearchPath implements Client
    17  // if either a search-path or search-path-prefix is set in config, set the search path
    18  // (otherwise fall back to user search path)
    19  // this just sets the required search path for this client
    20  // - when creating a database session, we will actually set the searchPath
    21  func (c *DbClient) SetRequiredSessionSearchPath(ctx context.Context) error {
    22  	configuredSearchPath := viper.GetStringSlice(constants.ArgSearchPath)
    23  	searchPathPrefix := viper.GetStringSlice(constants.ArgSearchPathPrefix)
    24  
    25  	// strip empty elements from search path and prefix
    26  	configuredSearchPath = helpers.RemoveFromStringSlice(configuredSearchPath, "")
    27  	searchPathPrefix = helpers.RemoveFromStringSlice(searchPathPrefix, "")
    28  
    29  	// default required path to user search path
    30  	requiredSearchPath := c.userSearchPath
    31  
    32  	// store custom search path and search path prefix
    33  	c.searchPathPrefix = searchPathPrefix
    34  
    35  	// if a search path was passed, use that
    36  	if len(configuredSearchPath) > 0 {
    37  		requiredSearchPath = configuredSearchPath
    38  	}
    39  
    40  	// add in the prefix if present
    41  	requiredSearchPath = db_common.AddSearchPathPrefix(searchPathPrefix, requiredSearchPath)
    42  
    43  	requiredSearchPath = db_common.EnsureInternalSchemaSuffix(requiredSearchPath)
    44  
    45  	// if either configuredSearchPath or searchPathPrefix are set, store requiredSearchPath as customSearchPath
    46  	if len(configuredSearchPath)+len(searchPathPrefix) > 0 {
    47  		c.customSearchPath = requiredSearchPath
    48  	} else {
    49  		// otherwise clear it
    50  		c.customSearchPath = nil
    51  	}
    52  
    53  	return nil
    54  }
    55  
    56  func (c *DbClient) LoadUserSearchPath(ctx context.Context) error {
    57  	conn, err := c.managementPool.Acquire(ctx)
    58  	if err != nil {
    59  		return err
    60  	}
    61  	defer conn.Release()
    62  	return c.loadUserSearchPath(ctx, conn.Conn())
    63  }
    64  
    65  func (c *DbClient) loadUserSearchPath(ctx context.Context, connection *pgx.Conn) error {
    66  	// load the user search path
    67  	userSearchPath, err := db_common.GetUserSearchPath(ctx, connection)
    68  	if err != nil {
    69  		return err
    70  	}
    71  	// update the cached value
    72  	c.userSearchPath = userSearchPath
    73  	return nil
    74  }
    75  
    76  // GetRequiredSessionSearchPath implements Client
    77  func (c *DbClient) GetRequiredSessionSearchPath() []string {
    78  	if c.customSearchPath != nil {
    79  		return c.customSearchPath
    80  	}
    81  
    82  	return c.userSearchPath
    83  }
    84  
    85  func (c *DbClient) GetCustomSearchPath() []string {
    86  	return c.customSearchPath
    87  }
    88  
    89  // ensure the search path for the database session is as required
    90  func (c *DbClient) ensureSessionSearchPath(ctx context.Context, session *db_common.DatabaseSession) error {
    91  	log.Printf("[TRACE] ensureSessionSearchPath")
    92  
    93  	// update the stored value of user search path
    94  	// this might have changed if a connection has been added/removed
    95  	if err := c.loadUserSearchPath(ctx, session.Connection.Conn()); err != nil {
    96  		return err
    97  	}
    98  
    99  	// get the required search path which is either a custom search path (if present) or the user search path
   100  	requiredSearchPath := c.GetRequiredSessionSearchPath()
   101  
   102  	// now determine whether the session search path is the same as the required search path
   103  	// if so, return
   104  	if strings.Join(session.SearchPath, ",") == strings.Join(requiredSearchPath, ",") {
   105  		log.Printf("[TRACE] session search path is already correct - nothing to do")
   106  		return nil
   107  	}
   108  
   109  	// so we need to set the search path
   110  	log.Printf("[TRACE] session search path will be updated to  %s", strings.Join(c.customSearchPath, ","))
   111  
   112  	err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error {
   113  		_, err := tx.Exec(ctx, fmt.Sprintf("set search_path to %s", strings.Join(db_common.PgEscapeSearchPath(requiredSearchPath), ",")))
   114  		return err
   115  	})
   116  
   117  	if err == nil {
   118  		// update the session search path property
   119  		session.SearchPath = requiredSearchPath
   120  	}
   121  	return err
   122  }