github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/driver/postgres/postgres.go (about)

     1  package postgres
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"net/url"
     9  	"runtime"
    10  	"strings"
    11  
    12  	"github.com/amacneil/dbmate/pkg/dbmate"
    13  	"github.com/amacneil/dbmate/pkg/dbutil"
    14  
    15  	"github.com/lib/pq"
    16  )
    17  
    18  func init() {
    19  	dbmate.RegisterDriver(NewDriver, "postgres")
    20  	dbmate.RegisterDriver(NewDriver, "postgresql")
    21  }
    22  
    23  // Driver provides top level database functions
    24  type Driver struct {
    25  	migrationsTableName string
    26  	databaseURL         *url.URL
    27  	log                 io.Writer
    28  }
    29  
    30  // NewDriver initializes the driver
    31  func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
    32  	return &Driver{
    33  		migrationsTableName: config.MigrationsTableName,
    34  		databaseURL:         config.DatabaseURL,
    35  		log:                 config.Log,
    36  	}
    37  }
    38  
    39  func connectionString(u *url.URL) string {
    40  	hostname := u.Hostname()
    41  	port := u.Port()
    42  	query := u.Query()
    43  
    44  	// support socket parameter for consistency with mysql
    45  	if query.Get("socket") != "" {
    46  		query.Set("host", query.Get("socket"))
    47  		query.Del("socket")
    48  	}
    49  
    50  	// default hostname
    51  	if hostname == "" && query.Get("host") == "" {
    52  		switch runtime.GOOS {
    53  		case "linux":
    54  			query.Set("host", "/var/run/postgresql")
    55  		case "darwin", "freebsd", "dragonfly", "openbsd", "netbsd":
    56  			query.Set("host", "/tmp")
    57  		default:
    58  			hostname = "localhost"
    59  		}
    60  	}
    61  
    62  	// host param overrides url hostname
    63  	if query.Get("host") != "" {
    64  		hostname = ""
    65  	}
    66  
    67  	// always specify a port
    68  	if query.Get("port") != "" {
    69  		port = query.Get("port")
    70  		query.Del("port")
    71  	}
    72  	if port == "" {
    73  		port = "5432"
    74  	}
    75  
    76  	// generate output URL
    77  	out, _ := url.Parse(u.String())
    78  	out.Host = fmt.Sprintf("%s:%s", hostname, port)
    79  	out.RawQuery = query.Encode()
    80  
    81  	return out.String()
    82  }
    83  
    84  func connectionArgsForDump(u *url.URL) []string {
    85  	u = dbutil.MustParseURL(connectionString(u))
    86  
    87  	// find schemas from search_path
    88  	query := u.Query()
    89  	schemas := strings.Split(query.Get("search_path"), ",")
    90  	query.Del("search_path")
    91  	u.RawQuery = query.Encode()
    92  
    93  	out := []string{}
    94  	for _, schema := range schemas {
    95  		schema = strings.TrimSpace(schema)
    96  		if schema != "" {
    97  			out = append(out, "--schema", schema)
    98  		}
    99  	}
   100  	out = append(out, u.String())
   101  
   102  	return out
   103  }
   104  
   105  // Open creates a new database connection
   106  func (drv *Driver) Open() (*sql.DB, error) {
   107  	return sql.Open("postgres", connectionString(drv.databaseURL))
   108  }
   109  
   110  func (drv *Driver) openPostgresDB() (*sql.DB, error) {
   111  	// clone databaseURL
   112  	postgresURL, err := url.Parse(connectionString(drv.databaseURL))
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	// connect to postgres database
   118  	postgresURL.Path = "postgres"
   119  
   120  	return sql.Open("postgres", postgresURL.String())
   121  }
   122  
   123  // CreateDatabase creates the specified database
   124  func (drv *Driver) CreateDatabase() error {
   125  	name := dbutil.DatabaseName(drv.databaseURL)
   126  	fmt.Fprintf(drv.log, "Creating: %s\n", name)
   127  
   128  	db, err := drv.openPostgresDB()
   129  	if err != nil {
   130  		return err
   131  	}
   132  	defer dbutil.MustClose(db)
   133  
   134  	_, err = db.Exec(fmt.Sprintf("create database %s",
   135  		pq.QuoteIdentifier(name)))
   136  
   137  	return err
   138  }
   139  
   140  // DropDatabase drops the specified database (if it exists)
   141  func (drv *Driver) DropDatabase() error {
   142  	name := dbutil.DatabaseName(drv.databaseURL)
   143  	fmt.Fprintf(drv.log, "Dropping: %s\n", name)
   144  
   145  	db, err := drv.openPostgresDB()
   146  	if err != nil {
   147  		return err
   148  	}
   149  	defer dbutil.MustClose(db)
   150  
   151  	_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
   152  		pq.QuoteIdentifier(name)))
   153  
   154  	return err
   155  }
   156  
   157  func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
   158  	migrationsTable, err := drv.quotedMigrationsTableName(db)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  
   163  	// load applied migrations
   164  	migrations, err := dbutil.QueryColumn(db,
   165  		"select quote_literal(version) from "+migrationsTable+" order by version asc")
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	// build migrations table data
   171  	var buf bytes.Buffer
   172  	buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n")
   173  
   174  	if len(migrations) > 0 {
   175  		buf.WriteString("INSERT INTO " + migrationsTable + " (version) VALUES\n    (" +
   176  			strings.Join(migrations, "),\n    (") +
   177  			");\n")
   178  	}
   179  
   180  	return buf.Bytes(), nil
   181  }
   182  
   183  // DumpSchema returns the current database schema
   184  func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
   185  	// load schema
   186  	args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
   187  		"--no-privileges", "--no-owner"}, connectionArgsForDump(drv.databaseURL)...)
   188  	schema, err := dbutil.RunCommand("pg_dump", args...)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	migrations, err := drv.schemaMigrationsDump(db)
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  
   198  	schema = append(schema, migrations...)
   199  	return dbutil.TrimLeadingSQLComments(schema)
   200  }
   201  
   202  // DatabaseExists determines whether the database exists
   203  func (drv *Driver) DatabaseExists() (bool, error) {
   204  	name := dbutil.DatabaseName(drv.databaseURL)
   205  
   206  	db, err := drv.openPostgresDB()
   207  	if err != nil {
   208  		return false, err
   209  	}
   210  	defer dbutil.MustClose(db)
   211  
   212  	exists := false
   213  	err = db.QueryRow("select true from pg_database where datname = $1", name).
   214  		Scan(&exists)
   215  	if err == sql.ErrNoRows {
   216  		return false, nil
   217  	}
   218  
   219  	return exists, err
   220  }
   221  
   222  // MigrationsTableExists checks if the schema_migrations table exists
   223  func (drv *Driver) MigrationsTableExists(db *sql.DB) (bool, error) {
   224  	schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db)
   225  	if err != nil {
   226  		return false, err
   227  	}
   228  
   229  	exists := false
   230  	err = db.QueryRow("SELECT 1 FROM information_schema.tables "+
   231  		"WHERE  table_schema = $1 "+
   232  		"AND    table_name   = $2",
   233  		schema, migrationsTable).
   234  		Scan(&exists)
   235  	if err == sql.ErrNoRows {
   236  		return false, nil
   237  	}
   238  
   239  	return exists, err
   240  }
   241  
   242  // CreateMigrationsTable creates the schema_migrations table
   243  func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
   244  	schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db)
   245  	if err != nil {
   246  		return err
   247  	}
   248  
   249  	// first attempt at creating migrations table
   250  	createTableStmt := fmt.Sprintf(
   251  		"create table if not exists %s.%s (version varchar(128) primary key)",
   252  		schema, migrationsTable)
   253  	_, err = db.Exec(createTableStmt)
   254  	if err == nil {
   255  		// table exists or created successfully
   256  		return nil
   257  	}
   258  
   259  	// catch 'schema does not exist' error
   260  	pqErr, ok := err.(*pq.Error)
   261  	if !ok || pqErr.Code != "3F000" {
   262  		// unknown error
   263  		return err
   264  	}
   265  
   266  	// in theory we could attempt to create the schema every time, but we avoid that
   267  	// in case the user doesn't have permissions to create schemas
   268  	fmt.Fprintf(drv.log, "Creating schema: %s\n", schema)
   269  	_, err = db.Exec(fmt.Sprintf("create schema if not exists %s", schema))
   270  	if err != nil {
   271  		return err
   272  	}
   273  
   274  	// second and final attempt at creating migrations table
   275  	_, err = db.Exec(createTableStmt)
   276  	return err
   277  }
   278  
   279  // SelectMigrations returns a list of applied migrations
   280  // with an optional limit (in descending order)
   281  func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
   282  	migrationsTable, err := drv.quotedMigrationsTableName(db)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  
   287  	query := "select version from " + migrationsTable + " order by version desc"
   288  	if limit >= 0 {
   289  		query = fmt.Sprintf("%s limit %d", query, limit)
   290  	}
   291  	rows, err := db.Query(query)
   292  	if err != nil {
   293  		return nil, err
   294  	}
   295  
   296  	defer dbutil.MustClose(rows)
   297  
   298  	migrations := map[string]bool{}
   299  	for rows.Next() {
   300  		var version string
   301  		if err := rows.Scan(&version); err != nil {
   302  			return nil, err
   303  		}
   304  
   305  		migrations[version] = true
   306  	}
   307  
   308  	if err = rows.Err(); err != nil {
   309  		return nil, err
   310  	}
   311  
   312  	return migrations, nil
   313  }
   314  
   315  // InsertMigration adds a new migration record
   316  func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
   317  	migrationsTable, err := drv.quotedMigrationsTableName(db)
   318  	if err != nil {
   319  		return err
   320  	}
   321  
   322  	_, err = db.Exec("insert into "+migrationsTable+" (version) values ($1)", version)
   323  
   324  	return err
   325  }
   326  
   327  // DeleteMigration removes a migration record
   328  func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
   329  	migrationsTable, err := drv.quotedMigrationsTableName(db)
   330  	if err != nil {
   331  		return err
   332  	}
   333  
   334  	_, err = db.Exec("delete from "+migrationsTable+" where version = $1", version)
   335  
   336  	return err
   337  }
   338  
   339  // Ping verifies a connection to the database server. It does not verify whether the
   340  // specified database exists.
   341  func (drv *Driver) Ping() error {
   342  	// attempt connection to primary database, not "postgres" database
   343  	// to support servers with no "postgres" database
   344  	// (see https://github.com/amacneil/dbmate/issues/78)
   345  	db, err := drv.Open()
   346  	if err != nil {
   347  		return err
   348  	}
   349  	defer dbutil.MustClose(db)
   350  
   351  	err = db.Ping()
   352  	if err == nil {
   353  		return nil
   354  	}
   355  
   356  	// ignore 'database does not exist' error
   357  	pqErr, ok := err.(*pq.Error)
   358  	if ok && pqErr.Code == "3D000" {
   359  		return nil
   360  	}
   361  
   362  	return err
   363  }
   364  
   365  func (drv *Driver) quotedMigrationsTableName(db dbutil.Transaction) (string, error) {
   366  	schema, name, err := drv.quotedMigrationsTableNameParts(db)
   367  	if err != nil {
   368  		return "", err
   369  	}
   370  
   371  	return schema + "." + name, nil
   372  }
   373  
   374  func (drv *Driver) quotedMigrationsTableNameParts(db dbutil.Transaction) (string, string, error) {
   375  	schema := ""
   376  	tableNameParts := strings.Split(drv.migrationsTableName, ".")
   377  	if len(tableNameParts) > 1 {
   378  		// schema specified as part of table name
   379  		schema, tableNameParts = tableNameParts[0], tableNameParts[1:]
   380  	}
   381  
   382  	if schema == "" {
   383  		// no schema specified with table name, try URL search path if available
   384  		searchPath := strings.Split(drv.databaseURL.Query().Get("search_path"), ",")
   385  		schema = strings.TrimSpace(searchPath[0])
   386  	}
   387  
   388  	var err error
   389  	if schema == "" {
   390  		// if no URL available, use current schema
   391  		// this is a hack because we don't always have the URL context available
   392  		schema, err = dbutil.QueryValue(db, "select current_schema()")
   393  		if err != nil {
   394  			return "", "", err
   395  		}
   396  	}
   397  
   398  	// fall back to public schema as last resort
   399  	if schema == "" {
   400  		schema = "public"
   401  	}
   402  
   403  	// quote all parts
   404  	// use server rather than client to do this to avoid unnecessary quotes
   405  	// (which would change schema.sql diff)
   406  	tableNameParts = append([]string{schema}, tableNameParts...)
   407  	quotedNameParts, err := dbutil.QueryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts))
   408  	if err != nil {
   409  		return "", "", err
   410  	}
   411  
   412  	// if more than one part, we already have a schema
   413  	return quotedNameParts[0], strings.Join(quotedNameParts[1:], "."), nil
   414  }