github.com/matcornic/migrate@v3.3.2-0.20180717234201-feea45c20506+incompatible/database/postgres/postgres.go (about)

     1  // +build go1.9
     2  
     3  package postgres
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	nurl "net/url"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"github.com/golang-migrate/migrate"
    16  	"github.com/golang-migrate/migrate/database"
    17  	"github.com/lib/pq"
    18  )
    19  
    20  func init() {
    21  	db := Postgres{}
    22  	database.Register("postgres", &db)
    23  	database.Register("postgresql", &db)
    24  }
    25  
    26  var DefaultMigrationsTable = "schema_migrations"
    27  
    28  var (
    29  	ErrNilConfig      = fmt.Errorf("no config")
    30  	ErrNoDatabaseName = fmt.Errorf("no database name")
    31  	ErrNoSchema       = fmt.Errorf("no schema")
    32  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    33  )
    34  
    35  type Config struct {
    36  	MigrationsTable string
    37  	DatabaseName    string
    38  }
    39  
    40  type Postgres struct {
    41  	// Locking and unlocking need to use the same connection
    42  	conn     *sql.Conn
    43  	isLocked bool
    44  
    45  	// Open and WithInstance need to garantuee that config is never nil
    46  	config *Config
    47  }
    48  
    49  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    50  	if config == nil {
    51  		return nil, ErrNilConfig
    52  	}
    53  
    54  	if err := instance.Ping(); err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	query := `SELECT CURRENT_DATABASE()`
    59  	var databaseName string
    60  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    61  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    62  	}
    63  
    64  	if len(databaseName) == 0 {
    65  		return nil, ErrNoDatabaseName
    66  	}
    67  
    68  	config.DatabaseName = databaseName
    69  
    70  	if len(config.MigrationsTable) == 0 {
    71  		config.MigrationsTable = DefaultMigrationsTable
    72  	}
    73  
    74  	conn, err := instance.Conn(context.Background())
    75  
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	px := &Postgres{
    81  		conn:   conn,
    82  		config: config,
    83  	}
    84  
    85  	if err := px.ensureVersionTable(); err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	return px, nil
    90  }
    91  
    92  func (p *Postgres) Open(url string) (database.Driver, error) {
    93  	purl, err := nurl.Parse(url)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	migrationsTable := purl.Query().Get("x-migrations-table")
   104  	if len(migrationsTable) == 0 {
   105  		migrationsTable = DefaultMigrationsTable
   106  	}
   107  
   108  	px, err := WithInstance(db, &Config{
   109  		DatabaseName:    purl.Path,
   110  		MigrationsTable: migrationsTable,
   111  	})
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	return px, nil
   117  }
   118  
   119  func (p *Postgres) Close() error {
   120  	return p.conn.Close()
   121  }
   122  
   123  // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
   124  func (p *Postgres) Lock() error {
   125  	if p.isLocked {
   126  		return database.ErrLocked
   127  	}
   128  
   129  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
   130  	if err != nil {
   131  		return err
   132  	}
   133  
   134  	// This will either obtain the lock immediately and return true,
   135  	// or return false if the lock cannot be acquired immediately.
   136  	query := `SELECT pg_advisory_lock($1)`
   137  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   138  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   139  	}
   140  
   141  	p.isLocked = true
   142  	return nil
   143  }
   144  
   145  func (p *Postgres) Unlock() error {
   146  	if !p.isLocked {
   147  		return nil
   148  	}
   149  
   150  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
   151  	if err != nil {
   152  		return err
   153  	}
   154  
   155  	query := `SELECT pg_advisory_unlock($1)`
   156  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   157  		return &database.Error{OrigErr: err, Query: []byte(query)}
   158  	}
   159  	p.isLocked = false
   160  	return nil
   161  }
   162  
   163  func (p *Postgres) Run(migration io.Reader) error {
   164  	migr, err := ioutil.ReadAll(migration)
   165  	if err != nil {
   166  		return err
   167  	}
   168  
   169  	// run migration
   170  	query := string(migr[:])
   171  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   172  		if pgErr, ok := err.(*pq.Error); ok {
   173  			var line uint
   174  			var col uint
   175  			var lineColOK bool
   176  			if pgErr.Position != "" {
   177  				if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
   178  					line, col, lineColOK = computeLineFromPos(query, int(pos))
   179  				}
   180  			}
   181  			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
   182  			if lineColOK {
   183  				message = fmt.Sprintf("%s (column %d)", message, col)
   184  			}
   185  			if pgErr.Detail != "" {
   186  				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
   187  			}
   188  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
   189  		}
   190  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   191  	}
   192  
   193  	return nil
   194  }
   195  
   196  func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
   197  	// replace crlf with lf
   198  	s = strings.Replace(s, "\r\n", "\n", -1)
   199  	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
   200  	runes := []rune(s)
   201  	if pos > len(runes) {
   202  		return 0, 0, false
   203  	}
   204  	sel := runes[:pos]
   205  	line = uint(runesCount(sel, newLine) + 1)
   206  	col = uint(pos - 1 - runesLastIndex(sel, newLine))
   207  	return line, col, true
   208  }
   209  
   210  const newLine = '\n'
   211  
   212  func runesCount(input []rune, target rune) int {
   213  	var count int
   214  	for _, r := range input {
   215  		if r == target {
   216  			count++
   217  		}
   218  	}
   219  	return count
   220  }
   221  
   222  func runesLastIndex(input []rune, target rune) int {
   223  	for i := len(input) - 1; i >= 0; i-- {
   224  		if input[i] == target {
   225  			return i
   226  		}
   227  	}
   228  	return -1
   229  }
   230  
   231  func (p *Postgres) SetVersion(version int, dirty bool) error {
   232  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   233  	if err != nil {
   234  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   235  	}
   236  
   237  	query := `TRUNCATE "` + p.config.MigrationsTable + `"`
   238  	if _, err := tx.Exec(query); err != nil {
   239  		tx.Rollback()
   240  		return &database.Error{OrigErr: err, Query: []byte(query)}
   241  	}
   242  
   243  	if version >= 0 {
   244  		query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
   245  		if _, err := tx.Exec(query, version, dirty); err != nil {
   246  			tx.Rollback()
   247  			return &database.Error{OrigErr: err, Query: []byte(query)}
   248  		}
   249  	}
   250  
   251  	if err := tx.Commit(); err != nil {
   252  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   253  	}
   254  
   255  	return nil
   256  }
   257  
   258  func (p *Postgres) Version() (version int, dirty bool, err error) {
   259  	query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
   260  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   261  	switch {
   262  	case err == sql.ErrNoRows:
   263  		return database.NilVersion, false, nil
   264  
   265  	case err != nil:
   266  		if e, ok := err.(*pq.Error); ok {
   267  			if e.Code.Name() == "undefined_table" {
   268  				return database.NilVersion, false, nil
   269  			}
   270  		}
   271  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   272  
   273  	default:
   274  		return version, dirty, nil
   275  	}
   276  }
   277  
   278  func (p *Postgres) Drop() error {
   279  	// select all tables in current schema
   280  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
   281  	tables, err := p.conn.QueryContext(context.Background(), query)
   282  	if err != nil {
   283  		return &database.Error{OrigErr: err, Query: []byte(query)}
   284  	}
   285  	defer tables.Close()
   286  
   287  	// delete one table after another
   288  	tableNames := make([]string, 0)
   289  	for tables.Next() {
   290  		var tableName string
   291  		if err := tables.Scan(&tableName); err != nil {
   292  			return err
   293  		}
   294  		if len(tableName) > 0 {
   295  			tableNames = append(tableNames, tableName)
   296  		}
   297  	}
   298  
   299  	if len(tableNames) > 0 {
   300  		// delete one by one ...
   301  		for _, t := range tableNames {
   302  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   303  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   304  				return &database.Error{OrigErr: err, Query: []byte(query)}
   305  			}
   306  		}
   307  		if err := p.ensureVersionTable(); err != nil {
   308  			return err
   309  		}
   310  	}
   311  
   312  	return nil
   313  }
   314  
   315  func (p *Postgres) ensureVersionTable() error {
   316  	// check if migration table exists
   317  	var count int
   318  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   319  	if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
   320  		return &database.Error{OrigErr: err, Query: []byte(query)}
   321  	}
   322  	if count == 1 {
   323  		return nil
   324  	}
   325  
   326  	// if not, create the empty migration table
   327  	query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
   328  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   329  		return &database.Error{OrigErr: err, Query: []byte(query)}
   330  	}
   331  	return nil
   332  }