github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_client/db_client_execute.go (about)

     1  package db_client
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"log"
     8  	"net/netip"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/google/uuid"
    13  	"github.com/jackc/pgx/v5"
    14  	"github.com/jackc/pgx/v5/pgtype"
    15  	"github.com/spf13/viper"
    16  	"github.com/turbot/go-kit/helpers"
    17  	"github.com/turbot/steampipe/pkg/constants"
    18  	"github.com/turbot/steampipe/pkg/db/db_common"
    19  	"github.com/turbot/steampipe/pkg/error_helpers"
    20  	"github.com/turbot/steampipe/pkg/query/queryresult"
    21  	"github.com/turbot/steampipe/pkg/statushooks"
    22  	"github.com/turbot/steampipe/pkg/utils"
    23  	"golang.org/x/text/language"
    24  	"golang.org/x/text/message"
    25  )
    26  
    27  // ExecuteSync implements Client
    28  // execute a query against this client and wait for the result
    29  func (c *DbClient) ExecuteSync(ctx context.Context, query string, args ...any) (*queryresult.SyncQueryResult, error) {
    30  	// acquire a session
    31  	sessionResult := c.AcquireSession(ctx)
    32  	if sessionResult.Error != nil {
    33  		return nil, sessionResult.Error
    34  	}
    35  
    36  	defer func() {
    37  		// we need to do this in a closure, otherwise the ctx will be evaluated immediately
    38  		// and not in call-time
    39  		sessionResult.Session.Close(error_helpers.IsContextCanceled(ctx))
    40  	}()
    41  	return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, args...)
    42  }
    43  
    44  // ExecuteSyncInSession implements Client
    45  // execute a query against this client and wait for the result
    46  func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, args ...any) (*queryresult.SyncQueryResult, error) {
    47  	if query == "" {
    48  		return &queryresult.SyncQueryResult{}, nil
    49  	}
    50  
    51  	result, err := c.ExecuteInSession(ctx, session, nil, query, args...)
    52  	if err != nil {
    53  		return nil, error_helpers.WrapError(err)
    54  	}
    55  
    56  	syncResult := &queryresult.SyncQueryResult{Cols: result.Cols}
    57  	for row := range *result.RowChan {
    58  		select {
    59  		case <-ctx.Done():
    60  		default:
    61  			// save the first row error to return
    62  			if row.Error != nil && err == nil {
    63  				err = error_helpers.WrapError(row.Error)
    64  			}
    65  			syncResult.Rows = append(syncResult.Rows, row)
    66  		}
    67  	}
    68  	if c.shouldFetchTiming() {
    69  		syncResult.TimingResult = <-result.TimingResult
    70  	}
    71  
    72  	return syncResult, err
    73  }
    74  
    75  // Execute implements Client
    76  // execute the query in the given Context
    77  // NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
    78  func (c *DbClient) Execute(ctx context.Context, query string, args ...any) (*queryresult.Result, error) {
    79  	// acquire a session
    80  	sessionResult := c.AcquireSession(ctx)
    81  	if sessionResult.Error != nil {
    82  		return nil, sessionResult.Error
    83  	}
    84  
    85  	// define callback to close session when the async execution is complete
    86  	closeSessionCallback := func() { sessionResult.Session.Close(error_helpers.IsContextCanceled(ctx)) }
    87  	return c.ExecuteInSession(ctx, sessionResult.Session, closeSessionCallback, query, args...)
    88  }
    89  
    90  // ExecuteInSession implements Client
    91  // execute the query in the given Context using the provided DatabaseSession
    92  // ExecuteInSession assumes no responsibility over the lifecycle of the DatabaseSession - that is the responsibility of the caller
    93  // NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
    94  func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, onComplete func(), query string, args ...any) (res *queryresult.Result, err error) {
    95  	if query == "" {
    96  		return queryresult.NewResult(nil), nil
    97  	}
    98  
    99  	// fail-safes
   100  	if session == nil {
   101  		return nil, fmt.Errorf("nil session passed to ExecuteInSession")
   102  	}
   103  	if session.Connection == nil {
   104  		return nil, fmt.Errorf("nil database connection passed to ExecuteInSession")
   105  	}
   106  	startTime := time.Now()
   107  	// get a context with a timeout for the query to execute within
   108  	// we don't use the cancelFn from this timeout context, since usage will lead to 'pgx'
   109  	// prematurely closing the database connection that this query executed in
   110  	ctxExecute := c.getExecuteContext(ctx)
   111  
   112  	var tx *sql.Tx
   113  
   114  	defer func() {
   115  		if err != nil {
   116  			err = error_helpers.HandleQueryTimeoutError(err)
   117  			// stop spinner in case of error
   118  			statushooks.Done(ctxExecute)
   119  			// error - rollback transaction if we have one
   120  			if tx != nil {
   121  				_ = tx.Rollback()
   122  			}
   123  			// in case of error call the onComplete callback
   124  			if onComplete != nil {
   125  				onComplete()
   126  			}
   127  		}
   128  	}()
   129  
   130  	// start query
   131  	var rows pgx.Rows
   132  	rows, err = c.startQueryWithRetries(ctxExecute, session, query, args...)
   133  	if err != nil {
   134  		return
   135  	}
   136  
   137  	colDefs := fieldDescriptionsToColumns(rows.FieldDescriptions(), session.Connection.Conn())
   138  
   139  	result := queryresult.NewResult(colDefs)
   140  
   141  	// read the rows in a go routine
   142  	go func() {
   143  		// define a callback which fetches the timing information
   144  		// this will be invoked after reading rows is complete but BEFORE closing the rows object (which closes the connection)
   145  		timingCallback := func() {
   146  			c.getQueryTiming(ctxExecute, startTime, session, result.TimingResult)
   147  		}
   148  
   149  		// read in the rows and stream to the query result object
   150  		c.readRows(ctxExecute, rows, result, timingCallback)
   151  
   152  		// call the completion callback - if one was provided
   153  		if onComplete != nil {
   154  			onComplete()
   155  		}
   156  	}()
   157  
   158  	return result, nil
   159  }
   160  
   161  func (c *DbClient) getExecuteContext(ctx context.Context) context.Context {
   162  	queryTimeout := time.Duration(viper.GetInt(constants.ArgDatabaseQueryTimeout)) * time.Second
   163  	// if timeout is zero, do not set a timeout
   164  	if queryTimeout == 0 {
   165  		return ctx
   166  	}
   167  	// create a context with a deadline
   168  	shouldBeDoneBy := time.Now().Add(queryTimeout)
   169  	//nolint:golint,lostcancel //we don't use this cancel fn because, pgx prematurely cancels the PG connection when this cancel gets called in 'defer'
   170  	newCtx, _ := context.WithDeadline(ctx, shouldBeDoneBy)
   171  
   172  	return newCtx
   173  }
   174  
   175  func (c *DbClient) getQueryTiming(ctx context.Context, startTime time.Time, session *db_common.DatabaseSession, resultChannel chan *queryresult.TimingResult) {
   176  	// do not fetch if timing is disabled, unless output not JSON
   177  	if !c.shouldFetchTiming() {
   178  		return
   179  	}
   180  
   181  	var timingResult = &queryresult.TimingResult{
   182  		DurationMs: time.Since(startTime).Milliseconds(),
   183  	}
   184  	// disable fetching timing information to avoid recursion
   185  	c.disableTiming = true
   186  
   187  	// whatever happens, we need to reenable timing, and send the result back with at least the duration
   188  	defer func() {
   189  		c.disableTiming = false
   190  		resultChannel <- timingResult
   191  	}()
   192  
   193  	// load the timing summary
   194  	summary, err := c.loadTimingSummary(ctx, session)
   195  	if err != nil {
   196  		log.Printf("[WARN] getQueryTiming: failed to read scan metadata, err: %s", err)
   197  		return
   198  	}
   199  
   200  	// only load the individual scan  metadata if output is JSON or timing is verbose
   201  	var scans []*queryresult.ScanMetadataRow
   202  	if c.shouldFetchVerboseTiming() {
   203  		scans, err = c.loadTimingMetadata(ctx, session)
   204  		if err != nil {
   205  			log.Printf("[WARN] getQueryTiming: failed to read scan metadata, err: %s", err)
   206  			return
   207  		}
   208  	}
   209  
   210  	// populate hydrate calls and rows fetched
   211  	timingResult.Initialise(summary, scans)
   212  }
   213  
   214  func (c *DbClient) loadTimingSummary(ctx context.Context, session *db_common.DatabaseSession) (*queryresult.QueryRowSummary, error) {
   215  	var summary = &queryresult.QueryRowSummary{}
   216  	err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error {
   217  		query := fmt.Sprintf(`select uncached_rows_fetched,
   218  cached_rows_fetched,
   219  hydrate_calls, 
   220  scan_count,
   221  connection_count from %s.%s `, constants.InternalSchema, constants.ForeignTableScanMetadataSummary)
   222  		//query := fmt.Sprintf("select id, 'table' as table, cache_hit, rows_fetched, hydrate_calls, start_time, duration, columns, 'limit' as limit, quals from %s.%s where id > %d", constants.InternalSchema, constants.ForeignTableScanMetadata, session.ScanMetadataMaxId)
   223  		rows, err := tx.Query(ctx, query)
   224  		if err != nil {
   225  			return err
   226  		}
   227  
   228  		// scan into summary
   229  		summary, err = pgx.CollectOneRow(rows, pgx.RowToAddrOfStructByName[queryresult.QueryRowSummary])
   230  		// no rows counts as an error
   231  		if err != nil {
   232  			return err
   233  		}
   234  		return nil
   235  	})
   236  	return summary, err
   237  }
   238  
   239  func (c *DbClient) loadTimingMetadata(ctx context.Context, session *db_common.DatabaseSession) ([]*queryresult.ScanMetadataRow, error) {
   240  	var scans []*queryresult.ScanMetadataRow
   241  
   242  	err := db_common.ExecuteSystemClientCall(ctx, session.Connection.Conn(), func(ctx context.Context, tx pgx.Tx) error {
   243  		query := fmt.Sprintf(`
   244  select connection,
   245  "table",
   246  cache_hit, 
   247  rows_fetched, 
   248  hydrate_calls, 
   249  start_time,
   250  duration_ms,
   251  columns,
   252  "limit",
   253  quals from %s.%s order by duration_ms desc`, constants.InternalSchema, constants.ForeignTableScanMetadata)
   254  		rows, err := tx.Query(ctx, query)
   255  		if err != nil {
   256  			return err
   257  		}
   258  
   259  		scans, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[queryresult.ScanMetadataRow])
   260  		return err
   261  	})
   262  	return scans, err
   263  }
   264  
   265  // run query in a goroutine, so we can check for cancellation
   266  // in case the client becomes unresponsive and does not respect context cancellation
   267  func (c *DbClient) startQuery(ctx context.Context, conn *pgx.Conn, query string, args ...any) (rows pgx.Rows, err error) {
   268  	doneChan := make(chan bool)
   269  	go func() {
   270  		// start asynchronous query
   271  		rows, err = conn.Query(ctx, query, args...)
   272  		close(doneChan)
   273  	}()
   274  
   275  	select {
   276  	case <-doneChan:
   277  	case <-ctx.Done():
   278  		err = ctx.Err()
   279  	}
   280  	return
   281  }
   282  
   283  func (c *DbClient) readRows(ctx context.Context, rows pgx.Rows, result *queryresult.Result, timingCallback func()) {
   284  	// defer this, so that these get cleaned up even if there is an unforeseen error
   285  	defer func() {
   286  		// we are done fetching results. time for display. clear the status indication
   287  		statushooks.Done(ctx)
   288  		// call the timing callback BEFORE closing the rows
   289  		timingCallback()
   290  		// close the sql rows object
   291  		rows.Close()
   292  		if err := rows.Err(); err != nil {
   293  			result.StreamError(err)
   294  		}
   295  		// close the channels in the result object
   296  		result.Close()
   297  
   298  	}()
   299  
   300  	rowCount := 0
   301  Loop:
   302  	for rows.Next() {
   303  		select {
   304  		case <-ctx.Done():
   305  			statushooks.SetStatus(ctx, "Cancelling query")
   306  			break Loop
   307  		default:
   308  			rowResult, err := readRow(rows, result.Cols)
   309  			if err != nil {
   310  				// the error will be streamed in the defer
   311  				break Loop
   312  			}
   313  
   314  			// TACTICAL
   315  			// determine whether to stop the spinner as soon as we stream a row or to wait for completion
   316  			if isStreamingOutput() {
   317  				statushooks.Done(ctx)
   318  			}
   319  
   320  			result.StreamRow(rowResult)
   321  
   322  			// update the status message with the count of rows that have already been fetched
   323  			// this will not show if the spinner is not active
   324  			statushooks.SetStatus(ctx, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount)))
   325  			rowCount++
   326  		}
   327  	}
   328  }
   329  
   330  func readRow(rows pgx.Rows, cols []*queryresult.ColumnDef) ([]interface{}, error) {
   331  	columnValues, err := rows.Values()
   332  	if err != nil {
   333  		return nil, error_helpers.WrapError(err)
   334  	}
   335  	return populateRow(columnValues, cols)
   336  }
   337  
   338  func populateRow(columnValues []interface{}, cols []*queryresult.ColumnDef) ([]interface{}, error) {
   339  	result := make([]interface{}, len(columnValues))
   340  	for i, columnValue := range columnValues {
   341  		if columnValue != nil {
   342  			result[i] = columnValue
   343  			switch cols[i].DataType {
   344  			case "_TEXT":
   345  				if arr, ok := columnValue.([]interface{}); ok {
   346  					elements := utils.Map(arr, func(e interface{}) string { return e.(string) })
   347  					result[i] = strings.Join(elements, ",")
   348  				}
   349  			case "INET":
   350  				if inet, ok := columnValue.(netip.Prefix); ok {
   351  					result[i] = strings.TrimSuffix(inet.String(), "/32")
   352  				}
   353  			case "UUID":
   354  				if bytes, ok := columnValue.([16]uint8); ok {
   355  					if u, err := uuid.FromBytes(bytes[:]); err == nil {
   356  						result[i] = u
   357  					}
   358  				}
   359  			case "TIME":
   360  				if t, ok := columnValue.(pgtype.Time); ok {
   361  					result[i] = time.UnixMicro(t.Microseconds).UTC().Format("15:04:05")
   362  				}
   363  			case "INTERVAL":
   364  				if interval, ok := columnValue.(pgtype.Interval); ok {
   365  					var sb strings.Builder
   366  					years := interval.Months / 12
   367  					months := interval.Months % 12
   368  					if years > 0 {
   369  						sb.WriteString(fmt.Sprintf("%d %s ", years, utils.Pluralize("year", int(years))))
   370  					}
   371  					if months > 0 {
   372  						sb.WriteString(fmt.Sprintf("%d %s ", months, utils.Pluralize("mon", int(months))))
   373  					}
   374  					if interval.Days > 0 {
   375  						sb.WriteString(fmt.Sprintf("%d %s ", interval.Days, utils.Pluralize("day", int(interval.Days))))
   376  					}
   377  					if interval.Microseconds > 0 {
   378  						d := time.Duration(interval.Microseconds) * time.Microsecond
   379  						formatStr := time.Unix(0, 0).UTC().Add(d).Format("15:04:05")
   380  						sb.WriteString(formatStr)
   381  					}
   382  					result[i] = sb.String()
   383  				}
   384  
   385  			case "NUMERIC":
   386  				if numeric, ok := columnValue.(pgtype.Numeric); ok {
   387  					if f, err := numeric.Float64Value(); err == nil {
   388  						result[i] = f.Float64
   389  					}
   390  				}
   391  			}
   392  		}
   393  	}
   394  	return result, nil
   395  }
   396  
   397  func isStreamingOutput() bool {
   398  	outputFormat := viper.GetString(constants.ArgOutput)
   399  
   400  	return helpers.StringSliceContains([]string{constants.OutputFormatCSV, constants.OutputFormatLine}, outputFormat)
   401  }
   402  
   403  func humanizeRowCount(count int) string {
   404  	p := message.NewPrinter(language.English)
   405  	return p.Sprintf("%d", count)
   406  }