github.com/nagyist/migrate/v4@v4.14.6/database/sqlserver/sqlserver.go (about) 1 package sqlserver 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "io" 8 "io/ioutil" 9 nurl "net/url" 10 11 mssql "github.com/denisenkom/go-mssqldb" // mssql support 12 "github.com/golang-migrate/migrate/v4" 13 "github.com/golang-migrate/migrate/v4/database" 14 "github.com/hashicorp/go-multierror" 15 ) 16 17 func init() { 18 database.Register("sqlserver", &SQLServer{}) 19 } 20 21 // DefaultMigrationsTable is the name of the migrations table in the database 22 var DefaultMigrationsTable = "schema_migrations" 23 24 var ( 25 ErrNilConfig = fmt.Errorf("no config") 26 ErrNoDatabaseName = fmt.Errorf("no database name") 27 ErrNoSchema = fmt.Errorf("no schema") 28 ErrDatabaseDirty = fmt.Errorf("database is dirty") 29 ) 30 31 var lockErrorMap = map[mssql.ReturnStatus]string{ 32 -1: "The lock request timed out.", 33 -2: "The lock request was canceled.", 34 -3: "The lock request was chosen as a deadlock victim.", 35 -999: "Parameter validation or other call error.", 36 } 37 38 // Config for database 39 type Config struct { 40 MigrationsTable string 41 DatabaseName string 42 SchemaName string 43 } 44 45 // SQL Server connection 46 type SQLServer struct { 47 // Locking and unlocking need to use the same connection 48 conn *sql.Conn 49 db *sql.DB 50 isLocked bool 51 52 // Open and WithInstance need to garantuee that config is never nil 53 config *Config 54 } 55 56 // WithInstance returns a database instance from an already created database connection. 57 // 58 // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. 59 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 60 if config == nil { 61 return nil, ErrNilConfig 62 } 63 64 if err := instance.Ping(); err != nil { 65 return nil, err 66 } 67 68 if config.DatabaseName == "" { 69 query := `SELECT DB_NAME()` 70 var databaseName string 71 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 72 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 73 } 74 75 if len(databaseName) == 0 { 76 return nil, ErrNoDatabaseName 77 } 78 79 config.DatabaseName = databaseName 80 } 81 82 if config.SchemaName == "" { 83 query := `SELECT SCHEMA_NAME()` 84 var schemaName string 85 if err := instance.QueryRow(query).Scan(&schemaName); err != nil { 86 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 87 } 88 89 if len(schemaName) == 0 { 90 return nil, ErrNoSchema 91 } 92 93 config.SchemaName = schemaName 94 } 95 96 if len(config.MigrationsTable) == 0 { 97 config.MigrationsTable = DefaultMigrationsTable 98 } 99 100 conn, err := instance.Conn(context.Background()) 101 102 if err != nil { 103 return nil, err 104 } 105 106 ss := &SQLServer{ 107 conn: conn, 108 db: instance, 109 config: config, 110 } 111 112 if err := ss.ensureVersionTable(); err != nil { 113 return nil, err 114 } 115 116 return ss, nil 117 } 118 119 // Open a connection to the database 120 func (ss *SQLServer) Open(url string) (database.Driver, error) { 121 purl, err := nurl.Parse(url) 122 if err != nil { 123 return nil, err 124 } 125 126 db, err := sql.Open("sqlserver", migrate.FilterCustomQuery(purl).String()) 127 if err != nil { 128 return nil, err 129 } 130 131 migrationsTable := purl.Query().Get("x-migrations-table") 132 133 px, err := WithInstance(db, &Config{ 134 DatabaseName: purl.Path, 135 MigrationsTable: migrationsTable, 136 }) 137 138 if err != nil { 139 return nil, err 140 } 141 142 return px, nil 143 } 144 145 // Close the database connection 146 func (ss *SQLServer) Close() error { 147 connErr := ss.conn.Close() 148 dbErr := ss.db.Close() 149 if connErr != nil || dbErr != nil { 150 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 151 } 152 return nil 153 } 154 155 // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. 156 func (ss *SQLServer) Lock() error { 157 if ss.isLocked { 158 return database.ErrLocked 159 } 160 161 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 162 if err != nil { 163 return err 164 } 165 166 // This will either obtain the lock immediately and return true, 167 // or return false if the lock cannot be acquired immediately. 168 // MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017 169 query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0` 170 171 var status mssql.ReturnStatus 172 if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 { 173 ss.isLocked = true 174 return nil 175 } else if err != nil { 176 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 177 } else { 178 return &database.Error{Err: fmt.Sprintf("try lock failed with error %v: %v", status, lockErrorMap[status]), Query: []byte(query)} 179 } 180 } 181 182 // Unlock froms the migration lock from the database 183 func (ss *SQLServer) Unlock() error { 184 if !ss.isLocked { 185 return nil 186 } 187 188 aid, err := database.GenerateAdvisoryLockId(ss.config.DatabaseName, ss.config.SchemaName) 189 if err != nil { 190 return err 191 } 192 193 // MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017 194 query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'` 195 if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil { 196 return &database.Error{OrigErr: err, Query: []byte(query)} 197 } 198 ss.isLocked = false 199 200 return nil 201 } 202 203 // Run the migrations for the database 204 func (ss *SQLServer) Run(migration io.Reader) error { 205 migr, err := ioutil.ReadAll(migration) 206 if err != nil { 207 return err 208 } 209 210 // run migration 211 query := string(migr[:]) 212 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 213 if msErr, ok := err.(mssql.Error); ok { 214 message := fmt.Sprintf("migration failed: %s", msErr.Message) 215 if msErr.ProcName != "" { 216 message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) 217 } 218 return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)} 219 } 220 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 221 } 222 223 return nil 224 } 225 226 // SetVersion for the current database 227 func (ss *SQLServer) SetVersion(version int, dirty bool) error { 228 229 tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{}) 230 if err != nil { 231 return &database.Error{OrigErr: err, Err: "transaction start failed"} 232 } 233 234 query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"` 235 if _, err := tx.Exec(query); err != nil { 236 if errRollback := tx.Rollback(); errRollback != nil { 237 err = multierror.Append(err, errRollback) 238 } 239 return &database.Error{OrigErr: err, Query: []byte(query)} 240 } 241 242 // Also re-write the schema version for nil dirty versions to prevent 243 // empty schema version for failed down migration on the first migration 244 // See: https://github.com/golang-migrate/migrate/issues/330 245 if version >= 0 || (version == database.NilVersion && dirty) { 246 var dirtyBit int 247 if dirty { 248 dirtyBit = 1 249 } 250 query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)` 251 if _, err := tx.Exec(query, version, dirtyBit); err != nil { 252 if errRollback := tx.Rollback(); errRollback != nil { 253 err = multierror.Append(err, errRollback) 254 } 255 return &database.Error{OrigErr: err, Query: []byte(query)} 256 } 257 } 258 259 if err := tx.Commit(); err != nil { 260 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 261 } 262 263 return nil 264 } 265 266 // Version of the current database state 267 func (ss *SQLServer) Version() (version int, dirty bool, err error) { 268 query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"` 269 err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 270 switch { 271 case err == sql.ErrNoRows: 272 return database.NilVersion, false, nil 273 274 case err != nil: 275 // FIXME: convert to MSSQL error 276 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 277 278 default: 279 return version, dirty, nil 280 } 281 } 282 283 // Drop all tables from the database. 284 func (ss *SQLServer) Drop() error { 285 286 // drop all referential integrity constraints 287 query := ` 288 DECLARE @Sql NVARCHAR(500) DECLARE @Cursor CURSOR 289 290 SET @Cursor = CURSOR FAST_FORWARD FOR 291 SELECT DISTINCT sql = 'ALTER TABLE [' + tc2.TABLE_NAME + '] DROP [' + rc1.CONSTRAINT_NAME + ']' 292 FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc1 293 LEFT JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc2 ON tc2.CONSTRAINT_NAME =rc1.CONSTRAINT_NAME 294 295 OPEN @Cursor FETCH NEXT FROM @Cursor INTO @Sql 296 297 WHILE (@@FETCH_STATUS = 0) 298 BEGIN 299 Exec sp_executesql @Sql 300 FETCH NEXT FROM @Cursor INTO @Sql 301 END 302 303 CLOSE @Cursor DEALLOCATE @Cursor` 304 305 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 306 return &database.Error{OrigErr: err, Query: []byte(query)} 307 } 308 309 // drop the tables 310 query = `EXEC sp_MSforeachtable 'DROP TABLE ?'` 311 if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { 312 return &database.Error{OrigErr: err, Query: []byte(query)} 313 } 314 315 return nil 316 } 317 318 func (ss *SQLServer) ensureVersionTable() (err error) { 319 if err = ss.Lock(); err != nil { 320 return err 321 } 322 323 defer func() { 324 if e := ss.Unlock(); e != nil { 325 if err == nil { 326 err = e 327 } else { 328 err = multierror.Append(err, e) 329 } 330 } 331 }() 332 333 query := `IF NOT EXISTS 334 (SELECT * 335 FROM sysobjects 336 WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]') 337 AND OBJECTPROPERTY(id, N'IsUserTable') = 1 338 ) 339 CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` 340 341 if _, err = ss.conn.ExecContext(context.Background(), query); err != nil { 342 return &database.Error{OrigErr: err, Query: []byte(query)} 343 } 344 345 return nil 346 }