github.com/kunlun-qilian/sqlx/v2@v2.24.0/connectors/postgresql/postgresql_connector.go (about)

     1  package postgresql
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"io"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  
    13  	typex "github.com/go-courier/x/types"
    14  
    15  	"github.com/kunlun-qilian/sqlx/v2"
    16  	"github.com/kunlun-qilian/sqlx/v2/builder"
    17  	"github.com/kunlun-qilian/sqlx/v2/migration"
    18  	"github.com/lib/pq"
    19  )
    20  
    21  var _ interface {
    22  	driver.Connector
    23  	builder.Dialect
    24  } = (*PostgreSQLConnector)(nil)
    25  
    26  type PostgreSQLConnector struct {
    27  	Host       string
    28  	DBName     string
    29  	Extra      string
    30  	Extensions []string
    31  }
    32  
    33  func (c *PostgreSQLConnector) Connect(ctx context.Context) (driver.Conn, error) {
    34  	d := c.Driver()
    35  
    36  	conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra))
    37  	if err != nil {
    38  		if c.IsErrorUnknownDatabase(err) {
    39  			connectForCreateDB, err := d.Open(dsn(c.Host, "", c.Extra))
    40  			if err != nil {
    41  				return nil, err
    42  			}
    43  			if _, err := connectForCreateDB.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil {
    44  				return nil, err
    45  			}
    46  			if err := connectForCreateDB.Close(); err != nil {
    47  				return nil, err
    48  			}
    49  			return c.Connect(ctx)
    50  		}
    51  		return nil, err
    52  	}
    53  	for _, ex := range c.Extensions {
    54  		if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), "CREATE EXTENSION IF NOT EXISTS "+ex+";", nil); err != nil {
    55  			return nil, err
    56  		}
    57  	}
    58  
    59  	return conn, nil
    60  }
    61  
    62  func (PostgreSQLConnector) Driver() driver.Driver {
    63  	return &PostgreSQLLoggingDriver{}
    64  }
    65  
    66  func dsn(host string, dbName string, extra string) string {
    67  	if extra != "" {
    68  		extra = "?" + extra
    69  	}
    70  	return host + "/" + dbName + extra
    71  }
    72  
    73  func (c PostgreSQLConnector) WithDBName(dbName string) driver.Connector {
    74  	c.DBName = dbName
    75  	return &c
    76  }
    77  
    78  func (c *PostgreSQLConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error {
    79  	output := migration.MigrationOutputFromContext(ctx)
    80  
    81  	prevDB, err := dbFromInformationSchema(db)
    82  	if err != nil {
    83  		return err
    84  	}
    85  
    86  	d := db.D()
    87  	dialect := db.Dialect()
    88  
    89  	exec := func(expr builder.SqlExpr) error {
    90  		if expr == nil || expr.IsNil() {
    91  			return nil
    92  		}
    93  
    94  		if output != nil {
    95  			_, _ = io.WriteString(output, builder.ResolveExpr(expr).Query())
    96  			_, _ = io.WriteString(output, "\n")
    97  			return nil
    98  		}
    99  
   100  		_, err := db.ExecExpr(expr)
   101  		return err
   102  	}
   103  
   104  	if prevDB == nil {
   105  		prevDB = &sqlx.Database{
   106  			Name: d.Name,
   107  		}
   108  		if err := exec(dialect.CreateDatabase(d.Name)); err != nil {
   109  			return err
   110  		}
   111  	}
   112  
   113  	if d.Schema != "" {
   114  		if err := exec(dialect.CreateSchema(d.Schema)); err != nil {
   115  			return err
   116  		}
   117  		prevDB = prevDB.WithSchema(d.Schema)
   118  	}
   119  
   120  	for _, name := range d.Tables.TableNames() {
   121  		table := d.Table(name)
   122  
   123  		prevTable := prevDB.Table(name)
   124  
   125  		if prevTable == nil {
   126  			for _, expr := range dialect.CreateTableIsNotExists(table) {
   127  				if err := exec(expr); err != nil {
   128  					return err
   129  				}
   130  			}
   131  			continue
   132  		}
   133  
   134  		exprList := table.Diff(prevTable, dialect)
   135  
   136  		for _, expr := range exprList {
   137  			if err := exec(expr); err != nil {
   138  				return err
   139  			}
   140  		}
   141  	}
   142  
   143  	return nil
   144  }
   145  
   146  func (PostgreSQLConnector) DriverName() string {
   147  	return "postgres"
   148  }
   149  
   150  func (PostgreSQLConnector) PrimaryKeyName() string {
   151  	return "pkey"
   152  }
   153  
   154  func (PostgreSQLConnector) IsErrorUnknownDatabase(err error) bool {
   155  	if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "3D000" {
   156  		return true
   157  	}
   158  	return false
   159  }
   160  
   161  func (PostgreSQLConnector) IsErrorConflict(err error) bool {
   162  	if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "23505" {
   163  		return true
   164  	}
   165  	return false
   166  }
   167  
   168  func (c *PostgreSQLConnector) CreateDatabase(dbName string) builder.SqlExpr {
   169  	e := builder.Expr("CREATE DATABASE ")
   170  	e.WriteQuery(dbName)
   171  	e.WriteEnd()
   172  	return e
   173  }
   174  
   175  func (c *PostgreSQLConnector) CreateSchema(schema string) builder.SqlExpr {
   176  	e := builder.Expr("CREATE SCHEMA IF NOT EXISTS ")
   177  	e.WriteQuery(schema)
   178  	e.WriteEnd()
   179  	return e
   180  }
   181  
   182  func (c *PostgreSQLConnector) DropDatabase(dbName string) builder.SqlExpr {
   183  	e := builder.Expr("DROP DATABASE IF EXISTS ")
   184  	e.WriteQuery(dbName)
   185  	e.WriteEnd()
   186  	return e
   187  }
   188  
   189  func (c *PostgreSQLConnector) AddIndex(key *builder.Key) builder.SqlExpr {
   190  	if key.IsPrimary() {
   191  		e := builder.Expr("ALTER TABLE ")
   192  		e.WriteExpr(key.Table)
   193  		e.WriteQuery(" ADD PRIMARY KEY ")
   194  		e.WriteExpr(key.Def.TableExpr(key.Table))
   195  		e.WriteEnd()
   196  		return e
   197  	}
   198  
   199  	e := builder.Expr("CREATE ")
   200  	if key.IsUnique {
   201  		e.WriteQuery("UNIQUE ")
   202  	}
   203  	e.WriteQuery("INDEX ")
   204  
   205  	e.WriteQuery(key.Table.Name)
   206  	e.WriteQuery("_")
   207  	e.WriteQuery(key.Name)
   208  
   209  	e.WriteQuery(" ON ")
   210  	e.WriteExpr(key.Table)
   211  
   212  	if m := strings.ToUpper(key.Method); m != "" {
   213  		if m == "SPATIAL" {
   214  			m = "GIST"
   215  		}
   216  		e.WriteQuery(" USING ")
   217  		e.WriteQuery(m)
   218  	}
   219  
   220  	e.WriteQueryByte(' ')
   221  	e.WriteExpr(key.Def.TableExpr(key.Table))
   222  
   223  	e.WriteEnd()
   224  	return e
   225  }
   226  
   227  func (c *PostgreSQLConnector) DropIndex(key *builder.Key) builder.SqlExpr {
   228  	if key.IsPrimary() {
   229  		e := builder.Expr("ALTER TABLE ")
   230  		e.WriteExpr(key.Table)
   231  		e.WriteQuery(" DROP CONSTRAINT ")
   232  		e.WriteExpr(key.Table)
   233  		e.WriteQuery("_pkey")
   234  		e.WriteEnd()
   235  		return e
   236  	}
   237  	e := builder.Expr("DROP ")
   238  
   239  	e.WriteQuery("INDEX IF EXISTS ")
   240  	e.WriteExpr(key.Table)
   241  	e.WriteQueryByte('_')
   242  	e.WriteQuery(key.Name)
   243  	e.WriteEnd()
   244  
   245  	return e
   246  }
   247  
   248  func (c *PostgreSQLConnector) CreateTableIsNotExists(t *builder.Table) (exprs []builder.SqlExpr) {
   249  	expr := builder.Expr("CREATE TABLE IF NOT EXISTS ")
   250  	expr.WriteExpr(t)
   251  	expr.WriteQueryByte(' ')
   252  	expr.WriteGroup(func(e *builder.Ex) {
   253  		if t.Columns.IsNil() {
   254  			return
   255  		}
   256  
   257  		t.Columns.Range(func(col *builder.Column, idx int) {
   258  			if col.DeprecatedActions != nil {
   259  				return
   260  			}
   261  
   262  			if idx > 0 {
   263  				e.WriteQueryByte(',')
   264  			}
   265  			e.WriteQueryByte('\n')
   266  			e.WriteQueryByte('\t')
   267  
   268  			e.WriteExpr(col)
   269  			e.WriteQueryByte(' ')
   270  			e.WriteExpr(c.DataType(col.ColumnType))
   271  		})
   272  
   273  		t.Keys.Range(func(key *builder.Key, idx int) {
   274  			if key.IsPrimary() {
   275  				e.WriteQueryByte(',')
   276  				e.WriteQueryByte('\n')
   277  				e.WriteQueryByte('\t')
   278  				e.WriteQuery("PRIMARY KEY ")
   279  				e.WriteExpr(key.Def.TableExpr(key.Table))
   280  			}
   281  		})
   282  
   283  		expr.WriteQueryByte('\n')
   284  	})
   285  
   286  	t.Keys.Range(func(key *builder.Key, idx int) {
   287  		if key.IsPartition() {
   288  			expr.WriteQuery(" PARTITION BY ")
   289  			expr.WriteQuery(key.Method)
   290  			expr.WriteQueryByte(' ')
   291  			expr.WriteExpr(key.Def.TableExpr(key.Table))
   292  		}
   293  	})
   294  
   295  	expr.WriteEnd()
   296  	exprs = append(exprs, expr)
   297  
   298  	t.Keys.Range(func(key *builder.Key, idx int) {
   299  		if !key.IsPrimary() && !key.IsPartition() {
   300  			exprs = append(exprs, c.AddIndex(key))
   301  		}
   302  	})
   303  
   304  	return
   305  }
   306  
   307  func (c *PostgreSQLConnector) DropTable(t *builder.Table) builder.SqlExpr {
   308  	e := builder.Expr("DROP TABLE IF EXISTS ")
   309  	e.WriteExpr(t)
   310  	e.WriteEnd()
   311  	return e
   312  }
   313  
   314  func (c *PostgreSQLConnector) TruncateTable(t *builder.Table) builder.SqlExpr {
   315  	e := builder.Expr("TRUNCATE TABLE ")
   316  	e.WriteExpr(t)
   317  	e.WriteEnd()
   318  	return e
   319  }
   320  
   321  func (c *PostgreSQLConnector) AddColumn(col *builder.Column) builder.SqlExpr {
   322  	e := builder.Expr("ALTER TABLE ")
   323  	e.WriteExpr(col.Table)
   324  	e.WriteQuery(" ADD COLUMN ")
   325  	e.WriteExpr(col)
   326  	e.WriteQueryByte(' ')
   327  	e.WriteExpr(c.DataType(col.ColumnType))
   328  	e.WriteEnd()
   329  	return e
   330  }
   331  
   332  func (c *PostgreSQLConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr {
   333  	e := builder.Expr("ALTER TABLE ")
   334  	e.WriteExpr(col.Table)
   335  	e.WriteQuery(" RENAME COLUMN ")
   336  	e.WriteExpr(col)
   337  	e.WriteQuery(" TO ")
   338  	e.WriteExpr(target)
   339  	e.WriteEnd()
   340  	return e
   341  }
   342  
   343  func (c *PostgreSQLConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr {
   344  	if col.AutoIncrement {
   345  		return nil
   346  	}
   347  
   348  	e := builder.Expr("ALTER TABLE ")
   349  	e.WriteExpr(col.Table)
   350  
   351  	dbDataType := c.dataType(col.ColumnType.Type, col.ColumnType)
   352  	prevDbDataType := c.dataType(prev.ColumnType.Type, prev.ColumnType)
   353  
   354  	isFirstSub := true
   355  	isEmpty := true
   356  
   357  	prepareAppendSubCmd := func() {
   358  		if !isFirstSub {
   359  			e.WriteQueryByte(',')
   360  		}
   361  		isFirstSub = false
   362  		isEmpty = false
   363  	}
   364  
   365  	if dbDataType != prevDbDataType {
   366  		prepareAppendSubCmd()
   367  
   368  		e.WriteQuery(" ALTER COLUMN ")
   369  		e.WriteExpr(col)
   370  		e.WriteQuery(" TYPE ")
   371  		e.WriteQuery(dbDataType)
   372  
   373  		e.WriteQuery(" /* FROM ")
   374  		e.WriteQuery(prevDbDataType)
   375  		e.WriteQuery(" */")
   376  	}
   377  
   378  	if col.Null != prev.Null {
   379  		prepareAppendSubCmd()
   380  
   381  		e.WriteQuery(" ALTER COLUMN ")
   382  		e.WriteExpr(col)
   383  		if !col.Null {
   384  			e.WriteQuery(" SET NOT NULL")
   385  		} else {
   386  			e.WriteQuery(" DROP NOT NULL")
   387  		}
   388  	}
   389  
   390  	defaultValue := normalizeDefaultValue(col.Default, dbDataType)
   391  	prevDefaultValue := normalizeDefaultValue(prev.Default, prevDbDataType)
   392  
   393  	if defaultValue != prevDefaultValue {
   394  		prepareAppendSubCmd()
   395  
   396  		e.WriteQuery(" ALTER COLUMN ")
   397  		e.WriteExpr(col)
   398  		if col.Default != nil {
   399  			e.WriteQuery(" SET DEFAULT ")
   400  			e.WriteQuery(defaultValue)
   401  
   402  			e.WriteQuery(" /* FROM ")
   403  			e.WriteQuery(prevDefaultValue)
   404  			e.WriteQuery(" */")
   405  		} else {
   406  			e.WriteQuery(" DROP DEFAULT")
   407  		}
   408  	}
   409  
   410  	if isEmpty {
   411  		return nil
   412  	}
   413  
   414  	e.WriteEnd()
   415  
   416  	return e
   417  }
   418  
   419  func (c *PostgreSQLConnector) DropColumn(col *builder.Column) builder.SqlExpr {
   420  	e := builder.Expr("ALTER TABLE ")
   421  	e.WriteExpr(col.Table)
   422  	e.WriteQuery(" DROP COLUMN ")
   423  	e.WriteQuery(col.Name)
   424  	e.WriteEnd()
   425  	return e
   426  }
   427  
   428  func (c *PostgreSQLConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr {
   429  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   430  	return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType, dbDataType))
   431  }
   432  
   433  func (c *PostgreSQLConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string {
   434  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   435  	return dbDataType + autocompleteSize(dbDataType, columnType)
   436  }
   437  
   438  func (c *PostgreSQLConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string {
   439  	if columnType.DataType != "" {
   440  		return columnType.DataType
   441  	}
   442  
   443  	if rv, ok := typex.TryNew(typ); ok {
   444  		if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok {
   445  			return dtd.DataType(c.DriverName())
   446  		}
   447  	}
   448  
   449  	switch typ.Kind() {
   450  	case reflect.Ptr:
   451  		return c.dataType(typ.Elem(), columnType)
   452  	case reflect.Bool:
   453  		return "boolean"
   454  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
   455  		if columnType.AutoIncrement {
   456  			return "serial"
   457  		}
   458  		return "integer"
   459  	case reflect.Int64, reflect.Uint64:
   460  		if columnType.AutoIncrement {
   461  			return "bigserial"
   462  		}
   463  		return "bigint"
   464  	case reflect.Float64:
   465  		return "double precision"
   466  	case reflect.Float32:
   467  		return "real"
   468  	case reflect.Slice:
   469  		if typ.Elem().Kind() == reflect.Uint8 {
   470  			return "bytea"
   471  		}
   472  	case reflect.String:
   473  		size := columnType.Length
   474  		if size < 65535/3 {
   475  			return "varchar"
   476  		}
   477  		return "text"
   478  	}
   479  
   480  	switch typ.Name() {
   481  	case "Hstore":
   482  		return "hstore"
   483  	case "ByteaArray":
   484  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.ByteaArray{[]byte("")}[0])), columnType) + "[]"
   485  	case "BoolArray":
   486  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.BoolArray{true}[0])), columnType) + "[]"
   487  	case "Float64Array":
   488  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.Float64Array{0}[0])), columnType) + "[]"
   489  	case "Int64Array":
   490  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.Int64Array{0}[0])), columnType) + "[]"
   491  	case "StringArray":
   492  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.StringArray{""}[0])), columnType) + "[]"
   493  	case "NullInt64":
   494  		return "bigint"
   495  	case "NullFloat64":
   496  		return "double precision"
   497  	case "NullBool":
   498  		return "boolean"
   499  	case "Time", "NullTime":
   500  		return "timestamp with time zone"
   501  	}
   502  
   503  	panic(fmt.Errorf("unsupport type %s", typ))
   504  }
   505  
   506  func (c *PostgreSQLConnector) dataTypeModify(columnType *builder.ColumnType, dataType string) string {
   507  	buf := bytes.NewBuffer(nil)
   508  
   509  	if !columnType.Null {
   510  		buf.WriteString(" NOT NULL")
   511  	}
   512  
   513  	if columnType.Default != nil {
   514  		buf.WriteString(" DEFAULT ")
   515  		buf.WriteString(normalizeDefaultValue(columnType.Default, dataType))
   516  	}
   517  
   518  	return buf.String()
   519  }
   520  
   521  func normalizeDefaultValue(defaultValue *string, dataType string) string {
   522  	if defaultValue == nil {
   523  		return ""
   524  	}
   525  
   526  	dv := *defaultValue
   527  
   528  	if dv[0] == '\'' {
   529  		if strings.Contains(dv, "'::") {
   530  			return dv
   531  		}
   532  		return dv + "::" + dataType
   533  	}
   534  
   535  	_, err := strconv.ParseFloat(dv, 64)
   536  	if err == nil {
   537  		return "'" + dv + "'::" + dataType
   538  	}
   539  
   540  	return dv
   541  }
   542  
   543  func autocompleteSize(dataType string, columnType *builder.ColumnType) string {
   544  	switch dataType {
   545  	case "character varying", "character":
   546  		size := columnType.Length
   547  		if size == 0 {
   548  			size = 255
   549  		}
   550  		return sizeModifier(size, columnType.Decimal)
   551  	case "decimal", "numeric", "real", "double precision":
   552  		if columnType.Length > 0 {
   553  			return sizeModifier(columnType.Length, columnType.Decimal)
   554  		}
   555  	}
   556  	return ""
   557  }
   558  
   559  func dealias(dataType string) string {
   560  	switch dataType {
   561  	case "varchar":
   562  		return "character varying"
   563  	case "timestamp":
   564  		return "timestamp without time zone"
   565  	}
   566  	return dataType
   567  }
   568  
   569  func sizeModifier(length uint64, decimal uint64) string {
   570  	if length > 0 {
   571  		size := strconv.FormatUint(length, 10)
   572  		if decimal > 0 {
   573  			return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")"
   574  		}
   575  		return "(" + size + ")"
   576  	}
   577  	return ""
   578  }