github.com/nokia/migrate/v4@v4.16.0/database/sqlserver/sqlserver.go (about)

     1  package sqlserver
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	nurl "net/url"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"go.uber.org/atomic"
    14  
    15  	"github.com/Azure/go-autorest/autorest/adal"
    16  	mssql "github.com/denisenkom/go-mssqldb" // mssql support
    17  	"github.com/hashicorp/go-multierror"
    18  	"github.com/nokia/migrate/v4"
    19  	"github.com/nokia/migrate/v4/database"
    20  	"github.com/nokia/migrate/v4/source"
    21  )
    22  
    23  func init() {
    24  	database.Register("sqlserver", &SQLServer{})
    25  }
    26  
    27  // DefaultMigrationsTable is the name of the migrations table in the database
    28  var DefaultMigrationsTable = "schema_migrations"
    29  
    30  var (
    31  	ErrNilConfig                 = fmt.Errorf("no config")
    32  	ErrNoDatabaseName            = fmt.Errorf("no database name")
    33  	ErrNoSchema                  = fmt.Errorf("no schema")
    34  	ErrDatabaseDirty             = fmt.Errorf("database is dirty")
    35  	ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
    36  )
    37  
    38  var lockErrorMap = map[mssql.ReturnStatus]string{
    39  	-1:   "The lock request timed out.",
    40  	-2:   "The lock request was canceled.",
    41  	-3:   "The lock request was chosen as a deadlock victim.",
    42  	-999: "Parameter validation or other call error.",
    43  }
    44  
    45  // Config for database
    46  type Config struct {
    47  	MigrationsTable string
    48  	DatabaseName    string
    49  	SchemaName      string
    50  }
    51  
    52  // SQL Server connection
    53  type SQLServer struct {
    54  	// Locking and unlocking need to use the same connection
    55  	conn     *sql.Conn
    56  	db       *sql.DB
    57  	isLocked atomic.Bool
    58  
    59  	// Open and WithInstance need to garantuee that config is never nil
    60  	config *Config
    61  }
    62  
    63  // WithInstance returns a database instance from an already created database connection.
    64  //
    65  // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
    66  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    67  	if config == nil {
    68  		return nil, ErrNilConfig
    69  	}
    70  
    71  	if err := instance.Ping(); err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	if config.DatabaseName == "" {
    76  		query := `SELECT DB_NAME()`
    77  		var databaseName string
    78  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    79  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    80  		}
    81  
    82  		if len(databaseName) == 0 {
    83  			return nil, ErrNoDatabaseName
    84  		}
    85  
    86  		config.DatabaseName = databaseName
    87  	}
    88  
    89  	if config.SchemaName == "" {
    90  		query := `SELECT SCHEMA_NAME()`
    91  		var schemaName string
    92  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    93  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    94  		}
    95  
    96  		if len(schemaName) == 0 {
    97  			return nil, ErrNoSchema
    98  		}
    99  
   100  		config.SchemaName = schemaName
   101  	}
   102  
   103  	if len(config.MigrationsTable) == 0 {
   104  		config.MigrationsTable = DefaultMigrationsTable
   105  	}
   106  
   107  	conn, err := instance.Conn(context.Background())
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	ss := &SQLServer{
   113  		conn:   conn,
   114  		db:     instance,
   115  		config: config,
   116  	}
   117  
   118  	if err := ss.ensureVersionTable(); err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	return ss, nil
   123  }
   124  
   125  // Open a connection to the database.
   126  func (ss *SQLServer) Open(url string) (database.Driver, error) {
   127  	purl, err := nurl.Parse(url)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	useMsiParam := purl.Query().Get("useMsi")
   133  	useMsi := false
   134  	if len(useMsiParam) > 0 {
   135  		useMsi, err = strconv.ParseBool(useMsiParam)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  	}
   140  
   141  	if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
   142  		return nil, ErrMultipleAuthOptionsPassed
   143  	}
   144  
   145  	filteredURL := migrate.FilterCustomQuery(purl).String()
   146  
   147  	var db *sql.DB
   148  	if useMsi {
   149  		resource := getAADResourceFromServerUri(purl)
   150  		tokenProvider, err := getMSITokenProvider(resource)
   151  		if err != nil {
   152  			return nil, err
   153  		}
   154  
   155  		connector, err := mssql.NewAccessTokenConnector(
   156  			filteredURL, tokenProvider)
   157  		if err != nil {
   158  			return nil, err
   159  		}
   160  
   161  		db = sql.OpenDB(connector)
   162  
   163  	} else {
   164  		db, err = sql.Open("sqlserver", filteredURL)
   165  		if err != nil {
   166  			return nil, err
   167  		}
   168  	}
   169  
   170  	migrationsTable := purl.Query().Get("x-migrations-table")
   171  
   172  	px, err := WithInstance(db, &Config{
   173  		DatabaseName:    purl.Path,
   174  		MigrationsTable: migrationsTable,
   175  	})
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  
   180  	return px, nil
   181  }
   182  
   183  // Close the database connection
   184  func (ss *SQLServer) Close() error {
   185  	connErr := ss.conn.Close()
   186  	dbErr := ss.db.Close()
   187  	if connErr != nil || dbErr != nil {
   188  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   189  	}
   190  	return nil
   191  }
   192  
   193  // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
   194  func (ss *SQLServer) Lock() error {
   195  	return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
   196  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   197  		if err != nil {
   198  			return err
   199  		}
   200  
   201  		// This will either obtain the lock immediately and return true,
   202  		// or return false if the lock cannot be acquired immediately.
   203  		// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
   204  		query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
   205  
   206  		var status mssql.ReturnStatus
   207  		if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
   208  			return nil
   209  		} else if err != nil {
   210  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   211  		} else {
   212  			return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
   213  		}
   214  	})
   215  }
   216  
   217  // Unlock froms the migration lock from the database
   218  func (ss *SQLServer) Unlock() error {
   219  	return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
   220  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   221  		if err != nil {
   222  			return err
   223  		}
   224  
   225  		// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
   226  		query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
   227  		if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
   228  			return &database.Error{OrigErr: err, Query: []byte(query)}
   229  		}
   230  
   231  		return nil
   232  	})
   233  }
   234  
   235  // Run the migrations for the database
   236  func (ss *SQLServer) Run(migration io.Reader) error {
   237  	migr, err := ioutil.ReadAll(migration)
   238  	if err != nil {
   239  		return err
   240  	}
   241  
   242  	// run migration
   243  	query := string(migr[:])
   244  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   245  		if msErr, ok := err.(mssql.Error); ok {
   246  			message := fmt.Sprintf("migration failed: %s", msErr.Message)
   247  			if msErr.ProcName != "" {
   248  				message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
   249  			}
   250  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
   251  		}
   252  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   253  	}
   254  
   255  	return nil
   256  }
   257  
   258  func (ss *SQLServer) RunFunctionMigration(fn source.MigrationFunc) error {
   259  	return database.ErrNotImpl
   260  }
   261  
   262  // SetVersion for the current database
   263  func (ss *SQLServer) SetVersion(version int, dirty bool) error {
   264  	tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
   265  	if err != nil {
   266  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   267  	}
   268  
   269  	query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
   270  	if _, err := tx.Exec(query); err != nil {
   271  		if errRollback := tx.Rollback(); errRollback != nil {
   272  			err = multierror.Append(err, errRollback)
   273  		}
   274  		return &database.Error{OrigErr: err, Query: []byte(query)}
   275  	}
   276  
   277  	// Also re-write the schema version for nil dirty versions to prevent
   278  	// empty schema version for failed down migration on the first migration
   279  	// See: https://github.com/nokia/migrate/issues/330
   280  	if version >= 0 || (version == database.NilVersion && dirty) {
   281  		var dirtyBit int
   282  		if dirty {
   283  			dirtyBit = 1
   284  		}
   285  		query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
   286  		if _, err := tx.Exec(query, version, dirtyBit); err != nil {
   287  			if errRollback := tx.Rollback(); errRollback != nil {
   288  				err = multierror.Append(err, errRollback)
   289  			}
   290  			return &database.Error{OrigErr: err, Query: []byte(query)}
   291  		}
   292  	}
   293  
   294  	if err := tx.Commit(); err != nil {
   295  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  // Version of the current database state
   302  func (ss *SQLServer) Version() (version int, dirty bool, err error) {
   303  	query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
   304  	err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   305  	switch {
   306  	case err == sql.ErrNoRows:
   307  		return database.NilVersion, false, nil
   308  
   309  	case err != nil:
   310  		// FIXME: convert to MSSQL error
   311  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   312  
   313  	default:
   314  		return version, dirty, nil
   315  	}
   316  }
   317  
   318  // Drop all tables from the database.
   319  func (ss *SQLServer) Drop() error {
   320  	// drop all referential integrity constraints
   321  	query := `
   322  	DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
   323  
   324  	SET @Cursor = CURSOR FAST_FORWARD FOR
   325  	SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
   326  	FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
   327  	LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
   328  
   329  	OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
   330  
   331  	WHILE (@@FETCH_STATUS = 0)
   332  	BEGIN
   333  	Exec sp_executesql @Sql
   334  	FETCH NEXT FROM @Cursor INTO @Sql
   335  	END
   336  
   337  	CLOSE @Cursor DEALLOCATE @Cursor`
   338  
   339  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   340  		return &database.Error{OrigErr: err, Query: []byte(query)}
   341  	}
   342  
   343  	// drop the tables
   344  	query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
   345  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   346  		return &database.Error{OrigErr: err, Query: []byte(query)}
   347  	}
   348  
   349  	return nil
   350  }
   351  
   352  func (ss *SQLServer) ensureVersionTable() (err error) {
   353  	if err = ss.Lock(); err != nil {
   354  		return err
   355  	}
   356  
   357  	defer func() {
   358  		if e := ss.Unlock(); e != nil {
   359  			if err == nil {
   360  				err = e
   361  			} else {
   362  				err = multierror.Append(err, e)
   363  			}
   364  		}
   365  	}()
   366  
   367  	query := `IF NOT EXISTS
   368  	(SELECT *
   369  		 FROM sysobjects
   370  		WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
   371  			AND OBJECTPROPERTY(id, N'IsUserTable') = 1
   372  	)
   373  	CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
   374  
   375  	if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
   376  		return &database.Error{OrigErr: err, Query: []byte(query)}
   377  	}
   378  
   379  	return nil
   380  }
   381  
   382  func getMSITokenProvider(resource string) (func() (string, error), error) {
   383  	msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  
   388  	return func() (string, error) {
   389  		err := msi.EnsureFresh()
   390  		if err != nil {
   391  			return "", err
   392  		}
   393  		token := msi.OAuthToken()
   394  		return token, nil
   395  	}, nil
   396  }
   397  
   398  // The sql server resource can change across clouds so get it
   399  // dynamically based on the server uri.
   400  // ex. <server name>.database.windows.net -> https://database.windows.net
   401  func getAADResourceFromServerUri(purl *nurl.URL) string {
   402  	return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
   403  }