github.com/nagyist/migrate/v4@v4.14.6/database/sqlserver/sqlserver.go (about)

     1  package sqlserver
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	nurl "net/url"
    10  
    11  	mssql "github.com/denisenkom/go-mssqldb" // mssql support
    12  	"github.com/golang-migrate/migrate/v4"
    13  	"github.com/golang-migrate/migrate/v4/database"
    14  	"github.com/hashicorp/go-multierror"
    15  )
    16  
    17  func init() {
    18  	database.Register("sqlserver", &SQLServer{})
    19  }
    20  
    21  // DefaultMigrationsTable is the name of the migrations table in the database
    22  var DefaultMigrationsTable = "schema_migrations"
    23  
    24  var (
    25  	ErrNilConfig      = fmt.Errorf("no config")
    26  	ErrNoDatabaseName = fmt.Errorf("no database name")
    27  	ErrNoSchema       = fmt.Errorf("no schema")
    28  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    29  )
    30  
    31  var lockErrorMap = map[mssql.ReturnStatus]string{
    32  	-1:   "The lock request timed out.",
    33  	-2:   "The lock request was canceled.",
    34  	-3:   "The lock request was chosen as a deadlock victim.",
    35  	-999: "Parameter validation or other call error.",
    36  }
    37  
    38  // Config for database
    39  type Config struct {
    40  	MigrationsTable string
    41  	DatabaseName    string
    42  	SchemaName      string
    43  }
    44  
    45  // SQL Server connection
    46  type SQLServer struct {
    47  	// Locking and unlocking need to use the same connection
    48  	conn     *sql.Conn
    49  	db       *sql.DB
    50  	isLocked bool
    51  
    52  	// Open and WithInstance need to garantuee that config is never nil
    53  	config *Config
    54  }
    55  
    56  // WithInstance returns a database instance from an already created database connection.
    57  //
    58  // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
    59  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    60  	if config == nil {
    61  		return nil, ErrNilConfig
    62  	}
    63  
    64  	if err := instance.Ping(); err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	if config.DatabaseName == "" {
    69  		query := `SELECT DB_NAME()`
    70  		var databaseName string
    71  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    72  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    73  		}
    74  
    75  		if len(databaseName) == 0 {
    76  			return nil, ErrNoDatabaseName
    77  		}
    78  
    79  		config.DatabaseName = databaseName
    80  	}
    81  
    82  	if config.SchemaName == "" {
    83  		query := `SELECT SCHEMA_NAME()`
    84  		var schemaName string
    85  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    86  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    87  		}
    88  
    89  		if len(schemaName) == 0 {
    90  			return nil, ErrNoSchema
    91  		}
    92  
    93  		config.SchemaName = schemaName
    94  	}
    95  
    96  	if len(config.MigrationsTable) == 0 {
    97  		config.MigrationsTable = DefaultMigrationsTable
    98  	}
    99  
   100  	conn, err := instance.Conn(context.Background())
   101  
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	ss := &SQLServer{
   107  		conn:   conn,
   108  		db:     instance,
   109  		config: config,
   110  	}
   111  
   112  	if err := ss.ensureVersionTable(); err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	return ss, nil
   117  }
   118  
   119  // Open a connection to the database
   120  func (ss *SQLServer) Open(url string) (database.Driver, error) {
   121  	purl, err := nurl.Parse(url)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	db, err := sql.Open("sqlserver", migrate.FilterCustomQuery(purl).String())
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	migrationsTable := purl.Query().Get("x-migrations-table")
   132  
   133  	px, err := WithInstance(db, &Config{
   134  		DatabaseName:    purl.Path,
   135  		MigrationsTable: migrationsTable,
   136  	})
   137  
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	return px, nil
   143  }
   144  
   145  // Close the database connection
   146  func (ss *SQLServer) Close() error {
   147  	connErr := ss.conn.Close()
   148  	dbErr := ss.db.Close()
   149  	if connErr != nil || dbErr != nil {
   150  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   151  	}
   152  	return nil
   153  }
   154  
   155  // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
   156  func (ss *SQLServer) Lock() error {
   157  	if ss.isLocked {
   158  		return database.ErrLocked
   159  	}
   160  
   161  	aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   162  	if err != nil {
   163  		return err
   164  	}
   165  
   166  	// This will either obtain the lock immediately and return true,
   167  	// or return false if the lock cannot be acquired immediately.
   168  	// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
   169  	query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
   170  
   171  	var status mssql.ReturnStatus
   172  	if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
   173  		ss.isLocked = true
   174  		return nil
   175  	} else if err != nil {
   176  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   177  	} else {
   178  		return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
   179  	}
   180  }
   181  
   182  // Unlock froms the migration lock from the database
   183  func (ss *SQLServer) Unlock() error {
   184  	if !ss.isLocked {
   185  		return nil
   186  	}
   187  
   188  	aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   189  	if err != nil {
   190  		return err
   191  	}
   192  
   193  	// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
   194  	query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
   195  	if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
   196  		return &database.Error{OrigErr: err, Query: []byte(query)}
   197  	}
   198  	ss.isLocked = false
   199  
   200  	return nil
   201  }
   202  
   203  // Run the migrations for the database
   204  func (ss *SQLServer) Run(migration io.Reader) error {
   205  	migr, err := ioutil.ReadAll(migration)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	// run migration
   211  	query := string(migr[:])
   212  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   213  		if msErr, ok := err.(mssql.Error); ok {
   214  			message := fmt.Sprintf("migration failed: %s", msErr.Message)
   215  			if msErr.ProcName != "" {
   216  				message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
   217  			}
   218  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
   219  		}
   220  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   221  	}
   222  
   223  	return nil
   224  }
   225  
   226  // SetVersion for the current database
   227  func (ss *SQLServer) SetVersion(version int, dirty bool) error {
   228  
   229  	tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
   230  	if err != nil {
   231  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   232  	}
   233  
   234  	query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
   235  	if _, err := tx.Exec(query); err != nil {
   236  		if errRollback := tx.Rollback(); errRollback != nil {
   237  			err = multierror.Append(err, errRollback)
   238  		}
   239  		return &database.Error{OrigErr: err, Query: []byte(query)}
   240  	}
   241  
   242  	// Also re-write the schema version for nil dirty versions to prevent
   243  	// empty schema version for failed down migration on the first migration
   244  	// See: https://github.com/golang-migrate/migrate/issues/330
   245  	if version >= 0 || (version == database.NilVersion && dirty) {
   246  		var dirtyBit int
   247  		if dirty {
   248  			dirtyBit = 1
   249  		}
   250  		query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
   251  		if _, err := tx.Exec(query, version, dirtyBit); err != nil {
   252  			if errRollback := tx.Rollback(); errRollback != nil {
   253  				err = multierror.Append(err, errRollback)
   254  			}
   255  			return &database.Error{OrigErr: err, Query: []byte(query)}
   256  		}
   257  	}
   258  
   259  	if err := tx.Commit(); err != nil {
   260  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   261  	}
   262  
   263  	return nil
   264  }
   265  
   266  // Version of the current database state
   267  func (ss *SQLServer) Version() (version int, dirty bool, err error) {
   268  	query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
   269  	err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   270  	switch {
   271  	case err == sql.ErrNoRows:
   272  		return database.NilVersion, false, nil
   273  
   274  	case err != nil:
   275  		// FIXME: convert to MSSQL error
   276  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   277  
   278  	default:
   279  		return version, dirty, nil
   280  	}
   281  }
   282  
   283  // Drop all tables from the database.
   284  func (ss *SQLServer) Drop() error {
   285  
   286  	// drop all referential integrity constraints
   287  	query := `
   288  	DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
   289  
   290  	SET @Cursor = CURSOR FAST_FORWARD FOR
   291  	SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
   292  	FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
   293  	LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
   294  
   295  	OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
   296  
   297  	WHILE (@@FETCH_STATUS = 0)
   298  	BEGIN
   299  	Exec sp_executesql @Sql
   300  	FETCH NEXT FROM @Cursor INTO @Sql
   301  	END
   302  
   303  	CLOSE @Cursor DEALLOCATE @Cursor`
   304  
   305  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   306  		return &database.Error{OrigErr: err, Query: []byte(query)}
   307  	}
   308  
   309  	// drop the tables
   310  	query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
   311  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   312  		return &database.Error{OrigErr: err, Query: []byte(query)}
   313  	}
   314  
   315  	return nil
   316  }
   317  
   318  func (ss *SQLServer) ensureVersionTable() (err error) {
   319  	if err = ss.Lock(); err != nil {
   320  		return err
   321  	}
   322  
   323  	defer func() {
   324  		if e := ss.Unlock(); e != nil {
   325  			if err == nil {
   326  				err = e
   327  			} else {
   328  				err = multierror.Append(err, e)
   329  			}
   330  		}
   331  	}()
   332  
   333  	query := `IF NOT EXISTS
   334  	(SELECT *
   335  		 FROM sysobjects
   336  		WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
   337  			AND OBJECTPROPERTY(id, N'IsUserTable') = 1
   338  	)
   339  	CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
   340  
   341  	if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
   342  		return &database.Error{OrigErr: err, Query: []byte(query)}
   343  	}
   344  
   345  	return nil
   346  }