github.com/Elate-DevOps/migrate/v4@v4.0.12/database/sqlserver/sqlserver.go (about)

     1  package sqlserver
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"go.uber.org/atomic"
    13  
    14  	"github.com/Azure/go-autorest/autorest/adal"
    15  	"github.com/Elate-DevOps/migrate/v4"
    16  	"github.com/Elate-DevOps/migrate/v4/database"
    17  	"github.com/hashicorp/go-multierror"
    18  	mssql "github.com/microsoft/go-mssqldb" // mssql support
    19  )
    20  
    21  func init() {
    22  	database.Register("sqlserver", &SQLServer{})
    23  }
    24  
    25  // DefaultMigrationsTable is the name of the migrations table in the database
    26  var DefaultMigrationsTable = "schema_migrations"
    27  
    28  var (
    29  	ErrNilConfig                 = fmt.Errorf("no config")
    30  	ErrNoDatabaseName            = fmt.Errorf("no database name")
    31  	ErrNoSchema                  = fmt.Errorf("no schema")
    32  	ErrDatabaseDirty             = fmt.Errorf("database is dirty")
    33  	ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
    34  )
    35  
    36  var lockErrorMap = map[mssql.ReturnStatus]string{
    37  	-1:   "The lock request timed out.",
    38  	-2:   "The lock request was canceled.",
    39  	-3:   "The lock request was chosen as a deadlock victim.",
    40  	-999: "Parameter validation or other call error.",
    41  }
    42  
    43  // Config for database
    44  type Config struct {
    45  	MigrationsTable string
    46  	DatabaseName    string
    47  	SchemaName      string
    48  }
    49  
    50  // SQL Server connection
    51  type SQLServer struct {
    52  	// Locking and unlocking need to use the same connection
    53  	conn     *sql.Conn
    54  	db       *sql.DB
    55  	isLocked atomic.Bool
    56  
    57  	// Open and WithInstance need to garantuee that config is never nil
    58  	config *Config
    59  }
    60  
    61  // WithInstance returns a database instance from an already created database connection.
    62  //
    63  // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
    64  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    65  	if config == nil {
    66  		return nil, ErrNilConfig
    67  	}
    68  
    69  	if err := instance.Ping(); err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	if config.DatabaseName == "" {
    74  		query := `SELECT DB_NAME()`
    75  		var databaseName string
    76  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    77  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    78  		}
    79  
    80  		if len(databaseName) == 0 {
    81  			return nil, ErrNoDatabaseName
    82  		}
    83  
    84  		config.DatabaseName = databaseName
    85  	}
    86  
    87  	if config.SchemaName == "" {
    88  		query := `SELECT SCHEMA_NAME()`
    89  		var schemaName string
    90  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    91  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    92  		}
    93  
    94  		if len(schemaName) == 0 {
    95  			return nil, ErrNoSchema
    96  		}
    97  
    98  		config.SchemaName = schemaName
    99  	}
   100  
   101  	if len(config.MigrationsTable) == 0 {
   102  		config.MigrationsTable = DefaultMigrationsTable
   103  	}
   104  
   105  	conn, err := instance.Conn(context.Background())
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	ss := &SQLServer{
   111  		conn:   conn,
   112  		db:     instance,
   113  		config: config,
   114  	}
   115  
   116  	if err := ss.ensureVersionTable(); err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	return ss, nil
   121  }
   122  
   123  // Open a connection to the database.
   124  func (ss *SQLServer) Open(url string) (database.Driver, error) {
   125  	purl, err := nurl.Parse(url)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	useMsiParam := purl.Query().Get("useMsi")
   131  	useMsi := false
   132  	if len(useMsiParam) > 0 {
   133  		useMsi, err = strconv.ParseBool(useMsiParam)
   134  		if err != nil {
   135  			return nil, err
   136  		}
   137  	}
   138  
   139  	if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
   140  		return nil, ErrMultipleAuthOptionsPassed
   141  	}
   142  
   143  	filteredURL := migrate.FilterCustomQuery(purl).String()
   144  
   145  	var db *sql.DB
   146  	if useMsi {
   147  		resource := getAADResourceFromServerUri(purl)
   148  		tokenProvider, err := getMSITokenProvider(resource)
   149  		if err != nil {
   150  			return nil, err
   151  		}
   152  
   153  		connector, err := mssql.NewAccessTokenConnector(
   154  			filteredURL, tokenProvider)
   155  		if err != nil {
   156  			return nil, err
   157  		}
   158  
   159  		db = sql.OpenDB(connector)
   160  
   161  	} else {
   162  		db, err = sql.Open("sqlserver", filteredURL)
   163  		if err != nil {
   164  			return nil, err
   165  		}
   166  	}
   167  
   168  	migrationsTable := purl.Query().Get("x-migrations-table")
   169  
   170  	px, err := WithInstance(db, &Config{
   171  		DatabaseName:    purl.Path,
   172  		MigrationsTable: migrationsTable,
   173  	})
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	return px, nil
   179  }
   180  
   181  // Close the database connection
   182  func (ss *SQLServer) Close() error {
   183  	connErr := ss.conn.Close()
   184  	dbErr := ss.db.Close()
   185  	if connErr != nil || dbErr != nil {
   186  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   187  	}
   188  	return nil
   189  }
   190  
   191  // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
   192  func (ss *SQLServer) Lock() error {
   193  	return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
   194  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   195  		if err != nil {
   196  			return err
   197  		}
   198  
   199  		// This will either obtain the lock immediately and return true,
   200  		// or return false if the lock cannot be acquired immediately.
   201  		// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
   202  		query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
   203  
   204  		var status mssql.ReturnStatus
   205  		if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
   206  			return nil
   207  		} else if err != nil {
   208  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   209  		} else {
   210  			return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
   211  		}
   212  	})
   213  }
   214  
   215  // Unlock froms the migration lock from the database
   216  func (ss *SQLServer) Unlock() error {
   217  	return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
   218  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   219  		if err != nil {
   220  			return err
   221  		}
   222  
   223  		// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
   224  		query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
   225  		if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
   226  			return &database.Error{OrigErr: err, Query: []byte(query)}
   227  		}
   228  
   229  		return nil
   230  	})
   231  }
   232  
   233  // Run the migrations for the database
   234  func (ss *SQLServer) Run(migration io.Reader) error {
   235  	migr, err := io.ReadAll(migration)
   236  	if err != nil {
   237  		return err
   238  	}
   239  
   240  	// run migration
   241  	query := string(migr[:])
   242  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   243  		if msErr, ok := err.(mssql.Error); ok {
   244  			message := fmt.Sprintf("migration failed: %s", msErr.Message)
   245  			if msErr.ProcName != "" {
   246  				message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
   247  			}
   248  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
   249  		}
   250  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   251  	}
   252  
   253  	return nil
   254  }
   255  
   256  // SetVersion for the current database
   257  func (ss *SQLServer) SetVersion(version int, dirty bool) error {
   258  	tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
   259  	if err != nil {
   260  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   261  	}
   262  
   263  	query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
   264  	if _, err := tx.Exec(query); err != nil {
   265  		if errRollback := tx.Rollback(); errRollback != nil {
   266  			err = multierror.Append(err, errRollback)
   267  		}
   268  		return &database.Error{OrigErr: err, Query: []byte(query)}
   269  	}
   270  
   271  	// Also re-write the schema version for nil dirty versions to prevent
   272  	// empty schema version for failed down migration on the first migration
   273  	// See: https://github.com/golang-migrate/migrate/issues/330
   274  	if version >= 0 || (version == database.NilVersion && dirty) {
   275  		var dirtyBit int
   276  		if dirty {
   277  			dirtyBit = 1
   278  		}
   279  		query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
   280  		if _, err := tx.Exec(query, version, dirtyBit); err != nil {
   281  			if errRollback := tx.Rollback(); errRollback != nil {
   282  				err = multierror.Append(err, errRollback)
   283  			}
   284  			return &database.Error{OrigErr: err, Query: []byte(query)}
   285  		}
   286  	}
   287  
   288  	if err := tx.Commit(); err != nil {
   289  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   290  	}
   291  
   292  	return nil
   293  }
   294  
   295  // Version of the current database state
   296  func (ss *SQLServer) Version() (version int, dirty bool, err error) {
   297  	query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
   298  	err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   299  	switch {
   300  	case err == sql.ErrNoRows:
   301  		return database.NilVersion, false, nil
   302  
   303  	case err != nil:
   304  		// FIXME: convert to MSSQL error
   305  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   306  
   307  	default:
   308  		return version, dirty, nil
   309  	}
   310  }
   311  
   312  // Drop all tables from the database.
   313  func (ss *SQLServer) Drop() error {
   314  	// drop all referential integrity constraints
   315  	query := `
   316  	DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
   317  
   318  	SET @Cursor = CURSOR FAST_FORWARD FOR
   319  	SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
   320  	FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
   321  	LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
   322  
   323  	OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
   324  
   325  	WHILE (@@FETCH_STATUS = 0)
   326  	BEGIN
   327  	Exec sp_executesql @Sql
   328  	FETCH NEXT FROM @Cursor INTO @Sql
   329  	END
   330  
   331  	CLOSE @Cursor DEALLOCATE @Cursor`
   332  
   333  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   334  		return &database.Error{OrigErr: err, Query: []byte(query)}
   335  	}
   336  
   337  	// drop the tables
   338  	query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
   339  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   340  		return &database.Error{OrigErr: err, Query: []byte(query)}
   341  	}
   342  
   343  	return nil
   344  }
   345  
   346  func (ss *SQLServer) ensureVersionTable() (err error) {
   347  	if err = ss.Lock(); err != nil {
   348  		return err
   349  	}
   350  
   351  	defer func() {
   352  		if e := ss.Unlock(); e != nil {
   353  			if err == nil {
   354  				err = e
   355  			} else {
   356  				err = multierror.Append(err, e)
   357  			}
   358  		}
   359  	}()
   360  
   361  	query := `IF NOT EXISTS
   362  	(SELECT *
   363  		 FROM sysobjects
   364  		WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
   365  			AND OBJECTPROPERTY(id, N'IsUserTable') = 1
   366  	)
   367  	CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
   368  
   369  	if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
   370  		return &database.Error{OrigErr: err, Query: []byte(query)}
   371  	}
   372  
   373  	return nil
   374  }
   375  
   376  func getMSITokenProvider(resource string) (func() (string, error), error) {
   377  	msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  
   382  	return func() (string, error) {
   383  		err := msi.EnsureFresh()
   384  		if err != nil {
   385  			return "", err
   386  		}
   387  		token := msi.OAuthToken()
   388  		return token, nil
   389  	}, nil
   390  }
   391  
   392  // The sql server resource can change across clouds so get it
   393  // dynamically based on the server uri.
   394  // ex. <server name>.database.windows.net -> https://database.windows.net
   395  func getAADResourceFromServerUri(purl *nurl.URL) string {
   396  	return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
   397  }