github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/sqlserver/sqlserver.go (about) 1 package sqlserver 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 "io/ioutil" 9 nurl "net/url" 10 "strconv" 11 "strings" 12 13 "go.uber.org/atomic" 14 15 "github.com/Azure/go-autorest/autorest/adal" 16 mssql "github.com/denisenkom/go-mssqldb" // mssql support 17 "github.com/getsynq/migrate 18 "github.com/getsynq/migrateabase" 19 "github.com/hashicorp/go-multierror" 20 ) 21 22 func init() { 23 database.Register("sqlserver", &SQLServer{}) 24 } 25 26 // DefaultMigrationsTable is the name of the migrations table in the database 27 var DefaultMigrationsTable = "schema_migrations" 28 29 var ( 30 ErrNilConfig = fmt.Errorf("no config") 31 ErrNoDatabaseName = fmt.Errorf("no database name") 32 ErrNoSchema = fmt.Errorf("no schema") 33 ErrDatabaseDirty = fmt.Errorf("database is dirty") 34 ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.") 35 ) 36 37 var lockErrorMap = map[mssql.ReturnStatus]string{ 38 -1: "The lock request timed out.", 39 -2: "The lock request was canceled.", 40 -3: "The lock request was chosen as a deadlock victim.", 41 -999: "Parameter validation or other call error.", 42 } 43 44 // Config for database 45 type Config struct { 46 MigrationsTable string 47 DatabaseName string 48 SchemaName string 49 } 50 51 // SQL Server connection 52 type SQLServer struct { 53 // Locking and unlocking need to use the same connection 54 conn *sql.Conn 55 db *sql.DB 56 isLocked atomic.Bool 57 58 // Open and WithInstance need to garantuee that config is never nil 59 config *Config 60 } 61 62 // WithInstance returns a database instance from an already created database connection. 63 // 64 // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. 65 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 66 if config == nil { 67 return nil, ErrNilConfig 68 } 69 70 if err := instance.Ping(); err != nil { 71 return nil, err 72 } 73 74 if config.DatabaseName == "" { 75 query := `SELECT DB_NAME()` 76 var databaseName string 77 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 78 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 79 } 80 81 if len(databaseName) == 0 { 82 return nil, ErrNoDatabaseName 83 } 84 85 config.DatabaseName = databaseName 86 } 87 88 if config.SchemaName == "" { 89 query := `SELECT SCHEMA_NAME()` 90 var schemaName string 91 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 92 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 93 } 94 95 if len(schemaName) == 0 { 96 return nil, ErrNoSchema 97 } 98 99 config.SchemaName = schemaName 100 } 101 102 if len(config.MigrationsTable) == 0 { 103 config.MigrationsTable = DefaultMigrationsTable 104 } 105 106 conn, err := instance.Conn(context.Background()) 107 108 if err != nil { 109 return nil, err 110 } 111 112 ss := &SQLServer{ 113 conn: conn, 114 db: instance, 115 config: config, 116 } 117 118 if err := ss.ensureVersionTable(); err != nil { 119 return nil, err 120 } 121 122 return ss, nil 123 } 124 125 // Open a connection to the database. 126 func (ss *SQLServer) Open(url string) (database.Driver, error) { 127 purl, err := nurl.Parse(url) 128 if err != nil { 129 return nil, err 130 } 131 132 useMsiParam := purl.Query().Get("useMsi") 133 useMsi := false 134 if len(useMsiParam) > 0 { 135 useMsi, err = strconv.ParseBool(useMsiParam) 136 if err != nil { 137 return nil, err 138 } 139 } 140 141 if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet { 142 return nil, ErrMultipleAuthOptionsPassed 143 } 144 145 filteredURL := migrate.FilterCustomQuery(purl).String() 146 147 var db *sql.DB 148 if useMsi { 149 resource := getAADResourceFromServerUri(purl) 150 tokenProvider, err := getMSITokenProvider(resource) 151 if err != nil { 152 return nil, err 153 } 154 155 connector, err := mssql.NewAccessTokenConnector( 156 filteredURL, tokenProvider) 157 if err != nil { 158 return nil, err 159 } 160 161 db = sql.OpenDB(connector) 162 163 } else { 164 db, err = sql.Open("sqlserver", filteredURL) 165 if err != nil { 166 return nil, err 167 } 168 } 169 170 migrationsTable := purl.Query().Get("x-migrations-table") 171 172 px, err := WithInstance(db, &Config{ 173 DatabaseName: purl.Path, 174 MigrationsTable: migrationsTable, 175 }) 176 177 if err != nil { 178 return nil, err 179 } 180 181 return px, nil 182 } 183 184 // Close the database connection 185 func (ss *SQLServer) Close() error { 186 connErr := ss.conn.Close() 187 dbErr := ss.db.Close() 188 if connErr != nil || dbErr != nil { 189 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 190 } 191 return nil 192 } 193 194 // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. 195 func (ss *SQLServer) Lock() error { 196 return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { 197 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 198 if err != nil { 199 return err 200 } 201 202 // This will either obtain the lock immediately and return true, 203 // or return false if the lock cannot be acquired immediately. 204 // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017 205 query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0` 206 207 var status mssql.ReturnStatus 208 if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 { 209 return nil 210 } else if err != nil { 211 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 212 } else { 213 return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)} 214 } 215 }) 216 } 217 218 // Unlock froms the migration lock from the database 219 func (ss *SQLServer) Unlock() error { 220 return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error { 221 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 222 if err != nil { 223 return err 224 } 225 226 // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017 227 query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'` 228 if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil { 229 return &database.Error{OrigErr: err, Query: []byte(query)} 230 } 231 232 return nil 233 }) 234 } 235 236 // Run the migrations for the database 237 func (ss *SQLServer) Run(migration io.Reader) error { 238 migr, err := ioutil.ReadAll(migration) 239 if err != nil { 240 return err 241 } 242 243 // run migration 244 query := string(migr[:]) 245 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 246 if msErr, ok := err.(mssql.Error); ok { 247 message := fmt.Sprintf("migration failed: %s", msErr.Message) 248 if msErr.ProcName != "" { 249 message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) 250 } 251 return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)} 252 } 253 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 254 } 255 256 return nil 257 } 258 259 // SetVersion for the current database 260 func (ss *SQLServer) SetVersion(version int, dirty bool) error { 261 262 tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) 263 if err != nil { 264 return &database.Error{OrigErr: err, Err: "transaction start failed"} 265 } 266 267 query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"` 268 if _, err := tx.Exec(query); err != nil { 269 if errRollback := tx.Rollback(); errRollback != nil { 270 err = multierror.Append(err, errRollback) 271 } 272 return &database.Error{OrigErr: err, Query: []byte(query)} 273 } 274 275 // Also re-write the schema version for nil dirty versions to prevent 276 // empty schema version for failed down migration on the first migration 277 // See: https://github.com/getsynq/migrate/330 278 if version >= 0 || (version == database.NilVersion && dirty) { 279 var dirtyBit int 280 if dirty { 281 dirtyBit = 1 282 } 283 query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)` 284 if _, err := tx.Exec(query, version, dirtyBit); err != nil { 285 if errRollback := tx.Rollback(); errRollback != nil { 286 err = multierror.Append(err, errRollback) 287 } 288 return &database.Error{OrigErr: err, Query: []byte(query)} 289 } 290 } 291 292 if err := tx.Commit(); err != nil { 293 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 294 } 295 296 return nil 297 } 298 299 // Version of the current database state 300 func (ss *SQLServer) Version() (version int, dirty bool, err error) { 301 query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"` 302 err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 303 switch { 304 case err == sql.ErrNoRows: 305 return database.NilVersion, false, nil 306 307 case err != nil: 308 // FIXME: convert to MSSQL error 309 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 310 311 default: 312 return version, dirty, nil 313 } 314 } 315 316 // Drop all tables from the database. 317 func (ss *SQLServer) Drop() error { 318 319 // drop all referential integrity constraints 320 query := ` 321 DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR 322 323 SET @Cursor = CURSOR FAST_FORWARD FOR 324 SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']' 325 FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1 326 LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME 327 328 OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql 329 330 WHILE (@@FETCH_STATUS = 0) 331 BEGIN 332 Exec sp_executesql @Sql 333 FETCH NEXT FROM @Cursor INTO @Sql 334 END 335 336 CLOSE @Cursor DEALLOCATE @Cursor` 337 338 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 339 return &database.Error{OrigErr: err, Query: []byte(query)} 340 } 341 342 // drop the tables 343 query = `EXEC sp_MSforeachtable 'DROP TABLE ?'` 344 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 345 return &database.Error{OrigErr: err, Query: []byte(query)} 346 } 347 348 return nil 349 } 350 351 func (ss *SQLServer) ensureVersionTable() (err error) { 352 if err = ss.Lock(); err != nil { 353 return err 354 } 355 356 defer func() { 357 if e := ss.Unlock(); e != nil { 358 if err == nil { 359 err = e 360 } else { 361 err = multierror.Append(err, e) 362 } 363 } 364 }() 365 366 query := `IF NOT EXISTS 367 (SELECT * 368 FROM sysobjects 369 WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]') 370 AND OBJECTPROPERTY(id, N'IsUserTable') = 1 371 ) 372 CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` 373 374 if _, err = ss.conn.ExecContext(context.Background(), query); err != nil { 375 return &database.Error{OrigErr: err, Query: []byte(query)} 376 } 377 378 return nil 379 } 380 381 func getMSITokenProvider(resource string) (func() (string, error), error) { 382 msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil) 383 if err != nil { 384 return nil, err 385 } 386 387 return func() (string, error) { 388 err := msi.EnsureFresh() 389 if err != nil { 390 return "", err 391 } 392 token := msi.OAuthToken() 393 return token, nil 394 }, nil 395 } 396 397 // The sql server resource can change across clouds so get it 398 // dynamically based on the server uri. 399 // ex. <server name>.database.windows.net -> https://database.windows.net 400 func getAADResourceFromServerUri(purl *nurl.URL) string { 401 return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], ".")) 402 }