github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/depends/kit/sqlx/driver/postgres/schema.go (about)

     1  package postgres
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strings"
     7  
     8  	"github.com/machinefi/w3bstream/pkg/depends/kit/sqlx"
     9  	"github.com/machinefi/w3bstream/pkg/depends/kit/sqlx/builder"
    10  	"github.com/machinefi/w3bstream/pkg/depends/x/misc/slice"
    11  )
    12  
    13  var regexpUsing = regexp.MustCompile(`USING ([^ ]+)`)
    14  
    15  func databaseFromSchema(db sqlx.DBExecutor) (*sqlx.Database, error) {
    16  	d := db.D()
    17  
    18  	var (
    19  		tableNames  = slice.ToAnySlice(d.Tables.TableNames()...)
    20  		tableSchema = SchemaDB.T(&ColumnSchema{}).WithSchema("information_schema")
    21  		columns     = make([]ColumnSchema, 0)
    22  	)
    23  
    24  	d = sqlx.NewDatabase(d.Name).WithSchema(d.Schema)
    25  
    26  	schema := "public"
    27  	if d.Schema != "" {
    28  		schema = d.Schema
    29  	}
    30  
    31  	stmt := builder.Select(tableSchema.Columns.Clone()).
    32  		From(
    33  			tableSchema,
    34  			builder.Where(
    35  				builder.And(
    36  					tableSchema.ColByFieldName("TABLE_SCHEMA").Eq(schema),
    37  					tableSchema.ColByFieldName("TABLE_NAME").In(tableNames...),
    38  				),
    39  			),
    40  		)
    41  
    42  	err := db.QueryAndScan(stmt, &columns)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	for i := range columns {
    48  		cs := columns[i]
    49  
    50  		tbl := d.Table(cs.TABLE_NAME)
    51  		if tbl == nil {
    52  			tbl = builder.T(cs.TABLE_NAME)
    53  			d.AddTable(tbl)
    54  		}
    55  
    56  		tbl.AddCol(colFromSchema(&cs))
    57  	}
    58  
    59  	if tableSchema.Columns.Len() != 0 {
    60  		v := SchemaDB.T(&IndexSchema{})
    61  		indexes := make([]IndexSchema, 0)
    62  
    63  		err = db.QueryAndScan(
    64  			builder.Select(v.Columns.Clone()).
    65  				From(
    66  					v,
    67  					builder.Where(
    68  						builder.And(
    69  							v.ColByFieldName("TABLE_SCHEMA").Eq(schema),
    70  							v.ColByFieldName("TABLE_NAME").In(tableNames...),
    71  						),
    72  					),
    73  				),
    74  			&indexes,
    75  		)
    76  
    77  		if err != nil {
    78  			return nil, err
    79  		}
    80  
    81  		for _, index := range indexes {
    82  			table := d.Table(index.TABLE_NAME)
    83  			key := &builder.Key{
    84  				Name:     strings.ToLower(index.INDEX_NAME[len(table.Name)+1:]),
    85  				Method:   strings.ToUpper(regexpUsing.FindString(index.INDEX_DEF)[6:]),
    86  				IsUnique: strings.Contains(index.INDEX_DEF, "UNIQUE"),
    87  				Def: builder.IndexDef{
    88  					Expr: strings.TrimSpace(regexpUsing.Split(index.INDEX_DEF, 2)[1]),
    89  				},
    90  			}
    91  			table.AddKey(key)
    92  		}
    93  	}
    94  
    95  	return d, nil
    96  }
    97  
    98  func colFromSchema(columnSchema *ColumnSchema) *builder.Column {
    99  	col := builder.Col(columnSchema.COLUMN_NAME)
   100  
   101  	defaultValue := columnSchema.COLUMN_DEFAULT
   102  
   103  	if defaultValue != "" {
   104  		col.AutoIncrement = strings.HasSuffix(columnSchema.COLUMN_DEFAULT, "_seq'::regclass)")
   105  
   106  		if !col.AutoIncrement {
   107  			if !strings.Contains(defaultValue, "'::") && '0' <= defaultValue[0] && defaultValue[0] <= '9' {
   108  				defaultValue = fmt.Sprintf("'%s'::integer", defaultValue)
   109  			}
   110  			col.Default = &defaultValue
   111  		}
   112  	}
   113  
   114  	dataType := columnSchema.DATA_TYPE
   115  
   116  	if col.AutoIncrement {
   117  		if strings.HasPrefix(dataType, "big") {
   118  			dataType = "bigserial"
   119  		} else {
   120  			dataType = "serial"
   121  		}
   122  	}
   123  
   124  	col.DataType = dataType
   125  
   126  	// numeric type
   127  	if columnSchema.NUMERIC_PRECISION > 0 {
   128  		col.Length = columnSchema.NUMERIC_PRECISION
   129  		col.Decimal = columnSchema.NUMERIC_SCALE
   130  	} else {
   131  		col.Length = columnSchema.CHARACTER_MAXIMUM_LENGTH
   132  	}
   133  
   134  	if columnSchema.IS_NULLABLE == "YES" {
   135  		col.Null = true
   136  	}
   137  
   138  	return col
   139  }
   140  
   141  type ColumnSchema struct {
   142  	TABLE_SCHEMA             string `db:"table_schema"`
   143  	TABLE_NAME               string `db:"table_name"`
   144  	COLUMN_NAME              string `db:"column_name"`
   145  	DATA_TYPE                string `db:"data_type"`
   146  	IS_NULLABLE              string `db:"is_nullable"`
   147  	COLUMN_DEFAULT           string `db:"column_default"`
   148  	CHARACTER_MAXIMUM_LENGTH uint64 `db:"character_maximum_length"`
   149  	NUMERIC_PRECISION        uint64 `db:"numeric_precision"`
   150  	NUMERIC_SCALE            uint64 `db:"numeric_scale"`
   151  }
   152  
   153  func (ColumnSchema) TableName() string { return "columns" }
   154  
   155  type IndexSchema struct {
   156  	TABLE_SCHEMA string `db:"schemaname"`
   157  	TABLE_NAME   string `db:"tablename"`
   158  	INDEX_NAME   string `db:"indexname"`
   159  	INDEX_DEF    string `db:"indexdef"`
   160  }
   161  
   162  func (IndexSchema) TableName() string { return "pg_indexes" }
   163  
   164  var SchemaDB = sqlx.NewDatabase("INFORMATION_SCHEMA")
   165  
   166  func init() {
   167  	SchemaDB.Register(&ColumnSchema{})
   168  	SchemaDB.Register(&IndexSchema{})
   169  }