github.com/solvedata/migrate/v4@v4.8.7-0.20201127053940-c9fba4ce569f/database/postgres/postgres.go (about)

     1  // +build go1.9
     2  
     3  package postgres
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	nurl "net/url"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"github.com/solvedata/migrate/v4"
    16  	"github.com/solvedata/migrate/v4/database"
    17  	multierror "github.com/hashicorp/go-multierror"
    18  	"github.com/lib/pq"
    19  )
    20  
    21  func init() {
    22  	db := Postgres{}
    23  	database.Register("postgres", &db)
    24  	database.Register("postgresql", &db)
    25  }
    26  
    27  var DefaultMigrationsTable = "schema_migrations"
    28  
    29  var (
    30  	ErrNilConfig      = fmt.Errorf("no config")
    31  	ErrNoDatabaseName = fmt.Errorf("no database name")
    32  	ErrNoSchema       = fmt.Errorf("no schema")
    33  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    34  )
    35  
    36  type Config struct {
    37  	MigrationsTable string
    38  	DatabaseName    string
    39  	SchemaName      string
    40  }
    41  
    42  type Postgres struct {
    43  	// Locking and unlocking need to use the same connection
    44  	conn     *sql.Conn
    45  	db       *sql.DB
    46  	isLocked bool
    47  
    48  	// Open and WithInstance need to guarantee that config is never nil
    49  	config *Config
    50  }
    51  
    52  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    53  	if config == nil {
    54  		return nil, ErrNilConfig
    55  	}
    56  
    57  	if err := instance.Ping(); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	if config.DatabaseName == "" {
    62  		query := `SELECT CURRENT_DATABASE()`
    63  		var databaseName string
    64  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    65  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    66  		}
    67  
    68  		if len(databaseName) == 0 {
    69  			return nil, ErrNoDatabaseName
    70  		}
    71  
    72  		config.DatabaseName = databaseName
    73  	}
    74  
    75  	if config.SchemaName == "" {
    76  		query := `SELECT CURRENT_SCHEMA()`
    77  		var schemaName string
    78  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    79  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    80  		}
    81  
    82  		if len(schemaName) == 0 {
    83  			return nil, ErrNoSchema
    84  		}
    85  
    86  		config.SchemaName = schemaName
    87  	}
    88  
    89  	if len(config.MigrationsTable) == 0 {
    90  		config.MigrationsTable = DefaultMigrationsTable
    91  	}
    92  
    93  	conn, err := instance.Conn(context.Background())
    94  
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	px := &Postgres{
   100  		conn:   conn,
   101  		db:     instance,
   102  		config: config,
   103  	}
   104  
   105  	if err := px.ensureVersionTable(); err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	return px, nil
   110  }
   111  
   112  func (p *Postgres) Open(url string) (database.Driver, error) {
   113  	purl, err := nurl.Parse(url)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	migrationsTable := purl.Query().Get("x-migrations-table")
   124  
   125  	px, err := WithInstance(db, &Config{
   126  		DatabaseName:    purl.Path,
   127  		MigrationsTable: migrationsTable,
   128  	})
   129  
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	return px, nil
   135  }
   136  
   137  func (p *Postgres) Close() error {
   138  	connErr := p.conn.Close()
   139  	dbErr := p.db.Close()
   140  	if connErr != nil || dbErr != nil {
   141  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   142  	}
   143  	return nil
   144  }
   145  
   146  // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
   147  func (p *Postgres) Lock() error {
   148  	if p.isLocked {
   149  		return database.ErrLocked
   150  	}
   151  
   152  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
   153  	if err != nil {
   154  		return err
   155  	}
   156  
   157  	// This will wait indefinitely until the lock can be acquired.
   158  	query := `SELECT pg_advisory_lock($1)`
   159  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   160  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   161  	}
   162  
   163  	p.isLocked = true
   164  	return nil
   165  }
   166  
   167  func (p *Postgres) Unlock() error {
   168  	if !p.isLocked {
   169  		return nil
   170  	}
   171  
   172  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	query := `SELECT pg_advisory_unlock($1)`
   178  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   179  		return &database.Error{OrigErr: err, Query: []byte(query)}
   180  	}
   181  	p.isLocked = false
   182  	return nil
   183  }
   184  
   185  func (p *Postgres) Run(migration io.Reader) error {
   186  	migr, err := ioutil.ReadAll(migration)
   187  	if err != nil {
   188  		return err
   189  	}
   190  
   191  	// run migration
   192  	query := string(migr[:])
   193  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   194  		if pgErr, ok := err.(*pq.Error); ok {
   195  			var line uint
   196  			var col uint
   197  			var lineColOK bool
   198  			if pgErr.Position != "" {
   199  				if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
   200  					line, col, lineColOK = computeLineFromPos(query, int(pos))
   201  				}
   202  			}
   203  			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
   204  			if lineColOK {
   205  				message = fmt.Sprintf("%s (column %d)", message, col)
   206  			}
   207  			if pgErr.Detail != "" {
   208  				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
   209  			}
   210  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
   211  		}
   212  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   213  	}
   214  
   215  	return nil
   216  }
   217  
   218  func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
   219  	// replace crlf with lf
   220  	s = strings.Replace(s, "\r\n", "\n", -1)
   221  	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
   222  	runes := []rune(s)
   223  	if pos > len(runes) {
   224  		return 0, 0, false
   225  	}
   226  	sel := runes[:pos]
   227  	line = uint(runesCount(sel, newLine) + 1)
   228  	col = uint(pos - 1 - runesLastIndex(sel, newLine))
   229  	return line, col, true
   230  }
   231  
   232  const newLine = '\n'
   233  
   234  func runesCount(input []rune, target rune) int {
   235  	var count int
   236  	for _, r := range input {
   237  		if r == target {
   238  			count++
   239  		}
   240  	}
   241  	return count
   242  }
   243  
   244  func runesLastIndex(input []rune, target rune) int {
   245  	for i := len(input) - 1; i >= 0; i-- {
   246  		if input[i] == target {
   247  			return i
   248  		}
   249  	}
   250  	return -1
   251  }
   252  
   253  func (p *Postgres) SetVersion(version int, dirty bool) error {
   254  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   255  	if err != nil {
   256  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   257  	}
   258  
   259  	query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable)
   260  	if _, err := tx.Exec(query); err != nil {
   261  		if errRollback := tx.Rollback(); errRollback != nil {
   262  			err = multierror.Append(err, errRollback)
   263  		}
   264  		return &database.Error{OrigErr: err, Query: []byte(query)}
   265  	}
   266  
   267  	if version >= 0 {
   268  		query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)`
   269  		if _, err := tx.Exec(query, version, dirty); err != nil {
   270  			if errRollback := tx.Rollback(); errRollback != nil {
   271  				err = multierror.Append(err, errRollback)
   272  			}
   273  			return &database.Error{OrigErr: err, Query: []byte(query)}
   274  		}
   275  	}
   276  
   277  	if err := tx.Commit(); err != nil {
   278  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   279  	}
   280  
   281  	return nil
   282  }
   283  
   284  func (p *Postgres) Version() (version int, dirty bool, err error) {
   285  	query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
   286  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   287  	switch {
   288  	case err == sql.ErrNoRows:
   289  		return database.NilVersion, false, nil
   290  
   291  	case err != nil:
   292  		if e, ok := err.(*pq.Error); ok {
   293  			if e.Code.Name() == "undefined_table" {
   294  				return database.NilVersion, false, nil
   295  			}
   296  		}
   297  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   298  
   299  	default:
   300  		return version, dirty, nil
   301  	}
   302  }
   303  
   304  func (p *Postgres) Drop() (err error) {
   305  	// select all tables in current schema
   306  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
   307  	tables, err := p.conn.QueryContext(context.Background(), query)
   308  	if err != nil {
   309  		return &database.Error{OrigErr: err, Query: []byte(query)}
   310  	}
   311  	defer func() {
   312  		if errClose := tables.Close(); errClose != nil {
   313  			err = multierror.Append(err, errClose)
   314  		}
   315  	}()
   316  
   317  	// delete one table after another
   318  	tableNames := make([]string, 0)
   319  	for tables.Next() {
   320  		var tableName string
   321  		if err := tables.Scan(&tableName); err != nil {
   322  			return err
   323  		}
   324  		if len(tableName) > 0 {
   325  			tableNames = append(tableNames, tableName)
   326  		}
   327  	}
   328  
   329  	if len(tableNames) > 0 {
   330  		// delete one by one ...
   331  		for _, t := range tableNames {
   332  			query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
   333  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   334  				return &database.Error{OrigErr: err, Query: []byte(query)}
   335  			}
   336  		}
   337  	}
   338  
   339  	return nil
   340  }
   341  
   342  // ensureVersionTable checks if versions table exists and, if not, creates it.
   343  // Note that this function locks the database, which deviates from the usual
   344  // convention of "caller locks" in the Postgres type.
   345  func (p *Postgres) ensureVersionTable() (err error) {
   346  	if err = p.Lock(); err != nil {
   347  		return err
   348  	}
   349  
   350  	defer func() {
   351  		if e := p.Unlock(); e != nil {
   352  			if err == nil {
   353  				err = e
   354  			} else {
   355  				err = multierror.Append(err, e)
   356  			}
   357  		}
   358  	}()
   359  
   360  	query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
   361  	if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
   362  		return &database.Error{OrigErr: err, Query: []byte(query)}
   363  	}
   364  
   365  	return nil
   366  }