github.com/brandonmartin/migrate/v4@v4.14.2/database/oracle/oracle.go (about)

     1  package oracle
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"database/sql"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	nurl "net/url"
    12  	"strings"
    13  
    14  	"github.com/godror/godror"
    15  
    16  	_ "github.com/godror/godror"
    17  	"github.com/golang-migrate/migrate/v4"
    18  	"github.com/golang-migrate/migrate/v4/database"
    19  	multierror "github.com/hashicorp/go-multierror"
    20  )
    21  
    22  func init() {
    23  	db := Oracle{}
    24  	database.Register("oracle", &db)
    25  }
    26  
    27  const (
    28  	defaultMigrationsTable         = "SCHEMA_MIGRATIONS"
    29  	defaultStatementSeparator      = ";"
    30  	plsqlDefaultStatementSeparator = "---"
    31  	plsqlStatementEndToken         = "END;"
    32  )
    33  
    34  var (
    35  	ErrNilConfig      = fmt.Errorf("no config")
    36  	ErrNoDatabaseName = fmt.Errorf("no database name")
    37  )
    38  
    39  type Config struct {
    40  	MigrationsTable         string
    41  	DisableMultiStatements  bool
    42  	PLSQLStatementSeparator string
    43  
    44  	databaseName string
    45  }
    46  
    47  type Oracle struct {
    48  	// Locking and unlocking need to use the same connection
    49  	conn     *sql.Conn
    50  	db       *sql.DB
    51  	isLocked bool
    52  
    53  	// Open and WithInstance need to guarantee that config is never nil
    54  	config *Config
    55  }
    56  
    57  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    58  	if config == nil {
    59  		return nil, ErrNilConfig
    60  	}
    61  
    62  	if err := instance.Ping(); err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	query := `SELECT SYS_CONTEXT('USERENV','DB_NAME') FROM DUAL`
    67  	var dbName string
    68  	if err := instance.QueryRow(query).Scan(&dbName); err != nil {
    69  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    70  	}
    71  
    72  	if dbName == "" {
    73  		return nil, ErrNoDatabaseName
    74  	}
    75  
    76  	config.databaseName = dbName
    77  
    78  	if config.MigrationsTable == "" {
    79  		config.MigrationsTable = defaultMigrationsTable
    80  	}
    81  
    82  	if config.PLSQLStatementSeparator == "" {
    83  		config.PLSQLStatementSeparator = plsqlDefaultStatementSeparator
    84  	}
    85  
    86  	conn, err := instance.Conn(context.Background())
    87  
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	ora := &Oracle{
    93  		conn:   conn,
    94  		db:     instance,
    95  		config: config,
    96  	}
    97  
    98  	if err := ora.ensureVersionTable(); err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	return ora, nil
   103  }
   104  
   105  func (ora *Oracle) Open(url string) (database.Driver, error) {
   106  	purl, err := nurl.Parse(url)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	db, err := sql.Open("godror", migrate.FilterCustomQuery(purl).String())
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	migrationsTable := strings.ToUpper(purl.Query().Get("x-migrations-table"))
   116  	statementSeparator := purl.Query().Get("x-statement-separator")
   117  	disableMultiStatement := false
   118  	if purl.Query().Get("x-disable-multi-statements") == "true" {
   119  		disableMultiStatement = true
   120  	}
   121  
   122  	oraInst, err := WithInstance(db, &Config{
   123  		databaseName:            purl.Path,
   124  		MigrationsTable:         migrationsTable,
   125  		DisableMultiStatements:  disableMultiStatement,
   126  		PLSQLStatementSeparator: statementSeparator,
   127  	})
   128  
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	return oraInst, nil
   134  }
   135  
   136  func (ora *Oracle) Close() error {
   137  	connErr := ora.conn.Close()
   138  	dbErr := ora.db.Close()
   139  	if connErr != nil || dbErr != nil {
   140  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   141  	}
   142  	return nil
   143  }
   144  
   145  func (ora *Oracle) Lock() error {
   146  	if ora.isLocked {
   147  		return database.ErrLocked
   148  	}
   149  
   150  	// https://docs.oracle.com/cd/B28359_01/appdev.111/b28419/d_lock.htm#ARPLS021
   151  	query := `
   152  declare
   153      v_lockhandle varchar2(200);
   154      v_result     number;
   155  begin
   156  
   157      dbms_lock.allocate_unique('control_lock', v_lockhandle);
   158  
   159      v_result := dbms_lock.request(v_lockhandle, dbms_lock.x_mode);
   160  
   161      if v_result <> 0 then
   162          dbms_output.put_line(
   163                  case
   164                      when v_result=1 then 'Timeout'
   165                      when v_result=2 then 'Deadlock'
   166                      when v_result=3 then 'Parameter Error'
   167                      when v_result=4 then 'Already owned'
   168                      when v_result=5 then 'Illegal Lock Handle'
   169                      end);
   170      end if;
   171  
   172  end;
   173  `
   174  	if _, err := ora.conn.ExecContext(context.Background(), query); err != nil {
   175  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   176  	}
   177  
   178  	ora.isLocked = true
   179  	return nil
   180  }
   181  
   182  func (ora *Oracle) Unlock() error {
   183  	if !ora.isLocked {
   184  		return nil
   185  	}
   186  
   187  	query := `
   188  declare
   189    v_lockhandle varchar2(200);
   190    v_result     number;
   191  begin
   192  
   193    dbms_lock.allocate_unique('control_lock', v_lockhandle);
   194  
   195    v_result := dbms_lock.release(v_lockhandle);
   196  
   197    if v_result <> 0 then 
   198      dbms_output.put_line(
   199             case 
   200                when v_result=1 then 'Timeout'
   201                when v_result=2 then 'Deadlock'
   202                when v_result=3 then 'Parameter Error'
   203                when v_result=4 then 'Already owned'
   204                when v_result=5 then 'Illegal Lock Handle'
   205              end);
   206    end if;
   207  
   208  end;
   209  `
   210  	if _, err := ora.conn.ExecContext(context.Background(), query); err != nil {
   211  		return &database.Error{OrigErr: err, Query: []byte(query)}
   212  	}
   213  	ora.isLocked = false
   214  	return nil
   215  }
   216  
   217  func (ora *Oracle) Run(migration io.Reader) error {
   218  	queries, err := parseStatements(migration, ora.config)
   219  	if err != nil {
   220  		return err
   221  	}
   222  	for _, query := range queries {
   223  		if _, err := ora.conn.ExecContext(context.Background(), query); err != nil {
   224  			if oraErr, ok := godror.AsOraErr(err); ok {
   225  				return database.Error{OrigErr: oraErr, Err: oraErr.Message(), Query: []byte(query)}
   226  			}
   227  			return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(query)}
   228  		}
   229  	}
   230  
   231  	return nil
   232  }
   233  
   234  func (ora *Oracle) SetVersion(version int, dirty bool) error {
   235  	tx, err := ora.conn.BeginTx(context.Background(), &sql.TxOptions{})
   236  	if err != nil {
   237  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   238  	}
   239  
   240  	query := "TRUNCATE TABLE " + ora.config.MigrationsTable
   241  	if _, err := tx.Exec(query); err != nil {
   242  		if errRollback := tx.Rollback(); errRollback != nil {
   243  			err = multierror.Append(err, errRollback)
   244  		}
   245  		return &database.Error{OrigErr: err, Query: []byte(query)}
   246  	}
   247  
   248  	if version >= 0 {
   249  		query = `INSERT INTO ` + ora.config.MigrationsTable + ` (VERSION, DIRTY) VALUES (:1, :2)`
   250  		if _, err := tx.Exec(query, version, b2i(dirty)); err != nil {
   251  			if errRollback := tx.Rollback(); errRollback != nil {
   252  				err = multierror.Append(err, errRollback)
   253  			}
   254  			return &database.Error{OrigErr: err, Query: []byte(query)}
   255  		}
   256  	}
   257  
   258  	if err := tx.Commit(); err != nil {
   259  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   260  	}
   261  
   262  	return nil
   263  }
   264  
   265  func (ora *Oracle) Version() (version int, dirty bool, err error) {
   266  	query := "SELECT VERSION, DIRTY FROM " + ora.config.MigrationsTable + " WHERE ROWNUM = 1 ORDER BY VERSION desc"
   267  	err = ora.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   268  	switch {
   269  	case err == sql.ErrNoRows:
   270  		return database.NilVersion, false, nil
   271  
   272  	case err != nil:
   273  		if _, ok := godror.AsOraErr(err); ok {
   274  			return database.NilVersion, false, nil
   275  		}
   276  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   277  
   278  	default:
   279  		return version, dirty, nil
   280  	}
   281  }
   282  
   283  func (ora *Oracle) Drop() (err error) {
   284  	// select all tables in current schema
   285  	query := fmt.Sprintf(`SELECT TABLE_NAME FROM USER_TABLES`)
   286  	tables, err := ora.conn.QueryContext(context.Background(), query)
   287  	if err != nil {
   288  		return &database.Error{OrigErr: err, Query: []byte(query)}
   289  	}
   290  	defer func() {
   291  		if errClose := tables.Close(); errClose != nil {
   292  			err = multierror.Append(err, errClose)
   293  		}
   294  	}()
   295  
   296  	// delete one table after another
   297  	tableNames := make([]string, 0)
   298  	for tables.Next() {
   299  		var tableName string
   300  		if err := tables.Scan(&tableName); err != nil {
   301  			return err
   302  		}
   303  		if len(tableName) > 0 {
   304  			tableNames = append(tableNames, tableName)
   305  		}
   306  	}
   307  
   308  	query = `
   309  BEGIN
   310     EXECUTE IMMEDIATE 'DROP TABLE %s';
   311  EXCEPTION
   312     WHEN OTHERS THEN
   313        IF SQLCODE != -942 THEN
   314           RAISE;
   315        END IF;
   316  END;
   317  `
   318  	if len(tableNames) > 0 {
   319  		// delete one by one ...
   320  		for _, t := range tableNames {
   321  			if _, err := ora.conn.ExecContext(context.Background(), fmt.Sprintf(query, t)); err != nil {
   322  				return &database.Error{OrigErr: err, Query: []byte(query)}
   323  			}
   324  		}
   325  	}
   326  
   327  	return nil
   328  }
   329  
   330  // ensureVersionTable checks if versions table exists and, if not, creates it.
   331  // Note that this function locks the database, which deviates from the usual
   332  // convention of "caller locks" in the Postgres type.
   333  func (ora *Oracle) ensureVersionTable() (err error) {
   334  	if err = ora.Lock(); err != nil {
   335  		return err
   336  	}
   337  
   338  	defer func() {
   339  		if e := ora.Unlock(); e != nil {
   340  			if err == nil {
   341  				err = e
   342  			} else {
   343  				err = multierror.Append(err, e)
   344  			}
   345  		}
   346  	}()
   347  
   348  	query := `
   349  declare
   350  v_sql LONG;
   351  begin
   352  
   353  v_sql:='create table %s
   354    (
   355    VERSION NUMBER(20) NOT NULL PRIMARY KEY,
   356    DIRTY NUMBER(1) NOT NULL
   357    )';
   358  execute immediate v_sql;
   359  
   360  EXCEPTION
   361      WHEN OTHERS THEN
   362        IF SQLCODE = -955 THEN
   363          NULL; -- suppresses ORA-00955 exception
   364        ELSE
   365           RAISE;
   366        END IF;
   367  END;
   368  `
   369  	if _, err = ora.conn.ExecContext(context.Background(), fmt.Sprintf(query, ora.config.MigrationsTable)); err != nil {
   370  		return &database.Error{OrigErr: err, Query: []byte(query)}
   371  	}
   372  
   373  	return nil
   374  }
   375  
   376  func parseStatements(rd io.Reader, c *Config) ([]string, error) {
   377  	migr, err := ioutil.ReadAll(rd)
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  
   382  	// If multi-statements has been disable explicitly,
   383  	// i.e, there is no multi-statement enabled(neither normal multi-statements nor multi-PL/SQL-statements),
   384  	// return the whole migration as a blob.
   385  	if c.DisableMultiStatements {
   386  		return []string{string(migr)}, nil
   387  	}
   388  
   389  	// Either normal multi-statements or multi-PL/SQL-statements has been enabled.
   390  	plsqlEnabled := false
   391  	if strings.Contains(string(migr), plsqlStatementEndToken) {
   392  		plsqlEnabled = true
   393  	}
   394  	var queries []string
   395  	var buf bytes.Buffer
   396  	scanner := bufio.NewScanner(bytes.NewBuffer(migr))
   397  	scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
   398  	for scanner.Scan() {
   399  		line := scanner.Text()
   400  		if plsqlEnabled && line == c.PLSQLStatementSeparator {
   401  			query := buf.String()
   402  			if query != "" {
   403  				queries = append(queries, query)
   404  			}
   405  			buf.Reset()
   406  		}
   407  		// ignore comment
   408  		if strings.HasPrefix(line, "--") {
   409  			continue
   410  		}
   411  		if _, err := buf.WriteString(line + "\n"); err != nil {
   412  			return nil, err
   413  		}
   414  	}
   415  	if plsqlEnabled {
   416  		query := buf.String()
   417  		if query != "" {
   418  			queries = append(queries, query)
   419  		}
   420  	} else {
   421  		queries = strings.Split(buf.String(), defaultStatementSeparator)
   422  	}
   423  
   424  	results := make([]string, 0)
   425  	sLen := len(plsqlStatementEndToken)
   426  	for _, query := range queries {
   427  		query = strings.TrimSpace(query)
   428  		query = strings.TrimPrefix(query, "\n")
   429  		query = strings.TrimSuffix(query, "\n")
   430  		if len(query) > sLen && strings.ToUpper(query[len(query)-sLen:]) != plsqlStatementEndToken {
   431  			query = strings.TrimSuffix(query, ";")
   432  		}
   433  		if query == "" {
   434  			continue
   435  		}
   436  		results = append(results, query)
   437  	}
   438  	return results, nil
   439  }
   440  
   441  func b2i(b bool) int {
   442  	if b {
   443  		return 1
   444  	}
   445  	return 0
   446  }