github.com/eden-framework/sqlx@v0.0.2/postgresqlconnector/postgresql_connector.go (about)

     1  package postgresqlconnector
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"log"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/eden-framework/sqlx"
    14  	"github.com/eden-framework/sqlx/builder"
    15  	"github.com/eden-framework/sqlx/migration"
    16  	"github.com/lib/pq"
    17  	"github.com/sirupsen/logrus"
    18  )
    19  
    20  var _ interface {
    21  	driver.Connector
    22  	builder.Dialect
    23  } = (*PostgreSQLConnector)(nil)
    24  
    25  type PostgreSQLConnector struct {
    26  	Host       string
    27  	DBName     string
    28  	Extra      string
    29  	Extensions []string
    30  }
    31  
    32  func dsn(host string, dbName string, extra string) string {
    33  	if extra != "" {
    34  		extra = "?" + extra
    35  	}
    36  	return host + "/" + dbName + extra
    37  }
    38  
    39  func (c PostgreSQLConnector) WithDBName(dbName string) driver.Connector {
    40  	c.DBName = dbName
    41  	return &c
    42  }
    43  
    44  func (c *PostgreSQLConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error {
    45  	opts := migration.MigrationOptsFromContext(ctx)
    46  
    47  	prevDB := dbFromInformationSchema(db)
    48  
    49  	d := db.D()
    50  	dialect := db.Dialect()
    51  
    52  	if prevDB == nil {
    53  		prevDB = &sqlx.Database{
    54  			Name: d.Name,
    55  		}
    56  		if _, err := db.ExecExpr(dialect.CreateDatabase(d.Name)); err != nil {
    57  			return err
    58  		}
    59  	}
    60  
    61  	if d.Schema != "" {
    62  		if _, err := db.ExecExpr(dialect.CreateSchema(d.Schema)); err != nil {
    63  			return err
    64  		}
    65  		prevDB = prevDB.WithSchema(d.Schema)
    66  	}
    67  
    68  	for _, name := range d.Tables.TableNames() {
    69  		table := d.Table(name)
    70  
    71  		prevTable := prevDB.Table(name)
    72  
    73  		if prevTable == nil {
    74  			for _, expr := range dialect.CreateTableIsNotExists(table) {
    75  				if _, err := db.ExecExpr(expr); err != nil {
    76  					return err
    77  				}
    78  			}
    79  			continue
    80  		}
    81  
    82  		exprList := table.Diff(prevTable, dialect)
    83  
    84  		for _, expr := range exprList {
    85  			if !(expr == nil || expr.IsNil()) {
    86  				if opts.DryRun {
    87  					log.Printf(builder.ResolveExpr(expr).Query())
    88  				} else {
    89  					if _, err := db.ExecExpr(expr); err != nil {
    90  						return err
    91  					}
    92  				}
    93  			}
    94  		}
    95  	}
    96  
    97  	return nil
    98  }
    99  
   100  func (c *PostgreSQLConnector) Connect(ctx context.Context) (driver.Conn, error) {
   101  	d := c.Driver()
   102  	conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra))
   103  	if err != nil {
   104  		if c.IsErrorUnknownDatabase(err) {
   105  			connectForCreateDB, err := d.Open(dsn(c.Host, "", c.Extra))
   106  			if err != nil {
   107  				return nil, err
   108  			}
   109  			if _, err := connectForCreateDB.(driver.ExecerContext).
   110  				ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil {
   111  				return nil, err
   112  			}
   113  			if err := connectForCreateDB.Close(); err != nil {
   114  				return nil, err
   115  			}
   116  			return c.Connect(ctx)
   117  		}
   118  		return nil, err
   119  	}
   120  	for _, ex := range c.Extensions {
   121  		if _, err := conn.(driver.ExecerContext).
   122  			ExecContext(context.Background(), "CREATE EXTENSION IF NOT EXISTS "+ex+";", nil); err != nil {
   123  			return nil, err
   124  		}
   125  	}
   126  	return conn, nil
   127  }
   128  
   129  func (PostgreSQLConnector) Driver() driver.Driver {
   130  	return &PostgreSQLLoggingDriver{Driver: &pq.Driver{}, Logger: logrus.StandardLogger()}
   131  }
   132  
   133  func (PostgreSQLConnector) DriverName() string {
   134  	return "postgres"
   135  }
   136  
   137  func (PostgreSQLConnector) PrimaryKeyName() string {
   138  	return "pkey"
   139  }
   140  
   141  func (PostgreSQLConnector) IsErrorUnknownDatabase(err error) bool {
   142  	if e, ok := err.(*pq.Error); ok && e.Code == "3D000" {
   143  		return true
   144  	}
   145  	return false
   146  }
   147  
   148  func (PostgreSQLConnector) IsErrorConflict(err error) bool {
   149  	if e, ok := err.(*pq.Error); ok && e.Code == "23505" {
   150  		return true
   151  	}
   152  	return false
   153  }
   154  
   155  func (c *PostgreSQLConnector) CreateDatabase(dbName string) builder.SqlExpr {
   156  	e := builder.Expr("CREATE DATABASE ")
   157  	e.WriteString(dbName)
   158  	e.WriteEnd()
   159  	return e
   160  }
   161  
   162  func (c *PostgreSQLConnector) CreateSchema(schema string) builder.SqlExpr {
   163  	e := builder.Expr("CREATE SCHEMA IF NOT EXISTS ")
   164  	e.WriteString(schema)
   165  	e.WriteEnd()
   166  	return e
   167  }
   168  
   169  func (c *PostgreSQLConnector) DropDatabase(dbName string) builder.SqlExpr {
   170  	e := builder.Expr("DROP DATABASE IF EXISTS ")
   171  	e.WriteString(dbName)
   172  	e.WriteEnd()
   173  	return e
   174  }
   175  
   176  func (c *PostgreSQLConnector) AddIndex(key *builder.Key) builder.SqlExpr {
   177  	if key.IsPrimary() {
   178  		e := builder.Expr("ALTER TABLE ")
   179  		e.WriteExpr(key.Table)
   180  		e.WriteString(" ADD PRIMARY KEY ")
   181  		e.WriteGroup(func(e *builder.Ex) {
   182  			e.WriteExpr(key.Columns)
   183  		})
   184  		e.WriteEnd()
   185  		return e
   186  	}
   187  
   188  	e := builder.Expr("CREATE ")
   189  	if key.IsUnique {
   190  		e.WriteString("UNIQUE ")
   191  	}
   192  	e.WriteString("INDEX ")
   193  
   194  	e.WriteString(key.Table.Name)
   195  	e.WriteString("_")
   196  	e.WriteString(key.Name)
   197  
   198  	e.WriteString(" ON ")
   199  	e.WriteExpr(key.Table)
   200  
   201  	if map[string]bool{
   202  		"BTREE":   true,
   203  		"HASH":    true,
   204  		"SPATIAL": true,
   205  		"GIST":    true,
   206  	}[strings.ToUpper(key.Method)] {
   207  		e.WriteString(" USING ")
   208  
   209  		if key.Method == "SPATIAL" {
   210  			e.WriteString("GIST")
   211  		} else {
   212  			e.WriteString(key.Method)
   213  		}
   214  	}
   215  
   216  	e.WriteByte(' ')
   217  	e.WriteGroup(func(e *builder.Ex) {
   218  		e.WriteExpr(key.Columns)
   219  	})
   220  
   221  	e.WriteEnd()
   222  	return e
   223  }
   224  
   225  func (c *PostgreSQLConnector) DropIndex(key *builder.Key) builder.SqlExpr {
   226  	if key.IsPrimary() {
   227  		e := builder.Expr("ALTER TABLE ")
   228  		e.WriteExpr(key.Table)
   229  		e.WriteString(" DROP CONSTRAINT ")
   230  		e.WriteExpr(key.Table)
   231  		e.WriteString("_pkey")
   232  		e.WriteEnd()
   233  		return e
   234  	}
   235  	e := builder.Expr("DROP ")
   236  
   237  	e.WriteString("INDEX IF EXISTS ")
   238  	e.WriteExpr(key.Table)
   239  	e.WriteByte('_')
   240  	e.WriteString(key.Name)
   241  
   242  	return e
   243  }
   244  
   245  func (c *PostgreSQLConnector) CreateTableIsNotExists(t *builder.Table) (exprs []builder.SqlExpr) {
   246  	expr := builder.Expr("CREATE TABLE IF NOT EXISTS ")
   247  	expr.WriteExpr(t)
   248  	expr.WriteByte(' ')
   249  	expr.WriteGroup(func(e *builder.Ex) {
   250  		if t.Columns.IsNil() {
   251  			return
   252  		}
   253  
   254  		t.Columns.Range(func(col *builder.Column, idx int) {
   255  			if col.DeprecatedActions != nil {
   256  				return
   257  			}
   258  
   259  			if idx > 0 {
   260  				e.WriteByte(',')
   261  			}
   262  			e.WriteByte('\n')
   263  			e.WriteByte('\t')
   264  
   265  			e.WriteExpr(col)
   266  			e.WriteByte(' ')
   267  			e.WriteExpr(c.DataType(col.ColumnType))
   268  		})
   269  
   270  		t.Keys.Range(func(key *builder.Key, idx int) {
   271  			if key.IsPrimary() {
   272  				e.WriteByte(',')
   273  				e.WriteByte('\n')
   274  				e.WriteByte('\t')
   275  				e.WriteString("PRIMARY KEY ")
   276  				e.WriteGroup(func(e *builder.Ex) {
   277  					e.WriteExpr(key.Columns)
   278  				})
   279  			}
   280  		})
   281  
   282  		expr.WriteByte('\n')
   283  	})
   284  
   285  	expr.WriteEnd()
   286  	exprs = append(exprs, expr)
   287  
   288  	t.Keys.Range(func(key *builder.Key, idx int) {
   289  		if !key.IsPrimary() {
   290  			exprs = append(exprs, c.AddIndex(key))
   291  		}
   292  	})
   293  
   294  	return
   295  }
   296  
   297  func (c *PostgreSQLConnector) DropTable(t *builder.Table) builder.SqlExpr {
   298  	e := builder.Expr("DROP TABLE IF EXISTS ")
   299  	e.WriteExpr(t)
   300  	e.WriteEnd()
   301  	return e
   302  }
   303  
   304  func (c *PostgreSQLConnector) TruncateTable(t *builder.Table) builder.SqlExpr {
   305  	e := builder.Expr("TRUNCATE TABLE ")
   306  	e.WriteExpr(t)
   307  	e.WriteEnd()
   308  	return e
   309  }
   310  
   311  func (c *PostgreSQLConnector) AddColumn(col *builder.Column) builder.SqlExpr {
   312  	e := builder.Expr("ALTER TABLE ")
   313  	e.WriteExpr(col.Table)
   314  	e.WriteString(" ADD COLUMN ")
   315  	e.WriteExpr(col)
   316  	e.WriteByte(' ')
   317  	e.WriteExpr(c.DataType(col.ColumnType))
   318  	e.WriteEnd()
   319  	return e
   320  }
   321  
   322  func (c *PostgreSQLConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr {
   323  	e := builder.Expr("ALTER TABLE ")
   324  	e.WriteExpr(col.Table)
   325  	e.WriteString(" RENAME COLUMN ")
   326  	e.WriteExpr(col)
   327  	e.WriteString(" TO ")
   328  	e.WriteExpr(target)
   329  	e.WriteEnd()
   330  	return e
   331  }
   332  
   333  func (c *PostgreSQLConnector) ModifyColumn(col *builder.Column) builder.SqlExpr {
   334  	if col.AutoIncrement {
   335  		return nil
   336  	}
   337  
   338  	e := builder.Expr("ALTER TABLE ")
   339  	e.WriteExpr(col.Table)
   340  
   341  	e.WriteString(" ALTER COLUMN ")
   342  	e.WriteExpr(col)
   343  	e.WriteString(" TYPE ")
   344  	e.WriteString(c.dataType(col.ColumnType.Type, col.ColumnType))
   345  
   346  	{
   347  		e.WriteString(", ALTER COLUMN ")
   348  		e.WriteExpr(col)
   349  		if !col.Null {
   350  			e.WriteString(" SET NOT NULL")
   351  		} else {
   352  			e.WriteString(" DROP NOT NULL")
   353  		}
   354  	}
   355  
   356  	{
   357  		e.WriteString(", ALTER COLUMN ")
   358  		e.WriteExpr(col)
   359  		if col.Default != nil {
   360  			e.WriteString(" SET DEFAULT ")
   361  			e.WriteString(*col.Default)
   362  		} else {
   363  			e.WriteString(" DROP DEFAULT")
   364  		}
   365  	}
   366  
   367  	e.WriteEnd()
   368  
   369  	return e
   370  }
   371  
   372  func (c *PostgreSQLConnector) DropColumn(col *builder.Column) builder.SqlExpr {
   373  	e := builder.Expr("ALTER TABLE ")
   374  	e.WriteExpr(col.Table)
   375  	e.WriteString(" DROP COLUMN ")
   376  	e.WriteString(col.Name)
   377  	e.WriteEnd()
   378  	return e
   379  }
   380  
   381  func (c *PostgreSQLConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr {
   382  	return builder.Expr(c.dataType(columnType.Type, columnType) + c.dataTypeModify(columnType))
   383  }
   384  
   385  func (c *PostgreSQLConnector) dataType(typ reflect.Type, columnType *builder.ColumnType) string {
   386  	if columnType.GetDataType != nil {
   387  		return columnType.GetDataType(c.DriverName())
   388  	}
   389  
   390  	switch typ.Kind() {
   391  	case reflect.Ptr:
   392  		return c.dataType(typ.Elem(), columnType)
   393  	case reflect.Bool:
   394  		return "boolean"
   395  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
   396  		if columnType.AutoIncrement {
   397  			return "serial"
   398  		}
   399  		return "integer"
   400  	case reflect.Int64, reflect.Uint64:
   401  		if columnType.AutoIncrement {
   402  			return "bigserial"
   403  		}
   404  		return "bigint"
   405  	case reflect.Float64:
   406  		return "double precision"
   407  	case reflect.Float32:
   408  		return "real"
   409  	case reflect.Slice:
   410  		if typ.Elem().Kind() == reflect.Uint8 {
   411  			return "bytea"
   412  		}
   413  	case reflect.String:
   414  		size := columnType.Length
   415  		if size == 0 {
   416  			size = 255
   417  		}
   418  		if size < 65535/3 {
   419  			return "varchar" + sizeModifier(size, 0)
   420  		}
   421  		return "text"
   422  	}
   423  
   424  	switch typ.Name() {
   425  	case "Hstore":
   426  		return "hstore"
   427  	case "ByteaArray":
   428  		return c.dataType(reflect.TypeOf(pq.ByteaArray{[]byte("")}[0]), columnType) + "[]"
   429  	case "BoolArray":
   430  		return c.dataType(reflect.TypeOf(pq.BoolArray{true}[0]), columnType) + "[]"
   431  	case "Float64Array":
   432  		return c.dataType(reflect.TypeOf(pq.Float64Array{0}[0]), columnType) + "[]"
   433  	case "Int64Array":
   434  		return c.dataType(reflect.TypeOf(pq.Int64Array{0}[0]), columnType) + "[]"
   435  	case "StringArray":
   436  		return c.dataType(reflect.TypeOf(pq.StringArray{""}[0]), columnType) + "[]"
   437  	case "NullInt64":
   438  		return "bigint"
   439  	case "NullFloat64":
   440  		return "double precision"
   441  	case "NullBool":
   442  		return "boolean"
   443  	case "Time", "NullTime":
   444  		return "timestamp with time zone"
   445  	}
   446  
   447  	panic(fmt.Errorf("unsupport type %s", typ))
   448  }
   449  
   450  func (c *PostgreSQLConnector) dataTypeModify(columnType *builder.ColumnType) string {
   451  	buf := bytes.NewBuffer(nil)
   452  
   453  	if !columnType.Null {
   454  		buf.WriteString(" NOT NULL")
   455  	}
   456  
   457  	if columnType.Default != nil {
   458  		buf.WriteString(" DEFAULT ")
   459  		buf.WriteString(*columnType.Default)
   460  	}
   461  
   462  	return buf.String()
   463  }
   464  
   465  func sizeModifier(length uint64, decimal uint64) string {
   466  	if length > 0 {
   467  		size := strconv.FormatUint(length, 10)
   468  		if decimal > 0 {
   469  			return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")"
   470  		}
   471  		return "(" + size + ")"
   472  	}
   473  	return ""
   474  }