github.com/olitvin/migrate/v4@v4.14.3-0.20210330111251-992b37ee04c8/database/postgres/postgres.go (about)

     1  // +build go1.9
     2  
     3  package postgres
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	nurl "net/url"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	multierror "github.com/hashicorp/go-multierror"
    17  	"github.com/lib/pq"
    18  	"github.com/olitvin/migrate/v4"
    19  	"github.com/olitvin/migrate/v4/database"
    20  	"github.com/olitvin/migrate/v4/database/multistmt"
    21  )
    22  
    23  func init() {
    24  	db := Postgres{}
    25  	database.Register("postgres", &db)
    26  	database.Register("postgresql", &db)
    27  }
    28  
    29  var (
    30  	multiStmtDelimiter = []byte(";")
    31  
    32  	DefaultMigrationsTable       = "schema_migrations"
    33  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    34  )
    35  
    36  var (
    37  	ErrNilConfig      = fmt.Errorf("no config")
    38  	ErrNoDatabaseName = fmt.Errorf("no database name")
    39  	ErrNoSchema       = fmt.Errorf("no schema")
    40  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    41  )
    42  
    43  type Config struct {
    44  	MigrationsTable       string
    45  	DatabaseName          string
    46  	SchemaName            string
    47  	StatementTimeout      time.Duration
    48  	MultiStatementEnabled bool
    49  	MultiStatementMaxSize int
    50  }
    51  
    52  type Postgres struct {
    53  	// Locking and unlocking need to use the same connection
    54  	conn     *sql.Conn
    55  	db       *sql.DB
    56  	isLocked bool
    57  
    58  	// Open and WithInstance need to guarantee that config is never nil
    59  	config *Config
    60  }
    61  
    62  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    63  	if config == nil {
    64  		return nil, ErrNilConfig
    65  	}
    66  
    67  	if err := instance.Ping(); err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	if config.DatabaseName == "" {
    72  		query := `SELECT CURRENT_DATABASE()`
    73  		var databaseName string
    74  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    75  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    76  		}
    77  
    78  		if len(databaseName) == 0 {
    79  			return nil, ErrNoDatabaseName
    80  		}
    81  
    82  		config.DatabaseName = databaseName
    83  	}
    84  
    85  	if config.SchemaName == "" {
    86  		query := `SELECT CURRENT_SCHEMA()`
    87  		var schemaName string
    88  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    89  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    90  		}
    91  
    92  		if len(schemaName) == 0 {
    93  			return nil, ErrNoSchema
    94  		}
    95  
    96  		config.SchemaName = schemaName
    97  	}
    98  
    99  	if len(config.MigrationsTable) == 0 {
   100  		config.MigrationsTable = DefaultMigrationsTable
   101  	}
   102  
   103  	conn, err := instance.Conn(context.Background())
   104  
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	px := &Postgres{
   110  		conn:   conn,
   111  		db:     instance,
   112  		config: config,
   113  	}
   114  
   115  	if err := px.ensureVersionTable(); err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	return px, nil
   120  }
   121  
   122  func (p *Postgres) Open(url string) (database.Driver, error) {
   123  	purl, err := nurl.Parse(url)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	migrationsTable := purl.Query().Get("x-migrations-table")
   134  	statementTimeoutString := purl.Query().Get("x-statement-timeout")
   135  	statementTimeout := 0
   136  	if statementTimeoutString != "" {
   137  		statementTimeout, err = strconv.Atoi(statementTimeoutString)
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  	}
   142  
   143  	multiStatementMaxSize := DefaultMultiStatementMaxSize
   144  	if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
   145  		multiStatementMaxSize, err = strconv.Atoi(s)
   146  		if err != nil {
   147  			return nil, err
   148  		}
   149  		if multiStatementMaxSize <= 0 {
   150  			multiStatementMaxSize = DefaultMultiStatementMaxSize
   151  		}
   152  	}
   153  
   154  	multiStatementEnabled := false
   155  	if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
   156  		multiStatementEnabled, err = strconv.ParseBool(s)
   157  		if err != nil {
   158  			return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
   159  		}
   160  	}
   161  
   162  	px, err := WithInstance(db, &Config{
   163  		DatabaseName:          purl.Path,
   164  		MigrationsTable:       migrationsTable,
   165  		StatementTimeout:      time.Duration(statementTimeout) * time.Millisecond,
   166  		MultiStatementEnabled: multiStatementEnabled,
   167  		MultiStatementMaxSize: multiStatementMaxSize,
   168  	})
   169  
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	return px, nil
   175  }
   176  
   177  func (p *Postgres) Close() error {
   178  	connErr := p.conn.Close()
   179  	dbErr := p.db.Close()
   180  	if connErr != nil || dbErr != nil {
   181  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   182  	}
   183  	return nil
   184  }
   185  
   186  // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
   187  func (p *Postgres) Lock() error {
   188  	if p.isLocked {
   189  		return database.ErrLocked
   190  	}
   191  
   192  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	// This will wait indefinitely until the lock can be acquired.
   198  	query := `SELECT pg_advisory_lock($1)`
   199  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   200  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   201  	}
   202  
   203  	p.isLocked = true
   204  	return nil
   205  }
   206  
   207  func (p *Postgres) Unlock() error {
   208  	if !p.isLocked {
   209  		return nil
   210  	}
   211  
   212  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
   213  	if err != nil {
   214  		return err
   215  	}
   216  
   217  	query := `SELECT pg_advisory_unlock($1)`
   218  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   219  		return &database.Error{OrigErr: err, Query: []byte(query)}
   220  	}
   221  	p.isLocked = false
   222  	return nil
   223  }
   224  
   225  func (p *Postgres) Run(migration io.Reader) error {
   226  	if p.config.MultiStatementEnabled {
   227  		var err error
   228  		if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
   229  			if err = p.runStatement(m); err != nil {
   230  				return false
   231  			}
   232  			return true
   233  		}); e != nil {
   234  			return e
   235  		}
   236  		return err
   237  	}
   238  	migr, err := ioutil.ReadAll(migration)
   239  	if err != nil {
   240  		return err
   241  	}
   242  	return p.runStatement(migr)
   243  }
   244  
   245  func (p *Postgres) runStatement(statement []byte) error {
   246  	ctx := context.Background()
   247  	if p.config.StatementTimeout != 0 {
   248  		var cancel context.CancelFunc
   249  		ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
   250  		defer cancel()
   251  	}
   252  	query := string(statement)
   253  	if strings.TrimSpace(query) == "" {
   254  		return nil
   255  	}
   256  	if _, err := p.conn.ExecContext(ctx, query); err != nil {
   257  		if pgErr, ok := err.(*pq.Error); ok {
   258  			var line uint
   259  			var col uint
   260  			var lineColOK bool
   261  			if pgErr.Position != "" {
   262  				if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
   263  					line, col, lineColOK = computeLineFromPos(query, int(pos))
   264  				}
   265  			}
   266  			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
   267  			if lineColOK {
   268  				message = fmt.Sprintf("%s (column %d)", message, col)
   269  			}
   270  			if pgErr.Detail != "" {
   271  				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
   272  			}
   273  			return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
   274  		}
   275  		return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
   276  	}
   277  	return nil
   278  }
   279  
   280  func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
   281  	// replace crlf with lf
   282  	s = strings.Replace(s, "\r\n", "\n", -1)
   283  	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
   284  	runes := []rune(s)
   285  	if pos > len(runes) {
   286  		return 0, 0, false
   287  	}
   288  	sel := runes[:pos]
   289  	line = uint(runesCount(sel, newLine) + 1)
   290  	col = uint(pos - 1 - runesLastIndex(sel, newLine))
   291  	return line, col, true
   292  }
   293  
   294  const newLine = '\n'
   295  
   296  func runesCount(input []rune, target rune) int {
   297  	var count int
   298  	for _, r := range input {
   299  		if r == target {
   300  			count++
   301  		}
   302  	}
   303  	return count
   304  }
   305  
   306  func runesLastIndex(input []rune, target rune) int {
   307  	for i := len(input) - 1; i >= 0; i-- {
   308  		if input[i] == target {
   309  			return i
   310  		}
   311  	}
   312  	return -1
   313  }
   314  
   315  func (p *Postgres) SetVersion(version int, dirty bool) error {
   316  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   317  	if err != nil {
   318  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   319  	}
   320  
   321  	query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable)
   322  	if _, err := tx.Exec(query); err != nil {
   323  		if errRollback := tx.Rollback(); errRollback != nil {
   324  			err = multierror.Append(err, errRollback)
   325  		}
   326  		return &database.Error{OrigErr: err, Query: []byte(query)}
   327  	}
   328  
   329  	// Also re-write the schema version for nil dirty versions to prevent
   330  	// empty schema version for failed down migration on the first migration
   331  	// See: https://github.com/olitvin/migrate/issues/330
   332  	if version >= 0 || (version == database.NilVersion && dirty) {
   333  		query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) +
   334  			` (version, dirty) VALUES ($1, $2)`
   335  		if _, err := tx.Exec(query, version, dirty); err != nil {
   336  			if errRollback := tx.Rollback(); errRollback != nil {
   337  				err = multierror.Append(err, errRollback)
   338  			}
   339  			return &database.Error{OrigErr: err, Query: []byte(query)}
   340  		}
   341  	}
   342  
   343  	if err := tx.Commit(); err != nil {
   344  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   345  	}
   346  
   347  	return nil
   348  }
   349  
   350  func (p *Postgres) Version() (version int, dirty bool, err error) {
   351  	query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
   352  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   353  	switch {
   354  	case err == sql.ErrNoRows:
   355  		return database.NilVersion, false, nil
   356  
   357  	case err != nil:
   358  		if e, ok := err.(*pq.Error); ok {
   359  			if e.Code.Name() == "undefined_table" {
   360  				return database.NilVersion, false, nil
   361  			}
   362  		}
   363  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   364  
   365  	default:
   366  		return version, dirty, nil
   367  	}
   368  }
   369  
   370  func (p *Postgres) Drop() (err error) {
   371  	// select all tables in current schema
   372  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
   373  	tables, err := p.conn.QueryContext(context.Background(), query)
   374  	if err != nil {
   375  		return &database.Error{OrigErr: err, Query: []byte(query)}
   376  	}
   377  	defer func() {
   378  		if errClose := tables.Close(); errClose != nil {
   379  			err = multierror.Append(err, errClose)
   380  		}
   381  	}()
   382  
   383  	// delete one table after another
   384  	tableNames := make([]string, 0)
   385  	for tables.Next() {
   386  		var tableName string
   387  		if err := tables.Scan(&tableName); err != nil {
   388  			return err
   389  		}
   390  		if len(tableName) > 0 {
   391  			tableNames = append(tableNames, tableName)
   392  		}
   393  	}
   394  	if err := tables.Err(); err != nil {
   395  		return &database.Error{OrigErr: err, Query: []byte(query)}
   396  	}
   397  
   398  	if len(tableNames) > 0 {
   399  		// delete one by one ...
   400  		for _, t := range tableNames {
   401  			query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
   402  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   403  				return &database.Error{OrigErr: err, Query: []byte(query)}
   404  			}
   405  		}
   406  	}
   407  
   408  	return nil
   409  }
   410  
   411  // ensureVersionTable checks if versions table exists and, if not, creates it.
   412  // Note that this function locks the database, which deviates from the usual
   413  // convention of "caller locks" in the Postgres type.
   414  func (p *Postgres) ensureVersionTable() (err error) {
   415  	if err = p.Lock(); err != nil {
   416  		return err
   417  	}
   418  
   419  	defer func() {
   420  		if e := p.Unlock(); e != nil {
   421  			if err == nil {
   422  				err = e
   423  			} else {
   424  				err = multierror.Append(err, e)
   425  			}
   426  		}
   427  	}()
   428  
   429  	// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
   430  	// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
   431  	// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
   432  	// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
   433  	var count int
   434  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   435  	row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)
   436  
   437  	err = row.Scan(&count)
   438  	if err != nil {
   439  		return &database.Error{OrigErr: err, Query: []byte(query)}
   440  	}
   441  
   442  	if count == 1 {
   443  		return nil
   444  	}
   445  
   446  	query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
   447  	if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
   448  		return &database.Error{OrigErr: err, Query: []byte(query)}
   449  	}
   450  
   451  	return nil
   452  }