gitee.com/go-genie/sqlx@v1.0.3/connectors/postgresql/postgresql_connector.go (about)

     1  package postgresql
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"gitee.com/go-genie/sqlx/generator"
     9  	"io"
    10  	"reflect"
    11  	"strconv"
    12  	"strings"
    13  
    14  	typex "gitee.com/go-genie/xx/types"
    15  
    16  	"gitee.com/go-genie/sqlx"
    17  	"gitee.com/go-genie/sqlx/builder"
    18  	"gitee.com/go-genie/sqlx/command"
    19  	"github.com/lib/pq"
    20  )
    21  
    22  var _ interface {
    23  	driver.Connector
    24  	builder.Dialect
    25  } = (*PostgreSQLConnector)(nil)
    26  
    27  type PostgreSQLConnector struct {
    28  	Host       string
    29  	DBName     string
    30  	Extra      string
    31  	Extensions []string
    32  }
    33  
    34  func (c *PostgreSQLConnector) Connect(ctx context.Context) (driver.Conn, error) {
    35  	d := c.Driver()
    36  
    37  	conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra))
    38  	if err != nil {
    39  		if c.IsErrorUnknownDatabase(err) {
    40  			connectForCreateDB, err := d.Open(dsn(c.Host, "", c.Extra))
    41  			if err != nil {
    42  				return nil, err
    43  			}
    44  			if _, err := connectForCreateDB.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil {
    45  				return nil, err
    46  			}
    47  			if err := connectForCreateDB.Close(); err != nil {
    48  				return nil, err
    49  			}
    50  			return c.Connect(ctx)
    51  		}
    52  		return nil, err
    53  	}
    54  	for _, ex := range c.Extensions {
    55  		if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), "CREATE EXTENSION IF NOT EXISTS "+ex+";", nil); err != nil {
    56  			return nil, err
    57  		}
    58  	}
    59  
    60  	return conn, nil
    61  }
    62  
    63  func (PostgreSQLConnector) Driver() driver.Driver {
    64  	return &PostgreSQLLoggingDriver{}
    65  }
    66  
    67  func dsn(host string, dbName string, extra string) string {
    68  	if extra != "" {
    69  		extra = "?" + extra
    70  	}
    71  	return host + "/" + dbName + extra
    72  }
    73  
    74  func (c PostgreSQLConnector) WithDBName(dbName string) driver.Connector {
    75  	c.DBName = dbName
    76  	return &c
    77  }
    78  
    79  func (c *PostgreSQLConnector) Generate(ctx context.Context, db sqlx.DBExecutor) error {
    80  	//output := command.MigrationOutputFromContext(ctx)
    81  
    82  	prevDB, err := dbFromInformationSchema(db, COMMAND_GENERATE)
    83  	if err != nil {
    84  		return err
    85  	}
    86  
    87  	//cwd, _ := os.Getwd()
    88  	//
    89  	//pkg, err := packagesx.Load(cwd)
    90  	//if err != nil {
    91  	//	panic(err)
    92  	//}
    93  
    94  	models := generator.NewModelsFromDataBase(prevDB)
    95  
    96  	for _, item := range models {
    97  		err = item.Generator()
    98  		if err != nil {
    99  			panic(err)
   100  		}
   101  	}
   102  
   103  	return nil
   104  }
   105  
   106  func (c *PostgreSQLConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error {
   107  	output := command.MigrationOutputFromContext(ctx)
   108  
   109  	prevDB, err := dbFromInformationSchema(db, COMMAND_MIGRATE)
   110  	if err != nil {
   111  		return err
   112  	}
   113  
   114  	d := db.D()
   115  	dialect := db.Dialect()
   116  
   117  	exec := func(expr builder.SqlExpr) error {
   118  		if expr == nil || expr.IsNil() {
   119  			return nil
   120  		}
   121  
   122  		if output != nil {
   123  			_, _ = io.WriteString(output, builder.ResolveExpr(expr).Query())
   124  			_, _ = io.WriteString(output, "\n")
   125  			return nil
   126  		}
   127  
   128  		_, err := db.ExecExpr(expr)
   129  		return err
   130  	}
   131  
   132  	if prevDB == nil {
   133  		prevDB = &sqlx.Database{
   134  			Name: d.Name,
   135  		}
   136  		if err := exec(dialect.CreateDatabase(d.Name)); err != nil {
   137  			return err
   138  		}
   139  	}
   140  
   141  	if d.Schema != "" {
   142  		if err := exec(dialect.CreateSchema(d.Schema)); err != nil {
   143  			return err
   144  		}
   145  		prevDB = prevDB.WithSchema(d.Schema)
   146  	}
   147  
   148  	for _, name := range d.Tables.TableNames() {
   149  		table := d.Table(name)
   150  
   151  		prevTable := prevDB.Table(name)
   152  
   153  		if prevTable == nil {
   154  			for _, expr := range dialect.CreateTableIsNotExists(table) {
   155  				if err := exec(expr); err != nil {
   156  					return err
   157  				}
   158  			}
   159  			continue
   160  		}
   161  
   162  		exprList := table.Diff(prevTable, dialect)
   163  
   164  		for _, expr := range exprList {
   165  			if err := exec(expr); err != nil {
   166  				return err
   167  			}
   168  		}
   169  	}
   170  
   171  	return nil
   172  }
   173  
   174  func (PostgreSQLConnector) DriverName() string {
   175  	return "postgres"
   176  }
   177  
   178  func (PostgreSQLConnector) PrimaryKeyName() string {
   179  	return "pkey"
   180  }
   181  
   182  func (PostgreSQLConnector) IsErrorUnknownDatabase(err error) bool {
   183  	if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "3D000" {
   184  		return true
   185  	}
   186  	return false
   187  }
   188  
   189  func (PostgreSQLConnector) IsErrorConflict(err error) bool {
   190  	if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "23505" {
   191  		return true
   192  	}
   193  	return false
   194  }
   195  
   196  func (c *PostgreSQLConnector) CreateDatabase(dbName string) builder.SqlExpr {
   197  	e := builder.Expr("CREATE DATABASE ")
   198  	e.WriteQuery(dbName)
   199  	e.WriteEnd()
   200  	return e
   201  }
   202  
   203  func (c *PostgreSQLConnector) CreateSchema(schema string) builder.SqlExpr {
   204  	e := builder.Expr("CREATE SCHEMA IF NOT EXISTS ")
   205  	e.WriteQuery(schema)
   206  	e.WriteEnd()
   207  	return e
   208  }
   209  
   210  func (c *PostgreSQLConnector) DropDatabase(dbName string) builder.SqlExpr {
   211  	e := builder.Expr("DROP DATABASE IF EXISTS ")
   212  	e.WriteQuery(dbName)
   213  	e.WriteEnd()
   214  	return e
   215  }
   216  
   217  func (c *PostgreSQLConnector) AddIndex(key *builder.Key) builder.SqlExpr {
   218  	if key.IsPrimary() {
   219  		e := builder.Expr("ALTER TABLE ")
   220  		e.WriteExpr(key.Table)
   221  		e.WriteQuery(" ADD PRIMARY KEY ")
   222  		e.WriteExpr(key.Def.TableExpr(key.Table))
   223  		e.WriteEnd()
   224  		return e
   225  	}
   226  
   227  	e := builder.Expr("CREATE ")
   228  	if key.IsUnique {
   229  		e.WriteQuery("UNIQUE ")
   230  	}
   231  	e.WriteQuery("INDEX ")
   232  
   233  	e.WriteQuery(key.Table.Name)
   234  	e.WriteQuery("_")
   235  	e.WriteQuery(key.Name)
   236  
   237  	e.WriteQuery(" ON ")
   238  	e.WriteExpr(key.Table)
   239  
   240  	if m := strings.ToUpper(key.Method); m != "" {
   241  		if m == "SPATIAL" {
   242  			m = "GIST"
   243  		}
   244  		e.WriteQuery(" USING ")
   245  		e.WriteQuery(m)
   246  	}
   247  
   248  	e.WriteQueryByte(' ')
   249  	e.WriteExpr(key.Def.TableExpr(key.Table))
   250  
   251  	e.WriteEnd()
   252  	return e
   253  }
   254  
   255  func (c *PostgreSQLConnector) DropIndex(key *builder.Key) builder.SqlExpr {
   256  	if key.IsPrimary() {
   257  		e := builder.Expr("ALTER TABLE ")
   258  		e.WriteExpr(key.Table)
   259  		e.WriteQuery(" DROP CONSTRAINT ")
   260  		e.WriteExpr(key.Table)
   261  		e.WriteQuery("_pkey")
   262  		e.WriteEnd()
   263  		return e
   264  	}
   265  	e := builder.Expr("DROP ")
   266  
   267  	e.WriteQuery("INDEX IF EXISTS ")
   268  	e.WriteExpr(key.Table)
   269  	e.WriteQueryByte('_')
   270  	e.WriteQuery(key.Name)
   271  	e.WriteEnd()
   272  
   273  	return e
   274  }
   275  
   276  func (c *PostgreSQLConnector) CreateTableIsNotExists(t *builder.Table) (exprs []builder.SqlExpr) {
   277  	expr := builder.Expr("CREATE TABLE IF NOT EXISTS ")
   278  	expr.WriteExpr(t)
   279  	expr.WriteQueryByte(' ')
   280  	expr.WriteGroup(func(e *builder.Ex) {
   281  		if t.Columns.IsNil() {
   282  			return
   283  		}
   284  
   285  		t.Columns.Range(func(col *builder.Column, idx int) {
   286  			if col.DeprecatedActions != nil {
   287  				return
   288  			}
   289  
   290  			if idx > 0 {
   291  				e.WriteQueryByte(',')
   292  			}
   293  			e.WriteQueryByte('\n')
   294  			e.WriteQueryByte('\t')
   295  
   296  			e.WriteExpr(col)
   297  			e.WriteQueryByte(' ')
   298  			e.WriteExpr(c.DataType(col.ColumnType))
   299  		})
   300  
   301  		t.Keys.Range(func(key *builder.Key, idx int) {
   302  			if key.IsPrimary() {
   303  				e.WriteQueryByte(',')
   304  				e.WriteQueryByte('\n')
   305  				e.WriteQueryByte('\t')
   306  				e.WriteQuery("PRIMARY KEY ")
   307  				e.WriteExpr(key.Def.TableExpr(key.Table))
   308  			}
   309  		})
   310  
   311  		expr.WriteQueryByte('\n')
   312  	})
   313  
   314  	expr.WriteEnd()
   315  	exprs = append(exprs, expr)
   316  
   317  	t.Keys.Range(func(key *builder.Key, idx int) {
   318  		if !key.IsPrimary() {
   319  			exprs = append(exprs, c.AddIndex(key))
   320  		}
   321  	})
   322  
   323  	return
   324  }
   325  
   326  func (c *PostgreSQLConnector) DropTable(t *builder.Table) builder.SqlExpr {
   327  	e := builder.Expr("DROP TABLE IF EXISTS ")
   328  	e.WriteExpr(t)
   329  	e.WriteEnd()
   330  	return e
   331  }
   332  
   333  func (c *PostgreSQLConnector) TruncateTable(t *builder.Table) builder.SqlExpr {
   334  	e := builder.Expr("TRUNCATE TABLE ")
   335  	e.WriteExpr(t)
   336  	e.WriteEnd()
   337  	return e
   338  }
   339  
   340  func (c *PostgreSQLConnector) AddColumn(col *builder.Column) builder.SqlExpr {
   341  	e := builder.Expr("ALTER TABLE ")
   342  	e.WriteExpr(col.Table)
   343  	e.WriteQuery(" ADD COLUMN ")
   344  	e.WriteExpr(col)
   345  	e.WriteQueryByte(' ')
   346  	e.WriteExpr(c.DataType(col.ColumnType))
   347  	e.WriteEnd()
   348  	return e
   349  }
   350  
   351  func (c *PostgreSQLConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr {
   352  	e := builder.Expr("ALTER TABLE ")
   353  	e.WriteExpr(col.Table)
   354  	e.WriteQuery(" RENAME COLUMN ")
   355  	e.WriteExpr(col)
   356  	e.WriteQuery(" TO ")
   357  	e.WriteExpr(target)
   358  	e.WriteEnd()
   359  	return e
   360  }
   361  
   362  func (c *PostgreSQLConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr {
   363  	if col.AutoIncrement {
   364  		return nil
   365  	}
   366  
   367  	e := builder.Expr("ALTER TABLE ")
   368  	e.WriteExpr(col.Table)
   369  
   370  	dbDataType := c.dataType(col.ColumnType.Type, col.ColumnType)
   371  	prevDbDataType := c.dataType(prev.ColumnType.Type, prev.ColumnType)
   372  
   373  	isFirstSub := true
   374  	isEmpty := true
   375  
   376  	prepareAppendSubCmd := func() {
   377  		if !isFirstSub {
   378  			e.WriteQueryByte(',')
   379  		}
   380  		isFirstSub = false
   381  		isEmpty = false
   382  	}
   383  
   384  	if dbDataType != prevDbDataType {
   385  		prepareAppendSubCmd()
   386  
   387  		e.WriteQuery(" ALTER COLUMN ")
   388  		e.WriteExpr(col)
   389  		e.WriteQuery(" TYPE ")
   390  		e.WriteQuery(dbDataType)
   391  
   392  		e.WriteQuery(" /* FROM ")
   393  		e.WriteQuery(prevDbDataType)
   394  		e.WriteQuery(" */")
   395  	}
   396  
   397  	if col.Null != prev.Null {
   398  		prepareAppendSubCmd()
   399  
   400  		e.WriteQuery(" ALTER COLUMN ")
   401  		e.WriteExpr(col)
   402  		if !col.Null {
   403  			e.WriteQuery(" SET NOT NULL")
   404  		} else {
   405  			e.WriteQuery(" DROP NOT NULL")
   406  		}
   407  	}
   408  
   409  	defaultValue := normalizeDefaultValue(col.Default, dbDataType)
   410  	prevDefaultValue := normalizeDefaultValue(prev.Default, prevDbDataType)
   411  
   412  	if defaultValue != prevDefaultValue {
   413  		prepareAppendSubCmd()
   414  
   415  		e.WriteQuery(" ALTER COLUMN ")
   416  		e.WriteExpr(col)
   417  		if col.Default != nil {
   418  			e.WriteQuery(" SET DEFAULT ")
   419  			e.WriteQuery(defaultValue)
   420  
   421  			e.WriteQuery(" /* FROM ")
   422  			e.WriteQuery(prevDefaultValue)
   423  			e.WriteQuery(" */")
   424  		} else {
   425  			e.WriteQuery(" DROP DEFAULT")
   426  		}
   427  	}
   428  
   429  	if isEmpty {
   430  		return nil
   431  	}
   432  
   433  	e.WriteEnd()
   434  
   435  	return e
   436  }
   437  
   438  func (c *PostgreSQLConnector) DropColumn(col *builder.Column) builder.SqlExpr {
   439  	e := builder.Expr("ALTER TABLE ")
   440  	e.WriteExpr(col.Table)
   441  	e.WriteQuery(" DROP COLUMN ")
   442  	e.WriteQuery(col.Name)
   443  	e.WriteEnd()
   444  	return e
   445  }
   446  
   447  func (c *PostgreSQLConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr {
   448  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   449  	return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType, dbDataType))
   450  }
   451  
   452  func (c *PostgreSQLConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string {
   453  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   454  	return dbDataType + autocompleteSize(dbDataType, columnType)
   455  }
   456  
   457  func (c *PostgreSQLConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string {
   458  	if columnType.DataType != "" {
   459  		return columnType.DataType
   460  	}
   461  
   462  	if rv, ok := typex.TryNew(typ); ok {
   463  		if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok {
   464  			return dtd.DataType(c.DriverName())
   465  		}
   466  	}
   467  
   468  	switch typ.Kind() {
   469  	case reflect.Ptr:
   470  		return c.dataType(typ.Elem(), columnType)
   471  	case reflect.Bool:
   472  		return "boolean"
   473  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
   474  		if columnType.AutoIncrement {
   475  			return "serial"
   476  		}
   477  		return "integer"
   478  	case reflect.Int64, reflect.Uint64:
   479  		if columnType.AutoIncrement {
   480  			return "bigserial"
   481  		}
   482  		return "bigint"
   483  	case reflect.Float64:
   484  		return "double precision"
   485  	case reflect.Float32:
   486  		return "real"
   487  	case reflect.Slice:
   488  		if typ.Elem().Kind() == reflect.Uint8 {
   489  			return "bytea"
   490  		}
   491  	case reflect.String:
   492  		size := columnType.Length
   493  		if size < 65535/3 {
   494  			return "varchar"
   495  		}
   496  		return "text"
   497  	}
   498  
   499  	switch typ.Name() {
   500  	case "Hstore":
   501  		return "hstore"
   502  	case "ByteaArray":
   503  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.ByteaArray{[]byte("")}[0])), columnType) + "[]"
   504  	case "BoolArray":
   505  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.BoolArray{true}[0])), columnType) + "[]"
   506  	case "Float64Array":
   507  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.Float64Array{0}[0])), columnType) + "[]"
   508  	case "Int64Array":
   509  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.Int64Array{0}[0])), columnType) + "[]"
   510  	case "StringArray":
   511  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.StringArray{""}[0])), columnType) + "[]"
   512  	case "NullInt64":
   513  		return "bigint"
   514  	case "NullFloat64":
   515  		return "double precision"
   516  	case "NullBool":
   517  		return "boolean"
   518  	case "Time", "NullTime":
   519  		return "timestamp with time zone"
   520  	}
   521  
   522  	panic(fmt.Errorf("unsupport type %s", typ))
   523  }
   524  
   525  func (c *PostgreSQLConnector) dataTypeModify(columnType *builder.ColumnType, dataType string) string {
   526  	buf := bytes.NewBuffer(nil)
   527  
   528  	if !columnType.Null {
   529  		buf.WriteString(" NOT NULL")
   530  	}
   531  
   532  	if columnType.Default != nil {
   533  		buf.WriteString(" DEFAULT ")
   534  		buf.WriteString(normalizeDefaultValue(columnType.Default, dataType))
   535  	}
   536  
   537  	return buf.String()
   538  }
   539  
   540  func normalizeDefaultValue(defaultValue *string, dataType string) string {
   541  	if defaultValue == nil {
   542  		return ""
   543  	}
   544  
   545  	dv := *defaultValue
   546  
   547  	if dv[0] == '\'' {
   548  		if strings.Contains(dv, "'::") {
   549  			return dv
   550  		}
   551  		return dv + "::" + dataType
   552  	}
   553  
   554  	_, err := strconv.ParseFloat(dv, 64)
   555  	if err == nil {
   556  		return "'" + dv + "'::" + dataType
   557  	}
   558  
   559  	return dv
   560  }
   561  
   562  func autocompleteSize(dataType string, columnType *builder.ColumnType) string {
   563  	switch dataType {
   564  	case "character varying", "character":
   565  		size := columnType.Length
   566  		if size == 0 {
   567  			size = 255
   568  		}
   569  		return sizeModifier(size, columnType.Decimal)
   570  	case "decimal", "numeric", "real", "double precision":
   571  		if columnType.Length > 0 {
   572  			return sizeModifier(columnType.Length, columnType.Decimal)
   573  		}
   574  	}
   575  	return ""
   576  }
   577  
   578  func dealias(dataType string) string {
   579  	switch dataType {
   580  	case "varchar":
   581  		return "character varying"
   582  	case "timestamp":
   583  		return "timestamp without time zone"
   584  	}
   585  	return dataType
   586  }
   587  
   588  func sizeModifier(length uint64, decimal uint64) string {
   589  	if length > 0 {
   590  		size := strconv.FormatUint(length, 10)
   591  		if decimal > 0 {
   592  			return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")"
   593  		}
   594  		return "(" + size + ")"
   595  	}
   596  	return ""
   597  }