github.com/bdollma-te/migrate/v4@v4.17.0-clickv2/database/pgx/pgx.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package pgx
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"fmt"
    10  	"io"
    11  	nurl "net/url"
    12  	"regexp"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	"go.uber.org/atomic"
    18  
    19  	"github.com/bdollma-te/migrate/v4"
    20  	"github.com/bdollma-te/migrate/v4/database"
    21  	"github.com/bdollma-te/migrate/v4/database/multistmt"
    22  	"github.com/hashicorp/go-multierror"
    23  	"github.com/jackc/pgconn"
    24  	"github.com/jackc/pgerrcode"
    25  	_ "github.com/jackc/pgx/v4/stdlib"
    26  	"github.com/lib/pq"
    27  )
    28  
    29  const (
    30  	LockStrategyAdvisory = "advisory"
    31  	LockStrategyTable    = "table"
    32  )
    33  
    34  func init() {
    35  	db := Postgres{}
    36  	database.Register("pgx", &db)
    37  	database.Register("pgx4", &db)
    38  }
    39  
    40  var (
    41  	multiStmtDelimiter = []byte(";")
    42  
    43  	DefaultMigrationsTable       = "schema_migrations"
    44  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    45  	DefaultLockTable             = "schema_lock"
    46  	DefaultLockStrategy          = LockStrategyAdvisory
    47  )
    48  
    49  var (
    50  	ErrNilConfig      = fmt.Errorf("no config")
    51  	ErrNoDatabaseName = fmt.Errorf("no database name")
    52  	ErrNoSchema       = fmt.Errorf("no schema")
    53  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    54  )
    55  
    56  type Config struct {
    57  	MigrationsTable       string
    58  	DatabaseName          string
    59  	SchemaName            string
    60  	LockTable             string
    61  	LockStrategy          string
    62  	migrationsSchemaName  string
    63  	migrationsTableName   string
    64  	StatementTimeout      time.Duration
    65  	MigrationsTableQuoted bool
    66  	MultiStatementEnabled bool
    67  	MultiStatementMaxSize int
    68  }
    69  
    70  type Postgres struct {
    71  	// Locking and unlocking need to use the same connection
    72  	conn     *sql.Conn
    73  	db       *sql.DB
    74  	isLocked atomic.Bool
    75  
    76  	// Open and WithInstance need to guarantee that config is never nil
    77  	config *Config
    78  }
    79  
    80  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    81  	if config == nil {
    82  		return nil, ErrNilConfig
    83  	}
    84  
    85  	if err := instance.Ping(); err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	if config.DatabaseName == "" {
    90  		query := `SELECT CURRENT_DATABASE()`
    91  		var databaseName string
    92  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    93  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    94  		}
    95  
    96  		if len(databaseName) == 0 {
    97  			return nil, ErrNoDatabaseName
    98  		}
    99  
   100  		config.DatabaseName = databaseName
   101  	}
   102  
   103  	if config.SchemaName == "" {
   104  		query := `SELECT CURRENT_SCHEMA()`
   105  		var schemaName string
   106  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
   107  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
   108  		}
   109  
   110  		if len(schemaName) == 0 {
   111  			return nil, ErrNoSchema
   112  		}
   113  
   114  		config.SchemaName = schemaName
   115  	}
   116  
   117  	if len(config.MigrationsTable) == 0 {
   118  		config.MigrationsTable = DefaultMigrationsTable
   119  	}
   120  
   121  	if len(config.LockTable) == 0 {
   122  		config.LockTable = DefaultLockTable
   123  	}
   124  
   125  	if len(config.LockStrategy) == 0 {
   126  		config.LockStrategy = DefaultLockStrategy
   127  	}
   128  
   129  	config.migrationsSchemaName = config.SchemaName
   130  	config.migrationsTableName = config.MigrationsTable
   131  	if config.MigrationsTableQuoted {
   132  		re := regexp.MustCompile(`"(.*?)"`)
   133  		result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
   134  		config.migrationsTableName = result[len(result)-1][1]
   135  		if len(result) == 2 {
   136  			config.migrationsSchemaName = result[0][1]
   137  		} else if len(result) > 2 {
   138  			return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
   139  		}
   140  	}
   141  
   142  	conn, err := instance.Conn(context.Background())
   143  
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	px := &Postgres{
   149  		conn:   conn,
   150  		db:     instance,
   151  		config: config,
   152  	}
   153  
   154  	if err := px.ensureLockTable(); err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	if err := px.ensureVersionTable(); err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	return px, nil
   163  }
   164  
   165  func (p *Postgres) Open(url string) (database.Driver, error) {
   166  	purl, err := nurl.Parse(url)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	// Driver is registered as pgx, but connection string must use postgres schema
   172  	// when making actual connection
   173  	// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
   174  	purl.Scheme = "postgres"
   175  
   176  	db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	migrationsTable := purl.Query().Get("x-migrations-table")
   182  	migrationsTableQuoted := false
   183  	if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
   184  		migrationsTableQuoted, err = strconv.ParseBool(s)
   185  		if err != nil {
   186  			return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
   187  		}
   188  	}
   189  	if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
   190  		return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
   191  	}
   192  
   193  	statementTimeoutString := purl.Query().Get("x-statement-timeout")
   194  	statementTimeout := 0
   195  	if statementTimeoutString != "" {
   196  		statementTimeout, err = strconv.Atoi(statementTimeoutString)
   197  		if err != nil {
   198  			return nil, err
   199  		}
   200  	}
   201  
   202  	multiStatementMaxSize := DefaultMultiStatementMaxSize
   203  	if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
   204  		multiStatementMaxSize, err = strconv.Atoi(s)
   205  		if err != nil {
   206  			return nil, err
   207  		}
   208  		if multiStatementMaxSize <= 0 {
   209  			multiStatementMaxSize = DefaultMultiStatementMaxSize
   210  		}
   211  	}
   212  
   213  	multiStatementEnabled := false
   214  	if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
   215  		multiStatementEnabled, err = strconv.ParseBool(s)
   216  		if err != nil {
   217  			return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
   218  		}
   219  	}
   220  
   221  	lockStrategy := purl.Query().Get("x-lock-strategy")
   222  	lockTable := purl.Query().Get("x-lock-table")
   223  
   224  	px, err := WithInstance(db, &Config{
   225  		DatabaseName:          purl.Path,
   226  		MigrationsTable:       migrationsTable,
   227  		MigrationsTableQuoted: migrationsTableQuoted,
   228  		StatementTimeout:      time.Duration(statementTimeout) * time.Millisecond,
   229  		MultiStatementEnabled: multiStatementEnabled,
   230  		MultiStatementMaxSize: multiStatementMaxSize,
   231  		LockStrategy:          lockStrategy,
   232  		LockTable:             lockTable,
   233  	})
   234  
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  
   239  	return px, nil
   240  }
   241  
   242  func (p *Postgres) Close() error {
   243  	connErr := p.conn.Close()
   244  	dbErr := p.db.Close()
   245  	if connErr != nil || dbErr != nil {
   246  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   247  	}
   248  	return nil
   249  }
   250  
   251  func (p *Postgres) Lock() error {
   252  	return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
   253  		switch p.config.LockStrategy {
   254  		case LockStrategyAdvisory:
   255  			return p.applyAdvisoryLock()
   256  		case LockStrategyTable:
   257  			return p.applyTableLock()
   258  		default:
   259  			return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
   260  		}
   261  	})
   262  }
   263  
   264  func (p *Postgres) Unlock() error {
   265  	return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
   266  		switch p.config.LockStrategy {
   267  		case LockStrategyAdvisory:
   268  			return p.releaseAdvisoryLock()
   269  		case LockStrategyTable:
   270  			return p.releaseTableLock()
   271  		default:
   272  			return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
   273  		}
   274  	})
   275  }
   276  
   277  // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
   278  func (p *Postgres) applyAdvisoryLock() error {
   279  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
   280  	if err != nil {
   281  		return err
   282  	}
   283  
   284  	// This will wait indefinitely until the lock can be acquired.
   285  	query := `SELECT pg_advisory_lock($1)`
   286  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   287  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   288  	}
   289  	return nil
   290  }
   291  
   292  func (p *Postgres) applyTableLock() error {
   293  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   294  	if err != nil {
   295  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   296  	}
   297  	defer func() {
   298  		errRollback := tx.Rollback()
   299  		if errRollback != nil {
   300  			err = multierror.Append(err, errRollback)
   301  		}
   302  	}()
   303  
   304  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
   305  	if err != nil {
   306  		return err
   307  	}
   308  
   309  	query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
   310  	rows, err := tx.Query(query, aid)
   311  	if err != nil {
   312  		return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
   313  	}
   314  
   315  	defer func() {
   316  		if errClose := rows.Close(); errClose != nil {
   317  			err = multierror.Append(err, errClose)
   318  		}
   319  	}()
   320  
   321  	// If row exists at all, lock is present
   322  	locked := rows.Next()
   323  	if locked {
   324  		return database.ErrLocked
   325  	}
   326  
   327  	query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
   328  	if _, err := tx.Exec(query, aid); err != nil {
   329  		return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
   330  	}
   331  
   332  	return tx.Commit()
   333  }
   334  
   335  func (p *Postgres) releaseAdvisoryLock() error {
   336  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
   337  	if err != nil {
   338  		return err
   339  	}
   340  
   341  	query := `SELECT pg_advisory_unlock($1)`
   342  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   343  		return &database.Error{OrigErr: err, Query: []byte(query)}
   344  	}
   345  
   346  	return nil
   347  }
   348  
   349  func (p *Postgres) releaseTableLock() error {
   350  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
   351  	if err != nil {
   352  		return err
   353  	}
   354  
   355  	query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
   356  	if _, err := p.db.Exec(query, aid); err != nil {
   357  		return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
   358  	}
   359  
   360  	return nil
   361  }
   362  
   363  func (p *Postgres) Run(migration io.Reader) error {
   364  	if p.config.MultiStatementEnabled {
   365  		var err error
   366  		if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
   367  			if err = p.runStatement(m); err != nil {
   368  				return false
   369  			}
   370  			return true
   371  		}); e != nil {
   372  			return e
   373  		}
   374  		return err
   375  	}
   376  	migr, err := io.ReadAll(migration)
   377  	if err != nil {
   378  		return err
   379  	}
   380  	return p.runStatement(migr)
   381  }
   382  
   383  func (p *Postgres) runStatement(statement []byte) error {
   384  	ctx := context.Background()
   385  	if p.config.StatementTimeout != 0 {
   386  		var cancel context.CancelFunc
   387  		ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
   388  		defer cancel()
   389  	}
   390  	query := string(statement)
   391  	if strings.TrimSpace(query) == "" {
   392  		return nil
   393  	}
   394  	if _, err := p.conn.ExecContext(ctx, query); err != nil {
   395  
   396  		if pgErr, ok := err.(*pgconn.PgError); ok {
   397  			var line uint
   398  			var col uint
   399  			var lineColOK bool
   400  			line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
   401  			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
   402  			if lineColOK {
   403  				message = fmt.Sprintf("%s (column %d)", message, col)
   404  			}
   405  			if pgErr.Detail != "" {
   406  				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
   407  			}
   408  			return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
   409  		}
   410  		return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
   411  	}
   412  	return nil
   413  }
   414  
   415  func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
   416  	// replace crlf with lf
   417  	s = strings.Replace(s, "\r\n", "\n", -1)
   418  	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
   419  	runes := []rune(s)
   420  	if pos > len(runes) {
   421  		return 0, 0, false
   422  	}
   423  	sel := runes[:pos]
   424  	line = uint(runesCount(sel, newLine) + 1)
   425  	col = uint(pos - 1 - runesLastIndex(sel, newLine))
   426  	return line, col, true
   427  }
   428  
   429  const newLine = '\n'
   430  
   431  func runesCount(input []rune, target rune) int {
   432  	var count int
   433  	for _, r := range input {
   434  		if r == target {
   435  			count++
   436  		}
   437  	}
   438  	return count
   439  }
   440  
   441  func runesLastIndex(input []rune, target rune) int {
   442  	for i := len(input) - 1; i >= 0; i-- {
   443  		if input[i] == target {
   444  			return i
   445  		}
   446  	}
   447  	return -1
   448  }
   449  
   450  func (p *Postgres) SetVersion(version int, dirty bool) error {
   451  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   452  	if err != nil {
   453  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   454  	}
   455  
   456  	query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
   457  	if _, err := tx.Exec(query); err != nil {
   458  		if errRollback := tx.Rollback(); errRollback != nil {
   459  			err = multierror.Append(err, errRollback)
   460  		}
   461  		return &database.Error{OrigErr: err, Query: []byte(query)}
   462  	}
   463  
   464  	// Also re-write the schema version for nil dirty versions to prevent
   465  	// empty schema version for failed down migration on the first migration
   466  	// See: https://github.com/golang-migrate/migrate/issues/330
   467  	if version >= 0 || (version == database.NilVersion && dirty) {
   468  		query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
   469  		if _, err := tx.Exec(query, version, dirty); err != nil {
   470  			if errRollback := tx.Rollback(); errRollback != nil {
   471  				err = multierror.Append(err, errRollback)
   472  			}
   473  			return &database.Error{OrigErr: err, Query: []byte(query)}
   474  		}
   475  	}
   476  
   477  	if err := tx.Commit(); err != nil {
   478  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   479  	}
   480  
   481  	return nil
   482  }
   483  
   484  func (p *Postgres) Version() (version int, dirty bool, err error) {
   485  	query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
   486  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   487  	switch {
   488  	case err == sql.ErrNoRows:
   489  		return database.NilVersion, false, nil
   490  
   491  	case err != nil:
   492  		if e, ok := err.(*pgconn.PgError); ok {
   493  			if e.SQLState() == pgerrcode.UndefinedTable {
   494  				return database.NilVersion, false, nil
   495  			}
   496  		}
   497  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   498  
   499  	default:
   500  		return version, dirty, nil
   501  	}
   502  }
   503  
   504  func (p *Postgres) Drop() (err error) {
   505  	// select all tables in current schema
   506  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
   507  	tables, err := p.conn.QueryContext(context.Background(), query)
   508  	if err != nil {
   509  		return &database.Error{OrigErr: err, Query: []byte(query)}
   510  	}
   511  	defer func() {
   512  		if errClose := tables.Close(); errClose != nil {
   513  			err = multierror.Append(err, errClose)
   514  		}
   515  	}()
   516  
   517  	// delete one table after another
   518  	tableNames := make([]string, 0)
   519  	for tables.Next() {
   520  		var tableName string
   521  		if err := tables.Scan(&tableName); err != nil {
   522  			return err
   523  		}
   524  
   525  		// do not drop lock table
   526  		if tableName == p.config.LockTable && p.config.LockStrategy == LockStrategyTable {
   527  			continue
   528  		}
   529  
   530  		if len(tableName) > 0 {
   531  			tableNames = append(tableNames, tableName)
   532  		}
   533  	}
   534  	if err := tables.Err(); err != nil {
   535  		return &database.Error{OrigErr: err, Query: []byte(query)}
   536  	}
   537  
   538  	if len(tableNames) > 0 {
   539  		// delete one by one ...
   540  		for _, t := range tableNames {
   541  			query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
   542  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   543  				return &database.Error{OrigErr: err, Query: []byte(query)}
   544  			}
   545  		}
   546  	}
   547  
   548  	return nil
   549  }
   550  
   551  // ensureVersionTable checks if versions table exists and, if not, creates it.
   552  // Note that this function locks the database, which deviates from the usual
   553  // convention of "caller locks" in the Postgres type.
   554  func (p *Postgres) ensureVersionTable() (err error) {
   555  	if err = p.Lock(); err != nil {
   556  		return err
   557  	}
   558  
   559  	defer func() {
   560  		if e := p.Unlock(); e != nil {
   561  			if err == nil {
   562  				err = e
   563  			} else {
   564  				err = multierror.Append(err, e)
   565  			}
   566  		}
   567  	}()
   568  
   569  	// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
   570  	// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
   571  	// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
   572  	// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
   573  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
   574  	row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
   575  
   576  	var count int
   577  	err = row.Scan(&count)
   578  	if err != nil {
   579  		return &database.Error{OrigErr: err, Query: []byte(query)}
   580  	}
   581  
   582  	if count == 1 {
   583  		return nil
   584  	}
   585  
   586  	query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
   587  	if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
   588  		return &database.Error{OrigErr: err, Query: []byte(query)}
   589  	}
   590  
   591  	return nil
   592  }
   593  
   594  func (p *Postgres) ensureLockTable() error {
   595  	if p.config.LockStrategy != LockStrategyTable {
   596  		return nil
   597  	}
   598  
   599  	var count int
   600  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   601  	if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
   602  		return &database.Error{OrigErr: err, Query: []byte(query)}
   603  	}
   604  	if count == 1 {
   605  		return nil
   606  	}
   607  
   608  	query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
   609  	if _, err := p.db.Exec(query); err != nil {
   610  		return &database.Error{OrigErr: err, Query: []byte(query)}
   611  	}
   612  
   613  	return nil
   614  }
   615  
   616  // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
   617  func quoteIdentifier(name string) string {
   618  	end := strings.IndexRune(name, 0)
   619  	if end > -1 {
   620  		name = name[:end]
   621  	}
   622  	return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
   623  }