github.com/meetsoni15/go-migrate/v4@v4.15.3-0.20221220054613-2c40bd0c4ee9/database/sqlserver/sqlserver.go (about) 1 package sqlserver 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 12 "github.com/denisenkom/go-mssqldb/batch" 13 14 "go.uber.org/atomic" 15 16 "github.com/Azure/go-autorest/autorest/adal" 17 mssql "github.com/denisenkom/go-mssqldb" // mssql support 18 "github.com/hashicorp/go-multierror" 19 "github.com/meetsoni15/go-migrate/v4" 20 "github.com/meetsoni15/go-migrate/v4/database" 21 ) 22 23 func init() { 24 database.Register("sqlserver", &SQLServer{}) 25 } 26 27 // DefaultMigrationsTable is the name of the migrations table in the database 28 var DefaultMigrationsTable = "schema_migrations" 29 30 var ( 31 ErrNilConfig = fmt.Errorf("no config") 32 ErrNoDatabaseName = fmt.Errorf("no database name") 33 ErrNoSchema = fmt.Errorf("no schema") 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.") 36 ) 37 38 var lockErrorMap = map[mssql.ReturnStatus]string{ 39 -1: "The lock request timed out.", 40 -2: "The lock request was canceled.", 41 -3: "The lock request was chosen as a deadlock victim.", 42 -999: "Parameter validation or other call error.", 43 } 44 45 // Config for database 46 type Config struct { 47 MigrationsTable string 48 DatabaseName string 49 SchemaName string 50 BatchStatementEnabled bool 51 } 52 53 // SQL Server connection 54 type SQLServer struct { 55 // Locking and unlocking need to use the same connection 56 conn *sql.Conn 57 db *sql.DB 58 isLocked atomic.Bool 59 60 // Open and WithInstance need to garantuee that config is never nil 61 config *Config 62 } 63 64 // WithInstance returns a database instance from an already created database connection. 65 // 66 // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. 67 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 68 if config == nil { 69 return nil, ErrNilConfig 70 } 71 72 if err := instance.Ping(); err != nil { 73 return nil, err 74 } 75 76 if config.DatabaseName == "" { 77 query := `SELECT DB_NAME()` 78 var databaseName string 79 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 80 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 81 } 82 83 if len(databaseName) == 0 { 84 return nil, ErrNoDatabaseName 85 } 86 87 config.DatabaseName = databaseName 88 } 89 90 if config.SchemaName == "" { 91 query := `SELECT SCHEMA_NAME()` 92 var schemaName string 93 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 94 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 95 } 96 97 if len(schemaName) == 0 { 98 return nil, ErrNoSchema 99 } 100 101 config.SchemaName = schemaName 102 } 103 104 if len(config.MigrationsTable) == 0 { 105 config.MigrationsTable = DefaultMigrationsTable 106 } 107 108 conn, err := instance.Conn(context.Background()) 109 110 if err != nil { 111 return nil, err 112 } 113 114 ss := &SQLServer{ 115 conn: conn, 116 db: instance, 117 config: config, 118 } 119 120 if err := ss.ensureVersionTable(); err != nil { 121 return nil, err 122 } 123 124 return ss, nil 125 } 126 127 // Open a connection to the database. 128 func (ss *SQLServer) Open(url string) (database.Driver, error) { 129 purl, err := nurl.Parse(url) 130 if err != nil { 131 return nil, err 132 } 133 134 useMsiParam := purl.Query().Get("useMsi") 135 useMsi := false 136 if len(useMsiParam) > 0 { 137 useMsi, err = strconv.ParseBool(useMsiParam) 138 if err != nil { 139 return nil, err 140 } 141 } 142 143 if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet { 144 return nil, ErrMultipleAuthOptionsPassed 145 } 146 147 filteredURL := migrate.FilterCustomQuery(purl).String() 148 149 var db *sql.DB 150 if useMsi { 151 resource := getAADResourceFromServerUri(purl) 152 tokenProvider, err := getMSITokenProvider(resource) 153 if err != nil { 154 return nil, err 155 } 156 157 connector, err := mssql.NewAccessTokenConnector( 158 filteredURL, tokenProvider) 159 if err != nil { 160 return nil, err 161 } 162 163 db = sql.OpenDB(connector) 164 165 } else { 166 db, err = sql.Open("sqlserver", filteredURL) 167 if err != nil { 168 return nil, err 169 } 170 } 171 172 migrationsTable := purl.Query().Get("x-migrations-table") 173 174 batchStatementEnabled := false 175 if s := purl.Query().Get("x-batch"); len(s) > 0 { 176 batchStatementEnabled, err = strconv.ParseBool(s) 177 if err != nil { 178 return nil, fmt.Errorf("Unable to parse option x-batch: %w", err) 179 } 180 } 181 182 px, err := WithInstance(db, &Config{ 183 DatabaseName: purl.Path, 184 MigrationsTable: migrationsTable, 185 BatchStatementEnabled: batchStatementEnabled, 186 }) 187 188 if err != nil { 189 return nil, err 190 } 191 192 return px, nil 193 } 194 195 // Close the database connection 196 func (ss *SQLServer) Close() error { 197 connErr := ss.conn.Close() 198 dbErr := ss.db.Close() 199 if connErr != nil || dbErr != nil { 200 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 201 } 202 return nil 203 } 204 205 // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. 206 func (ss *SQLServer) Lock() error { 207 return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { 208 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 209 if err != nil { 210 return err 211 } 212 213 // This will either obtain the lock immediately and return true, 214 // or return false if the lock cannot be acquired immediately. 215 // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017 216 query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0` 217 218 var status mssql.ReturnStatus 219 if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 { 220 return nil 221 } else if err != nil { 222 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 223 } else { 224 return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)} 225 } 226 }) 227 } 228 229 // Unlock froms the migration lock from the database 230 func (ss *SQLServer) Unlock() error { 231 return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error { 232 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 233 if err != nil { 234 return err 235 } 236 237 // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017 238 query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'` 239 if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil { 240 return &database.Error{OrigErr: err, Query: []byte(query)} 241 } 242 243 return nil 244 }) 245 } 246 247 // Run the migrations for the database 248 func (ss *SQLServer) Run(migration io.Reader) error { 249 migr, err := io.ReadAll(migration) 250 if err != nil { 251 return err 252 } 253 254 // run migration 255 query := string(migr[:]) 256 scripts := []string{query} 257 258 if ss.config.BatchStatementEnabled { 259 scripts = batch.Split(query, "go") 260 } 261 262 for _, script := range scripts { 263 if _, err := ss.conn.ExecContext(context.Background(), script); err != nil { 264 if msErr, ok := err.(mssql.Error); ok { 265 message := fmt.Sprintf("migration failed: %s", msErr.Message) 266 if msErr.ProcName != "" { 267 message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) 268 } 269 return database.Error{OrigErr: err, Err: message, Query: []byte(script), Line: uint(msErr.LineNo)} 270 } 271 return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(script)} 272 } 273 } 274 275 return nil 276 } 277 278 // SetVersion for the current database 279 func (ss *SQLServer) SetVersion(version int, dirty bool) error { 280 281 tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) 282 if err != nil { 283 return &database.Error{OrigErr: err, Err: "transaction start failed"} 284 } 285 286 query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"` 287 if _, err := tx.Exec(query); err != nil { 288 if errRollback := tx.Rollback(); errRollback != nil { 289 err = multierror.Append(err, errRollback) 290 } 291 return &database.Error{OrigErr: err, Query: []byte(query)} 292 } 293 294 // Also re-write the schema version for nil dirty versions to prevent 295 // empty schema version for failed down migration on the first migration 296 // See: https://github.com/golang-migrate/migrate/issues/330 297 if version >= 0 || (version == database.NilVersion && dirty) { 298 var dirtyBit int 299 if dirty { 300 dirtyBit = 1 301 } 302 query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)` 303 if _, err := tx.Exec(query, version, dirtyBit); err != nil { 304 if errRollback := tx.Rollback(); errRollback != nil { 305 err = multierror.Append(err, errRollback) 306 } 307 return &database.Error{OrigErr: err, Query: []byte(query)} 308 } 309 } 310 311 if err := tx.Commit(); err != nil { 312 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 313 } 314 315 return nil 316 } 317 318 // Version of the current database state 319 func (ss *SQLServer) Version() (version int, dirty bool, err error) { 320 query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"` 321 err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 322 switch { 323 case err == sql.ErrNoRows: 324 return database.NilVersion, false, nil 325 326 case err != nil: 327 // FIXME: convert to MSSQL error 328 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 329 330 default: 331 return version, dirty, nil 332 } 333 } 334 335 // Drop all tables from the database. 336 func (ss *SQLServer) Drop() error { 337 338 // drop all referential integrity constraints 339 query := ` 340 DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR 341 SET @Cursor = CURSOR FAST_FORWARD FOR 342 SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']' 343 FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1 344 LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME 345 OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql 346 WHILE (@@FETCH_STATUS = 0) 347 BEGIN 348 Exec sp_executesql @Sql 349 FETCH NEXT FROM @Cursor INTO @Sql 350 END 351 CLOSE @Cursor DEALLOCATE @Cursor` 352 353 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 354 return &database.Error{OrigErr: err, Query: []byte(query)} 355 } 356 357 // drop the tables 358 query = `EXEC sp_MSforeachtable 'DROP TABLE ?'` 359 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 360 return &database.Error{OrigErr: err, Query: []byte(query)} 361 } 362 363 return nil 364 } 365 366 func (ss *SQLServer) ensureVersionTable() (err error) { 367 if err = ss.Lock(); err != nil { 368 return err 369 } 370 371 defer func() { 372 if e := ss.Unlock(); e != nil { 373 if err == nil { 374 err = e 375 } else { 376 err = multierror.Append(err, e) 377 } 378 } 379 }() 380 381 query := `IF NOT EXISTS 382 (SELECT * 383 FROM sysobjects 384 WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]') 385 AND OBJECTPROPERTY(id, N'IsUserTable') = 1 386 ) 387 CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` 388 389 if _, err = ss.conn.ExecContext(context.Background(), query); err != nil { 390 return &database.Error{OrigErr: err, Query: []byte(query)} 391 } 392 393 return nil 394 } 395 396 func getMSITokenProvider(resource string) (func() (string, error), error) { 397 msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil) 398 if err != nil { 399 return nil, err 400 } 401 402 return func() (string, error) { 403 err := msi.EnsureFresh() 404 if err != nil { 405 return "", err 406 } 407 token := msi.OAuthToken() 408 return token, nil 409 }, nil 410 } 411 412 // The sql server resource can change across clouds so get it 413 // dynamically based on the server uri. 414 // ex. <server name>.database.windows.net -> https://database.windows.net 415 func getAADResourceFromServerUri(purl *nurl.URL) string { 416 return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], ".")) 417 }