github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_client/db_client_execute.go (about) 1 package db_client 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "log" 8 "net/netip" 9 "strings" 10 "time" 11 12 "github.com/google/uuid" 13 "github.com/jackc/pgx/v5" 14 "github.com/jackc/pgx/v5/pgtype" 15 "github.com/spf13/viper" 16 "github.com/turbot/go-kit/helpers" 17 "github.com/turbot/steampipe/pkg/constants" 18 "github.com/turbot/steampipe/pkg/db/db_common" 19 "github.com/turbot/steampipe/pkg/error_helpers" 20 "github.com/turbot/steampipe/pkg/query/queryresult" 21 "github.com/turbot/steampipe/pkg/statushooks" 22 "github.com/turbot/steampipe/pkg/utils" 23 "golang.org/x/text/language" 24 "golang.org/x/text/message" 25 ) 26 27 // ExecuteSync implements Client 28 // execute a query against this client and wait for the result 29 func (c *DbClient) ExecuteSync(ctx context.Context, query string, args ...any) (*queryresult.SyncQueryResult, error) { 30 // acquire a session 31 sessionResult := c.AcquireSession(ctx) 32 if sessionResult.Error != nil { 33 return nil, sessionResult.Error 34 } 35 36 defer func() { 37 // we need to do this in a closure, otherwise the ctx will be evaluated immediately 38 // and not in call-time 39 sessionResult.Session.Close(error_helpers.IsContextCanceled(ctx)) 40 }() 41 return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, args...) 42 } 43 44 // ExecuteSyncInSession implements Client 45 // execute a query against this client and wait for the result 46 func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, args ...any) (*queryresult.SyncQueryResult, error) { 47 if query == "" { 48 return &queryresult.SyncQueryResult{}, nil 49 } 50 51 result, err := c.ExecuteInSession(ctx, session, nil, query, args...) 52 if err != nil { 53 return nil, error_helpers.WrapError(err) 54 } 55 56 syncResult := &queryresult.SyncQueryResult{Cols: result.Cols} 57 for row := range *result.RowChan { 58 select { 59 case <-ctx.Done(): 60 default: 61 // save the first row error to return 62 if row.Error != nil && err == nil { 63 err = error_helpers.WrapError(row.Error) 64 } 65 syncResult.Rows = append(syncResult.Rows, row) 66 } 67 } 68 if c.shouldFetchTiming() { 69 syncResult.TimingResult = <-result.TimingResult 70 } 71 72 return syncResult, err 73 } 74 75 // Execute implements Client 76 // execute the query in the given Context 77 // NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication 78 func (c *DbClient) Execute(ctx context.Context, query string, args ...any) (*queryresult.Result, error) { 79 // acquire a session 80 sessionResult := c.AcquireSession(ctx) 81 if sessionResult.Error != nil { 82 return nil, sessionResult.Error 83 } 84 85 // define callback to close session when the async execution is complete 86 closeSessionCallback := func() { sessionResult.Session.Close(error_helpers.IsContextCanceled(ctx)) } 87 return c.ExecuteInSession(ctx, sessionResult.Session, closeSessionCallback, query, args...) 88 } 89 90 // ExecuteInSession implements Client 91 // execute the query in the given Context using the provided DatabaseSession 92 // ExecuteInSession assumes no responsibility over the lifecycle of the DatabaseSession - that is the responsibility of the caller 93 // NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication 94 func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, onComplete func(), query string, args ...any) (res *queryresult.Result, err error) { 95 if query == "" { 96 return queryresult.NewResult(nil), nil 97 } 98 99 // fail-safes 100 if session == nil { 101 return nil, fmt.Errorf("nil session passed to ExecuteInSession") 102 } 103 if session.Connection == nil { 104 return nil, fmt.Errorf("nil database connection passed to ExecuteInSession") 105 } 106 startTime := time.Now() 107 // get a context with a timeout for the query to execute within 108 // we don't use the cancelFn from this timeout context, since usage will lead to 'pgx' 109 // prematurely closing the database connection that this query executed in 110 ctxExecute := c.getExecuteContext(ctx) 111 112 var tx *sql.Tx 113 114 defer func() { 115 if err != nil { 116 err = error_helpers.HandleQueryTimeoutError(err) 117 // stop spinner in case of error 118 statushooks.Done(ctxExecute) 119 // error - rollback transaction if we have one 120 if tx != nil { 121 _ = tx.Rollback() 122 } 123 // in case of error call the onComplete callback 124 if onComplete != nil { 125 onComplete() 126 } 127 } 128 }() 129 130 // start query 131 var rows pgx.Rows 132 rows, err = c.startQueryWithRetries(ctxExecute, session, query, args...) 133 if err != nil { 134 return 135 } 136 137 colDefs := fieldDescriptionsToColumns(rows.FieldDescriptions(), session.Connection.Conn()) 138 139 result := queryresult.NewResult(colDefs) 140 141 // read the rows in a go routine 142 go func() { 143 // define a callback which fetches the timing information 144 // this will be invoked after reading rows is complete but BEFORE closing the rows object (which closes the connection) 145 timingCallback := func() { 146 c.getQueryTiming(ctxExecute, startTime, session, result.TimingResult) 147 } 148 149 // read in the rows and stream to the query result object 150 c.readRows(ctxExecute, rows, result, timingCallback) 151 152 // call the completion callback - if one was provided 153 if onComplete != nil { 154 onComplete() 155 } 156 }() 157 158 return result, nil 159 } 160 161 func (c *DbClient) getExecuteContext(ctx context.Context) context.Context { 162 queryTimeout := time.Duration(viper.GetInt(constants.ArgDatabaseQueryTimeout)) * time.Second 163 // if timeout is zero, do not set a timeout 164 if queryTimeout == 0 { 165 return ctx 166 } 167 // create a context with a deadline 168 shouldBeDoneBy := time.Now().Add(queryTimeout) 169 //nolint:golint,lostcancel //we don't use this cancel fn because, pgx prematurely cancels the PG connection when this cancel gets called in 'defer' 170 newCtx, _ := context.WithDeadline(ctx, shouldBeDoneBy) 171 172 return newCtx 173 } 174 175 func (c *DbClient) getQueryTiming(ctx context.Context, startTime time.Time, session *db_common.DatabaseSession, resultChannel chan *queryresult.TimingResult) { 176 // do not fetch if timing is disabled, unless output not JSON 177 if !c.shouldFetchTiming() { 178 return 179 } 180 181 var timingResult = &queryresult.TimingResult{ 182 DurationMs: time.Since(startTime).Milliseconds(), 183 } 184 // disable fetching timing information to avoid recursion 185 c.disableTiming = true 186 187 // whatever happens, we need to reenable timing, and send the result back with at least the duration 188 defer func() { 189 c.disableTiming = false 190 resultChannel <- timingResult 191 }() 192 193 // load the timing summary 194 summary, err := c.loadTimingSummary(ctx, session) 195 if err != nil { 196 log.Printf("[WARN] getQueryTiming: failed to read scan metadata, err: %s", err) 197 return 198 } 199 200 // only load the individual scan metadata if output is JSON or timing is verbose 201 var scans []*queryresult.ScanMetadataRow 202 if c.shouldFetchVerboseTiming() { 203 scans, err = c.loadTimingMetadata(ctx, session) 204 if err != nil { 205 log.Printf("[WARN] getQueryTiming: failed to read scan metadata, err: %s", err) 206 return 207 } 208 } 209 210 // populate hydrate calls and rows fetched 211 timingResult.Initialise(summary, scans) 212 } 213 214 func (c *DbClient) loadTimingSummary(ctx context.Context, session *db_common.DatabaseSession) (*queryresult.QueryRowSummary, error) { 215 var summary = &queryresult.QueryRowSummary{} 216 err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error { 217 query := fmt.Sprintf(`select uncached_rows_fetched, 218 cached_rows_fetched, 219 hydrate_calls, 220 scan_count, 221 connection_count from %s.%s `, constants.InternalSchema, constants.ForeignTableScanMetadataSummary) 222 //query := fmt.Sprintf("select id, 'table' as table, cache_hit, rows_fetched, hydrate_calls, start_time, duration, columns, 'limit' as limit, quals from %s.%s where id > %d", constants.InternalSchema, constants.ForeignTableScanMetadata, session.ScanMetadataMaxId) 223 rows, err := tx.Query(ctx, query) 224 if err != nil { 225 return err 226 } 227 228 // scan into summary 229 summary, err = pgx.CollectOneRow(rows, pgx.RowToAddrOfStructByName[queryresult.QueryRowSummary]) 230 // no rows counts as an error 231 if err != nil { 232 return err 233 } 234 return nil 235 }) 236 return summary, err 237 } 238 239 func (c *DbClient) loadTimingMetadata(ctx context.Context, session *db_common.DatabaseSession) ([]*queryresult.ScanMetadataRow, error) { 240 var scans []*queryresult.ScanMetadataRow 241 242 err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error { 243 query := fmt.Sprintf(` 244 select connection, 245 "table", 246 cache_hit, 247 rows_fetched, 248 hydrate_calls, 249 start_time, 250 duration_ms, 251 columns, 252 "limit", 253 quals from %s.%s order by duration_ms desc`, constants.InternalSchema, constants.ForeignTableScanMetadata) 254 rows, err := tx.Query(ctx, query) 255 if err != nil { 256 return err 257 } 258 259 scans, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[queryresult.ScanMetadataRow]) 260 return err 261 }) 262 return scans, err 263 } 264 265 // run query in a goroutine, so we can check for cancellation 266 // in case the client becomes unresponsive and does not respect context cancellation 267 func (c *DbClient) startQuery(ctx context.Context, conn *pgx.Conn, query string, args ...any) (rows pgx.Rows, err error) { 268 doneChan := make(chan bool) 269 go func() { 270 // start asynchronous query 271 rows, err = conn.Query(ctx, query, args...) 272 close(doneChan) 273 }() 274 275 select { 276 case <-doneChan: 277 case <-ctx.Done(): 278 err = ctx.Err() 279 } 280 return 281 } 282 283 func (c *DbClient) readRows(ctx context.Context, rows pgx.Rows, result *queryresult.Result, timingCallback func()) { 284 // defer this, so that these get cleaned up even if there is an unforeseen error 285 defer func() { 286 // we are done fetching results. time for display. clear the status indication 287 statushooks.Done(ctx) 288 // call the timing callback BEFORE closing the rows 289 timingCallback() 290 // close the sql rows object 291 rows.Close() 292 if err := rows.Err(); err != nil { 293 result.StreamError(err) 294 } 295 // close the channels in the result object 296 result.Close() 297 298 }() 299 300 rowCount := 0 301 Loop: 302 for rows.Next() { 303 select { 304 case <-ctx.Done(): 305 statushooks.SetStatus(ctx, "Cancelling query") 306 break Loop 307 default: 308 rowResult, err := readRow(rows, result.Cols) 309 if err != nil { 310 // the error will be streamed in the defer 311 break Loop 312 } 313 314 // TACTICAL 315 // determine whether to stop the spinner as soon as we stream a row or to wait for completion 316 if isStreamingOutput() { 317 statushooks.Done(ctx) 318 } 319 320 result.StreamRow(rowResult) 321 322 // update the status message with the count of rows that have already been fetched 323 // this will not show if the spinner is not active 324 statushooks.SetStatus(ctx, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount))) 325 rowCount++ 326 } 327 } 328 } 329 330 func readRow(rows pgx.Rows, cols []*queryresult.ColumnDef) ([]interface{}, error) { 331 columnValues, err := rows.Values() 332 if err != nil { 333 return nil, error_helpers.WrapError(err) 334 } 335 return populateRow(columnValues, cols) 336 } 337 338 func populateRow(columnValues []interface{}, cols []*queryresult.ColumnDef) ([]interface{}, error) { 339 result := make([]interface{}, len(columnValues)) 340 for i, columnValue := range columnValues { 341 if columnValue != nil { 342 result[i] = columnValue 343 switch cols[i].DataType { 344 case "_TEXT": 345 if arr, ok := columnValue.([]interface{}); ok { 346 elements := utils.Map(arr, func(e interface{}) string { return e.(string) }) 347 result[i] = strings.Join(elements, ",") 348 } 349 case "INET": 350 if inet, ok := columnValue.(netip.Prefix); ok { 351 result[i] = strings.TrimSuffix(inet.String(), "/32") 352 } 353 case "UUID": 354 if bytes, ok := columnValue.([16]uint8); ok { 355 if u, err := uuid.FromBytes(bytes[:]); err == nil { 356 result[i] = u 357 } 358 } 359 case "TIME": 360 if t, ok := columnValue.(pgtype.Time); ok { 361 result[i] = time.UnixMicro(t.Microseconds).UTC().Format("15:04:05") 362 } 363 case "INTERVAL": 364 if interval, ok := columnValue.(pgtype.Interval); ok { 365 var sb strings.Builder 366 years := interval.Months / 12 367 months := interval.Months % 12 368 if years > 0 { 369 sb.WriteString(fmt.Sprintf("%d %s ", years, utils.Pluralize("year", int(years)))) 370 } 371 if months > 0 { 372 sb.WriteString(fmt.Sprintf("%d %s ", months, utils.Pluralize("mon", int(months)))) 373 } 374 if interval.Days > 0 { 375 sb.WriteString(fmt.Sprintf("%d %s ", interval.Days, utils.Pluralize("day", int(interval.Days)))) 376 } 377 if interval.Microseconds > 0 { 378 d := time.Duration(interval.Microseconds) * time.Microsecond 379 formatStr := time.Unix(0, 0).UTC().Add(d).Format("15:04:05") 380 sb.WriteString(formatStr) 381 } 382 result[i] = sb.String() 383 } 384 385 case "NUMERIC": 386 if numeric, ok := columnValue.(pgtype.Numeric); ok { 387 if f, err := numeric.Float64Value(); err == nil { 388 result[i] = f.Float64 389 } 390 } 391 } 392 } 393 } 394 return result, nil 395 } 396 397 func isStreamingOutput() bool { 398 outputFormat := viper.GetString(constants.ArgOutput) 399 400 return helpers.StringSliceContains([]string{constants.OutputFormatCSV, constants.OutputFormatLine}, outputFormat) 401 } 402 403 func humanizeRowCount(count int) string { 404 p := message.NewPrinter(language.English) 405 return p.Sprintf("%d", count) 406 }