github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/postgres/catalog.go (about)

     1  package postgres
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"regexp"
     8  	"strings"
     9  	textscanner "text/scanner"
    10  
    11  	"github.com/octohelm/storage/internal/sql/adapter"
    12  	"github.com/octohelm/storage/internal/sql/scanner"
    13  	"github.com/octohelm/storage/pkg/sqlbuilder"
    14  )
    15  
    16  func (a *pgAdapter) Catalog(ctx context.Context) (*sqlbuilder.Tables, error) {
    17  	return catalog(ctx, a, a.dbName)
    18  }
    19  
    20  var reUsing = regexp.MustCompile(`USING ([^ ]+)`)
    21  
    22  func catalog(ctx context.Context, a adapter.Adapter, dbName string) (*sqlbuilder.Tables, error) {
    23  	cat := &sqlbuilder.Tables{}
    24  
    25  	tableColumnSchema := sqlbuilder.TableFromModel(&columnSchema{})
    26  
    27  	tableSchema := "public"
    28  
    29  	stmt := sqlbuilder.Select(tableColumnSchema.Cols()).From(tableColumnSchema,
    30  		sqlbuilder.Where(
    31  			sqlbuilder.And(
    32  				sqlbuilder.TypedColOf[string](tableColumnSchema, "TABLE_SCHEMA").V(sqlbuilder.Eq(tableSchema)),
    33  			),
    34  		),
    35  	)
    36  
    37  	rows, err := a.Query(ctx, stmt)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  
    42  	colSchemaList := make([]columnSchema, 0)
    43  
    44  	if err := scanner.Scan(ctx, rows, &colSchemaList); err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	for i := range colSchemaList {
    49  		colSchema := colSchemaList[i]
    50  
    51  		table := cat.Table(colSchema.TABLE_NAME)
    52  		if table == nil {
    53  			table = sqlbuilder.T(colSchema.TABLE_NAME)
    54  			cat.Add(table)
    55  		}
    56  
    57  		table.(sqlbuilder.ColumnCollectionManger).AddCol(colSchema.ToColumn())
    58  	}
    59  
    60  	if cols := tableColumnSchema.Cols(); cols.Len() != 0 {
    61  		tableIndexSchema := sqlbuilder.TableFromModel(&indexSchema{})
    62  
    63  		indexList := make([]indexSchema, 0)
    64  
    65  		rows, err := a.Query(
    66  			ctx,
    67  			sqlbuilder.Select(tableIndexSchema.Cols()).
    68  				From(
    69  					tableIndexSchema,
    70  					sqlbuilder.Where(
    71  						sqlbuilder.And(
    72  							sqlbuilder.TypedColOf[string](tableIndexSchema, "TABLE_SCHEMA").V(sqlbuilder.Eq(tableSchema)),
    73  						),
    74  					),
    75  				),
    76  		)
    77  		if err != nil {
    78  			return nil, err
    79  		}
    80  
    81  		if err := scanner.Scan(ctx, rows, &indexList); err != nil {
    82  			return nil, err
    83  		}
    84  
    85  		for _, idxSchema := range indexList {
    86  			t := cat.Table(idxSchema.TABLE_NAME)
    87  			t.(sqlbuilder.KeyCollectionManager).AddKey(idxSchema.ToKey(t))
    88  		}
    89  	}
    90  
    91  	return cat, nil
    92  }
    93  
    94  type columnSchema struct {
    95  	TABLE_SCHEMA             string `db:"table_schema"`
    96  	TABLE_NAME               string `db:"table_name"`
    97  	COLUMN_NAME              string `db:"column_name"`
    98  	DATA_TYPE                string `db:"data_type"`
    99  	IS_NULLABLE              string `db:"is_nullable"`
   100  	COLUMN_DEFAULT           string `db:"column_default"`
   101  	CHARACTER_MAXIMUM_LENGTH uint64 `db:"character_maximum_length"`
   102  	NUMERIC_PRECISION        uint64 `db:"numeric_precision"`
   103  	NUMERIC_SCALE            uint64 `db:"numeric_scale"`
   104  }
   105  
   106  func (columnSchema) TableName() string {
   107  	return "information_schema.columns"
   108  }
   109  
   110  func (columnSchema *columnSchema) ToColumn() sqlbuilder.Column {
   111  	defaultValue := columnSchema.COLUMN_DEFAULT
   112  	def := sqlbuilder.ColumnDef{}
   113  
   114  	if defaultValue != "" {
   115  		def.AutoIncrement = strings.HasSuffix(columnSchema.COLUMN_DEFAULT, "_seq'::regclass)")
   116  
   117  		if !def.AutoIncrement {
   118  			if !strings.Contains(defaultValue, "'::") && '0' <= defaultValue[0] && defaultValue[0] <= '9' {
   119  				defaultValue = fmt.Sprintf("'%s'::integer", defaultValue)
   120  			}
   121  			def.Default = &defaultValue
   122  		}
   123  	}
   124  
   125  	dataType := columnSchema.DATA_TYPE
   126  
   127  	if def.AutoIncrement {
   128  		if strings.HasPrefix(dataType, "big") {
   129  			dataType = "bigserial"
   130  		} else {
   131  			dataType = "serial"
   132  		}
   133  	}
   134  
   135  	def.DataType = dataType
   136  
   137  	// numeric type
   138  	if columnSchema.NUMERIC_PRECISION > 0 {
   139  		def.Length = columnSchema.NUMERIC_PRECISION
   140  		def.Decimal = columnSchema.NUMERIC_SCALE
   141  	} else {
   142  		def.Length = columnSchema.CHARACTER_MAXIMUM_LENGTH
   143  	}
   144  
   145  	if columnSchema.IS_NULLABLE == "YES" {
   146  		def.Null = true
   147  	}
   148  
   149  	return sqlbuilder.Col(columnSchema.COLUMN_NAME, sqlbuilder.ColDef(def))
   150  }
   151  
   152  type indexSchema struct {
   153  	TABLE_SCHEMA string `db:"schemaname"`
   154  	TABLE_NAME   string `db:"tablename"`
   155  	INDEX_NAME   string `db:"indexname"`
   156  	INDEX_DEF    string `db:"indexdef"`
   157  }
   158  
   159  func (indexSchema) TableName() string {
   160  	return "pg_indexes"
   161  }
   162  
   163  func (idxSchema *indexSchema) ToKey(table sqlbuilder.Table) sqlbuilder.Key {
   164  	name := strings.ToLower(idxSchema.INDEX_NAME[len(table.TableName())+1:])
   165  	method := strings.ToUpper(reUsing.FindString(idxSchema.INDEX_DEF)[6:])
   166  	isUnique := strings.Contains(idxSchema.INDEX_DEF, "UNIQUE")
   167  
   168  	colNameAndOptions := make([]string, 0)
   169  
   170  	s := &textscanner.Scanner{}
   171  	s.Init(bytes.NewBufferString(strings.TrimSpace(reUsing.Split(idxSchema.INDEX_DEF, 2)[1])))
   172  
   173  	parts := make([]string, 0)
   174  
   175  	for t := s.Scan(); t != textscanner.EOF; t = s.Scan() {
   176  		part := s.TokenText()
   177  
   178  		switch part {
   179  		case "(":
   180  			continue
   181  		case ",", ")":
   182  			colNameAndOption := parts[0]
   183  			if len(parts) > 1 {
   184  				colNameAndOption += "/" + parts[1]
   185  			}
   186  			colNameAndOptions = append(colNameAndOptions, colNameAndOption)
   187  			// reset
   188  			parts = make([]string, 0)
   189  			continue
   190  		}
   191  		parts = append(parts, part)
   192  	}
   193  
   194  	if isUnique {
   195  		return sqlbuilder.UniqueIndex(name, nil, sqlbuilder.IndexUsing(method), sqlbuilder.IndexColNameAndOptions(colNameAndOptions...))
   196  	}
   197  	return sqlbuilder.Index(name, nil, sqlbuilder.IndexUsing(method), sqlbuilder.IndexColNameAndOptions(colNameAndOptions...))
   198  }