github.com/kunlun-qilian/sqlx/v3@v3.0.0/connectors/postgresql/schemas.go (about)

     1  package postgresql
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strings"
     7  
     8  	"github.com/kunlun-qilian/sqlx/v3"
     9  	"github.com/kunlun-qilian/sqlx/v3/builder"
    10  )
    11  
    12  func toInterfaces(list ...string) []interface{} {
    13  	s := make([]interface{}, len(list))
    14  	for i, v := range list {
    15  		s[i] = v
    16  	}
    17  	return s
    18  }
    19  
    20  var reUsing = regexp.MustCompile(`USING ([^ ]+)`)
    21  
    22  func dbFromInformationSchema(db sqlx.DBExecutor) (*sqlx.Database, error) {
    23  	d := db.D()
    24  
    25  	dbName := d.Name
    26  	dbSchema := d.Schema
    27  	tableNames := d.Tables.TableNames()
    28  
    29  	d = sqlx.NewDatabase(dbName).WithSchema(dbSchema)
    30  
    31  	tableColumnSchema := SchemaDatabase.T(&ColumnSchema{}).WithSchema("information_schema")
    32  	columnSchemaList := make([]ColumnSchema, 0)
    33  
    34  	tableSchema := "public"
    35  	if d.Schema != "" {
    36  		tableSchema = d.Schema
    37  	}
    38  
    39  	stmt := builder.Select(tableColumnSchema.Columns.Clone()).From(tableColumnSchema,
    40  		builder.Where(
    41  			builder.And(
    42  				tableColumnSchema.F("TABLE_SCHEMA").Eq(tableSchema),
    43  				tableColumnSchema.F("TABLE_NAME").In(toInterfaces(tableNames...)...),
    44  			),
    45  		),
    46  	)
    47  
    48  	err := db.QueryExprAndScan(stmt, &columnSchemaList)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	for i := range columnSchemaList {
    54  		columnSchema := columnSchemaList[i]
    55  
    56  		table := d.Table(columnSchema.TABLE_NAME)
    57  		if table == nil {
    58  			table = builder.T(columnSchema.TABLE_NAME)
    59  			d.AddTable(table)
    60  		}
    61  
    62  		table.AddCol(colFromColumnSchema(&columnSchema))
    63  	}
    64  
    65  	if tableColumnSchema.Columns.Len() != 0 {
    66  		tableIndexSchema := SchemaDatabase.T(&IndexSchema{})
    67  
    68  		indexList := make([]IndexSchema, 0)
    69  
    70  		err = db.QueryExprAndScan(
    71  			builder.Select(tableIndexSchema.Columns.Clone()).
    72  				From(
    73  					tableIndexSchema,
    74  					builder.Where(
    75  						builder.And(
    76  							tableIndexSchema.F("TABLE_SCHEMA").Eq(tableSchema),
    77  							tableIndexSchema.F("TABLE_NAME").In(toInterfaces(tableNames...)...),
    78  						),
    79  					),
    80  				),
    81  			&indexList,
    82  		)
    83  
    84  		if err != nil {
    85  			return nil, err
    86  		}
    87  
    88  		for _, indexSchema := range indexList {
    89  			table := d.Table(indexSchema.TABLE_NAME)
    90  
    91  			key := &builder.Key{}
    92  			key.Name = strings.ToLower(indexSchema.INDEX_NAME[len(table.Name)+1:])
    93  			key.Method = strings.ToUpper(reUsing.FindString(indexSchema.INDEX_DEF)[6:])
    94  			key.IsUnique = strings.Contains(indexSchema.INDEX_DEF, "UNIQUE")
    95  
    96  			key.Def.Expr = strings.Replace(strings.TrimSpace(reUsing.Split(indexSchema.INDEX_DEF, 2)[1]), ", ", ",", -1)
    97  
    98  			table.AddKey(key)
    99  		}
   100  	}
   101  
   102  	return d, nil
   103  }
   104  
   105  var SchemaDatabase = sqlx.NewDatabase("INFORMATION_SCHEMA")
   106  
   107  func init() {
   108  	SchemaDatabase.Register(&ColumnSchema{})
   109  	SchemaDatabase.Register(&IndexSchema{})
   110  }
   111  
   112  func colFromColumnSchema(columnSchema *ColumnSchema) *builder.Column {
   113  	col := builder.Col(columnSchema.COLUMN_NAME)
   114  
   115  	defaultValue := columnSchema.COLUMN_DEFAULT
   116  
   117  	if defaultValue != "" {
   118  		col.AutoIncrement = strings.HasSuffix(columnSchema.COLUMN_DEFAULT, "_seq'::regclass)")
   119  
   120  		if !col.AutoIncrement {
   121  			if !strings.Contains(defaultValue, "'::") && '0' <= defaultValue[0] && defaultValue[0] <= '9' {
   122  				defaultValue = fmt.Sprintf("'%s'::integer", defaultValue)
   123  			}
   124  			col.Default = &defaultValue
   125  		}
   126  	}
   127  
   128  	dataType := columnSchema.DATA_TYPE
   129  
   130  	if col.AutoIncrement {
   131  		if strings.HasPrefix(dataType, "big") {
   132  			dataType = "bigserial"
   133  		} else {
   134  			dataType = "serial"
   135  		}
   136  	}
   137  
   138  	col.DataType = dataType
   139  
   140  	// numeric type
   141  	if columnSchema.NUMERIC_PRECISION > 0 {
   142  		col.Length = columnSchema.NUMERIC_PRECISION
   143  		col.Decimal = columnSchema.NUMERIC_SCALE
   144  	} else {
   145  		col.Length = columnSchema.CHARACTER_MAXIMUM_LENGTH
   146  	}
   147  
   148  	if columnSchema.IS_NULLABLE == "YES" {
   149  		col.Null = true
   150  	}
   151  
   152  	return col
   153  }
   154  
   155  type ColumnSchema struct {
   156  	TABLE_SCHEMA             string `db:"table_schema"`
   157  	TABLE_NAME               string `db:"table_name"`
   158  	COLUMN_NAME              string `db:"column_name"`
   159  	DATA_TYPE                string `db:"data_type"`
   160  	IS_NULLABLE              string `db:"is_nullable"`
   161  	COLUMN_DEFAULT           string `db:"column_default"`
   162  	CHARACTER_MAXIMUM_LENGTH uint64 `db:"character_maximum_length"`
   163  	NUMERIC_PRECISION        uint64 `db:"numeric_precision"`
   164  	NUMERIC_SCALE            uint64 `db:"numeric_scale"`
   165  }
   166  
   167  func (ColumnSchema) TableName() string {
   168  	return "columns"
   169  }
   170  
   171  type IndexSchema struct {
   172  	TABLE_SCHEMA string `db:"schemaname"`
   173  	TABLE_NAME   string `db:"tablename"`
   174  	INDEX_NAME   string `db:"indexname"`
   175  	INDEX_DEF    string `db:"indexdef"`
   176  }
   177  
   178  func (IndexSchema) TableName() string {
   179  	return `
   180  	(SELECT n.nspname AS schemaname,
   181      c.relname AS tablename,
   182      i.relname AS indexname,
   183      pg_get_indexdef(i.oid) AS indexdef
   184  	FROM pg_index x
   185  	JOIN pg_class c ON c.oid = x.indrelid
   186  	JOIN pg_class i ON i.oid = x.indexrelid
   187  	LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
   188  	LEFT JOIN pg_tablespace t ON t.oid = i.reltablespace
   189  	WHERE (c.relkind in ('r','m','p'))
   190      AND i.relkind in ('i', 'I')) as pg_indexes
   191  	`
   192  }