github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/postgres/postgres.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package postgres
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	nurl "net/url"
    13  	"regexp"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"go.uber.org/atomic"
    19  
    20  	"github.com/hashicorp/go-multierror"
    21  	"github.com/lib/pq"
    22  	migrate "github.com/seashell-org/golang-migrate/v4"
    23  	"github.com/seashell-org/golang-migrate/v4/database"
    24  	"github.com/seashell-org/golang-migrate/v4/database/multistmt"
    25  )
    26  
    27  func init() {
    28  	db := Postgres{}
    29  	database.Register("postgres", &db)
    30  	database.Register("postgresql", &db)
    31  }
    32  
    33  var (
    34  	multiStmtDelimiter = []byte(";")
    35  
    36  	DefaultMigrationsTable       = "schema_migrations"
    37  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    38  )
    39  
    40  var (
    41  	ErrNilConfig      = fmt.Errorf("no config")
    42  	ErrNoDatabaseName = fmt.Errorf("no database name")
    43  	ErrNoSchema       = fmt.Errorf("no schema")
    44  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    45  )
    46  
    47  type Config struct {
    48  	MigrationsTable       string
    49  	MigrationsTableQuoted bool
    50  	MultiStatementEnabled bool
    51  	DatabaseName          string
    52  	SchemaName            string
    53  	migrationsSchemaName  string
    54  	migrationsTableName   string
    55  	StatementTimeout      time.Duration
    56  	MultiStatementMaxSize int
    57  }
    58  
    59  type Postgres struct {
    60  	// Locking and unlocking need to use the same connection
    61  	conn     *sql.Conn
    62  	db       *sql.DB
    63  	isLocked atomic.Bool
    64  
    65  	// Open and WithInstance need to guarantee that config is never nil
    66  	config *Config
    67  }
    68  
    69  func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) {
    70  	if config == nil {
    71  		return nil, ErrNilConfig
    72  	}
    73  
    74  	if err := conn.PingContext(ctx); err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	if config.DatabaseName == "" {
    79  		query := `SELECT CURRENT_DATABASE()`
    80  		var databaseName string
    81  		if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
    82  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    83  		}
    84  
    85  		if len(databaseName) == 0 {
    86  			return nil, ErrNoDatabaseName
    87  		}
    88  
    89  		config.DatabaseName = databaseName
    90  	}
    91  
    92  	if config.SchemaName == "" {
    93  		query := `SELECT CURRENT_SCHEMA()`
    94  		var schemaName string
    95  		if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
    96  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    97  		}
    98  
    99  		if len(schemaName) == 0 {
   100  			return nil, ErrNoSchema
   101  		}
   102  
   103  		config.SchemaName = schemaName
   104  	}
   105  
   106  	if len(config.MigrationsTable) == 0 {
   107  		config.MigrationsTable = DefaultMigrationsTable
   108  	}
   109  
   110  	config.migrationsSchemaName = config.SchemaName
   111  	config.migrationsTableName = config.MigrationsTable
   112  	if config.MigrationsTableQuoted {
   113  		re := regexp.MustCompile(`"(.*?)"`)
   114  		result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
   115  		config.migrationsTableName = result[len(result)-1][1]
   116  		if len(result) == 2 {
   117  			config.migrationsSchemaName = result[0][1]
   118  		} else if len(result) > 2 {
   119  			return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
   120  		}
   121  	}
   122  
   123  	px := &Postgres{
   124  		conn:   conn,
   125  		config: config,
   126  	}
   127  
   128  	if err := px.ensureVersionTable(); err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	return px, nil
   133  }
   134  
   135  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
   136  	ctx := context.Background()
   137  
   138  	if err := instance.Ping(); err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	conn, err := instance.Conn(ctx)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	px, err := WithConnection(ctx, conn, config)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	px.db = instance
   152  	return px, nil
   153  }
   154  
   155  func (p *Postgres) Open(url string) (database.Driver, error) {
   156  	purl, err := nurl.Parse(url)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	migrationsTable := purl.Query().Get("x-migrations-table")
   167  	migrationsTableQuoted := false
   168  	if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
   169  		migrationsTableQuoted, err = strconv.ParseBool(s)
   170  		if err != nil {
   171  			return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
   172  		}
   173  	}
   174  	if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
   175  		return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
   176  	}
   177  
   178  	statementTimeoutString := purl.Query().Get("x-statement-timeout")
   179  	statementTimeout := 0
   180  	if statementTimeoutString != "" {
   181  		statementTimeout, err = strconv.Atoi(statementTimeoutString)
   182  		if err != nil {
   183  			return nil, err
   184  		}
   185  	}
   186  
   187  	multiStatementMaxSize := DefaultMultiStatementMaxSize
   188  	if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
   189  		multiStatementMaxSize, err = strconv.Atoi(s)
   190  		if err != nil {
   191  			return nil, err
   192  		}
   193  		if multiStatementMaxSize <= 0 {
   194  			multiStatementMaxSize = DefaultMultiStatementMaxSize
   195  		}
   196  	}
   197  
   198  	multiStatementEnabled := false
   199  	if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
   200  		multiStatementEnabled, err = strconv.ParseBool(s)
   201  		if err != nil {
   202  			return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
   203  		}
   204  	}
   205  
   206  	px, err := WithInstance(db, &Config{
   207  		DatabaseName:          purl.Path,
   208  		MigrationsTable:       migrationsTable,
   209  		MigrationsTableQuoted: migrationsTableQuoted,
   210  		StatementTimeout:      time.Duration(statementTimeout) * time.Millisecond,
   211  		MultiStatementEnabled: multiStatementEnabled,
   212  		MultiStatementMaxSize: multiStatementMaxSize,
   213  	})
   214  
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  
   219  	return px, nil
   220  }
   221  
   222  func (p *Postgres) Close() error {
   223  	connErr := p.conn.Close()
   224  	var dbErr error
   225  	if p.db != nil {
   226  		dbErr = p.db.Close()
   227  	}
   228  
   229  	if connErr != nil || dbErr != nil {
   230  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   231  	}
   232  	return nil
   233  }
   234  
   235  // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
   236  func (p *Postgres) Lock() error {
   237  	return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
   238  		aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
   239  		if err != nil {
   240  			return err
   241  		}
   242  
   243  		// This will wait indefinitely until the lock can be acquired.
   244  		query := `SELECT pg_advisory_lock($1)`
   245  		if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   246  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   247  		}
   248  
   249  		return nil
   250  	})
   251  }
   252  
   253  func (p *Postgres) Unlock() error {
   254  	return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
   255  		aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
   256  		if err != nil {
   257  			return err
   258  		}
   259  
   260  		query := `SELECT pg_advisory_unlock($1)`
   261  		if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   262  			return &database.Error{OrigErr: err, Query: []byte(query)}
   263  		}
   264  		return nil
   265  	})
   266  }
   267  
   268  func (p *Postgres) Run(migration io.Reader) error {
   269  	if p.config.MultiStatementEnabled {
   270  		var err error
   271  		if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
   272  			if err = p.runStatement(m); err != nil {
   273  				return false
   274  			}
   275  			return true
   276  		}); e != nil {
   277  			return e
   278  		}
   279  		return err
   280  	}
   281  	migr, err := ioutil.ReadAll(migration)
   282  	if err != nil {
   283  		return err
   284  	}
   285  	return p.runStatement(migr)
   286  }
   287  
   288  func (p *Postgres) runStatement(statement []byte) error {
   289  	ctx := context.Background()
   290  	if p.config.StatementTimeout != 0 {
   291  		var cancel context.CancelFunc
   292  		ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
   293  		defer cancel()
   294  	}
   295  	query := string(statement)
   296  	if strings.TrimSpace(query) == "" {
   297  		return nil
   298  	}
   299  	if _, err := p.conn.ExecContext(ctx, query); err != nil {
   300  		if pgErr, ok := err.(*pq.Error); ok {
   301  			var line uint
   302  			var col uint
   303  			var lineColOK bool
   304  			if pgErr.Position != "" {
   305  				if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
   306  					line, col, lineColOK = computeLineFromPos(query, int(pos))
   307  				}
   308  			}
   309  			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
   310  			if lineColOK {
   311  				message = fmt.Sprintf("%s (column %d)", message, col)
   312  			}
   313  			if pgErr.Detail != "" {
   314  				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
   315  			}
   316  			return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
   317  		}
   318  		return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
   319  	}
   320  	return nil
   321  }
   322  
   323  func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
   324  	// replace crlf with lf
   325  	s = strings.Replace(s, "\r\n", "\n", -1)
   326  	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
   327  	runes := []rune(s)
   328  	if pos > len(runes) {
   329  		return 0, 0, false
   330  	}
   331  	sel := runes[:pos]
   332  	line = uint(runesCount(sel, newLine) + 1)
   333  	col = uint(pos - 1 - runesLastIndex(sel, newLine))
   334  	return line, col, true
   335  }
   336  
   337  const newLine = '\n'
   338  
   339  func runesCount(input []rune, target rune) int {
   340  	var count int
   341  	for _, r := range input {
   342  		if r == target {
   343  			count++
   344  		}
   345  	}
   346  	return count
   347  }
   348  
   349  func runesLastIndex(input []rune, target rune) int {
   350  	for i := len(input) - 1; i >= 0; i-- {
   351  		if input[i] == target {
   352  			return i
   353  		}
   354  	}
   355  	return -1
   356  }
   357  
   358  func (p *Postgres) SetVersion(version int, dirty bool) error {
   359  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   360  	if err != nil {
   361  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   362  	}
   363  
   364  	query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName)
   365  	if _, err := tx.Exec(query); err != nil {
   366  		if errRollback := tx.Rollback(); errRollback != nil {
   367  			err = multierror.Append(err, errRollback)
   368  		}
   369  		return &database.Error{OrigErr: err, Query: []byte(query)}
   370  	}
   371  
   372  	// Also re-write the schema version for nil dirty versions to prevent
   373  	// empty schema version for failed down migration on the first migration
   374  	// See: https://github.com/seashell-org/golang-migrate/issues/330
   375  	if version >= 0 || (version == database.NilVersion && dirty) {
   376  		query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
   377  		if _, err := tx.Exec(query, version, dirty); err != nil {
   378  			if errRollback := tx.Rollback(); errRollback != nil {
   379  				err = multierror.Append(err, errRollback)
   380  			}
   381  			return &database.Error{OrigErr: err, Query: []byte(query)}
   382  		}
   383  	}
   384  
   385  	if err := tx.Commit(); err != nil {
   386  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   387  	}
   388  
   389  	return nil
   390  }
   391  
   392  func (p *Postgres) Version() (version int, dirty bool, err error) {
   393  	query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
   394  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   395  	switch {
   396  	case err == sql.ErrNoRows:
   397  		return database.NilVersion, false, nil
   398  
   399  	case err != nil:
   400  		if e, ok := err.(*pq.Error); ok {
   401  			if e.Code.Name() == "undefined_table" {
   402  				return database.NilVersion, false, nil
   403  			}
   404  		}
   405  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   406  
   407  	default:
   408  		return version, dirty, nil
   409  	}
   410  }
   411  
   412  func (p *Postgres) Drop() (err error) {
   413  	// select all tables in current schema
   414  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
   415  	tables, err := p.conn.QueryContext(context.Background(), query)
   416  	if err != nil {
   417  		return &database.Error{OrigErr: err, Query: []byte(query)}
   418  	}
   419  	defer func() {
   420  		if errClose := tables.Close(); errClose != nil {
   421  			err = multierror.Append(err, errClose)
   422  		}
   423  	}()
   424  
   425  	// delete one table after another
   426  	tableNames := make([]string, 0)
   427  	for tables.Next() {
   428  		var tableName string
   429  		if err := tables.Scan(&tableName); err != nil {
   430  			return err
   431  		}
   432  		if len(tableName) > 0 {
   433  			tableNames = append(tableNames, tableName)
   434  		}
   435  	}
   436  	if err := tables.Err(); err != nil {
   437  		return &database.Error{OrigErr: err, Query: []byte(query)}
   438  	}
   439  
   440  	if len(tableNames) > 0 {
   441  		// delete one by one ...
   442  		for _, t := range tableNames {
   443  			query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
   444  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   445  				return &database.Error{OrigErr: err, Query: []byte(query)}
   446  			}
   447  		}
   448  	}
   449  
   450  	return nil
   451  }
   452  
   453  // ensureVersionTable checks if versions table exists and, if not, creates it.
   454  // Note that this function locks the database, which deviates from the usual
   455  // convention of "caller locks" in the Postgres type.
   456  func (p *Postgres) ensureVersionTable() (err error) {
   457  	if err = p.Lock(); err != nil {
   458  		return err
   459  	}
   460  
   461  	defer func() {
   462  		if e := p.Unlock(); e != nil {
   463  			if err == nil {
   464  				err = e
   465  			} else {
   466  				err = multierror.Append(err, e)
   467  			}
   468  		}
   469  	}()
   470  
   471  	// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
   472  	// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
   473  	// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
   474  	// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
   475  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
   476  	row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
   477  
   478  	var count int
   479  	err = row.Scan(&count)
   480  	if err != nil {
   481  		return &database.Error{OrigErr: err, Query: []byte(query)}
   482  	}
   483  
   484  	if count == 1 {
   485  		return nil
   486  	}
   487  
   488  	query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
   489  	if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
   490  		return &database.Error{OrigErr: err, Query: []byte(query)}
   491  	}
   492  
   493  	return nil
   494  }