github.com/SmoothieNoIce/migrate@v3.5.4+incompatible/database/postgres/postgres.go (about)

     1  // +build go1.9
     2  
     3  package postgres
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	nurl "net/url"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"github.com/golang-migrate/migrate"
    16  	"github.com/golang-migrate/migrate/database"
    17  	"github.com/lib/pq"
    18  )
    19  
    20  func init() {
    21  	db := Postgres{}
    22  	database.Register("postgres", &db)
    23  	database.Register("postgresql", &db)
    24  }
    25  
    26  var DefaultMigrationsTable = "schema_migrations"
    27  
    28  var (
    29  	ErrNilConfig      = fmt.Errorf("no config")
    30  	ErrNoDatabaseName = fmt.Errorf("no database name")
    31  	ErrNoSchema       = fmt.Errorf("no schema")
    32  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    33  )
    34  
    35  type Config struct {
    36  	MigrationsTable string
    37  	DatabaseName    string
    38  }
    39  
    40  type Postgres struct {
    41  	// Locking and unlocking need to use the same connection
    42  	conn     *sql.Conn
    43  	db       *sql.DB
    44  	isLocked bool
    45  
    46  	// Open and WithInstance need to garantuee that config is never nil
    47  	config *Config
    48  }
    49  
    50  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    51  	if config == nil {
    52  		return nil, ErrNilConfig
    53  	}
    54  
    55  	if err := instance.Ping(); err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	query := `SELECT CURRENT_DATABASE()`
    60  	var databaseName string
    61  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    62  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    63  	}
    64  
    65  	if len(databaseName) == 0 {
    66  		return nil, ErrNoDatabaseName
    67  	}
    68  
    69  	config.DatabaseName = databaseName
    70  
    71  	if len(config.MigrationsTable) == 0 {
    72  		config.MigrationsTable = DefaultMigrationsTable
    73  	}
    74  
    75  	conn, err := instance.Conn(context.Background())
    76  
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	px := &Postgres{
    82  		conn:   conn,
    83  		db:     instance,
    84  		config: config,
    85  	}
    86  
    87  	if err := px.ensureVersionTable(); err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	return px, nil
    92  }
    93  
    94  func (p *Postgres) Open(url string) (database.Driver, error) {
    95  	purl, err := nurl.Parse(url)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	migrationsTable := purl.Query().Get("x-migrations-table")
   106  	if len(migrationsTable) == 0 {
   107  		migrationsTable = DefaultMigrationsTable
   108  	}
   109  
   110  	px, err := WithInstance(db, &Config{
   111  		DatabaseName:    purl.Path,
   112  		MigrationsTable: migrationsTable,
   113  	})
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	return px, nil
   119  }
   120  
   121  func (p *Postgres) Close() error {
   122  	connErr := p.conn.Close()
   123  	dbErr := p.db.Close()
   124  	if connErr != nil || dbErr != nil {
   125  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   126  	}
   127  	return nil
   128  }
   129  
   130  // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
   131  func (p *Postgres) Lock() error {
   132  	if p.isLocked {
   133  		return database.ErrLocked
   134  	}
   135  
   136  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	// This will either obtain the lock immediately and return true,
   142  	// or return false if the lock cannot be acquired immediately.
   143  	query := `SELECT pg_advisory_lock($1)`
   144  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   145  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   146  	}
   147  
   148  	p.isLocked = true
   149  	return nil
   150  }
   151  
   152  func (p *Postgres) Unlock() error {
   153  	if !p.isLocked {
   154  		return nil
   155  	}
   156  
   157  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
   158  	if err != nil {
   159  		return err
   160  	}
   161  
   162  	query := `SELECT pg_advisory_unlock($1)`
   163  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   164  		return &database.Error{OrigErr: err, Query: []byte(query)}
   165  	}
   166  	p.isLocked = false
   167  	return nil
   168  }
   169  
   170  func (p *Postgres) Run(migration io.Reader) error {
   171  	migr, err := ioutil.ReadAll(migration)
   172  	if err != nil {
   173  		return err
   174  	}
   175  
   176  	// run migration
   177  	query := string(migr[:])
   178  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   179  		if pgErr, ok := err.(*pq.Error); ok {
   180  			var line uint
   181  			var col uint
   182  			var lineColOK bool
   183  			if pgErr.Position != "" {
   184  				if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
   185  					line, col, lineColOK = computeLineFromPos(query, int(pos))
   186  				}
   187  			}
   188  			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
   189  			if lineColOK {
   190  				message = fmt.Sprintf("%s (column %d)", message, col)
   191  			}
   192  			if pgErr.Detail != "" {
   193  				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
   194  			}
   195  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
   196  		}
   197  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   198  	}
   199  
   200  	return nil
   201  }
   202  
   203  func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
   204  	// replace crlf with lf
   205  	s = strings.Replace(s, "\r\n", "\n", -1)
   206  	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
   207  	runes := []rune(s)
   208  	if pos > len(runes) {
   209  		return 0, 0, false
   210  	}
   211  	sel := runes[:pos]
   212  	line = uint(runesCount(sel, newLine) + 1)
   213  	col = uint(pos - 1 - runesLastIndex(sel, newLine))
   214  	return line, col, true
   215  }
   216  
   217  const newLine = '\n'
   218  
   219  func runesCount(input []rune, target rune) int {
   220  	var count int
   221  	for _, r := range input {
   222  		if r == target {
   223  			count++
   224  		}
   225  	}
   226  	return count
   227  }
   228  
   229  func runesLastIndex(input []rune, target rune) int {
   230  	for i := len(input) - 1; i >= 0; i-- {
   231  		if input[i] == target {
   232  			return i
   233  		}
   234  	}
   235  	return -1
   236  }
   237  
   238  func (p *Postgres) SetVersion(version int, dirty bool) error {
   239  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   240  	if err != nil {
   241  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   242  	}
   243  
   244  	query := `TRUNCATE "` + p.config.MigrationsTable + `"`
   245  	if _, err := tx.Exec(query); err != nil {
   246  		tx.Rollback()
   247  		return &database.Error{OrigErr: err, Query: []byte(query)}
   248  	}
   249  
   250  	if version >= 0 {
   251  		query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
   252  		if _, err := tx.Exec(query, version, dirty); err != nil {
   253  			tx.Rollback()
   254  			return &database.Error{OrigErr: err, Query: []byte(query)}
   255  		}
   256  	}
   257  
   258  	if err := tx.Commit(); err != nil {
   259  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   260  	}
   261  
   262  	return nil
   263  }
   264  
   265  func (p *Postgres) Version() (version int, dirty bool, err error) {
   266  	query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
   267  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   268  	switch {
   269  	case err == sql.ErrNoRows:
   270  		return database.NilVersion, false, nil
   271  
   272  	case err != nil:
   273  		if e, ok := err.(*pq.Error); ok {
   274  			if e.Code.Name() == "undefined_table" {
   275  				return database.NilVersion, false, nil
   276  			}
   277  		}
   278  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   279  
   280  	default:
   281  		return version, dirty, nil
   282  	}
   283  }
   284  
   285  func (p *Postgres) Drop() error {
   286  	// select all tables in current schema
   287  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
   288  	tables, err := p.conn.QueryContext(context.Background(), query)
   289  	if err != nil {
   290  		return &database.Error{OrigErr: err, Query: []byte(query)}
   291  	}
   292  	defer tables.Close()
   293  
   294  	// delete one table after another
   295  	tableNames := make([]string, 0)
   296  	for tables.Next() {
   297  		var tableName string
   298  		if err := tables.Scan(&tableName); err != nil {
   299  			return err
   300  		}
   301  		if len(tableName) > 0 {
   302  			tableNames = append(tableNames, tableName)
   303  		}
   304  	}
   305  
   306  	if len(tableNames) > 0 {
   307  		// delete one by one ...
   308  		for _, t := range tableNames {
   309  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   310  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   311  				return &database.Error{OrigErr: err, Query: []byte(query)}
   312  			}
   313  		}
   314  		if err := p.ensureVersionTable(); err != nil {
   315  			return err
   316  		}
   317  	}
   318  
   319  	return nil
   320  }
   321  
   322  func (p *Postgres) ensureVersionTable() error {
   323  	// check if migration table exists
   324  	var count int
   325  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   326  	if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
   327  		return &database.Error{OrigErr: err, Query: []byte(query)}
   328  	}
   329  	if count == 1 {
   330  		return nil
   331  	}
   332  
   333  	// if not, create the empty migration table
   334  	query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
   335  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   336  		return &database.Error{OrigErr: err, Query: []byte(query)}
   337  	}
   338  	return nil
   339  }