github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/sqlserver/sqlserver.go (about) 1 package sqlserver 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 12 "go.uber.org/atomic" 13 14 "github.com/Azure/go-autorest/autorest/adal" 15 "github.com/golang-migrate/migrate/v4" 16 "github.com/golang-migrate/migrate/v4/database" 17 "github.com/hashicorp/go-multierror" 18 mssql "github.com/microsoft/go-mssqldb" // mssql support 19 ) 20 21 func init() { 22 database.Register("sqlserver", &SQLServer{}) 23 } 24 25 // DefaultMigrationsTable is the name of the migrations table in the database 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrNilConfig = fmt.Errorf("no config") 30 ErrNoDatabaseName = fmt.Errorf("no database name") 31 ErrNoSchema = fmt.Errorf("no schema") 32 ErrDatabaseDirty = fmt.Errorf("database is dirty") 33 ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.") 34 ) 35 36 var lockErrorMap = map[mssql.ReturnStatus]string{ 37 -1: "The lock request timed out.", 38 -2: "The lock request was canceled.", 39 -3: "The lock request was chosen as a deadlock victim.", 40 -999: "Parameter validation or other call error.", 41 } 42 43 // Config for database 44 type Config struct { 45 MigrationsTable string 46 DatabaseName string 47 SchemaName string 48 } 49 50 // SQL Server connection 51 type SQLServer struct { 52 // Locking and unlocking need to use the same connection 53 conn *sql.Conn 54 db *sql.DB 55 isLocked atomic.Bool 56 57 // Open and WithInstance need to garantuee that config is never nil 58 config *Config 59 } 60 61 // WithInstance returns a database instance from an already created database connection. 62 // 63 // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. 64 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 65 if config == nil { 66 return nil, ErrNilConfig 67 } 68 69 if err := instance.Ping(); err != nil { 70 return nil, err 71 } 72 73 if config.DatabaseName == "" { 74 query := `SELECT DB_NAME()` 75 var databaseName string 76 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 77 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 78 } 79 80 if len(databaseName) == 0 { 81 return nil, ErrNoDatabaseName 82 } 83 84 config.DatabaseName = databaseName 85 } 86 87 if config.SchemaName == "" { 88 query := `SELECT SCHEMA_NAME()` 89 var schemaName string 90 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 91 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 92 } 93 94 if len(schemaName) == 0 { 95 return nil, ErrNoSchema 96 } 97 98 config.SchemaName = schemaName 99 } 100 101 if len(config.MigrationsTable) == 0 { 102 config.MigrationsTable = DefaultMigrationsTable 103 } 104 105 conn, err := instance.Conn(context.Background()) 106 107 if err != nil { 108 return nil, err 109 } 110 111 ss := &SQLServer{ 112 conn: conn, 113 db: instance, 114 config: config, 115 } 116 117 if err := ss.ensureVersionTable(); err != nil { 118 return nil, err 119 } 120 121 return ss, nil 122 } 123 124 // Open a connection to the database. 125 func (ss *SQLServer) Open(url string) (database.Driver, error) { 126 purl, err := nurl.Parse(url) 127 if err != nil { 128 return nil, err 129 } 130 131 useMsiParam := purl.Query().Get("useMsi") 132 useMsi := false 133 if len(useMsiParam) > 0 { 134 useMsi, err = strconv.ParseBool(useMsiParam) 135 if err != nil { 136 return nil, err 137 } 138 } 139 140 if _, isPasswordSet := purl.User.Password(); useMsi && isPasswordSet { 141 return nil, ErrMultipleAuthOptionsPassed 142 } 143 144 filteredURL := migrate.FilterCustomQuery(purl).String() 145 146 var db *sql.DB 147 if useMsi { 148 resource := getAADResourceFromServerUri(purl) 149 tokenProvider, err := getMSITokenProvider(resource) 150 if err != nil { 151 return nil, err 152 } 153 154 connector, err := mssql.NewAccessTokenConnector( 155 filteredURL, tokenProvider) 156 if err != nil { 157 return nil, err 158 } 159 160 db = sql.OpenDB(connector) 161 162 } else { 163 db, err = sql.Open("sqlserver", filteredURL) 164 if err != nil { 165 return nil, err 166 } 167 } 168 169 migrationsTable := purl.Query().Get("x-migrations-table") 170 171 px, err := WithInstance(db, &Config{ 172 DatabaseName: purl.Path, 173 MigrationsTable: migrationsTable, 174 }) 175 176 if err != nil { 177 return nil, err 178 } 179 180 return px, nil 181 } 182 183 // Close the database connection 184 func (ss *SQLServer) Close() error { 185 connErr := ss.conn.Close() 186 dbErr := ss.db.Close() 187 if connErr != nil || dbErr != nil { 188 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 189 } 190 return nil 191 } 192 193 // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. 194 func (ss *SQLServer) Lock() error { 195 return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { 196 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 197 if err != nil { 198 return err 199 } 200 201 // This will either obtain the lock immediately and return true, 202 // or return false if the lock cannot be acquired immediately. 203 // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017 204 query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0` 205 206 var status mssql.ReturnStatus 207 if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 { 208 return nil 209 } else if err != nil { 210 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 211 } else { 212 return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)} 213 } 214 }) 215 } 216 217 // Unlock froms the migration lock from the database 218 func (ss *SQLServer) Unlock() error { 219 return database.CasRestoreOnErr(&ss.isLocked, true, false, database.ErrNotLocked, func() error { 220 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 221 if err != nil { 222 return err 223 } 224 225 // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017 226 query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'` 227 if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil { 228 return &database.Error{OrigErr: err, Query: []byte(query)} 229 } 230 231 return nil 232 }) 233 } 234 235 // Run the migrations for the database 236 func (ss *SQLServer) Run(migration io.Reader) error { 237 migr, err := io.ReadAll(migration) 238 if err != nil { 239 return err 240 } 241 242 // run migration 243 query := string(migr[:]) 244 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 245 if msErr, ok := err.(mssql.Error); ok { 246 message := fmt.Sprintf("migration failed: %s", msErr.Message) 247 if msErr.ProcName != "" { 248 message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) 249 } 250 return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)} 251 } 252 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 253 } 254 255 return nil 256 } 257 258 // SetVersion for the current database 259 func (ss *SQLServer) SetVersion(version int, dirty bool) error { 260 261 tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) 262 if err != nil { 263 return &database.Error{OrigErr: err, Err: "transaction start failed"} 264 } 265 266 query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"` 267 if _, err := tx.Exec(query); err != nil { 268 if errRollback := tx.Rollback(); errRollback != nil { 269 err = multierror.Append(err, errRollback) 270 } 271 return &database.Error{OrigErr: err, Query: []byte(query)} 272 } 273 274 // Also re-write the schema version for nil dirty versions to prevent 275 // empty schema version for failed down migration on the first migration 276 // See: https://github.com/golang-migrate/migrate/issues/330 277 if version >= 0 || (version == database.NilVersion && dirty) { 278 var dirtyBit int 279 if dirty { 280 dirtyBit = 1 281 } 282 query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)` 283 if _, err := tx.Exec(query, version, dirtyBit); err != nil { 284 if errRollback := tx.Rollback(); errRollback != nil { 285 err = multierror.Append(err, errRollback) 286 } 287 return &database.Error{OrigErr: err, Query: []byte(query)} 288 } 289 } 290 291 if err := tx.Commit(); err != nil { 292 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 293 } 294 295 return nil 296 } 297 298 // Version of the current database state 299 func (ss *SQLServer) Version() (version int, dirty bool, err error) { 300 query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"` 301 err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 302 switch { 303 case err == sql.ErrNoRows: 304 return database.NilVersion, false, nil 305 306 case err != nil: 307 // FIXME: convert to MSSQL error 308 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 309 310 default: 311 return version, dirty, nil 312 } 313 } 314 315 // Drop all tables from the database. 316 func (ss *SQLServer) Drop() error { 317 318 // drop all referential integrity constraints 319 query := ` 320 DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR 321 322 SET @Cursor = CURSOR FAST_FORWARD FOR 323 SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']' 324 FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1 325 LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME 326 327 OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql 328 329 WHILE (@@FETCH_STATUS = 0) 330 BEGIN 331 Exec sp_executesql @Sql 332 FETCH NEXT FROM @Cursor INTO @Sql 333 END 334 335 CLOSE @Cursor DEALLOCATE @Cursor` 336 337 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 338 return &database.Error{OrigErr: err, Query: []byte(query)} 339 } 340 341 // drop the tables 342 query = `EXEC sp_MSforeachtable 'DROP TABLE ?'` 343 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 344 return &database.Error{OrigErr: err, Query: []byte(query)} 345 } 346 347 return nil 348 } 349 350 func (ss *SQLServer) ensureVersionTable() (err error) { 351 if err = ss.Lock(); err != nil { 352 return err 353 } 354 355 defer func() { 356 if e := ss.Unlock(); e != nil { 357 if err == nil { 358 err = e 359 } else { 360 err = multierror.Append(err, e) 361 } 362 } 363 }() 364 365 query := `IF NOT EXISTS 366 (SELECT * 367 FROM sysobjects 368 WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]') 369 AND OBJECTPROPERTY(id, N'IsUserTable') = 1 370 ) 371 CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` 372 373 if _, err = ss.conn.ExecContext(context.Background(), query); err != nil { 374 return &database.Error{OrigErr: err, Query: []byte(query)} 375 } 376 377 return nil 378 } 379 380 func getMSITokenProvider(resource string) (func() (string, error), error) { 381 msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil) 382 if err != nil { 383 return nil, err 384 } 385 386 return func() (string, error) { 387 err := msi.EnsureFresh() 388 if err != nil { 389 return "", err 390 } 391 token := msi.OAuthToken() 392 return token, nil 393 }, nil 394 } 395 396 // The sql server resource can change across clouds so get it 397 // dynamically based on the server uri. 398 // ex. <server name>.database.windows.net -> https://database.windows.net 399 func getAADResourceFromServerUri(purl *nurl.URL) string { 400 return fmt.Sprintf("%s%s", "https://", strings.Join(strings.Split(purl.Hostname(), ".")[1:], ".")) 401 }