github.com/fr-nvriep/migrate/v4@v4.3.2/database/postgres/postgres.go (about)

     1  // +build go1.9
     2  
     3  package postgres
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	nurl "net/url"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"github.com/fr-nvriep/migrate/v4"
    16  	"github.com/fr-nvriep/migrate/v4/database"
    17  	multierror "github.com/hashicorp/go-multierror"
    18  	"github.com/lib/pq"
    19  )
    20  
    21  func init() {
    22  	db := Postgres{}
    23  	database.Register("postgres", &db)
    24  	database.Register("postgresql", &db)
    25  }
    26  
    27  var DefaultMigrationsTable = "schema_migrations"
    28  
    29  var (
    30  	ErrNilConfig      = fmt.Errorf("no config")
    31  	ErrNoDatabaseName = fmt.Errorf("no database name")
    32  	ErrNoSchema       = fmt.Errorf("no schema")
    33  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    34  )
    35  
    36  type Config struct {
    37  	MigrationsTable string
    38  	DatabaseName    string
    39  	SchemaName      string
    40  }
    41  
    42  type Postgres struct {
    43  	// Locking and unlocking need to use the same connection
    44  	conn     *sql.Conn
    45  	db       *sql.DB
    46  	isLocked bool
    47  
    48  	// Open and WithInstance need to guarantee that config is never nil
    49  	config *Config
    50  }
    51  
    52  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    53  	if config == nil {
    54  		return nil, ErrNilConfig
    55  	}
    56  
    57  	if err := instance.Ping(); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	query := `SELECT CURRENT_DATABASE()`
    62  	var databaseName string
    63  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    64  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    65  	}
    66  
    67  	if len(databaseName) == 0 {
    68  		return nil, ErrNoDatabaseName
    69  	}
    70  
    71  	config.DatabaseName = databaseName
    72  
    73  	query = `SELECT CURRENT_SCHEMA()`
    74  	var schemaName string
    75  	if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    76  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    77  	}
    78  
    79  	if len(schemaName) == 0 {
    80  		return nil, ErrNoSchema
    81  	}
    82  
    83  	config.SchemaName = schemaName
    84  
    85  	if len(config.MigrationsTable) == 0 {
    86  		config.MigrationsTable = DefaultMigrationsTable
    87  	}
    88  
    89  	conn, err := instance.Conn(context.Background())
    90  
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	px := &Postgres{
    96  		conn:   conn,
    97  		db:     instance,
    98  		config: config,
    99  	}
   100  
   101  	if err := px.ensureVersionTable(); err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	return px, nil
   106  }
   107  
   108  func (p *Postgres) Open(url string) (database.Driver, error) {
   109  	purl, err := nurl.Parse(url)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	migrationsTable := purl.Query().Get("x-migrations-table")
   120  
   121  	px, err := WithInstance(db, &Config{
   122  		DatabaseName:    purl.Path,
   123  		MigrationsTable: migrationsTable,
   124  	})
   125  
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	return px, nil
   131  }
   132  
   133  func (p *Postgres) Close() error {
   134  	connErr := p.conn.Close()
   135  	dbErr := p.db.Close()
   136  	if connErr != nil || dbErr != nil {
   137  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   138  	}
   139  	return nil
   140  }
   141  
   142  // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
   143  func (p *Postgres) Lock() error {
   144  	if p.isLocked {
   145  		return database.ErrLocked
   146  	}
   147  
   148  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	// This will either obtain the lock immediately and return true,
   154  	// or return false if the lock cannot be acquired immediately.
   155  	query := `SELECT pg_advisory_lock($1)`
   156  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   157  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   158  	}
   159  
   160  	p.isLocked = true
   161  	return nil
   162  }
   163  
   164  func (p *Postgres) Unlock() error {
   165  	if !p.isLocked {
   166  		return nil
   167  	}
   168  
   169  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
   170  	if err != nil {
   171  		return err
   172  	}
   173  
   174  	query := `SELECT pg_advisory_unlock($1)`
   175  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   176  		return &database.Error{OrigErr: err, Query: []byte(query)}
   177  	}
   178  	p.isLocked = false
   179  	return nil
   180  }
   181  
   182  func (p *Postgres) Run(migration io.Reader) error {
   183  	migr, err := ioutil.ReadAll(migration)
   184  	if err != nil {
   185  		return err
   186  	}
   187  
   188  	// run migration
   189  	query := string(migr[:])
   190  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   191  		if pgErr, ok := err.(*pq.Error); ok {
   192  			var line uint
   193  			var col uint
   194  			var lineColOK bool
   195  			if pgErr.Position != "" {
   196  				if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
   197  					line, col, lineColOK = computeLineFromPos(query, int(pos))
   198  				}
   199  			}
   200  			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
   201  			if lineColOK {
   202  				message = fmt.Sprintf("%s (column %d)", message, col)
   203  			}
   204  			if pgErr.Detail != "" {
   205  				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
   206  			}
   207  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
   208  		}
   209  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   210  	}
   211  
   212  	return nil
   213  }
   214  
   215  func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
   216  	// replace crlf with lf
   217  	s = strings.Replace(s, "\r\n", "\n", -1)
   218  	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
   219  	runes := []rune(s)
   220  	if pos > len(runes) {
   221  		return 0, 0, false
   222  	}
   223  	sel := runes[:pos]
   224  	line = uint(runesCount(sel, newLine) + 1)
   225  	col = uint(pos - 1 - runesLastIndex(sel, newLine))
   226  	return line, col, true
   227  }
   228  
   229  const newLine = '\n'
   230  
   231  func runesCount(input []rune, target rune) int {
   232  	var count int
   233  	for _, r := range input {
   234  		if r == target {
   235  			count++
   236  		}
   237  	}
   238  	return count
   239  }
   240  
   241  func runesLastIndex(input []rune, target rune) int {
   242  	for i := len(input) - 1; i >= 0; i-- {
   243  		if input[i] == target {
   244  			return i
   245  		}
   246  	}
   247  	return -1
   248  }
   249  
   250  func (p *Postgres) SetVersion(version int, dirty bool) error {
   251  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   252  	if err != nil {
   253  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   254  	}
   255  
   256  	query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable)
   257  	if _, err := tx.Exec(query); err != nil {
   258  		if errRollback := tx.Rollback(); errRollback != nil {
   259  			err = multierror.Append(err, errRollback)
   260  		}
   261  		return &database.Error{OrigErr: err, Query: []byte(query)}
   262  	}
   263  
   264  	if version >= 0 {
   265  		query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)`
   266  		if _, err := tx.Exec(query, version, dirty); err != nil {
   267  			if errRollback := tx.Rollback(); errRollback != nil {
   268  				err = multierror.Append(err, errRollback)
   269  			}
   270  			return &database.Error{OrigErr: err, Query: []byte(query)}
   271  		}
   272  	}
   273  
   274  	if err := tx.Commit(); err != nil {
   275  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   276  	}
   277  
   278  	return nil
   279  }
   280  
   281  func (p *Postgres) Version() (version int, dirty bool, err error) {
   282  	query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
   283  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   284  	switch {
   285  	case err == sql.ErrNoRows:
   286  		return database.NilVersion, false, nil
   287  
   288  	case err != nil:
   289  		if e, ok := err.(*pq.Error); ok {
   290  			if e.Code.Name() == "undefined_table" {
   291  				return database.NilVersion, false, nil
   292  			}
   293  		}
   294  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   295  
   296  	default:
   297  		return version, dirty, nil
   298  	}
   299  }
   300  
   301  func (p *Postgres) Drop() (err error) {
   302  	// select all tables in current schema
   303  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
   304  	tables, err := p.conn.QueryContext(context.Background(), query)
   305  	if err != nil {
   306  		return &database.Error{OrigErr: err, Query: []byte(query)}
   307  	}
   308  	defer func() {
   309  		if errClose := tables.Close(); errClose != nil {
   310  			err = multierror.Append(err, errClose)
   311  		}
   312  	}()
   313  
   314  	// delete one table after another
   315  	tableNames := make([]string, 0)
   316  	for tables.Next() {
   317  		var tableName string
   318  		if err := tables.Scan(&tableName); err != nil {
   319  			return err
   320  		}
   321  		if len(tableName) > 0 {
   322  			tableNames = append(tableNames, tableName)
   323  		}
   324  	}
   325  
   326  	if len(tableNames) > 0 {
   327  		// delete one by one ...
   328  		for _, t := range tableNames {
   329  			query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
   330  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   331  				return &database.Error{OrigErr: err, Query: []byte(query)}
   332  			}
   333  		}
   334  	}
   335  
   336  	return nil
   337  }
   338  
   339  // ensureVersionTable checks if versions table exists and, if not, creates it.
   340  // Note that this function locks the database, which deviates from the usual
   341  // convention of "caller locks" in the Postgres type.
   342  func (p *Postgres) ensureVersionTable() (err error) {
   343  	if err = p.Lock(); err != nil {
   344  		return err
   345  	}
   346  
   347  	defer func() {
   348  		if e := p.Unlock(); e != nil {
   349  			if err == nil {
   350  				err = e
   351  			} else {
   352  				err = multierror.Append(err, e)
   353  			}
   354  		}
   355  	}()
   356  
   357  	// check if migration table exists
   358  	var count int
   359  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   360  	if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
   361  		return &database.Error{OrigErr: err, Query: []byte(query)}
   362  	}
   363  	if count == 1 {
   364  		return nil
   365  	}
   366  
   367  	query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
   368  	if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
   369  		return &database.Error{OrigErr: err, Query: []byte(query)}
   370  	}
   371  
   372  	return nil
   373  }