github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/sqlserver/sqlserver.go (about)

     1  package sqlserver
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"go.uber.org/atomic"
    13  
    14  	"github.com/Azure/go-autorest/autorest/adal"
    15  	"github.com/golang-migrate/migrate/v4"
    16  	"github.com/golang-migrate/migrate/v4/database"
    17  	"github.com/hashicorp/go-multierror"
    18  	mssql "github.com/microsoft/go-mssqldb" // mssql support
    19  )
    20  
    21  func init() {
    22  	database.Register("sqlserver", &SQLServer{})
    23  }
    24  
    25  // DefaultMigrationsTable is the name of the migrations table in the database
    26  var DefaultMigrationsTable = "schema_migrations"
    27  
    28  var (
    29  	ErrNilConfig                 = fmt.Errorf("no config")
    30  	ErrNoDatabaseName            = fmt.Errorf("no database name")
    31  	ErrNoSchema                  = fmt.Errorf("no schema")
    32  	ErrDatabaseDirty             = fmt.Errorf("database is dirty")
    33  	ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
    34  )
    35  
    36  var lockErrorMap = map[mssql.ReturnStatus]string{
    37  	-1:   "The lock request timed out.",
    38  	-2:   "The lock request was canceled.",
    39  	-3:   "The lock request was chosen as a deadlock victim.",
    40  	-999: "Parameter validation or other call error.",
    41  }
    42  
    43  // Config for database
    44  type Config struct {
    45  	MigrationsTable string
    46  	DatabaseName    string
    47  	SchemaName      string
    48  }
    49  
    50  // SQL Server connection
    51  type SQLServer struct {
    52  	// Locking and unlocking need to use the same connection
    53  	conn     *sql.Conn
    54  	db       *sql.DB
    55  	isLocked atomic.Bool
    56  
    57  	// Open and WithInstance need to garantuee that config is never nil
    58  	config *Config
    59  }
    60  
    61  // WithInstance returns a database instance from an already created database connection.
    62  //
    63  // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
    64  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    65  	if config == nil {
    66  		return nil, ErrNilConfig
    67  	}
    68  
    69  	if err := instance.Ping(); err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	if config.DatabaseName == "" {
    74  		query := `SELECT DB_NAME()`
    75  		var databaseName string
    76  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    77  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    78  		}
    79  
    80  		if len(databaseName) == 0 {
    81  			return nil, ErrNoDatabaseName
    82  		}
    83  
    84  		config.DatabaseName = databaseName
    85  	}
    86  
    87  	if config.SchemaName == "" {
    88  		query := `SELECT SCHEMA_NAME()`
    89  		var schemaName string
    90  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    91  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    92  		}
    93  
    94  		if len(schemaName) == 0 {
    95  			return nil, ErrNoSchema
    96  		}
    97  
    98  		config.SchemaName = schemaName
    99  	}
   100  
   101  	if len(config.MigrationsTable) == 0 {
   102  		config.MigrationsTable = DefaultMigrationsTable
   103  	}
   104  
   105  	conn, err := instance.Conn(context.Background())
   106  
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	ss := &SQLServer{
   112  		conn:   conn,
   113  		db:     instance,
   114  		config: config,
   115  	}
   116  
   117  	if err := ss.ensureVersionTable(); err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	return ss, nil
   122  }
   123  
   124  // Open a connection to the database.
   125  func (ss *SQLServer) Open(url string) (database.Driver, error) {
   126  	purl, err := nurl.Parse(url)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	useMsiParam := purl.Query().Get("useMsi")
   132  	useMsi := false
   133  	if len(useMsiParam) > 0 {
   134  		useMsi, err = strconv.ParseBool(useMsiParam)
   135  		if err != nil {
   136  			return nil, err
   137  		}
   138  	}
   139  
   140  	if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
   141  		return nil, ErrMultipleAuthOptionsPassed
   142  	}
   143  
   144  	filteredURL := migrate.FilterCustomQuery(purl).String()
   145  
   146  	var db *sql.DB
   147  	if useMsi {
   148  		resource := getAADResourceFromServerUri(purl)
   149  		tokenProvider, err := getMSITokenProvider(resource)
   150  		if err != nil {
   151  			return nil, err
   152  		}
   153  
   154  		connector, err := mssql.NewAccessTokenConnector(
   155  			filteredURL, tokenProvider)
   156  		if err != nil {
   157  			return nil, err
   158  		}
   159  
   160  		db = sql.OpenDB(connector)
   161  
   162  	} else {
   163  		db, err = sql.Open("sqlserver", filteredURL)
   164  		if err != nil {
   165  			return nil, err
   166  		}
   167  	}
   168  
   169  	migrationsTable := purl.Query().Get("x-migrations-table")
   170  
   171  	px, err := WithInstance(db, &Config{
   172  		DatabaseName:    purl.Path,
   173  		MigrationsTable: migrationsTable,
   174  	})
   175  
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  
   180  	return px, nil
   181  }
   182  
   183  // Close the database connection
   184  func (ss *SQLServer) Close() error {
   185  	connErr := ss.conn.Close()
   186  	dbErr := ss.db.Close()
   187  	if connErr != nil || dbErr != nil {
   188  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   189  	}
   190  	return nil
   191  }
   192  
   193  // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
   194  func (ss *SQLServer) Lock() error {
   195  	return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
   196  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   197  		if err != nil {
   198  			return err
   199  		}
   200  
   201  		// This will either obtain the lock immediately and return true,
   202  		// or return false if the lock cannot be acquired immediately.
   203  		// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
   204  		query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
   205  
   206  		var status mssql.ReturnStatus
   207  		if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
   208  			return nil
   209  		} else if err != nil {
   210  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   211  		} else {
   212  			return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
   213  		}
   214  	})
   215  }
   216  
   217  // Unlock froms the migration lock from the database
   218  func (ss *SQLServer) Unlock() error {
   219  	return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
   220  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   221  		if err != nil {
   222  			return err
   223  		}
   224  
   225  		// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
   226  		query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
   227  		if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
   228  			return &database.Error{OrigErr: err, Query: []byte(query)}
   229  		}
   230  
   231  		return nil
   232  	})
   233  }
   234  
   235  // Run the migrations for the database
   236  func (ss *SQLServer) Run(migration io.Reader) error {
   237  	migr, err := io.ReadAll(migration)
   238  	if err != nil {
   239  		return err
   240  	}
   241  
   242  	// run migration
   243  	query := string(migr[:])
   244  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   245  		if msErr, ok := err.(mssql.Error); ok {
   246  			message := fmt.Sprintf("migration failed: %s", msErr.Message)
   247  			if msErr.ProcName != "" {
   248  				message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
   249  			}
   250  			return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)}
   251  		}
   252  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   253  	}
   254  
   255  	return nil
   256  }
   257  
   258  // SetVersion for the current database
   259  func (ss *SQLServer) SetVersion(version int, dirty bool) error {
   260  
   261  	tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
   262  	if err != nil {
   263  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   264  	}
   265  
   266  	query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
   267  	if _, err := tx.Exec(query); err != nil {
   268  		if errRollback := tx.Rollback(); errRollback != nil {
   269  			err = multierror.Append(err, errRollback)
   270  		}
   271  		return &database.Error{OrigErr: err, Query: []byte(query)}
   272  	}
   273  
   274  	// Also re-write the schema version for nil dirty versions to prevent
   275  	// empty schema version for failed down migration on the first migration
   276  	// See: https://github.com/golang-migrate/migrate/issues/330
   277  	if version >= 0 || (version == database.NilVersion && dirty) {
   278  		var dirtyBit int
   279  		if dirty {
   280  			dirtyBit = 1
   281  		}
   282  		query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
   283  		if _, err := tx.Exec(query, version, dirtyBit); err != nil {
   284  			if errRollback := tx.Rollback(); errRollback != nil {
   285  				err = multierror.Append(err, errRollback)
   286  			}
   287  			return &database.Error{OrigErr: err, Query: []byte(query)}
   288  		}
   289  	}
   290  
   291  	if err := tx.Commit(); err != nil {
   292  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   293  	}
   294  
   295  	return nil
   296  }
   297  
   298  // Version of the current database state
   299  func (ss *SQLServer) Version() (version int, dirty bool, err error) {
   300  	query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
   301  	err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   302  	switch {
   303  	case err == sql.ErrNoRows:
   304  		return database.NilVersion, false, nil
   305  
   306  	case err != nil:
   307  		// FIXME: convert to MSSQL error
   308  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   309  
   310  	default:
   311  		return version, dirty, nil
   312  	}
   313  }
   314  
   315  // Drop all tables from the database.
   316  func (ss *SQLServer) Drop() error {
   317  
   318  	// drop all referential integrity constraints
   319  	query := `
   320  	DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
   321  
   322  	SET @Cursor = CURSOR FAST_FORWARD FOR
   323  	SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
   324  	FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
   325  	LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
   326  
   327  	OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
   328  
   329  	WHILE (@@FETCH_STATUS = 0)
   330  	BEGIN
   331  	Exec sp_executesql @Sql
   332  	FETCH NEXT FROM @Cursor INTO @Sql
   333  	END
   334  
   335  	CLOSE @Cursor DEALLOCATE @Cursor`
   336  
   337  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   338  		return &database.Error{OrigErr: err, Query: []byte(query)}
   339  	}
   340  
   341  	// drop the tables
   342  	query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
   343  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   344  		return &database.Error{OrigErr: err, Query: []byte(query)}
   345  	}
   346  
   347  	return nil
   348  }
   349  
   350  func (ss *SQLServer) ensureVersionTable() (err error) {
   351  	if err = ss.Lock(); err != nil {
   352  		return err
   353  	}
   354  
   355  	defer func() {
   356  		if e := ss.Unlock(); e != nil {
   357  			if err == nil {
   358  				err = e
   359  			} else {
   360  				err = multierror.Append(err, e)
   361  			}
   362  		}
   363  	}()
   364  
   365  	query := `IF NOT EXISTS
   366  	(SELECT *
   367  		 FROM sysobjects
   368  		WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
   369  			AND OBJECTPROPERTY(id, N'IsUserTable') = 1
   370  	)
   371  	CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
   372  
   373  	if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
   374  		return &database.Error{OrigErr: err, Query: []byte(query)}
   375  	}
   376  
   377  	return nil
   378  }
   379  
   380  func getMSITokenProvider(resource string) (func() (string, error), error) {
   381  	msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
   382  	if err != nil {
   383  		return nil, err
   384  	}
   385  
   386  	return func() (string, error) {
   387  		err := msi.EnsureFresh()
   388  		if err != nil {
   389  			return "", err
   390  		}
   391  		token := msi.OAuthToken()
   392  		return token, nil
   393  	}, nil
   394  }
   395  
   396  // The sql server resource can change across clouds so get it
   397  // dynamically based on the server uri.
   398  // ex. <server name>.database.windows.net -> https://database.windows.net
   399  func getAADResourceFromServerUri(purl *nurl.URL) string {
   400  	return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
   401  }