github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_common/schema.go (about)

     1  package db_common
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/jackc/pgx/v5"
    10  	typeHelpers "github.com/turbot/go-kit/types"
    11  	"github.com/turbot/steampipe/pkg/constants"
    12  	"github.com/turbot/steampipe/pkg/utils"
    13  )
    14  
    15  type schemaRecord struct {
    16  	TableSchema       string
    17  	TableName         string
    18  	ColumnName        string
    19  	UdtName           string
    20  	ColumnDefault     string
    21  	IsNullable        string
    22  	DataType          string
    23  	ColumnDescription string
    24  	TableDescription  string
    25  }
    26  
    27  func LoadForeignSchemaNames(ctx context.Context, conn *pgx.Conn) ([]string, error) {
    28  	res, err := conn.Query(ctx, "SELECT DISTINCT foreign_table_schema FROM information_schema.foreign_tables WHERE foreign_server_name='steampipe'")
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  
    33  	var foreignSchemaNames []string
    34  	var schema string
    35  	for res.Next() {
    36  		if err := res.Scan(&schema); err != nil {
    37  			return nil, err
    38  		}
    39  		// ignore internal schema and legacy command schema
    40  		if schema != constants.InternalSchema && schema != constants.LegacyCommandSchema {
    41  			foreignSchemaNames = append(foreignSchemaNames, schema)
    42  		}
    43  	}
    44  	sort.Strings(foreignSchemaNames)
    45  	return foreignSchemaNames, nil
    46  }
    47  
    48  func LoadSchemaMetadata(ctx context.Context, conn *pgx.Conn, query string) (*SchemaMetadata, error) {
    49  	var schemaRecords []schemaRecord
    50  	rows, err := conn.Query(ctx, query)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	defer rows.Close()
    55  
    56  	schemaRecords, err = getSchemaRecordsFromRows(rows)
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	// build schema metadata from query result
    62  	return buildSchemaMetadata(schemaRecords)
    63  }
    64  
    65  func buildSchemaMetadata(records []schemaRecord) (_ *SchemaMetadata, err error) {
    66  	utils.LogTime("db.buildSchemaMetadata start")
    67  	defer func() {
    68  		utils.LogTime("db.buildSchemaMetadata end")
    69  	}()
    70  	schemaMetadata := NewSchemaMetadata()
    71  
    72  	utils.LogTime("db.buildSchemaMetadata.iteration start")
    73  	for _, record := range records {
    74  		if _, schemaFound := schemaMetadata.Schemas[record.TableSchema]; !schemaFound {
    75  			schemaMetadata.Schemas[record.TableSchema] = map[string]TableSchema{}
    76  		}
    77  
    78  		if _, tblFound := schemaMetadata.Schemas[record.TableSchema][record.TableName]; !tblFound {
    79  			schemaMetadata.Schemas[record.TableSchema][record.TableName] = TableSchema{
    80  				Schema:      record.TableSchema,
    81  				Name:        record.TableName,
    82  				FullName:    fmt.Sprintf("%s.%s", record.TableSchema, record.TableName),
    83  				Description: record.TableDescription,
    84  				Columns:     map[string]ColumnSchema{},
    85  			}
    86  		}
    87  
    88  		schemaMetadata.Schemas[record.TableSchema][record.TableName].Columns[record.ColumnName] = ColumnSchema{
    89  			Name:        record.ColumnName,
    90  			NotNull:     typeHelpers.StringToBool(record.IsNullable),
    91  			Type:        record.DataType,
    92  			Default:     record.ColumnDefault,
    93  			Description: record.ColumnDescription,
    94  		}
    95  
    96  		if strings.HasPrefix(record.TableSchema, "pg_temp") {
    97  			schemaMetadata.TemporarySchemaName = record.TableSchema
    98  		}
    99  	}
   100  	utils.LogTime("db.buildSchemaMetadata.iteration end")
   101  
   102  	return schemaMetadata, err
   103  }
   104  
   105  func getSchemaRecordsFromRows(rows pgx.Rows) ([]schemaRecord, error) {
   106  	utils.LogTime("db.getSchemaRecordsFromRows start")
   107  	defer utils.LogTime("db.getSchemaRecordsFromRows end")
   108  
   109  	var records []schemaRecord
   110  
   111  	// set this to the number of cols that are getting fetched
   112  	numCols := 9
   113  
   114  	rawResult := make([][]byte, numCols)
   115  	dest := make([]interface{}, numCols) // A temporary interface{} slice
   116  	for i := range rawResult {
   117  		dest[i] = &rawResult[i] // Put pointers to each string in the interface slice
   118  	}
   119  
   120  	for rows.Next() {
   121  		err := rows.Scan(dest...)
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  
   126  		t := schemaRecord{
   127  			TableName:         string(rawResult[0]),
   128  			ColumnName:        string(rawResult[1]),
   129  			ColumnDefault:     string(rawResult[2]),
   130  			IsNullable:        string(rawResult[3]),
   131  			DataType:          string(rawResult[4]),
   132  			UdtName:           string(rawResult[5]),
   133  			TableSchema:       string(rawResult[6]),
   134  			ColumnDescription: string(rawResult[7]),
   135  			TableDescription:  string(rawResult[8]),
   136  		}
   137  		// for ltree data type, we need to use UdtName
   138  		if t.DataType == "USER-DEFINED" {
   139  			t.DataType = t.UdtName
   140  		}
   141  		records = append(records, t)
   142  	}
   143  
   144  	return records, nil
   145  }