github.com/meetsoni15/go-migrate/v4@v4.15.3-0.20221220054613-2c40bd0c4ee9/database/sqlserver/sqlserver.go (about)

     1  package sqlserver
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/denisenkom/go-mssqldb/batch"
    13  
    14  	"go.uber.org/atomic"
    15  
    16  	"github.com/Azure/go-autorest/autorest/adal"
    17  	mssql "github.com/denisenkom/go-mssqldb" // mssql support
    18  	"github.com/hashicorp/go-multierror"
    19  	"github.com/meetsoni15/go-migrate/v4"
    20  	"github.com/meetsoni15/go-migrate/v4/database"
    21  )
    22  
    23  func init() {
    24  	database.Register("sqlserver", &SQLServer{})
    25  }
    26  
    27  // DefaultMigrationsTable is the name of the migrations table in the database
    28  var DefaultMigrationsTable = "schema_migrations"
    29  
    30  var (
    31  	ErrNilConfig                 = fmt.Errorf("no config")
    32  	ErrNoDatabaseName            = fmt.Errorf("no database name")
    33  	ErrNoSchema                  = fmt.Errorf("no schema")
    34  	ErrDatabaseDirty             = fmt.Errorf("database is dirty")
    35  	ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.")
    36  )
    37  
    38  var lockErrorMap = map[mssql.ReturnStatus]string{
    39  	-1:   "The lock request timed out.",
    40  	-2:   "The lock request was canceled.",
    41  	-3:   "The lock request was chosen as a deadlock victim.",
    42  	-999: "Parameter validation or other call error.",
    43  }
    44  
    45  // Config for database
    46  type Config struct {
    47  	MigrationsTable       string
    48  	DatabaseName          string
    49  	SchemaName            string
    50  	BatchStatementEnabled bool
    51  }
    52  
    53  // SQL Server connection
    54  type SQLServer struct {
    55  	// Locking and unlocking need to use the same connection
    56  	conn     *sql.Conn
    57  	db       *sql.DB
    58  	isLocked atomic.Bool
    59  
    60  	// Open and WithInstance need to garantuee that config is never nil
    61  	config *Config
    62  }
    63  
    64  // WithInstance returns a database instance from an already created database connection.
    65  //
    66  // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
    67  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    68  	if config == nil {
    69  		return nil, ErrNilConfig
    70  	}
    71  
    72  	if err := instance.Ping(); err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	if config.DatabaseName == "" {
    77  		query := `SELECT DB_NAME()`
    78  		var databaseName string
    79  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    80  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    81  		}
    82  
    83  		if len(databaseName) == 0 {
    84  			return nil, ErrNoDatabaseName
    85  		}
    86  
    87  		config.DatabaseName = databaseName
    88  	}
    89  
    90  	if config.SchemaName == "" {
    91  		query := `SELECT SCHEMA_NAME()`
    92  		var schemaName string
    93  		if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
    94  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    95  		}
    96  
    97  		if len(schemaName) == 0 {
    98  			return nil, ErrNoSchema
    99  		}
   100  
   101  		config.SchemaName = schemaName
   102  	}
   103  
   104  	if len(config.MigrationsTable) == 0 {
   105  		config.MigrationsTable = DefaultMigrationsTable
   106  	}
   107  
   108  	conn, err := instance.Conn(context.Background())
   109  
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	ss := &SQLServer{
   115  		conn:   conn,
   116  		db:     instance,
   117  		config: config,
   118  	}
   119  
   120  	if err := ss.ensureVersionTable(); err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	return ss, nil
   125  }
   126  
   127  // Open a connection to the database.
   128  func (ss *SQLServer) Open(url string) (database.Driver, error) {
   129  	purl, err := nurl.Parse(url)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	useMsiParam := purl.Query().Get("useMsi")
   135  	useMsi := false
   136  	if len(useMsiParam) > 0 {
   137  		useMsi, err = strconv.ParseBool(useMsiParam)
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  	}
   142  
   143  	if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet {
   144  		return nil, ErrMultipleAuthOptionsPassed
   145  	}
   146  
   147  	filteredURL := migrate.FilterCustomQuery(purl).String()
   148  
   149  	var db *sql.DB
   150  	if useMsi {
   151  		resource := getAADResourceFromServerUri(purl)
   152  		tokenProvider, err := getMSITokenProvider(resource)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  
   157  		connector, err := mssql.NewAccessTokenConnector(
   158  			filteredURL, tokenProvider)
   159  		if err != nil {
   160  			return nil, err
   161  		}
   162  
   163  		db = sql.OpenDB(connector)
   164  
   165  	} else {
   166  		db, err = sql.Open("sqlserver", filteredURL)
   167  		if err != nil {
   168  			return nil, err
   169  		}
   170  	}
   171  
   172  	migrationsTable := purl.Query().Get("x-migrations-table")
   173  
   174  	batchStatementEnabled := false
   175  	if s := purl.Query().Get("x-batch"); len(s) > 0 {
   176  		batchStatementEnabled, err = strconv.ParseBool(s)
   177  		if err != nil {
   178  			return nil, fmt.Errorf("Unable to parse option x-batch: %w", err)
   179  		}
   180  	}
   181  
   182  	px, err := WithInstance(db, &Config{
   183  		DatabaseName:          purl.Path,
   184  		MigrationsTable:       migrationsTable,
   185  		BatchStatementEnabled: batchStatementEnabled,
   186  	})
   187  
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  
   192  	return px, nil
   193  }
   194  
   195  // Close the database connection
   196  func (ss *SQLServer) Close() error {
   197  	connErr := ss.conn.Close()
   198  	dbErr := ss.db.Close()
   199  	if connErr != nil || dbErr != nil {
   200  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   201  	}
   202  	return nil
   203  }
   204  
   205  // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
   206  func (ss *SQLServer) Lock() error {
   207  	return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error {
   208  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   209  		if err != nil {
   210  			return err
   211  		}
   212  
   213  		// This will either obtain the lock immediately and return true,
   214  		// or return false if the lock cannot be acquired immediately.
   215  		// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
   216  		query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
   217  
   218  		var status mssql.ReturnStatus
   219  		if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
   220  			return nil
   221  		} else if err != nil {
   222  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   223  		} else {
   224  			return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)}
   225  		}
   226  	})
   227  }
   228  
   229  // Unlock froms the migration lock from the database
   230  func (ss *SQLServer) Unlock() error {
   231  	return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error {
   232  		aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName)
   233  		if err != nil {
   234  			return err
   235  		}
   236  
   237  		// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
   238  		query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
   239  		if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
   240  			return &database.Error{OrigErr: err, Query: []byte(query)}
   241  		}
   242  
   243  		return nil
   244  	})
   245  }
   246  
   247  // Run the migrations for the database
   248  func (ss *SQLServer) Run(migration io.Reader) error {
   249  	migr, err := io.ReadAll(migration)
   250  	if err != nil {
   251  		return err
   252  	}
   253  
   254  	// run migration
   255  	query := string(migr[:])
   256  	scripts := []string{query}
   257  
   258  	if ss.config.BatchStatementEnabled {
   259  		scripts = batch.Split(query, "go")
   260  	}
   261  
   262  	for _, script := range scripts {
   263  		if _, err := ss.conn.ExecContext(context.Background(), script); err != nil {
   264  			if msErr, ok := err.(mssql.Error); ok {
   265  				message := fmt.Sprintf("migration failed: %s", msErr.Message)
   266  				if msErr.ProcName != "" {
   267  					message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName)
   268  				}
   269  				return database.Error{OrigErr: err, Err: message, Query: []byte(script), Line: uint(msErr.LineNo)}
   270  			}
   271  			return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(script)}
   272  		}
   273  	}
   274  
   275  	return nil
   276  }
   277  
   278  // SetVersion for the current database
   279  func (ss *SQLServer) SetVersion(version int, dirty bool) error {
   280  
   281  	tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
   282  	if err != nil {
   283  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   284  	}
   285  
   286  	query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"`
   287  	if _, err := tx.Exec(query); err != nil {
   288  		if errRollback := tx.Rollback(); errRollback != nil {
   289  			err = multierror.Append(err, errRollback)
   290  		}
   291  		return &database.Error{OrigErr: err, Query: []byte(query)}
   292  	}
   293  
   294  	// Also re-write the schema version for nil dirty versions to prevent
   295  	// empty schema version for failed down migration on the first migration
   296  	// See: https://github.com/golang-migrate/migrate/issues/330
   297  	if version >= 0 || (version == database.NilVersion && dirty) {
   298  		var dirtyBit int
   299  		if dirty {
   300  			dirtyBit = 1
   301  		}
   302  		query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
   303  		if _, err := tx.Exec(query, version, dirtyBit); err != nil {
   304  			if errRollback := tx.Rollback(); errRollback != nil {
   305  				err = multierror.Append(err, errRollback)
   306  			}
   307  			return &database.Error{OrigErr: err, Query: []byte(query)}
   308  		}
   309  	}
   310  
   311  	if err := tx.Commit(); err != nil {
   312  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   313  	}
   314  
   315  	return nil
   316  }
   317  
   318  // Version of the current database state
   319  func (ss *SQLServer) Version() (version int, dirty bool, err error) {
   320  	query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
   321  	err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   322  	switch {
   323  	case err == sql.ErrNoRows:
   324  		return database.NilVersion, false, nil
   325  
   326  	case err != nil:
   327  		// FIXME: convert to MSSQL error
   328  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   329  
   330  	default:
   331  		return version, dirty, nil
   332  	}
   333  }
   334  
   335  // Drop all tables from the database.
   336  func (ss *SQLServer) Drop() error {
   337  
   338  	// drop all referential integrity constraints
   339  	query := `
   340  	DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR
   341  	SET @Cursor = CURSOR FAST_FORWARD FOR
   342  	SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']'
   343  	FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1
   344  	LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME
   345  	OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql
   346  	WHILE (@@FETCH_STATUS = 0)
   347  	BEGIN
   348  	Exec sp_executesql @Sql
   349  	FETCH NEXT FROM @Cursor INTO @Sql
   350  	END
   351  	CLOSE @Cursor DEALLOCATE @Cursor`
   352  
   353  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   354  		return &database.Error{OrigErr: err, Query: []byte(query)}
   355  	}
   356  
   357  	// drop the tables
   358  	query = `EXEC sp_MSforeachtable 'DROP TABLE ?'`
   359  	if _, err := ss.conn.ExecContext(context.Background(), query); err != nil {
   360  		return &database.Error{OrigErr: err, Query: []byte(query)}
   361  	}
   362  
   363  	return nil
   364  }
   365  
   366  func (ss *SQLServer) ensureVersionTable() (err error) {
   367  	if err = ss.Lock(); err != nil {
   368  		return err
   369  	}
   370  
   371  	defer func() {
   372  		if e := ss.Unlock(); e != nil {
   373  			if err == nil {
   374  				err = e
   375  			} else {
   376  				err = multierror.Append(err, e)
   377  			}
   378  		}
   379  	}()
   380  
   381  	query := `IF NOT EXISTS
   382  	(SELECT *
   383  		 FROM sysobjects
   384  		WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]')
   385  			AND OBJECTPROPERTY(id, N'IsUserTable') = 1
   386  	)
   387  	CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );`
   388  
   389  	if _, err = ss.conn.ExecContext(context.Background(), query); err != nil {
   390  		return &database.Error{OrigErr: err, Query: []byte(query)}
   391  	}
   392  
   393  	return nil
   394  }
   395  
   396  func getMSITokenProvider(resource string) (func() (string, error), error) {
   397  	msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil)
   398  	if err != nil {
   399  		return nil, err
   400  	}
   401  
   402  	return func() (string, error) {
   403  		err := msi.EnsureFresh()
   404  		if err != nil {
   405  			return "", err
   406  		}
   407  		token := msi.OAuthToken()
   408  		return token, nil
   409  	}, nil
   410  }
   411  
   412  // The sql server resource can change across clouds so get it
   413  // dynamically based on the server uri.
   414  // ex. <server name>.database.windows.net -> https://database.windows.net
   415  func getAADResourceFromServerUri(purl *nurl.URL) string {
   416  	return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], "."))
   417  }