github.com/nokia/migrate/v4@v4.16.0/database/sqlserver/sqlserver.go (about) 1 package sqlserver 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 "io/ioutil" 9 nurl "net/url" 10 "strconv" 11 "strings" 12 13 "go.uber.org/atomic" 14 15 "github.com/Azure/go-autorest/autorest/adal" 16 mssql "github.com/denisenkom/go-mssqldb" // mssql support 17 "github.com/hashicorp/go-multierror" 18 "github.com/nokia/migrate/v4" 19 "github.com/nokia/migrate/v4/database" 20 "github.com/nokia/migrate/v4/source" 21 ) 22 23 func init() { 24 database.Register("sqlserver", &SQLServer{}) 25 } 26 27 // DefaultMigrationsTable is the name of the migrations table in the database 28 var DefaultMigrationsTable = "schema_migrations" 29 30 var ( 31 ErrNilConfig = fmt.Errorf("no config") 32 ErrNoDatabaseName = fmt.Errorf("no database name") 33 ErrNoSchema = fmt.Errorf("no schema") 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.") 36 ) 37 38 var lockErrorMap = map[mssql.ReturnStatus]string{ 39 -1: "The lock request timed out.", 40 -2: "The lock request was canceled.", 41 -3: "The lock request was chosen as a deadlock victim.", 42 -999: "Parameter validation or other call error.", 43 } 44 45 // Config for database 46 type Config struct { 47 MigrationsTable string 48 DatabaseName string 49 SchemaName string 50 } 51 52 // SQL Server connection 53 type SQLServer struct { 54 // Locking and unlocking need to use the same connection 55 conn *sql.Conn 56 db *sql.DB 57 isLocked atomic.Bool 58 59 // Open and WithInstance need to garantuee that config is never nil 60 config *Config 61 } 62 63 // WithInstance returns a database instance from an already created database connection. 64 // 65 // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. 66 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 67 if config == nil { 68 return nil, ErrNilConfig 69 } 70 71 if err := instance.Ping(); err != nil { 72 return nil, err 73 } 74 75 if config.DatabaseName == "" { 76 query := `SELECT DB_NAME()` 77 var databaseName string 78 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 79 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 80 } 81 82 if len(databaseName) == 0 { 83 return nil, ErrNoDatabaseName 84 } 85 86 config.DatabaseName = databaseName 87 } 88 89 if config.SchemaName == "" { 90 query := `SELECT SCHEMA_NAME()` 91 var schemaName string 92 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 93 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 94 } 95 96 if len(schemaName) == 0 { 97 return nil, ErrNoSchema 98 } 99 100 config.SchemaName = schemaName 101 } 102 103 if len(config.MigrationsTable) == 0 { 104 config.MigrationsTable = DefaultMigrationsTable 105 } 106 107 conn, err := instance.Conn(context.Background()) 108 if err != nil { 109 return nil, err 110 } 111 112 ss := &SQLServer{ 113 conn: conn, 114 db: instance, 115 config: config, 116 } 117 118 if err := ss.ensureVersionTable(); err != nil { 119 return nil, err 120 } 121 122 return ss, nil 123 } 124 125 // Open a connection to the database. 126 func (ss *SQLServer) Open(url string) (database.Driver, error) { 127 purl, err := nurl.Parse(url) 128 if err != nil { 129 return nil, err 130 } 131 132 useMsiParam := purl.Query().Get("useMsi") 133 useMsi := false 134 if len(useMsiParam) > 0 { 135 useMsi, err = strconv.ParseBool(useMsiParam) 136 if err != nil { 137 return nil, err 138 } 139 } 140 141 if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet { 142 return nil, ErrMultipleAuthOptionsPassed 143 } 144 145 filteredURL := migrate.FilterCustomQuery(purl).String() 146 147 var db *sql.DB 148 if useMsi { 149 resource := getAADResourceFromServerUri(purl) 150 tokenProvider, err := getMSITokenProvider(resource) 151 if err != nil { 152 return nil, err 153 } 154 155 connector, err := mssql.NewAccessTokenConnector( 156 filteredURL, tokenProvider) 157 if err != nil { 158 return nil, err 159 } 160 161 db = sql.OpenDB(connector) 162 163 } else { 164 db, err = sql.Open("sqlserver", filteredURL) 165 if err != nil { 166 return nil, err 167 } 168 } 169 170 migrationsTable := purl.Query().Get("x-migrations-table") 171 172 px, err := WithInstance(db, &Config{ 173 DatabaseName: purl.Path, 174 MigrationsTable: migrationsTable, 175 }) 176 if err != nil { 177 return nil, err 178 } 179 180 return px, nil 181 } 182 183 // Close the database connection 184 func (ss *SQLServer) Close() error { 185 connErr := ss.conn.Close() 186 dbErr := ss.db.Close() 187 if connErr != nil || dbErr != nil { 188 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 189 } 190 return nil 191 } 192 193 // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. 194 func (ss *SQLServer) Lock() error { 195 return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { 196 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 197 if err != nil { 198 return err 199 } 200 201 // This will either obtain the lock immediately and return true, 202 // or return false if the lock cannot be acquired immediately. 203 // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017 204 query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0` 205 206 var status mssql.ReturnStatus 207 if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 { 208 return nil 209 } else if err != nil { 210 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 211 } else { 212 return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)} 213 } 214 }) 215 } 216 217 // Unlock froms the migration lock from the database 218 func (ss *SQLServer) Unlock() error { 219 return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error { 220 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 221 if err != nil { 222 return err 223 } 224 225 // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017 226 query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'` 227 if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil { 228 return &database.Error{OrigErr: err, Query: []byte(query)} 229 } 230 231 return nil 232 }) 233 } 234 235 // Run the migrations for the database 236 func (ss *SQLServer) Run(migration io.Reader) error { 237 migr, err := ioutil.ReadAll(migration) 238 if err != nil { 239 return err 240 } 241 242 // run migration 243 query := string(migr[:]) 244 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 245 if msErr, ok := err.(mssql.Error); ok { 246 message := fmt.Sprintf("migration failed: %s", msErr.Message) 247 if msErr.ProcName != "" { 248 message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) 249 } 250 return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)} 251 } 252 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 253 } 254 255 return nil 256 } 257 258 func (ss *SQLServer) RunFunctionMigration(fn source.MigrationFunc) error { 259 return database.ErrNotImpl 260 } 261 262 // SetVersion for the current database 263 func (ss *SQLServer) SetVersion(version int, dirty bool) error { 264 tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) 265 if err != nil { 266 return &database.Error{OrigErr: err, Err: "transaction start failed"} 267 } 268 269 query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"` 270 if _, err := tx.Exec(query); err != nil { 271 if errRollback := tx.Rollback(); errRollback != nil { 272 err = multierror.Append(err, errRollback) 273 } 274 return &database.Error{OrigErr: err, Query: []byte(query)} 275 } 276 277 // Also re-write the schema version for nil dirty versions to prevent 278 // empty schema version for failed down migration on the first migration 279 // See: https://github.com/nokia/migrate/issues/330 280 if version >= 0 || (version == database.NilVersion && dirty) { 281 var dirtyBit int 282 if dirty { 283 dirtyBit = 1 284 } 285 query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)` 286 if _, err := tx.Exec(query, version, dirtyBit); err != nil { 287 if errRollback := tx.Rollback(); errRollback != nil { 288 err = multierror.Append(err, errRollback) 289 } 290 return &database.Error{OrigErr: err, Query: []byte(query)} 291 } 292 } 293 294 if err := tx.Commit(); err != nil { 295 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 296 } 297 298 return nil 299 } 300 301 // Version of the current database state 302 func (ss *SQLServer) Version() (version int, dirty bool, err error) { 303 query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"` 304 err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 305 switch { 306 case err == sql.ErrNoRows: 307 return database.NilVersion, false, nil 308 309 case err != nil: 310 // FIXME: convert to MSSQL error 311 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 312 313 default: 314 return version, dirty, nil 315 } 316 } 317 318 // Drop all tables from the database. 319 func (ss *SQLServer) Drop() error { 320 // drop all referential integrity constraints 321 query := ` 322 DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR 323 324 SET @Cursor = CURSOR FAST_FORWARD FOR 325 SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']' 326 FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1 327 LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME 328 329 OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql 330 331 WHILE (@@FETCH_STATUS = 0) 332 BEGIN 333 Exec sp_executesql @Sql 334 FETCH NEXT FROM @Cursor INTO @Sql 335 END 336 337 CLOSE @Cursor DEALLOCATE @Cursor` 338 339 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 340 return &database.Error{OrigErr: err, Query: []byte(query)} 341 } 342 343 // drop the tables 344 query = `EXEC sp_MSforeachtable 'DROP TABLE ?'` 345 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 346 return &database.Error{OrigErr: err, Query: []byte(query)} 347 } 348 349 return nil 350 } 351 352 func (ss *SQLServer) ensureVersionTable() (err error) { 353 if err = ss.Lock(); err != nil { 354 return err 355 } 356 357 defer func() { 358 if e := ss.Unlock(); e != nil { 359 if err == nil { 360 err = e 361 } else { 362 err = multierror.Append(err, e) 363 } 364 } 365 }() 366 367 query := `IF NOT EXISTS 368 (SELECT * 369 FROM sysobjects 370 WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]') 371 AND OBJECTPROPERTY(id, N'IsUserTable') = 1 372 ) 373 CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` 374 375 if _, err = ss.conn.ExecContext(context.Background(), query); err != nil { 376 return &database.Error{OrigErr: err, Query: []byte(query)} 377 } 378 379 return nil 380 } 381 382 func getMSITokenProvider(resource string) (func() (string, error), error) { 383 msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil) 384 if err != nil { 385 return nil, err 386 } 387 388 return func() (string, error) { 389 err := msi.EnsureFresh() 390 if err != nil { 391 return "", err 392 } 393 token := msi.OAuthToken() 394 return token, nil 395 }, nil 396 } 397 398 // The sql server resource can change across clouds so get it 399 // dynamically based on the server uri. 400 // ex. <server name>.database.windows.net -> https://database.windows.net 401 func getAADResourceFromServerUri(purl *nurl.URL) string { 402 return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], ".")) 403 }