github.com/Elate-DevOps/migrate/v4@v4.0.12/database/sqlserver/sqlserver.go (about) 1 package sqlserver 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 12 "go.uber.org/atomic" 13 14 "github.com/Azure/go-autorest/autorest/adal" 15 "github.com/Elate-DevOps/migrate/v4" 16 "github.com/Elate-DevOps/migrate/v4/database" 17 "github.com/hashicorp/go-multierror" 18 mssql "github.com/microsoft/go-mssqldb" // mssql support 19 ) 20 21 func init() { 22 database.Register("sqlserver", &SQLServer{}) 23 } 24 25 // DefaultMigrationsTable is the name of the migrations table in the database 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrNilConfig = fmt.Errorf("no config") 30 ErrNoDatabaseName = fmt.Errorf("no database name") 31 ErrNoSchema = fmt.Errorf("no schema") 32 ErrDatabaseDirty = fmt.Errorf("database is dirty") 33 ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.") 34 ) 35 36 var lockErrorMap = map[mssql.ReturnStatus]string{ 37 -1: "The lock request timed out.", 38 -2: "The lock request was canceled.", 39 -3: "The lock request was chosen as a deadlock victim.", 40 -999: "Parameter validation or other call error.", 41 } 42 43 // Config for database 44 type Config struct { 45 MigrationsTable string 46 DatabaseName string 47 SchemaName string 48 } 49 50 // SQL Server connection 51 type SQLServer struct { 52 // Locking and unlocking need to use the same connection 53 conn *sql.Conn 54 db *sql.DB 55 isLocked atomic.Bool 56 57 // Open and WithInstance need to garantuee that config is never nil 58 config *Config 59 } 60 61 // WithInstance returns a database instance from an already created database connection. 62 // 63 // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. 64 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 65 if config == nil { 66 return nil, ErrNilConfig 67 } 68 69 if err := instance.Ping(); err != nil { 70 return nil, err 71 } 72 73 if config.DatabaseName == "" { 74 query := `SELECT DB_NAME()` 75 var databaseName string 76 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 77 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 78 } 79 80 if len(databaseName) == 0 { 81 return nil, ErrNoDatabaseName 82 } 83 84 config.DatabaseName = databaseName 85 } 86 87 if config.SchemaName == "" { 88 query := `SELECT SCHEMA_NAME()` 89 var schemaName string 90 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 91 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 92 } 93 94 if len(schemaName) == 0 { 95 return nil, ErrNoSchema 96 } 97 98 config.SchemaName = schemaName 99 } 100 101 if len(config.MigrationsTable) == 0 { 102 config.MigrationsTable = DefaultMigrationsTable 103 } 104 105 conn, err := instance.Conn(context.Background()) 106 if err != nil { 107 return nil, err 108 } 109 110 ss := &SQLServer{ 111 conn: conn, 112 db: instance, 113 config: config, 114 } 115 116 if err := ss.ensureVersionTable(); err != nil { 117 return nil, err 118 } 119 120 return ss, nil 121 } 122 123 // Open a connection to the database. 124 func (ss *SQLServer) Open(url string) (database.Driver, error) { 125 purl, err := nurl.Parse(url) 126 if err != nil { 127 return nil, err 128 } 129 130 useMsiParam := purl.Query().Get("useMsi") 131 useMsi := false 132 if len(useMsiParam) > 0 { 133 useMsi, err = strconv.ParseBool(useMsiParam) 134 if err != nil { 135 return nil, err 136 } 137 } 138 139 if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet { 140 return nil, ErrMultipleAuthOptionsPassed 141 } 142 143 filteredURL := migrate.FilterCustomQuery(purl).String() 144 145 var db *sql.DB 146 if useMsi { 147 resource := getAADResourceFromServerUri(purl) 148 tokenProvider, err := getMSITokenProvider(resource) 149 if err != nil { 150 return nil, err 151 } 152 153 connector, err := mssql.NewAccessTokenConnector( 154 filteredURL, tokenProvider) 155 if err != nil { 156 return nil, err 157 } 158 159 db = sql.OpenDB(connector) 160 161 } else { 162 db, err = sql.Open("sqlserver", filteredURL) 163 if err != nil { 164 return nil, err 165 } 166 } 167 168 migrationsTable := purl.Query().Get("x-migrations-table") 169 170 px, err := WithInstance(db, &Config{ 171 DatabaseName: purl.Path, 172 MigrationsTable: migrationsTable, 173 }) 174 if err != nil { 175 return nil, err 176 } 177 178 return px, nil 179 } 180 181 // Close the database connection 182 func (ss *SQLServer) Close() error { 183 connErr := ss.conn.Close() 184 dbErr := ss.db.Close() 185 if connErr != nil || dbErr != nil { 186 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 187 } 188 return nil 189 } 190 191 // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. 192 func (ss *SQLServer) Lock() error { 193 return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { 194 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 195 if err != nil { 196 return err 197 } 198 199 // This will either obtain the lock immediately and return true, 200 // or return false if the lock cannot be acquired immediately. 201 // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017 202 query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0` 203 204 var status mssql.ReturnStatus 205 if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 { 206 return nil 207 } else if err != nil { 208 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 209 } else { 210 return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)} 211 } 212 }) 213 } 214 215 // Unlock froms the migration lock from the database 216 func (ss *SQLServer) Unlock() error { 217 return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error { 218 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 219 if err != nil { 220 return err 221 } 222 223 // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017 224 query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'` 225 if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil { 226 return &database.Error{OrigErr: err, Query: []byte(query)} 227 } 228 229 return nil 230 }) 231 } 232 233 // Run the migrations for the database 234 func (ss *SQLServer) Run(migration io.Reader) error { 235 migr, err := io.ReadAll(migration) 236 if err != nil { 237 return err 238 } 239 240 // run migration 241 query := string(migr[:]) 242 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 243 if msErr, ok := err.(mssql.Error); ok { 244 message := fmt.Sprintf("migration failed: %s", msErr.Message) 245 if msErr.ProcName != "" { 246 message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) 247 } 248 return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)} 249 } 250 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 251 } 252 253 return nil 254 } 255 256 // SetVersion for the current database 257 func (ss *SQLServer) SetVersion(version int, dirty bool) error { 258 tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) 259 if err != nil { 260 return &database.Error{OrigErr: err, Err: "transaction start failed"} 261 } 262 263 query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"` 264 if _, err := tx.Exec(query); err != nil { 265 if errRollback := tx.Rollback(); errRollback != nil { 266 err = multierror.Append(err, errRollback) 267 } 268 return &database.Error{OrigErr: err, Query: []byte(query)} 269 } 270 271 // Also re-write the schema version for nil dirty versions to prevent 272 // empty schema version for failed down migration on the first migration 273 // See: https://github.com/golang-migrate/migrate/issues/330 274 if version >= 0 || (version == database.NilVersion && dirty) { 275 var dirtyBit int 276 if dirty { 277 dirtyBit = 1 278 } 279 query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)` 280 if _, err := tx.Exec(query, version, dirtyBit); err != nil { 281 if errRollback := tx.Rollback(); errRollback != nil { 282 err = multierror.Append(err, errRollback) 283 } 284 return &database.Error{OrigErr: err, Query: []byte(query)} 285 } 286 } 287 288 if err := tx.Commit(); err != nil { 289 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 290 } 291 292 return nil 293 } 294 295 // Version of the current database state 296 func (ss *SQLServer) Version() (version int, dirty bool, err error) { 297 query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"` 298 err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 299 switch { 300 case err == sql.ErrNoRows: 301 return database.NilVersion, false, nil 302 303 case err != nil: 304 // FIXME: convert to MSSQL error 305 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 306 307 default: 308 return version, dirty, nil 309 } 310 } 311 312 // Drop all tables from the database. 313 func (ss *SQLServer) Drop() error { 314 // drop all referential integrity constraints 315 query := ` 316 DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR 317 318 SET @Cursor = CURSOR FAST_FORWARD FOR 319 SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']' 320 FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1 321 LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME 322 323 OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql 324 325 WHILE (@@FETCH_STATUS = 0) 326 BEGIN 327 Exec sp_executesql @Sql 328 FETCH NEXT FROM @Cursor INTO @Sql 329 END 330 331 CLOSE @Cursor DEALLOCATE @Cursor` 332 333 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 334 return &database.Error{OrigErr: err, Query: []byte(query)} 335 } 336 337 // drop the tables 338 query = `EXEC sp_MSforeachtable 'DROP TABLE ?'` 339 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 340 return &database.Error{OrigErr: err, Query: []byte(query)} 341 } 342 343 return nil 344 } 345 346 func (ss *SQLServer) ensureVersionTable() (err error) { 347 if err = ss.Lock(); err != nil { 348 return err 349 } 350 351 defer func() { 352 if e := ss.Unlock(); e != nil { 353 if err == nil { 354 err = e 355 } else { 356 err = multierror.Append(err, e) 357 } 358 } 359 }() 360 361 query := `IF NOT EXISTS 362 (SELECT * 363 FROM sysobjects 364 WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]') 365 AND OBJECTPROPERTY(id, N'IsUserTable') = 1 366 ) 367 CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` 368 369 if _, err = ss.conn.ExecContext(context.Background(), query); err != nil { 370 return &database.Error{OrigErr: err, Query: []byte(query)} 371 } 372 373 return nil 374 } 375 376 func getMSITokenProvider(resource string) (func() (string, error), error) { 377 msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil) 378 if err != nil { 379 return nil, err 380 } 381 382 return func() (string, error) { 383 err := msi.EnsureFresh() 384 if err != nil { 385 return "", err 386 } 387 token := msi.OAuthToken() 388 return token, nil 389 }, nil 390 } 391 392 // The sql server resource can change across clouds so get it 393 // dynamically based on the server uri. 394 // ex. <server name>.database.windows.net -> https://database.windows.net 395 func getAADResourceFromServerUri(purl *nurl.URL) string { 396 return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], ".")) 397 }