github.com/Elate-DevOps/migrate/v4@v4.0.12/database/redshift/redshift.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package redshift
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"fmt"
    10  	"io"
    11  	nurl "net/url"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"go.uber.org/atomic"
    16  
    17  	"github.com/Elate-DevOps/migrate/v4"
    18  	"github.com/Elate-DevOps/migrate/v4/database"
    19  	"github.com/hashicorp/go-multierror"
    20  	"github.com/lib/pq"
    21  )
    22  
    23  func init() {
    24  	db := Redshift{}
    25  	database.Register("redshift", &db)
    26  }
    27  
    28  var DefaultMigrationsTable = "schema_migrations"
    29  
    30  var (
    31  	ErrNilConfig      = fmt.Errorf("no config")
    32  	ErrNoDatabaseName = fmt.Errorf("no database name")
    33  )
    34  
    35  type Config struct {
    36  	MigrationsTable string
    37  	DatabaseName    string
    38  }
    39  
    40  type Redshift struct {
    41  	isLocked atomic.Bool
    42  	conn     *sql.Conn
    43  	db       *sql.DB
    44  
    45  	// Open and WithInstance need to guarantee that config is never nil
    46  	config *Config
    47  }
    48  
    49  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    50  	if config == nil {
    51  		return nil, ErrNilConfig
    52  	}
    53  
    54  	if err := instance.Ping(); err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	if config.DatabaseName == "" {
    59  		query := `SELECT CURRENT_DATABASE()`
    60  		var databaseName string
    61  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    62  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    63  		}
    64  
    65  		if len(databaseName) == 0 {
    66  			return nil, ErrNoDatabaseName
    67  		}
    68  
    69  		config.DatabaseName = databaseName
    70  	}
    71  
    72  	if len(config.MigrationsTable) == 0 {
    73  		config.MigrationsTable = DefaultMigrationsTable
    74  	}
    75  
    76  	conn, err := instance.Conn(context.Background())
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	px := &Redshift{
    82  		conn:   conn,
    83  		db:     instance,
    84  		config: config,
    85  	}
    86  
    87  	if err := px.ensureVersionTable(); err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	return px, nil
    92  }
    93  
    94  func (p *Redshift) Open(url string) (database.Driver, error) {
    95  	purl, err := nurl.Parse(url)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	purl.Scheme = "postgres"
   100  
   101  	db, err := sql.Open("postgres", migrate.FilterCustomQuery(purl).String())
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	migrationsTable := purl.Query().Get("x-migrations-table")
   107  
   108  	px, err := WithInstance(db, &Config{
   109  		DatabaseName:    purl.Path,
   110  		MigrationsTable: migrationsTable,
   111  	})
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	return px, nil
   117  }
   118  
   119  func (p *Redshift) Close() error {
   120  	connErr := p.conn.Close()
   121  	dbErr := p.db.Close()
   122  	if connErr != nil || dbErr != nil {
   123  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   124  	}
   125  	return nil
   126  }
   127  
   128  // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html
   129  func (p *Redshift) Lock() error {
   130  	if !p.isLocked.CAS(false, true) {
   131  		return database.ErrLocked
   132  	}
   133  	return nil
   134  }
   135  
   136  func (p *Redshift) Unlock() error {
   137  	if !p.isLocked.CAS(true, false) {
   138  		return database.ErrNotLocked
   139  	}
   140  	return nil
   141  }
   142  
   143  func (p *Redshift) Run(migration io.Reader) error {
   144  	migr, err := io.ReadAll(migration)
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	// run migration
   150  	query := string(migr[:])
   151  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   152  		if pgErr, ok := err.(*pq.Error); ok {
   153  			var line uint
   154  			var col uint
   155  			var lineColOK bool
   156  			if pgErr.Position != "" {
   157  				if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
   158  					line, col, lineColOK = computeLineFromPos(query, int(pos))
   159  				}
   160  			}
   161  			message := fmt.Sprintf("migration failed: %s", pgErr.Message)
   162  			if lineColOK {
   163  				message = fmt.Sprintf("%s (column %d)", message, col)
   164  			}
   165  			if pgErr.Detail != "" {
   166  				message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
   167  			}
   168  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
   169  		}
   170  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
   177  	// replace crlf with lf
   178  	s = strings.Replace(s, "\r\n", "\n", -1)
   179  	// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
   180  	runes := []rune(s)
   181  	if pos > len(runes) {
   182  		return 0, 0, false
   183  	}
   184  	sel := runes[:pos]
   185  	line = uint(runesCount(sel, newLine) + 1)
   186  	col = uint(pos - 1 - runesLastIndex(sel, newLine))
   187  	return line, col, true
   188  }
   189  
   190  const newLine = '\n'
   191  
   192  func runesCount(input []rune, target rune) int {
   193  	var count int
   194  	for _, r := range input {
   195  		if r == target {
   196  			count++
   197  		}
   198  	}
   199  	return count
   200  }
   201  
   202  func runesLastIndex(input []rune, target rune) int {
   203  	for i := len(input) - 1; i >= 0; i-- {
   204  		if input[i] == target {
   205  			return i
   206  		}
   207  	}
   208  	return -1
   209  }
   210  
   211  func (p *Redshift) SetVersion(version int, dirty bool) error {
   212  	tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
   213  	if err != nil {
   214  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   215  	}
   216  
   217  	query := `DELETE FROM "` + p.config.MigrationsTable + `"`
   218  	if _, err := tx.Exec(query); err != nil {
   219  		if errRollback := tx.Rollback(); errRollback != nil {
   220  			err = multierror.Append(err, errRollback)
   221  		}
   222  		return &database.Error{OrigErr: err, Query: []byte(query)}
   223  	}
   224  
   225  	// Also re-write the schema version for nil dirty versions to prevent
   226  	// empty schema version for failed down migration on the first migration
   227  	// See: https://github.com/golang-migrate/migrate/issues/330
   228  	if version >= 0 || (version == database.NilVersion && dirty) {
   229  		query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
   230  		if _, err := tx.Exec(query, version, dirty); err != nil {
   231  			if errRollback := tx.Rollback(); errRollback != nil {
   232  				err = multierror.Append(err, errRollback)
   233  			}
   234  			return &database.Error{OrigErr: err, Query: []byte(query)}
   235  		}
   236  	}
   237  
   238  	if err := tx.Commit(); err != nil {
   239  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  func (p *Redshift) Version() (version int, dirty bool, err error) {
   246  	query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
   247  	err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   248  	switch {
   249  	case err == sql.ErrNoRows:
   250  		return database.NilVersion, false, nil
   251  
   252  	case err != nil:
   253  		if e, ok := err.(*pq.Error); ok {
   254  			if e.Code.Name() == "undefined_table" {
   255  				return database.NilVersion, false, nil
   256  			}
   257  		}
   258  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   259  
   260  	default:
   261  		return version, dirty, nil
   262  	}
   263  }
   264  
   265  func (p *Redshift) Drop() (err error) {
   266  	// select all tables in current schema
   267  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
   268  	tables, err := p.conn.QueryContext(context.Background(), query)
   269  	if err != nil {
   270  		return &database.Error{OrigErr: err, Query: []byte(query)}
   271  	}
   272  	defer func() {
   273  		if errClose := tables.Close(); errClose != nil {
   274  			err = multierror.Append(err, errClose)
   275  		}
   276  	}()
   277  
   278  	// delete one table after another
   279  	tableNames := make([]string, 0)
   280  	for tables.Next() {
   281  		var tableName string
   282  		if err := tables.Scan(&tableName); err != nil {
   283  			return err
   284  		}
   285  		if len(tableName) > 0 {
   286  			tableNames = append(tableNames, tableName)
   287  		}
   288  	}
   289  	if err := tables.Err(); err != nil {
   290  		return &database.Error{OrigErr: err, Query: []byte(query)}
   291  	}
   292  
   293  	if len(tableNames) > 0 {
   294  		// delete one by one ...
   295  		for _, t := range tableNames {
   296  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   297  			if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   298  				return &database.Error{OrigErr: err, Query: []byte(query)}
   299  			}
   300  		}
   301  	}
   302  
   303  	return nil
   304  }
   305  
   306  // ensureVersionTable checks if versions table exists and, if not, creates it.
   307  // Note that this function locks the database, which deviates from the usual
   308  // convention of "caller locks" in the Redshift type.
   309  func (p *Redshift) ensureVersionTable() (err error) {
   310  	if err = p.Lock(); err != nil {
   311  		return err
   312  	}
   313  
   314  	defer func() {
   315  		if e := p.Unlock(); e != nil {
   316  			if err == nil {
   317  				err = e
   318  			} else {
   319  				err = multierror.Append(err, e)
   320  			}
   321  		}
   322  	}()
   323  
   324  	// check if migration table exists
   325  	var count int
   326  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   327  	if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
   328  		return &database.Error{OrigErr: err, Query: []byte(query)}
   329  	}
   330  	if count == 1 {
   331  		return nil
   332  	}
   333  
   334  	// if not, create the empty migration table
   335  	query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)`
   336  	if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
   337  		return &database.Error{OrigErr: err, Query: []byte(query)}
   338  	}
   339  	return nil
   340  }