github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_client/db_client_search_path.go (about) 1 package db_client 2 3 import ( 4 "context" 5 "fmt" 6 "log" 7 "strings" 8 9 "github.com/jackc/pgx/v5" 10 "github.com/spf13/viper" 11 "github.com/turbot/go-kit/helpers" 12 "github.com/turbot/steampipe/pkg/constants" 13 "github.com/turbot/steampipe/pkg/db/db_common" 14 ) 15 16 // SetRequiredSessionSearchPath implements Client 17 // if either a search-path or search-path-prefix is set in config, set the search path 18 // (otherwise fall back to user search path) 19 // this just sets the required search path for this client 20 // - when creating a database session, we will actually set the searchPath 21 func (c *DbClient) SetRequiredSessionSearchPath(ctx context.Context) error { 22 configuredSearchPath := viper.GetStringSlice(constants.ArgSearchPath) 23 searchPathPrefix := viper.GetStringSlice(constants.ArgSearchPathPrefix) 24 25 // strip empty elements from search path and prefix 26 configuredSearchPath = helpers.RemoveFromStringSlice(configuredSearchPath, "") 27 searchPathPrefix = helpers.RemoveFromStringSlice(searchPathPrefix, "") 28 29 // default required path to user search path 30 requiredSearchPath := c.userSearchPath 31 32 // store custom search path and search path prefix 33 c.searchPathPrefix = searchPathPrefix 34 35 // if a search path was passed, use that 36 if len(configuredSearchPath) > 0 { 37 requiredSearchPath = configuredSearchPath 38 } 39 40 // add in the prefix if present 41 requiredSearchPath = db_common.AddSearchPathPrefix(searchPathPrefix, requiredSearchPath) 42 43 requiredSearchPath = db_common.EnsureInternalSchemaSuffix(requiredSearchPath) 44 45 // if either configuredSearchPath or searchPathPrefix are set, store requiredSearchPath as customSearchPath 46 if len(configuredSearchPath)+len(searchPathPrefix) > 0 { 47 c.customSearchPath = requiredSearchPath 48 } else { 49 // otherwise clear it 50 c.customSearchPath = nil 51 } 52 53 return nil 54 } 55 56 func (c *DbClient) LoadUserSearchPath(ctx context.Context) error { 57 conn, err := c.managementPool.Acquire(ctx) 58 if err != nil { 59 return err 60 } 61 defer conn.Release() 62 return c.loadUserSearchPath(ctx, conn.Conn()) 63 } 64 65 func (c *DbClient) loadUserSearchPath(ctx context.Context, connection *pgx.Conn) error { 66 // load the user search path 67 userSearchPath, err := db_common.GetUserSearchPath(ctx, connection) 68 if err != nil { 69 return err 70 } 71 // update the cached value 72 c.userSearchPath = userSearchPath 73 return nil 74 } 75 76 // GetRequiredSessionSearchPath implements Client 77 func (c *DbClient) GetRequiredSessionSearchPath() []string { 78 if c.customSearchPath != nil { 79 return c.customSearchPath 80 } 81 82 return c.userSearchPath 83 } 84 85 func (c *DbClient) GetCustomSearchPath() []string { 86 return c.customSearchPath 87 } 88 89 // ensure the search path for the database session is as required 90 func (c *DbClient) ensureSessionSearchPath(ctx context.Context, session *db_common.DatabaseSession) error { 91 log.Printf("[TRACE] ensureSessionSearchPath") 92 93 // update the stored value of user search path 94 // this might have changed if a connection has been added/removed 95 if err := c.loadUserSearchPath(ctx, session.Connection.Conn()); err != nil { 96 return err 97 } 98 99 // get the required search path which is either a custom search path (if present) or the user search path 100 requiredSearchPath := c.GetRequiredSessionSearchPath() 101 102 // now determine whether the session search path is the same as the required search path 103 // if so, return 104 if strings.Join(session.SearchPath, ",") == strings.Join(requiredSearchPath, ",") { 105 log.Printf("[TRACE] session search path is already correct - nothing to do") 106 return nil 107 } 108 109 // so we need to set the search path 110 log.Printf("[TRACE] session search path will be updated to %s", strings.Join(c.customSearchPath, ",")) 111 112 err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error { 113 _, err := tx.Exec(ctx, fmt.Sprintf("set search_path to %s", strings.Join(db_common.PgEscapeSearchPath(requiredSearchPath), ","))) 114 return err 115 }) 116 117 if err == nil { 118 // update the session search path property 119 session.SearchPath = requiredSearchPath 120 } 121 return err 122 }