github.com/go-courier/sqlx/v2@v2.23.13/connectors/postgresql/postgresql_connector.go (about)

     1  package postgresql
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"io"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  
    13  	typex "github.com/go-courier/x/types"
    14  
    15  	"github.com/go-courier/sqlx/v2"
    16  	"github.com/go-courier/sqlx/v2/builder"
    17  	"github.com/go-courier/sqlx/v2/migration"
    18  	"github.com/lib/pq"
    19  )
    20  
    21  var _ interface {
    22  	driver.Connector
    23  	builder.Dialect
    24  } = (*PostgreSQLConnector)(nil)
    25  
    26  type PostgreSQLConnector struct {
    27  	Host       string
    28  	DBName     string
    29  	Extra      string
    30  	Extensions []string
    31  }
    32  
    33  func (c *PostgreSQLConnector) Connect(ctx context.Context) (driver.Conn, error) {
    34  	d := c.Driver()
    35  
    36  	conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra))
    37  	if err != nil {
    38  		if c.IsErrorUnknownDatabase(err) {
    39  			connectForCreateDB, err := d.Open(dsn(c.Host, "", c.Extra))
    40  			if err != nil {
    41  				return nil, err
    42  			}
    43  			if _, err := connectForCreateDB.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil {
    44  				return nil, err
    45  			}
    46  			if err := connectForCreateDB.Close(); err != nil {
    47  				return nil, err
    48  			}
    49  			return c.Connect(ctx)
    50  		}
    51  		return nil, err
    52  	}
    53  	for _, ex := range c.Extensions {
    54  		if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), "CREATE EXTENSION IF NOT EXISTS "+ex+";", nil); err != nil {
    55  			return nil, err
    56  		}
    57  	}
    58  
    59  	return conn, nil
    60  }
    61  
    62  func (PostgreSQLConnector) Driver() driver.Driver {
    63  	return &PostgreSQLLoggingDriver{}
    64  }
    65  
    66  func dsn(host string, dbName string, extra string) string {
    67  	if extra != "" {
    68  		extra = "?" + extra
    69  	}
    70  	return host + "/" + dbName + extra
    71  }
    72  
    73  func (c PostgreSQLConnector) WithDBName(dbName string) driver.Connector {
    74  	c.DBName = dbName
    75  	return &c
    76  }
    77  
    78  func (c *PostgreSQLConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error {
    79  	output := migration.MigrationOutputFromContext(ctx)
    80  
    81  	prevDB, err := dbFromInformationSchema(db)
    82  	if err != nil {
    83  		return err
    84  	}
    85  
    86  	d := db.D()
    87  	dialect := db.Dialect()
    88  
    89  	exec := func(expr builder.SqlExpr) error {
    90  		if expr == nil || expr.IsNil() {
    91  			return nil
    92  		}
    93  
    94  		if output != nil {
    95  			_, _ = io.WriteString(output, builder.ResolveExpr(expr).Query())
    96  			_, _ = io.WriteString(output, "\n")
    97  			return nil
    98  		}
    99  
   100  		_, err := db.ExecExpr(expr)
   101  		return err
   102  	}
   103  
   104  	if prevDB == nil {
   105  		prevDB = &sqlx.Database{
   106  			Name: d.Name,
   107  		}
   108  		if err := exec(dialect.CreateDatabase(d.Name)); err != nil {
   109  			return err
   110  		}
   111  	}
   112  
   113  	if d.Schema != "" {
   114  		if err := exec(dialect.CreateSchema(d.Schema)); err != nil {
   115  			return err
   116  		}
   117  		prevDB = prevDB.WithSchema(d.Schema)
   118  	}
   119  
   120  	for _, name := range d.Tables.TableNames() {
   121  		table := d.Table(name)
   122  
   123  		prevTable := prevDB.Table(name)
   124  
   125  		if prevTable == nil {
   126  			for _, expr := range dialect.CreateTableIsNotExists(table) {
   127  				if err := exec(expr); err != nil {
   128  					return err
   129  				}
   130  			}
   131  			continue
   132  		}
   133  
   134  		exprList := table.Diff(prevTable, dialect)
   135  
   136  		for _, expr := range exprList {
   137  			if err := exec(expr); err != nil {
   138  				return err
   139  			}
   140  		}
   141  	}
   142  
   143  	return nil
   144  }
   145  
   146  func (PostgreSQLConnector) DriverName() string {
   147  	return "postgres"
   148  }
   149  
   150  func (PostgreSQLConnector) PrimaryKeyName() string {
   151  	return "pkey"
   152  }
   153  
   154  func (PostgreSQLConnector) IsErrorUnknownDatabase(err error) bool {
   155  	if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "3D000" {
   156  		return true
   157  	}
   158  	return false
   159  }
   160  
   161  func (PostgreSQLConnector) IsErrorConflict(err error) bool {
   162  	if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "23505" {
   163  		return true
   164  	}
   165  	return false
   166  }
   167  
   168  func (c *PostgreSQLConnector) CreateDatabase(dbName string) builder.SqlExpr {
   169  	e := builder.Expr("CREATE DATABASE ")
   170  	e.WriteQuery(dbName)
   171  	e.WriteEnd()
   172  	return e
   173  }
   174  
   175  func (c *PostgreSQLConnector) CreateSchema(schema string) builder.SqlExpr {
   176  	e := builder.Expr("CREATE SCHEMA IF NOT EXISTS ")
   177  	e.WriteQuery(schema)
   178  	e.WriteEnd()
   179  	return e
   180  }
   181  
   182  func (c *PostgreSQLConnector) DropDatabase(dbName string) builder.SqlExpr {
   183  	e := builder.Expr("DROP DATABASE IF EXISTS ")
   184  	e.WriteQuery(dbName)
   185  	e.WriteEnd()
   186  	return e
   187  }
   188  
   189  func (c *PostgreSQLConnector) AddIndex(key *builder.Key) builder.SqlExpr {
   190  	if key.IsPrimary() {
   191  		e := builder.Expr("ALTER TABLE ")
   192  		e.WriteExpr(key.Table)
   193  		e.WriteQuery(" ADD PRIMARY KEY ")
   194  		e.WriteExpr(key.Def.TableExpr(key.Table))
   195  		e.WriteEnd()
   196  		return e
   197  	}
   198  
   199  	e := builder.Expr("CREATE ")
   200  	if key.IsUnique {
   201  		e.WriteQuery("UNIQUE ")
   202  	}
   203  	e.WriteQuery("INDEX ")
   204  
   205  	e.WriteQuery(key.Table.Name)
   206  	e.WriteQuery("_")
   207  	e.WriteQuery(key.Name)
   208  
   209  	e.WriteQuery(" ON ")
   210  	e.WriteExpr(key.Table)
   211  
   212  	if m := strings.ToUpper(key.Method); m != "" {
   213  		if m == "SPATIAL" {
   214  			m = "GIST"
   215  		}
   216  		e.WriteQuery(" USING ")
   217  		e.WriteQuery(m)
   218  	}
   219  
   220  	e.WriteQueryByte(' ')
   221  	e.WriteExpr(key.Def.TableExpr(key.Table))
   222  
   223  	e.WriteEnd()
   224  	return e
   225  }
   226  
   227  func (c *PostgreSQLConnector) DropIndex(key *builder.Key) builder.SqlExpr {
   228  	if key.IsPrimary() {
   229  		e := builder.Expr("ALTER TABLE ")
   230  		e.WriteExpr(key.Table)
   231  		e.WriteQuery(" DROP CONSTRAINT ")
   232  		e.WriteExpr(key.Table)
   233  		e.WriteQuery("_pkey")
   234  		e.WriteEnd()
   235  		return e
   236  	}
   237  	e := builder.Expr("DROP ")
   238  
   239  	e.WriteQuery("INDEX IF EXISTS ")
   240  	e.WriteExpr(key.Table)
   241  	e.WriteQueryByte('_')
   242  	e.WriteQuery(key.Name)
   243  	e.WriteEnd()
   244  
   245  	return e
   246  }
   247  
   248  func (c *PostgreSQLConnector) CreateTableIsNotExists(t *builder.Table) (exprs []builder.SqlExpr) {
   249  	expr := builder.Expr("CREATE TABLE IF NOT EXISTS ")
   250  	expr.WriteExpr(t)
   251  	expr.WriteQueryByte(' ')
   252  	expr.WriteGroup(func(e *builder.Ex) {
   253  		if t.Columns.IsNil() {
   254  			return
   255  		}
   256  
   257  		t.Columns.Range(func(col *builder.Column, idx int) {
   258  			if col.DeprecatedActions != nil {
   259  				return
   260  			}
   261  
   262  			if idx > 0 {
   263  				e.WriteQueryByte(',')
   264  			}
   265  			e.WriteQueryByte('\n')
   266  			e.WriteQueryByte('\t')
   267  
   268  			e.WriteExpr(col)
   269  			e.WriteQueryByte(' ')
   270  			e.WriteExpr(c.DataType(col.ColumnType))
   271  		})
   272  
   273  		t.Keys.Range(func(key *builder.Key, idx int) {
   274  			if key.IsPrimary() {
   275  				e.WriteQueryByte(',')
   276  				e.WriteQueryByte('\n')
   277  				e.WriteQueryByte('\t')
   278  				e.WriteQuery("PRIMARY KEY ")
   279  				e.WriteExpr(key.Def.TableExpr(key.Table))
   280  			}
   281  		})
   282  
   283  		expr.WriteQueryByte('\n')
   284  	})
   285  
   286  	expr.WriteEnd()
   287  	exprs = append(exprs, expr)
   288  
   289  	t.Keys.Range(func(key *builder.Key, idx int) {
   290  		if !key.IsPrimary() {
   291  			exprs = append(exprs, c.AddIndex(key))
   292  		}
   293  	})
   294  
   295  	return
   296  }
   297  
   298  func (c *PostgreSQLConnector) DropTable(t *builder.Table) builder.SqlExpr {
   299  	e := builder.Expr("DROP TABLE IF EXISTS ")
   300  	e.WriteExpr(t)
   301  	e.WriteEnd()
   302  	return e
   303  }
   304  
   305  func (c *PostgreSQLConnector) TruncateTable(t *builder.Table) builder.SqlExpr {
   306  	e := builder.Expr("TRUNCATE TABLE ")
   307  	e.WriteExpr(t)
   308  	e.WriteEnd()
   309  	return e
   310  }
   311  
   312  func (c *PostgreSQLConnector) AddColumn(col *builder.Column) builder.SqlExpr {
   313  	e := builder.Expr("ALTER TABLE ")
   314  	e.WriteExpr(col.Table)
   315  	e.WriteQuery(" ADD COLUMN ")
   316  	e.WriteExpr(col)
   317  	e.WriteQueryByte(' ')
   318  	e.WriteExpr(c.DataType(col.ColumnType))
   319  	e.WriteEnd()
   320  	return e
   321  }
   322  
   323  func (c *PostgreSQLConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr {
   324  	e := builder.Expr("ALTER TABLE ")
   325  	e.WriteExpr(col.Table)
   326  	e.WriteQuery(" RENAME COLUMN ")
   327  	e.WriteExpr(col)
   328  	e.WriteQuery(" TO ")
   329  	e.WriteExpr(target)
   330  	e.WriteEnd()
   331  	return e
   332  }
   333  
   334  func (c *PostgreSQLConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr {
   335  	if col.AutoIncrement {
   336  		return nil
   337  	}
   338  
   339  	e := builder.Expr("ALTER TABLE ")
   340  	e.WriteExpr(col.Table)
   341  
   342  	dbDataType := c.dataType(col.ColumnType.Type, col.ColumnType)
   343  	prevDbDataType := c.dataType(prev.ColumnType.Type, prev.ColumnType)
   344  
   345  	isFirstSub := true
   346  	isEmpty := true
   347  
   348  	prepareAppendSubCmd := func() {
   349  		if !isFirstSub {
   350  			e.WriteQueryByte(',')
   351  		}
   352  		isFirstSub = false
   353  		isEmpty = false
   354  	}
   355  
   356  	if dbDataType != prevDbDataType {
   357  		prepareAppendSubCmd()
   358  
   359  		e.WriteQuery(" ALTER COLUMN ")
   360  		e.WriteExpr(col)
   361  		e.WriteQuery(" TYPE ")
   362  		e.WriteQuery(dbDataType)
   363  
   364  		e.WriteQuery(" /* FROM ")
   365  		e.WriteQuery(prevDbDataType)
   366  		e.WriteQuery(" */")
   367  	}
   368  
   369  	if col.Null != prev.Null {
   370  		prepareAppendSubCmd()
   371  
   372  		e.WriteQuery(" ALTER COLUMN ")
   373  		e.WriteExpr(col)
   374  		if !col.Null {
   375  			e.WriteQuery(" SET NOT NULL")
   376  		} else {
   377  			e.WriteQuery(" DROP NOT NULL")
   378  		}
   379  	}
   380  
   381  	defaultValue := normalizeDefaultValue(col.Default, dbDataType)
   382  	prevDefaultValue := normalizeDefaultValue(prev.Default, prevDbDataType)
   383  
   384  	if defaultValue != prevDefaultValue {
   385  		prepareAppendSubCmd()
   386  
   387  		e.WriteQuery(" ALTER COLUMN ")
   388  		e.WriteExpr(col)
   389  		if col.Default != nil {
   390  			e.WriteQuery(" SET DEFAULT ")
   391  			e.WriteQuery(defaultValue)
   392  
   393  			e.WriteQuery(" /* FROM ")
   394  			e.WriteQuery(prevDefaultValue)
   395  			e.WriteQuery(" */")
   396  		} else {
   397  			e.WriteQuery(" DROP DEFAULT")
   398  		}
   399  	}
   400  
   401  	if isEmpty {
   402  		return nil
   403  	}
   404  
   405  	e.WriteEnd()
   406  
   407  	return e
   408  }
   409  
   410  func (c *PostgreSQLConnector) DropColumn(col *builder.Column) builder.SqlExpr {
   411  	e := builder.Expr("ALTER TABLE ")
   412  	e.WriteExpr(col.Table)
   413  	e.WriteQuery(" DROP COLUMN ")
   414  	e.WriteQuery(col.Name)
   415  	e.WriteEnd()
   416  	return e
   417  }
   418  
   419  func (c *PostgreSQLConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr {
   420  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   421  	return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType, dbDataType))
   422  }
   423  
   424  func (c *PostgreSQLConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string {
   425  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   426  	return dbDataType + autocompleteSize(dbDataType, columnType)
   427  }
   428  
   429  func (c *PostgreSQLConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string {
   430  	if columnType.DataType != "" {
   431  		return columnType.DataType
   432  	}
   433  
   434  	if rv, ok := typex.TryNew(typ); ok {
   435  		if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok {
   436  			return dtd.DataType(c.DriverName())
   437  		}
   438  	}
   439  
   440  	switch typ.Kind() {
   441  	case reflect.Ptr:
   442  		return c.dataType(typ.Elem(), columnType)
   443  	case reflect.Bool:
   444  		return "boolean"
   445  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
   446  		if columnType.AutoIncrement {
   447  			return "serial"
   448  		}
   449  		return "integer"
   450  	case reflect.Int64, reflect.Uint64:
   451  		if columnType.AutoIncrement {
   452  			return "bigserial"
   453  		}
   454  		return "bigint"
   455  	case reflect.Float64:
   456  		return "double precision"
   457  	case reflect.Float32:
   458  		return "real"
   459  	case reflect.Slice:
   460  		if typ.Elem().Kind() == reflect.Uint8 {
   461  			return "bytea"
   462  		}
   463  	case reflect.String:
   464  		size := columnType.Length
   465  		if size < 65535/3 {
   466  			return "varchar"
   467  		}
   468  		return "text"
   469  	}
   470  
   471  	switch typ.Name() {
   472  	case "Hstore":
   473  		return "hstore"
   474  	case "ByteaArray":
   475  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.ByteaArray{[]byte("")}[0])), columnType) + "[]"
   476  	case "BoolArray":
   477  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.BoolArray{true}[0])), columnType) + "[]"
   478  	case "Float64Array":
   479  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.Float64Array{0}[0])), columnType) + "[]"
   480  	case "Int64Array":
   481  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.Int64Array{0}[0])), columnType) + "[]"
   482  	case "StringArray":
   483  		return c.dataType(typex.FromRType(reflect.TypeOf(pq.StringArray{""}[0])), columnType) + "[]"
   484  	case "NullInt64":
   485  		return "bigint"
   486  	case "NullFloat64":
   487  		return "double precision"
   488  	case "NullBool":
   489  		return "boolean"
   490  	case "Time", "NullTime":
   491  		return "timestamp with time zone"
   492  	}
   493  
   494  	panic(fmt.Errorf("unsupport type %s", typ))
   495  }
   496  
   497  func (c *PostgreSQLConnector) dataTypeModify(columnType *builder.ColumnType, dataType string) string {
   498  	buf := bytes.NewBuffer(nil)
   499  
   500  	if !columnType.Null {
   501  		buf.WriteString(" NOT NULL")
   502  	}
   503  
   504  	if columnType.Default != nil {
   505  		buf.WriteString(" DEFAULT ")
   506  		buf.WriteString(normalizeDefaultValue(columnType.Default, dataType))
   507  	}
   508  
   509  	return buf.String()
   510  }
   511  
   512  func normalizeDefaultValue(defaultValue *string, dataType string) string {
   513  	if defaultValue == nil {
   514  		return ""
   515  	}
   516  
   517  	dv := *defaultValue
   518  
   519  	if dv[0] == '\'' {
   520  		if strings.Contains(dv, "'::") {
   521  			return dv
   522  		}
   523  		return dv + "::" + dataType
   524  	}
   525  
   526  	_, err := strconv.ParseFloat(dv, 64)
   527  	if err == nil {
   528  		return "'" + dv + "'::" + dataType
   529  	}
   530  
   531  	return dv
   532  }
   533  
   534  func autocompleteSize(dataType string, columnType *builder.ColumnType) string {
   535  	switch dataType {
   536  	case "character varying", "character":
   537  		size := columnType.Length
   538  		if size == 0 {
   539  			size = 255
   540  		}
   541  		return sizeModifier(size, columnType.Decimal)
   542  	case "decimal", "numeric", "real", "double precision":
   543  		if columnType.Length > 0 {
   544  			return sizeModifier(columnType.Length, columnType.Decimal)
   545  		}
   546  	}
   547  	return ""
   548  }
   549  
   550  func dealias(dataType string) string {
   551  	switch dataType {
   552  	case "varchar":
   553  		return "character varying"
   554  	case "timestamp":
   555  		return "timestamp without time zone"
   556  	}
   557  	return dataType
   558  }
   559  
   560  func sizeModifier(length uint64, decimal uint64) string {
   561  	if length > 0 {
   562  		size := strconv.FormatUint(length, 10)
   563  		if decimal > 0 {
   564  			return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")"
   565  		}
   566  		return "(" + size + ")"
   567  	}
   568  	return ""
   569  }