github.com/qdentity/migrate@v3.3.0+incompatible/database/postgres/postgres.go (about)

     1  // +build go1.9
     2  
     3  package postgres
     4  
     5  import (
     6  	"database/sql"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	nurl "net/url"
    11  
    12  	"context"
    13  	"github.com/golang-migrate/migrate"
    14  	"github.com/golang-migrate/migrate/database"
    15  	"github.com/lib/pq"
    16  )
    17  
    18  func init() {
    19  	db := Postgres{}
    20  	database.Register("postgres", &db)
    21  	database.Register("postgresql", &db)
    22  }
    23  
    24  var DefaultMigrationsTable = "schema_migrations"
    25  
    26  var (
    27  	ErrNilConfig      = fmt.Errorf("no config")
    28  	ErrNoDatabaseName = fmt.Errorf("no database name")
    29  	ErrNoSchema       = fmt.Errorf("no schema")
    30  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    31  )
    32  
    33  type Config struct {
    34  	MigrationsTable string
    35  	DatabaseName    string
    36  }
    37  
    38  type Postgres struct {
    39  	// Locking and unlocking need to use the same connection
    40  	conn     *sql.Conn
    41  	isLocked bool
    42  
    43  	// Open and WithInstance need to garantuee that config is never nil
    44  	config *Config
    45  }
    46  
    47  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    48  	if config == nil {
    49  		return nil, ErrNilConfig
    50  	}
    51  
    52  	if err := instance.Ping(); err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	query := `SELECT CURRENT_DATABASE()`
    57  	var databaseName string
    58  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    59  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    60  	}
    61  
    62  	if len(databaseName) == 0 {
    63  		return nil, ErrNoDatabaseName
    64  	}
    65  
    66  	config.DatabaseName = databaseName
    67  
    68  	if len(config.MigrationsTable) == 0 {
    69  		config.MigrationsTable = DefaultMigrationsTable
    70  	}
    71  
    72  	conn, err := instance.Conn(context.Background())
    73  
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	px := &Postgres{
    79  		conn:   conn,
    80  		config: config,
    81  	}
    82  
    83  	if err := px.ensureVersionTable(); err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	return px, nil
    88  }
    89  
    90  func (p *Postgres) Open(url string) (database.Driver, error) {
    91  	purl, err := nurl.Parse(url)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	migrationsTable := purl.Query().Get("x-migrations-table")
   102  	if len(migrationsTable) == 0 {
   103  		migrationsTable = DefaultMigrationsTable
   104  	}
   105  
   106  	px, err := WithInstance(db, &Config{
   107  		DatabaseName:    purl.Path,
   108  		MigrationsTable: migrationsTable,
   109  	})
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	return px, nil
   115  }
   116  
   117  func (p *Postgres) Close() error {
   118  	return p.conn.Close()
   119  }
   120  
   121  // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
   122  func (p *Postgres) Lock() error {
   123  	if p.isLocked {
   124  		return database.ErrLocked
   125  	}
   126  
   127  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
   128  	if err != nil {
   129  		return err
   130  	}
   131  
   132  	// This will either obtain the lock immediately and return true,
   133  	// or return false if the lock cannot be acquired immediately.
   134  	query := `SELECT pg_advisory_lock($1)`
   135  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   136  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   137  	}
   138  
   139  	p.isLocked = true
   140  	return nil
   141  }
   142  
   143  func (p *Postgres) Unlock() error {
   144  	if !p.isLocked {
   145  		return nil
   146  	}
   147  
   148  	aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	query := `SELECT pg_advisory_unlock($1)`
   154  	if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
   155  		return &database.Error{OrigErr: err, Query: []byte(query)}
   156  	}
   157  	p.isLocked = false
   158  	return nil
   159  }
   160  
   161  func (p *Postgres) Run(migration io.Reader) error {
   162  	migr, err := ioutil.ReadAll(migration)
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	// run migration
   168  	query := string(migr[:])
   169  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   170  		// TODO: cast to postgress error and get line number
   171  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   172  	}
   173  
   174  	return nil
   175  }
   176  
   177  func (p *Postgres) SetVersion(version int, dirty bool) error {
   178  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   179  	if err != nil {
   180  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   181  	}
   182  
   183  	query := `TRUNCATE "` + p.config.MigrationsTable + `"`
   184  	if _, err := tx.Exec(query); err != nil {
   185  		tx.Rollback()
   186  		return &database.Error{OrigErr: err, Query: []byte(query)}
   187  	}
   188  
   189  	if version >= 0 {
   190  		query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
   191  		if _, err := tx.Exec(query, version, dirty); err != nil {
   192  			tx.Rollback()
   193  			return &database.Error{OrigErr: err, Query: []byte(query)}
   194  		}
   195  	}
   196  
   197  	if err := tx.Commit(); err != nil {
   198  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   199  	}
   200  
   201  	return nil
   202  }
   203  
   204  func (p *Postgres) Version() (version int, dirty bool, err error) {
   205  	query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
   206  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   207  	switch {
   208  	case err == sql.ErrNoRows:
   209  		return database.NilVersion, false, nil
   210  
   211  	case err != nil:
   212  		if e, ok := err.(*pq.Error); ok {
   213  			if e.Code.Name() == "undefined_table" {
   214  				return database.NilVersion, false, nil
   215  			}
   216  		}
   217  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   218  
   219  	default:
   220  		return version, dirty, nil
   221  	}
   222  }
   223  
   224  func (p *Postgres) Drop() error {
   225  	// select all tables in current schema
   226  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
   227  	tables, err := p.conn.QueryContext(context.Background(), query)
   228  	if err != nil {
   229  		return &database.Error{OrigErr: err, Query: []byte(query)}
   230  	}
   231  	defer tables.Close()
   232  
   233  	// delete one table after another
   234  	tableNames := make([]string, 0)
   235  	for tables.Next() {
   236  		var tableName string
   237  		if err := tables.Scan(&tableName); err != nil {
   238  			return err
   239  		}
   240  		if len(tableName) > 0 {
   241  			tableNames = append(tableNames, tableName)
   242  		}
   243  	}
   244  
   245  	if len(tableNames) > 0 {
   246  		// delete one by one ...
   247  		for _, t := range tableNames {
   248  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   249  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   250  				return &database.Error{OrigErr: err, Query: []byte(query)}
   251  			}
   252  		}
   253  		if err := p.ensureVersionTable(); err != nil {
   254  			return err
   255  		}
   256  	}
   257  
   258  	return nil
   259  }
   260  
   261  func (p *Postgres) ensureVersionTable() error {
   262  	// check if migration table exists
   263  	var count int
   264  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   265  	if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
   266  		return &database.Error{OrigErr: err, Query: []byte(query)}
   267  	}
   268  	if count == 1 {
   269  		return nil
   270  	}
   271  
   272  	// if not, create the empty migration table
   273  	query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
   274  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   275  		return &database.Error{OrigErr: err, Query: []byte(query)}
   276  	}
   277  	return nil
   278  }